hermes_five/utils/
task.rs1use std::future::Future;
3
4use parking_lot::Mutex;
5use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
6use tokio::sync::OnceCell;
7use tokio::task;
8use tokio::task::JoinHandle;
9
10use crate::errors::{Error, RuntimeError, UnknownError};
11
12pub enum TaskResult {
16 Ok,
17 Err(Error),
18}
19
20pub type TaskHandler = JoinHandle<Result<(), Error>>;
22
23pub static RUNTIME_TX: OnceCell<Mutex<Option<UnboundedSender<UnboundedReceiver<TaskResult>>>>> =
25 OnceCell::const_new();
26pub static RUNTIME_RX: OnceCell<Mutex<Option<UnboundedReceiver<UnboundedReceiver<TaskResult>>>>> =
27 OnceCell::const_new();
28
29impl From<Result<(), Error>> for TaskResult {
30 fn from(result: Result<(), Error>) -> Self {
31 match result {
32 Ok(_) => TaskResult::Ok,
33 Err(e) => TaskResult::Err(e),
34 }
35 }
36}
37
38impl From<()> for TaskResult {
39 fn from(_: ()) -> Self {
40 TaskResult::Ok
41 }
42}
43
44pub async fn init_task_channel() {
45 RUNTIME_RX
47 .get_or_init(|| async {
48 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<UnboundedReceiver<TaskResult>>();
50
51 RUNTIME_TX
53 .get_or_init(|| async { Mutex::new(Some(tx)) })
54 .await;
55
56 Mutex::new(Some(rx))
58 })
59 .await;
60}
61
62pub fn run<F, T>(future: F) -> Result<TaskHandler, Error>
87where
88 F: Future<Output = T> + Send + 'static,
89 T: Into<TaskResult> + Send + 'static,
90{
91 let (task_tx, task_rx) = tokio::sync::mpsc::unbounded_channel();
93
94 let handler = task::spawn(async move {
97 let result = future.await.into();
99 task_tx.send(result).map_err(|err| UnknownError {
100 info: err.to_string(),
101 })?;
102 Ok(())
103 });
104
105 let cell = RUNTIME_TX.get().ok_or(RuntimeError)?;
109 let mut lock = cell.lock();
110 let runtime_tx = lock.as_mut().ok_or(RuntimeError)?;
111
112 runtime_tx.send(task_rx).map_err(|err| UnknownError {
113 info: err.to_string(),
114 })?;
115
116 Ok(handler)
117}
118
119#[macro_export]
120macro_rules! pause {
121 ($ms:expr) => {
122 tokio::time::sleep(tokio::time::Duration::from_millis($ms as u64)).await
123 };
124}
125
126#[macro_export]
127macro_rules! pause_sync {
128 ($ms:expr) => {
129 std::thread::sleep(std::time::Duration::from_millis($ms as u64))
130 };
131}
132
133#[cfg(test)]
134mod tests {
135 use std::sync::atomic::{AtomicU8, Ordering};
136 use std::sync::Arc;
137 use std::time::SystemTime;
138
139 use serial_test::serial;
140
141 use crate::errors::{Error, UnknownError};
142 use crate::utils::task;
143
144 #[hermes_five_macros::runtime]
145 async fn my_runtime() -> Result<(), Error> {
146 task::run(async move {
147 pause!(500);
148 task::run(async move {
149 pause!(100);
150 task::run(async move {
151 pause!(100);
152 })?;
153 Ok(())
154 })?;
155 Ok(())
156 })?;
157
158 task::run(async move {
159 pause!(500);
160 })?;
161
162 task::run(async move {
163 pause!(500);
164 })?;
165
166 Ok(())
167 }
168
169 #[serial]
170 #[test]
171 fn test_task_parallel_execution() {
172 let start = SystemTime::now();
176 my_runtime().unwrap();
177 let end = SystemTime::now();
178
179 let duration = end.duration_since(start).unwrap().as_millis();
180 assert!(
181 duration > 500,
182 "Duration should be greater than 500ms (found: {})",
183 duration,
184 );
185 assert!(
186 duration < 1500,
187 "Duration should be lower than 1500ms (found: {})",
188 duration,
189 );
190 }
191
192 #[hermes_five_macros::test]
193 async fn test_task_abort_execution() {
194 let flag = Arc::new(AtomicU8::new(0));
195 let flag_clone = flag.clone();
196
197 task::run(async move {
199 pause!(100);
200 flag_clone.fetch_add(1, Ordering::SeqCst);
201 })
202 .expect("Should not panic");
203
204 pause!(50);
206 assert_eq!(
207 flag.load(Ordering::SeqCst),
208 0,
209 "Flag should not be updated by the task before 100ms",
210 );
211
212 pause!(100);
214 assert_eq!(
215 flag.load(Ordering::SeqCst),
216 1,
217 "Flag should be updated by the task after 100ms",
218 );
219
220 let flag_clone = flag.clone();
223
224 let handler = task::run(async move {
226 pause!(100);
227 flag_clone.fetch_add(1, Ordering::SeqCst);
228 })
229 .expect("Should not panic");
230
231 pause!(50);
233 assert_eq!(
234 flag.load(Ordering::SeqCst),
235 1,
236 "Flag should not be updated by the task before 100ms",
237 );
238
239 handler.abort();
241
242 pause!(100);
244 assert_eq!(
245 flag.load(Ordering::SeqCst),
246 1,
247 "Flag should be updated by the task after 100ms",
248 );
249 }
250
251 #[hermes_five_macros::test]
252 async fn test_task_with_result() {
253 let task = task::run(async move { Ok(()) });
254
255 assert!(task.is_ok(), "An Ok(()) task do not panic the runtime");
256
257 let task = task::run(async move {
258 Err(UnknownError {
259 info: "wow panic!".to_string(),
260 })
261 });
262
263 assert!(task.is_ok(), "A panicking task do not panic the runtime");
264 }
265}