use core::{
any::type_name,
fmt,
marker::PhantomData,
mem::{align_of, size_of, ManuallyDrop},
ptr::{drop_in_place, read, NonNull},
};
use super::slot::Slot;
#[cfg(feature = "at-least-inline-captures-48")]
const MAX_CLOSURE_DATA_SIZE: usize = 48;
#[cfg(all(
not(feature = "at-least-inline-captures-48"),
feature = "at-least-inline-captures-32"
))]
const MAX_CLOSURE_DATA_SIZE: usize = 32;
#[cfg(all(
not(feature = "at-least-inline-captures-48"),
not(feature = "at-least-inline-captures-32")
))]
const MAX_CLOSURE_DATA_SIZE: usize = 16;
type ErasedFnOnceSlot = Slot<MAX_CLOSURE_DATA_SIZE>;
type CallFn<T> = unsafe fn(ErasedFnOnceSlot) -> T;
type DropInPlaceFn = unsafe fn(*mut ErasedFnOnceSlot);
struct ErasedFnOnceVtable<T> {
call_impl: CallFn<T>,
drop_in_place_impl: DropInPlaceFn,
}
pub(crate) struct ErasedFnOnce<'a, T = ()> {
slot: ErasedFnOnceSlot,
vtable: NonNull<ErasedFnOnceVtable<T>>,
_marker: PhantomData<dyn FnOnce() -> T + 'a>,
}
impl<'a, T> ErasedFnOnce<'a, T> {
pub(crate) const fn new<F>(fn_once: F) -> Self
where
F: FnOnce() -> T + 'a,
{
assert!(
align_of::<F>() <= align_of::<ErasedFnOnceSlot>(),
"tailcall runtime cannot store this closure inline because its alignment exceeds the thunk slot alignment; reduce what the closure captures or move large/over-aligned state behind a pointer",
);
assert!(
size_of::<F>() <= MAX_CLOSURE_DATA_SIZE,
"tailcall runtime cannot store this closure inline because its captured state exceeds the configured thunk slot capacity; reduce captures, pass state as function arguments, box large captured values, or enable a larger thunk size feature",
);
Self {
slot: Slot::new(fn_once),
vtable: {
let vtable: *const ErasedFnOnceVtable<T> = &ErasedFnOnceVtable {
call_impl: |slot| {
unsafe { slot.into_value::<F>()() }
},
drop_in_place_impl: |slot_ptr| {
unsafe { drop_in_place(slot_ptr.cast::<F>()) };
},
};
unsafe { NonNull::new_unchecked(vtable.cast_mut()) }
},
_marker: PhantomData,
}
}
#[inline(always)]
fn vtable(&self) -> &ErasedFnOnceVtable<T> {
unsafe { self.vtable.as_ref() }
}
#[inline(always)]
pub(crate) fn call(self) -> T {
let this = ManuallyDrop::new(self);
let slot = unsafe { read(&this.slot) };
let vtable = this.vtable();
unsafe { (vtable.call_impl)(slot) }
}
}
impl<T> Drop for ErasedFnOnce<'_, T> {
fn drop(&mut self) {
unsafe { (self.vtable().drop_in_place_impl)(&mut self.slot) }
}
}
impl<T> fmt::Debug for ErasedFnOnce<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ErasedFnOnce -> {}", type_name::<T>())
}
}
#[cfg(test)]
mod tests {
extern crate std;
use super::ErasedFnOnce;
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::{boxed::Box, string::String};
#[test]
fn sanity() {
let thunk = ErasedFnOnce::new(|| 42);
assert_eq!(42, thunk.call());
}
#[test]
fn with_captures() {
let x = 1;
let y = 2;
let thunk = ErasedFnOnce::new(move || x + y);
assert_eq!(3, thunk.call());
}
#[test]
fn with_too_many_captures_reports_actionable_message() {
let a: u64 = 1;
let b: u64 = 2;
let c: u64 = 3;
let d: u64 = 4;
let e: u64 = 5;
let f: u64 = 6;
let g: u64 = 7;
let h: u64 = 8;
let panic = catch_unwind(AssertUnwindSafe(|| {
let _ = ErasedFnOnce::new(move || a + b + c + d + e + f + g + h);
}))
.expect_err("expected oversized closure capture to panic");
let message = panic_message(&panic);
assert!(message.contains("captured state exceeds the configured thunk slot capacity"));
assert!(message.contains("pass state as function arguments"));
}
#[test]
fn dropping_without_call_runs_destructor_once() {
let drops = std::rc::Rc::new(std::cell::Cell::new(0));
let tracker = DropTracker {
drops: std::rc::Rc::clone(&drops),
};
let thunk = ErasedFnOnce::new(move || {
let _tracker = tracker;
});
drop(thunk);
assert_eq!(drops.get(), 1);
}
#[test]
fn calling_runs_destructor_once() {
let drops = std::rc::Rc::new(std::cell::Cell::new(0));
let tracker = DropTracker {
drops: std::rc::Rc::clone(&drops),
};
let thunk = ErasedFnOnce::new(move || {
let _tracker = tracker;
});
thunk.call();
assert_eq!(drops.get(), 1);
}
#[test]
fn panic_during_call_drops_capture_once() {
let drops = std::rc::Rc::new(std::cell::Cell::new(0));
let tracker = DropTracker {
drops: std::rc::Rc::clone(&drops),
};
let thunk = ErasedFnOnce::new(move || {
let _tracker = tracker;
panic!("boom");
});
let _ = catch_unwind(AssertUnwindSafe(|| thunk.call()));
assert_eq!(drops.get(), 1);
}
struct DropTracker {
drops: std::rc::Rc<std::cell::Cell<usize>>,
}
impl Drop for DropTracker {
fn drop(&mut self) {
self.drops.set(self.drops.get() + 1);
}
}
fn panic_message(panic: &Box<dyn core::any::Any + Send>) -> &str {
if let Some(message) = panic.downcast_ref::<&'static str>() {
message
} else if let Some(message) = panic.downcast_ref::<String>() {
message.as_str()
} else {
panic!("unexpected panic payload type");
}
}
}