use {
futures::{
future::{
join_all,
select,
select_all,
},
Future,
StreamExt,
},
loga::{
ea,
DebugDisplay,
ResultContext,
},
std::{
collections::HashSet,
pin::pin,
sync::{
Arc,
Mutex,
},
time::Duration,
},
tokio::{
select,
signal::unix::{
signal,
SignalKind,
},
spawn,
task::JoinHandle,
time::sleep,
},
tokio_util::sync::CancellationToken,
waitgroup::WaitGroup,
};
struct TaskManagerInner {
id_prefix: String,
alive_task_ids: Arc<Mutex<HashSet<String>>>,
alive: CancellationToken,
critical: Mutex<Option<Vec<JoinHandle<Result<(), loga::Error>>>>>,
wg: Mutex<Option<WaitGroup>>,
}
#[derive(Clone)]
pub struct TaskManager(Arc<TaskManagerInner>);
impl TaskManager {
pub fn new() -> TaskManager {
let tm = TaskManager(Arc::new(TaskManagerInner {
id_prefix: "".to_string(),
alive_task_ids: Arc::new(Mutex::new(HashSet::new())),
alive: CancellationToken::new(),
critical: Mutex::new(Some(Vec::new())),
wg: Mutex::new(Some(WaitGroup::new())),
}));
tm.task("Signals", {
let mut sig1 = signal(SignalKind::interrupt()).unwrap();
let mut sig2 = signal(SignalKind::terminate()).unwrap();
let tm = tm.clone();
async move {
select!{
_ = select(Box::pin(sig1.recv()), Box::pin(sig2.recv())) => {
eprintln!("Got signal, terminating.");
tm.terminate();
},
_ = tm.until_terminate() =>()
}
}
});
tm
}
fn prefix_id(&self, id: impl Into<String>) -> String {
if self.0.id_prefix.is_empty() {
return id.into();
} else {
return format!("{}/{}", self.0.id_prefix, id.into());
}
}
pub fn sub(&self, id: impl Into<String>) -> TaskManager {
TaskManager(Arc::new(TaskManagerInner {
id_prefix: self.prefix_id(id),
alive_task_ids: Arc::new(Mutex::new(HashSet::new())),
alive: self.0.alive.child_token(),
critical: Mutex::new(Some(Vec::new())),
wg: Mutex::new(Some(WaitGroup::new())),
}))
}
pub async fn until_terminate(&self) {
self.0.alive.cancelled().await;
}
pub fn task_(
&self,
critical: bool,
id: impl Into<String>,
future: impl Future<Output = Result<(), loga::Error>> + Send + 'static,
) {
let id = self.prefix_id(id);
let task_ids = self.0.alive_task_ids.clone();
if !task_ids.lock().unwrap().insert(id.clone()) {
panic!("Task with id {} already running!", id);
}
let w = match self.0.wg.lock().unwrap().as_ref() {
Some(w) => w.worker(),
None => {
return;
},
};
let j = spawn(async move {
let _w = w;
let res = future.await;
task_ids.lock().unwrap().remove(&id);
if critical {
return res.context(&format!("Critical task failed: {}", id));
} else {
return res;
}
});
if critical {
self.0.critical.lock().unwrap().as_mut().unwrap().push(j);
}
}
pub fn critical_task(
&self,
id: impl Into<String>,
future: impl Future<Output = Result<(), loga::Error>> + Send + 'static,
) {
self.task_(true, id, future);
}
pub fn task(&self, id: impl Into<String>, future: impl Future<Output = ()> + Send + 'static) {
self.task_(false, id, async move {
future.await;
return Ok(());
})
}
pub fn periodic<
F: FnMut() -> T + Send + 'static,
T: Future<Output = ()> + Send + 'static,
>(&self, id: impl Into<String>, period: Duration, mut f: F) {
let id = self.prefix_id(id);
let task_ids = self.0.alive_task_ids.clone();
let tm = self.clone();
spawn(async move {
loop {
let _w = match tm.0.wg.lock().unwrap().as_ref() {
Some(w) => w.worker(),
None => break,
};
if !task_ids.lock().unwrap().insert(id.clone()) {
panic!("Task with id {} already running!", id);
}
f().await;
drop(_w);
task_ids.lock().unwrap().remove(&id);
select!{
_ = tm.until_terminate() => {
break;
}
_ = sleep(period) => {
}
};
}
});
}
fn stream_<
T,
S: StreamExt<Item = T> + Send + 'static + Unpin,
Hn: FnMut(T) -> F + Send + 'static,
F: Future<Output = Result<(), loga::Error>> + Send + 'static,
>(&self, critical: bool, id: impl Into<String>, mut stream: S, mut handler: Hn) {
let id = self.prefix_id(id);
let task_ids = self.0.alive_task_ids.clone();
let tm = self.clone();
let join_handle = spawn(async move {
return loop {
let f = {
let e = select!{
_ = tm.until_terminate() => break Ok(()),
e = stream.next() => e,
};
match e {
Some(x) => handler(x),
None => break Ok(()),
}
};
let _w = match tm.0.wg.lock().unwrap().as_ref() {
Some(w) => w.worker(),
None => break Ok(()),
};
if !task_ids.lock().unwrap().insert(id.clone()) {
panic!("Task with id {} already running!", id);
}
let res = f.await;
drop(_w);
task_ids.lock().unwrap().remove(&id);
if let Err(e) = res {
if critical {
break Err(e).context(format!("Critical task failed: {}", id));
} else {
break Err(e);
}
}
};
});
if critical {
self.0.critical.lock().unwrap().as_mut().unwrap().push(join_handle);
}
}
pub fn critical_stream<
T,
S: StreamExt<Item = T> + Send + 'static + Unpin,
Hn: FnMut(T) -> F + Send + 'static,
F: Future<Output = Result<(), loga::Error>> + Send + 'static,
>(&self, id: impl Into<String>, stream: S, handler: Hn) {
self.stream_(true, id, stream, handler);
}
pub fn stream<
T,
S: StreamExt<Item = T> + Send + 'static + Unpin,
Hn: FnMut(T) -> F + Send + 'static,
F: Future<Output = ()> + Send + 'static,
>(&self, id: impl Into<String>, stream: S, mut handler: Hn) {
self.stream_(false, id, stream, move |e| {
let t = handler(e);
async move {
t.await;
return Ok(());
}
});
}
pub fn terminate(&self) {
self.0.alive.cancel();
}
pub async fn join(self, log: &loga::Log) -> Result<(), loga::Error> {
let alive_ids = self.0.alive_task_ids.clone();
let mut critical_tasks = self.0.critical.lock().unwrap().take().unwrap();
let mut results = vec![];
if !critical_tasks.is_empty() {
let first_critical_task_res;
(first_critical_task_res, _, critical_tasks) = select_all(critical_tasks).await;
results.push(first_critical_task_res);
}
let wg = self.0.wg.lock().unwrap().take().unwrap();
self.terminate();
let mut work = pin!(async move {
let results = join_all(critical_tasks).await;
wg.wait().await;
return results;
});
let results1 = loop {
select!{
results =& mut work => break results,
_ = sleep(std::time::Duration::from_secs(10)) => {
log.log_with(
loga::INFO,
"Waiting for all tasks to finish",
ea!(alive = (*alive_ids.lock().unwrap()).dbg_str()),
);
}
}
};
results.extend(results1);
let errs = results.into_iter().filter_map(|r| match r {
Ok(r) => match r {
Ok(_) => None,
Err(e) => Some(e),
},
Err(e) => Some(e.into()),
}).collect::<Vec<loga::Error>>();
if !errs.is_empty() {
return Err(loga::agg_err("The task manager exited after critical tasks failed", errs));
}
return Ok(())
}
}