use std::collections::HashMap;
use std::hash::Hash as StdHash;
use std::sync::Arc;
use tokio::sync::{Mutex, Notify, RwLock};
#[derive(Clone, Debug)]
pub struct TaskTracker<T, ID>(Arc<RwLock<HashMap<ID, Task<T, ID>>>>);
impl<T, ID> TaskTracker<T, ID>
where
T: Clone,
ID: Copy + Eq + StdHash,
{
pub fn new() -> Self {
Self(Arc::new(RwLock::new(HashMap::new())))
}
#[allow(unused)]
pub async fn len(&self) -> usize {
let inner = self.0.read().await;
inner.len()
}
pub async fn track(&self, id: ID) -> Task<T, ID> {
let mut inner = self.0.write().await;
match inner.get(&id) {
Some(task) => task.clone(),
None => {
let task = Task::<T, ID>::new(id);
inner.insert(id, task.clone());
task
}
}
}
pub async fn mark_as_done(&self, id: ID, result: T) {
let mut inner = self.0.write().await;
let Some(task) = inner.remove(&id) else {
return;
};
task.mark_as_done(result).await;
}
}
impl<T, ID> Default for TaskTracker<T, ID>
where
T: Clone,
ID: Copy + Eq + StdHash,
{
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct Task<T, ID> {
id: ID,
ready_result: Arc<Mutex<Option<T>>>,
ready_signal: Arc<Notify>,
}
impl<T, ID> PartialEq for Task<T, ID>
where
ID: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl<T, ID> Task<T, ID>
where
T: Clone,
{
fn new(id: ID) -> Self {
Self {
id,
ready_result: Arc::new(Mutex::new(None)),
ready_signal: Arc::new(Notify::new()),
}
}
async fn mark_as_done(&self, result: T) {
{
let mut ready_result = self.ready_result.lock().await;
*ready_result = Some(result);
}
self.ready_signal.notify_waiters();
}
pub async fn ready(&self) -> T {
{
let ready_result = self.ready_result.lock().await;
if ready_result.is_some() {
return ready_result
.clone()
.expect("result exists after ready signal was fired");
}
}
self.ready_signal.notified().await;
let ready_result = self.ready_result.lock().await;
ready_result
.clone()
.expect("result exists after ready signal was fired")
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::TaskTracker;
#[tokio::test]
async fn deduplicate_tasks_by_id() {
let tasks = TaskTracker::<String, usize>::new();
let task_a = tasks.track(1).await;
let task_b = tasks.track(1).await; let task_c = tasks.track(2).await;
assert_eq!(tasks.len().await, 2);
assert_eq!(task_a, task_b);
assert_ne!(task_a, task_c);
}
#[tokio::test]
async fn notify_all() {
let tasks = TaskTracker::<String, usize>::new();
let mut futures = Vec::new();
const TASK_ID: usize = 1;
for _ in 0..10 {
let tasks = tasks.clone();
let handle = tokio::spawn(async move {
let task_1 = tasks.track(TASK_ID).await;
assert_eq!(tasks.len().await, 1);
task_1.ready().await
});
futures.push(handle);
}
tokio::time::sleep(Duration::from_millis(50)).await;
tasks.mark_as_done(TASK_ID, "yay, we did it!".into()).await;
let result = futures_util::future::join_all(futures).await;
for i in 0..10 {
assert_eq!(result[i].as_ref().unwrap(), &"yay, we did it!".to_string());
}
}
#[tokio::test]
async fn concurrent_removal() {
let tasks = TaskTracker::<String, usize>::new();
let task_a = tasks.track(1).await;
tasks.mark_as_done(1, "yay, we did it".to_string()).await;
tasks.mark_as_done(1, "yay, we did it".to_string()).await;
let result = task_a.ready().await;
assert_eq!(result, "yay, we did it".to_string());
}
}