awaken_runtime/runtime/agent_runtime/
control.rs1use awaken_contract::contract::message::Message;
4use awaken_contract::contract::suspension::ToolCallResume;
5
6use super::AgentRuntime;
7use super::active_registry::HandleLookup;
8
9#[cfg(not(test))]
10const CANCEL_WAIT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
11#[cfg(test)]
12const CANCEL_WAIT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(25);
13
14impl AgentRuntime {
15 pub async fn cancel_and_wait_by_thread(&self, thread_id: &str) -> bool {
20 let notify = match self.active_runs.cancel_and_get_notify(thread_id) {
21 Some(n) => n,
22 None => return false,
23 };
24 if !self.active_runs.has_active_thread(thread_id) {
25 return true;
26 }
27 tokio::time::timeout(CANCEL_WAIT_TIMEOUT, notify.notified())
29 .await
30 .is_ok()
31 || !self.active_runs.has_active_thread(thread_id)
32 }
33
34 pub fn cancel_by_thread(&self, thread_id: &str) -> bool {
36 if let Some(handle) = self.active_runs.get_by_thread_id(thread_id) {
37 handle.cancel();
38 true
39 } else {
40 false
41 }
42 }
43
44 pub fn cancel_by_run_id(&self, run_id: &str) -> bool {
46 if let Some(handle) = self.active_runs.get_by_run_id(run_id) {
47 handle.cancel();
48 true
49 } else {
50 false
51 }
52 }
53
54 pub fn cancel(&self, id: &str) -> bool {
57 match self.active_runs.lookup_strict(id) {
58 HandleLookup::Found(handle) => {
59 handle.cancel();
60 true
61 }
62 HandleLookup::NotFound => false,
63 HandleLookup::Ambiguous => {
64 tracing::warn!(id = %id, "cancel rejected: ambiguous control id");
65 false
66 }
67 }
68 }
69
70 pub fn send_decisions(
72 &self,
73 thread_id: &str,
74 decisions: Vec<(String, ToolCallResume)>,
75 ) -> bool {
76 if let Some(handle) = self.active_runs.get_by_thread_id(thread_id) {
77 if handle.send_decisions(decisions).is_err() {
78 tracing::warn!(
79 thread_id = %thread_id,
80 "send_decisions failed: channel closed"
81 );
82 return false;
83 }
84 true
85 } else {
86 false
87 }
88 }
89
90 pub fn send_decision(&self, id: &str, tool_call_id: String, resume: ToolCallResume) -> bool {
93 match self.active_runs.lookup_strict(id) {
94 HandleLookup::Found(handle) => handle.send_decision(tool_call_id, resume).is_ok(),
95 HandleLookup::NotFound => false,
96 HandleLookup::Ambiguous => {
97 tracing::warn!(id = %id, "send_decision rejected: ambiguous control id");
98 false
99 }
100 }
101 }
102
103 pub fn send_messages(&self, id: &str, messages: Vec<Message>) -> bool {
106 match self.active_runs.lookup_strict(id) {
107 HandleLookup::Found(handle) => handle.send_messages(messages),
108 HandleLookup::NotFound => false,
109 HandleLookup::Ambiguous => {
110 tracing::warn!(id = %id, "send_messages rejected: ambiguous control id");
111 false
112 }
113 }
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use awaken_contract::contract::suspension::{ResumeDecisionAction, ToolCallResume};
121 use serde_json::json;
122 use std::sync::Arc;
123
124 use crate::error::RuntimeError;
125 use crate::registry::{AgentResolver, ResolvedAgent};
126
127 struct StubResolver;
128 impl AgentResolver for StubResolver {
129 fn resolve(&self, _agent_id: &str) -> Result<ResolvedAgent, RuntimeError> {
130 Err(RuntimeError::ResolveFailed {
131 message: "stub".into(),
132 })
133 }
134 }
135
136 fn make_runtime() -> AgentRuntime {
137 AgentRuntime::new(Arc::new(StubResolver))
138 }
139
140 fn make_resume() -> ToolCallResume {
141 ToolCallResume {
142 decision_id: "d1".into(),
143 action: ResumeDecisionAction::Resume,
144 result: json!(null),
145 reason: None,
146 updated_at: 0,
147 }
148 }
149
150 #[test]
153 fn cancel_by_run_id_returns_true_when_registered() {
154 let rt = make_runtime();
155 let (handle, _token, _rx) = rt.create_run_channels("r1".into());
156 rt.register_run("t1", handle).unwrap();
157
158 assert!(rt.cancel_by_run_id("r1"));
159 }
160
161 #[test]
162 fn cancel_by_run_id_returns_false_when_not_found() {
163 let rt = make_runtime();
164 assert!(!rt.cancel_by_run_id("nonexistent"));
165 }
166
167 #[test]
168 fn cancel_by_run_id_signals_cancellation_token() {
169 let rt = make_runtime();
170 let (handle, token, _rx) = rt.create_run_channels("r1".into());
171 rt.register_run("t1", handle).unwrap();
172
173 assert!(!token.is_cancelled());
174 rt.cancel_by_run_id("r1");
175 assert!(token.is_cancelled());
176 }
177
178 #[test]
181 fn cancel_by_thread_returns_true_when_registered() {
182 let rt = make_runtime();
183 let (handle, _token, _rx) = rt.create_run_channels("r1".into());
184 rt.register_run("t1", handle).unwrap();
185
186 assert!(rt.cancel_by_thread("t1"));
187 }
188
189 #[test]
190 fn cancel_by_thread_returns_false_when_not_found() {
191 let rt = make_runtime();
192 assert!(!rt.cancel_by_thread("nonexistent"));
193 }
194
195 #[test]
196 fn cancel_by_thread_signals_cancellation_token() {
197 let rt = make_runtime();
198 let (handle, token, _rx) = rt.create_run_channels("r1".into());
199 rt.register_run("t1", handle).unwrap();
200
201 assert!(!token.is_cancelled());
202 rt.cancel_by_thread("t1");
203 assert!(token.is_cancelled());
204 }
205
206 #[test]
209 fn cancel_by_run_id_via_dual_index() {
210 let rt = make_runtime();
211 let (handle, token, _rx) = rt.create_run_channels("r1".into());
212 rt.register_run("t1", handle).unwrap();
213
214 assert!(rt.cancel("r1"));
215 assert!(token.is_cancelled());
216 }
217
218 #[test]
219 fn cancel_by_thread_id_via_dual_index() {
220 let rt = make_runtime();
221 let (handle, token, _rx) = rt.create_run_channels("r1".into());
222 rt.register_run("t1", handle).unwrap();
223
224 assert!(rt.cancel("t1"));
225 assert!(token.is_cancelled());
226 }
227
228 #[test]
229 fn cancel_returns_false_for_unknown_id() {
230 let rt = make_runtime();
231 assert!(!rt.cancel("unknown"));
232 }
233
234 #[test]
235 fn cancel_returns_false_for_ambiguous_id() {
236 let rt = make_runtime();
237 let (h1, _t1, _rx1) = rt.create_run_channels("r1".into());
239 rt.register_run("shared", h1).unwrap();
240 let (h2, _t2, _rx2) = rt.create_run_channels("shared".into());
241 rt.register_run("t2", h2).unwrap();
242
243 assert!(!rt.cancel("shared"));
245 }
246
247 #[test]
250 fn send_decisions_returns_true_and_delivers() {
251 let rt = make_runtime();
252 let (handle, _token, mut rx) = rt.create_run_channels("r1".into());
253 rt.register_run("t1", handle).unwrap();
254
255 let resume = make_resume();
256 assert!(rt.send_decisions("t1", vec![("tc1".into(), resume)]));
257
258 let batch = rx.try_recv().unwrap();
260 assert_eq!(batch.len(), 1);
261 assert_eq!(batch[0].0, "tc1");
262 }
263
264 #[test]
265 fn send_decisions_returns_false_for_unknown_thread() {
266 let rt = make_runtime();
267 assert!(!rt.send_decisions("unknown", vec![("tc1".into(), make_resume())]));
268 }
269
270 #[test]
271 fn send_decisions_returns_false_when_channel_closed() {
272 let rt = make_runtime();
273 let (handle, _token, rx) = rt.create_run_channels("r1".into());
274 rt.register_run("t1", handle).unwrap();
275
276 drop(rx);
278
279 assert!(!rt.send_decisions("t1", vec![("tc1".into(), make_resume())]));
280 }
281
282 #[test]
285 fn send_decision_by_run_id() {
286 let rt = make_runtime();
287 let (handle, _token, mut rx) = rt.create_run_channels("r1".into());
288 rt.register_run("t1", handle).unwrap();
289
290 assert!(rt.send_decision("r1", "tc1".into(), make_resume()));
291
292 let batch = rx.try_recv().unwrap();
293 assert_eq!(batch.len(), 1);
294 assert_eq!(batch[0].0, "tc1");
295 }
296
297 #[test]
298 fn send_decision_by_thread_id() {
299 let rt = make_runtime();
300 let (handle, _token, mut rx) = rt.create_run_channels("r1".into());
301 rt.register_run("t1", handle).unwrap();
302
303 assert!(rt.send_decision("t1", "tc1".into(), make_resume()));
304
305 let batch = rx.try_recv().unwrap();
306 assert_eq!(batch.len(), 1);
307 }
308
309 #[test]
310 fn send_decision_returns_false_for_unknown_id() {
311 let rt = make_runtime();
312 assert!(!rt.send_decision("unknown", "tc1".into(), make_resume()));
313 }
314
315 #[test]
316 fn send_decision_returns_false_for_ambiguous_id() {
317 let rt = make_runtime();
318 let (h1, _t1, _rx1) = rt.create_run_channels("r1".into());
319 rt.register_run("shared", h1).unwrap();
320 let (h2, _t2, _rx2) = rt.create_run_channels("shared".into());
321 rt.register_run("t2", h2).unwrap();
322
323 assert!(!rt.send_decision("shared", "tc1".into(), make_resume()));
324 }
325
326 #[test]
327 fn send_decision_returns_false_when_channel_closed() {
328 let rt = make_runtime();
329 let (handle, _token, rx) = rt.create_run_channels("r1".into());
330 rt.register_run("t1", handle).unwrap();
331 drop(rx);
332
333 assert!(!rt.send_decision("r1", "tc1".into(), make_resume()));
334 }
335
336 #[test]
339 fn send_messages_by_run_id_delivers_to_inbox() {
340 let rt = make_runtime();
341 let (inbox_tx, mut inbox_rx) = crate::inbox::inbox_channel();
342 let (handle, _token, _rx) =
343 rt.create_run_channels_with_inbox("r1".into(), None, Some(inbox_tx));
344 rt.register_run("t1", handle).unwrap();
345
346 assert!(rt.send_messages("r1", vec![Message::user("live")]));
347
348 let payload = inbox_rx.try_recv().expect("payload should be delivered");
349 let messages = crate::inbox::inbox_payload_messages(&payload);
350 assert_eq!(messages.len(), 1);
351 assert_eq!(messages[0].text(), "live");
352 }
353
354 #[test]
355 fn send_messages_returns_false_without_inbox() {
356 let rt = make_runtime();
357 let (handle, _token, _rx) = rt.create_run_channels("r1".into());
358 rt.register_run("t1", handle).unwrap();
359
360 assert!(!rt.send_messages("r1", vec![Message::user("live")]));
361 }
362
363 #[test]
364 fn send_messages_returns_false_for_closed_inbox() {
365 let rt = make_runtime();
366 let (inbox_tx, inbox_rx) = crate::inbox::inbox_channel();
367 drop(inbox_rx);
368 let (handle, _token, _rx) =
369 rt.create_run_channels_with_inbox("r1".into(), None, Some(inbox_tx));
370 rt.register_run("t1", handle).unwrap();
371
372 assert!(!rt.send_messages("r1", vec![Message::user("live")]));
373 }
374
375 #[test]
378 fn cancel_after_unregister_returns_false() {
379 let rt = make_runtime();
380 let (handle, _token, _rx) = rt.create_run_channels("r1".into());
381 rt.register_run("t1", handle).unwrap();
382 rt.unregister_run("r1");
383
384 assert!(!rt.cancel("r1"));
385 assert!(!rt.cancel("t1"));
386 }
387
388 #[tokio::test]
391 async fn cancel_and_wait_returns_false_when_no_run() {
392 let rt = make_runtime();
393 assert!(!rt.cancel_and_wait_by_thread("unknown").await);
394 }
395
396 #[tokio::test]
397 async fn cancel_and_wait_returns_false_when_run_does_not_unregister() {
398 let rt = make_runtime();
399 let (handle, token, _rx) = rt.create_run_channels("r1".into());
400 rt.register_run("t1", handle).unwrap();
401
402 assert!(!rt.cancel_and_wait_by_thread("t1").await);
403 assert!(token.is_cancelled());
404 }
405
406 #[tokio::test]
407 async fn cancel_and_wait_completes_after_unregister() {
408 use std::sync::Arc;
409
410 let rt = Arc::new(make_runtime());
411 let (handle, token, _rx) = rt.create_run_channels("r1".into());
412 rt.register_run("t1", handle).unwrap();
413
414 let rt2 = Arc::clone(&rt);
416 tokio::spawn(async move {
417 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
418 rt2.unregister_run("r1");
419 });
420
421 assert!(rt.cancel_and_wait_by_thread("t1").await);
423 assert!(token.is_cancelled());
424 assert!(!rt.cancel_by_thread("t1"));
426 }
427}