Skip to main content

awaken_runtime/runtime/agent_runtime/
control.rs

1//! Control methods: cancel, send_decisions — with dual-index lookup (run_id + thread_id).
2
3use awaken_contract::contract::suspension::ToolCallResume;
4
5use super::AgentRuntime;
6use super::active_registry::HandleLookup;
7
8impl AgentRuntime {
9    /// Cancel an active run by thread ID and wait for it to finish.
10    ///
11    /// Returns `true` if a run was cancelled, `false` if no active run existed.
12    /// Waits up to 5 seconds for the run to actually unregister.
13    pub async fn cancel_and_wait_by_thread(&self, thread_id: &str) -> bool {
14        let notify = match self.active_runs.cancel_and_get_notify(thread_id) {
15            Some(n) => n,
16            None => return false,
17        };
18        // Wait for the RunSlotGuard to drop (calls unregister, which fires the notify).
19        let _ = tokio::time::timeout(std::time::Duration::from_secs(5), notify.notified()).await;
20        true
21    }
22
23    /// Cancel an active run by thread ID.
24    pub fn cancel_by_thread(&self, thread_id: &str) -> bool {
25        if let Some(handle) = self.active_runs.get_by_thread_id(thread_id) {
26            handle.cancel();
27            true
28        } else {
29            false
30        }
31    }
32
33    /// Cancel an active run by run ID.
34    pub fn cancel_by_run_id(&self, run_id: &str) -> bool {
35        if let Some(handle) = self.active_runs.get_by_run_id(run_id) {
36            handle.cancel();
37            true
38        } else {
39            false
40        }
41    }
42
43    /// Cancel an active run by dual-index ID (run_id or thread_id).
44    /// Ambiguous IDs are rejected.
45    pub fn cancel(&self, id: &str) -> bool {
46        match self.active_runs.lookup_strict(id) {
47            HandleLookup::Found(handle) => {
48                handle.cancel();
49                true
50            }
51            HandleLookup::NotFound => false,
52            HandleLookup::Ambiguous => {
53                tracing::warn!(id = %id, "cancel rejected: ambiguous control id");
54                false
55            }
56        }
57    }
58
59    /// Send decisions to an active run by thread ID.
60    pub fn send_decisions(
61        &self,
62        thread_id: &str,
63        decisions: Vec<(String, ToolCallResume)>,
64    ) -> bool {
65        if let Some(handle) = self.active_runs.get_by_thread_id(thread_id) {
66            if handle.send_decisions(decisions).is_err() {
67                tracing::warn!(
68                    thread_id = %thread_id,
69                    "send_decisions failed: channel closed"
70                );
71                return false;
72            }
73            true
74        } else {
75            false
76        }
77    }
78
79    /// Send a decision by dual-index ID (run_id or thread_id).
80    /// Ambiguous IDs are rejected.
81    pub fn send_decision(&self, id: &str, tool_call_id: String, resume: ToolCallResume) -> bool {
82        match self.active_runs.lookup_strict(id) {
83            HandleLookup::Found(handle) => handle.send_decision(tool_call_id, resume).is_ok(),
84            HandleLookup::NotFound => false,
85            HandleLookup::Ambiguous => {
86                tracing::warn!(id = %id, "send_decision rejected: ambiguous control id");
87                false
88            }
89        }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use awaken_contract::contract::suspension::{ResumeDecisionAction, ToolCallResume};
97    use serde_json::json;
98    use std::sync::Arc;
99
100    use crate::error::RuntimeError;
101    use crate::registry::{AgentResolver, ResolvedAgent};
102
103    struct StubResolver;
104    impl AgentResolver for StubResolver {
105        fn resolve(&self, _agent_id: &str) -> Result<ResolvedAgent, RuntimeError> {
106            Err(RuntimeError::ResolveFailed {
107                message: "stub".into(),
108            })
109        }
110    }
111
112    fn make_runtime() -> AgentRuntime {
113        AgentRuntime::new(Arc::new(StubResolver))
114    }
115
116    fn make_resume() -> ToolCallResume {
117        ToolCallResume {
118            decision_id: "d1".into(),
119            action: ResumeDecisionAction::Resume,
120            result: json!(null),
121            reason: None,
122            updated_at: 0,
123        }
124    }
125
126    // -- cancel_by_run_id --
127
128    #[test]
129    fn cancel_by_run_id_returns_true_when_registered() {
130        let rt = make_runtime();
131        let (handle, _token, _rx) = rt.create_run_channels("r1".into());
132        rt.register_run("t1", handle).unwrap();
133
134        assert!(rt.cancel_by_run_id("r1"));
135    }
136
137    #[test]
138    fn cancel_by_run_id_returns_false_when_not_found() {
139        let rt = make_runtime();
140        assert!(!rt.cancel_by_run_id("nonexistent"));
141    }
142
143    #[test]
144    fn cancel_by_run_id_signals_cancellation_token() {
145        let rt = make_runtime();
146        let (handle, token, _rx) = rt.create_run_channels("r1".into());
147        rt.register_run("t1", handle).unwrap();
148
149        assert!(!token.is_cancelled());
150        rt.cancel_by_run_id("r1");
151        assert!(token.is_cancelled());
152    }
153
154    // -- cancel_by_thread --
155
156    #[test]
157    fn cancel_by_thread_returns_true_when_registered() {
158        let rt = make_runtime();
159        let (handle, _token, _rx) = rt.create_run_channels("r1".into());
160        rt.register_run("t1", handle).unwrap();
161
162        assert!(rt.cancel_by_thread("t1"));
163    }
164
165    #[test]
166    fn cancel_by_thread_returns_false_when_not_found() {
167        let rt = make_runtime();
168        assert!(!rt.cancel_by_thread("nonexistent"));
169    }
170
171    #[test]
172    fn cancel_by_thread_signals_cancellation_token() {
173        let rt = make_runtime();
174        let (handle, token, _rx) = rt.create_run_channels("r1".into());
175        rt.register_run("t1", handle).unwrap();
176
177        assert!(!token.is_cancelled());
178        rt.cancel_by_thread("t1");
179        assert!(token.is_cancelled());
180    }
181
182    // -- cancel (dual-index) --
183
184    #[test]
185    fn cancel_by_run_id_via_dual_index() {
186        let rt = make_runtime();
187        let (handle, token, _rx) = rt.create_run_channels("r1".into());
188        rt.register_run("t1", handle).unwrap();
189
190        assert!(rt.cancel("r1"));
191        assert!(token.is_cancelled());
192    }
193
194    #[test]
195    fn cancel_by_thread_id_via_dual_index() {
196        let rt = make_runtime();
197        let (handle, token, _rx) = rt.create_run_channels("r1".into());
198        rt.register_run("t1", handle).unwrap();
199
200        assert!(rt.cancel("t1"));
201        assert!(token.is_cancelled());
202    }
203
204    #[test]
205    fn cancel_returns_false_for_unknown_id() {
206        let rt = make_runtime();
207        assert!(!rt.cancel("unknown"));
208    }
209
210    #[test]
211    fn cancel_returns_false_for_ambiguous_id() {
212        let rt = make_runtime();
213        // Register two runs where thread_id of first == run_id of second
214        let (h1, _t1, _rx1) = rt.create_run_channels("r1".into());
215        rt.register_run("shared", h1).unwrap();
216        let (h2, _t2, _rx2) = rt.create_run_channels("shared".into());
217        rt.register_run("t2", h2).unwrap();
218
219        // "shared" matches both as thread_id (-> r1) and run_id (-> shared), different runs
220        assert!(!rt.cancel("shared"));
221    }
222
223    // -- send_decisions --
224
225    #[test]
226    fn send_decisions_returns_true_and_delivers() {
227        let rt = make_runtime();
228        let (handle, _token, mut rx) = rt.create_run_channels("r1".into());
229        rt.register_run("t1", handle).unwrap();
230
231        let resume = make_resume();
232        assert!(rt.send_decisions("t1", vec![("tc1".into(), resume)]));
233
234        // Verify delivery
235        let batch = rx.try_recv().unwrap();
236        assert_eq!(batch.len(), 1);
237        assert_eq!(batch[0].0, "tc1");
238    }
239
240    #[test]
241    fn send_decisions_returns_false_for_unknown_thread() {
242        let rt = make_runtime();
243        assert!(!rt.send_decisions("unknown", vec![("tc1".into(), make_resume())]));
244    }
245
246    #[test]
247    fn send_decisions_returns_false_when_channel_closed() {
248        let rt = make_runtime();
249        let (handle, _token, rx) = rt.create_run_channels("r1".into());
250        rt.register_run("t1", handle).unwrap();
251
252        // Drop receiver to close the channel
253        drop(rx);
254
255        assert!(!rt.send_decisions("t1", vec![("tc1".into(), make_resume())]));
256    }
257
258    // -- send_decision (dual-index) --
259
260    #[test]
261    fn send_decision_by_run_id() {
262        let rt = make_runtime();
263        let (handle, _token, mut rx) = rt.create_run_channels("r1".into());
264        rt.register_run("t1", handle).unwrap();
265
266        assert!(rt.send_decision("r1", "tc1".into(), make_resume()));
267
268        let batch = rx.try_recv().unwrap();
269        assert_eq!(batch.len(), 1);
270        assert_eq!(batch[0].0, "tc1");
271    }
272
273    #[test]
274    fn send_decision_by_thread_id() {
275        let rt = make_runtime();
276        let (handle, _token, mut rx) = rt.create_run_channels("r1".into());
277        rt.register_run("t1", handle).unwrap();
278
279        assert!(rt.send_decision("t1", "tc1".into(), make_resume()));
280
281        let batch = rx.try_recv().unwrap();
282        assert_eq!(batch.len(), 1);
283    }
284
285    #[test]
286    fn send_decision_returns_false_for_unknown_id() {
287        let rt = make_runtime();
288        assert!(!rt.send_decision("unknown", "tc1".into(), make_resume()));
289    }
290
291    #[test]
292    fn send_decision_returns_false_for_ambiguous_id() {
293        let rt = make_runtime();
294        let (h1, _t1, _rx1) = rt.create_run_channels("r1".into());
295        rt.register_run("shared", h1).unwrap();
296        let (h2, _t2, _rx2) = rt.create_run_channels("shared".into());
297        rt.register_run("t2", h2).unwrap();
298
299        assert!(!rt.send_decision("shared", "tc1".into(), make_resume()));
300    }
301
302    #[test]
303    fn send_decision_returns_false_when_channel_closed() {
304        let rt = make_runtime();
305        let (handle, _token, rx) = rt.create_run_channels("r1".into());
306        rt.register_run("t1", handle).unwrap();
307        drop(rx);
308
309        assert!(!rt.send_decision("r1", "tc1".into(), make_resume()));
310    }
311
312    // -- cancel after unregister --
313
314    #[test]
315    fn cancel_after_unregister_returns_false() {
316        let rt = make_runtime();
317        let (handle, _token, _rx) = rt.create_run_channels("r1".into());
318        rt.register_run("t1", handle).unwrap();
319        rt.unregister_run("r1");
320
321        assert!(!rt.cancel("r1"));
322        assert!(!rt.cancel("t1"));
323    }
324
325    // -- cancel_and_wait_by_thread --
326
327    #[tokio::test]
328    async fn cancel_and_wait_returns_false_when_no_run() {
329        let rt = make_runtime();
330        assert!(!rt.cancel_and_wait_by_thread("unknown").await);
331    }
332
333    #[tokio::test]
334    async fn cancel_and_wait_completes_after_unregister() {
335        use std::sync::Arc;
336
337        let rt = Arc::new(make_runtime());
338        let (handle, token, _rx) = rt.create_run_channels("r1".into());
339        rt.register_run("t1", handle).unwrap();
340
341        // Spawn a task that unregisters after a short delay
342        let rt2 = Arc::clone(&rt);
343        tokio::spawn(async move {
344            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
345            rt2.unregister_run("r1");
346        });
347
348        // cancel_and_wait should return true and complete once unregister fires
349        assert!(rt.cancel_and_wait_by_thread("t1").await);
350        assert!(token.is_cancelled());
351        // Slot should be free now
352        assert!(!rt.cancel_by_thread("t1"));
353    }
354}