reusable_box_future/
box_future.rs

1use alloc::alloc::Layout;
2use alloc::boxed::Box;
3use core::fmt;
4use core::future::Future;
5use core::mem::ManuallyDrop;
6use core::pin::Pin;
7use core::ptr::{self, NonNull};
8use core::task::{Context, Poll};
9
10/// A reusable `Pin<Box<dyn Future<Output = T> + Send>>`.
11///
12/// This type lets you replace the future stored in the box without
13/// reallocating when the size and alignment permits this.
14pub struct ReusableBoxFuture<T> {
15    boxed: NonNull<dyn Future<Output = T> + Send>,
16}
17
18impl<T> ReusableBoxFuture<T> {
19    /// Create a new `ReusableBoxFuture<T>` containing the provided future.
20    pub fn new<F>(future: F) -> Self
21    where
22        F: Future<Output = T> + Send + 'static,
23    {
24        let boxed: Box<dyn Future<Output = T> + Send> = Box::new(future);
25        let boxed = Box::into_raw(boxed);
26
27        // SAFETY: Box::into_raw does not return null pointers.
28        let boxed = unsafe { NonNull::new_unchecked(boxed) };
29
30        Self { boxed }
31    }
32
33    /// Replace the future currently stored in this box.
34    ///
35    /// This reallocates if and only if the layout of the provided future is
36    /// different from the layout of the currently stored future.
37    pub fn set<F>(&mut self, future: F)
38    where
39        F: Future<Output = T> + Send + 'static,
40    {
41        if let Err(future) = self.try_set(future) {
42            *self = Self::new(future);
43        }
44    }
45
46    /// Replace the future currently stored in this box.
47    ///
48    /// This function never reallocates, but returns an error if the provided
49    /// future has a different size or alignment from the currently stored
50    /// future.
51    pub fn try_set<F>(&mut self, future: F) -> Result<(), F>
52    where
53        F: Future<Output = T> + Send + 'static,
54    {
55        // SAFETY: The pointer is not dangling.
56        let self_layout = {
57            let dyn_future: &(dyn Future<Output = T> + Send) = unsafe { self.boxed.as_ref() };
58            Layout::for_value(dyn_future)
59        };
60
61        if Layout::new::<F>() == self_layout {
62            // SAFETY: We just checked that the layout of F is correct.
63            unsafe {
64                self.set_same_layout(future);
65            }
66
67            Ok(())
68        } else {
69            Err(future)
70        }
71    }
72
73    /// Set the current future.
74    ///
75    /// # Safety
76    ///
77    /// This function requires that the layout of the provided future is the
78    /// same as `self.boxed` layout.
79    unsafe fn set_same_layout<F>(&mut self, future: F)
80    where
81        F: Future<Output = T> + Send + 'static,
82    {
83        struct SetLayout<'a, F, T>
84        where
85            F: Future<Output = T> + Send + 'static,
86        {
87            rbf: &'a mut ReusableBoxFuture<T>,
88            new_future: ManuallyDrop<F>,
89        }
90
91        impl<'a, F, T> Drop for SetLayout<'a, F, T>
92        where
93            F: Future<Output = T> + Send + 'static,
94        {
95            fn drop(&mut self) {
96                // By doing the replacement on `drop` we make sure the change
97                // will happen even if the existing future panics on drop.
98                //
99                // We could use `catch_unwind`, but it is not available in `no_std`.
100                unsafe {
101                    // Overwrite the future behind the pointer. This is safe because the
102                    // allocation was allocated with the same size and alignment as the type F.
103                    let fut_ptr: *mut F = self.rbf.boxed.as_ptr() as *mut F;
104                    ptr::write(fut_ptr, ManuallyDrop::take(&mut self.new_future));
105
106                    // Update the vtable of self.boxed. The pointer is not null because we
107                    // just got it from self.boxed, which is not null.
108                    self.rbf.boxed = NonNull::new_unchecked(fut_ptr);
109                }
110            }
111        }
112
113        let set_layout = SetLayout {
114            rbf: self,
115            new_future: ManuallyDrop::new(future),
116        };
117
118        // Drop the existing future.
119        ptr::drop_in_place(set_layout.rbf.boxed.as_ptr());
120        // Now `set_layout` will be dropped and do the replacement.
121    }
122
123    /// Get a pinned reference to the underlying future.
124    pub fn get_pin(&mut self) -> Pin<&mut (dyn Future<Output = T> + Send)> {
125        // SAFETY: The user of this box cannot move the box, and we do not move it
126        // either.
127        unsafe { Pin::new_unchecked(self.boxed.as_mut()) }
128    }
129
130    /// Poll the future stored inside this box.
131    pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<T> {
132        self.get_pin().poll(cx)
133    }
134}
135
136impl<T> Future for ReusableBoxFuture<T> {
137    type Output = T;
138
139    /// Poll the future stored inside this box.
140    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
141        Pin::into_inner(self).get_pin().poll(cx)
142    }
143}
144
145// The future stored inside ReusableBoxFuture<T> must be Send.
146unsafe impl<T> Send for ReusableBoxFuture<T> {}
147
148// The only method called on self.boxed is poll, which takes &mut self, so this
149// struct being Sync does not permit any invalid access to the Future, even if
150// the future is not Sync.
151unsafe impl<T> Sync for ReusableBoxFuture<T> {}
152
153// Just like a Pin<Box<dyn Future>> is always Unpin, so is this type.
154impl<T> Unpin for ReusableBoxFuture<T> {}
155
156impl<T> Drop for ReusableBoxFuture<T> {
157    fn drop(&mut self) {
158        unsafe {
159            drop(Box::from_raw(self.boxed.as_ptr()));
160        }
161    }
162}
163
164impl<T> fmt::Debug for ReusableBoxFuture<T> {
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        f.debug_struct("ReusableBoxFuture").finish()
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use futures_executor::block_on;
174    use static_assertions::assert_impl_all;
175    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
176    use std::sync::Arc;
177
178    struct TestFut<T: Unpin> {
179        polled_nr: u32,
180        ready_val: u32,
181        dropped: Arc<AtomicBool>,
182        _buf: Option<T>,
183    }
184
185    impl<T: Unpin> TestFut<T> {
186        fn new(ready_val: u32) -> Self {
187            TestFut {
188                polled_nr: 0,
189                ready_val,
190                dropped: Arc::new(AtomicBool::new(false)),
191                _buf: None,
192            }
193        }
194    }
195
196    impl<T: Unpin> Future for TestFut<T> {
197        type Output = u32;
198
199        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
200            self.polled_nr += 1;
201
202            match self.polled_nr {
203                1 => {
204                    // First poll, simulate pending
205                    cx.waker().wake_by_ref();
206                    Poll::Pending
207                }
208                2 => {
209                    // Second poll, simulate ready
210                    Poll::Ready(self.ready_val)
211                }
212                _ => panic!("Future completed"),
213            }
214        }
215    }
216
217    impl<T: Unpin> Drop for TestFut<T> {
218        fn drop(&mut self) {
219            self.dropped.store(true, Ordering::SeqCst);
220        }
221    }
222
223    #[test]
224    fn alloc() {
225        block_on(async {
226            let test_fut = TestFut::<[u8; 32]>::new(1);
227            let dropped = Arc::clone(&test_fut.dropped);
228
229            let mut fut = ReusableBoxFuture::new(test_fut);
230            assert!(!dropped.load(Ordering::SeqCst));
231
232            assert_eq!((&mut fut).await, 1);
233            assert!(!dropped.load(Ordering::SeqCst));
234
235            let ptr = fut.boxed.as_ptr();
236            let test_fut = TestFut::<[u8; 32]>::new(2);
237            let dropped_2 = Arc::clone(&test_fut.dropped);
238            assert!(fut.try_set(test_fut).is_ok());
239            assert!(dropped.load(Ordering::SeqCst));
240            assert!(!dropped_2.load(Ordering::SeqCst));
241            assert_eq!(
242                ptr as *const _ as *mut u8,
243                fut.boxed.as_ptr() as *const _ as *mut u8
244            );
245
246            assert_eq!((&mut fut).await, 2);
247            assert!(!dropped_2.load(Ordering::SeqCst));
248
249            let test_fut = TestFut::<[u8; 256]>::new(3);
250            let dropped_3 = Arc::clone(&test_fut.dropped);
251            assert!(fut.try_set(test_fut).is_err());
252            assert!(!dropped_2.load(Ordering::SeqCst));
253            assert!(dropped_3.load(Ordering::SeqCst));
254
255            let test_fut = TestFut::<[u8; 256]>::new(4);
256            let dropped_4 = Arc::clone(&test_fut.dropped);
257            fut.set(test_fut);
258            assert!(dropped_2.load(Ordering::SeqCst));
259            assert!(!dropped_4.load(Ordering::SeqCst));
260            assert_ne!(
261                ptr as *const _ as *mut u8,
262                fut.boxed.as_ptr() as *const _ as *mut u8
263            );
264
265            assert_eq!((&mut fut).await, 4);
266            assert!(!dropped_4.load(Ordering::SeqCst));
267        })
268    }
269
270    #[test]
271    fn static_assertion() {
272        assert_impl_all!(ReusableBoxFuture<()>: Sync, Send, Unpin);
273    }
274
275    #[test]
276    fn panicking_drop() {
277        struct PanicDrop(Arc<AtomicUsize>);
278
279        impl Future for PanicDrop {
280            type Output = ();
281
282            fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
283                Poll::Ready(())
284            }
285        }
286
287        impl Drop for PanicDrop {
288            fn drop(&mut self) {
289                self.0.fetch_add(1, Ordering::Relaxed);
290
291                if !std::thread::panicking() {
292                    panic!(1u32);
293                }
294            }
295        }
296
297        // We use this second type to verify that we replace vtable by having a different
298        // drop implementation (i.e. adding 100 instead of 1)
299        struct NonPanicDrop(Arc<AtomicUsize>);
300
301        impl Future for NonPanicDrop {
302            type Output = ();
303
304            fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
305                Poll::Ready(())
306            }
307        }
308
309        impl Drop for NonPanicDrop {
310            fn drop(&mut self) {
311                self.0.fetch_add(100, Ordering::Relaxed);
312            }
313        }
314
315        let drop1 = Arc::new(AtomicUsize::new(0));
316        let drop2 = Arc::new(AtomicUsize::new(0));
317
318        let result = std::panic::catch_unwind({
319            let drop1 = Arc::clone(&drop1);
320            let drop2 = Arc::clone(&drop2);
321
322            move || {
323                let mut fut = ReusableBoxFuture::new(PanicDrop(drop1));
324
325                match fut.try_set(NonPanicDrop(drop2)) {
326                    Ok(_) => panic!(2u32),
327                    Err(_) => panic!(3u32),
328                }
329            }
330        });
331
332        // Make sure that panic was propagated correctly
333        assert_eq!(*result.err().unwrap().downcast::<u32>().unwrap(), 1);
334
335        // Make sure we drop only once per item
336        assert_eq!(drop1.load(Ordering::Relaxed), 1);
337        assert_eq!(drop2.load(Ordering::Relaxed), 100);
338    }
339}