hermes_five/utils/
task.rs

1//! Defines Hermes-Five Runtime task runner.
2use std::future::Future;
3
4use parking_lot::Mutex;
5use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
6use tokio::sync::OnceCell;
7use tokio::task;
8use tokio::task::JoinHandle;
9
10use crate::errors::{Error, RuntimeError, UnknownError};
11
12/// Represents the result of a TaskResult.
13/// A task may return either () or Result<(), Error> for flexibility which
14/// will be converted to TaskResult sent to the runtime.
15pub enum TaskResult {
16    Ok,
17    Err(Error),
18}
19
20/// Represents an arc protected handler for a task.
21pub type TaskHandler = JoinHandle<Result<(), Error>>;
22
23/// Globally accessible runtime transmitter(TX)/receiver(RX) (not initialised yet)
24pub static RUNTIME_TX: OnceCell<Mutex<Option<UnboundedSender<UnboundedReceiver<TaskResult>>>>> =
25    OnceCell::const_new();
26pub static RUNTIME_RX: OnceCell<Mutex<Option<UnboundedReceiver<UnboundedReceiver<TaskResult>>>>> =
27    OnceCell::const_new();
28
29impl From<Result<(), Error>> for TaskResult {
30    fn from(result: Result<(), Error>) -> Self {
31        match result {
32            Ok(_) => TaskResult::Ok,
33            Err(e) => TaskResult::Err(e),
34        }
35    }
36}
37
38impl From<()> for TaskResult {
39    fn from(_: ()) -> Self {
40        TaskResult::Ok
41    }
42}
43
44pub async fn init_task_channel() {
45    // If no receiver is configured, create a new one (with associated sender).
46    RUNTIME_RX
47        .get_or_init(|| async {
48            // Arbitrary limit to 100 simultaneous tasks.
49            let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<UnboundedReceiver<TaskResult>>();
50
51            // Set the runtime sender.
52            RUNTIME_TX
53                .get_or_init(|| async { Mutex::new(Some(tx)) })
54                .await;
55
56            // Set the runtime receiver.
57            Mutex::new(Some(rx))
58        })
59        .await;
60}
61
62/// Runs a given future as a Tokio task while ensuring the main function (marked by `#[hermes_five::runtime]`)
63/// will not finish before all tasks running as done.
64/// This is done by using a globally accessible channel to communicate the handlers to be waited by the
65/// runtime.
66///
67/// # Parameters
68/// * `future`: A future that implements `Future<Output = ()>`, `Send`, and has a `'static` lifetime.
69///
70/// # Errors
71/// Returns an error if the lock cannot be acquired or if the sender is not initialized or if sending the task handle fails.
72///
73/// # Example
74/// ```
75/// use hermes_five::utils::task;
76///
77/// #[hermes_five::runtime]
78/// async fn main() {
79///     let handler = task::run(async move {
80///         // whatever
81///     }).unwrap();
82///     // Abort the task early.
83///     handler.abort();
84/// }
85/// ```
86pub fn run<F, T>(future: F) -> Result<TaskHandler, Error>
87where
88    F: Future<Output = T> + Send + 'static,
89    T: Into<TaskResult> + Send + 'static,
90{
91    // Create a transmitter(tx)/receiver(rx) unique to this task.
92    let (task_tx, task_rx) = tokio::sync::mpsc::unbounded_channel();
93
94    // --
95    // Create a task to run our future: note how we capture the tx...
96    let handler = task::spawn(async move {
97        // ...to send the result of the future through that channel.
98        let result = future.await.into();
99        task_tx.send(result).map_err(|err| UnknownError {
100            info: err.to_string(),
101        })?;
102        Ok(())
103    });
104
105    // --
106    // Send the receiver(rx) side of the task-channel to the runtime.
107
108    let cell = RUNTIME_TX.get().ok_or(RuntimeError)?;
109    let mut lock = cell.lock();
110    let runtime_tx = lock.as_mut().ok_or(RuntimeError)?;
111
112    runtime_tx.send(task_rx).map_err(|err| UnknownError {
113        info: err.to_string(),
114    })?;
115
116    Ok(handler)
117}
118
119#[macro_export]
120macro_rules! pause {
121    ($ms:expr) => {
122        tokio::time::sleep(tokio::time::Duration::from_millis($ms as u64)).await
123    };
124}
125
126#[macro_export]
127macro_rules! pause_sync {
128    ($ms:expr) => {
129        std::thread::sleep(std::time::Duration::from_millis($ms as u64))
130    };
131}
132
133#[cfg(test)]
134mod tests {
135    use std::sync::atomic::{AtomicU8, Ordering};
136    use std::sync::Arc;
137    use std::time::SystemTime;
138
139    use serial_test::serial;
140
141    use crate::errors::{Error, UnknownError};
142    use crate::utils::task;
143
144    #[hermes_five_macros::runtime]
145    async fn my_runtime() -> Result<(), Error> {
146        task::run(async move {
147            pause!(500);
148            task::run(async move {
149                pause!(100);
150                task::run(async move {
151                    pause!(100);
152                })?;
153                Ok(())
154            })?;
155            Ok(())
156        })?;
157
158        task::run(async move {
159            pause!(500);
160        })?;
161
162        task::run(async move {
163            pause!(500);
164        })?;
165
166        Ok(())
167    }
168
169    #[serial]
170    #[test]
171    fn test_task_parallel_execution() {
172        // Tasks should be parallel and function should be blocked until all done.
173        // Therefore the `my_runtime()` function should take more time than the longest task, but less
174        // than the sum of task times.
175        let start = SystemTime::now();
176        my_runtime().unwrap();
177        let end = SystemTime::now();
178
179        let duration = end.duration_since(start).unwrap().as_millis();
180        assert!(
181            duration > 500,
182            "Duration should be greater than 500ms (found: {})",
183            duration,
184        );
185        assert!(
186            duration < 1500,
187            "Duration should be lower than 1500ms (found: {})",
188            duration,
189        );
190    }
191
192    #[hermes_five_macros::test]
193    async fn test_task_abort_execution() {
194        let flag = Arc::new(AtomicU8::new(0));
195        let flag_clone = flag.clone();
196
197        // Increment the flag after 100ms
198        task::run(async move {
199            pause!(100);
200            flag_clone.fetch_add(1, Ordering::SeqCst);
201        })
202        .expect("Should not panic");
203
204        // The flag should not have been incremented before the 100ms elapsed.
205        pause!(50);
206        assert_eq!(
207            flag.load(Ordering::SeqCst),
208            0,
209            "Flag should not be updated by the task before 100ms",
210        );
211
212        // The flag should have been incremented after the 100ms elapsed.
213        pause!(100);
214        assert_eq!(
215            flag.load(Ordering::SeqCst),
216            1,
217            "Flag should be updated by the task after 100ms",
218        );
219
220        // ########################################
221        // Same test but aborting
222        let flag_clone = flag.clone();
223
224        // Increment the flag after 100ms
225        let handler = task::run(async move {
226            pause!(100);
227            flag_clone.fetch_add(1, Ordering::SeqCst);
228        })
229        .expect("Should not panic");
230
231        // The flag should not have been incremented before the 100ms elapsed.
232        pause!(50);
233        assert_eq!(
234            flag.load(Ordering::SeqCst),
235            1,
236            "Flag should not be updated by the task before 100ms",
237        );
238
239        // Abort the task
240        handler.abort();
241
242        // The flag should not have been incremented after the 100ms elapsed.
243        pause!(100);
244        assert_eq!(
245            flag.load(Ordering::SeqCst),
246            1,
247            "Flag should be updated by the task after 100ms",
248        );
249    }
250
251    #[hermes_five_macros::test]
252    async fn test_task_with_result() {
253        let task = task::run(async move { Ok(()) });
254
255        assert!(task.is_ok(), "An Ok(()) task do not panic the runtime");
256
257        let task = task::run(async move {
258            Err(UnknownError {
259                info: "wow panic!".to_string(),
260            })
261        });
262
263        assert!(task.is_ok(), "A panicking task do not panic the runtime");
264    }
265}