use std::future::Future;
use parking_lot::Mutex;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::OnceCell;
use tokio::task;
use tokio::task::JoinHandle;
use crate::errors::{Error, RuntimeError, UnknownError};
pub enum TaskResult {
Ok,
Err(Error),
}
pub type TaskHandler = JoinHandle<Result<(), Error>>;
pub static RUNTIME_TX: OnceCell<Mutex<Option<UnboundedSender<UnboundedReceiver<TaskResult>>>>> =
OnceCell::const_new();
pub static RUNTIME_RX: OnceCell<Mutex<Option<UnboundedReceiver<UnboundedReceiver<TaskResult>>>>> =
OnceCell::const_new();
impl From<Result<(), Error>> for TaskResult {
fn from(result: Result<(), Error>) -> Self {
match result {
Ok(_) => TaskResult::Ok,
Err(e) => TaskResult::Err(e),
}
}
}
impl From<()> for TaskResult {
fn from(_: ()) -> Self {
TaskResult::Ok
}
}
pub async fn init_task_channel() {
RUNTIME_RX
.get_or_init(|| async {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<UnboundedReceiver<TaskResult>>();
RUNTIME_TX
.get_or_init(|| async { Mutex::new(Some(tx)) })
.await;
Mutex::new(Some(rx))
})
.await;
}
pub fn run<F, T>(future: F) -> Result<TaskHandler, Error>
where
F: Future<Output = T> + Send + 'static,
T: Into<TaskResult> + Send + 'static,
{
let (task_tx, task_rx) = tokio::sync::mpsc::unbounded_channel();
let handler = task::spawn(async move {
let result = future.await.into();
task_tx.send(result).map_err(|err| UnknownError {
info: err.to_string(),
})?;
Ok(())
});
let cell = RUNTIME_TX.get().ok_or(RuntimeError)?;
let mut lock = cell.lock();
let runtime_tx = lock.as_mut().ok_or(RuntimeError)?;
runtime_tx.send(task_rx).map_err(|err| UnknownError {
info: err.to_string(),
})?;
Ok(handler)
}
#[macro_export]
macro_rules! pause {
($ms:expr) => {
tokio::time::sleep(tokio::time::Duration::from_millis($ms as u64)).await
};
}
#[macro_export]
macro_rules! pause_sync {
($ms:expr) => {
std::thread::sleep(std::time::Duration::from_millis($ms as u64))
};
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use std::time::SystemTime;
use serial_test::serial;
use crate::errors::{Error, UnknownError};
use crate::utils::task;
#[hermes_five_macros::runtime]
async fn my_runtime() -> Result<(), Error> {
task::run(async move {
pause!(500);
task::run(async move {
pause!(100);
task::run(async move {
pause!(100);
})?;
Ok(())
})?;
Ok(())
})?;
task::run(async move {
pause!(500);
})?;
task::run(async move {
pause!(500);
})?;
Ok(())
}
#[serial]
#[test]
fn test_task_parallel_execution() {
let start = SystemTime::now();
my_runtime().unwrap();
let end = SystemTime::now();
let duration = end.duration_since(start).unwrap().as_millis();
assert!(
duration > 500,
"Duration should be greater than 500ms (found: {})",
duration,
);
assert!(
duration < 1500,
"Duration should be lower than 1500ms (found: {})",
duration,
);
}
#[hermes_five_macros::test]
async fn test_task_abort_execution() {
let flag = Arc::new(AtomicU8::new(0));
let flag_clone = flag.clone();
task::run(async move {
pause!(100);
flag_clone.fetch_add(1, Ordering::SeqCst);
})
.expect("Should not panic");
pause!(50);
assert_eq!(
flag.load(Ordering::SeqCst),
0,
"Flag should not be updated by the task before 100ms",
);
pause!(100);
assert_eq!(
flag.load(Ordering::SeqCst),
1,
"Flag should be updated by the task after 100ms",
);
let flag_clone = flag.clone();
let handler = task::run(async move {
pause!(100);
flag_clone.fetch_add(1, Ordering::SeqCst);
})
.expect("Should not panic");
pause!(50);
assert_eq!(
flag.load(Ordering::SeqCst),
1,
"Flag should not be updated by the task before 100ms",
);
handler.abort();
pause!(100);
assert_eq!(
flag.load(Ordering::SeqCst),
1,
"Flag should be updated by the task after 100ms",
);
}
#[hermes_five_macros::test]
async fn test_task_with_result() {
let task = task::run(async move { Ok(()) });
assert!(task.is_ok(), "An Ok(()) task do not panic the runtime");
let task = task::run(async move {
Err(UnknownError {
info: "wow panic!".to_string(),
})
});
assert!(task.is_ok(), "A panicking task do not panic the runtime");
}
}