use std::{future::Future, time::Duration};
use tokio::sync::broadcast::Sender;
use tokio::{sync::broadcast, time::Instant};
pub struct TaskController {
timeout: Option<Instant>,
cancel_sender: Sender<()>,
}
impl TaskController {
pub fn cancel(self) {}
pub fn new() -> TaskController {
let (tx, _) = broadcast::channel(1);
TaskController {
timeout: None,
cancel_sender: tx,
}
}
pub fn with_timeout(timeout: Duration) -> TaskController {
let (tx, _) = broadcast::channel(1);
TaskController {
timeout: Some(Instant::now() + timeout),
cancel_sender: tx,
}
}
pub fn spawn<T>(&mut self, future: T) -> tokio::task::JoinHandle<Option<T::Output>>
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
let mut rx = self.cancel_sender.subscribe();
if let Some(instant) = self.timeout {
tokio::task::spawn(async move {
tokio::select! {
res = future => Some(res),
_ = rx.recv() => None,
_ = tokio::time::sleep_until(instant) => None,
}
})
} else {
tokio::task::spawn(async move {
tokio::select! {
res = future => Some(res),
_ = rx.recv() => None,
}
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn cancel_handle_cancels_task() {
let mut controller = TaskController::new();
let join = controller.spawn(async { tokio::time::sleep(Duration::from_secs(60)).await });
controller.cancel();
tokio::select! {
_ = join => assert!(true),
_ = tokio::time::sleep(Duration::from_millis(1)) => assert!(false),
}
}
#[tokio::test]
async fn duration_cancels_task() {
let mut controller = TaskController::with_timeout(Duration::from_millis(10));
let join = controller.spawn(async { tokio::time::sleep(Duration::from_secs(60)).await });
tokio::select! {
_ = join => assert!(true),
_ = tokio::time::sleep(Duration::from_millis(15)) => assert!(false),
}
}
}