use core::fmt;
use core::mem::ManuallyDrop;
use core::ops::Deref;
use core::ptr::NonNull;
use objc2::encode::{EncodeArguments, EncodeReturn};
use crate::abi::BlockHeader;
use crate::debug::debug_block_header;
use crate::{ffi, Block, IntoBlock, StackBlock};
#[doc(alias = "MallocBlock")]
pub struct RcBlock<F: ?Sized> {
ptr: NonNull<Block<F>>,
}
impl<F: ?Sized> RcBlock<F> {
#[inline]
pub unsafe fn from_raw(ptr: *mut Block<F>) -> Option<Self> {
NonNull::new(ptr).map(|ptr| Self { ptr })
}
#[doc(alias = "Block_copy")]
#[doc(alias = "_Block_copy")]
#[inline]
pub unsafe fn copy(ptr: *mut Block<F>) -> Option<Self> {
let ptr: *mut Block<F> = unsafe { ffi::_Block_copy(ptr.cast()) }.cast();
unsafe { Self::from_raw(ptr) }
}
}
impl<F: ?Sized> RcBlock<F> {
pub fn new<'f, A, R, Closure>(closure: Closure) -> Self
where
A: EncodeArguments,
R: EncodeReturn,
Closure: IntoBlock<'f, A, R, Dyn = F>,
{
let block = unsafe { StackBlock::new_no_clone(closure) };
let mut block = ManuallyDrop::new(block);
let ptr: *mut StackBlock<'f, A, R, Closure> = &mut *block;
let ptr: *mut Block<F> = ptr.cast();
unsafe { Self::copy(ptr) }.unwrap_or_else(|| rc_new_fail())
}
}
impl<F: ?Sized> Clone for RcBlock<F> {
#[doc(alias = "Block_copy")]
#[doc(alias = "_Block_copy")]
#[inline]
fn clone(&self) -> Self {
unsafe { Self::copy(self.ptr.as_ptr()) }.unwrap_or_else(|| rc_clone_fail())
}
}
fn rc_new_fail() -> ! {
panic!("failed creating RcBlock")
}
pub(crate) fn block_copy_fail() -> ! {
panic!("failed copying Block")
}
fn rc_clone_fail() -> ! {
unreachable!("cloning a RcBlock bumps the reference count, which should be infallible")
}
impl<F: ?Sized> Deref for RcBlock<F> {
type Target = Block<F>;
#[inline]
fn deref(&self) -> &Block<F> {
unsafe { self.ptr.as_ref() }
}
}
impl<F: ?Sized> Drop for RcBlock<F> {
#[doc(alias = "Block_release")]
#[doc(alias = "_Block_release")]
#[inline]
fn drop(&mut self) {
unsafe { ffi::_Block_release(self.ptr.as_ptr().cast()) };
}
}
impl<F: ?Sized> fmt::Debug for RcBlock<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut f = f.debug_struct("RcBlock");
let header = unsafe { self.ptr.cast::<BlockHeader>().as_ref() };
debug_block_header(header, &mut f);
f.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use alloc::rc::Rc;
use core::cell::OnceCell;
use super::*;
#[test]
fn return_rc_block() {
fn get_adder(x: i32) -> RcBlock<dyn Fn(i32) -> i32> {
RcBlock::new(move |y| y + x)
}
let add2 = get_adder(2);
assert_eq!(add2.call((5,)), 7);
assert_eq!(add2.call((-1,)), 1);
}
#[test]
fn rc_block_with_precisely_described_lifetimes() {
fn args<'a, 'b>(
f: impl Fn(&'a i32, &'b i32) + 'static,
) -> RcBlock<dyn Fn(&'a i32, &'b i32) + 'static> {
RcBlock::new(f)
}
fn args_return<'a, 'b>(
f: impl Fn(&'a i32) -> &'b i32 + 'static,
) -> RcBlock<dyn Fn(&'a i32) -> &'b i32 + 'static> {
RcBlock::new(f)
}
fn args_entire<'a, 'b>(f: impl Fn(&'a i32) + 'b) -> RcBlock<dyn Fn(&'a i32) + 'b> {
RcBlock::new(f)
}
fn return_entire<'a, 'b>(
f: impl Fn() -> &'a i32 + 'b,
) -> RcBlock<dyn Fn() -> &'a i32 + 'b> {
RcBlock::new(f)
}
let _ = args(|_, _| {});
let _ = args_return(|x| x);
let _ = args_entire(|_| {});
let _ = return_entire(|| &5);
}
#[allow(dead_code)]
fn covariant<'f>(b: RcBlock<dyn Fn() + 'static>) -> RcBlock<dyn Fn() + 'f> {
b
}
#[test]
fn allow_re_entrancy() {
#[allow(clippy::type_complexity)]
let block: Rc<OnceCell<RcBlock<dyn Fn(u32) -> u32>>> = Rc::new(OnceCell::new());
let captured_block = block.clone();
let fibonacci = move |n| {
let captured_fibonacci = captured_block.get().unwrap();
match n {
0 => 0,
1 => 1,
n => captured_fibonacci.call((n - 1,)) + captured_fibonacci.call((n - 2,)),
}
};
let block = block.get_or_init(|| RcBlock::new(fibonacci));
assert_eq!(block.call((0,)), 0);
assert_eq!(block.call((1,)), 1);
assert_eq!(block.call((6,)), 8);
assert_eq!(block.call((10,)), 55);
assert_eq!(block.call((19,)), 4181);
}
}