async_safe_defer/
lib.rs

1//! This library provides two implementations of RAII-style deferred execution:
2//! one using dynamic allocation (the default) and one that avoids allocation
3//! entirely (`no_alloc`), with a fixed-capacity array of deferred function pointers.
4
5#![cfg_attr(not(test), no_std)]
6#![cfg_attr(docsrs, feature(doc_cfg))]
7
8extern crate alloc;
9use alloc::boxed::Box;
10use alloc::vec::Vec;
11use core::future::Future;
12use core::pin::Pin;
13
14/// RAII-style guard for executing a closure at the end of a scope.
15#[must_use = "Defer must be stored in a variable to execute the closure"]
16pub fn defer<F>(f: F) -> impl Drop
17where
18    F: FnOnce(),
19{
20    struct Defer<F: FnOnce()> {
21        f: Option<F>,
22    }
23
24    impl<F: FnOnce()> Drop for Defer<F> {
25        fn drop(&mut self) {
26            if let Some(f) = self.f.take() {
27                f();
28            }
29        }
30    }
31
32    Defer { f: Some(f) }
33}
34
35/// Macro for creating a synchronous defer guard.
36#[macro_export]
37macro_rules! defer {
38    ($e:expr) => {
39        let _guard = $crate::defer(|| $e);
40        let _ = &_guard;
41    };
42}
43
44/// An async-aware scope guard that stores deferred async closures (heap-based).
45pub struct AsyncScope {
46    defer: Vec<Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = ()> + 'static>> + 'static>>,
47}
48
49impl AsyncScope {
50    /// Creates a new `AsyncScope` for collecting async deferred tasks.
51    pub fn new() -> Self {
52        AsyncScope { defer: Vec::new() }
53    }
54
55    /// Registers an async closure to be executed later (LIFO).
56    pub fn defer<F>(&mut self, f: F)
57    where
58        F: FnOnce() -> Pin<Box<dyn Future<Output = ()> + 'static>> + 'static,
59    {
60        self.defer.push(Box::new(move || Box::pin(f())));
61    }
62
63    /// Runs all stored async tasks in reverse order.
64    pub async fn run(mut self) {
65        while let Some(f) = self.defer.pop() {
66            f().await;
67        }
68    }
69}
70
71/// Macro that creates an async scope to automatically await all defers.
72#[macro_export]
73macro_rules! async_scope {
74    ($scope:ident, $body:block) => {
75        async {
76            let mut $scope = $crate::AsyncScope::new();
77            $body
78            $scope.run().await;
79        }
80    };
81}
82
83/// A module for a no-alloc, fixed-capacity async scope.
84///
85/// Only compiled if `feature = "no_alloc"` is enabled or when tests run.
86/// Appears in docs if `docsrs` or the feature is active.
87#[cfg_attr(docsrs, doc(cfg(feature = "no_alloc")))]
88#[cfg(any(feature = "no_alloc", test, docsrs))]
89pub mod no_alloc {
90    use alloc::boxed::Box;
91    use core::{future::Future, pin::Pin};
92
93    /// Type alias for a `'static` function pointer returning a pinned async future.
94    pub type DeferredFn = fn() -> Pin<Box<dyn Future<Output = ()> + 'static>>;
95
96    /// A fixed-capacity async scope that does not use dynamic allocation.
97    pub struct AsyncScopeNoAlloc<const N: usize> {
98        tasks: [Option<DeferredFn>; N],
99        len: usize,
100    }
101
102    impl<const N: usize> AsyncScopeNoAlloc<N> {
103        /// Creates a new `AsyncScopeNoAlloc` with capacity `N`.
104        pub const fn new() -> Self {
105            Self {
106                tasks: [None; N],
107                len: 0,
108            }
109        }
110
111        /// Registers a `'static` function pointer to be called later.
112        ///
113        /// Panics if capacity is exceeded.
114        pub fn defer(&mut self, f: DeferredFn) {
115            if self.len >= N {
116                panic!("No space left for more tasks.");
117            }
118            self.tasks[self.len] = Some(f);
119            self.len += 1;
120        }
121
122        /// Executes all tasks in reverse order, awaiting each one.
123        pub async fn run(&mut self) {
124            while self.len > 0 {
125                self.len -= 1;
126                let task = self.tasks[self.len].take().unwrap();
127                (task)().await;
128            }
129        }
130    }
131
132    /// Macro to create a no-alloc async scope with fixed capacity.
133    #[cfg_attr(docsrs, doc(cfg(feature = "no_alloc")))]
134    #[macro_export]
135    macro_rules! no_alloc_async_scope {
136        ($scope:ident : $cap:expr, $body:block) => {
137            async {
138                let mut $scope = $crate::no_alloc::AsyncScopeNoAlloc::<$cap>::new();
139                $body
140                $scope.run().await;
141            }
142        };
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    extern crate std;
149    use self::std::sync::{
150        Arc, Mutex,
151        atomic::{AtomicUsize, Ordering},
152    };
153    use super::*;
154
155    #[test]
156    fn test_sync_defer() {
157        println!("test_sync_defer start");
158        let val = Arc::new(AtomicUsize::new(0));
159        {
160            println!("in scope, val={}", val.load(Ordering::SeqCst));
161            let v = val.clone();
162            defer!(v.store(42, Ordering::SeqCst));
163        }
164        println!("out of scope, val={}", val.load(Ordering::SeqCst));
165        assert_eq!(val.load(Ordering::SeqCst), 42);
166    }
167
168    #[tokio::test]
169    async fn test_async_scope_order() {
170        println!("test_async_scope_order start");
171        let log = Arc::new(Mutex::new(Vec::new()));
172        {
173            let mut scope = AsyncScope::new();
174
175            let l1 = log.clone();
176            scope.defer(move || {
177                println!("push(1) scheduled");
178                let l1 = l1.clone();
179                Box::pin(async move {
180                    println!("push(1) running");
181                    l1.lock().unwrap().push(1);
182                })
183            });
184
185            let l2 = log.clone();
186            scope.defer(move || {
187                println!("push(2) scheduled");
188                let l2 = l2.clone();
189                Box::pin(async move {
190                    println!("push(2) running");
191                    l2.lock().unwrap().push(2);
192                })
193            });
194
195            scope.run().await;
196        }
197        let result = log.lock().unwrap().clone();
198        println!("final log: {:?}", result);
199        assert_eq!(result, vec![2, 1]);
200    }
201
202    #[tokio::test]
203    async fn test_async_scope_macro() {
204        println!("test_async_scope_macro start");
205        use crate::async_scope;
206        let flag = Arc::new(AtomicUsize::new(0));
207        {
208            let f = Arc::clone(&flag);
209            async_scope!(scope, {
210                let f2 = Arc::clone(&f);
211                scope.defer(move || {
212                    println!("store(1) scheduled");
213                    Box::pin(async move {
214                        println!("store(1) running");
215                        f2.store(1, Ordering::SeqCst);
216                    })
217                });
218                println!("in scope, flag={}", f.load(Ordering::SeqCst));
219            })
220            .await;
221        }
222        println!("out of scope, flag={}", flag.load(Ordering::SeqCst));
223        assert_eq!(flag.load(Ordering::SeqCst), 1);
224    }
225
226    #[cfg(feature = "no_alloc")]
227    #[tokio::test]
228    async fn test_no_alloc_scope() {
229        println!("test_no_alloc_scope start");
230        use super::no_alloc::{AsyncScopeNoAlloc, DeferredFn};
231        use core::future::Future;
232        use core::pin::Pin;
233
234        fn task_one() -> Pin<Box<dyn Future<Output = ()> + 'static>> {
235            Box::pin(async {
236                println!("task_one running");
237            })
238        }
239        fn task_two() -> Pin<Box<dyn Future<Output = ()> + 'static>> {
240            Box::pin(async {
241                println!("task_two running");
242            })
243        }
244
245        let mut scope = AsyncScopeNoAlloc::<2>::new();
246        scope.defer(task_one as DeferredFn);
247        scope.defer(task_two as DeferredFn);
248        scope.run().await;
249    }
250}