mountpoint_s3_crt/io/
futures.rs1use std::fmt::Debug;
5use std::future::Future;
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll};
8
9use futures::channel::oneshot;
10use futures::future::BoxFuture;
11use futures::task::ArcWake;
12use futures::{FutureExt, TryFutureExt};
13use thiserror::Error;
14
15use crate::common::allocator::Allocator;
16use crate::common::task_scheduler::{Task, TaskScheduler, TaskStatus};
17
18#[derive(Debug)]
20pub struct FutureJoinHandle<T: Send + 'static> {
21 inner: Arc<Mutex<Option<FutureTaskInner<T>>>>,
22
23 receiver: oneshot::Receiver<Result<T, JoinError>>,
24}
25
26impl<T> FutureJoinHandle<T>
27where
28 T: Send + 'static,
29{
30 pub fn into_future(self) -> impl Future<Output = Result<T, JoinError>> {
32 self.receiver
33 .unwrap_or_else(|oneshot::Canceled| Err(JoinError::Canceled))
34 }
35
36 pub fn wait(self) -> Result<T, JoinError> {
38 futures::executor::block_on(self.into_future())
39 }
40
41 pub fn cancel(self) {
48 let mut locked = self.inner.lock().unwrap();
49
50 if let Some(inner) = locked.take() {
54 std::mem::drop(inner);
55 }
56 }
57}
58
59struct FutureTaskInner<T: Send + 'static> {
61 future: BoxFuture<'static, T>,
63
64 result_channel: oneshot::Sender<Result<T, JoinError>>,
66}
67
68impl<T: Debug + Send + 'static> Debug for FutureTaskInner<T> {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("FutureTaskInner")
72 .field("future", &(&self.future as *const BoxFuture<'static, T>))
73 .field("result_channel", &self.result_channel)
74 .finish()
75 }
76}
77
78struct FutureTaskWaker<S: TaskScheduler, T: Send + 'static> {
80 inner: Arc<Mutex<Option<FutureTaskInner<T>>>>,
85
86 scheduler: S,
88}
89
90impl<S: TaskScheduler, T: Send + 'static> FutureTaskWaker<S, T> {
91 fn finish_with_error(arc_self: &Arc<Self>, error: JoinError) {
94 let mut locked = arc_self.inner.lock().unwrap();
95
96 if let Some(inner) = locked.take() {
97 std::mem::drop(inner.future);
100 let _ = inner.result_channel.send(Err(error));
101 }
102 }
103
104 fn poll(arc_self: &Arc<Self>) {
108 let mut locked = arc_self.inner.lock().unwrap();
111
112 if let Some(mut inner) = locked.take() {
114 let waker = futures::task::waker_ref(arc_self);
116 let context = &mut Context::from_waker(&waker);
117
118 match Future::poll(inner.future.as_mut(), context) {
119 Poll::Ready(value) => {
120 std::mem::drop(inner.future);
124 let _ = inner.result_channel.send(Ok(value));
125 }
126 Poll::Pending => {
127 *locked = Some(inner);
130 }
131 }
132 }
133 }
134}
135
136impl<S: TaskScheduler, T: Send + 'static> ArcWake for FutureTaskWaker<S, T> {
137 fn wake_by_ref(arc_self: &Arc<Self>) {
140 let task_arc_self = arc_self.clone();
141
142 let task = Task::init(
145 &Allocator::default(),
146 move |status| match status {
147 TaskStatus::RunReady => FutureTaskWaker::poll(&task_arc_self),
148 TaskStatus::Canceled => FutureTaskWaker::finish_with_error(&task_arc_self, JoinError::Canceled),
149 },
150 "FutureTaskWaker_wake_by_ref",
151 );
152
153 match arc_self.scheduler.schedule_task_now(task) {
155 Ok(()) => {}
156 Err(err) => FutureTaskWaker::finish_with_error(arc_self, err.into()),
157 }
158 }
159}
160
161pub trait FutureSpawner: crate::private::Sealed {
163 fn spawn_future<T>(&self, future: impl Future<Output = T> + Send + 'static) -> FutureJoinHandle<T>
178 where
179 T: Send + 'static;
180}
181
182impl<S: TaskScheduler + Clone> FutureSpawner for S {
183 fn spawn_future<T>(&self, future: impl Future<Output = T> + Send + 'static) -> FutureJoinHandle<T>
184 where
185 T: Send + 'static,
186 {
187 let future = future.boxed();
188
189 let (tx, rx) = oneshot::channel();
190
191 let task_inner = Arc::new(Mutex::new(Some(FutureTaskInner {
192 future,
193 result_channel: tx,
194 })));
195
196 let waker = futures::task::waker(Arc::new(FutureTaskWaker {
197 inner: task_inner.clone(),
198 scheduler: self.clone(),
199 }));
200
201 waker.wake_by_ref();
204
205 FutureJoinHandle {
206 inner: task_inner,
207 receiver: rx,
208 }
209 }
210}
211
212#[derive(Error, Debug)]
214pub enum JoinError {
215 #[error("The task was cancelled")]
217 Canceled,
218
219 #[error("Internal CRT error: {0}")]
221 InternalError(#[from] crate::common::error::Error),
222}
223
224#[cfg(test)]
225mod test {
226 use futures::executor::block_on;
227 use futures::future::join_all;
228 use std::sync::atomic::{AtomicBool, AtomicU64};
229 use std::time::Duration;
230
231 use super::*;
232 use crate::common::allocator::Allocator;
233 use crate::io::event_loop::{EventLoopGroup, EventLoopTimer};
234 use std::sync::atomic::Ordering;
235
236 #[test]
238 fn test_simple_future() {
239 let allocator = Allocator::default();
240 let el_group = EventLoopGroup::new_default(&allocator, None, || {}).unwrap();
241
242 let handle = el_group.spawn_future(async {
243 println!("Hello from the future");
244 });
245
246 handle.wait().unwrap();
247 }
248
249 fn test_join_all_futures(scheduler: &impl FutureSpawner) {
251 const NUM_FUTURES: u64 = 50_000;
252
253 let counter = Arc::new(AtomicU64::new(0));
254
255 let mut future_handles = vec![];
256
257 for _ in 0..NUM_FUTURES {
258 let counter = counter.clone();
259 future_handles.push(scheduler.spawn_future(async move {
260 counter.fetch_add(1, Ordering::SeqCst);
261 }))
262 }
263
264 let results = block_on(join_all(future_handles.into_iter().map(FutureJoinHandle::into_future)));
265
266 assert_eq!(
267 Arc::strong_count(&counter),
268 1,
269 "all references to the counter except ours should be dropped"
270 );
271
272 let results: Result<(), JoinError> = results.into_iter().collect();
274 results.expect("one or more futures failed");
275
276 assert_eq!(counter.load(Ordering::SeqCst), NUM_FUTURES);
277 }
278
279 #[test]
281 fn test_join_all_futures_event_loop() {
282 let allocator = Allocator::default();
283 let el_group = EventLoopGroup::new_default(&allocator, None, || {}).unwrap();
284 let event_loop = el_group.get_next_loop().unwrap();
285
286 test_join_all_futures(&event_loop);
287 }
288
289 #[test]
291 fn test_join_all_futures_event_loop_group() {
292 let allocator = Allocator::default();
293 let el_group = EventLoopGroup::new_default(&allocator, None, || {}).unwrap();
294
295 test_join_all_futures(&el_group);
296 }
297
298 #[test]
300 fn test_cancel_future() {
301 let allocator = Allocator::default();
302 let el_group = EventLoopGroup::new_default(&allocator, None, || {}).unwrap();
303
304 let timer = EventLoopTimer::new(&el_group.get_next_loop().unwrap(), Duration::from_secs(20));
306
307 let flag = Arc::new(AtomicBool::new(false));
309
310 let future_handle = {
312 let flag = flag.clone();
313 el_group.spawn_future(async move {
314 timer.await.expect("failed while awaiting timer");
315 flag.store(true, Ordering::SeqCst);
316 })
317 };
318
319 assert_eq!(
320 Arc::strong_count(&flag),
321 2,
322 "there should be 2 references to flag: ours and the Future's"
323 );
324
325 std::thread::sleep(Duration::from_secs(1));
327
328 future_handle.cancel();
330
331 assert_eq!(
332 Arc::strong_count(&flag),
333 1,
334 "The Future should be dropped at this point"
335 );
336 assert!(
337 !flag.load(Ordering::SeqCst),
338 "flag should still be false after cancellation"
339 );
340 }
341}