use core::{
marker::PhantomData,
mem::{ManuallyDrop, align_of, size_of},
ptr,
};
#[derive(Debug)]
#[repr(align(1))]
pub struct Align1<const N: usize>([u8; N]);
#[derive(Debug)]
#[repr(align(2))]
pub struct Align2<const N: usize>([u8; N]);
#[derive(Debug)]
#[repr(align(4))]
pub struct Align4<const N: usize>([u8; N]);
#[derive(Debug)]
#[repr(align(8))]
pub struct Align8<const N: usize>([u8; N]);
#[derive(Debug)]
#[repr(align(16))]
pub struct Align16<const N: usize>([u8; N]);
#[derive(Debug)]
#[repr(align(32))]
pub struct Align32<const N: usize>([u8; N]);
#[derive(Debug)]
#[repr(align(64))]
pub struct Align64<const N: usize>([u8; N]);
pub trait Buffer {
fn copy_bytes<T>(bytes: T) -> Self;
fn as_ptr(&self) -> *const u8;
fn as_mut_ptr(&mut self) -> *mut u8;
}
macro_rules! impl_buffer {
($ty:ident) => {
impl<const N: usize> Buffer for $ty<N> {
fn copy_bytes<T>(bytes: T) -> Self {
const {
assert!(size_of::<T>() < size_of::<Self>());
}
let mut zero = unsafe { ::core::mem::zeroed::<Self>() };
let ptr = zero.as_mut_ptr().cast::<T>();
unsafe {
ptr.write(bytes);
}
zero
}
fn as_ptr(&self) -> *const u8 {
self.0.as_ptr()
}
fn as_mut_ptr(&mut self) -> *mut u8 {
self.0.as_mut_ptr()
}
}
};
}
impl_buffer!(Align1);
impl_buffer!(Align2);
impl_buffer!(Align4);
impl_buffer!(Align8);
impl_buffer!(Align16);
impl_buffer!(Align32);
#[derive(Debug)]
#[repr(C)]
pub struct OpaqueFn<'a, A, R, B: Buffer> {
buf: ManuallyDrop<B>,
call_fn: fn(*const u8, A) -> R,
drop_fn: fn(*mut u8),
_marker: PhantomData<&'a ()>, }
impl<'a, A, R, B: Buffer> OpaqueFn<'a, A, R, B> {
pub fn new<F>(f: F) -> Self
where
F: Fn(A) -> R + 'a,
{
fn call_impl<F, A, R>(ptr: *const u8, a: A) -> R
where
F: Fn(A) -> R,
{
let f = unsafe { &*(ptr.cast::<F>()) };
f(a)
}
fn drop_impl<F>(ptr: *mut u8)
where
F: Sized,
{
unsafe {
ptr::drop_in_place(ptr.cast::<F>());
}
}
const {
assert!(
size_of::<F>() <= size_of::<B>(),
"OpaqueFn too large for buffer"
);
assert!(
align_of::<F>() <= align_of::<B>(),
"OpaqueFn alignment too large for buffer"
);
}
let buf = B::copy_bytes(f);
Self {
buf: ManuallyDrop::new(buf),
call_fn: call_impl::<F, A, R>,
drop_fn: drop_impl::<F>,
_marker: PhantomData,
}
}
#[inline]
pub fn call<'call>(&self, a: A) -> R
where
A: 'call,
{
(self.call_fn)(self.buf.as_ptr(), a)
}
}
impl<A, R, B: Buffer> Drop for OpaqueFn<'_, A, R, B> {
fn drop(&mut self) {
unsafe {
(self.drop_fn)(self.buf.as_mut_ptr());
ManuallyDrop::drop(&mut self.buf);
}
}
}
#[derive(Debug)]
#[repr(C)]
pub struct OpaqueFnMut<'a, A, R, B: Buffer> {
buf: ManuallyDrop<B>,
call_fn: fn(*const u8, A) -> R,
drop_fn: fn(*mut u8),
_marker: PhantomData<&'a ()>, }
impl<'a, A, R, B: Buffer> OpaqueFnMut<'a, A, R, B> {
pub fn new<F>(f: F) -> Self
where
F: FnMut(A) -> R + 'a,
{
fn call_impl<F, A, R>(ptr: *const u8, a: A) -> R
where
F: FnMut(A) -> R,
{
let f = unsafe { &mut *(ptr as *mut F) }; f(a)
}
fn drop_impl<F>(ptr: *mut u8)
where
F: Sized,
{
unsafe { ptr::drop_in_place(ptr.cast::<F>()) };
}
const {
assert!(
size_of::<F>() <= size_of::<B>(),
"OpaqueFn too large for buffer"
);
assert!(
align_of::<F>() <= align_of::<B>(),
"OpaqueFn alignment too large for buffer"
);
}
let buf = B::copy_bytes(f);
Self {
buf: ManuallyDrop::new(buf),
call_fn: call_impl::<F, A, R>,
drop_fn: drop_impl::<F>,
_marker: PhantomData,
}
}
#[inline]
pub fn call<'call>(&self, a: A) -> R
where
A: 'call,
'call: 'a,
{
(self.call_fn)(self.buf.as_ptr(), a)
}
}
impl<A, R, B: Buffer> Drop for OpaqueFnMut<'_, A, R, B> {
fn drop(&mut self) {
unsafe {
(self.drop_fn)(self.buf.as_mut_ptr());
ManuallyDrop::drop(&mut self.buf);
}
}
}
#[derive(Debug)]
#[repr(C)]
pub struct OpaqueFnOnce<'a, A, R, B: Buffer> {
buf: ManuallyDrop<B>,
call_fn: fn(*const u8, A) -> R,
drop_fn: fn(*mut u8),
_marker: PhantomData<&'a ()>, }
impl<'a, A, R, B: Buffer> OpaqueFnOnce<'a, A, R, B> {
pub fn new<F>(f: F) -> Self
where
F: FnOnce(A) -> R + 'a,
{
fn call_impl<F, A, R>(ptr: *const u8, a: A) -> R
where
F: FnOnce(A) -> R,
{
unsafe {
let f = ptr::read(ptr.cast::<F>()); f(a)
}
}
fn drop_impl<F>(ptr: *mut u8)
where
F: Sized,
{
unsafe { ptr::drop_in_place(ptr.cast::<F>()) };
}
const {
assert!(size_of::<F>() <= size_of::<B>());
assert!(align_of::<F>() <= align_of::<B>());
}
let buf = B::copy_bytes(f);
Self {
buf: ManuallyDrop::new(buf),
call_fn: call_impl::<F, A, R>,
drop_fn: drop_impl::<F>,
_marker: PhantomData,
}
}
pub fn call(mut self, a: A) -> R {
let f = self.call_fn;
let result = f(self.buf.as_ptr(), a);
unsafe { core::ptr::write_bytes(&raw mut self.buf, 0, 1) };
core::mem::forget(self);
result
}
}
impl<A, R, B: Buffer> Drop for OpaqueFnOnce<'_, A, R, B> {
fn drop(&mut self) {
unsafe {
(self.drop_fn)(self.buf.as_mut_ptr());
ManuallyDrop::drop(&mut self.buf);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
extern crate alloc;
#[test]
fn simple_call() {
let f = OpaqueFn::<_, u32, Align1<32>>::new(|x: u32| x + 1);
assert_eq!(f.call(41), 42);
}
#[test]
fn capture_closure() {
let y = 10;
let f = OpaqueFn::<_, u32, Align8<32>>::new(|x: u32| x + y);
assert_eq!(f.call(5), 15);
}
#[test]
fn zero_sized_closure() {
let f = OpaqueFn::<_, u32, Align1<8>>::new(|x: u32| x * 2);
assert_eq!(f.call(21), 42);
}
#[test]
fn drop_runs() {
use core::cell::Cell;
struct DropCounter<'a> {
count: &'a Cell<u32>,
}
impl<'a> Drop for DropCounter<'a> {
fn drop(&mut self) {
self.count.set(self.count.get() + 1);
}
}
let counter = Cell::new(0);
{
let dc = DropCounter { count: &counter };
let f = OpaqueFn::<_, (), Align8<32>>::new(move |_| {
let _ = &dc; });
f.call(());
assert_eq!(counter.get(), 0); }
assert_eq!(counter.get(), 1); }
#[test]
#[should_panic(expected = "OpaqueFn too large for buffer")]
fn too_large_for_buffer() {
let big = [0u8; 128];
let _f = OpaqueFn::<(), (), Align2<32>>::new(move |_| {
let _ = &big;
});
}
#[test]
fn aligned_huge_buffer() {
let big = [0u128; 2];
let _f = OpaqueFn::<(), (), Align16<32>>::new(move |_| {
let _ = &big;
});
}
#[test]
fn fn_mut_increment() {
use core::cell::Cell;
let counter = Cell::new(0);
let f = OpaqueFnMut::<_, (), Align8<32>>::new({
let counter_ref = &counter;
move |_| {
counter_ref.set(counter_ref.get() + 1);
}
});
f.call(());
f.call(());
f.call(());
assert_eq!(counter.get(), 3);
}
#[test]
fn fn_once_consumed() {
let captured = alloc::string::String::from("hello");
let f = OpaqueFnOnce::<_, usize, Align16<64>>::new(move |x: usize| {
assert_eq!(captured, "hello");
x + captured.len()
});
let result = f.call(5);
assert_eq!(result, 5 + 5); }
#[test]
fn fn_once_mutable_capture() {
let mut s = alloc::string::String::from("foo");
let f = OpaqueFnOnce::<_, (), Align16<64>>::new(move |_| {
s.push_str("bar");
assert_eq!(s, "foobar");
});
f.call(()); }
#[test]
fn fn_mut_multiple_calls_mutation() {
let mut sum = 0;
let f = OpaqueFnMut::<_, u32, Align8<32>>::new(move |x: u32| {
sum += x;
sum
});
assert_eq!(f.call(1), 1);
assert_eq!(f.call(2), 3);
assert_eq!(f.call(3), 6);
}
}