callback_future/
lib.rs

1use std::pin::Pin;
2use std::sync::{Arc, Mutex};
3
4use futures::Future;
5use futures::task::{Context, Poll};
6
7/// An adapter between callbacks and futures.
8///
9/// Allows wrapping asynchronous API with callbacks into futures.
10/// Calls loader upon first `Future::poll` call; stores result and wakes upon getting callback.
11pub struct CallbackFuture<T> {
12    loader: Option<Box<dyn FnOnce(Box<dyn FnOnce(T) + Send + 'static>) + Send + 'static>>,
13    result: Arc<Mutex<Option<T>>>,
14}
15
16impl<T> CallbackFuture<T> {
17    /// Creates a new CallbackFuture
18    ///
19    /// # Examples
20    /// ```
21    /// use callback_future::CallbackFuture;
22    /// use futures::executor::block_on;
23    /// use std::thread;
24    /// use std::time::Duration;
25    ///
26    /// let future = CallbackFuture::new(|complete| {
27    ///     // make call with callback here, call `complete` upon callback reception, e.g.:
28    ///     thread::spawn(move || {
29    ///         complete("Test");
30    ///     });
31    /// });
32    /// assert_eq!(block_on(future), "Test");
33    /// ```
34    pub fn new(loader: impl FnOnce(Box<dyn FnOnce(T) + Send + 'static>) + Send + 'static)
35               -> CallbackFuture<T> {
36        CallbackFuture {
37            loader: Some(Box::new(loader)),
38            result: Arc::new(Mutex::new(None)),
39        }
40    }
41
42    /// Creates a ready CallbackFuture
43    ///
44    /// # Examples
45    /// ```
46    /// use callback_future::CallbackFuture;
47    /// use futures::executor::block_on;
48    ///
49    /// assert_eq!(block_on(CallbackFuture::ready("Test")), "Test");
50    /// ```
51    pub fn ready(value: T) -> CallbackFuture<T> {
52        CallbackFuture {
53            loader: None,
54            result: Arc::new(Mutex::new(Some(value))),
55        }
56    }
57}
58
59impl<T: Send + 'static> Future for CallbackFuture<T> {
60    type Output = T;
61
62    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63        let self_mut = self.get_mut();
64        match self_mut.loader.take() {
65            // in case loader is still present, loader was not yet invoked: invoke it
66            Some(loader) => {
67                let waker = cx.waker().clone();
68                let result = self_mut.result.clone();
69                loader(Box::new(move |value| {
70                    *result.lock().unwrap() = Some(value);
71                    waker.wake();
72                }));
73                Poll::Pending
74            }
75            // in case loader was moved-out: either result is already ready,
76            // or we haven't yet received callback
77            None => {
78                match self_mut.result.lock().unwrap().take() {
79                    Some(value) => Poll::Ready(value),
80                    None => Poll::Pending, // we haven't received callback yet
81                }
82            }
83        }
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use std::thread;
90    use std::time::Duration;
91
92    use futures::{executor::block_on, join};
93
94    use crate::CallbackFuture;
95
96    #[test]
97    fn test_complete_async() {
98        let fu = CallbackFuture::new(move |complete| {
99            thread::spawn(move || { complete(42); });
100        });
101
102        assert_eq!(block_on(fu), 42);
103    }
104
105    #[test]
106    fn test_complete_sync() {
107        let fu = CallbackFuture::new(move |complete| {
108            complete(42);
109        });
110
111        assert_eq!(block_on(fu), 42);
112    }
113
114    #[test]
115    fn test_ready() {
116        let fu = CallbackFuture::ready(42);
117
118        assert_eq!(block_on(fu), 42);
119    }
120
121    #[test]
122    fn test_join() {
123        let all = async {
124            let fu1 = CallbackFuture::new(move |complete| {
125                complete("Hello");
126            });
127
128            let fu2 = CallbackFuture::ready(", ");
129
130            let fu3 = CallbackFuture::new(move |complete| {
131                thread::spawn(move || { complete("world!"); });
132            });
133
134            let (r1, r2, r3) = join!(fu1, fu2, fu3);
135            [r1, r2, r3].concat()
136        };
137
138        assert_eq!(block_on(all), "Hello, world!");
139    }
140
141    #[test]
142    fn test_await() {
143        let all = async {
144            let r1 = CallbackFuture::new(move |complete| {
145                thread::sleep(Duration::from_millis(100));
146                complete("Hello");
147            }).await;
148
149            let r2 = CallbackFuture::ready(", ").await;
150
151            let r3 = CallbackFuture::new(move |complete| {
152                thread::spawn(move || { complete("world!"); });
153            }).await;
154
155            [r1, r2, r3].concat()
156        };
157
158        assert_eq!(block_on(all), "Hello, world!");
159    }
160
161    #[test]
162    fn test_async_fn() {
163        async fn do_async() -> String {
164            CallbackFuture::new(move |complete| {
165                thread::spawn(move || { complete("Hello, world!".to_string()); });
166            }).await
167        }
168
169        assert_eq!(block_on(do_async()), "Hello, world!");
170    }
171}