Skip to main content

mermaid_cli/effect/
turn_scope.rs

1//! Per-turn structured concurrency.
2//!
3//! A `TurnScope` owns exactly one `CancellationToken` and one
4//! `JoinSet`. Every task spawned for this turn gets a clone of the
5//! token, and every handle lands in the set. When the user cancels
6//! (or the reducer dispatches `Cmd::CancelScope`), we cancel the
7//! token — tokio's cooperative cancellation then unwinds every child
8//! at its next `.await`. The set is drained on drop so no task leaks.
9//!
10//! The point: cancellation is a **signal**, not a **poll**. No child
11//! task has to remember to check a shared flag; they're awaiting an
12//! mpsc receive or an HTTP body, and `token.cancelled()` races every
13//! such await via `select!`. Abort latency = the time to reach the
14//! next await point — microseconds for HTTP streams, milliseconds for
15//! tool subprocess fan-out.
16//!
17//! Forgetting cancellation is impossible: the token is baked into
18//! every adapter's `StreamContext` / `ExecContext`, and the adapter
19//! must `select!` on it to proceed. Contrast with a "drain events
20//! every 50ms" polling pattern, where long-running ops (web search,
21//! execute command) had to remember to check a shared flag — silent
22//! forgetting there shipped as hangs-until-timeout bugs.
23
24use std::future::Future;
25
26use tokio::task::{AbortHandle, JoinSet};
27use tokio_util::sync::CancellationToken;
28
29use crate::domain::TurnId;
30
31/// One turn's cancellable scope. Construct once per `SubmitPrompt`;
32/// abandon (drop) at the end of the turn.
33#[derive(Debug)]
34pub struct TurnScope {
35    id: TurnId,
36    token: CancellationToken,
37    joins: JoinSet<()>,
38}
39
40impl TurnScope {
41    pub fn new(id: TurnId) -> Self {
42        Self {
43            id,
44            token: CancellationToken::new(),
45            joins: JoinSet::new(),
46        }
47    }
48
49    pub fn id(&self) -> TurnId {
50        self.id
51    }
52
53    /// Clone the scope's token. Hand this to child tasks so they can
54    /// participate in cooperative cancellation.
55    pub fn token(&self) -> CancellationToken {
56        self.token.clone()
57    }
58
59    /// Spawn a child task under this scope. The returned handle is
60    /// retained inside the scope's `JoinSet` — callers don't need to
61    /// keep it. Cancellation of the scope aborts the task at its next
62    /// await.
63    pub fn spawn<Fut>(&mut self, fut: Fut) -> AbortHandle
64    where
65        Fut: Future<Output = ()> + Send + 'static,
66    {
67        self.joins.spawn(fut)
68    }
69
70    /// Signal cancellation to every child task. Returns immediately —
71    /// callers drain the `JoinSet` separately via `drain_completed`.
72    pub fn cancel(&self) {
73        self.token.cancel();
74    }
75
76    /// True iff the scope has been cancelled.
77    pub fn is_cancelled(&self) -> bool {
78        self.token.is_cancelled()
79    }
80
81    /// Join one task if any has completed. Returns `None` immediately
82    /// when the set is empty or nothing is ready. Intended for the
83    /// main loop's per-tick bookkeeping — not a blocking drain.
84    pub async fn join_next(&mut self) -> Option<Result<(), tokio::task::JoinError>> {
85        self.joins.join_next().await
86    }
87
88    /// True iff no child task is currently running inside this scope.
89    /// The main loop uses this after a `cancel()` to decide when to
90    /// transition from `TurnState::Cancelling` back to `Idle`.
91    pub fn is_empty(&self) -> bool {
92        self.joins.is_empty()
93    }
94
95    /// Drain any already-completed tasks from the JoinSet without
96    /// blocking. `JoinSet::is_empty` only flips to true after finished
97    /// tasks are explicitly harvested via `join_next`; without this,
98    /// `EffectRunner::reap_empty_scopes` would see finished-but-not-
99    /// joined scopes as "still busy" and never reap them. F12.
100    pub fn drain_completed(&mut self) {
101        while self.joins.try_join_next().is_some() {}
102    }
103
104    pub fn len(&self) -> usize {
105        self.joins.len()
106    }
107
108    /// Await every outstanding task to completion, swallowing
109    /// `JoinError`s (they happen on normal `abort()`s after
110    /// cancellation). Use during shutdown.
111    pub async fn drain(&mut self) {
112        while let Some(result) = self.joins.join_next().await {
113            if let Err(e) = result
114                && !e.is_cancelled()
115            {
116                tracing::warn!(
117                    turn = %self.id,
118                    error = %e,
119                    "turn_scope: child task panicked"
120                );
121            }
122        }
123    }
124}
125
126impl Drop for TurnScope {
127    fn drop(&mut self) {
128        // If the caller forgot to cancel before dropping, be safe: a
129        // live child task holding resources after the turn ended is
130        // exactly the kind of leak this architecture exists to
131        // prevent. `JoinSet::drop` already aborts its members, but we
132        // still flip the cancellation token so any child that branches
133        // on it observes the abort intent too.
134        if !self.token.is_cancelled() {
135            self.token.cancel();
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use std::time::Duration;
144
145    #[tokio::test]
146    async fn fresh_scope_has_no_tasks() {
147        let scope = TurnScope::new(TurnId(1));
148        assert_eq!(scope.len(), 0);
149        assert!(scope.is_empty());
150        assert!(!scope.is_cancelled());
151    }
152
153    #[tokio::test]
154    async fn spawned_task_completes_within_scope() {
155        let mut scope = TurnScope::new(TurnId(1));
156        scope.spawn(async {
157            tokio::time::sleep(Duration::from_millis(5)).await;
158        });
159        assert_eq!(scope.len(), 1);
160        // Wait for it.
161        let result = scope.join_next().await;
162        assert!(result.is_some());
163        assert!(scope.is_empty());
164    }
165
166    #[tokio::test]
167    async fn cancel_signals_child_tasks() {
168        let mut scope = TurnScope::new(TurnId(1));
169        let token = scope.token();
170        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<&'static str>();
171        scope.spawn(async move {
172            tokio::select! {
173                _ = token.cancelled() => {
174                    let _ = tx.send("cancelled");
175                },
176                _ = tokio::time::sleep(Duration::from_secs(30)) => {
177                    let _ = tx.send("timeout");
178                },
179            }
180        });
181
182        // Give the task a moment to register its select.
183        tokio::time::sleep(Duration::from_millis(10)).await;
184        scope.cancel();
185        let msg = tokio::time::timeout(Duration::from_millis(500), rx.recv())
186            .await
187            .expect("cancellation should propagate")
188            .expect("sender alive");
189        assert_eq!(msg, "cancelled");
190        scope.drain().await;
191    }
192
193    #[tokio::test]
194    async fn drop_cancels_token() {
195        let token = {
196            let scope = TurnScope::new(TurnId(2));
197            scope.token()
198        };
199        // Scope dropped — token should be cancelled.
200        assert!(token.is_cancelled());
201    }
202
203    #[tokio::test]
204    async fn drain_runs_to_completion_on_normal_tasks() {
205        let mut scope = TurnScope::new(TurnId(3));
206        for i in 0..5 {
207            scope.spawn(async move {
208                tokio::time::sleep(Duration::from_millis(i)).await;
209            });
210        }
211        assert_eq!(scope.len(), 5);
212        scope.drain().await;
213        assert!(scope.is_empty());
214    }
215
216    #[tokio::test]
217    async fn cancel_then_drain_is_quick() {
218        let mut scope = TurnScope::new(TurnId(4));
219        let token = scope.token();
220        for _ in 0..10 {
221            let t = token.clone();
222            scope.spawn(async move {
223                tokio::select! {
224                    _ = t.cancelled() => {},
225                    _ = tokio::time::sleep(Duration::from_secs(60)) => {},
226                }
227            });
228        }
229        scope.cancel();
230        let start = std::time::Instant::now();
231        scope.drain().await;
232        // All tasks cancel and unwind within 100ms (generous bound —
233        // realistic would be <10ms).
234        assert!(
235            start.elapsed() < Duration::from_millis(100),
236            "cancel+drain took {:?}",
237            start.elapsed()
238        );
239    }
240}