use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::task::yield_now;
use tokio::time::{sleep, timeout};
use crate::proto::DynFut;
use crate::*;
type Output = Arc<Mutex<String>>;
struct TaskHi {
out: Output,
}
struct TaskProgress;
struct TaskErr;
struct Periodic {
out: Output,
status: TaskStatus,
}
struct PeriodicSpawner;
impl ITaskExecutor for TaskHi {
type Error = ();
fn name(&self) -> &'static str {
"hi"
}
fn execute(&self, _: Box<dyn ITaskController>, task: Task) -> DynFut<Result<TaskStatus, ()>> {
let mtx = self.out.clone();
Box::pin(async move {
let msg = format!("Hi, {}!\n", task.payload);
mtx.lock().await.push_str(&msg);
Ok(TaskStatus::Completed)
})
}
}
impl ITaskExecutor for TaskProgress {
type Error = ();
fn name(&self) -> &'static str {
"prog"
}
fn execute(
&self,
ctrl: Box<dyn ITaskController>,
task: Task,
) -> DynFut<Result<TaskStatus, ()>> {
let id = task.id;
Box::pin(async move {
ctrl.set_progress(id, 20).await.unwrap();
sleep(Duration::from_secs(1)).await;
ctrl.set_progress(id, 80).await.unwrap();
sleep(Duration::from_secs(1)).await;
Ok(TaskStatus::Error("expected error".to_string()))
})
}
}
impl ITaskExecutor for Periodic {
type Error = ();
fn name(&self) -> &'static str {
"per"
}
fn periodic_interval(&self) -> Option<u64> {
Some(1)
}
fn execute(&self, _: Box<dyn ITaskController>, _: Task) -> DynFut<Result<TaskStatus, ()>> {
let mtx = self.out.clone();
let status = self.status.clone();
Box::pin(async move {
let mut out = mtx.lock().await;
out.push_str("periodic\n");
Ok(status)
})
}
}
impl ITaskExecutor for TaskErr {
type Error = String;
fn name(&self) -> &'static str {
"err"
}
fn execute(
&self,
_: Box<dyn ITaskController>,
_: Task,
) -> DynFut<Result<TaskStatus, Self::Error>> {
Box::pin(async move { Err("oh".to_string()) })
}
}
impl ITaskExecutor for PeriodicSpawner {
type Error = ();
fn name(&self) -> &'static str {
"per-spawn"
}
fn periodic_interval(&self) -> Option<u64> {
Some(1)
}
fn execute(&self, ctrl: Box<dyn ITaskController>, _: Task) -> DynFut<Result<TaskStatus, ()>> {
Box::pin(async move {
ctrl.create_task(NewTask {
executor: "hi",
payload: "Mark".to_string(),
})
.await
.unwrap();
Ok(TaskStatus::Completed)
})
}
}
#[tokio::test]
async fn test_all() {
let out = Arc::new(Mutex::new(String::new()));
let man = TaskManagerBuilder::new_mem()
.add_executor(TaskHi { out: out.clone() })
.add_executor(TaskProgress)
.add_executor(Periodic {
out: out.clone(),
status: TaskStatus::Completed,
})
.build()
.unwrap();
let fut = timeout(Duration::from_secs(10), man.run());
tokio::task::spawn(fut);
let prog_id = man
.create_task(NewTask {
executor: "prog",
payload: "".to_string(),
})
.await
.unwrap();
man.create_task(NewTask {
executor: "hi",
payload: "Alex".to_string(),
})
.await
.unwrap();
loop {
let task = man.get(prog_id).await.unwrap().unwrap();
assert_eq!(task.executor, "prog");
match task.status {
TaskStatus::InProgress(0) => (),
TaskStatus::InProgress(20) => break,
other => panic!("unexpected status 1: {:?}", other),
}
yield_now().await
}
loop {
let task = man.get(prog_id).await.unwrap().unwrap();
assert_eq!(task.executor, "prog");
match task.status {
TaskStatus::InProgress(20) => (),
TaskStatus::InProgress(80) => break,
other => panic!("unexpected status 2: {:?}", other),
}
yield_now().await
}
loop {
let task = man.get(prog_id).await.unwrap().unwrap();
assert_eq!(task.executor, "prog");
match task.status {
TaskStatus::InProgress(80) => (),
TaskStatus::Error(msg) if msg == "expected error" => break,
other => panic!("unexpected status 3: {:?}", other),
}
yield_now().await
}
let raw_out = out.lock().await.clone();
assert!(
raw_out == "periodic\nHi, Alex!\nperiodic\n"
|| raw_out == "Hi, Alex!\nperiodic\nperiodic\n"
)
}
#[tokio::test]
async fn test_too_many_tasks_for_cell() {
let out = Arc::new(Mutex::new(String::new()));
let man = TaskManagerBuilder::new_mem()
.add_executor(TaskHi { out: out.clone() })
.add_executor(TaskProgress)
.set_pool_size(1)
.set_sleep(Duration::from_millis(100))
.build()
.unwrap();
let fut = timeout(Duration::from_secs(10), man.run());
man.create_task(NewTask {
executor: "hi",
payload: "a".to_string(),
})
.await
.unwrap();
man.create_task(NewTask {
executor: "hi",
payload: "b".to_string(),
})
.await
.unwrap();
man.create_task(NewTask {
executor: "hi",
payload: "c".to_string(),
})
.await
.unwrap();
tokio::task::spawn(fut);
sleep(Duration::from_secs(3)).await;
let text = out.lock().await.clone();
let lines = text
.split("\n")
.filter(|s| !s.is_empty())
.collect::<HashSet<_>>();
let expected = ["Hi, a!", "Hi, b!", "Hi, c!"]
.into_iter()
.collect::<HashSet<_>>();
assert_eq!(lines, expected)
}
#[tokio::test]
async fn test_periodic_fail() {
let out = Arc::new(Mutex::new(String::new()));
let man = TaskManagerBuilder::new_mem()
.add_executor(TaskHi { out: out.clone() })
.add_executor(TaskProgress)
.add_executor(Periodic {
out: out.clone(),
status: TaskStatus::Error("x".to_string()),
})
.set_pool_size(1)
.set_sleep(Duration::from_millis(100))
.build()
.unwrap();
tokio::task::spawn(timeout(Duration::from_secs(10), man.run()));
sleep(Duration::from_millis(2300)).await;
let text = out.lock().await.clone();
assert_eq!(text, "periodic\nperiodic\nperiodic\n")
}
#[tokio::test]
async fn test_bad_executor() {
let out = Arc::new(Mutex::new(String::new()));
let man = TaskManagerBuilder::new_mem()
.add_executor(TaskHi { out: out.clone() })
.add_executor(TaskProgress)
.set_pool_size(1)
.set_sleep(Duration::from_millis(100))
.build()
.unwrap();
let fut = timeout(Duration::from_secs(2), man.run());
tokio::task::spawn(fut);
let res = man
.create_task(NewTask {
executor: "xxx",
payload: "".to_string(),
})
.await;
assert!(res.is_err());
}
#[tokio::test]
async fn test_prune() {
let out = Arc::new(Mutex::new(String::new()));
let man = TaskManagerBuilder::new_mem()
.add_executor(TaskHi { out })
.add_executor(TaskErr)
.set_pool_size(5)
.set_succ_task_lifetime(1)
.set_err_task_lifetime(1)
.set_sleep(Duration::from_millis(100))
.build()
.unwrap();
tokio::task::spawn(timeout(Duration::from_secs(4), man.run()));
let hi_id = man
.create_task(NewTask {
executor: "hi",
payload: "a".to_string(),
})
.await
.unwrap();
let err_id = man
.create_task(NewTask {
executor: "err",
payload: "".to_string(),
})
.await
.unwrap();
loop {
let task = man.get(hi_id).await.unwrap().unwrap();
if matches!(task.status, TaskStatus::Completed) {
break;
}
yield_now().await
}
loop {
let task = man.get(err_id).await.unwrap().unwrap();
if matches!(task.status, TaskStatus::Error(_)) {
break;
}
yield_now().await
}
sleep(Duration::from_secs(3)).await;
assert!(man.get(hi_id).await.unwrap().is_none());
assert!(man.get(err_id).await.unwrap().is_none());
}
#[tokio::test]
async fn test_lock_mechanism() {
let out = Arc::new(Mutex::new(String::new()));
let man = TaskManagerBuilder::new_mem()
.add_executor(TaskHi { out })
.set_sleep(Duration::from_millis(100))
.build()
.unwrap();
let task1 = man
.create_task(NewTask {
executor: "hi",
payload: "a".to_string(),
})
.await
.unwrap();
let task2 = man
.create_task(NewTask {
executor: "hi",
payload: "b".to_string(),
})
.await
.unwrap();
let task3 = man
.create_task(NewTask {
executor: "hi",
payload: "a".to_string(),
})
.await
.unwrap();
tokio::task::spawn(timeout(Duration::from_secs(4), man.run()));
assert_eq!(task1, task3);
assert_ne!(task1, task2);
loop {
let status = man.get(task1).await.unwrap().unwrap().status;
if matches!(status, TaskStatus::Completed) {
break;
}
}
let task4 = man
.create_task(NewTask {
executor: "hi",
payload: "a".to_string(),
})
.await
.unwrap();
assert_ne!(task4, task1);
}
#[tokio::test]
async fn test_spawner() {
let out = Arc::new(Mutex::new(String::new()));
let man = TaskManagerBuilder::new_mem()
.add_executor(TaskHi { out: out.clone() })
.add_executor(PeriodicSpawner)
.set_sleep(Duration::from_millis(100))
.build()
.unwrap();
tokio::task::spawn(timeout(Duration::from_secs(4), man.run()));
sleep(Duration::from_millis(2100)).await;
let text = out.lock().await.clone();
assert!(text == "Hi, Mark!\nHi, Mark!\nHi, Mark!\n" || text == "Hi, Mark!\nHi, Mark!\n");
}