task_tracker/
tracker.rs

1use crate::inner_task::{InnerTask, TaskResult};
2use hashbrown::{HashMap, HashSet};
3use std::sync::Arc;
4use tokio::sync::RwLock;
5use tracing::debug;
6
7#[derive(Clone)]
8pub struct TaskTracker<K, R>
9where
10    K: Clone + Eq + std::hash::Hash + Sync + Send + std::fmt::Debug + 'static,
11    R: Sync + Send + std::fmt::Debug + 'static,
12{
13    inner: Arc<RwLock<HashMap<K, InnerTask<K, R>>>>,
14}
15
16impl<K, R> TaskTracker<K, R>
17where
18    K: Clone + Eq + std::hash::Hash + Sync + Send + std::fmt::Debug + 'static,
19    R: Sync + Send + std::fmt::Debug + 'static,
20{
21    pub fn new() -> Self {
22        Self {
23            inner: Default::default(),
24        }
25    }
26
27    pub async fn add<Fut>(&self, key: K, fut: Fut) -> Option<TaskResult<R>>
28    where
29        Fut: std::future::Future<Output = R> + Send + 'static,
30    {
31        let prev_result = self.remove(&key).await;
32        self.inner
33            .write()
34            .await
35            .insert(key.clone(), InnerTask::new(key, fut));
36        prev_result
37    }
38
39    pub async fn remove(&self, key: &K) -> Option<TaskResult<R>> {
40        let prev_task = self.inner.write().await.remove(key);
41        if let Some(prev_task) = prev_task {
42            debug!(?key, %prev_task.id, "removed previous task");
43            Some(prev_task.cancel_and_wait().await)
44        } else {
45            None
46        }
47    }
48
49    pub async fn apply<I, IT, G, F, Fut>(
50        &self,
51        iter: IT,
52        get_key: G,
53        create_task: F,
54    ) -> HashMap<K, TaskResult<R>>
55    where
56        IT: IntoIterator<Item = I>,
57        G: Fn(&I) -> &K,
58        F: Fn(I) -> Fut,
59        Fut: std::future::Future<Output = R> + Send + 'static,
60    {
61        let mut finished = self.remove_finished().await;
62        let mut old_keys: HashSet<K> = self.inner.read().await.keys().cloned().collect();
63
64        for item in iter {
65            let key = get_key(&item);
66            if !old_keys.remove(key) {
67                self.add(key.to_owned(), create_task(item)).await;
68            }
69        }
70
71        for key in old_keys {
72            let result = self.remove(&key).await;
73            if let Some(result) = result {
74                finished.insert(key, result);
75            }
76        }
77
78        finished
79    }
80
81    pub async fn remove_finished(&self) -> HashMap<K, TaskResult<R>> {
82        let mut results = HashMap::new();
83        let mut inner = self.inner.write().await;
84        for (key, inner_task) in inner.drain_filter(|_, inner_task| inner_task.is_finished()) {
85            results.insert(key, inner_task.wait().await);
86        }
87        results
88    }
89
90    pub async fn wait_for_tasks(&self) -> HashMap<K, TaskResult<R>> {
91        let mut results = HashMap::new();
92        debug!("waiting for all tasks to finish");
93        let mut inner = self.inner.write().await;
94        for (key, inner_task) in inner.drain() {
95            let result = inner_task.wait().await;
96            results.insert(key, result);
97        }
98        results
99    }
100}