1use std::sync::Arc;
4
5use crate::backend::{
6 BackendControl, BackendLocalRootContext, BackendRootRunRequest, ExecutionBackend,
7 ExecutionBackendError, LocalBackend, execute_remote_root_lifecycle, execution_capabilities,
8 validate_root_execution_request,
9};
10use crate::loop_runner::{AgentLoopError, AgentRunResult};
11use crate::registry::{ExecutionResolver, ResolvedExecution};
12use awaken_contract::contract::active_agent::ActiveAgentIdKey;
13use awaken_contract::contract::event_sink::{EventSink, NullEventSink};
14use awaken_contract::contract::identity::RunIdentity;
15use awaken_contract::contract::message::{Message, Role, Visibility};
16use awaken_contract::contract::storage::RunRecord;
17use awaken_contract::contract::suspension::ToolCallStatus;
18use awaken_contract::now_ms;
19use awaken_contract::state::PersistedState;
20
21use super::AgentRuntime;
22use super::run_request::RunRequest;
23
24const DEFAULT_AGENT_ID: &str = "default";
25
26struct RunSlotGuard<'a> {
29 runtime: &'a AgentRuntime,
30 run_id: String,
31}
32
33impl Drop for RunSlotGuard<'_> {
34 fn drop(&mut self) {
35 self.runtime.unregister_run(&self.run_id);
36 }
37}
38
39struct PreparedLocalRootExecution {
40 messages: Vec<Message>,
41 phase_runtime: crate::phase::PhaseRuntime,
42 inbox: crate::inbox::InboxReceiver,
43}
44
45impl AgentRuntime {
46 pub async fn run_to_completion(
53 &self,
54 request: RunRequest,
55 ) -> Result<AgentRunResult, AgentLoopError> {
56 self.run(request, Arc::new(NullEventSink)).await
57 }
58
59 pub async fn run(
75 &self,
76 request: RunRequest,
77 sink: Arc<dyn EventSink>,
78 ) -> Result<AgentRunResult, AgentLoopError> {
79 let RunRequest {
80 messages: request_messages,
81 messages_already_persisted,
82 thread_id,
83 agent_id,
84 overrides,
85 decisions,
86 frontend_tools,
87 origin: req_origin,
88 run_mode,
89 adapter,
90 parent_run_id: req_parent_run_id,
91 parent_thread_id: req_parent_thread_id,
92 continue_run_id,
93 run_id_hint,
94 dispatch_id_hint,
95 dispatch_id,
96 session_id,
97 transport_request_id,
98 run_inbox,
99 } = request;
100 let new_messages = request_messages.clone();
101 let requested_continue_run_id = continue_run_id.clone();
102 let agent_id = self.resolve_agent_id(agent_id, &thread_id).await?;
103 let run_resolver: Arc<dyn ExecutionResolver> =
104 if let Some(snapshot) = self.registry_snapshot() {
105 Arc::new(crate::registry::resolve::RegistrySetResolver::new(
106 snapshot.into_registries(),
107 ))
108 } else {
109 self.execution_resolver_arc()
110 };
111 let resolved_execution = run_resolver
112 .resolve_execution(&agent_id)
113 .map_err(AgentLoopError::RuntimeError)?;
114 let capabilities =
115 execution_capabilities(&resolved_execution).map_err(local_root_execution_error)?;
116 let (run_id, is_continuation) = self
117 .next_root_run_id(
118 &thread_id,
119 continue_run_id,
120 run_id_hint,
121 dispatch_id_hint,
122 matches!(&resolved_execution, ResolvedExecution::Local(_)),
123 )
124 .await?;
125 let run_origin = match req_origin {
126 awaken_contract::contract::storage::RunRequestOrigin::User => {
127 awaken_contract::contract::identity::RunOrigin::User
128 }
129 awaken_contract::contract::storage::RunRequestOrigin::A2A => {
130 awaken_contract::contract::identity::RunOrigin::Subagent
131 }
132 awaken_contract::contract::storage::RunRequestOrigin::Internal => {
133 awaken_contract::contract::identity::RunOrigin::Internal
134 }
135 };
136 let mut run_identity = RunIdentity::new(
137 thread_id.clone(),
138 req_parent_thread_id,
139 run_id.clone(),
140 req_parent_run_id,
141 agent_id.clone(),
142 run_origin,
143 )
144 .with_run_mode(run_mode)
145 .with_adapter(adapter);
146 if let Some(dispatch_id) = dispatch_id {
147 run_identity = run_identity.with_dispatch_id(dispatch_id);
148 }
149 if let Some(session_id) = session_id {
150 run_identity = run_identity.with_session_id(session_id);
151 }
152 if let Some(transport_request_id) = transport_request_id {
153 run_identity = run_identity.with_transport_request_id(transport_request_id);
154 }
155
156 let mut run_inbox = run_inbox;
157 let (messages, phase_runtime, inbox, previous_non_local_state) = match &resolved_execution {
158 ResolvedExecution::Local(preflight_resolved) => {
159 let prepared = self
160 .prepare_local_root_execution(
161 preflight_resolved,
162 &thread_id,
163 request_messages,
164 messages_already_persisted,
165 &decisions,
166 run_inbox.take(),
167 )
168 .await?;
169 (
170 prepared.messages,
171 Some(prepared.phase_runtime),
172 Some(prepared.inbox),
173 None,
174 )
175 }
176 ResolvedExecution::NonLocal(_) => (
177 self.load_non_local_messages(
178 &thread_id,
179 request_messages,
180 messages_already_persisted,
181 )
182 .await?,
183 None,
184 run_inbox.take().map(|run_inbox| run_inbox.receiver),
185 self.load_non_local_state(&thread_id, requested_continue_run_id.as_deref())
186 .await?,
187 ),
188 };
189 let run_created_at = now_ms();
190
191 let (handle, cancellation_token, raw_decision_rx) =
192 self.create_run_channels(run_id.clone());
193 let runtime_cancellation_token = cancellation_token.clone();
194 let decision_rx = if capabilities.decisions {
195 Some(raw_decision_rx)
196 } else {
197 drop(raw_decision_rx);
198 None
199 };
200
201 let backend_request = BackendRootRunRequest {
202 agent_id: &agent_id,
203 messages,
204 new_messages,
205 sink: sink.clone(),
206 resolver: run_resolver.as_ref(),
207 run_identity: run_identity.clone(),
208 checkpoint_store: match &resolved_execution {
209 ResolvedExecution::Local(_) => phase_runtime.as_ref().and(self.storage.as_deref()),
210 ResolvedExecution::NonLocal(_) => self.storage.as_deref(),
211 },
212 control: BackendControl {
213 cancellation_token: capabilities
214 .cancellation
215 .supports_cooperative_token()
216 .then_some(cancellation_token),
217 decision_rx,
218 },
219 decisions,
220 overrides,
221 frontend_tools,
222 local: phase_runtime
223 .as_ref()
224 .map(|phase_runtime| BackendLocalRootContext { phase_runtime }),
225 inbox,
226 is_continuation,
227 };
228 validate_root_execution_request(&resolved_execution, &backend_request).map_err(
229 |error| match error {
230 ExecutionBackendError::Loop(loop_error) => loop_error,
231 other => AgentLoopError::RuntimeError(crate::RuntimeError::ResolveFailed {
232 message: other.to_string(),
233 }),
234 },
235 )?;
236
237 self.register_run(&thread_id, handle)
239 .map_err(AgentLoopError::RuntimeError)?;
240 let _guard = RunSlotGuard {
241 runtime: self,
242 run_id: run_id.clone(),
243 };
244
245 match &resolved_execution {
246 ResolvedExecution::Local(_) => {
247 let result = LocalBackend::new()
248 .execute_root(backend_request)
249 .await
250 .map_err(local_root_execution_error)?;
251 Ok(AgentRunResult {
252 run_id: run_id.clone(),
253 response: result.response.unwrap_or_default(),
254 termination: result.termination,
255 steps: result.steps,
256 })
257 }
258 ResolvedExecution::NonLocal(non_local) => {
259 execute_remote_root_lifecycle(
260 non_local,
261 backend_request,
262 run_created_at,
263 runtime_cancellation_token,
264 previous_non_local_state,
265 )
266 .await
267 }
268 }
269 }
270
271 async fn prepare_local_root_execution(
272 &self,
273 preflight_resolved: &crate::registry::ResolvedAgent,
274 thread_id: &str,
275 request_messages: Vec<Message>,
276 messages_already_persisted: bool,
277 decisions: &[(
278 String,
279 awaken_contract::contract::suspension::ToolCallResume,
280 )],
281 run_inbox: Option<super::run_request::RunInbox>,
282 ) -> Result<PreparedLocalRootExecution, AgentLoopError> {
283 let store = crate::state::StateStore::new();
284 let phase_runtime =
285 crate::phase::PhaseRuntime::new(store.clone()).map_err(AgentLoopError::PhaseError)?;
286 store
287 .install_plugin(crate::loop_runner::LoopStatePlugin)
288 .map_err(AgentLoopError::PhaseError)?;
289 let run_inbox = run_inbox.unwrap_or_else(|| {
290 let (sender, receiver) = crate::inbox::inbox_channel();
291 super::run_request::RunInbox { sender, receiver }
292 });
293 let owner_inbox = run_inbox.sender.clone();
294 crate::backend::LocalBackend::bind_local_execution_env(
295 &store,
296 preflight_resolved,
297 Some(&owner_inbox),
298 )
299 .map_err(AgentLoopError::PhaseError)?;
300
301 let mut messages = if let Some(ref ts) = self.storage {
302 if let Some(prev_run) = ts
303 .latest_run(thread_id)
304 .await
305 .map_err(|e| AgentLoopError::StorageError(e.to_string()))?
306 && let Some(persisted) = prev_run.state
307 {
308 store
309 .restore_thread_scoped(persisted, awaken_contract::UnknownKeyPolicy::Skip)
310 .map_err(AgentLoopError::PhaseError)?;
311 }
312 ts.load_messages(thread_id)
313 .await
314 .map_err(|e| AgentLoopError::StorageError(e.to_string()))?
315 .unwrap_or_default()
316 } else {
317 vec![]
318 };
319 if should_supersede_suspended_calls(&request_messages, decisions) {
320 strip_superseded_suspended_tool_calls(&mut messages, &store);
321 }
322 strip_unpaired_tool_calls(&mut messages);
323 if !messages_already_persisted {
324 messages.extend(request_messages);
325 }
326
327 Ok(PreparedLocalRootExecution {
328 messages,
329 phase_runtime,
330 inbox: run_inbox.receiver,
331 })
332 }
333
334 async fn load_non_local_messages(
335 &self,
336 thread_id: &str,
337 request_messages: Vec<Message>,
338 messages_already_persisted: bool,
339 ) -> Result<Vec<Message>, AgentLoopError> {
340 let mut messages = if let Some(ref storage) = self.storage {
341 storage
342 .load_messages(thread_id)
343 .await
344 .map_err(|e| AgentLoopError::StorageError(e.to_string()))?
345 .unwrap_or_default()
346 } else {
347 Vec::new()
348 };
349 strip_unpaired_tool_calls(&mut messages);
350 if !messages_already_persisted {
351 messages.extend(request_messages);
352 }
353 Ok(messages)
354 }
355
356 async fn next_root_run_id(
357 &self,
358 thread_id: &str,
359 continue_run_id: Option<String>,
360 run_id_hint: Option<String>,
361 dispatch_id_hint: Option<String>,
362 allow_waiting_reuse: bool,
363 ) -> Result<(String, bool), AgentLoopError> {
364 if let Some(run_id) = continue_run_id {
365 let Some(ref ts) = self.storage else {
366 return Err(AgentLoopError::InvalidResume(format!(
367 "continue_run_id '{run_id}' requires run storage"
368 )));
369 };
370 if ts
371 .load_run(&run_id)
372 .await
373 .map_err(|e| AgentLoopError::StorageError(e.to_string()))?
374 .is_some()
375 {
376 return Ok((run_id, true));
377 }
378 return Err(AgentLoopError::InvalidResume(format!(
379 "continue_run_id '{run_id}' does not reference an existing run"
380 )));
381 }
382 if let Some(run_id) = run_id_hint.and_then(|id| {
383 let trimmed = id.trim();
384 (!trimmed.is_empty()).then(|| trimmed.to_string())
385 }) {
386 if let Some(ref ts) = self.storage
387 && let Some(existing) = ts
388 .load_run(&run_id)
389 .await
390 .map_err(|e| AgentLoopError::StorageError(e.to_string()))?
391 {
392 if existing.status == awaken_contract::contract::lifecycle::RunStatus::Created {
393 return Ok((run_id, false));
394 }
395 return Err(AgentLoopError::InvalidResume(format!(
396 "run_id_hint '{run_id}' already exists as a run"
397 )));
398 }
399 return Ok((run_id, false));
400 }
401 if let Some(run_id) = dispatch_id_hint.and_then(|id| {
402 let trimmed = id.trim();
403 (!trimmed.is_empty()).then(|| trimmed.to_string())
404 }) {
405 if let Some(ref ts) = self.storage
406 && ts
407 .load_run(&run_id)
408 .await
409 .map_err(|e| AgentLoopError::StorageError(e.to_string()))?
410 .is_some()
411 {
412 return Err(AgentLoopError::InvalidResume(format!(
413 "dispatch_id_hint '{run_id}' already exists as a run"
414 )));
415 }
416 return Ok((run_id, false));
417 }
418 if allow_waiting_reuse && let Some(prev) = self.reusable_waiting_run(thread_id).await? {
419 return Ok((prev.run_id.clone(), true));
420 }
421 Ok((uuid::Uuid::now_v7().to_string(), false))
422 }
423
424 async fn reusable_waiting_run(
425 &self,
426 thread_id: &str,
427 ) -> Result<Option<RunRecord>, AgentLoopError> {
428 let Some(ref ts) = self.storage else {
429 return Ok(None);
430 };
431
432 if let Some(thread) = ts
433 .load_thread(thread_id)
434 .await
435 .map_err(|e| AgentLoopError::StorageError(e.to_string()))?
436 && let Some(open_run_id) = thread.open_run_id.as_deref()
437 && let Some(run) = ts
438 .load_run(open_run_id)
439 .await
440 .map_err(|e| AgentLoopError::StorageError(e.to_string()))?
441 && run.thread_id == thread_id
442 && run.is_resumable_waiting()
443 {
444 return Ok(Some(run));
445 }
446
447 Ok(ts
448 .latest_run(thread_id)
449 .await
450 .map_err(|e| AgentLoopError::StorageError(e.to_string()))?
451 .filter(RunRecord::is_resumable_waiting))
452 }
453
454 async fn resolve_agent_id(
455 &self,
456 requested_agent_id: Option<String>,
457 thread_id: &str,
458 ) -> Result<String, AgentLoopError> {
459 if let Some(agent_id) = requested_agent_id {
460 return Ok(agent_id);
461 }
462
463 if let Some(inferred) = self.infer_agent_id_from_thread(thread_id).await? {
464 return Ok(inferred);
465 }
466
467 Ok(DEFAULT_AGENT_ID.to_string())
468 }
469
470 async fn infer_agent_id_from_thread(
471 &self,
472 thread_id: &str,
473 ) -> Result<Option<String>, AgentLoopError> {
474 let Some(storage) = &self.storage else {
475 return Ok(None);
476 };
477
478 let Some(prev_run) = storage
479 .latest_run(thread_id)
480 .await
481 .map_err(|e| AgentLoopError::StorageError(e.to_string()))?
482 else {
483 return Ok(None);
484 };
485
486 if let Some(agent_id) = prev_run.state.as_ref().and_then(active_agent_from_state) {
487 return Ok(Some(agent_id));
488 }
489
490 let agent_id = prev_run.agent_id.trim();
491 if agent_id.is_empty() {
492 Ok(None)
493 } else {
494 Ok(Some(agent_id.to_string()))
495 }
496 }
497
498 async fn load_non_local_state(
499 &self,
500 thread_id: &str,
501 continue_run_id: Option<&str>,
502 ) -> Result<Option<PersistedState>, AgentLoopError> {
503 let Some(storage) = &self.storage else {
504 return Ok(None);
505 };
506
507 if let Some(run_id) = continue_run_id {
508 return Ok(storage
509 .load_run(run_id)
510 .await
511 .map_err(|error| AgentLoopError::StorageError(error.to_string()))?
512 .and_then(|run| run.state));
513 }
514
515 Ok(storage
516 .latest_run(thread_id)
517 .await
518 .map_err(|error| AgentLoopError::StorageError(error.to_string()))?
519 .and_then(|run| run.state))
520 }
521}
522
523fn local_root_execution_error(error: ExecutionBackendError) -> AgentLoopError {
524 match error {
525 ExecutionBackendError::Loop(loop_error) => loop_error,
526 other => AgentLoopError::RuntimeError(crate::RuntimeError::ResolveFailed {
527 message: other.to_string(),
528 }),
529 }
530}
531
532fn active_agent_from_state(state: &PersistedState) -> Option<String> {
533 state
534 .extensions
535 .get(<ActiveAgentIdKey as awaken_contract::StateKey>::KEY)
536 .and_then(|value| value.as_str())
537 .map(str::trim)
538 .filter(|v| !v.is_empty())
539 .map(ToOwned::to_owned)
540}
541
542fn strip_unpaired_tool_calls(messages: &mut Vec<Message>) {
549 use std::collections::HashSet;
550
551 let answered: HashSet<String> = messages
553 .iter()
554 .filter(|m| m.role == Role::Tool)
555 .filter_map(|m| m.tool_call_id.clone())
556 .collect();
557
558 for msg in messages.iter_mut() {
560 if msg.role != Role::Assistant {
561 continue;
562 }
563 if let Some(ref mut calls) = msg.tool_calls {
564 calls.retain(|c| answered.contains(&c.id));
565 if calls.is_empty() {
566 msg.tool_calls = None;
567 }
568 }
569 }
570
571 while let Some(last) = messages.last() {
573 if last.role == Role::Assistant
574 && last.tool_calls.is_none()
575 && last.text().trim().is_empty()
576 {
577 messages.pop();
578 } else {
579 break;
580 }
581 }
582}
583
584fn should_supersede_suspended_calls(
585 request_messages: &[Message],
586 decisions: &[(
587 String,
588 awaken_contract::contract::suspension::ToolCallResume,
589 )],
590) -> bool {
591 decisions.is_empty()
592 && request_messages
593 .iter()
594 .any(|message| message.role == Role::User && message.visibility == Visibility::All)
595}
596
597fn strip_superseded_suspended_tool_calls(
598 messages: &mut Vec<Message>,
599 store: &crate::state::StateStore,
600) {
601 use std::collections::HashSet;
602
603 let suspended_ids: HashSet<String> = store
604 .read::<crate::agent::state::ToolCallStates>()
605 .unwrap_or_default()
606 .calls
607 .into_iter()
608 .filter_map(|(call_id, state)| {
609 (state.status == ToolCallStatus::Suspended).then_some(call_id)
610 })
611 .collect();
612 if suspended_ids.is_empty() {
613 return;
614 }
615
616 for message in messages.iter_mut() {
617 if message.role != Role::Assistant {
618 continue;
619 }
620 if let Some(ref mut calls) = message.tool_calls {
621 calls.retain(|call| !suspended_ids.contains(&call.id));
622 if calls.is_empty() {
623 message.tool_calls = None;
624 }
625 }
626 }
627
628 messages.retain(|message| {
629 !(message.role == Role::Tool
630 && message
631 .tool_call_id
632 .as_ref()
633 .is_some_and(|call_id| suspended_ids.contains(call_id)))
634 });
635
636 while let Some(last) = messages.last() {
637 if last.role == Role::Assistant
638 && last.tool_calls.is_none()
639 && last.text().trim().is_empty()
640 {
641 messages.pop();
642 } else {
643 break;
644 }
645 }
646}
647
648#[cfg(test)]
649mod tests {
650 use super::super::*;
651 #[cfg(feature = "a2a")]
652 use crate::extensions::a2a::{
653 AgentBackend, AgentBackendError, AgentBackendFactory, AgentBackendFactoryError,
654 DelegateRunResult, DelegateRunStatus,
655 };
656 use crate::loop_runner::build_agent_env;
657 use crate::plugins::{Plugin, PluginDescriptor, PluginRegistrar};
658 #[cfg(feature = "a2a")]
659 use crate::registry::memory::{
660 MapAgentSpecRegistry, MapBackendRegistry, MapModelRegistry, MapPluginSource,
661 MapProviderRegistry, MapToolRegistry,
662 };
663 #[cfg(feature = "a2a")]
664 use crate::registry::snapshot::RegistryHandle;
665 #[cfg(feature = "a2a")]
666 use crate::registry::traits::{BackendRegistry, ModelBinding, RegistrySet};
667 use crate::registry::{AgentResolver, ResolvedAgent};
668 use crate::state::{KeyScope, StateCommand, StateKey, StateKeyOptions};
669 use crate::{PhaseContext, PhaseHook, ToolPolicyHook};
670 use async_trait::async_trait;
671 use awaken_contract::PersistedState;
672 use awaken_contract::contract::active_agent::ActiveAgentIdKey;
673 use awaken_contract::contract::content::ContentBlock;
674 use awaken_contract::contract::event::AgentEvent;
675 use awaken_contract::contract::event_sink::{EventSink, NullEventSink, VecEventSink};
676 use awaken_contract::contract::executor::{
677 InferenceExecutionError, InferenceRequest, LlmExecutor,
678 };
679 use awaken_contract::contract::inference::{InferenceOverride, StopReason, StreamResult};
680 use awaken_contract::contract::lifecycle::{RunStatus, TerminationReason};
681 use awaken_contract::contract::message::Message;
682 use awaken_contract::contract::storage::{
683 RunQuery, RunRecord, RunStore, RunWaitingState, ThreadRunStore, ThreadStore, WaitingReason,
684 };
685 use awaken_contract::contract::suspension::ResumeDecisionAction;
686 use awaken_contract::contract::suspension::ToolCallResume;
687 use awaken_contract::contract::tool::{
688 Tool, ToolCallContext, ToolDescriptor, ToolError, ToolOutput, ToolResult,
689 };
690 use awaken_contract::contract::tool_intercept::{
691 AdapterKind, RunMode, ToolPolicyContext, ToolPolicyDecision,
692 };
693 #[cfg(feature = "a2a")]
694 use awaken_contract::registry_spec::{AgentSpec, RemoteEndpoint};
695 use awaken_stores::InMemoryStore;
696 use serde_json::{Value, json};
697 use std::collections::HashMap;
698 use std::sync::atomic::{AtomicUsize, Ordering};
699 use std::sync::{Arc, Mutex};
700
701 struct ScriptedLlm {
702 responses: Mutex<Vec<StreamResult>>,
703 seen_overrides: Mutex<Vec<Option<InferenceOverride>>>,
704 }
705
706 impl ScriptedLlm {
707 fn new(responses: Vec<StreamResult>) -> Self {
708 Self {
709 responses: Mutex::new(responses),
710 seen_overrides: Mutex::new(Vec::new()),
711 }
712 }
713 }
714
715 #[async_trait]
716 impl LlmExecutor for ScriptedLlm {
717 async fn execute(
718 &self,
719 request: InferenceRequest,
720 ) -> Result<StreamResult, InferenceExecutionError> {
721 self.seen_overrides
722 .lock()
723 .expect("lock poisoned")
724 .push(request.overrides.clone());
725 let mut responses = self.responses.lock().expect("lock poisoned");
726 if responses.is_empty() {
727 Ok(StreamResult {
728 content: vec![ContentBlock::text("done")],
729 tool_calls: vec![],
730 usage: None,
731 stop_reason: Some(StopReason::EndTurn),
732 has_incomplete_tool_calls: false,
733 })
734 } else {
735 Ok(responses.remove(0))
736 }
737 }
738
739 fn name(&self) -> &str {
740 "scripted"
741 }
742 }
743
744 #[cfg(feature = "a2a")]
745 struct StaticRemoteBackend {
746 response: String,
747 delay_ms: u64,
748 cancellation: bool,
749 continuation: bool,
750 abort_count: Arc<AtomicUsize>,
751 termination: TerminationReason,
752 status_reason: Option<String>,
753 }
754
755 #[cfg(feature = "a2a")]
756 #[async_trait]
757 impl AgentBackend for StaticRemoteBackend {
758 fn capabilities(&self) -> crate::backend::BackendCapabilities {
759 crate::backend::BackendCapabilities {
760 cancellation: if self.cancellation {
761 crate::backend::BackendCancellationCapability::RemoteAbort
762 } else {
763 crate::backend::BackendCancellationCapability::None
764 },
765 decisions: false,
766 overrides: false,
767 frontend_tools: false,
768 continuation: if self.continuation {
769 crate::backend::BackendContinuationCapability::RemoteState
770 } else {
771 crate::backend::BackendContinuationCapability::None
772 },
773 waits: crate::backend::BackendWaitCapability::None,
774 transcript: crate::backend::BackendTranscriptCapability::SinglePrompt,
775 output: crate::backend::BackendOutputCapability::Text,
776 }
777 }
778
779 async fn abort(
780 &self,
781 _request: crate::backend::BackendAbortRequest<'_>,
782 ) -> Result<(), AgentBackendError> {
783 self.abort_count.fetch_add(1, Ordering::SeqCst);
784 Ok(())
785 }
786
787 async fn execute_root(
788 &self,
789 request: crate::backend::BackendRootRunRequest<'_>,
790 ) -> Result<DelegateRunResult, AgentBackendError> {
791 if self.delay_ms > 0 {
792 tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
793 }
794 Ok(DelegateRunResult {
795 agent_id: request.agent_id.to_string(),
796 status: match &self.termination {
797 TerminationReason::Cancelled => DelegateRunStatus::Cancelled,
798 TerminationReason::Error(message) => DelegateRunStatus::Failed(message.clone()),
799 _ => DelegateRunStatus::Completed,
800 },
801 termination: self.termination.clone(),
802 status_reason: self.status_reason.clone(),
803 response: Some(self.response.clone()),
804 output: crate::backend::BackendRunOutput::from_text(Some(self.response.clone())),
805 steps: 1,
806 run_id: Some("child-remote-run".into()),
807 inbox: None,
808 state: None,
809 })
810 }
811 }
812
813 #[cfg(feature = "a2a")]
814 struct StaticRemoteBackendFactory {
815 abort_count: Arc<AtomicUsize>,
816 }
817
818 #[cfg(feature = "a2a")]
819 impl AgentBackendFactory for StaticRemoteBackendFactory {
820 fn backend(&self) -> &str {
821 "test-remote"
822 }
823
824 fn build(
825 &self,
826 endpoint: &RemoteEndpoint,
827 ) -> Result<Arc<dyn AgentBackend>, AgentBackendFactoryError> {
828 if endpoint.backend != "test-remote" {
829 return Err(AgentBackendFactoryError::InvalidConfig(format!(
830 "unexpected backend '{}'",
831 endpoint.backend
832 )));
833 }
834 let delay_ms = endpoint
835 .options
836 .get("delay_ms")
837 .and_then(serde_json::Value::as_u64)
838 .unwrap_or(0);
839 let cancellation = endpoint
840 .options
841 .get("supports_cancellation")
842 .and_then(serde_json::Value::as_bool)
843 .unwrap_or(true);
844 let continuation = endpoint
845 .options
846 .get("supports_continuation")
847 .and_then(serde_json::Value::as_bool)
848 .unwrap_or(false);
849 let termination = match endpoint.options.get("termination").and_then(|v| v.as_str()) {
850 Some("suspended") => TerminationReason::Suspended,
851 Some("cancelled") => TerminationReason::Cancelled,
852 Some("error") => TerminationReason::Error("remote root error".into()),
853 _ => TerminationReason::NaturalEnd,
854 };
855 let status_reason = endpoint
856 .options
857 .get("status_reason")
858 .and_then(serde_json::Value::as_str)
859 .map(ToOwned::to_owned);
860 Ok(Arc::new(StaticRemoteBackend {
861 response: "remote root response".into(),
862 delay_ms,
863 cancellation,
864 continuation,
865 abort_count: self.abort_count.clone(),
866 termination,
867 status_reason,
868 }))
869 }
870 }
871
872 #[cfg(feature = "a2a")]
873 fn build_remote_runtime(
874 endpoint: RemoteEndpoint,
875 abort_count: Arc<AtomicUsize>,
876 ) -> AgentRuntime {
877 let mut models = MapModelRegistry::new();
878 models
879 .register_model(
880 "test-model",
881 ModelBinding {
882 provider_id: "mock".into(),
883 upstream_model: "mock-model".into(),
884 },
885 )
886 .unwrap();
887
888 let mut providers = MapProviderRegistry::new();
889 providers
890 .register_provider("mock", Arc::new(ScriptedLlm::new(Vec::new())))
891 .unwrap();
892
893 let mut agents = MapAgentSpecRegistry::new();
894 agents
895 .register_spec(
896 AgentSpec::new("remote-root")
897 .with_model_id("test-model")
898 .with_system_prompt("remote root")
899 .with_endpoint(endpoint),
900 )
901 .unwrap();
902
903 let mut backends = MapBackendRegistry::new();
904 backends
905 .register_backend_factory(Arc::new(StaticRemoteBackendFactory { abort_count }))
906 .unwrap();
907
908 let registries = RegistrySet {
909 agents: Arc::new(agents),
910 tools: Arc::new(MapToolRegistry::new()),
911 models: Arc::new(models),
912 providers: Arc::new(providers),
913 plugins: Arc::new(MapPluginSource::new()),
914 backends: Arc::new(backends) as Arc<dyn BackendRegistry>,
915 };
916 let handle = RegistryHandle::new(registries.clone());
917 AgentRuntime::new(Arc::new(
918 crate::registry::resolve::DynamicRegistryResolver::new(handle.clone()),
919 ))
920 .with_registry_handle(handle)
921 .with_thread_run_store(Arc::new(InMemoryStore::new()))
922 }
923
924 #[cfg(feature = "a2a")]
925 #[tokio::test]
926 async fn run_supports_endpoint_root_agents() {
927 let runtime = build_remote_runtime(
928 RemoteEndpoint {
929 backend: "test-remote".into(),
930 base_url: "https://remote.example.com".into(),
931 ..Default::default()
932 },
933 Arc::new(AtomicUsize::new(0)),
934 );
935
936 let sink = Arc::new(VecEventSink::new());
937 let result = runtime
938 .run(
939 RunRequest::new("remote-thread", vec![Message::user("hello")])
940 .with_agent_id("remote-root"),
941 sink.clone(),
942 )
943 .await
944 .expect("endpoint root run should succeed");
945
946 assert_eq!(result.response, "remote root response");
947 assert!(matches!(result.termination, TerminationReason::NaturalEnd));
948
949 let events = sink.events();
950 assert!(matches!(events.first(), Some(AgentEvent::RunStart { .. })));
951 assert!(events.iter().any(|event| matches!(
952 event,
953 AgentEvent::TextDelta { delta } if delta == "remote root response"
954 )));
955 assert!(events.iter().any(|event| matches!(
956 event,
957 AgentEvent::RunFinish {
958 termination: TerminationReason::NaturalEnd,
959 ..
960 }
961 )));
962
963 let latest_run = runtime
964 .thread_run_store()
965 .expect("store")
966 .latest_run("remote-thread")
967 .await
968 .expect("run lookup should succeed")
969 .expect("run record should be persisted");
970 assert_eq!(latest_run.agent_id, "remote-root");
971 assert_eq!(latest_run.status, RunStatus::Done);
972
973 let messages = runtime
974 .thread_run_store()
975 .expect("store")
976 .load_messages("remote-thread")
977 .await
978 .expect("message lookup should succeed")
979 .expect("messages should be persisted");
980 assert!(messages.iter().any(|message| {
981 message.role == awaken_contract::contract::message::Role::Assistant
982 && message.text() == "remote root response"
983 }));
984 }
985
986 #[cfg(feature = "a2a")]
987 #[tokio::test]
988 async fn run_persists_non_local_waiting_reason_from_backend() {
989 let runtime = build_remote_runtime(
990 RemoteEndpoint {
991 backend: "test-remote".into(),
992 base_url: "https://remote.example.com".into(),
993 options: std::collections::BTreeMap::from([
994 ("termination".into(), json!("suspended")),
995 ("status_reason".into(), json!("input_required")),
996 ]),
997 ..Default::default()
998 },
999 Arc::new(AtomicUsize::new(0)),
1000 );
1001
1002 let sink = Arc::new(VecEventSink::new());
1003 let result = runtime
1004 .run(
1005 RunRequest::new("remote-thread-waiting", vec![Message::user("hello")])
1006 .with_agent_id("remote-root"),
1007 sink.clone(),
1008 )
1009 .await
1010 .expect("endpoint root run should suspend cleanly");
1011
1012 assert_eq!(result.termination, TerminationReason::Suspended);
1013
1014 let latest_run = runtime
1015 .thread_run_store()
1016 .expect("store")
1017 .latest_run("remote-thread-waiting")
1018 .await
1019 .expect("run lookup should succeed")
1020 .expect("run record should be persisted");
1021 assert_eq!(latest_run.status, RunStatus::Waiting);
1022 assert_eq!(latest_run.waiting_reason(), Some(WaitingReason::UserInput));
1023
1024 let events = sink.events();
1025 assert!(events.iter().any(|event| matches!(
1026 event,
1027 AgentEvent::RunFinish {
1028 termination: TerminationReason::Suspended,
1029 result: Some(result),
1030 ..
1031 } if result["status_reason"].as_str() == Some("input_required")
1032 )));
1033 }
1034
1035 #[cfg(feature = "a2a")]
1036 #[tokio::test]
1037 async fn run_rejects_remote_overrides_without_backend_capability() {
1038 let runtime = build_remote_runtime(
1039 RemoteEndpoint {
1040 backend: "test-remote".into(),
1041 base_url: "https://remote.example.com".into(),
1042 ..Default::default()
1043 },
1044 Arc::new(AtomicUsize::new(0)),
1045 );
1046
1047 let error = runtime
1048 .run(
1049 RunRequest::new("remote-thread-overrides", vec![Message::user("hello")])
1050 .with_agent_id("remote-root")
1051 .with_overrides(InferenceOverride {
1052 temperature: Some(0.2),
1053 ..Default::default()
1054 }),
1055 Arc::new(VecEventSink::new()),
1056 )
1057 .await
1058 .expect_err("remote backend should reject overrides");
1059
1060 assert!(error.to_string().contains("does not support: overrides"));
1061 }
1062
1063 #[cfg(feature = "a2a")]
1064 #[tokio::test]
1065 async fn run_allows_non_local_root_backends_without_cancellation_capability() {
1066 let abort_count = Arc::new(AtomicUsize::new(0));
1067 let runtime = Arc::new(build_remote_runtime(
1068 RemoteEndpoint {
1069 backend: "test-remote".into(),
1070 base_url: "https://remote.example.com".into(),
1071 options: std::collections::BTreeMap::from([
1072 ("delay_ms".into(), json!(5_000_u64)),
1073 ("supports_cancellation".into(), json!(false)),
1074 ]),
1075 ..Default::default()
1076 },
1077 abort_count.clone(),
1078 ));
1079
1080 let run_handle = {
1081 let runtime = runtime.clone();
1082 tokio::spawn(async move {
1083 runtime
1084 .run(
1085 RunRequest::new("remote-thread-cancel", vec![Message::user("hello")])
1086 .with_agent_id("remote-root"),
1087 Arc::new(VecEventSink::new()),
1088 )
1089 .await
1090 })
1091 };
1092
1093 let mut cancelled = false;
1094 for _ in 0..20 {
1095 if runtime.cancel("remote-thread-cancel") {
1096 cancelled = true;
1097 break;
1098 }
1099 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1100 }
1101 assert!(cancelled);
1102
1103 let result = run_handle
1104 .await
1105 .expect("task should join")
1106 .expect("cancelled run should still return a result");
1107 assert!(matches!(result.termination, TerminationReason::Cancelled));
1108 assert_eq!(abort_count.load(Ordering::SeqCst), 0);
1109 }
1110
1111 #[cfg(feature = "a2a")]
1112 #[tokio::test]
1113 async fn run_non_local_root_cancel_invokes_backend_abort_hook() {
1114 let abort_count = Arc::new(AtomicUsize::new(0));
1115 let runtime = Arc::new(build_remote_runtime(
1116 RemoteEndpoint {
1117 backend: "test-remote".into(),
1118 base_url: "https://remote.example.com".into(),
1119 options: std::collections::BTreeMap::from([("delay_ms".into(), json!(5_000_u64))]),
1120 ..Default::default()
1121 },
1122 abort_count.clone(),
1123 ));
1124
1125 let run_handle = {
1126 let runtime = runtime.clone();
1127 tokio::spawn(async move {
1128 runtime
1129 .run(
1130 RunRequest::new("remote-thread-abort", vec![Message::user("hello")])
1131 .with_agent_id("remote-root"),
1132 Arc::new(VecEventSink::new()),
1133 )
1134 .await
1135 })
1136 };
1137
1138 let mut cancelled = false;
1139 for _ in 0..20 {
1140 if runtime.cancel("remote-thread-abort") {
1141 cancelled = true;
1142 break;
1143 }
1144 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1145 }
1146 assert!(cancelled);
1147 let _ = run_handle.await.expect("task should join");
1148
1149 assert_eq!(abort_count.load(Ordering::SeqCst), 1);
1150 }
1151
1152 #[cfg(feature = "a2a")]
1153 #[tokio::test]
1154 async fn run_rejects_remote_resume_decisions_without_backend_capability() {
1155 let runtime = build_remote_runtime(
1156 RemoteEndpoint {
1157 backend: "test-remote".into(),
1158 base_url: "https://remote.example.com".into(),
1159 ..Default::default()
1160 },
1161 Arc::new(AtomicUsize::new(0)),
1162 );
1163
1164 let error = runtime
1165 .run(
1166 RunRequest::new("remote-thread-decisions", vec![Message::user("hello")])
1167 .with_agent_id("remote-root")
1168 .with_decisions(vec![(
1169 "call-1".into(),
1170 ToolCallResume {
1171 decision_id: "d1".into(),
1172 action: ResumeDecisionAction::Resume,
1173 result: Value::Null,
1174 reason: None,
1175 updated_at: 1,
1176 },
1177 )]),
1178 Arc::new(VecEventSink::new()),
1179 )
1180 .await
1181 .expect_err("remote backend should reject resume decisions");
1182
1183 assert!(error.to_string().contains("does not support: decisions"));
1184 }
1185
1186 #[cfg(feature = "a2a")]
1187 #[tokio::test]
1188 async fn run_rejects_remote_frontend_tools_without_backend_capability() {
1189 let runtime = build_remote_runtime(
1190 RemoteEndpoint {
1191 backend: "test-remote".into(),
1192 base_url: "https://remote.example.com".into(),
1193 ..Default::default()
1194 },
1195 Arc::new(AtomicUsize::new(0)),
1196 );
1197
1198 let error = runtime
1199 .run(
1200 RunRequest::new("remote-thread-frontend", vec![Message::user("hello")])
1201 .with_agent_id("remote-root")
1202 .with_frontend_tools(vec![ToolDescriptor::new(
1203 "browser",
1204 "browser",
1205 "frontend tool",
1206 )]),
1207 Arc::new(VecEventSink::new()),
1208 )
1209 .await
1210 .expect_err("remote backend should reject frontend tools");
1211
1212 assert!(
1213 error
1214 .to_string()
1215 .contains("does not support: frontend_tools")
1216 );
1217 }
1218
1219 #[tokio::test]
1220 async fn run_rejects_remote_continuation_without_backend_capability() {
1221 let runtime = build_remote_runtime(
1222 RemoteEndpoint {
1223 backend: "test-remote".into(),
1224 base_url: "https://remote.example.com".into(),
1225 ..Default::default()
1226 },
1227 Arc::new(AtomicUsize::new(0)),
1228 );
1229 let store = runtime.thread_run_store().expect("store");
1230 let existing_run = RunRecord {
1231 run_id: "existing-run".into(),
1232 thread_id: "remote-thread-cont".into(),
1233 agent_id: "remote-root".into(),
1234 parent_run_id: None,
1235 request: None,
1236 input: None,
1237 output: None,
1238 status: RunStatus::Waiting,
1239 termination_reason: None,
1240 final_output: None,
1241 error_payload: None,
1242 dispatch_id: None,
1243 session_id: None,
1244 transport_request_id: None,
1245 waiting: None,
1246 outcome: None,
1247 created_at: 1,
1248 started_at: None,
1249 finished_at: None,
1250 updated_at: 1,
1251 steps: 1,
1252 input_tokens: 0,
1253 output_tokens: 0,
1254 state: None,
1255 };
1256 store
1257 .checkpoint(
1258 "remote-thread-cont",
1259 &[Message::user("previous remote turn")],
1260 &existing_run,
1261 )
1262 .await
1263 .expect("seed existing remote run");
1264
1265 let error = runtime
1266 .run(
1267 RunRequest::new("remote-thread-cont", vec![Message::user("hello")])
1268 .with_agent_id("remote-root")
1269 .with_continue_run_id("existing-run"),
1270 Arc::new(VecEventSink::new()),
1271 )
1272 .await
1273 .expect_err("remote backend should reject continuation");
1274
1275 assert!(error.to_string().contains("does not support: continuation"));
1276 }
1277
1278 #[tokio::test]
1279 async fn run_rejects_unknown_continue_run_id() {
1280 let runtime = build_remote_runtime(
1281 RemoteEndpoint {
1282 backend: "test-remote".into(),
1283 base_url: "https://remote.example.com".into(),
1284 options: std::collections::BTreeMap::from([(
1285 "supports_continuation".into(),
1286 json!(true),
1287 )]),
1288 ..Default::default()
1289 },
1290 Arc::new(AtomicUsize::new(0)),
1291 );
1292
1293 let error = runtime
1294 .run(
1295 RunRequest::new("remote-thread-missing-cont", vec![Message::user("resume")])
1296 .with_agent_id("remote-root")
1297 .with_continue_run_id("missing-run"),
1298 Arc::new(VecEventSink::new()),
1299 )
1300 .await
1301 .expect_err("unknown continuation run id should fail");
1302
1303 assert!(
1304 error
1305 .to_string()
1306 .contains("continue_run_id 'missing-run' does not reference an existing run")
1307 );
1308 }
1309
1310 #[tokio::test]
1311 async fn run_uses_dispatch_id_hint_for_new_run_identity() {
1312 let runtime = build_remote_runtime(
1313 RemoteEndpoint {
1314 backend: "test-remote".into(),
1315 base_url: "https://remote.example.com".into(),
1316 ..Default::default()
1317 },
1318 Arc::new(AtomicUsize::new(0)),
1319 );
1320
1321 runtime
1322 .run(
1323 RunRequest::new("remote-thread-dispatch-hint", vec![Message::user("hello")])
1324 .with_agent_id("remote-root")
1325 .with_dispatch_id_hint("external-task-1"),
1326 Arc::new(VecEventSink::new()),
1327 )
1328 .await
1329 .expect("dispatch id hint should create the run identity");
1330
1331 let store = runtime.thread_run_store().expect("store");
1332 let run = store
1333 .load_run("external-task-1")
1334 .await
1335 .expect("load hinted run")
1336 .expect("hinted run");
1337 assert_eq!(run.thread_id, "remote-thread-dispatch-hint");
1338 assert_eq!(run.status, RunStatus::Done);
1339 }
1340
1341 #[tokio::test]
1342 async fn run_trace_dispatch_id_does_not_block_local_waiting_reuse() {
1343 let store = Arc::new(InMemoryStore::new());
1344 let llm = Arc::new(ScriptedLlm::new(vec![StreamResult {
1345 content: vec![ContentBlock::text("continued")],
1346 tool_calls: vec![],
1347 usage: None,
1348 stop_reason: Some(StopReason::EndTurn),
1349 has_incomplete_tool_calls: false,
1350 }]));
1351 let resolver = Arc::new(FixedResolver {
1352 agent: ResolvedAgent::new("agent", "m", "sys", llm),
1353 plugins: vec![],
1354 });
1355 let runtime = AgentRuntime::new(resolver)
1356 .with_thread_run_store(store.clone() as Arc<dyn ThreadRunStore>);
1357 store
1358 .checkpoint(
1359 "thread-default-hint",
1360 &[Message::user("waiting")],
1361 &RunRecord {
1362 run_id: "waiting-run".into(),
1363 thread_id: "thread-default-hint".into(),
1364 agent_id: "agent".into(),
1365 parent_run_id: None,
1366 request: None,
1367 input: None,
1368 output: None,
1369 status: RunStatus::Waiting,
1370 termination_reason: None,
1371 final_output: None,
1372 error_payload: None,
1373 dispatch_id: None,
1374 session_id: None,
1375 transport_request_id: None,
1376 waiting: Some(RunWaitingState {
1377 reason: WaitingReason::BackgroundTasks,
1378 ticket_ids: Vec::new(),
1379 tickets: Vec::new(),
1380 since_dispatch_id: Some("mailbox-dispatch-1".into()),
1381 message: Some("waiting for background work".into()),
1382 }),
1383 outcome: None,
1384 created_at: 1,
1385 started_at: None,
1386 finished_at: None,
1387 updated_at: 1,
1388 steps: 1,
1389 input_tokens: 0,
1390 output_tokens: 0,
1391 state: None,
1392 },
1393 )
1394 .await
1395 .expect("seed waiting run");
1396
1397 let result = runtime
1398 .run(
1399 RunRequest::new("thread-default-hint", vec![Message::user("resume")])
1400 .with_agent_id("agent")
1401 .with_trace_dispatch_id("mailbox-dispatch-1"),
1402 Arc::new(VecEventSink::new()),
1403 )
1404 .await
1405 .expect("default dispatch trace should allow waiting reuse");
1406
1407 assert_eq!(result.run_id, "waiting-run");
1408 assert!(
1409 store
1410 .load_run("mailbox-dispatch-1")
1411 .await
1412 .expect("load default hint run")
1413 .is_none(),
1414 "default dispatch trace must not create a new run when a local waiting run is reusable"
1415 );
1416 }
1417
1418 #[tokio::test]
1419 async fn run_reuses_structured_tool_permission_waiting_run() {
1420 let store = Arc::new(InMemoryStore::new());
1421 let llm = Arc::new(ScriptedLlm::new(vec![StreamResult {
1422 content: vec![ContentBlock::text("approved continuation")],
1423 tool_calls: vec![],
1424 usage: None,
1425 stop_reason: Some(StopReason::EndTurn),
1426 has_incomplete_tool_calls: false,
1427 }]));
1428 let resolver = Arc::new(FixedResolver {
1429 agent: ResolvedAgent::new("agent", "m", "sys", llm),
1430 plugins: vec![],
1431 });
1432 let runtime = AgentRuntime::new(resolver)
1433 .with_thread_run_store(store.clone() as Arc<dyn ThreadRunStore>);
1434 store
1435 .checkpoint(
1436 "thread-tool-permission",
1437 &[Message::user("waiting")],
1438 &RunRecord {
1439 run_id: "waiting-tool-run".into(),
1440 thread_id: "thread-tool-permission".into(),
1441 agent_id: "agent".into(),
1442 parent_run_id: None,
1443 request: None,
1444 input: None,
1445 output: None,
1446 status: RunStatus::Waiting,
1447 termination_reason: None,
1448 final_output: None,
1449 error_payload: None,
1450 dispatch_id: None,
1451 session_id: None,
1452 transport_request_id: None,
1453 waiting: Some(RunWaitingState {
1454 reason: WaitingReason::ToolPermission,
1455 ticket_ids: vec!["call-1".into()],
1456 tickets: Vec::new(),
1457 since_dispatch_id: None,
1458 message: Some("approval required".into()),
1459 }),
1460 outcome: None,
1461 created_at: 1,
1462 started_at: None,
1463 finished_at: None,
1464 updated_at: 1,
1465 steps: 1,
1466 input_tokens: 0,
1467 output_tokens: 0,
1468 state: None,
1469 },
1470 )
1471 .await
1472 .expect("seed waiting run");
1473
1474 let result = runtime
1475 .run(
1476 RunRequest::new("thread-tool-permission", vec![Message::user("approved")])
1477 .with_agent_id("agent")
1478 .with_trace_dispatch_id("mailbox-dispatch-tool"),
1479 Arc::new(VecEventSink::new()),
1480 )
1481 .await
1482 .expect("structured waiting run should be reusable");
1483
1484 assert_eq!(result.run_id, "waiting-tool-run");
1485 assert!(
1486 store
1487 .load_run("mailbox-dispatch-tool")
1488 .await
1489 .expect("load default hint run")
1490 .is_none(),
1491 "default dispatch trace must stay trace-only when a structured waiting run is reusable"
1492 );
1493 }
1494
1495 #[tokio::test]
1496 async fn run_trace_dispatch_id_is_trace_not_canonical_run_id_for_new_run() {
1497 let llm = Arc::new(ScriptedLlm::new(vec![StreamResult {
1498 content: vec![ContentBlock::text("new run")],
1499 tool_calls: vec![],
1500 usage: None,
1501 stop_reason: Some(StopReason::EndTurn),
1502 has_incomplete_tool_calls: false,
1503 }]));
1504 let resolver = Arc::new(FixedResolver {
1505 agent: ResolvedAgent::new("agent", "m", "sys", llm),
1506 plugins: vec![],
1507 });
1508 let runtime = AgentRuntime::new(resolver);
1509 let sink = Arc::new(VecEventSink::new());
1510
1511 let result = runtime
1512 .run(
1513 RunRequest::new("thread-default-new", vec![Message::user("start")])
1514 .with_agent_id("agent")
1515 .with_trace_dispatch_id("mailbox-dispatch-new"),
1516 sink.clone(),
1517 )
1518 .await
1519 .expect("run should succeed");
1520
1521 assert_ne!(result.run_id, "mailbox-dispatch-new");
1522 let start = sink
1523 .events()
1524 .into_iter()
1525 .find_map(|event| match event {
1526 AgentEvent::RunStart {
1527 run_id, identity, ..
1528 } => Some((run_id, identity)),
1529 _ => None,
1530 })
1531 .expect("run start event should be emitted");
1532 assert_eq!(start.0, result.run_id);
1533 assert_eq!(
1534 start.1.and_then(|identity| identity.trace.dispatch_id),
1535 Some("mailbox-dispatch-new".into())
1536 );
1537 }
1538
1539 #[tokio::test]
1540 async fn run_non_local_continuation_uses_requested_run_state_not_latest() {
1541 let runtime = build_remote_runtime(
1542 RemoteEndpoint {
1543 backend: "test-remote".into(),
1544 base_url: "https://remote.example.com".into(),
1545 options: std::collections::BTreeMap::from([(
1546 "supports_continuation".into(),
1547 json!(true),
1548 )]),
1549 ..Default::default()
1550 },
1551 Arc::new(AtomicUsize::new(0)),
1552 );
1553 let store = runtime.thread_run_store().expect("store");
1554 let continued_state = PersistedState {
1555 revision: 1,
1556 extensions: HashMap::from([("marker".into(), json!("continued-run-state"))]),
1557 };
1558 let latest_state = PersistedState {
1559 revision: 2,
1560 extensions: HashMap::from([("marker".into(), json!("latest-run-state"))]),
1561 };
1562
1563 store
1564 .checkpoint(
1565 "remote-thread-state",
1566 &[Message::user("waiting turn")],
1567 &RunRecord {
1568 run_id: "continued-run".into(),
1569 thread_id: "remote-thread-state".into(),
1570 agent_id: "remote-root".into(),
1571 parent_run_id: None,
1572 request: None,
1573 input: None,
1574 output: None,
1575 status: RunStatus::Waiting,
1576 termination_reason: None,
1577 final_output: None,
1578 error_payload: None,
1579 dispatch_id: None,
1580 session_id: None,
1581 transport_request_id: None,
1582 waiting: None,
1583 outcome: None,
1584 created_at: 1,
1585 started_at: None,
1586 finished_at: None,
1587 updated_at: 1,
1588 steps: 1,
1589 input_tokens: 0,
1590 output_tokens: 0,
1591 state: Some(continued_state),
1592 },
1593 )
1594 .await
1595 .expect("seed continued run");
1596 store
1597 .checkpoint(
1598 "remote-thread-state",
1599 &[Message::user("latest turn")],
1600 &RunRecord {
1601 run_id: "latest-run".into(),
1602 thread_id: "remote-thread-state".into(),
1603 agent_id: "remote-root".into(),
1604 parent_run_id: None,
1605 request: None,
1606 input: None,
1607 output: None,
1608 status: RunStatus::Done,
1609 termination_reason: None,
1610 final_output: None,
1611 error_payload: None,
1612 dispatch_id: None,
1613 session_id: None,
1614 transport_request_id: None,
1615 waiting: None,
1616 outcome: None,
1617 created_at: 2,
1618 started_at: None,
1619 finished_at: None,
1620 updated_at: 2,
1621 steps: 1,
1622 input_tokens: 0,
1623 output_tokens: 0,
1624 state: Some(latest_state),
1625 },
1626 )
1627 .await
1628 .expect("seed latest run");
1629
1630 runtime
1631 .run(
1632 RunRequest::new("remote-thread-state", vec![Message::user("resume")])
1633 .with_agent_id("remote-root")
1634 .with_continue_run_id("continued-run"),
1635 Arc::new(VecEventSink::new()),
1636 )
1637 .await
1638 .expect("remote continuation should run");
1639
1640 let continued = store
1641 .load_run("continued-run")
1642 .await
1643 .expect("load continued run")
1644 .expect("continued run");
1645 assert_eq!(
1646 continued
1647 .state
1648 .as_ref()
1649 .and_then(|state| state.extensions.get("marker"))
1650 .and_then(Value::as_str),
1651 Some("continued-run-state")
1652 );
1653 }
1654
1655 #[cfg(feature = "a2a")]
1656 #[tokio::test]
1657 async fn send_decisions_returns_false_for_remote_backend_without_decision_support() {
1658 let mut endpoint = RemoteEndpoint {
1659 backend: "test-remote".into(),
1660 base_url: "https://remote.example.com".into(),
1661 ..Default::default()
1662 };
1663 endpoint
1664 .options
1665 .insert("delay_ms".into(), serde_json::json!(100));
1666 let runtime = Arc::new(build_remote_runtime(
1667 endpoint,
1668 Arc::new(AtomicUsize::new(0)),
1669 ));
1670 let sink: Arc<dyn EventSink> = Arc::new(NullEventSink);
1671
1672 let run_task = {
1673 let runtime = runtime.clone();
1674 let sink = sink.clone();
1675 tokio::spawn(async move {
1676 runtime
1677 .run(
1678 RunRequest::new("remote-thread-live", vec![Message::user("hello")])
1679 .with_agent_id("remote-root"),
1680 sink,
1681 )
1682 .await
1683 })
1684 };
1685
1686 tokio::task::yield_now().await;
1687 let sent = runtime.send_decisions(
1688 "remote-thread-live",
1689 vec![(
1690 "call-1".into(),
1691 ToolCallResume {
1692 decision_id: "d1".into(),
1693 action: ResumeDecisionAction::Resume,
1694 result: Value::Null,
1695 reason: None,
1696 updated_at: 1,
1697 },
1698 )],
1699 );
1700 assert!(
1701 !sent,
1702 "remote backends without decision support must not expose a live decision channel"
1703 );
1704
1705 let result = run_task
1706 .await
1707 .expect("join should succeed")
1708 .expect("run should succeed");
1709 assert_eq!(result.response, "remote root response");
1710 }
1711
1712 struct ToggleSuspendTool {
1713 calls: AtomicUsize,
1714 }
1715
1716 #[async_trait]
1717 impl Tool for ToggleSuspendTool {
1718 fn descriptor(&self) -> ToolDescriptor {
1719 ToolDescriptor::new("dangerous", "dangerous", "suspend then succeed")
1720 }
1721
1722 async fn execute(
1723 &self,
1724 args: Value,
1725 _ctx: &ToolCallContext,
1726 ) -> Result<ToolOutput, ToolError> {
1727 let n = self.calls.fetch_add(1, Ordering::SeqCst);
1728 if n == 0 {
1729 Ok(ToolResult::suspended("dangerous", "needs approval").into())
1730 } else {
1731 Ok(ToolResult::success_with_message("dangerous", args, "approved").into())
1732 }
1733 }
1734 }
1735
1736 struct EchoTool {
1737 calls: AtomicUsize,
1738 }
1739
1740 #[async_trait]
1741 impl Tool for EchoTool {
1742 fn descriptor(&self) -> ToolDescriptor {
1743 ToolDescriptor::new("echo", "echo", "echo success")
1744 }
1745
1746 async fn execute(
1747 &self,
1748 args: Value,
1749 _ctx: &ToolCallContext,
1750 ) -> Result<ToolOutput, ToolError> {
1751 self.calls.fetch_add(1, Ordering::SeqCst);
1752 Ok(ToolResult::success("echo", args).into())
1753 }
1754 }
1755
1756 struct RecordingToolPolicyHook {
1757 seen: Arc<Mutex<Vec<ToolPolicyContext>>>,
1758 }
1759
1760 #[async_trait]
1761 impl ToolPolicyHook for RecordingToolPolicyHook {
1762 async fn decide(
1763 &self,
1764 ctx: &ToolPolicyContext,
1765 ) -> Result<ToolPolicyDecision, awaken_contract::StateError> {
1766 self.seen.lock().expect("lock poisoned").push(ctx.clone());
1767 if ctx.run_mode == RunMode::Scheduled
1768 && ctx.adapter == AdapterKind::Acp
1769 && ctx.tool_name == "echo"
1770 {
1771 return Ok(ToolPolicyDecision::Deny {
1772 reason: "scheduled ACP echo denied".into(),
1773 });
1774 }
1775 Ok(ToolPolicyDecision::Allow)
1776 }
1777 }
1778
1779 struct RecordingToolPolicyPlugin {
1780 seen: Arc<Mutex<Vec<ToolPolicyContext>>>,
1781 }
1782
1783 impl Plugin for RecordingToolPolicyPlugin {
1784 fn descriptor(&self) -> PluginDescriptor {
1785 PluginDescriptor {
1786 name: "recording-tool-policy",
1787 }
1788 }
1789
1790 fn register(
1791 &self,
1792 registrar: &mut PluginRegistrar,
1793 ) -> Result<(), awaken_contract::StateError> {
1794 registrar.register_tool_policy_hook(
1795 "recording-tool-policy",
1796 RecordingToolPolicyHook {
1797 seen: Arc::clone(&self.seen),
1798 },
1799 )
1800 }
1801 }
1802
1803 struct SpawnShortBgTaskTool {
1804 manager: Arc<crate::extensions::background::BackgroundTaskManager>,
1805 delay_ms: u64,
1806 }
1807
1808 #[async_trait]
1809 impl Tool for SpawnShortBgTaskTool {
1810 fn descriptor(&self) -> ToolDescriptor {
1811 ToolDescriptor::new("spawn_bg", "spawn_bg", "spawn short background task")
1812 }
1813
1814 async fn execute(
1815 &self,
1816 _args: Value,
1817 ctx: &ToolCallContext,
1818 ) -> Result<ToolOutput, ToolError> {
1819 let delay = self.delay_ms;
1820 self.manager
1821 .spawn(
1822 &ctx.run_identity.thread_id,
1823 "bg",
1824 None,
1825 "short task",
1826 crate::extensions::background::TaskParentContext::default(),
1827 move |_task_ctx| async move {
1828 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1829 crate::extensions::background::TaskResult::Success(json!({
1830 "done": true,
1831 "source": "background"
1832 }))
1833 },
1834 )
1835 .await
1836 .map_err(|e| ToolError::ExecutionFailed(e.to_string()))?;
1837 Ok(ToolResult::success("spawn_bg", json!({"spawned": true})).into())
1838 }
1839 }
1840
1841 struct RecordingLlm {
1842 responses: Mutex<Vec<StreamResult>>,
1843 requests: Arc<Mutex<Vec<InferenceRequest>>>,
1844 }
1845
1846 impl RecordingLlm {
1847 fn new(responses: Vec<StreamResult>, requests: Arc<Mutex<Vec<InferenceRequest>>>) -> Self {
1848 Self {
1849 responses: Mutex::new(responses),
1850 requests,
1851 }
1852 }
1853 }
1854
1855 #[async_trait]
1856 impl LlmExecutor for RecordingLlm {
1857 async fn execute(
1858 &self,
1859 request: InferenceRequest,
1860 ) -> Result<StreamResult, InferenceExecutionError> {
1861 self.requests.lock().expect("lock poisoned").push(request);
1862 let mut responses = self.responses.lock().expect("lock poisoned");
1863 Ok(responses.remove(0))
1864 }
1865
1866 fn name(&self) -> &str {
1867 "recording"
1868 }
1869 }
1870
1871 struct FixedResolver {
1872 agent: ResolvedAgent,
1873 plugins: Vec<Arc<dyn Plugin>>,
1874 }
1875
1876 impl AgentResolver for FixedResolver {
1877 fn resolve(&self, _agent_id: &str) -> Result<ResolvedAgent, crate::error::RuntimeError> {
1878 let mut agent = self.agent.clone();
1879 agent.env = build_agent_env(&self.plugins, &agent)?;
1880 Ok(agent)
1881 }
1882 }
1883
1884 struct ThreadCounterKey;
1885
1886 impl StateKey for ThreadCounterKey {
1887 const KEY: &'static str = "test.thread_counter";
1888 type Value = u32;
1889 type Update = u32;
1890
1891 fn apply(value: &mut Self::Value, update: Self::Update) {
1892 *value = update;
1893 }
1894 }
1895
1896 struct ThreadCounterPlugin;
1897
1898 impl Plugin for ThreadCounterPlugin {
1899 fn descriptor(&self) -> PluginDescriptor {
1900 PluginDescriptor {
1901 name: "test.thread-counter",
1902 }
1903 }
1904
1905 fn register(
1906 &self,
1907 registrar: &mut PluginRegistrar,
1908 ) -> Result<(), awaken_contract::StateError> {
1909 registrar.register_key::<ThreadCounterKey>(StateKeyOptions {
1910 persistent: true,
1911 scope: KeyScope::Thread,
1912 ..StateKeyOptions::default()
1913 })?;
1914 registrar.register_phase_hook(
1915 "test.thread-counter",
1916 awaken_contract::model::Phase::RunStart,
1917 ThreadCounterHook,
1918 )
1919 }
1920 }
1921
1922 struct ThreadCounterHook;
1923
1924 #[async_trait]
1925 impl PhaseHook for ThreadCounterHook {
1926 async fn run(
1927 &self,
1928 ctx: &PhaseContext,
1929 ) -> Result<StateCommand, awaken_contract::StateError> {
1930 let next = ctx.state::<ThreadCounterKey>().copied().unwrap_or(0) + 1;
1931 let mut cmd = StateCommand::new();
1932 cmd.update::<ThreadCounterKey>(next);
1933 Ok(cmd)
1934 }
1935 }
1936
1937 struct SequentialVisibilityKey;
1938
1939 impl StateKey for SequentialVisibilityKey {
1940 const KEY: &'static str = "test.sequential_visibility";
1941 type Value = bool;
1942 type Update = bool;
1943
1944 fn apply(value: &mut Self::Value, update: Self::Update) {
1945 *value = update;
1946 }
1947 }
1948
1949 struct SequentialVisibilityPlugin;
1950
1951 impl Plugin for SequentialVisibilityPlugin {
1952 fn descriptor(&self) -> PluginDescriptor {
1953 PluginDescriptor {
1954 name: "test.sequential-visibility",
1955 }
1956 }
1957
1958 fn register(
1959 &self,
1960 registrar: &mut PluginRegistrar,
1961 ) -> Result<(), awaken_contract::StateError> {
1962 registrar.register_key::<SequentialVisibilityKey>(StateKeyOptions::default())?;
1963 registrar.register_phase_hook(
1964 "test.sequential-visibility",
1965 awaken_contract::model::Phase::AfterToolExecute,
1966 SequentialVisibilityHook,
1967 )
1968 }
1969 }
1970
1971 struct SequentialVisibilityHook;
1972
1973 #[async_trait]
1974 impl PhaseHook for SequentialVisibilityHook {
1975 async fn run(
1976 &self,
1977 ctx: &PhaseContext,
1978 ) -> Result<StateCommand, awaken_contract::StateError> {
1979 let mut cmd = StateCommand::new();
1980 if ctx.tool_name.as_deref() == Some("writer") {
1981 cmd.update::<SequentialVisibilityKey>(true);
1982 }
1983 Ok(cmd)
1984 }
1985 }
1986
1987 struct WriterTool;
1988
1989 #[async_trait]
1990 impl Tool for WriterTool {
1991 fn descriptor(&self) -> ToolDescriptor {
1992 ToolDescriptor::new("writer", "writer", "writes marker in hook")
1993 }
1994
1995 async fn execute(
1996 &self,
1997 _args: Value,
1998 _ctx: &ToolCallContext,
1999 ) -> Result<ToolOutput, ToolError> {
2000 Ok(ToolResult::success("writer", Value::Null).into())
2001 }
2002 }
2003
2004 struct ReaderTool {
2005 saw_marker: Arc<std::sync::atomic::AtomicBool>,
2006 }
2007
2008 #[async_trait]
2009 impl Tool for ReaderTool {
2010 fn descriptor(&self) -> ToolDescriptor {
2011 ToolDescriptor::new("reader", "reader", "reads marker from snapshot")
2012 }
2013
2014 async fn execute(
2015 &self,
2016 _args: Value,
2017 ctx: &ToolCallContext,
2018 ) -> Result<ToolOutput, ToolError> {
2019 let saw = ctx
2020 .snapshot
2021 .get::<SequentialVisibilityKey>()
2022 .copied()
2023 .unwrap_or(false);
2024 self.saw_marker.store(saw, Ordering::SeqCst);
2025 Ok(ToolResult::success("reader", Value::Null).into())
2026 }
2027 }
2028
2029 fn seeded_run_record(
2030 run_id: &str,
2031 thread_id: &str,
2032 agent_id: &str,
2033 state: Option<PersistedState>,
2034 ) -> RunRecord {
2035 RunRecord {
2036 run_id: run_id.to_string(),
2037 thread_id: thread_id.to_string(),
2038 agent_id: agent_id.to_string(),
2039 parent_run_id: None,
2040 request: None,
2041 input: None,
2042 output: None,
2043 status: RunStatus::Done,
2044 termination_reason: None,
2045 final_output: None,
2046 error_payload: None,
2047 dispatch_id: None,
2048 session_id: None,
2049 transport_request_id: None,
2050 waiting: None,
2051 outcome: None,
2052 created_at: 1,
2053 started_at: None,
2054 finished_at: None,
2055 updated_at: 1,
2056 steps: 1,
2057 input_tokens: 0,
2058 output_tokens: 0,
2059 state,
2060 }
2061 }
2062
2063 #[tokio::test]
2064 async fn run_to_completion_returns_final_result() {
2065 let llm = Arc::new(ScriptedLlm::new(vec![StreamResult {
2066 content: vec![ContentBlock::text("ok")],
2067 tool_calls: vec![],
2068 usage: None,
2069 stop_reason: Some(StopReason::EndTurn),
2070 has_incomplete_tool_calls: false,
2071 }]));
2072 let resolver = Arc::new(FixedResolver {
2073 agent: ResolvedAgent::new("agent", "m", "sys", llm),
2074 plugins: vec![],
2075 });
2076 let runtime = AgentRuntime::new(resolver);
2077
2078 let result = runtime
2079 .run_to_completion(
2080 RunRequest::new("thread-completion", vec![Message::user("hi")])
2081 .with_agent_id("agent"),
2082 )
2083 .await
2084 .expect("run should succeed");
2085
2086 assert_eq!(result.response, "ok");
2087 assert_eq!(
2088 result.termination,
2089 awaken_contract::contract::lifecycle::TerminationReason::NaturalEnd
2090 );
2091 }
2092
2093 #[tokio::test]
2094 async fn run_request_overrides_are_forwarded_to_inference() {
2095 let llm = Arc::new(ScriptedLlm::new(vec![StreamResult {
2096 content: vec![ContentBlock::text("ok")],
2097 tool_calls: vec![],
2098 usage: Some(awaken_contract::contract::inference::TokenUsage {
2099 prompt_tokens: Some(11),
2100 completion_tokens: Some(7),
2101 ..Default::default()
2102 }),
2103 stop_reason: Some(StopReason::EndTurn),
2104 has_incomplete_tool_calls: false,
2105 }]));
2106 let resolver = Arc::new(FixedResolver {
2107 agent: ResolvedAgent::new("agent", "m", "sys", llm.clone()),
2108 plugins: vec![],
2109 });
2110 let runtime = AgentRuntime::new(resolver);
2111 let sink = Arc::new(VecEventSink::new());
2112 let override_req = InferenceOverride {
2113 upstream_model: Some("override-model".into()),
2114 temperature: Some(0.3),
2115 max_tokens: Some(77),
2116 ..Default::default()
2117 };
2118
2119 let result = runtime
2120 .run(
2121 RunRequest::new("thread-ovr", vec![Message::user("hi")])
2122 .with_agent_id("agent")
2123 .with_overrides(override_req.clone()),
2124 sink.clone(),
2125 )
2126 .await
2127 .expect("run should succeed");
2128
2129 assert_eq!(
2130 result.termination,
2131 awaken_contract::contract::lifecycle::TerminationReason::NaturalEnd
2132 );
2133 let seen = llm.seen_overrides.lock().expect("lock poisoned");
2134 assert_eq!(seen.len(), 1);
2135 assert_eq!(
2136 seen[0].as_ref().and_then(|o| o.temperature),
2137 override_req.temperature
2138 );
2139 assert_eq!(
2140 seen[0].as_ref().and_then(|o| o.max_tokens),
2141 override_req.max_tokens
2142 );
2143 assert!(
2144 seen[0]
2145 .as_ref()
2146 .and_then(|o| o.upstream_model.as_ref())
2147 .is_none()
2148 );
2149 let complete_model = sink.events().into_iter().find_map(|event| match event {
2150 AgentEvent::InferenceComplete { model, .. } => Some(model),
2151 _ => None,
2152 });
2153 assert_eq!(complete_model.as_deref(), Some("override-model"));
2154 }
2155
2156 #[tokio::test]
2157 async fn send_decisions_resumes_waiting_run() {
2158 let llm = Arc::new(ScriptedLlm::new(vec![
2159 StreamResult {
2160 content: vec![ContentBlock::text("calling tool")],
2161 tool_calls: vec![awaken_contract::contract::message::ToolCall::new(
2162 "c1",
2163 "dangerous",
2164 json!({"x": 1}),
2165 )],
2166 usage: None,
2167 stop_reason: Some(StopReason::ToolUse),
2168 has_incomplete_tool_calls: false,
2169 },
2170 StreamResult {
2171 content: vec![ContentBlock::text("finished")],
2172 tool_calls: vec![],
2173 usage: None,
2174 stop_reason: Some(StopReason::EndTurn),
2175 has_incomplete_tool_calls: false,
2176 },
2177 ]));
2178 let tool = Arc::new(ToggleSuspendTool {
2179 calls: AtomicUsize::new(0),
2180 });
2181 let resolver = Arc::new(FixedResolver {
2182 agent: ResolvedAgent::new("agent", "m", "sys", llm).with_tool(tool),
2183 plugins: vec![],
2184 });
2185 let runtime = Arc::new(AgentRuntime::new(resolver));
2186 let sink = Arc::new(VecEventSink::new());
2187
2188 let run_task = {
2189 let runtime = Arc::clone(&runtime);
2190 let sink = sink.clone();
2191 tokio::spawn(async move {
2192 runtime
2193 .run(
2194 RunRequest::new("thread-live", vec![Message::user("go")])
2195 .with_agent_id("agent"),
2196 sink as Arc<dyn EventSink>,
2197 )
2198 .await
2199 })
2200 };
2201
2202 let mut sent = false;
2203 for _ in 0..40 {
2204 if runtime.send_decisions(
2205 "thread-live",
2206 vec![(
2207 "c1".into(),
2208 ToolCallResume {
2209 decision_id: "d1".into(),
2210 action: ResumeDecisionAction::Resume,
2211 result: Value::Null,
2212 reason: None,
2213 updated_at: 1,
2214 },
2215 )],
2216 ) {
2217 sent = true;
2218 break;
2219 }
2220 tokio::task::yield_now().await;
2221 }
2222 assert!(sent, "should send decision while run is active");
2223
2224 let result = run_task
2225 .await
2226 .expect("join should succeed")
2227 .expect("run should succeed");
2228 assert_eq!(
2229 result.termination,
2230 awaken_contract::contract::lifecycle::TerminationReason::NaturalEnd
2231 );
2232
2233 let events = sink.take();
2234 assert!(
2235 events.iter().any(|event| {
2236 matches!(
2237 event,
2238 AgentEvent::ToolCallResumed { target_id, result }
2239 if target_id == "c1" && result == &json!({"x": 1})
2240 )
2241 }),
2242 "resumed replay should emit ToolCallResumed with the final tool result: {events:?}"
2243 );
2244 }
2245
2246 #[tokio::test]
2247 async fn run_request_policy_context_reaches_tool_gate() {
2248 let llm = Arc::new(ScriptedLlm::new(vec![StreamResult {
2249 content: vec![ContentBlock::text("calling echo")],
2250 tool_calls: vec![awaken_contract::contract::message::ToolCall::new(
2251 "c1",
2252 "echo",
2253 json!({"message": "hello"}),
2254 )],
2255 usage: None,
2256 stop_reason: Some(StopReason::ToolUse),
2257 has_incomplete_tool_calls: false,
2258 }]));
2259 let tool = Arc::new(EchoTool {
2260 calls: AtomicUsize::new(0),
2261 });
2262 let seen = Arc::new(Mutex::new(Vec::new()));
2263 let resolver = Arc::new(FixedResolver {
2264 agent: ResolvedAgent::new("agent", "m", "sys", llm).with_tool(tool.clone()),
2265 plugins: vec![Arc::new(RecordingToolPolicyPlugin {
2266 seen: Arc::clone(&seen),
2267 })],
2268 });
2269 let runtime = AgentRuntime::new(resolver);
2270
2271 let result = runtime
2272 .run(
2273 RunRequest::new("thread-policy", vec![Message::user("use echo")])
2274 .with_agent_id("agent")
2275 .with_run_mode(RunMode::Scheduled)
2276 .with_adapter(AdapterKind::Acp),
2277 Arc::new(VecEventSink::new()),
2278 )
2279 .await
2280 .expect("run should reach policy hook");
2281
2282 assert!(matches!(
2283 result.termination,
2284 TerminationReason::Blocked(ref reason) if reason == "scheduled ACP echo denied"
2285 ));
2286 assert_eq!(
2287 tool.calls.load(Ordering::SeqCst),
2288 0,
2289 "denied tool must not execute"
2290 );
2291
2292 let contexts = seen.lock().expect("lock poisoned");
2293 assert_eq!(contexts.len(), 1);
2294 let ctx = &contexts[0];
2295 assert_eq!(ctx.thread_id, "thread-policy");
2296 assert_eq!(ctx.run_mode, RunMode::Scheduled);
2297 assert_eq!(ctx.adapter, AdapterKind::Acp);
2298 assert_eq!(ctx.dispatch_id, None);
2299 assert_eq!(ctx.tool_name, "echo");
2300 }
2301
2302 #[tokio::test]
2303 async fn background_events_buffer_while_suspended_until_decision_arrives() {
2304 use awaken_contract::contract::message::{Role, Visibility};
2305
2306 let requests = Arc::new(Mutex::new(Vec::new()));
2307 let llm = Arc::new(RecordingLlm::new(
2308 vec![
2309 StreamResult {
2310 content: vec![ContentBlock::text("start tools")],
2311 tool_calls: vec![
2312 awaken_contract::contract::message::ToolCall::new(
2313 "bg1",
2314 "spawn_bg",
2315 json!({}),
2316 ),
2317 awaken_contract::contract::message::ToolCall::new(
2318 "c1",
2319 "dangerous",
2320 json!({"x": 1}),
2321 ),
2322 ],
2323 usage: None,
2324 stop_reason: Some(StopReason::ToolUse),
2325 has_incomplete_tool_calls: false,
2326 },
2327 StreamResult {
2328 content: vec![ContentBlock::text("done after approval")],
2329 tool_calls: vec![],
2330 usage: None,
2331 stop_reason: Some(StopReason::EndTurn),
2332 has_incomplete_tool_calls: false,
2333 },
2334 ],
2335 requests.clone(),
2336 ));
2337 let manager = Arc::new(crate::extensions::background::BackgroundTaskManager::new());
2338 let resolver = Arc::new(FixedResolver {
2339 agent: ResolvedAgent::new("agent", "m", "sys", llm)
2340 .with_tool(Arc::new(SpawnShortBgTaskTool {
2341 manager: manager.clone(),
2342 delay_ms: 25,
2343 }))
2344 .with_tool(Arc::new(ToggleSuspendTool {
2345 calls: AtomicUsize::new(0),
2346 })),
2347 plugins: vec![Arc::new(
2348 crate::extensions::background::BackgroundTaskPlugin::new(manager),
2349 )],
2350 });
2351 let runtime = Arc::new(AgentRuntime::new(resolver));
2352 let sink: Arc<dyn EventSink> = Arc::new(NullEventSink);
2353
2354 let run_task = {
2355 let runtime = runtime.clone();
2356 let sink = sink.clone();
2357 tokio::spawn(async move {
2358 runtime
2359 .run(
2360 RunRequest::new("thread-bg-suspend", vec![Message::user("go")])
2361 .with_agent_id("agent"),
2362 sink,
2363 )
2364 .await
2365 })
2366 };
2367
2368 tokio::time::sleep(std::time::Duration::from_millis(80)).await;
2369 assert_eq!(
2370 requests.lock().expect("lock poisoned").len(),
2371 1,
2372 "background completion must not resume the LLM before the suspended tool is decided"
2373 );
2374
2375 let sent = runtime.send_decisions(
2376 "thread-bg-suspend",
2377 vec![(
2378 "c1".into(),
2379 ToolCallResume {
2380 decision_id: "d1".into(),
2381 action: ResumeDecisionAction::Resume,
2382 result: Value::Null,
2383 reason: None,
2384 updated_at: 1,
2385 },
2386 )],
2387 );
2388 assert!(sent, "decision should reach the waiting run");
2389
2390 let result = run_task
2391 .await
2392 .expect("join should succeed")
2393 .expect("run should succeed");
2394 assert_eq!(
2395 result.termination,
2396 awaken_contract::contract::lifecycle::TerminationReason::NaturalEnd
2397 );
2398
2399 let recorded = requests.lock().expect("lock poisoned");
2400 assert_eq!(
2401 recorded.len(),
2402 2,
2403 "run should resume exactly once after approval"
2404 );
2405 assert!(
2406 recorded[1].messages.iter().any(|message| {
2407 message.role == Role::User
2408 && message.visibility == Visibility::Internal
2409 && message.text().contains("background-task-event")
2410 && message.text().contains("\"done\":true")
2411 }),
2412 "buffered background event should be injected into the resumed request"
2413 );
2414 }
2415
2416 #[tokio::test]
2417 async fn new_user_message_supersedes_suspended_calls_but_keeps_completed_results() {
2418 use awaken_contract::contract::lifecycle::RunStatus;
2419 use awaken_contract::contract::message::Role;
2420 use awaken_contract::contract::storage::ThreadStore;
2421 use awaken_stores::InMemoryStore;
2422
2423 let llm = Arc::new(ScriptedLlm::new(vec![
2424 StreamResult {
2425 content: vec![ContentBlock::text("call tools")],
2426 tool_calls: vec![
2427 awaken_contract::contract::message::ToolCall::new(
2428 "c_echo",
2429 "echo",
2430 json!({"ok": true}),
2431 ),
2432 awaken_contract::contract::message::ToolCall::new(
2433 "c_suspend",
2434 "dangerous",
2435 json!({"danger": true}),
2436 ),
2437 ],
2438 usage: None,
2439 stop_reason: Some(StopReason::ToolUse),
2440 has_incomplete_tool_calls: false,
2441 },
2442 StreamResult {
2443 content: vec![ContentBlock::text("fresh answer")],
2444 tool_calls: vec![],
2445 usage: None,
2446 stop_reason: Some(StopReason::EndTurn),
2447 has_incomplete_tool_calls: false,
2448 },
2449 ]));
2450 let echo = Arc::new(EchoTool {
2451 calls: AtomicUsize::new(0),
2452 });
2453 let dangerous = Arc::new(ToggleSuspendTool {
2454 calls: AtomicUsize::new(0),
2455 });
2456 let resolver = Arc::new(FixedResolver {
2457 agent: ResolvedAgent::new("agent", "m", "sys", llm)
2458 .with_tool(echo.clone())
2459 .with_tool(dangerous.clone()),
2460 plugins: vec![],
2461 });
2462 let store = Arc::new(InMemoryStore::new());
2463 let runtime = Arc::new(
2464 AgentRuntime::new(resolver)
2465 .with_thread_run_store(store.clone() as Arc<dyn ThreadRunStore>),
2466 );
2467 let sink: Arc<dyn EventSink> = Arc::new(NullEventSink);
2468
2469 let first_run = {
2470 let runtime = runtime.clone();
2471 let sink = sink.clone();
2472 tokio::spawn(async move {
2473 runtime
2474 .run(
2475 RunRequest::new("thread-supersede", vec![Message::user("first")])
2476 .with_agent_id("agent"),
2477 sink,
2478 )
2479 .await
2480 })
2481 };
2482
2483 let wait_deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
2484 loop {
2485 if let Some(run) = store
2486 .latest_run("thread-supersede")
2487 .await
2488 .expect("latest run lookup should succeed")
2489 && run.status == RunStatus::Waiting
2490 && run.waiting_reason() == Some(WaitingReason::ToolPermission)
2491 {
2492 let waiting = run.waiting.expect("waiting state should be durable");
2493 assert_eq!(waiting.ticket_ids, vec!["c_suspend"]);
2494 assert_eq!(waiting.tickets.len(), 1);
2495 assert_eq!(waiting.tickets[0].tool_call_id, "c_suspend");
2496 assert_eq!(waiting.tickets[0].tool_name, "dangerous");
2497 assert_eq!(waiting.tickets[0].arguments, json!({"danger": true}));
2498 break;
2499 }
2500 assert!(
2501 std::time::Instant::now() < wait_deadline,
2502 "timed out waiting for suspended checkpoint"
2503 );
2504 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
2505 }
2506
2507 assert!(
2508 runtime.cancel_and_wait_by_thread("thread-supersede").await,
2509 "new message path should be able to supersede the suspended run"
2510 );
2511
2512 let first = first_run
2513 .await
2514 .expect("join should succeed")
2515 .expect("first run should terminate cleanly");
2516 assert_eq!(
2517 first.termination,
2518 awaken_contract::contract::lifecycle::TerminationReason::Cancelled
2519 );
2520
2521 let second = runtime
2522 .run(
2523 RunRequest::new("thread-supersede", vec![Message::user("second")])
2524 .with_agent_id("agent"),
2525 sink,
2526 )
2527 .await
2528 .expect("second run should succeed");
2529 assert_eq!(
2530 second.termination,
2531 awaken_contract::contract::lifecycle::TerminationReason::NaturalEnd
2532 );
2533 assert_eq!(
2534 echo.calls.load(Ordering::SeqCst),
2535 1,
2536 "successful tool calls from the superseded run must not replay"
2537 );
2538 assert_eq!(
2539 dangerous.calls.load(Ordering::SeqCst),
2540 1,
2541 "suspended tool calls must be superseded instead of replayed on new user input"
2542 );
2543
2544 let messages = ThreadStore::load_messages(&*store, "thread-supersede")
2545 .await
2546 .expect("load messages should succeed")
2547 .expect("thread messages should exist");
2548 assert!(
2549 messages.iter().any(|message| message.role == Role::Tool
2550 && message.tool_call_id.as_deref() == Some("c_echo")),
2551 "completed tool result should remain in durable history"
2552 );
2553 assert!(
2554 !messages
2555 .iter()
2556 .filter(|message| message.role == Role::Assistant)
2557 .filter_map(|message| message.tool_calls.as_ref())
2558 .flatten()
2559 .any(|call| call.id == "c_suspend"),
2560 "superseded suspended tool calls should be stripped from later history"
2561 );
2562 }
2563
2564 #[tokio::test]
2565 async fn sequential_tool_execution_sees_latest_state_between_calls() {
2566 let llm = Arc::new(ScriptedLlm::new(vec![
2567 StreamResult {
2568 content: vec![ContentBlock::text("tools")],
2569 tool_calls: vec![
2570 awaken_contract::contract::message::ToolCall::new("c1", "writer", json!({})),
2571 awaken_contract::contract::message::ToolCall::new("c2", "reader", json!({})),
2572 ],
2573 usage: None,
2574 stop_reason: Some(StopReason::ToolUse),
2575 has_incomplete_tool_calls: false,
2576 },
2577 StreamResult {
2578 content: vec![ContentBlock::text("done")],
2579 tool_calls: vec![],
2580 usage: None,
2581 stop_reason: Some(StopReason::EndTurn),
2582 has_incomplete_tool_calls: false,
2583 },
2584 ]));
2585 let saw_marker = Arc::new(std::sync::atomic::AtomicBool::new(false));
2586 let resolver = Arc::new(FixedResolver {
2587 agent: ResolvedAgent::new("agent", "m", "sys", llm)
2588 .with_tool(Arc::new(WriterTool))
2589 .with_tool(Arc::new(ReaderTool {
2590 saw_marker: saw_marker.clone(),
2591 })),
2592 plugins: vec![Arc::new(SequentialVisibilityPlugin)],
2593 });
2594 let runtime = AgentRuntime::new(resolver);
2595 let sink: Arc<dyn EventSink> = Arc::new(NullEventSink);
2596
2597 let result = runtime
2598 .run(
2599 RunRequest::new("thread-seq-visibility", vec![Message::user("go")])
2600 .with_agent_id("agent"),
2601 sink.clone(),
2602 )
2603 .await
2604 .expect("run should succeed");
2605
2606 assert_eq!(
2607 result.termination,
2608 awaken_contract::contract::lifecycle::TerminationReason::NaturalEnd
2609 );
2610 assert!(
2611 saw_marker.load(Ordering::SeqCst),
2612 "second tool should observe state written after first tool"
2613 );
2614 }
2615
2616 #[tokio::test]
2617 async fn checkpoint_persists_state_and_thread_together() {
2618 let llm = Arc::new(ScriptedLlm::new(vec![StreamResult {
2619 content: vec![ContentBlock::text("ok")],
2620 tool_calls: vec![],
2621 usage: Some(awaken_contract::contract::inference::TokenUsage {
2622 prompt_tokens: Some(11),
2623 completion_tokens: Some(7),
2624 ..Default::default()
2625 }),
2626 stop_reason: Some(StopReason::EndTurn),
2627 has_incomplete_tool_calls: false,
2628 }]));
2629 let resolver = Arc::new(FixedResolver {
2630 agent: ResolvedAgent::new("agent", "m", "sys", llm),
2631 plugins: vec![],
2632 });
2633 let store = Arc::new(InMemoryStore::new());
2634 let runtime = AgentRuntime::new(resolver)
2635 .with_thread_run_store(store.clone() as Arc<dyn ThreadRunStore>);
2636 let sink: Arc<dyn EventSink> = Arc::new(NullEventSink);
2637
2638 let result = runtime
2639 .run(
2640 RunRequest::new("thread-tx", vec![Message::user("hi")]).with_agent_id("agent"),
2641 sink.clone(),
2642 )
2643 .await
2644 .expect("run should succeed");
2645 assert_eq!(
2646 result.termination,
2647 awaken_contract::contract::lifecycle::TerminationReason::NaturalEnd
2648 );
2649
2650 let latest = store
2651 .latest_run("thread-tx")
2652 .await
2653 .expect("latest run lookup")
2654 .expect("run persisted");
2655 assert_eq!(latest.thread_id, "thread-tx");
2656 assert!(latest.state.is_some(), "state snapshot should be persisted");
2657 assert_eq!(latest.input_tokens, 11);
2658 assert_eq!(latest.output_tokens, 7);
2659
2660 let msgs = store
2661 .load_messages("thread-tx")
2662 .await
2663 .expect("load messages")
2664 .expect("thread should exist");
2665 assert!(!msgs.is_empty());
2666 }
2667
2668 #[tokio::test]
2669 async fn run_request_without_agent_id_prefers_latest_thread_state_agent() {
2670 let llm = Arc::new(ScriptedLlm::new(vec![StreamResult {
2671 content: vec![ContentBlock::text("ok")],
2672 tool_calls: vec![],
2673 usage: None,
2674 stop_reason: Some(StopReason::EndTurn),
2675 has_incomplete_tool_calls: false,
2676 }]));
2677 let resolver = Arc::new(FixedResolver {
2678 agent: ResolvedAgent::new("agent", "m", "sys", llm),
2679 plugins: vec![],
2680 });
2681 let store = Arc::new(InMemoryStore::new());
2682
2683 let mut extensions = HashMap::new();
2684 extensions.insert(
2685 <ActiveAgentIdKey as StateKey>::KEY.to_string(),
2686 Value::String("agent-from-state".into()),
2687 );
2688 store
2689 .create_run(&seeded_run_record(
2690 "seed-1",
2691 "thread-infer-state",
2692 "agent-from-record",
2693 Some(PersistedState {
2694 revision: 1,
2695 extensions,
2696 }),
2697 ))
2698 .await
2699 .expect("seed run record");
2700
2701 let runtime = AgentRuntime::new(resolver)
2702 .with_thread_run_store(store.clone() as Arc<dyn ThreadRunStore>);
2703 let sink: Arc<dyn EventSink> = Arc::new(NullEventSink);
2704
2705 runtime
2706 .run(
2707 RunRequest::new("thread-infer-state", vec![Message::user("hi")]),
2708 sink.clone(),
2709 )
2710 .await
2711 .expect("run should succeed");
2712
2713 let latest = store
2714 .latest_run("thread-infer-state")
2715 .await
2716 .expect("latest run lookup")
2717 .expect("run persisted");
2718 assert_eq!(latest.agent_id, "agent-from-state");
2719 }
2720
2721 #[tokio::test]
2722 async fn run_request_without_agent_id_falls_back_to_latest_run_record_agent_id() {
2723 let llm = Arc::new(ScriptedLlm::new(vec![StreamResult {
2724 content: vec![ContentBlock::text("ok")],
2725 tool_calls: vec![],
2726 usage: None,
2727 stop_reason: Some(StopReason::EndTurn),
2728 has_incomplete_tool_calls: false,
2729 }]));
2730 let resolver = Arc::new(FixedResolver {
2731 agent: ResolvedAgent::new("agent", "m", "sys", llm),
2732 plugins: vec![],
2733 });
2734 let store = Arc::new(InMemoryStore::new());
2735
2736 store
2737 .create_run(&seeded_run_record(
2738 "seed-2",
2739 "thread-infer-record",
2740 "agent-from-record",
2741 None,
2742 ))
2743 .await
2744 .expect("seed run record");
2745
2746 let runtime = AgentRuntime::new(resolver)
2747 .with_thread_run_store(store.clone() as Arc<dyn ThreadRunStore>);
2748 let sink: Arc<dyn EventSink> = Arc::new(NullEventSink);
2749
2750 runtime
2751 .run(
2752 RunRequest::new("thread-infer-record", vec![Message::user("hi")]),
2753 sink.clone(),
2754 )
2755 .await
2756 .expect("run should succeed");
2757
2758 let latest = store
2759 .latest_run("thread-infer-record")
2760 .await
2761 .expect("latest run lookup")
2762 .expect("run persisted");
2763 assert_eq!(latest.agent_id, "agent-from-record");
2764 }
2765
2766 #[tokio::test]
2767 async fn thread_scoped_state_restores_before_run_start_hooks() {
2768 let llm = Arc::new(ScriptedLlm::new(vec![
2769 StreamResult {
2770 content: vec![ContentBlock::text("ok-1")],
2771 tool_calls: vec![],
2772 usage: None,
2773 stop_reason: Some(StopReason::EndTurn),
2774 has_incomplete_tool_calls: false,
2775 },
2776 StreamResult {
2777 content: vec![ContentBlock::text("ok-2")],
2778 tool_calls: vec![],
2779 usage: None,
2780 stop_reason: Some(StopReason::EndTurn),
2781 has_incomplete_tool_calls: false,
2782 },
2783 ]));
2784 let resolver = Arc::new(FixedResolver {
2785 agent: ResolvedAgent::new("agent", "m", "sys", llm),
2786 plugins: vec![Arc::new(ThreadCounterPlugin)],
2787 });
2788 let store = Arc::new(InMemoryStore::new());
2789 let runtime = AgentRuntime::new(resolver)
2790 .with_thread_run_store(store.clone() as Arc<dyn ThreadRunStore>);
2791 let sink: Arc<dyn EventSink> = Arc::new(NullEventSink);
2792
2793 runtime
2794 .run(
2795 RunRequest::new("thread-counter", vec![Message::user("first")])
2796 .with_agent_id("agent"),
2797 sink.clone(),
2798 )
2799 .await
2800 .expect("first run should succeed");
2801
2802 runtime
2803 .run(
2804 RunRequest::new("thread-counter", vec![Message::user("second")])
2805 .with_agent_id("agent"),
2806 sink.clone(),
2807 )
2808 .await
2809 .expect("second run should succeed");
2810
2811 let runs = store
2812 .list_runs(&RunQuery {
2813 thread_id: Some("thread-counter".into()),
2814 ..RunQuery::default()
2815 })
2816 .await
2817 .expect("run list lookup");
2818
2819 let max_counter = runs
2820 .items
2821 .iter()
2822 .filter_map(|record| record.state.as_ref())
2823 .filter_map(|persisted| persisted.extensions.get(ThreadCounterKey::KEY))
2824 .filter_map(serde_json::Value::as_u64)
2825 .max()
2826 .expect("thread counter should be persisted");
2827 assert_eq!(max_counter, 2, "counter should continue across runs");
2828 }
2829
2830 struct TruncatingLlm {
2837 call_count: AtomicUsize,
2838 followup_responses: Mutex<Vec<StreamResult>>,
2840 upstream_models_seen: Mutex<Vec<String>>,
2841 }
2842
2843 impl TruncatingLlm {
2844 fn new(followup_responses: Vec<StreamResult>) -> Self {
2845 Self {
2846 call_count: AtomicUsize::new(0),
2847 followup_responses: Mutex::new(followup_responses),
2848 upstream_models_seen: Mutex::new(Vec::new()),
2849 }
2850 }
2851 }
2852
2853 #[async_trait]
2854 impl LlmExecutor for TruncatingLlm {
2855 async fn execute(
2856 &self,
2857 _request: InferenceRequest,
2858 ) -> Result<StreamResult, InferenceExecutionError> {
2859 unreachable!("execute_stream is overridden");
2860 }
2861
2862 fn execute_stream(
2863 &self,
2864 request: InferenceRequest,
2865 ) -> std::pin::Pin<
2866 Box<
2867 dyn std::future::Future<
2868 Output = Result<
2869 awaken_contract::contract::executor::InferenceStream,
2870 InferenceExecutionError,
2871 >,
2872 > + Send
2873 + '_,
2874 >,
2875 > {
2876 use awaken_contract::contract::executor::{InferenceStream, LlmStreamEvent};
2877 use awaken_contract::contract::inference::TokenUsage;
2878
2879 Box::pin(async move {
2880 self.upstream_models_seen
2881 .lock()
2882 .expect("lock poisoned")
2883 .push(request.upstream_model.clone());
2884 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
2885 if n == 0 {
2886 let events: Vec<Result<LlmStreamEvent, InferenceExecutionError>> = vec![
2888 Ok(LlmStreamEvent::TextDelta("partial ".into())),
2889 Ok(LlmStreamEvent::ToolCallStart {
2890 id: "tc1".into(),
2891 name: "calculator".into(),
2892 }),
2893 Ok(LlmStreamEvent::ToolCallDelta {
2895 id: "tc1".into(),
2896 args_delta: r#"{"expr": "1+1"#.into(),
2897 }),
2898 Ok(LlmStreamEvent::Usage(TokenUsage {
2899 prompt_tokens: Some(50),
2900 completion_tokens: Some(100),
2901 ..Default::default()
2902 })),
2903 Ok(LlmStreamEvent::Stop(StopReason::MaxTokens)),
2904 ];
2905 Ok(Box::pin(futures::stream::iter(events)) as InferenceStream)
2906 } else {
2907 let mut followups = self.followup_responses.lock().expect("lock poisoned");
2909 let result = if followups.is_empty() {
2910 StreamResult {
2911 content: vec![ContentBlock::text("final response")],
2912 tool_calls: vec![],
2913 usage: None,
2914 stop_reason: Some(StopReason::EndTurn),
2915 has_incomplete_tool_calls: false,
2916 }
2917 } else {
2918 followups.remove(0)
2919 };
2920 let events =
2921 awaken_contract::contract::executor::collected_to_stream_events(result);
2922 Ok(Box::pin(futures::stream::iter(events)) as InferenceStream)
2923 }
2924 })
2925 }
2926
2927 fn name(&self) -> &str {
2928 "truncating"
2929 }
2930 }
2931
2932 #[tokio::test]
2933 async fn truncation_recovery_continues_on_max_tokens() {
2934 let llm = Arc::new(TruncatingLlm::new(vec![StreamResult {
2937 content: vec![ContentBlock::text("completed response")],
2938 tool_calls: vec![],
2939 usage: None,
2940 stop_reason: Some(StopReason::EndTurn),
2941 has_incomplete_tool_calls: false,
2942 }]));
2943 let resolver = Arc::new(FixedResolver {
2944 agent: ResolvedAgent::new("agent", "m", "sys", llm.clone())
2945 .with_max_continuation_retries(2),
2946 plugins: vec![],
2947 });
2948 let runtime = AgentRuntime::new(resolver);
2949 let sink: Arc<dyn EventSink> = Arc::new(NullEventSink);
2950
2951 let result = runtime
2952 .run(
2953 RunRequest::new("thread-trunc", vec![Message::user("hi")]).with_agent_id("agent"),
2954 sink.clone(),
2955 )
2956 .await
2957 .expect("run should succeed");
2958
2959 assert_eq!(
2960 result.termination,
2961 awaken_contract::contract::lifecycle::TerminationReason::NaturalEnd
2962 );
2963 assert_eq!(result.response, "completed response");
2965 assert_eq!(llm.call_count.load(Ordering::SeqCst), 2);
2967 }
2968
2969 #[tokio::test]
2970 async fn text_truncation_recovery_continues_on_max_tokens() {
2971 let llm = Arc::new(ScriptedLlm::new(vec![
2972 StreamResult {
2973 content: vec![ContentBlock::text("partial ")],
2974 tool_calls: vec![],
2975 usage: None,
2976 stop_reason: Some(StopReason::MaxTokens),
2977 has_incomplete_tool_calls: false,
2978 },
2979 StreamResult {
2980 content: vec![ContentBlock::text("completed")],
2981 tool_calls: vec![],
2982 usage: None,
2983 stop_reason: Some(StopReason::EndTurn),
2984 has_incomplete_tool_calls: false,
2985 },
2986 ]));
2987 let resolver = Arc::new(FixedResolver {
2988 agent: ResolvedAgent::new("agent", "m", "sys", llm.clone())
2989 .with_max_continuation_retries(2),
2990 plugins: vec![],
2991 });
2992 let runtime = AgentRuntime::new(resolver);
2993 let sink = Arc::new(VecEventSink::new());
2994
2995 let result = runtime
2996 .run(
2997 RunRequest::new("thread-text-trunc", vec![Message::user("hi")])
2998 .with_agent_id("agent"),
2999 sink.clone(),
3000 )
3001 .await
3002 .expect("run should succeed");
3003
3004 assert_eq!(
3005 result.termination,
3006 awaken_contract::contract::lifecycle::TerminationReason::NaturalEnd
3007 );
3008 assert_eq!(result.response, "completed");
3009 assert_eq!(llm.seen_overrides.lock().expect("lock poisoned").len(), 2);
3010
3011 let text_deltas: Vec<String> = sink
3012 .events()
3013 .into_iter()
3014 .filter_map(|event| match event {
3015 AgentEvent::TextDelta { delta } => Some(delta),
3016 _ => None,
3017 })
3018 .collect();
3019 assert_eq!(text_deltas, vec!["partial ", "completed"]);
3020 }
3021
3022 #[tokio::test]
3023 async fn truncation_recovery_preserves_model_override() {
3024 let llm = Arc::new(TruncatingLlm::new(vec![StreamResult {
3025 content: vec![ContentBlock::text("completed response")],
3026 tool_calls: vec![],
3027 usage: None,
3028 stop_reason: Some(StopReason::EndTurn),
3029 has_incomplete_tool_calls: false,
3030 }]));
3031 let resolver = Arc::new(FixedResolver {
3032 agent: ResolvedAgent::new("agent", "base-model", "sys", llm.clone())
3033 .with_max_continuation_retries(2),
3034 plugins: vec![],
3035 });
3036 let runtime = AgentRuntime::new(resolver);
3037 let sink: Arc<dyn EventSink> = Arc::new(NullEventSink);
3038
3039 let result = runtime
3040 .run(
3041 RunRequest::new("thread-trunc-override", vec![Message::user("hi")])
3042 .with_agent_id("agent")
3043 .with_overrides(InferenceOverride {
3044 upstream_model: Some("override-model".into()),
3045 ..Default::default()
3046 }),
3047 sink,
3048 )
3049 .await
3050 .expect("run should succeed");
3051
3052 assert_eq!(
3053 result.termination,
3054 awaken_contract::contract::lifecycle::TerminationReason::NaturalEnd
3055 );
3056 assert_eq!(
3057 llm.upstream_models_seen
3058 .lock()
3059 .expect("lock poisoned")
3060 .clone(),
3061 vec!["override-model".to_string(), "override-model".to_string()]
3062 );
3063 }
3064
3065 #[tokio::test]
3066 async fn truncation_recovery_gives_up_after_max_retries() {
3067 struct AlwaysTruncatingLlm {
3071 call_count: AtomicUsize,
3072 }
3073
3074 #[async_trait]
3075 impl LlmExecutor for AlwaysTruncatingLlm {
3076 async fn execute(
3077 &self,
3078 _request: InferenceRequest,
3079 ) -> Result<StreamResult, InferenceExecutionError> {
3080 unreachable!("execute_stream is overridden");
3081 }
3082
3083 fn execute_stream(
3084 &self,
3085 _request: InferenceRequest,
3086 ) -> std::pin::Pin<
3087 Box<
3088 dyn std::future::Future<
3089 Output = Result<
3090 awaken_contract::contract::executor::InferenceStream,
3091 InferenceExecutionError,
3092 >,
3093 > + Send
3094 + '_,
3095 >,
3096 > {
3097 use awaken_contract::contract::executor::{InferenceStream, LlmStreamEvent};
3098 use awaken_contract::contract::inference::TokenUsage;
3099
3100 Box::pin(async move {
3101 self.call_count.fetch_add(1, Ordering::SeqCst);
3102 let events: Vec<Result<LlmStreamEvent, InferenceExecutionError>> = vec![
3104 Ok(LlmStreamEvent::TextDelta("truncated ".into())),
3105 Ok(LlmStreamEvent::ToolCallStart {
3106 id: format!("tc{}", self.call_count.load(Ordering::SeqCst)),
3107 name: "calculator".into(),
3108 }),
3109 Ok(LlmStreamEvent::ToolCallDelta {
3110 id: format!("tc{}", self.call_count.load(Ordering::SeqCst)),
3111 args_delta: r#"{"incomplete"#.into(),
3112 }),
3113 Ok(LlmStreamEvent::Usage(TokenUsage {
3114 prompt_tokens: Some(50),
3115 completion_tokens: Some(100),
3116 ..Default::default()
3117 })),
3118 Ok(LlmStreamEvent::Stop(StopReason::MaxTokens)),
3119 ];
3120 Ok(Box::pin(futures::stream::iter(events)) as InferenceStream)
3121 })
3122 }
3123
3124 fn name(&self) -> &str {
3125 "always_truncating"
3126 }
3127 }
3128
3129 let llm = Arc::new(AlwaysTruncatingLlm {
3130 call_count: AtomicUsize::new(0),
3131 });
3132 let resolver = Arc::new(FixedResolver {
3133 agent: ResolvedAgent::new("agent", "m", "sys", llm.clone())
3134 .with_max_continuation_retries(2),
3135 plugins: vec![],
3136 });
3137 let runtime = AgentRuntime::new(resolver);
3138 let sink: Arc<dyn EventSink> = Arc::new(NullEventSink);
3139
3140 let result = runtime
3141 .run(
3142 RunRequest::new("thread-trunc-max", vec![Message::user("hi")])
3143 .with_agent_id("agent"),
3144 sink.clone(),
3145 )
3146 .await
3147 .expect("run should succeed");
3148
3149 assert_eq!(llm.call_count.load(Ordering::SeqCst), 3);
3151 assert_eq!(
3154 result.termination,
3155 awaken_contract::contract::lifecycle::TerminationReason::NaturalEnd
3156 );
3157 assert_eq!(result.response, "truncated ");
3158 }
3159
3160 mod strip_unpaired {
3163 use super::super::strip_unpaired_tool_calls;
3164 use awaken_contract::contract::message::{Message, Role, ToolCall};
3165
3166 fn assistant_with_calls(text: &str, call_ids: &[&str]) -> Message {
3167 let mut msg = Message::assistant(text);
3168 msg.tool_calls = Some(
3169 call_ids
3170 .iter()
3171 .map(|id| ToolCall {
3172 id: id.to_string(),
3173 name: "test_tool".into(),
3174 arguments: serde_json::json!({}),
3175 })
3176 .collect(),
3177 );
3178 msg
3179 }
3180
3181 fn tool_response(call_id: &str) -> Message {
3182 Message::tool(call_id, "result")
3183 }
3184
3185 #[test]
3186 fn paired_calls_unchanged() {
3187 let mut msgs = vec![
3188 Message::user("hi"),
3189 assistant_with_calls("calling", &["tc1"]),
3190 tool_response("tc1"),
3191 Message::assistant("done"),
3192 ];
3193 let original_len = msgs.len();
3194 strip_unpaired_tool_calls(&mut msgs);
3195 assert_eq!(msgs.len(), original_len);
3196 assert!(msgs[1].tool_calls.as_ref().unwrap().len() == 1);
3198 }
3199
3200 #[test]
3201 fn trailing_unpaired_calls_stripped() {
3202 let mut msgs = vec![
3203 Message::user("hi"),
3204 assistant_with_calls("calling", &["tc1", "tc2"]),
3205 tool_response("tc1"),
3206 ];
3208 strip_unpaired_tool_calls(&mut msgs);
3209 let calls = msgs[1].tool_calls.as_ref().unwrap();
3210 assert_eq!(calls.len(), 1);
3211 assert_eq!(calls[0].id, "tc1");
3212 }
3213
3214 #[test]
3215 fn all_unpaired_removes_tool_calls_field() {
3216 let mut msgs = vec![
3217 Message::user("hi"),
3218 assistant_with_calls("", &["tc1"]),
3219 ];
3221 strip_unpaired_tool_calls(&mut msgs);
3222 assert_eq!(msgs.len(), 1);
3224 assert_eq!(msgs[0].role, Role::User);
3225 }
3226
3227 #[test]
3228 fn middle_paired_not_affected() {
3229 let mut msgs = vec![
3230 Message::user("first"),
3231 assistant_with_calls("first call", &["tc1"]),
3232 tool_response("tc1"),
3233 Message::user("second"),
3234 assistant_with_calls("", &["tc2"]),
3235 ];
3237 strip_unpaired_tool_calls(&mut msgs);
3238 assert_eq!(msgs[1].tool_calls.as_ref().unwrap().len(), 1);
3240 assert_eq!(msgs.len(), 4); }
3243
3244 #[test]
3245 fn no_tool_calls_is_noop() {
3246 let mut msgs = vec![Message::user("hi"), Message::assistant("hello")];
3247 strip_unpaired_tool_calls(&mut msgs);
3248 assert_eq!(msgs.len(), 2);
3249 }
3250
3251 #[test]
3252 fn empty_messages_is_noop() {
3253 let mut msgs: Vec<Message> = vec![];
3254 strip_unpaired_tool_calls(&mut msgs);
3255 assert!(msgs.is_empty());
3256 }
3257 }
3258}