1use std::collections::HashMap;
8use std::future::Future;
9use std::hash::Hash;
10use std::sync::Arc;
11use std::time::Duration;
12
13use parking_lot::Mutex;
14use tokio::runtime::Handle;
15use tokio::task::JoinHandle;
16
17use crate::TaskStatus;
18
19type StatusChangeCallback<K, T> = Arc<dyn Fn(&K, TaskStatus<T>) + Send + Sync>;
20
21struct TaskPoolInner<K, T> {
22 rt: Handle,
23 entries: Mutex<HashMap<K, PoolEntry>>,
24 on_status_change: StatusChangeCallback<K, T>,
25}
26
27struct PoolEntry {
28 join_handle: JoinHandle<()>,
29}
30
31#[derive(Clone)]
43pub struct TaskPool<K, T>
44where
45 T: Send + 'static,
46{
47 inner: Arc<TaskPoolInner<K, T>>,
48}
49
50impl<K, T> std::fmt::Debug for TaskPool<K, T>
51where
52 T: Send + 'static,
53{
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.debug_struct("TaskPool").finish_non_exhaustive()
56 }
57}
58
59impl<K, T> TaskPool<K, T>
60where
61 K: Hash + Eq + Clone + Send + 'static,
62 T: Send + 'static,
63{
64 #[must_use]
70 pub fn new(
71 rt: Handle,
72 on_status_change: impl Fn(&K, TaskStatus<T>) + Send + Sync + 'static,
73 ) -> Self {
74 Self {
75 inner: Arc::new(TaskPoolInner {
76 rt,
77 entries: Mutex::new(HashMap::new()),
78 on_status_change: Arc::new(on_status_change),
79 }),
80 }
81 }
82
83 pub fn spawn<F, Fut>(&self, key: K, f: F)
89 where
90 F: FnOnce() -> Fut + Send + 'static,
91 Fut: Future<Output = T> + Send + 'static,
92 {
93 if let Some(old) = self.inner.entries.lock().remove(&key) {
94 old.join_handle.abort();
95 }
96
97 (self.inner.on_status_change)(&key, TaskStatus::Pending);
98
99 let inner = Arc::clone(&self.inner);
100 let key_clone = key.clone();
101 let handle = self.inner.rt.spawn(async move {
102 let result = f().await;
103 (inner.on_status_change)(&key_clone, TaskStatus::Resolved(result));
104 });
105
106 self.inner.entries.lock().insert(key, PoolEntry { join_handle: handle });
107 }
108
109 pub fn abort(&self, key: &K) -> bool {
114 if let Some(entry) = self.inner.entries.lock().remove(key) {
115 entry.join_handle.abort();
116 (self.inner.on_status_change)(key, TaskStatus::Aborted);
117 true
118 } else {
119 false
120 }
121 }
122
123 pub async fn shutdown(&self, timeout: Duration) {
128 let entries: Vec<_> = self.inner.entries.lock().drain().collect();
129
130 for (_key, entry) in entries {
131 let handle = entry.join_handle;
132 match tokio::time::timeout(timeout, handle).await {
133 Ok(_) | Err(_) => {}
134 }
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use std::collections::HashMap as StdHashMap;
143 use std::sync::Arc;
144
145 type PoolStatuses<T> = Arc<Mutex<StdHashMap<String, Vec<TaskStatus<T>>>>>;
146
147 fn make_pool<T>(
148 rt: &Handle,
149 ) -> (TaskPool<String, T>, PoolStatuses<T>)
150 where
151 T: Send + 'static,
152 {
153 let statuses: Arc<Mutex<StdHashMap<String, Vec<TaskStatus<T>>>>> =
154 Arc::new(Mutex::new(StdHashMap::new()));
155 let s = statuses.clone();
156 let pool = TaskPool::new(rt.clone(), move |key: &String, status| {
157 s.lock().entry(key.clone()).or_default().push(status);
158 });
159 (pool, statuses)
160 }
161
162 #[tokio::test]
163 async fn spawn_calls_on_status_change_with_pending_then_resolved() {
164 let rt = Handle::current();
166 let (pool, statuses) = make_pool::<i32>(&rt);
167
168 pool.spawn("key".to_string(), || async { 42 });
170 tokio::time::sleep(Duration::from_millis(10)).await;
171
172 let log = statuses.lock().get("key").cloned().unwrap();
174 assert_eq!(log.len(), 2);
175 assert!(log[0].is_pending());
176 assert!(log[1].is_resolved());
177 assert_eq!(log[1].resolved(), Some(&42));
178 }
179
180 #[tokio::test]
181 async fn re_spawn_with_same_key_aborts_previous_task() {
182 let rt = Handle::current();
184 let (pool, statuses) = make_pool::<i32>(&rt);
185 pool.spawn("key".to_string(), || async {
186 tokio::time::sleep(Duration::from_secs(10)).await;
187 1
188 });
189 tokio::time::sleep(Duration::from_millis(5)).await;
190
191 pool.spawn("key".to_string(), || async { 2 });
193 tokio::time::sleep(Duration::from_millis(10)).await;
194
195 let log = statuses.lock().get("key").cloned().unwrap();
197 assert!(log.contains(&TaskStatus::Pending));
198 assert!(log.iter().any(|s| s.is_resolved() && s.resolved() == Some(&2)));
199 }
200
201 #[tokio::test]
202 async fn abort_sets_status_to_aborted_and_returns_true() {
203 let rt = Handle::current();
205 let (pool, statuses) = make_pool::<()>(&rt);
206 pool.spawn("key".to_string(), || async {
207 tokio::time::sleep(Duration::from_secs(10)).await;
208 });
209 tokio::time::sleep(Duration::from_millis(5)).await;
210
211 let found = pool.abort(&"key".to_string());
213
214 assert!(found);
216 let log = statuses.lock().get("key").cloned().unwrap();
217 assert!(log.contains(&TaskStatus::Aborted));
218 }
219
220 #[tokio::test]
221 async fn abort_returns_false_for_unknown_key() {
222 let rt = Handle::current();
224 let (pool, _statuses): (TaskPool<String, ()>, _) = make_pool(&rt);
225
226 let found = pool.abort(&"missing".to_string());
228
229 assert!(!found);
231 }
232
233 #[tokio::test]
234 async fn shutdown_awaits_cooperative_tasks() {
235 let rt = Handle::current();
237 let (pool, statuses) = make_pool::<i32>(&rt);
238 pool.spawn("key".to_string(), || async { 99 });
239 tokio::time::sleep(Duration::from_millis(10)).await;
240
241 pool.shutdown(Duration::from_secs(1)).await;
243
244 let log = statuses.lock().get("key").cloned().unwrap();
246 assert!(log.iter().any(|s| s.is_resolved() && s.resolved() == Some(&99)));
247 }
248
249 #[tokio::test]
250 async fn shutdown_force_aborts_tasks_after_timeout() {
251 let rt = Handle::current();
253 let (pool, statuses) = make_pool::<()>(&rt);
254 pool.spawn("key".to_string(), || async {
255 tokio::time::sleep(Duration::from_secs(10)).await;
256 });
257 tokio::time::sleep(Duration::from_millis(5)).await;
258
259 pool.shutdown(Duration::from_millis(5)).await;
261
262 let log = statuses.lock().get("key").cloned().unwrap();
265 assert!(log.iter().any(TaskStatus::is_pending));
266 }
267
268 #[tokio::test]
269 async fn clone_is_cheap_shared_state() {
270 let rt = Handle::current();
272 let (pool, statuses) = make_pool::<i32>(&rt);
273
274 let pool2 = pool.clone();
276 pool2.spawn("key".to_string(), || async { 7 });
277 tokio::time::sleep(Duration::from_millis(10)).await;
278
279 let log = statuses.lock().get("key").cloned().unwrap();
281 assert!(log.iter().any(|s| s.is_resolved() && s.resolved() == Some(&7)));
282 }
283
284 #[tokio::test]
285 async fn different_keys_run_independently() {
286 let rt = Handle::current();
288 let (pool, statuses) = make_pool::<i32>(&rt);
289
290 pool.spawn("a".to_string(), || async { 1 });
292 pool.spawn("b".to_string(), || async { 2 });
293 tokio::time::sleep(Duration::from_millis(10)).await;
294
295 let log_a = statuses.lock().get("a").cloned().unwrap();
297 let log_b = statuses.lock().get("b").cloned().unwrap();
298 assert!(log_a.iter().any(|s| s.is_resolved() && s.resolved() == Some(&1)));
299 assert!(log_b.iter().any(|s| s.is_resolved() && s.resolved() == Some(&2)));
300 }
301
302 #[tokio::test]
303 async fn abort_after_task_completes_returns_true() {
304 let rt = Handle::current();
306 let (pool, _statuses) = make_pool::<i32>(&rt);
307 pool.spawn("key".to_string(), || async { 1 });
308 tokio::time::sleep(Duration::from_millis(10)).await;
309
310 let found = pool.abort(&"key".to_string());
312
313 assert!(found);
318 }
319}