awaken_runtime/runtime/agent_runtime/
control.rs1use awaken_contract::contract::suspension::ToolCallResume;
4
5use super::AgentRuntime;
6use super::active_registry::HandleLookup;
7
8impl AgentRuntime {
9 pub async fn cancel_and_wait_by_thread(&self, thread_id: &str) -> bool {
14 let notify = match self.active_runs.cancel_and_get_notify(thread_id) {
15 Some(n) => n,
16 None => return false,
17 };
18 let _ = tokio::time::timeout(std::time::Duration::from_secs(5), notify.notified()).await;
20 true
21 }
22
23 pub fn cancel_by_thread(&self, thread_id: &str) -> bool {
25 if let Some(handle) = self.active_runs.get_by_thread_id(thread_id) {
26 handle.cancel();
27 true
28 } else {
29 false
30 }
31 }
32
33 pub fn cancel_by_run_id(&self, run_id: &str) -> bool {
35 if let Some(handle) = self.active_runs.get_by_run_id(run_id) {
36 handle.cancel();
37 true
38 } else {
39 false
40 }
41 }
42
43 pub fn cancel(&self, id: &str) -> bool {
46 match self.active_runs.lookup_strict(id) {
47 HandleLookup::Found(handle) => {
48 handle.cancel();
49 true
50 }
51 HandleLookup::NotFound => false,
52 HandleLookup::Ambiguous => {
53 tracing::warn!(id = %id, "cancel rejected: ambiguous control id");
54 false
55 }
56 }
57 }
58
59 pub fn send_decisions(
61 &self,
62 thread_id: &str,
63 decisions: Vec<(String, ToolCallResume)>,
64 ) -> bool {
65 if let Some(handle) = self.active_runs.get_by_thread_id(thread_id) {
66 if handle.send_decisions(decisions).is_err() {
67 tracing::warn!(
68 thread_id = %thread_id,
69 "send_decisions failed: channel closed"
70 );
71 return false;
72 }
73 true
74 } else {
75 false
76 }
77 }
78
79 pub fn send_decision(&self, id: &str, tool_call_id: String, resume: ToolCallResume) -> bool {
82 match self.active_runs.lookup_strict(id) {
83 HandleLookup::Found(handle) => handle.send_decision(tool_call_id, resume).is_ok(),
84 HandleLookup::NotFound => false,
85 HandleLookup::Ambiguous => {
86 tracing::warn!(id = %id, "send_decision rejected: ambiguous control id");
87 false
88 }
89 }
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use awaken_contract::contract::suspension::{ResumeDecisionAction, ToolCallResume};
97 use serde_json::json;
98 use std::sync::Arc;
99
100 use crate::error::RuntimeError;
101 use crate::registry::{AgentResolver, ResolvedAgent};
102
103 struct StubResolver;
104 impl AgentResolver for StubResolver {
105 fn resolve(&self, _agent_id: &str) -> Result<ResolvedAgent, RuntimeError> {
106 Err(RuntimeError::ResolveFailed {
107 message: "stub".into(),
108 })
109 }
110 }
111
112 fn make_runtime() -> AgentRuntime {
113 AgentRuntime::new(Arc::new(StubResolver))
114 }
115
116 fn make_resume() -> ToolCallResume {
117 ToolCallResume {
118 decision_id: "d1".into(),
119 action: ResumeDecisionAction::Resume,
120 result: json!(null),
121 reason: None,
122 updated_at: 0,
123 }
124 }
125
126 #[test]
129 fn cancel_by_run_id_returns_true_when_registered() {
130 let rt = make_runtime();
131 let (handle, _token, _rx) = rt.create_run_channels("r1".into());
132 rt.register_run("t1", handle).unwrap();
133
134 assert!(rt.cancel_by_run_id("r1"));
135 }
136
137 #[test]
138 fn cancel_by_run_id_returns_false_when_not_found() {
139 let rt = make_runtime();
140 assert!(!rt.cancel_by_run_id("nonexistent"));
141 }
142
143 #[test]
144 fn cancel_by_run_id_signals_cancellation_token() {
145 let rt = make_runtime();
146 let (handle, token, _rx) = rt.create_run_channels("r1".into());
147 rt.register_run("t1", handle).unwrap();
148
149 assert!(!token.is_cancelled());
150 rt.cancel_by_run_id("r1");
151 assert!(token.is_cancelled());
152 }
153
154 #[test]
157 fn cancel_by_thread_returns_true_when_registered() {
158 let rt = make_runtime();
159 let (handle, _token, _rx) = rt.create_run_channels("r1".into());
160 rt.register_run("t1", handle).unwrap();
161
162 assert!(rt.cancel_by_thread("t1"));
163 }
164
165 #[test]
166 fn cancel_by_thread_returns_false_when_not_found() {
167 let rt = make_runtime();
168 assert!(!rt.cancel_by_thread("nonexistent"));
169 }
170
171 #[test]
172 fn cancel_by_thread_signals_cancellation_token() {
173 let rt = make_runtime();
174 let (handle, token, _rx) = rt.create_run_channels("r1".into());
175 rt.register_run("t1", handle).unwrap();
176
177 assert!(!token.is_cancelled());
178 rt.cancel_by_thread("t1");
179 assert!(token.is_cancelled());
180 }
181
182 #[test]
185 fn cancel_by_run_id_via_dual_index() {
186 let rt = make_runtime();
187 let (handle, token, _rx) = rt.create_run_channels("r1".into());
188 rt.register_run("t1", handle).unwrap();
189
190 assert!(rt.cancel("r1"));
191 assert!(token.is_cancelled());
192 }
193
194 #[test]
195 fn cancel_by_thread_id_via_dual_index() {
196 let rt = make_runtime();
197 let (handle, token, _rx) = rt.create_run_channels("r1".into());
198 rt.register_run("t1", handle).unwrap();
199
200 assert!(rt.cancel("t1"));
201 assert!(token.is_cancelled());
202 }
203
204 #[test]
205 fn cancel_returns_false_for_unknown_id() {
206 let rt = make_runtime();
207 assert!(!rt.cancel("unknown"));
208 }
209
210 #[test]
211 fn cancel_returns_false_for_ambiguous_id() {
212 let rt = make_runtime();
213 let (h1, _t1, _rx1) = rt.create_run_channels("r1".into());
215 rt.register_run("shared", h1).unwrap();
216 let (h2, _t2, _rx2) = rt.create_run_channels("shared".into());
217 rt.register_run("t2", h2).unwrap();
218
219 assert!(!rt.cancel("shared"));
221 }
222
223 #[test]
226 fn send_decisions_returns_true_and_delivers() {
227 let rt = make_runtime();
228 let (handle, _token, mut rx) = rt.create_run_channels("r1".into());
229 rt.register_run("t1", handle).unwrap();
230
231 let resume = make_resume();
232 assert!(rt.send_decisions("t1", vec![("tc1".into(), resume)]));
233
234 let batch = rx.try_recv().unwrap();
236 assert_eq!(batch.len(), 1);
237 assert_eq!(batch[0].0, "tc1");
238 }
239
240 #[test]
241 fn send_decisions_returns_false_for_unknown_thread() {
242 let rt = make_runtime();
243 assert!(!rt.send_decisions("unknown", vec![("tc1".into(), make_resume())]));
244 }
245
246 #[test]
247 fn send_decisions_returns_false_when_channel_closed() {
248 let rt = make_runtime();
249 let (handle, _token, rx) = rt.create_run_channels("r1".into());
250 rt.register_run("t1", handle).unwrap();
251
252 drop(rx);
254
255 assert!(!rt.send_decisions("t1", vec![("tc1".into(), make_resume())]));
256 }
257
258 #[test]
261 fn send_decision_by_run_id() {
262 let rt = make_runtime();
263 let (handle, _token, mut rx) = rt.create_run_channels("r1".into());
264 rt.register_run("t1", handle).unwrap();
265
266 assert!(rt.send_decision("r1", "tc1".into(), make_resume()));
267
268 let batch = rx.try_recv().unwrap();
269 assert_eq!(batch.len(), 1);
270 assert_eq!(batch[0].0, "tc1");
271 }
272
273 #[test]
274 fn send_decision_by_thread_id() {
275 let rt = make_runtime();
276 let (handle, _token, mut rx) = rt.create_run_channels("r1".into());
277 rt.register_run("t1", handle).unwrap();
278
279 assert!(rt.send_decision("t1", "tc1".into(), make_resume()));
280
281 let batch = rx.try_recv().unwrap();
282 assert_eq!(batch.len(), 1);
283 }
284
285 #[test]
286 fn send_decision_returns_false_for_unknown_id() {
287 let rt = make_runtime();
288 assert!(!rt.send_decision("unknown", "tc1".into(), make_resume()));
289 }
290
291 #[test]
292 fn send_decision_returns_false_for_ambiguous_id() {
293 let rt = make_runtime();
294 let (h1, _t1, _rx1) = rt.create_run_channels("r1".into());
295 rt.register_run("shared", h1).unwrap();
296 let (h2, _t2, _rx2) = rt.create_run_channels("shared".into());
297 rt.register_run("t2", h2).unwrap();
298
299 assert!(!rt.send_decision("shared", "tc1".into(), make_resume()));
300 }
301
302 #[test]
303 fn send_decision_returns_false_when_channel_closed() {
304 let rt = make_runtime();
305 let (handle, _token, rx) = rt.create_run_channels("r1".into());
306 rt.register_run("t1", handle).unwrap();
307 drop(rx);
308
309 assert!(!rt.send_decision("r1", "tc1".into(), make_resume()));
310 }
311
312 #[test]
315 fn cancel_after_unregister_returns_false() {
316 let rt = make_runtime();
317 let (handle, _token, _rx) = rt.create_run_channels("r1".into());
318 rt.register_run("t1", handle).unwrap();
319 rt.unregister_run("r1");
320
321 assert!(!rt.cancel("r1"));
322 assert!(!rt.cancel("t1"));
323 }
324
325 #[tokio::test]
328 async fn cancel_and_wait_returns_false_when_no_run() {
329 let rt = make_runtime();
330 assert!(!rt.cancel_and_wait_by_thread("unknown").await);
331 }
332
333 #[tokio::test]
334 async fn cancel_and_wait_completes_after_unregister() {
335 use std::sync::Arc;
336
337 let rt = Arc::new(make_runtime());
338 let (handle, token, _rx) = rt.create_run_channels("r1".into());
339 rt.register_run("t1", handle).unwrap();
340
341 let rt2 = Arc::clone(&rt);
343 tokio::spawn(async move {
344 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
345 rt2.unregister_run("r1");
346 });
347
348 assert!(rt.cancel_and_wait_by_thread("t1").await);
350 assert!(token.is_cancelled());
351 assert!(!rt.cancel_by_thread("t1"));
353 }
354}