Skip to main content

async_pool/
lib.rs

1//! Statically allocated pool providing a std-like Box,
2//! allowing to asynchronously await for a pool slot to become available.
3//!
4//! It is tailored to be used with no-std async runtimes, like [Embassy](https://embassy.dev/), but
5//! can also be used in std environments (check examples).
6//!
7//! The most common use-case is sharing large memory regions on constrained
8//! devices (e.g. microcontrollers), where multiple tasks may need to use the
9//! memory for buffering an I/O or performing calculations, and having
10//! separate static buffers would be too costly.
11//!
12//! It is important to know that waiting forever for a memory slot to be
13//! available may dead-lock your code if done wrong. With that in mind,
14//! you should consider using a timeout when allocating asynchronously (e.g. [embassy_time::with_timeout](https://docs.rs/embassy-time/0.3.2/embassy_time/fn.with_timeout.html)).
15//!
16//! #### Dependencies
17//!
18//! This crate requires a critical section implementation. Check [critical-section](https://crates.io/crates/critical-section).
19//!
20//! #### Example
21//!
22//! ```
23//! use async_pool::{pool, Box};
24//!
25//!struct Buffer([u8; 256]);
26//!
27//!// A maximum of 2 Packet instances can be allocated at a time.
28//!// A maximum of 3 futures can be waiting at a time.
29//!pool!(BufferPool: [Buffer; 2], 3);
30//!
31//!async fn run() {
32//!    // Allocate non-blocking (will return None if no data slot is available)
33//!    let box1 = Box::<BufferPool>::new(Buffer([0; 256]));
34//!
35//!    // Allocate asynchronously (will wait if no data slot is available)
36//!    // This can return None if all future slots are taken
37//!    let box2 = Box::<BufferPool>::new_async(Buffer([0; 256])).await;
38//!}
39//! ```
40#![cfg_attr(not(test), no_std)]
41
42mod atomic_bitset;
43
44use core::cell::UnsafeCell;
45use core::future::{poll_fn, Future};
46use core::hash::{Hash, Hasher};
47use core::mem::MaybeUninit;
48use core::ops::{Deref, DerefMut};
49use core::task::Poll;
50use core::{cmp, mem, ptr::NonNull};
51use embassy_sync::waitqueue::AtomicWaker;
52use portable_atomic::AtomicU32;
53
54use crate::atomic_bitset::AtomicBitset;
55
56/// Implementation detail. Not covered by semver guarantees.
57#[doc(hidden)]
58pub trait PoolStorage<T> {
59    fn alloc(&self) -> Option<NonNull<T>>;
60    fn alloc_async(&self) -> impl Future<Output = Option<NonNull<T>>>;
61    unsafe fn free(&self, p: NonNull<T>);
62}
63
64/// Implementation detail. Not covered by semver guarantees.
65#[doc(hidden)]
66pub struct PoolStorageImpl<T, const N: usize, const K: usize, const WN: usize, const WK: usize>
67where
68    [AtomicU32; K]: Sized,
69    [AtomicU32; WK]: Sized,
70{
71    used: AtomicBitset<N, K>,
72    data: [UnsafeCell<MaybeUninit<T>>; N],
73    wakers_used: AtomicBitset<WN, WK>,
74    wakers: [AtomicWaker; WN],
75}
76
77unsafe impl<T, const N: usize, const K: usize, const WN: usize, const WK: usize> Send
78    for PoolStorageImpl<T, N, K, WN, WK>
79{
80}
81unsafe impl<T, const N: usize, const K: usize, const WN: usize, const WK: usize> Sync
82    for PoolStorageImpl<T, N, K, WN, WK>
83{
84}
85
86impl<T, const N: usize, const K: usize, const WN: usize, const WK: usize> Default
87    for PoolStorageImpl<T, N, K, WN, WK>
88where
89    [AtomicU32; K]: Sized,
90    [AtomicU32; WK]: Sized,
91{
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl<T, const N: usize, const K: usize, const WN: usize, const WK: usize>
98    PoolStorageImpl<T, N, K, WN, WK>
99where
100    [AtomicU32; K]: Sized,
101    [AtomicU32; WK]: Sized,
102{
103    const UNINIT: UnsafeCell<MaybeUninit<T>> = UnsafeCell::new(MaybeUninit::uninit());
104
105    const WAKER: AtomicWaker = AtomicWaker::new();
106
107    pub const fn new() -> Self {
108        Self {
109            used: AtomicBitset::new(),
110            data: [Self::UNINIT; N],
111            wakers_used: AtomicBitset::new(),
112            wakers: [Self::WAKER; WN],
113        }
114    }
115}
116
117impl<T, const N: usize, const K: usize, const WN: usize, const WK: usize> PoolStorage<T>
118    for PoolStorageImpl<T, N, K, WN, WK>
119where
120    [AtomicU32; K]: Sized,
121    [AtomicU32; WK]: Sized,
122{
123    /// Returns an item from the data pool, if available.
124    /// Returns None if the data pool is full.
125    fn alloc(&self) -> Option<NonNull<T>> {
126        let n = self.used.alloc()?;
127        let p = self.data[n].get() as *mut T;
128        Some(unsafe { NonNull::new_unchecked(p) })
129    }
130
131    /// Wait until an item is available in the data pool, then return it.
132    /// Returns None if the waker pool is full.
133    fn alloc_async(&self) -> impl Future<Output = Option<NonNull<T>>> {
134        let mut waker_slot = None;
135        poll_fn(move |cx| {
136            // Check if there is a free slot in the data pool
137            if let Some(n) = self.used.alloc() {
138                let p = self.data[n].get() as *mut T;
139                return Poll::Ready(Some(unsafe { NonNull::new_unchecked(p) }));
140            }
141
142            // Try to allocate a waker slot if necessary
143            if waker_slot.is_none() {
144                waker_slot = self.wakers_used.alloc_droppable();
145            }
146
147            match &waker_slot {
148                Some(bit) => {
149                    self.wakers[bit.inner()].register(cx.waker());
150                    Poll::Pending
151                }
152                None => Poll::Ready(None), // No waker slots available
153            }
154        })
155    }
156
157    /// safety: p must be a pointer obtained from self.alloc that hasn't been freed yet.
158    unsafe fn free(&self, p: NonNull<T>) {
159        let origin = self.data.as_ptr() as *mut T;
160        let n = p.as_ptr().offset_from(origin);
161        assert!(n >= 0);
162        assert!((n as usize) < N);
163        self.used.free(n as usize);
164
165        // Wake up any wakers waiting for a slot
166        for waker in self.wakers.iter() {
167            waker.wake();
168        }
169    }
170}
171
172pub trait Pool: 'static {
173    type Item: 'static;
174
175    /// Implementation detail. Not covered by semver guarantees.
176    #[doc(hidden)]
177    type Storage: PoolStorage<Self::Item>;
178
179    /// Implementation detail. Not covered by semver guarantees.
180    #[doc(hidden)]
181    fn get() -> &'static Self::Storage;
182}
183
184pub struct Box<P: Pool> {
185    ptr: NonNull<P::Item>,
186}
187
188impl<P: Pool> Box<P> {
189    /// Returns an item from the data pool, if available.
190    /// Returns None if the data pool is full.
191    pub fn new(item: P::Item) -> Option<Self> {
192        let p = P::get().alloc()?;
193        unsafe { p.as_ptr().write(item) };
194        Some(Self { ptr: p })
195    }
196
197    /// Wait until an item is available in the data pool, then return it.
198    /// Returns None if the waker pool is full.
199    pub async fn new_async(item: P::Item) -> Option<Self> {
200        let p = match P::get().alloc_async().await {
201            Some(p) => p,
202            None => return None,
203        };
204        unsafe { p.as_ptr().write(item) };
205        Some(Self { ptr: p })
206    }
207
208    pub fn into_raw(b: Self) -> NonNull<P::Item> {
209        let res = b.ptr;
210        mem::forget(b);
211        res
212    }
213
214    /// # Safety
215    ///
216    /// The caller must ensure the pointer is valid and that it will live long enough.
217    pub unsafe fn from_raw(ptr: NonNull<P::Item>) -> Self {
218        Self { ptr }
219    }
220}
221
222impl<P: Pool> Drop for Box<P> {
223    fn drop(&mut self) {
224        unsafe {
225            //trace!("dropping {:u32}", self.ptr as u32);
226            self.ptr.as_ptr().drop_in_place();
227            P::get().free(self.ptr);
228        };
229    }
230}
231
232unsafe impl<P: Pool> Send for Box<P> where P::Item: Send {}
233
234unsafe impl<P: Pool> Sync for Box<P> where P::Item: Sync {}
235
236unsafe impl<P: Pool> stable_deref_trait::StableDeref for Box<P> {}
237
238impl<P: Pool> as_slice_01::AsSlice for Box<P>
239where
240    P::Item: as_slice_01::AsSlice,
241{
242    type Element = <P::Item as as_slice_01::AsSlice>::Element;
243
244    fn as_slice(&self) -> &[Self::Element] {
245        self.deref().as_slice()
246    }
247}
248
249impl<P: Pool> as_slice_01::AsMutSlice for Box<P>
250where
251    P::Item: as_slice_01::AsMutSlice,
252{
253    fn as_mut_slice(&mut self) -> &mut [Self::Element] {
254        self.deref_mut().as_mut_slice()
255    }
256}
257
258impl<P: Pool> as_slice_02::AsSlice for Box<P>
259where
260    P::Item: as_slice_02::AsSlice,
261{
262    type Element = <P::Item as as_slice_02::AsSlice>::Element;
263
264    fn as_slice(&self) -> &[Self::Element] {
265        self.deref().as_slice()
266    }
267}
268
269impl<P: Pool> as_slice_02::AsMutSlice for Box<P>
270where
271    P::Item: as_slice_02::AsMutSlice,
272{
273    fn as_mut_slice(&mut self) -> &mut [Self::Element] {
274        self.deref_mut().as_mut_slice()
275    }
276}
277
278impl<P: Pool> Deref for Box<P> {
279    type Target = P::Item;
280
281    fn deref(&self) -> &P::Item {
282        unsafe { self.ptr.as_ref() }
283    }
284}
285
286impl<P: Pool> DerefMut for Box<P> {
287    fn deref_mut(&mut self) -> &mut P::Item {
288        unsafe { self.ptr.as_mut() }
289    }
290}
291
292impl<P: Pool> core::fmt::Debug for Box<P>
293where
294    P::Item: core::fmt::Debug,
295{
296    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
297        <P::Item as core::fmt::Debug>::fmt(self, f)
298    }
299}
300
301impl<P: Pool> core::fmt::Display for Box<P>
302where
303    P::Item: core::fmt::Display,
304{
305    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
306        <P::Item as core::fmt::Display>::fmt(self, f)
307    }
308}
309
310impl<P: Pool> PartialEq for Box<P>
311where
312    P::Item: PartialEq,
313{
314    fn eq(&self, rhs: &Box<P>) -> bool {
315        <P::Item as PartialEq>::eq(self, rhs)
316    }
317}
318
319impl<P: Pool> Eq for Box<P> where P::Item: Eq {}
320
321impl<P: Pool> PartialOrd for Box<P>
322where
323    P::Item: PartialOrd,
324{
325    fn partial_cmp(&self, rhs: &Box<P>) -> Option<cmp::Ordering> {
326        <P::Item as PartialOrd>::partial_cmp(self, rhs)
327    }
328}
329
330impl<P: Pool> Ord for Box<P>
331where
332    P::Item: Ord,
333{
334    fn cmp(&self, rhs: &Box<P>) -> cmp::Ordering {
335        <P::Item as Ord>::cmp(self, rhs)
336    }
337}
338
339impl<P: Pool> Hash for Box<P>
340where
341    P::Item: Hash,
342{
343    fn hash<H>(&self, state: &mut H)
344    where
345        H: Hasher,
346    {
347        <P::Item as Hash>::hash(self, state)
348    }
349}
350
351/// Create a item pool of a given type and size, as well as a waker pool of a given length.
352///
353/// The waker pool is used to wake up tasks waiting for an item to become available in the data pool.
354/// Its length should be at least the number of tasks that can be waiting for an item at the same time.
355/// Example:
356/// ```
357/// use async_pool::{pool, Box};
358///
359/// #[derive(Debug)]
360/// #[allow(dead_code)]
361/// struct Packet(u32);
362///
363/// pool!(PacketPool: [Packet; 4], 2); // Item pool of 4 Packet instances, waker pool of 2 wakers
364/// ```
365#[macro_export]
366macro_rules! pool {
367    ($vis:vis $name:ident: [$ty:ty; $n:expr], $wn:expr) => {
368        $vis struct $name { _uninhabited: ::core::convert::Infallible }
369        impl $crate::Pool for $name {
370            type Item = $ty;
371            type Storage = $crate::PoolStorageImpl<$ty, {$n}, {($n+31)/32}, {$wn}, {($wn+31)/32}>;
372            fn get() -> &'static Self::Storage {
373                static POOL: $crate::PoolStorageImpl<$ty, {$n}, {($n+31)/32}, {$wn}, {($wn+31)/32}> = $crate::PoolStorageImpl::new();
374                &POOL
375            }
376        }
377    };
378}
379
380#[cfg(test)]
381mod test {
382    use embassy_futures::yield_now;
383    use tokio::{join, select};
384
385    use super::*;
386    use core::mem;
387
388    pool!(TestPool: [u32; 4], 0);
389    pool!(TestPool2: [u32; 4], 1);
390
391    #[test]
392    fn test_pool() {
393        let b1 = Box::<TestPool>::new(111).unwrap();
394        let b2 = Box::<TestPool>::new(222).unwrap();
395        let b3 = Box::<TestPool>::new(333).unwrap();
396        let b4 = Box::<TestPool>::new(444).unwrap();
397        assert!(Box::<TestPool>::new(555).is_none());
398        assert_eq!(*b1, 111);
399        assert_eq!(*b2, 222);
400        assert_eq!(*b3, 333);
401        assert_eq!(*b4, 444);
402        mem::drop(b3);
403        let b5 = Box::<TestPool>::new(555).unwrap();
404        assert!(Box::<TestPool>::new(666).is_none());
405        assert_eq!(*b1, 111);
406        assert_eq!(*b2, 222);
407        assert_eq!(*b4, 444);
408        assert_eq!(*b5, 555);
409    }
410
411    #[test]
412    fn test_async_sizes() {
413        let pool1 = <TestPool as Pool>::get();
414        let pool2 = <TestPool2 as Pool>::get();
415        assert!(mem::size_of_val(pool1) < mem::size_of_val(pool2));
416    }
417
418    #[tokio::test]
419    async fn empty_async_pool() {
420        let b1 = Box::<TestPool>::new(111).unwrap();
421        let b2 = Box::<TestPool>::new_async(222).await.unwrap();
422        let b3 = Box::<TestPool>::new(333).unwrap();
423        let b4 = Box::<TestPool>::new_async(444).await.unwrap();
424        assert!(Box::<TestPool>::new_async(555).await.is_none());
425        assert_eq!(*b3, 333);
426        mem::drop(b3);
427        let b5 = Box::<TestPool>::new_async(555).await.unwrap();
428        assert!(Box::<TestPool>::new_async(666).await.is_none());
429        assert_eq!(*b1, 111);
430        assert_eq!(*b2, 222);
431        assert_eq!(*b4, 444);
432        assert_eq!(*b5, 555);
433    }
434
435    #[tokio::test]
436    async fn cancelled_future() {
437        let b1 = Box::<TestPool2>::new_async(111).await.unwrap();
438        let b2 = Box::<TestPool2>::new(222).unwrap();
439        let b3 = Box::<TestPool2>::new_async(333).await.unwrap();
440        let b4 = Box::<TestPool2>::new(444).unwrap();
441        assert_eq!(*b1, 111);
442        assert_eq!(*b2, 222);
443        assert_eq!(*b3, 333);
444        assert_eq!(*b4, 444);
445
446        let fut1 = async {
447            yield_now().await;
448            yield_now().await;
449        };
450
451        let fut2 = async { Box::<TestPool2>::new_async(555).await.unwrap() };
452
453        select! {
454            _ = fut1 => {},
455            v = fut2 => panic!("Future should have been cancelled: {:?}", v),
456        }
457
458        let fut3 = async {
459            yield_now().await;
460            yield_now().await;
461            mem::drop(b1);
462        };
463
464        let fut4 = Box::<TestPool2>::new_async(666);
465
466        let (b6, _) = join!(fut4, fut3);
467        assert_eq!(*b6.unwrap(), 666);
468    }
469}