Skip to main content

awaken_runtime/runtime/agent_runtime/
control.rs

1//! Control methods: cancel, send_decisions — with dual-index lookup (run_id + thread_id).
2
3use awaken_contract::contract::message::Message;
4use awaken_contract::contract::suspension::ToolCallResume;
5
6use super::AgentRuntime;
7use super::active_registry::HandleLookup;
8
9#[cfg(not(test))]
10const CANCEL_WAIT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
11#[cfg(test)]
12const CANCEL_WAIT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(25);
13
14impl AgentRuntime {
15    /// Cancel an active run by thread ID and wait for it to finish.
16    ///
17    /// Returns `true` only when the run slot is released before the wait timeout.
18    /// Returns `false` when no active run exists or cancellation does not finish in time.
19    pub async fn cancel_and_wait_by_thread(&self, thread_id: &str) -> bool {
20        let notify = match self.active_runs.cancel_and_get_notify(thread_id) {
21            Some(n) => n,
22            None => return false,
23        };
24        if !self.active_runs.has_active_thread(thread_id) {
25            return true;
26        }
27        // Wait for the RunSlotGuard to drop (calls unregister, which fires the notify).
28        tokio::time::timeout(CANCEL_WAIT_TIMEOUT, notify.notified())
29            .await
30            .is_ok()
31            || !self.active_runs.has_active_thread(thread_id)
32    }
33
34    /// Cancel an active run by thread ID.
35    pub fn cancel_by_thread(&self, thread_id: &str) -> bool {
36        if let Some(handle) = self.active_runs.get_by_thread_id(thread_id) {
37            handle.cancel();
38            true
39        } else {
40            false
41        }
42    }
43
44    /// Cancel an active run by run ID.
45    pub fn cancel_by_run_id(&self, run_id: &str) -> bool {
46        if let Some(handle) = self.active_runs.get_by_run_id(run_id) {
47            handle.cancel();
48            true
49        } else {
50            false
51        }
52    }
53
54    /// Cancel an active run by dual-index ID (run_id or thread_id).
55    /// Ambiguous IDs are rejected.
56    pub fn cancel(&self, id: &str) -> bool {
57        match self.active_runs.lookup_strict(id) {
58            HandleLookup::Found(handle) => {
59                handle.cancel();
60                true
61            }
62            HandleLookup::NotFound => false,
63            HandleLookup::Ambiguous => {
64                tracing::warn!(id = %id, "cancel rejected: ambiguous control id");
65                false
66            }
67        }
68    }
69
70    /// Send decisions to an active run by thread ID.
71    pub fn send_decisions(
72        &self,
73        thread_id: &str,
74        decisions: Vec<(String, ToolCallResume)>,
75    ) -> bool {
76        if let Some(handle) = self.active_runs.get_by_thread_id(thread_id) {
77            if handle.send_decisions(decisions).is_err() {
78                tracing::warn!(
79                    thread_id = %thread_id,
80                    "send_decisions failed: channel closed"
81                );
82                return false;
83            }
84            true
85        } else {
86            false
87        }
88    }
89
90    /// Send a decision by dual-index ID (run_id or thread_id).
91    /// Ambiguous IDs are rejected.
92    pub fn send_decision(&self, id: &str, tool_call_id: String, resume: ToolCallResume) -> bool {
93        match self.active_runs.lookup_strict(id) {
94            HandleLookup::Found(handle) => handle.send_decision(tool_call_id, resume).is_ok(),
95            HandleLookup::NotFound => false,
96            HandleLookup::Ambiguous => {
97                tracing::warn!(id = %id, "send_decision rejected: ambiguous control id");
98                false
99            }
100        }
101    }
102
103    /// Send direct input messages to an active run by run ID or thread ID.
104    /// Ambiguous IDs are rejected.
105    pub fn send_messages(&self, id: &str, messages: Vec<Message>) -> bool {
106        match self.active_runs.lookup_strict(id) {
107            HandleLookup::Found(handle) => handle.send_messages(messages),
108            HandleLookup::NotFound => false,
109            HandleLookup::Ambiguous => {
110                tracing::warn!(id = %id, "send_messages rejected: ambiguous control id");
111                false
112            }
113        }
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use awaken_contract::contract::suspension::{ResumeDecisionAction, ToolCallResume};
121    use serde_json::json;
122    use std::sync::Arc;
123
124    use crate::error::RuntimeError;
125    use crate::registry::{AgentResolver, ResolvedAgent};
126
127    struct StubResolver;
128    impl AgentResolver for StubResolver {
129        fn resolve(&self, _agent_id: &str) -> Result<ResolvedAgent, RuntimeError> {
130            Err(RuntimeError::ResolveFailed {
131                message: "stub".into(),
132            })
133        }
134    }
135
136    fn make_runtime() -> AgentRuntime {
137        AgentRuntime::new(Arc::new(StubResolver))
138    }
139
140    fn make_resume() -> ToolCallResume {
141        ToolCallResume {
142            decision_id: "d1".into(),
143            action: ResumeDecisionAction::Resume,
144            result: json!(null),
145            reason: None,
146            updated_at: 0,
147        }
148    }
149
150    // -- cancel_by_run_id --
151
152    #[test]
153    fn cancel_by_run_id_returns_true_when_registered() {
154        let rt = make_runtime();
155        let (handle, _token, _rx) = rt.create_run_channels("r1".into());
156        rt.register_run("t1", handle).unwrap();
157
158        assert!(rt.cancel_by_run_id("r1"));
159    }
160
161    #[test]
162    fn cancel_by_run_id_returns_false_when_not_found() {
163        let rt = make_runtime();
164        assert!(!rt.cancel_by_run_id("nonexistent"));
165    }
166
167    #[test]
168    fn cancel_by_run_id_signals_cancellation_token() {
169        let rt = make_runtime();
170        let (handle, token, _rx) = rt.create_run_channels("r1".into());
171        rt.register_run("t1", handle).unwrap();
172
173        assert!(!token.is_cancelled());
174        rt.cancel_by_run_id("r1");
175        assert!(token.is_cancelled());
176    }
177
178    // -- cancel_by_thread --
179
180    #[test]
181    fn cancel_by_thread_returns_true_when_registered() {
182        let rt = make_runtime();
183        let (handle, _token, _rx) = rt.create_run_channels("r1".into());
184        rt.register_run("t1", handle).unwrap();
185
186        assert!(rt.cancel_by_thread("t1"));
187    }
188
189    #[test]
190    fn cancel_by_thread_returns_false_when_not_found() {
191        let rt = make_runtime();
192        assert!(!rt.cancel_by_thread("nonexistent"));
193    }
194
195    #[test]
196    fn cancel_by_thread_signals_cancellation_token() {
197        let rt = make_runtime();
198        let (handle, token, _rx) = rt.create_run_channels("r1".into());
199        rt.register_run("t1", handle).unwrap();
200
201        assert!(!token.is_cancelled());
202        rt.cancel_by_thread("t1");
203        assert!(token.is_cancelled());
204    }
205
206    // -- cancel (dual-index) --
207
208    #[test]
209    fn cancel_by_run_id_via_dual_index() {
210        let rt = make_runtime();
211        let (handle, token, _rx) = rt.create_run_channels("r1".into());
212        rt.register_run("t1", handle).unwrap();
213
214        assert!(rt.cancel("r1"));
215        assert!(token.is_cancelled());
216    }
217
218    #[test]
219    fn cancel_by_thread_id_via_dual_index() {
220        let rt = make_runtime();
221        let (handle, token, _rx) = rt.create_run_channels("r1".into());
222        rt.register_run("t1", handle).unwrap();
223
224        assert!(rt.cancel("t1"));
225        assert!(token.is_cancelled());
226    }
227
228    #[test]
229    fn cancel_returns_false_for_unknown_id() {
230        let rt = make_runtime();
231        assert!(!rt.cancel("unknown"));
232    }
233
234    #[test]
235    fn cancel_returns_false_for_ambiguous_id() {
236        let rt = make_runtime();
237        // Register two runs where thread_id of first == run_id of second
238        let (h1, _t1, _rx1) = rt.create_run_channels("r1".into());
239        rt.register_run("shared", h1).unwrap();
240        let (h2, _t2, _rx2) = rt.create_run_channels("shared".into());
241        rt.register_run("t2", h2).unwrap();
242
243        // "shared" matches both as thread_id (-> r1) and run_id (-> shared), different runs
244        assert!(!rt.cancel("shared"));
245    }
246
247    // -- send_decisions --
248
249    #[test]
250    fn send_decisions_returns_true_and_delivers() {
251        let rt = make_runtime();
252        let (handle, _token, mut rx) = rt.create_run_channels("r1".into());
253        rt.register_run("t1", handle).unwrap();
254
255        let resume = make_resume();
256        assert!(rt.send_decisions("t1", vec![("tc1".into(), resume)]));
257
258        // Verify delivery
259        let batch = rx.try_recv().unwrap();
260        assert_eq!(batch.len(), 1);
261        assert_eq!(batch[0].0, "tc1");
262    }
263
264    #[test]
265    fn send_decisions_returns_false_for_unknown_thread() {
266        let rt = make_runtime();
267        assert!(!rt.send_decisions("unknown", vec![("tc1".into(), make_resume())]));
268    }
269
270    #[test]
271    fn send_decisions_returns_false_when_channel_closed() {
272        let rt = make_runtime();
273        let (handle, _token, rx) = rt.create_run_channels("r1".into());
274        rt.register_run("t1", handle).unwrap();
275
276        // Drop receiver to close the channel
277        drop(rx);
278
279        assert!(!rt.send_decisions("t1", vec![("tc1".into(), make_resume())]));
280    }
281
282    // -- send_decision (dual-index) --
283
284    #[test]
285    fn send_decision_by_run_id() {
286        let rt = make_runtime();
287        let (handle, _token, mut rx) = rt.create_run_channels("r1".into());
288        rt.register_run("t1", handle).unwrap();
289
290        assert!(rt.send_decision("r1", "tc1".into(), make_resume()));
291
292        let batch = rx.try_recv().unwrap();
293        assert_eq!(batch.len(), 1);
294        assert_eq!(batch[0].0, "tc1");
295    }
296
297    #[test]
298    fn send_decision_by_thread_id() {
299        let rt = make_runtime();
300        let (handle, _token, mut rx) = rt.create_run_channels("r1".into());
301        rt.register_run("t1", handle).unwrap();
302
303        assert!(rt.send_decision("t1", "tc1".into(), make_resume()));
304
305        let batch = rx.try_recv().unwrap();
306        assert_eq!(batch.len(), 1);
307    }
308
309    #[test]
310    fn send_decision_returns_false_for_unknown_id() {
311        let rt = make_runtime();
312        assert!(!rt.send_decision("unknown", "tc1".into(), make_resume()));
313    }
314
315    #[test]
316    fn send_decision_returns_false_for_ambiguous_id() {
317        let rt = make_runtime();
318        let (h1, _t1, _rx1) = rt.create_run_channels("r1".into());
319        rt.register_run("shared", h1).unwrap();
320        let (h2, _t2, _rx2) = rt.create_run_channels("shared".into());
321        rt.register_run("t2", h2).unwrap();
322
323        assert!(!rt.send_decision("shared", "tc1".into(), make_resume()));
324    }
325
326    #[test]
327    fn send_decision_returns_false_when_channel_closed() {
328        let rt = make_runtime();
329        let (handle, _token, rx) = rt.create_run_channels("r1".into());
330        rt.register_run("t1", handle).unwrap();
331        drop(rx);
332
333        assert!(!rt.send_decision("r1", "tc1".into(), make_resume()));
334    }
335
336    // -- send_messages (dual-index) --
337
338    #[test]
339    fn send_messages_by_run_id_delivers_to_inbox() {
340        let rt = make_runtime();
341        let (inbox_tx, mut inbox_rx) = crate::inbox::inbox_channel();
342        let (handle, _token, _rx) =
343            rt.create_run_channels_with_inbox("r1".into(), None, Some(inbox_tx));
344        rt.register_run("t1", handle).unwrap();
345
346        assert!(rt.send_messages("r1", vec![Message::user("live")]));
347
348        let payload = inbox_rx.try_recv().expect("payload should be delivered");
349        let messages = crate::inbox::inbox_payload_messages(&payload);
350        assert_eq!(messages.len(), 1);
351        assert_eq!(messages[0].text(), "live");
352    }
353
354    #[test]
355    fn send_messages_returns_false_without_inbox() {
356        let rt = make_runtime();
357        let (handle, _token, _rx) = rt.create_run_channels("r1".into());
358        rt.register_run("t1", handle).unwrap();
359
360        assert!(!rt.send_messages("r1", vec![Message::user("live")]));
361    }
362
363    #[test]
364    fn send_messages_returns_false_for_closed_inbox() {
365        let rt = make_runtime();
366        let (inbox_tx, inbox_rx) = crate::inbox::inbox_channel();
367        drop(inbox_rx);
368        let (handle, _token, _rx) =
369            rt.create_run_channels_with_inbox("r1".into(), None, Some(inbox_tx));
370        rt.register_run("t1", handle).unwrap();
371
372        assert!(!rt.send_messages("r1", vec![Message::user("live")]));
373    }
374
375    // -- cancel after unregister --
376
377    #[test]
378    fn cancel_after_unregister_returns_false() {
379        let rt = make_runtime();
380        let (handle, _token, _rx) = rt.create_run_channels("r1".into());
381        rt.register_run("t1", handle).unwrap();
382        rt.unregister_run("r1");
383
384        assert!(!rt.cancel("r1"));
385        assert!(!rt.cancel("t1"));
386    }
387
388    // -- cancel_and_wait_by_thread --
389
390    #[tokio::test]
391    async fn cancel_and_wait_returns_false_when_no_run() {
392        let rt = make_runtime();
393        assert!(!rt.cancel_and_wait_by_thread("unknown").await);
394    }
395
396    #[tokio::test]
397    async fn cancel_and_wait_returns_false_when_run_does_not_unregister() {
398        let rt = make_runtime();
399        let (handle, token, _rx) = rt.create_run_channels("r1".into());
400        rt.register_run("t1", handle).unwrap();
401
402        assert!(!rt.cancel_and_wait_by_thread("t1").await);
403        assert!(token.is_cancelled());
404    }
405
406    #[tokio::test]
407    async fn cancel_and_wait_completes_after_unregister() {
408        use std::sync::Arc;
409
410        let rt = Arc::new(make_runtime());
411        let (handle, token, _rx) = rt.create_run_channels("r1".into());
412        rt.register_run("t1", handle).unwrap();
413
414        // Spawn a task that unregisters after a short delay
415        let rt2 = Arc::clone(&rt);
416        tokio::spawn(async move {
417            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
418            rt2.unregister_run("r1");
419        });
420
421        // cancel_and_wait should return true and complete once unregister fires
422        assert!(rt.cancel_and_wait_by_thread("t1").await);
423        assert!(token.is_cancelled());
424        // Slot should be free now
425        assert!(!rt.cancel_by_thread("t1"));
426    }
427}