use std::collections::HashMap;
use std::future::Future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::task::{AbortHandle, JoinHandle};
use crate::Action;
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct TaskKey(String);
impl TaskKey {
pub fn new(name: impl Into<String>) -> Self {
Self(name.into())
}
pub fn name(&self) -> &str {
&self.0
}
}
impl From<&'static str> for TaskKey {
fn from(s: &'static str) -> Self {
Self::new(s)
}
}
impl From<String> for TaskKey {
fn from(s: String) -> Self {
Self(s)
}
}
#[derive(Clone)]
pub struct TaskPauseHandle<A> {
paused: Arc<AtomicBool>,
queued_actions: Arc<Mutex<Vec<A>>>,
}
impl<A> TaskPauseHandle<A> {
pub fn pause(&self) {
self.paused.store(true, Ordering::SeqCst);
}
pub fn resume(&self) -> Vec<A> {
self.paused.store(false, Ordering::SeqCst);
std::mem::take(&mut *self.queued_actions.lock().unwrap())
}
pub fn is_paused(&self) -> bool {
self.paused.load(Ordering::SeqCst)
}
}
pub struct TaskManager<A> {
tasks: HashMap<TaskKey, AbortHandle>,
action_tx: mpsc::UnboundedSender<A>,
paused: Arc<AtomicBool>,
queued_actions: Arc<Mutex<Vec<A>>>,
}
impl<A> TaskManager<A>
where
A: Action,
{
pub fn new(action_tx: mpsc::UnboundedSender<A>) -> Self {
Self {
tasks: HashMap::new(),
action_tx,
paused: Arc::new(AtomicBool::new(false)),
queued_actions: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn pause(&self) {
self.paused.store(true, Ordering::SeqCst);
}
pub fn resume(&self) -> Vec<A> {
self.paused.store(false, Ordering::SeqCst);
std::mem::take(&mut *self.queued_actions.lock().unwrap())
}
pub fn is_paused(&self) -> bool {
self.paused.load(Ordering::SeqCst)
}
pub fn pause_handle(&self) -> TaskPauseHandle<A> {
TaskPauseHandle {
paused: self.paused.clone(),
queued_actions: self.queued_actions.clone(),
}
}
pub fn cleanup(&mut self) {
self.tasks.retain(|_, handle| !handle.is_finished());
}
pub fn spawn<F>(&mut self, key: impl Into<TaskKey>, future: F) -> &mut Self
where
F: Future<Output = A> + Send + 'static,
{
let key = key.into();
self.cleanup();
self.cancel(&key);
let tx = self.action_tx.clone();
let paused = self.paused.clone();
let queued = self.queued_actions.clone();
let handle: JoinHandle<()> = tokio::spawn(async move {
let action = future.await;
if paused.load(Ordering::SeqCst) {
queued.lock().unwrap().push(action);
} else {
let _ = tx.send(action);
}
});
self.tasks.insert(key, handle.abort_handle());
self
}
pub fn debounce<F>(
&mut self,
key: impl Into<TaskKey>,
duration: Duration,
future: F,
) -> &mut Self
where
F: Future<Output = A> + Send + 'static,
{
let key = key.into();
self.cleanup();
self.cancel(&key);
let tx = self.action_tx.clone();
let paused = self.paused.clone();
let queued = self.queued_actions.clone();
let handle: JoinHandle<()> = tokio::spawn(async move {
tokio::time::sleep(duration).await;
let action = future.await;
if paused.load(Ordering::SeqCst) {
queued.lock().unwrap().push(action);
} else {
let _ = tx.send(action);
}
});
self.tasks.insert(key, handle.abort_handle());
self
}
pub fn cancel(&mut self, key: &TaskKey) {
if let Some(handle) = self.tasks.remove(key) {
handle.abort();
}
}
pub fn cancel_all(&mut self) {
for (_, handle) in self.tasks.drain() {
handle.abort();
}
}
pub fn is_running(&self, key: &TaskKey) -> bool {
self.tasks
.get(key)
.map(|handle| !handle.is_finished())
.unwrap_or(false)
}
pub fn len(&self) -> usize {
self.tasks
.values()
.filter(|handle| !handle.is_finished())
.count()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn running_keys(&self) -> impl Iterator<Item = &TaskKey> {
self.tasks
.iter()
.filter(|(_, handle)| !handle.is_finished())
.map(|(key, _)| key)
}
}
impl<A> Drop for TaskManager<A> {
fn drop(&mut self) {
for (_, handle) in self.tasks.drain() {
handle.abort();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[derive(Clone, Debug)]
enum TestAction {
Done(usize),
}
impl Action for TestAction {
fn name(&self) -> &'static str {
"Done"
}
}
#[test]
fn test_task_key() {
let k1 = TaskKey::new("test");
let k2 = TaskKey::from("test");
let k3: TaskKey = "test".into();
assert_eq!(k1, k2);
assert_eq!(k2, k3);
assert_eq!(k1.name(), "test");
}
#[tokio::test]
async fn test_spawn_sends_action() {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut tasks = TaskManager::new(tx);
tasks.spawn("test", async { TestAction::Done(42) });
let action = tokio::time::timeout(Duration::from_millis(100), rx.recv())
.await
.expect("timeout")
.expect("channel closed");
assert!(matches!(action, TestAction::Done(42)));
}
#[tokio::test]
async fn test_spawn_cancels_previous() {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut tasks = TaskManager::new(tx);
let counter = Arc::new(AtomicUsize::new(0));
let c1 = counter.clone();
tasks.spawn("test", async move {
tokio::time::sleep(Duration::from_millis(100)).await;
c1.fetch_add(1, Ordering::SeqCst);
TestAction::Done(1)
});
let c2 = counter.clone();
tasks.spawn("test", async move {
c2.fetch_add(10, Ordering::SeqCst);
TestAction::Done(2)
});
let action = tokio::time::timeout(Duration::from_millis(200), rx.recv())
.await
.expect("timeout")
.expect("channel closed");
assert!(matches!(action, TestAction::Done(2)));
assert_eq!(counter.load(Ordering::SeqCst), 10);
}
#[tokio::test]
async fn test_debounce() {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut tasks = TaskManager::new(tx);
tasks.debounce("test", Duration::from_millis(50), async {
TestAction::Done(1)
});
let result = tokio::time::timeout(Duration::from_millis(30), rx.recv()).await;
assert!(result.is_err());
let action = tokio::time::timeout(Duration::from_millis(100), rx.recv())
.await
.expect("timeout")
.expect("channel closed");
assert!(matches!(action, TestAction::Done(1)));
}
#[tokio::test]
async fn test_debounce_resets() {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut tasks = TaskManager::new(tx);
tasks.debounce("test", Duration::from_millis(50), async {
TestAction::Done(1)
});
tokio::time::sleep(Duration::from_millis(30)).await;
tasks.debounce("test", Duration::from_millis(50), async {
TestAction::Done(2)
});
let action = tokio::time::timeout(Duration::from_millis(100), rx.recv())
.await
.expect("timeout")
.expect("channel closed");
assert!(matches!(action, TestAction::Done(2)));
}
#[tokio::test]
async fn test_cancel() {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut tasks = TaskManager::new(tx);
tasks.spawn("test", async {
tokio::time::sleep(Duration::from_millis(100)).await;
TestAction::Done(1)
});
assert!(tasks.is_running(&TaskKey::new("test")));
tasks.cancel(&TaskKey::new("test"));
assert!(!tasks.is_running(&TaskKey::new("test")));
let result = tokio::time::timeout(Duration::from_millis(150), rx.recv()).await;
assert!(result.is_err() || result.unwrap().is_none());
}
#[tokio::test]
async fn test_cancel_all() {
let (tx, _rx) = mpsc::unbounded_channel();
let mut tasks = TaskManager::new(tx);
tasks.spawn("a", async {
tokio::time::sleep(Duration::from_secs(10)).await;
TestAction::Done(1)
});
tasks.spawn("b", async {
tokio::time::sleep(Duration::from_secs(10)).await;
TestAction::Done(2)
});
assert_eq!(tasks.len(), 2);
tasks.cancel_all();
assert!(tasks.is_empty());
}
#[test]
fn test_running_keys() {
let (tx, _rx) = mpsc::unbounded_channel::<TestAction>();
let tasks = TaskManager::new(tx);
assert!(tasks.is_empty());
assert_eq!(tasks.len(), 0);
}
#[test]
fn test_pause_handle_basic() {
let (tx, _rx) = mpsc::unbounded_channel::<TestAction>();
let tasks = TaskManager::new(tx);
let handle = tasks.pause_handle();
assert!(!handle.is_paused());
handle.pause();
assert!(handle.is_paused());
let queued = handle.resume();
assert!(!handle.is_paused());
assert!(queued.is_empty());
}
#[tokio::test]
async fn test_pause_queues_actions() {
let (tx, mut rx) = mpsc::unbounded_channel::<TestAction>();
let mut tasks = TaskManager::new(tx);
let handle = tasks.pause_handle();
handle.pause();
tasks.spawn("test", async { TestAction::Done(42) });
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(rx.try_recv().is_err());
let queued = handle.resume();
assert_eq!(queued.len(), 1);
assert!(matches!(queued[0], TestAction::Done(42)));
}
#[tokio::test]
async fn test_pause_handle_clone() {
let (tx, _rx) = mpsc::unbounded_channel::<TestAction>();
let tasks = TaskManager::new(tx);
let handle1 = tasks.pause_handle();
let handle2 = handle1.clone();
handle1.pause();
assert!(handle2.is_paused());
handle2.resume();
assert!(!handle1.is_paused());
}
#[tokio::test]
async fn test_finished_tasks_cleaned_up() {
let (tx, mut rx) = mpsc::unbounded_channel();
let mut tasks = TaskManager::new(tx);
tasks.spawn("fast", async { TestAction::Done(1) });
let _ = tokio::time::timeout(Duration::from_millis(100), rx.recv())
.await
.expect("timeout");
assert!(!tasks.is_running(&TaskKey::new("fast")));
assert_eq!(tasks.len(), 0);
tasks.spawn("another", async { TestAction::Done(2) });
let _ = tokio::time::timeout(Duration::from_millis(100), rx.recv())
.await
.expect("timeout");
assert_eq!(tasks.len(), 0);
}
#[tokio::test]
async fn test_is_running_accurate_for_long_task() {
let (tx, _rx) = mpsc::unbounded_channel();
let mut tasks = TaskManager::new(tx);
tasks.spawn("slow", async {
tokio::time::sleep(Duration::from_millis(200)).await;
TestAction::Done(1)
});
assert!(tasks.is_running(&TaskKey::new("slow")));
assert_eq!(tasks.len(), 1);
tokio::time::sleep(Duration::from_millis(250)).await;
assert!(!tasks.is_running(&TaskKey::new("slow")));
assert_eq!(tasks.len(), 0);
}
}