use futures::future::BoxFuture;
use std::collections::BTreeMap;
use tokio_util::sync::CancellationToken;
type BoxedTaskFn<T> =
Box<dyn FnOnce(CancellationToken) -> BoxFuture<'static, Result<T, BoxedError>> + Send>;
type BoxedError = Box<dyn std::error::Error + Send + Sync>;
pub struct TaskMap<T> {
pub(crate) tasks: BTreeMap<&'static str, BoxedTaskFn<T>>,
}
impl<T> TaskMap<T> {
#[must_use]
pub fn new() -> Self {
Self {
tasks: BTreeMap::new(),
}
}
#[must_use]
pub fn insert<F, Fut, E>(mut self, name: &'static str, task: F) -> Self
where
F: FnOnce(CancellationToken) -> Fut + Send + 'static,
Fut: std::future::Future<Output = std::result::Result<T, E>> + Send + 'static,
E: Into<BoxedError>,
T: Send + 'static,
{
let boxed: BoxedTaskFn<T> = Box::new(move |token| {
let fut = task(token);
Box::pin(async move { fut.await.map_err(Into::into) })
});
self.tasks.insert(name, boxed);
self
}
#[must_use]
pub fn len(&self) -> usize {
self.tasks.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tasks.is_empty()
}
}
impl<T> Default for TaskMap<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn new_is_empty() {
let m: TaskMap<u32> = TaskMap::new();
assert!(m.is_empty());
assert_eq!(m.len(), 0);
}
#[test]
fn insert_increments_len() {
let m: TaskMap<u32> = TaskMap::new()
.insert("a", |_| async { Ok::<_, std::io::Error>(1) })
.insert("b", |_| async { Ok::<_, std::io::Error>(2) });
assert_eq!(m.len(), 2);
}
#[test]
fn insert_duplicate_overwrites() {
let m: TaskMap<u32> = TaskMap::new()
.insert("a", |_| async { Ok::<_, std::io::Error>(1) })
.insert("a", |_| async { Ok::<_, std::io::Error>(2) });
assert_eq!(m.len(), 1);
}
#[test]
fn default_is_empty() {
let m: TaskMap<u32> = TaskMap::default();
assert!(m.is_empty());
assert_eq!(m.len(), 0);
}
#[test]
fn len_after_three_inserts() {
let m: TaskMap<u32> = TaskMap::new()
.insert("a", |_| async { Ok::<_, std::io::Error>(1) })
.insert("b", |_| async { Ok::<_, std::io::Error>(2) })
.insert("c", |_| async { Ok::<_, std::io::Error>(3) });
assert_eq!(m.len(), 3);
assert!(!m.is_empty());
}
#[tokio::test]
async fn task_closure_executes_after_insert() {
let m: TaskMap<u32> =
TaskMap::new().insert("only", |_| async { Ok::<_, std::io::Error>(99) });
assert_eq!(m.len(), 1);
let (_name, task_fn) = m.tasks.into_iter().next().unwrap();
let ct = tokio_util::sync::CancellationToken::new();
let out = task_fn(ct).await.unwrap();
assert_eq!(out, 99);
}
}