mermaid_cli/effect/turn_scope.rs
1//! Per-turn structured concurrency.
2//!
3//! A `TurnScope` owns exactly one `CancellationToken` and one
4//! `JoinSet`. Every task spawned for this turn gets a clone of the
5//! token, and every handle lands in the set. When the user cancels
6//! (or the reducer dispatches `Cmd::CancelScope`), we cancel the
7//! token — tokio's cooperative cancellation then unwinds every child
8//! at its next `.await`. The set is drained on drop so no task leaks.
9//!
10//! The point: cancellation is a **signal**, not a **poll**. No child
11//! task has to remember to check a shared flag; they're awaiting an
12//! mpsc receive or an HTTP body, and `token.cancelled()` races every
13//! such await via `select!`. Abort latency = the time to reach the
14//! next await point — microseconds for HTTP streams, milliseconds for
15//! tool subprocess fan-out.
16//!
17//! Forgetting cancellation is impossible: the token is baked into
18//! every adapter's `StreamContext` / `ExecContext`, and the adapter
19//! must `select!` on it to proceed. Contrast with a "drain events
20//! every 50ms" polling pattern, where long-running ops (web search,
21//! execute command) had to remember to check a shared flag — silent
22//! forgetting there shipped as hangs-until-timeout bugs.
23
24use std::future::Future;
25
26use tokio::task::{AbortHandle, JoinSet};
27use tokio_util::sync::CancellationToken;
28
29use crate::domain::TurnId;
30
31/// One turn's cancellable scope. Construct once per `SubmitPrompt`;
32/// abandon (drop) at the end of the turn.
33#[derive(Debug)]
34pub struct TurnScope {
35 id: TurnId,
36 token: CancellationToken,
37 joins: JoinSet<()>,
38}
39
40impl TurnScope {
41 pub fn new(id: TurnId) -> Self {
42 Self {
43 id,
44 token: CancellationToken::new(),
45 joins: JoinSet::new(),
46 }
47 }
48
49 pub fn id(&self) -> TurnId {
50 self.id
51 }
52
53 /// Clone the scope's token. Hand this to child tasks so they can
54 /// participate in cooperative cancellation.
55 pub fn token(&self) -> CancellationToken {
56 self.token.clone()
57 }
58
59 /// Spawn a child task under this scope. The returned handle is
60 /// retained inside the scope's `JoinSet` — callers don't need to
61 /// keep it. Cancellation of the scope aborts the task at its next
62 /// await.
63 pub fn spawn<Fut>(&mut self, fut: Fut) -> AbortHandle
64 where
65 Fut: Future<Output = ()> + Send + 'static,
66 {
67 self.joins.spawn(fut)
68 }
69
70 /// Signal cancellation to every child task. Returns immediately —
71 /// callers drain the `JoinSet` separately via `drain_completed`.
72 pub fn cancel(&self) {
73 self.token.cancel();
74 }
75
76 /// True iff the scope has been cancelled.
77 pub fn is_cancelled(&self) -> bool {
78 self.token.is_cancelled()
79 }
80
81 /// Join one task if any has completed. Returns `None` immediately
82 /// when the set is empty or nothing is ready. Intended for the
83 /// main loop's per-tick bookkeeping — not a blocking drain.
84 pub async fn join_next(&mut self) -> Option<Result<(), tokio::task::JoinError>> {
85 self.joins.join_next().await
86 }
87
88 /// True iff no child task is currently running inside this scope.
89 /// The main loop uses this after a `cancel()` to decide when to
90 /// transition from `TurnState::Cancelling` back to `Idle`.
91 pub fn is_empty(&self) -> bool {
92 self.joins.is_empty()
93 }
94
95 /// Drain any already-completed tasks from the JoinSet without
96 /// blocking. `JoinSet::is_empty` only flips to true after finished
97 /// tasks are explicitly harvested via `join_next`; without this,
98 /// `EffectRunner::reap_empty_scopes` would see finished-but-not-
99 /// joined scopes as "still busy" and never reap them. F12.
100 pub fn drain_completed(&mut self) {
101 while self.joins.try_join_next().is_some() {}
102 }
103
104 pub fn len(&self) -> usize {
105 self.joins.len()
106 }
107
108 /// Await every outstanding task to completion, swallowing
109 /// `JoinError`s (they happen on normal `abort()`s after
110 /// cancellation). Use during shutdown.
111 pub async fn drain(&mut self) {
112 while let Some(result) = self.joins.join_next().await {
113 if let Err(e) = result
114 && !e.is_cancelled()
115 {
116 tracing::warn!(
117 turn = %self.id,
118 error = %e,
119 "turn_scope: child task panicked"
120 );
121 }
122 }
123 }
124}
125
126impl Drop for TurnScope {
127 fn drop(&mut self) {
128 // If the caller forgot to cancel before dropping, be safe: a
129 // live child task holding resources after the turn ended is
130 // exactly the kind of leak this architecture exists to
131 // prevent. `JoinSet::drop` already aborts its members, but we
132 // still flip the cancellation token so any child that branches
133 // on it observes the abort intent too.
134 if !self.token.is_cancelled() {
135 self.token.cancel();
136 }
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use std::time::Duration;
144
145 #[tokio::test]
146 async fn fresh_scope_has_no_tasks() {
147 let scope = TurnScope::new(TurnId(1));
148 assert_eq!(scope.len(), 0);
149 assert!(scope.is_empty());
150 assert!(!scope.is_cancelled());
151 }
152
153 #[tokio::test]
154 async fn spawned_task_completes_within_scope() {
155 let mut scope = TurnScope::new(TurnId(1));
156 scope.spawn(async {
157 tokio::time::sleep(Duration::from_millis(5)).await;
158 });
159 assert_eq!(scope.len(), 1);
160 // Wait for it.
161 let result = scope.join_next().await;
162 assert!(result.is_some());
163 assert!(scope.is_empty());
164 }
165
166 #[tokio::test]
167 async fn cancel_signals_child_tasks() {
168 let mut scope = TurnScope::new(TurnId(1));
169 let token = scope.token();
170 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<&'static str>();
171 scope.spawn(async move {
172 tokio::select! {
173 _ = token.cancelled() => {
174 let _ = tx.send("cancelled");
175 },
176 _ = tokio::time::sleep(Duration::from_secs(30)) => {
177 let _ = tx.send("timeout");
178 },
179 }
180 });
181
182 // Give the task a moment to register its select.
183 tokio::time::sleep(Duration::from_millis(10)).await;
184 scope.cancel();
185 let msg = tokio::time::timeout(Duration::from_millis(500), rx.recv())
186 .await
187 .expect("cancellation should propagate")
188 .expect("sender alive");
189 assert_eq!(msg, "cancelled");
190 scope.drain().await;
191 }
192
193 #[tokio::test]
194 async fn drop_cancels_token() {
195 let token = {
196 let scope = TurnScope::new(TurnId(2));
197 scope.token()
198 };
199 // Scope dropped — token should be cancelled.
200 assert!(token.is_cancelled());
201 }
202
203 #[tokio::test]
204 async fn drain_runs_to_completion_on_normal_tasks() {
205 let mut scope = TurnScope::new(TurnId(3));
206 for i in 0..5 {
207 scope.spawn(async move {
208 tokio::time::sleep(Duration::from_millis(i)).await;
209 });
210 }
211 assert_eq!(scope.len(), 5);
212 scope.drain().await;
213 assert!(scope.is_empty());
214 }
215
216 #[tokio::test]
217 async fn cancel_then_drain_is_quick() {
218 let mut scope = TurnScope::new(TurnId(4));
219 let token = scope.token();
220 for _ in 0..10 {
221 let t = token.clone();
222 scope.spawn(async move {
223 tokio::select! {
224 _ = t.cancelled() => {},
225 _ = tokio::time::sleep(Duration::from_secs(60)) => {},
226 }
227 });
228 }
229 scope.cancel();
230 let start = std::time::Instant::now();
231 scope.drain().await;
232 // All tasks cancel and unwind within 100ms (generous bound —
233 // realistic would be <10ms).
234 assert!(
235 start.elapsed() < Duration::from_millis(100),
236 "cancel+drain took {:?}",
237 start.elapsed()
238 );
239 }
240}