use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::task::yield_now;
use tokio::time::{sleep, timeout};
use crate::proto::DynFut;
use crate::*;
type Output = Arc<Mutex<String>>;
struct TaskHi {
out: Output,
}
struct TaskProgress;
struct Periodic {
out: Output,
}
impl ITaskExecutor for TaskHi {
fn name(&self) -> &'static str {
"hi"
}
fn execute(&self, _: Box<dyn ITaskController>, task: Task) -> DynFut<TaskStatus> {
let mtx = self.out.clone();
Box::pin(async move {
let msg = format!("Hi, {}!\n", task.payload);
let mut out = mtx.lock().await;
out.push_str(&msg);
TaskStatus::Completed
})
}
}
impl ITaskExecutor for TaskProgress {
fn name(&self) -> &'static str {
"prog"
}
fn execute(&self, ctrl: Box<dyn ITaskController>, task: Task) -> DynFut<TaskStatus> {
let id = task.id;
Box::pin(async move {
ctrl.set_progress(id, 20).await.unwrap();
sleep(Duration::from_secs(1)).await;
ctrl.set_progress(id, 80).await.unwrap();
sleep(Duration::from_secs(1)).await;
TaskStatus::Error("expected error".to_string())
})
}
}
impl ITaskExecutor for Periodic {
fn name(&self) -> &'static str {
"per"
}
fn periodic_interval(&self) -> Option<u64> {
Some(1)
}
fn execute(&self, _: Box<dyn ITaskController>, _: Task) -> DynFut<TaskStatus> {
let mtx = self.out.clone();
Box::pin(async move {
let mut out = mtx.lock().await;
out.push_str("periodic\n");
TaskStatus::Completed
})
}
}
#[tokio::test]
async fn test_all() {
let out = Arc::new(Mutex::new(String::new()));
let man = TaskManagerBuilder::new_sqlite(":memory:")
.add_executor(TaskHi { out: out.clone() })
.add_executor(TaskProgress)
.add_executor(Periodic { out: out.clone() })
.build()
.unwrap();
let fut = timeout(Duration::from_secs(10), man.run());
tokio::task::spawn(fut);
let prog_id = man
.create_task(NewTask {
executor: "prog".to_string(),
payload: "".to_string(),
})
.await
.unwrap();
man.create_task(NewTask {
executor: "hi".to_string(),
payload: "Alex".to_string(),
})
.await
.unwrap();
loop {
let task = man.get(prog_id).await.unwrap().unwrap();
assert_eq!(task.executor, "prog");
match task.status {
TaskStatus::InProgress(0) => (),
TaskStatus::InProgress(20) => break,
other => panic!("unexpected status 1: {:?}", other),
}
yield_now().await
}
loop {
let task = man.get(prog_id).await.unwrap().unwrap();
assert_eq!(task.executor, "prog");
match task.status {
TaskStatus::InProgress(20) => (),
TaskStatus::InProgress(80) => break,
other => panic!("unexpected status 2: {:?}", other),
}
yield_now().await
}
loop {
let task = man.get(prog_id).await.unwrap().unwrap();
assert_eq!(task.executor, "prog");
match task.status {
TaskStatus::InProgress(80) => (),
TaskStatus::Error(msg) if msg == "expected error" => break,
other => panic!("unexpected status 3: {:?}", other),
}
yield_now().await
}
let raw_out = out.lock().await.clone();
assert!(
raw_out == "periodic\nHi, Alex!\nperiodic\n"
|| raw_out == "Hi, Alex!\nperiodic\nperiodic\n"
)
}
#[tokio::test]
async fn test_lock_mechanism() {
let out = Arc::new(Mutex::new(String::new()));
let man = TaskManagerBuilder::new_mem()
.add_executor(TaskHi { out })
.set_sleep(Duration::from_millis(100))
.build()
.unwrap();
let task1 = man
.create_task(NewTask {
executor: "hi".to_string(),
payload: "a".to_string(),
})
.await
.unwrap();
let task2 = man
.create_task(NewTask {
executor: "hi".to_string(),
payload: "b".to_string(),
})
.await
.unwrap();
let task3 = man
.create_task(NewTask {
executor: "hi".to_string(),
payload: "a".to_string(),
})
.await
.unwrap();
tokio::task::spawn(timeout(Duration::from_secs(4), man.run()));
assert_eq!(task1, task3);
assert_ne!(task1, task2);
loop {
let status = man.get(task1).await.unwrap().unwrap().status;
if matches!(status, TaskStatus::Completed) {
break;
}
}
let task4 = man
.create_task(NewTask {
executor: "hi".to_string(),
payload: "a".to_string(),
})
.await
.unwrap();
assert_ne!(task4, task1);
}