1use std::pin::Pin;
2use std::sync::{Arc, Mutex};
3
4use futures::Future;
5use futures::task::{Context, Poll};
6
7pub struct CallbackFuture<T> {
12 loader: Option<Box<dyn FnOnce(Box<dyn FnOnce(T) + Send + 'static>) + Send + 'static>>,
13 result: Arc<Mutex<Option<T>>>,
14}
15
16impl<T> CallbackFuture<T> {
17 pub fn new(loader: impl FnOnce(Box<dyn FnOnce(T) + Send + 'static>) + Send + 'static)
35 -> CallbackFuture<T> {
36 CallbackFuture {
37 loader: Some(Box::new(loader)),
38 result: Arc::new(Mutex::new(None)),
39 }
40 }
41
42 pub fn ready(value: T) -> CallbackFuture<T> {
52 CallbackFuture {
53 loader: None,
54 result: Arc::new(Mutex::new(Some(value))),
55 }
56 }
57}
58
59impl<T: Send + 'static> Future for CallbackFuture<T> {
60 type Output = T;
61
62 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63 let self_mut = self.get_mut();
64 match self_mut.loader.take() {
65 Some(loader) => {
67 let waker = cx.waker().clone();
68 let result = self_mut.result.clone();
69 loader(Box::new(move |value| {
70 *result.lock().unwrap() = Some(value);
71 waker.wake();
72 }));
73 Poll::Pending
74 }
75 None => {
78 match self_mut.result.lock().unwrap().take() {
79 Some(value) => Poll::Ready(value),
80 None => Poll::Pending, }
82 }
83 }
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use std::thread;
90 use std::time::Duration;
91
92 use futures::{executor::block_on, join};
93
94 use crate::CallbackFuture;
95
96 #[test]
97 fn test_complete_async() {
98 let fu = CallbackFuture::new(move |complete| {
99 thread::spawn(move || { complete(42); });
100 });
101
102 assert_eq!(block_on(fu), 42);
103 }
104
105 #[test]
106 fn test_complete_sync() {
107 let fu = CallbackFuture::new(move |complete| {
108 complete(42);
109 });
110
111 assert_eq!(block_on(fu), 42);
112 }
113
114 #[test]
115 fn test_ready() {
116 let fu = CallbackFuture::ready(42);
117
118 assert_eq!(block_on(fu), 42);
119 }
120
121 #[test]
122 fn test_join() {
123 let all = async {
124 let fu1 = CallbackFuture::new(move |complete| {
125 complete("Hello");
126 });
127
128 let fu2 = CallbackFuture::ready(", ");
129
130 let fu3 = CallbackFuture::new(move |complete| {
131 thread::spawn(move || { complete("world!"); });
132 });
133
134 let (r1, r2, r3) = join!(fu1, fu2, fu3);
135 [r1, r2, r3].concat()
136 };
137
138 assert_eq!(block_on(all), "Hello, world!");
139 }
140
141 #[test]
142 fn test_await() {
143 let all = async {
144 let r1 = CallbackFuture::new(move |complete| {
145 thread::sleep(Duration::from_millis(100));
146 complete("Hello");
147 }).await;
148
149 let r2 = CallbackFuture::ready(", ").await;
150
151 let r3 = CallbackFuture::new(move |complete| {
152 thread::spawn(move || { complete("world!"); });
153 }).await;
154
155 [r1, r2, r3].concat()
156 };
157
158 assert_eq!(block_on(all), "Hello, world!");
159 }
160
161 #[test]
162 fn test_async_fn() {
163 async fn do_async() -> String {
164 CallbackFuture::new(move |complete| {
165 thread::spawn(move || { complete("Hello, world!".to_string()); });
166 }).await
167 }
168
169 assert_eq!(block_on(do_async()), "Hello, world!");
170 }
171}