use std::future::Future;
use tokio::task::{AbortHandle, JoinSet};
use tokio_util::sync::CancellationToken;
use crate::domain::TurnId;
#[derive(Debug)]
pub struct TurnScope {
id: TurnId,
token: CancellationToken,
joins: JoinSet<()>,
}
impl TurnScope {
pub fn new(id: TurnId) -> Self {
Self {
id,
token: CancellationToken::new(),
joins: JoinSet::new(),
}
}
pub fn id(&self) -> TurnId {
self.id
}
pub fn token(&self) -> CancellationToken {
self.token.clone()
}
pub fn spawn<Fut>(&mut self, fut: Fut) -> AbortHandle
where
Fut: Future<Output = ()> + Send + 'static,
{
self.joins.spawn(fut)
}
pub fn cancel(&self) {
self.token.cancel();
}
pub fn is_cancelled(&self) -> bool {
self.token.is_cancelled()
}
pub async fn join_next(&mut self) -> Option<Result<(), tokio::task::JoinError>> {
self.joins.join_next().await
}
pub fn is_empty(&self) -> bool {
self.joins.is_empty()
}
pub fn drain_completed(&mut self) {
while self.joins.try_join_next().is_some() {}
}
pub fn len(&self) -> usize {
self.joins.len()
}
pub async fn drain(&mut self) {
while let Some(result) = self.joins.join_next().await {
if let Err(e) = result
&& !e.is_cancelled()
{
tracing::warn!(
turn = %self.id,
error = %e,
"turn_scope: child task panicked"
);
}
}
}
}
impl Drop for TurnScope {
fn drop(&mut self) {
if !self.token.is_cancelled() {
self.token.cancel();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn fresh_scope_has_no_tasks() {
let scope = TurnScope::new(TurnId(1));
assert_eq!(scope.len(), 0);
assert!(scope.is_empty());
assert!(!scope.is_cancelled());
}
#[tokio::test]
async fn spawned_task_completes_within_scope() {
let mut scope = TurnScope::new(TurnId(1));
scope.spawn(async {
tokio::time::sleep(Duration::from_millis(5)).await;
});
assert_eq!(scope.len(), 1);
let result = scope.join_next().await;
assert!(result.is_some());
assert!(scope.is_empty());
}
#[tokio::test]
async fn cancel_signals_child_tasks() {
let mut scope = TurnScope::new(TurnId(1));
let token = scope.token();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<&'static str>();
scope.spawn(async move {
tokio::select! {
_ = token.cancelled() => {
let _ = tx.send("cancelled");
},
_ = tokio::time::sleep(Duration::from_secs(30)) => {
let _ = tx.send("timeout");
},
}
});
tokio::time::sleep(Duration::from_millis(10)).await;
scope.cancel();
let msg = tokio::time::timeout(Duration::from_millis(500), rx.recv())
.await
.expect("cancellation should propagate")
.expect("sender alive");
assert_eq!(msg, "cancelled");
scope.drain().await;
}
#[tokio::test]
async fn drop_cancels_token() {
let token = {
let scope = TurnScope::new(TurnId(2));
scope.token()
};
assert!(token.is_cancelled());
}
#[tokio::test]
async fn drain_runs_to_completion_on_normal_tasks() {
let mut scope = TurnScope::new(TurnId(3));
for i in 0..5 {
scope.spawn(async move {
tokio::time::sleep(Duration::from_millis(i)).await;
});
}
assert_eq!(scope.len(), 5);
scope.drain().await;
assert!(scope.is_empty());
}
#[tokio::test]
async fn cancel_then_drain_is_quick() {
let mut scope = TurnScope::new(TurnId(4));
let token = scope.token();
for _ in 0..10 {
let t = token.clone();
scope.spawn(async move {
tokio::select! {
_ = t.cancelled() => {},
_ = tokio::time::sleep(Duration::from_secs(60)) => {},
}
});
}
scope.cancel();
let start = std::time::Instant::now();
scope.drain().await;
assert!(
start.elapsed() < Duration::from_millis(100),
"cancel+drain took {:?}",
start.elapsed()
);
}
}