mountpoint_s3_fs/
async_util.rs

1use std::{fmt::Debug, future::Future};
2
3use async_channel::{Receiver, Sender};
4use futures::task::{Spawn, SpawnError, SpawnExt};
5
6use crate::sync::Arc;
7
8/// Type-erasure for a [Spawn] implementation.
9#[derive(Clone)]
10pub struct Runtime(Arc<dyn Spawn + Send + Sync>);
11
12impl Spawn for Runtime {
13    fn spawn_obj(&self, future: futures::task::FutureObj<'static, ()>) -> Result<(), SpawnError> {
14        self.0.spawn_obj(future)
15    }
16}
17
18impl Debug for Runtime {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        f.debug_tuple("Runtime").field(&"dyn").finish()
21    }
22}
23
24impl Runtime {
25    pub fn new(runtime: impl Spawn + Sync + Send + 'static) -> Self {
26        Runtime(Arc::new(runtime))
27    }
28
29    /// Spawns a task that polls the given future to completion and return
30    /// a [RemoteResult] with its output.
31    pub fn spawn_with_result<T, E, F>(&self, future: F) -> Result<RemoteResult<T, E>, SpawnError>
32    where
33        T: Send + 'static,
34        E: Send + 'static,
35        F: Future<Output = Result<T, E>> + Send + 'static,
36    {
37        let (sender, receiver) = result_channel();
38        self.spawn(async move {
39            let result = future.await;
40            sender.send(result).await;
41        })?;
42        Ok(receiver)
43    }
44}
45
46/// Creates an async one shot channel with a [RemoteResult] on the receiving end.
47pub fn result_channel<T, E>() -> (ResultSender<T, E>, RemoteResult<T, E>) {
48    let (sender, receiver) = async_channel::bounded(1);
49    (ResultSender { sender }, RemoteResult { receiver, value: None })
50}
51
52/// Holds the result of a spawned task.
53#[derive(Debug)]
54pub struct RemoteResult<T, E> {
55    receiver: Receiver<Result<T, E>>,
56    value: Option<T>,
57}
58
59/// Sender side of a [RemoteResult].
60pub struct ResultSender<T, E> {
61    sender: Sender<Result<T, E>>,
62}
63
64impl<T, E> ResultSender<T, E> {
65    pub async fn send(self, value: Result<T, E>) -> bool {
66        self.sender.send(value).await.is_ok()
67    }
68}
69
70impl<T, E> RemoteResult<T, E> {
71    async fn receive(&mut self) -> Result<&mut Option<T>, E> {
72        if self.value.is_none()
73            && let Ok(value) = self.receiver.recv().await
74        {
75            self.value = Some(value?);
76        }
77        Ok(&mut self.value)
78    }
79
80    pub async fn get_mut(&mut self) -> Result<Option<&mut T>, E> {
81        Ok(self.receive().await?.as_mut())
82    }
83
84    pub async fn into_inner(mut self) -> Result<Option<T>, E> {
85        Ok(self.receive().await?.take())
86    }
87}
88
89impl<T, E> Drop for RemoteResult<T, E> {
90    fn drop(&mut self) {
91        // Blocks to wait for the result and then drop it.
92        // Ignore the error if the sender has already been dropped.
93        _ = self.receiver.recv_blocking();
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use std::sync::Arc;
100    use std::sync::atomic::{AtomicBool, Ordering};
101
102    use futures::executor::{ThreadPool, block_on};
103    use test_case::test_case;
104
105    use super::{Runtime, result_channel};
106
107    #[test_case(Ok(42))]
108    #[test_case(Err("error"))]
109    fn test_into_inner(result: Result<i32, &'static str>) {
110        let expected = result;
111        let (sender, receiver) = result_channel();
112        block_on(sender.send(result));
113
114        let result = block_on(receiver.into_inner()).transpose().unwrap();
115        assert_eq!(result, expected);
116    }
117
118    #[test_case(Ok(42))]
119    #[test_case(Err("error"))]
120    fn test_get_mut(result: Result<i32, &'static str>) {
121        let expected = result;
122        let (sender, mut receiver) = result_channel();
123        block_on(sender.send(result));
124
125        let result = block_on(receiver.get_mut()).transpose().unwrap();
126        match expected {
127            Ok(expected_value) => assert!(matches!(result, Ok(value) if *value == expected_value)),
128            Err(expected_error) => assert!(matches!(result, Err(error) if *error == *expected_error)),
129        }
130    }
131
132    /// Verify that [RemoteResult] always drops the result.
133    #[test_case(true; "after await")]
134    #[test_case(false; "without await")]
135    fn test_drop(await_result: bool) {
136        let runtime = Runtime::new(ThreadPool::new().unwrap());
137
138        struct Dropping(Arc<AtomicBool>);
139
140        impl Drop for Dropping {
141            fn drop(&mut self) {
142                self.0.store(true, Ordering::SeqCst);
143            }
144        }
145
146        let was_dropped = Arc::new(AtomicBool::new(false));
147        let clone = was_dropped.clone();
148
149        let mut result = runtime
150            .spawn_with_result(async move { Ok::<_, &'static str>(Dropping(clone)) })
151            .unwrap();
152
153        if await_result {
154            block_on(async {
155                let _ = result.get_mut().await;
156            });
157        }
158
159        drop(result);
160
161        assert!(was_dropped.load(Ordering::SeqCst));
162    }
163}