use std::collections::HashMap;
use std::future::Future;
use std::hash::Hash;
use std::sync::Arc;
use std::time::Duration;
use parking_lot::Mutex;
use tokio::runtime::Handle;
use tokio::task::JoinHandle;
use crate::TaskStatus;
type StatusChangeCallback<K, T> = Arc<dyn Fn(&K, TaskStatus<T>) + Send + Sync>;
struct TaskPoolInner<K, T> {
rt: Handle,
entries: Mutex<HashMap<K, PoolEntry>>,
on_status_change: StatusChangeCallback<K, T>,
}
struct PoolEntry {
join_handle: JoinHandle<()>,
}
#[derive(Clone)]
pub struct TaskPool<K, T>
where
T: Send + 'static,
{
inner: Arc<TaskPoolInner<K, T>>,
}
impl<K, T> std::fmt::Debug for TaskPool<K, T>
where
T: Send + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TaskPool").finish_non_exhaustive()
}
}
impl<K, T> TaskPool<K, T>
where
K: Hash + Eq + Clone + Send + 'static,
T: Send + 'static,
{
#[must_use]
pub fn new(
rt: Handle,
on_status_change: impl Fn(&K, TaskStatus<T>) + Send + Sync + 'static,
) -> Self {
Self {
inner: Arc::new(TaskPoolInner {
rt,
entries: Mutex::new(HashMap::new()),
on_status_change: Arc::new(on_status_change),
}),
}
}
pub fn spawn<F, Fut>(&self, key: K, f: F)
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
{
if let Some(old) = self.inner.entries.lock().remove(&key) {
old.join_handle.abort();
}
(self.inner.on_status_change)(&key, TaskStatus::Pending);
let inner = Arc::clone(&self.inner);
let key_clone = key.clone();
let handle = self.inner.rt.spawn(async move {
let result = f().await;
(inner.on_status_change)(&key_clone, TaskStatus::Resolved(result));
});
self.inner.entries.lock().insert(key, PoolEntry { join_handle: handle });
}
pub fn abort(&self, key: &K) -> bool {
if let Some(entry) = self.inner.entries.lock().remove(key) {
entry.join_handle.abort();
(self.inner.on_status_change)(key, TaskStatus::Aborted);
true
} else {
false
}
}
pub async fn shutdown(&self, timeout: Duration) {
let entries: Vec<_> = self.inner.entries.lock().drain().collect();
for (_key, entry) in entries {
let handle = entry.join_handle;
match tokio::time::timeout(timeout, handle).await {
Ok(_) | Err(_) => {}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap as StdHashMap;
use std::sync::Arc;
type PoolStatuses<T> = Arc<Mutex<StdHashMap<String, Vec<TaskStatus<T>>>>>;
fn make_pool<T>(
rt: &Handle,
) -> (TaskPool<String, T>, PoolStatuses<T>)
where
T: Send + 'static,
{
let statuses: Arc<Mutex<StdHashMap<String, Vec<TaskStatus<T>>>>> =
Arc::new(Mutex::new(StdHashMap::new()));
let s = statuses.clone();
let pool = TaskPool::new(rt.clone(), move |key: &String, status| {
s.lock().entry(key.clone()).or_default().push(status);
});
(pool, statuses)
}
#[tokio::test]
async fn spawn_calls_on_status_change_with_pending_then_resolved() {
let rt = Handle::current();
let (pool, statuses) = make_pool::<i32>(&rt);
pool.spawn("key".to_string(), || async { 42 });
tokio::time::sleep(Duration::from_millis(10)).await;
let log = statuses.lock().get("key").cloned().unwrap();
assert_eq!(log.len(), 2);
assert!(log[0].is_pending());
assert!(log[1].is_resolved());
assert_eq!(log[1].resolved(), Some(&42));
}
#[tokio::test]
async fn re_spawn_with_same_key_aborts_previous_task() {
let rt = Handle::current();
let (pool, statuses) = make_pool::<i32>(&rt);
pool.spawn("key".to_string(), || async {
tokio::time::sleep(Duration::from_secs(10)).await;
1
});
tokio::time::sleep(Duration::from_millis(5)).await;
pool.spawn("key".to_string(), || async { 2 });
tokio::time::sleep(Duration::from_millis(10)).await;
let log = statuses.lock().get("key").cloned().unwrap();
assert!(log.contains(&TaskStatus::Pending));
assert!(log.iter().any(|s| s.is_resolved() && s.resolved() == Some(&2)));
}
#[tokio::test]
async fn abort_sets_status_to_aborted_and_returns_true() {
let rt = Handle::current();
let (pool, statuses) = make_pool::<()>(&rt);
pool.spawn("key".to_string(), || async {
tokio::time::sleep(Duration::from_secs(10)).await;
});
tokio::time::sleep(Duration::from_millis(5)).await;
let found = pool.abort(&"key".to_string());
assert!(found);
let log = statuses.lock().get("key").cloned().unwrap();
assert!(log.contains(&TaskStatus::Aborted));
}
#[tokio::test]
async fn abort_returns_false_for_unknown_key() {
let rt = Handle::current();
let (pool, _statuses): (TaskPool<String, ()>, _) = make_pool(&rt);
let found = pool.abort(&"missing".to_string());
assert!(!found);
}
#[tokio::test]
async fn shutdown_awaits_cooperative_tasks() {
let rt = Handle::current();
let (pool, statuses) = make_pool::<i32>(&rt);
pool.spawn("key".to_string(), || async { 99 });
tokio::time::sleep(Duration::from_millis(10)).await;
pool.shutdown(Duration::from_secs(1)).await;
let log = statuses.lock().get("key").cloned().unwrap();
assert!(log.iter().any(|s| s.is_resolved() && s.resolved() == Some(&99)));
}
#[tokio::test]
async fn shutdown_force_aborts_tasks_after_timeout() {
let rt = Handle::current();
let (pool, statuses) = make_pool::<()>(&rt);
pool.spawn("key".to_string(), || async {
tokio::time::sleep(Duration::from_secs(10)).await;
});
tokio::time::sleep(Duration::from_millis(5)).await;
pool.shutdown(Duration::from_millis(5)).await;
let log = statuses.lock().get("key").cloned().unwrap();
assert!(log.iter().any(TaskStatus::is_pending));
}
#[tokio::test]
async fn clone_is_cheap_shared_state() {
let rt = Handle::current();
let (pool, statuses) = make_pool::<i32>(&rt);
let pool2 = pool.clone();
pool2.spawn("key".to_string(), || async { 7 });
tokio::time::sleep(Duration::from_millis(10)).await;
let log = statuses.lock().get("key").cloned().unwrap();
assert!(log.iter().any(|s| s.is_resolved() && s.resolved() == Some(&7)));
}
#[tokio::test]
async fn different_keys_run_independently() {
let rt = Handle::current();
let (pool, statuses) = make_pool::<i32>(&rt);
pool.spawn("a".to_string(), || async { 1 });
pool.spawn("b".to_string(), || async { 2 });
tokio::time::sleep(Duration::from_millis(10)).await;
let log_a = statuses.lock().get("a").cloned().unwrap();
let log_b = statuses.lock().get("b").cloned().unwrap();
assert!(log_a.iter().any(|s| s.is_resolved() && s.resolved() == Some(&1)));
assert!(log_b.iter().any(|s| s.is_resolved() && s.resolved() == Some(&2)));
}
#[tokio::test]
async fn abort_after_task_completes_returns_true() {
let rt = Handle::current();
let (pool, _statuses) = make_pool::<i32>(&rt);
pool.spawn("key".to_string(), || async { 1 });
tokio::time::sleep(Duration::from_millis(10)).await;
let found = pool.abort(&"key".to_string());
assert!(found);
}
}