#![cfg_attr(not(feature = "std"), no_std)]
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, rc::Rc, vec::Vec};
#[cfg(feature = "std")]
use std::rc::Rc;
unsafe fn transmute_lifetime<'a, A: 'static, R: 'static>(
value: Box<dyn FnMut(A) -> R + 'a>,
) -> Box<dyn FnMut(A) -> R + 'static> {
core::mem::transmute(value)
}
#[cfg(feature = "async")]
unsafe fn transmute_future_lifetime<'a, T: 'static>(
future: futures_util::future::LocalBoxFuture<'a, T>,
) -> futures_util::future::LocalBoxFuture<'static, T> {
core::mem::transmute(future)
}
struct Deregister<'a>(core::cell::RefCell<Option<Box<dyn FnOnce() + 'a>>>);
impl<'a> Deregister<'a> {
fn new(f: Box<dyn FnOnce() + 'a>) -> Self {
Self(core::cell::RefCell::new(Some(f)))
}
fn force(&self) {
if let Some(f) = self.0.borrow_mut().take() {
f();
}
}
}
impl<'a> Drop for Deregister<'a> {
fn drop(&mut self) {
self.force();
}
}
pub struct Registered<'env, 'scope> {
deregister: Rc<Deregister<'env>>,
marker: core::marker::PhantomData<&'scope ()>,
}
impl<'env, 'scope> Drop for Registered<'env, 'scope> {
fn drop(&mut self) {
self.deregister.force()
}
}
pub struct Scope<'env> {
callbacks: core::cell::RefCell<Vec<Rc<Deregister<'env>>>>,
marker: core::marker::PhantomData<&'env mut &'env ()>,
}
impl<'env> Scope<'env> {
fn new() -> Self {
Self {
callbacks: core::cell::RefCell::new(Vec::new()),
marker: core::marker::PhantomData,
}
}
pub fn register<'scope, A: 'static, R: 'static, H: 'static>(
&'scope self,
c: impl (FnMut(A) -> R) + 'env,
register: impl FnOnce(Box<dyn FnMut(A) -> R>) -> H + 'env,
deregister: impl FnOnce(H) + 'env,
) -> Registered<'env, 'scope> {
let c = unsafe { transmute_lifetime(Box::new(c)) };
let c = Rc::new(core::cell::RefCell::new(Some(c)));
let handle = {
let c = c.clone();
register(Box::new(move |arg| {
(c.as_ref()
.borrow_mut()
.as_mut()
.expect("Callback used after scope is unsafe"))(arg)
}))
};
let deregister = Rc::new(Deregister::new(Box::new(move || {
deregister(handle);
c.as_ref().borrow_mut().take();
})));
self.callbacks.borrow_mut().push(deregister.clone());
Registered {
deregister,
marker: core::marker::PhantomData,
}
}
#[cfg(feature = "async")]
pub fn future<'scope>(
&'scope self,
future: futures_util::future::LocalBoxFuture<'env, ()>,
) -> impl futures_util::future::Future<Output = ()> + 'static {
use std::{cell::RefCell, pin::Pin};
let future = unsafe { transmute_future_lifetime(future) };
let future = Rc::new(RefCell::new(Some(future)));
self.callbacks
.borrow_mut()
.push(Rc::new(Deregister::new(Box::new({
let future = future.clone();
move || {
future.as_ref().borrow_mut().take();
}
}))));
use futures_util::{
future::{Future, LocalBoxFuture},
task::{Context, Poll},
};
struct StaticFuture(Rc<RefCell<Option<LocalBoxFuture<'static, ()>>>>);
impl Future for StaticFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(future) = self.0.borrow_mut().as_mut() {
Future::poll(future.as_mut(), cx)
} else {
panic!("Future used after scope is unsafe")
}
}
}
StaticFuture(future)
}
}
impl<'env> Drop for Scope<'env> {
fn drop(&mut self) {
self.callbacks
.borrow()
.iter()
.for_each(|deregister| deregister.force());
}
}
pub fn scope<'env, R>(f: impl FnOnce(&Scope<'env>) -> R) -> R {
f(&Scope::<'env>::new())
}
#[cfg(feature = "async")]
pub async fn scope_async<'env, R>(
f: impl for<'r> FnOnce(&'r Scope<'env>) -> futures_util::future::BoxFuture<'r, R>,
) -> R {
f(&Scope::<'env>::new()).await
}
#[cfg(feature = "async")]
pub async fn scope_async_local<'env, R>(
f: impl for<'r> FnOnce(&'r Scope<'env>) -> futures_util::future::LocalBoxFuture<'r, R>,
) -> R {
f(&Scope::<'env>::new()).await
}
#[cfg(test)]
mod tests {
use super::*;
fn register(callback: Box<dyn FnMut(i32)>) -> Box<dyn FnMut(i32)> {
callback
}
fn deregister(_callback: Box<dyn FnMut(i32)>) {}
#[test]
fn it_works() {
let a = 42;
scope(|scope| {
let registered = scope.register(
|_| {
let _b = a * a;
},
register,
deregister,
);
core::mem::drop(registered);
});
}
#[test]
fn calling() {
let stored = Rc::new(core::cell::RefCell::new(None));
scope(|scope| {
let registered = scope.register(
|a| 2 * a,
|callback| {
stored.as_ref().borrow_mut().replace(callback);
},
|_| {},
);
assert_eq!((stored.as_ref().borrow_mut().as_mut().unwrap())(42), 2 * 42);
core::mem::drop(registered);
});
}
#[test]
fn drop_registered_causes_deregister() {
let dropped = Rc::new(core::cell::Cell::new(false));
scope(|scope| {
let registered = scope.register(|_| {}, register, {
let dropped = dropped.clone();
move |_| dropped.as_ref().set(true)
});
core::mem::drop(registered);
assert!(dropped.as_ref().get());
});
}
#[test]
fn leaving_scope_causes_deregister() {
let dropped = Rc::new(core::cell::Cell::new(false));
scope(|scope| {
let registered = scope.register(|_| {}, register, {
let dropped = dropped.clone();
move |_| dropped.as_ref().set(true)
});
core::mem::forget(registered);
assert!(!dropped.as_ref().get());
});
assert!(dropped.as_ref().get());
}
#[test]
#[cfg(feature = "std")]
fn calling_static_callback_after_drop_panics() {
let res = std::panic::catch_unwind(|| {
let stored = Rc::new(core::cell::RefCell::new(None));
scope(|scope| {
let registered = scope.register(
|_| {},
|callback| {
stored.as_ref().borrow_mut().replace(callback);
},
|_| {},
);
core::mem::drop(registered);
(stored.as_ref().borrow_mut().as_mut().unwrap())(42);
});
});
assert!(res.is_err());
}
#[test]
#[cfg(feature = "std")]
fn calling_static_callback_after_scope_panics() {
let res = std::panic::catch_unwind(|| {
let stored = Rc::new(core::cell::RefCell::new(None));
scope(|scope| {
let registered = scope.register(
|_| {},
|callback| {
stored.as_ref().borrow_mut().replace(callback);
},
|_| {},
);
core::mem::forget(registered);
});
(stored.as_ref().borrow_mut().as_mut().unwrap())(42);
});
assert!(res.is_err());
}
#[test]
#[cfg(feature = "std")]
fn panic_in_scoped_is_safe() {
let stored = std::sync::Mutex::new(None);
let res = std::panic::catch_unwind(|| {
scope(|scope| {
let registered = scope.register(
|_| {},
|callback| {
stored.lock().unwrap().replace(callback);
},
|_| {},
);
core::mem::forget(registered);
panic!()
});
});
assert!(res.is_err());
let res = std::panic::catch_unwind(|| {
(stored.lock().unwrap().as_mut().take().unwrap())(42);
});
assert!(res.is_err());
}
}