1#![cfg_attr(not(feature = "std"), no_std)]
58
59#[cfg(not(feature = "std"))]
60extern crate alloc;
61
62#[cfg(not(feature = "std"))]
63use alloc::{boxed::Box, rc::Rc, vec::Vec};
64#[cfg(feature = "std")]
65use std::rc::Rc;
66
67unsafe fn transmute_lifetime<'a, A: 'static, R: 'static>(
68    value: Box<dyn FnMut(A) -> R + 'a>,
69) -> Box<dyn FnMut(A) -> R + 'static> {
70    core::mem::transmute(value)
71}
72
73#[cfg(feature = "async")]
74unsafe fn transmute_future_lifetime<'a, T: 'static>(
75    future: futures_util::future::LocalBoxFuture<'a, T>,
76) -> futures_util::future::LocalBoxFuture<'static, T> {
77    core::mem::transmute(future)
78}
79
80struct Deregister<'a>(core::cell::RefCell<Option<Box<dyn FnOnce() + 'a>>>);
81
82impl<'a> Deregister<'a> {
83    fn new(f: Box<dyn FnOnce() + 'a>) -> Self {
84        Self(core::cell::RefCell::new(Some(f)))
85    }
86
87    fn force(&self) {
88        if let Some(f) = self.0.borrow_mut().take() {
89            f();
90        }
91    }
92}
93
94impl<'a> Drop for Deregister<'a> {
95    fn drop(&mut self) {
96        self.force();
97    }
98}
99
100pub struct Registered<'env, 'scope> {
103    deregister: Rc<Deregister<'env>>,
104    marker: core::marker::PhantomData<&'scope ()>,
105}
106
107impl<'env, 'scope> Drop for Registered<'env, 'scope> {
108    fn drop(&mut self) {
109        self.deregister.force()
110    }
111}
112
113pub struct Scope<'env> {
116    callbacks: core::cell::RefCell<Vec<Rc<Deregister<'env>>>>,
117    marker: core::marker::PhantomData<&'env mut &'env ()>,
118}
119
120impl<'env> Scope<'env> {
121    fn new() -> Self {
122        Self {
123            callbacks: core::cell::RefCell::new(Vec::new()),
124            marker: core::marker::PhantomData,
125        }
126    }
127
128    pub fn register<'scope, A: 'static, R: 'static, H: 'static>(
138        &'scope self,
139        c: impl (FnMut(A) -> R) + 'env,
140        register: impl FnOnce(Box<dyn FnMut(A) -> R>) -> H + 'env,
141        deregister: impl FnOnce(H) + 'env,
142    ) -> Registered<'env, 'scope> {
143        let c = unsafe { transmute_lifetime(Box::new(c)) };
144        let c = Rc::new(core::cell::RefCell::new(Some(c)));
145        let handle = {
146            let c = c.clone();
147            register(Box::new(move |arg| {
148                (c.as_ref()
149                    .borrow_mut()
150                    .as_mut()
151                    .expect("Callback used after scope is unsafe"))(arg)
152            }))
153        };
154        let deregister = Rc::new(Deregister::new(Box::new(move || {
155            deregister(handle);
156            c.as_ref().borrow_mut().take();
157        })));
158        self.callbacks.borrow_mut().push(deregister.clone());
159        Registered {
160            deregister,
161            marker: core::marker::PhantomData,
162        }
163    }
164
165    #[cfg(feature = "async")]
168    pub fn future<'scope>(
169        &'scope self,
170        future: futures_util::future::LocalBoxFuture<'env, ()>,
171    ) -> impl futures_util::future::Future<Output = ()> + 'static {
172        use std::{cell::RefCell, pin::Pin};
173
174        let future = unsafe { transmute_future_lifetime(future) };
175        let future = Rc::new(RefCell::new(Some(future)));
176        self.callbacks
177            .borrow_mut()
178            .push(Rc::new(Deregister::new(Box::new({
179                let future = future.clone();
180                move || {
181                    future.as_ref().borrow_mut().take();
182                }
183            }))));
184
185        use futures_util::{
186            future::{Future, LocalBoxFuture},
187            task::{Context, Poll},
188        };
189        struct StaticFuture(Rc<RefCell<Option<LocalBoxFuture<'static, ()>>>>);
190
191        impl Future for StaticFuture {
192            type Output = ();
193
194            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
195                if let Some(future) = self.0.borrow_mut().as_mut() {
196                    Future::poll(future.as_mut(), cx)
197                } else {
198                    panic!("Future used after scope is unsafe")
199                }
200            }
201        }
202        StaticFuture(future)
203    }
204}
205
206impl<'env> Drop for Scope<'env> {
207    fn drop(&mut self) {
208        self.callbacks
209            .borrow()
210            .iter()
211            .for_each(|deregister| deregister.force());
212    }
213}
214
215pub fn scope<'env, R>(f: impl FnOnce(&Scope<'env>) -> R) -> R {
218    f(&Scope::<'env>::new())
219}
220
221#[cfg(feature = "async")]
229pub async fn scope_async<'env, R>(
230    f: impl for<'r> FnOnce(&'r Scope<'env>) -> futures_util::future::BoxFuture<'r, R>,
231) -> R {
232    f(&Scope::<'env>::new()).await
233}
234
235#[cfg(feature = "async")]
237pub async fn scope_async_local<'env, R>(
238    f: impl for<'r> FnOnce(&'r Scope<'env>) -> futures_util::future::LocalBoxFuture<'r, R>,
239) -> R {
240    f(&Scope::<'env>::new()).await
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    fn register(callback: Box<dyn FnMut(i32)>) -> Box<dyn FnMut(i32)> {
248        callback
249    }
250
251    fn deregister(_callback: Box<dyn FnMut(i32)>) {}
252
253    #[test]
254    fn it_works() {
255        let a = 42;
256        scope(|scope| {
257            let registered = scope.register(
258                |_| {
259                    let _b = a * a;
260                },
261                register,
262                deregister,
263            );
264
265            core::mem::drop(registered);
266        });
267    }
268
269    #[test]
270    fn calling() {
271        let stored = Rc::new(core::cell::RefCell::new(None));
272        scope(|scope| {
273            let registered = scope.register(
274                |a| 2 * a,
275                |callback| {
276                    stored.as_ref().borrow_mut().replace(callback);
277                },
278                |_| {},
279            );
280
281            assert_eq!((stored.as_ref().borrow_mut().as_mut().unwrap())(42), 2 * 42);
282
283            core::mem::drop(registered);
284        });
285    }
286
287    #[test]
288    fn drop_registered_causes_deregister() {
289        let dropped = Rc::new(core::cell::Cell::new(false));
290        scope(|scope| {
291            let registered = scope.register(|_| {}, register, {
292                let dropped = dropped.clone();
293                move |_| dropped.as_ref().set(true)
294            });
295
296            core::mem::drop(registered);
297            assert!(dropped.as_ref().get());
298        });
299    }
300
301    #[test]
302    fn leaving_scope_causes_deregister() {
303        let dropped = Rc::new(core::cell::Cell::new(false));
304        scope(|scope| {
305            let registered = scope.register(|_| {}, register, {
306                let dropped = dropped.clone();
307                move |_| dropped.as_ref().set(true)
308            });
309
310            core::mem::forget(registered);
311            assert!(!dropped.as_ref().get());
312        });
313        assert!(dropped.as_ref().get());
314    }
315
316    #[test]
317    #[cfg(feature = "std")]
320    fn calling_static_callback_after_drop_panics() {
321        let res = std::panic::catch_unwind(|| {
322            let stored = Rc::new(core::cell::RefCell::new(None));
323            scope(|scope| {
324                let registered = scope.register(
325                    |_| {},
326                    |callback| {
327                        stored.as_ref().borrow_mut().replace(callback);
328                    },
329                    |_| {},
330                );
331
332                core::mem::drop(registered);
333                (stored.as_ref().borrow_mut().as_mut().unwrap())(42);
334            });
335        });
336        assert!(res.is_err());
337    }
338
339    #[test]
340    #[cfg(feature = "std")]
343    fn calling_static_callback_after_scope_panics() {
344        let res = std::panic::catch_unwind(|| {
345            let stored = Rc::new(core::cell::RefCell::new(None));
346            scope(|scope| {
347                let registered = scope.register(
348                    |_| {},
349                    |callback| {
350                        stored.as_ref().borrow_mut().replace(callback);
351                    },
352                    |_| {},
353                );
354
355                core::mem::forget(registered);
356            });
357            (stored.as_ref().borrow_mut().as_mut().unwrap())(42);
358        });
359        assert!(res.is_err());
360    }
361
362    #[test]
363    #[cfg(feature = "std")]
366    fn panic_in_scoped_is_safe() {
367        let stored = std::sync::Mutex::new(None);
368        let res = std::panic::catch_unwind(|| {
369            scope(|scope| {
370                let registered = scope.register(
371                    |_| {},
372                    |callback| {
373                        stored.lock().unwrap().replace(callback);
374                    },
375                    |_| {},
376                );
377
378                core::mem::forget(registered);
379                panic!()
380            });
381        });
382        assert!(res.is_err());
383        let res = std::panic::catch_unwind(|| {
384            (stored.lock().unwrap().as_mut().take().unwrap())(42);
385        });
386        assert!(res.is_err());
387    }
388}