1use crate::inner_task::{InnerTask, TaskResult};
2use hashbrown::{HashMap, HashSet};
3use std::sync::Arc;
4use tokio::sync::RwLock;
5use tracing::debug;
6
7#[derive(Clone)]
8pub struct TaskTracker<K, R>
9where
10 K: Clone + Eq + std::hash::Hash + Sync + Send + std::fmt::Debug + 'static,
11 R: Sync + Send + std::fmt::Debug + 'static,
12{
13 inner: Arc<RwLock<HashMap<K, InnerTask<K, R>>>>,
14}
15
16impl<K, R> TaskTracker<K, R>
17where
18 K: Clone + Eq + std::hash::Hash + Sync + Send + std::fmt::Debug + 'static,
19 R: Sync + Send + std::fmt::Debug + 'static,
20{
21 pub fn new() -> Self {
22 Self {
23 inner: Default::default(),
24 }
25 }
26
27 pub async fn add<Fut>(&self, key: K, fut: Fut) -> Option<TaskResult<R>>
28 where
29 Fut: std::future::Future<Output = R> + Send + 'static,
30 {
31 let prev_result = self.remove(&key).await;
32 self.inner
33 .write()
34 .await
35 .insert(key.clone(), InnerTask::new(key, fut));
36 prev_result
37 }
38
39 pub async fn remove(&self, key: &K) -> Option<TaskResult<R>> {
40 let prev_task = self.inner.write().await.remove(key);
41 if let Some(prev_task) = prev_task {
42 debug!(?key, %prev_task.id, "removed previous task");
43 Some(prev_task.cancel_and_wait().await)
44 } else {
45 None
46 }
47 }
48
49 pub async fn apply<I, IT, G, F, Fut>(
50 &self,
51 iter: IT,
52 get_key: G,
53 create_task: F,
54 ) -> HashMap<K, TaskResult<R>>
55 where
56 IT: IntoIterator<Item = I>,
57 G: Fn(&I) -> &K,
58 F: Fn(I) -> Fut,
59 Fut: std::future::Future<Output = R> + Send + 'static,
60 {
61 let mut finished = self.remove_finished().await;
62 let mut old_keys: HashSet<K> = self.inner.read().await.keys().cloned().collect();
63
64 for item in iter {
65 let key = get_key(&item);
66 if !old_keys.remove(key) {
67 self.add(key.to_owned(), create_task(item)).await;
68 }
69 }
70
71 for key in old_keys {
72 let result = self.remove(&key).await;
73 if let Some(result) = result {
74 finished.insert(key, result);
75 }
76 }
77
78 finished
79 }
80
81 pub async fn remove_finished(&self) -> HashMap<K, TaskResult<R>> {
82 let mut results = HashMap::new();
83 let mut inner = self.inner.write().await;
84 for (key, inner_task) in inner.drain_filter(|_, inner_task| inner_task.is_finished()) {
85 results.insert(key, inner_task.wait().await);
86 }
87 results
88 }
89
90 pub async fn wait_for_tasks(&self) -> HashMap<K, TaskResult<R>> {
91 let mut results = HashMap::new();
92 debug!("waiting for all tasks to finish");
93 let mut inner = self.inner.write().await;
94 for (key, inner_task) in inner.drain() {
95 let result = inner_task.wait().await;
96 results.insert(key, result);
97 }
98 results
99 }
100}