1use async_trait::async_trait;
2use futures::stream::StreamExt;
3use serde_json::{json, Value};
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::mpsc;
7use tokio_util::sync::CancellationToken;
8
9use crate::compaction::{estimate_messages_tokens, CompactionContext, CompactionStrategy};
10use crate::event::{HarnessInternalEvent, HarnessUsage, NativeHarnessError, NativeTurnInput};
11use crate::model::{
12 AssistantThinking, ChatMessage, ModelChunk, ModelClient, ModelClientError, ModelTurnInput,
13};
14use crate::runner::NativeHarness;
15use crate::tools::{
16 bounded::BoundedToolRuntime, ToolFailure, ToolFailureKind, ToolInvocation, ToolOutcome,
17 ToolRuntime, ToolRuntimeError,
18};
19
20#[derive(Clone)]
26pub struct CompactionPolicy {
27 pub strategy: Arc<dyn CompactionStrategy>,
28 pub model_client: Arc<dyn ModelClient>,
32 pub context_window_tokens: u64,
33}
34
35const DEFAULT_STREAM_IDLE_TIMEOUT: Duration = Duration::from_secs(90);
41
42const DEFAULT_STREAM_MAX_ATTEMPTS: u32 = 6;
49
50#[derive(Clone)]
51pub struct AgentLoopHarness<M, R> {
52 model: M,
53 tools: BoundedToolRuntime<R>,
57 max_steps: usize,
58 compaction: Option<CompactionPolicy>,
59 tool_choice: crate::model::ToolChoice,
60 parallel_tool_calls: Option<bool>,
61 stream_idle_timeout: Duration,
62 stream_max_attempts: u32,
63}
64
65impl<M, R: ToolRuntime> AgentLoopHarness<M, R> {
66 pub fn new(model: M, tools: R) -> Self {
67 Self {
68 model,
69 tools: BoundedToolRuntime::new(tools),
70 max_steps: 8,
71 compaction: None,
72 tool_choice: crate::model::ToolChoice::Auto,
73 parallel_tool_calls: None,
74 stream_idle_timeout: DEFAULT_STREAM_IDLE_TIMEOUT,
75 stream_max_attempts: DEFAULT_STREAM_MAX_ATTEMPTS,
76 }
77 }
78
79 pub fn with_max_steps(mut self, max_steps: usize) -> Self {
84 self.max_steps = max_steps;
85 self
86 }
87
88 pub fn with_compaction(mut self, policy: CompactionPolicy) -> Self {
94 self.compaction = Some(policy);
95 self
96 }
97
98 pub fn with_tool_choice(mut self, choice: crate::model::ToolChoice) -> Self {
101 self.tool_choice = choice;
102 self
103 }
104
105 pub fn with_parallel_tool_calls(mut self, parallel: Option<bool>) -> Self {
109 self.parallel_tool_calls = parallel;
110 self
111 }
112
113 pub fn with_stream_resilience(mut self, idle_timeout: Duration, max_attempts: u32) -> Self {
120 self.stream_idle_timeout = idle_timeout;
121 self.stream_max_attempts = max_attempts.max(1);
122 self
123 }
124}
125
126#[async_trait]
127impl<M, R> NativeHarness for AgentLoopHarness<M, R>
128where
129 M: ModelClient + Clone + Send + Sync + 'static,
130 R: ToolRuntime + Clone + Send + Sync + 'static,
131{
132 async fn run_turn(
133 &self,
134 input: NativeTurnInput,
135 ) -> Result<mpsc::Receiver<Result<HarnessInternalEvent, NativeHarnessError>>, NativeHarnessError>
136 {
137 let (tx, rx) = mpsc::channel(16);
138 let model = self.model.clone();
139 let tools = self.tools.clone();
140 let max_steps = self.max_steps;
141 let compaction = self.compaction.clone();
142 let tool_choice = self.tool_choice.clone();
143 let parallel_tool_calls = self.parallel_tool_calls;
144 let stream_idle_timeout = self.stream_idle_timeout;
145 let stream_max_attempts = self.stream_max_attempts;
146
147 tokio::spawn(async move {
148 run_loop(
149 model,
150 tools,
151 RunLoopConfig {
152 max_steps,
153 compaction,
154 tool_choice,
155 parallel_tool_calls,
156 stream_idle_timeout,
157 stream_max_attempts,
158 },
159 input,
160 tx,
161 )
162 .await;
163 });
164
165 Ok(rx)
166 }
167}
168
169fn cancel_fired(token: Option<&CancellationToken>) -> bool {
171 token.is_some_and(|t| t.is_cancelled())
172}
173
174struct RunLoopConfig {
175 max_steps: usize,
176 compaction: Option<CompactionPolicy>,
177 tool_choice: crate::model::ToolChoice,
178 parallel_tool_calls: Option<bool>,
179 stream_idle_timeout: Duration,
180 stream_max_attempts: u32,
181}
182
183async fn run_loop<M, R>(
184 model: M,
185 tools: R,
186 config: RunLoopConfig,
187 input: NativeTurnInput,
188 tx: mpsc::Sender<Result<HarnessInternalEvent, NativeHarnessError>>,
189) where
190 M: ModelClient + Send + Sync,
191 R: ToolRuntime + Clone + Send + Sync + 'static,
192{
193 let system_prompt = input.system_prompt.clone();
194 let cancel_token = input.cancel_token.clone();
195 let context_path = input.context_path.clone();
196 let tools_snapshot = tools.specs();
200 let mut messages: Vec<ChatMessage> = if let Some(ref path) = context_path {
203 crate::context::jsonl::load_context(path).await
204 } else {
205 input.prior_messages
206 };
207 messages.push(ChatMessage::User {
208 content: input.prompt_text,
209 attachments: input.attachments,
210 });
211 let mut ctx_written: usize = match context_path.as_deref() {
215 None => 0,
216 Some(path) => {
217 let start = messages.len() - 1;
218 crate::context::jsonl::append_context(path, &messages[start..]).await;
219 messages.len()
220 }
221 };
222 let mut total_usage = HarnessUsage::default();
226 let mut saw_any_usage = false;
227
228 macro_rules! check_cancel {
234 () => {
235 if cancel_fired(cancel_token.as_ref()) {
236 let _ = tx
237 .send(Ok(HarnessInternalEvent::TurnEnd {
238 stop_reason: "interrupt".into(),
239 usage: saw_any_usage.then(|| total_usage.clone()),
240 final_messages: if context_path.is_none() { messages.clone() } else { vec![] },
241 }))
242 .await;
243 return;
244 }
245 };
246 }
247
248 for step in 0.. {
249 if config.max_steps != 0 && step >= config.max_steps {
253 break;
254 }
255 check_cancel!();
256 if let Some(policy) = &config.compaction {
262 if policy
263 .strategy
264 .should_compact(&messages, policy.context_window_tokens)
265 {
266 let original_count = messages.len();
267 let original_tokens = estimate_messages_tokens(&messages);
268 let cctx = CompactionContext {
269 system_prompt: system_prompt.clone(),
270 model_client: policy.model_client.clone(),
271 context_window_tokens: policy.context_window_tokens,
272 tools: tools_snapshot.clone(),
273 };
274 match policy.strategy.compact(messages.clone(), &cctx).await {
275 Ok(outcome) => {
276 let compacted_count = outcome.messages.len();
277 let compacted_tokens = estimate_messages_tokens(&outcome.messages);
278 messages = outcome.messages;
279 if let Some(u) = outcome.usage.as_ref() {
285 saw_any_usage = true;
286 total_usage.input_tokens += u.input_tokens;
287 total_usage.output_tokens += u.output_tokens;
288 total_usage.cache_read_input_tokens += u.cache_read_input_tokens;
289 total_usage.cache_creation_input_tokens +=
290 u.cache_creation_input_tokens;
291 total_usage.compaction_input_tokens += u.input_tokens;
292 total_usage.compaction_output_tokens += u.output_tokens;
293 }
294 tracing::info!(
298 target: "harness::compaction",
299 step,
300 original_message_count = original_count,
301 compacted_message_count = compacted_count,
302 original_estimated_tokens = original_tokens,
303 compacted_estimated_tokens = compacted_tokens,
304 context_window_tokens = policy.context_window_tokens,
305 "compaction applied"
306 );
307 if let Some(ref path) = context_path {
309 crate::context::jsonl::rewrite_context(path, &messages).await;
310 ctx_written = messages.len();
311 }
312 if tx
313 .send(Ok(HarnessInternalEvent::CompactionApplied {
314 original_message_count: original_count,
315 compacted_message_count: compacted_count,
316 original_tokens,
317 compacted_tokens,
318 }))
319 .await
320 .is_err()
321 {
322 return;
323 }
324 }
325 Err(e) => {
326 tracing::warn!(
327 target: "harness::compaction",
328 step,
329 error = %e,
330 "compaction skipped; history retained as-is, model call may now fail with context overflow"
331 );
332 }
333 }
334 }
335 }
336
337 const MAX_RETRIES: u32 = 3;
342 const BASE_BACKOFF_MS: u64 = 1_000;
343 const MAX_BACKOFF_MS: u64 = 16_000;
344
345 let model_input = ModelTurnInput {
346 system_prompt: system_prompt.clone(),
347 messages: messages.clone(),
348 tools: tools_snapshot.clone(),
349 tool_choice: config.tool_choice.clone(),
350 parallel_tool_calls: config.parallel_tool_calls,
351 };
352
353 let mut stream_attempt = 0u32;
363 let outcome = 'stream: loop {
364 let stream = {
365 let mut attempt = 0u32;
366 loop {
367 match model.stream(model_input.clone()).await {
368 Ok(s) => break s,
369 Err(e) => {
370 if e.retryable() && attempt < MAX_RETRIES {
371 let delay_ms =
372 (BASE_BACKOFF_MS * (1 << attempt)).min(MAX_BACKOFF_MS);
373 tracing::warn!(
374 attempt,
375 delay_ms,
376 error = %e,
377 "model call failed (retryable) — backing off"
378 );
379 if !backoff_sleep(delay_ms, cancel_token.as_ref()).await {
380 let _ = tx
381 .send(Err(NativeHarnessError::ModelOther(
382 "interrupted during retry backoff".into(),
383 )))
384 .await;
385 return;
386 }
387 attempt += 1;
388 } else {
389 tracing::error!(
392 attempt,
393 error = %e,
394 retryable = e.retryable(),
395 "model call failed — terminating turn"
396 );
397 let _ = tx.send(Err(model_error_to_native(e))).await;
398 return;
399 }
400 }
401 }
402 }
403 };
404
405 match consume_step_stream(
410 stream,
411 &tx,
412 step,
413 cancel_token.as_ref(),
414 config.stream_idle_timeout,
415 )
416 .await
417 {
418 Ok(StepDrain::Complete(o)) => break 'stream o,
419 Ok(StepDrain::Cancelled) => {
420 let _ = tx
421 .send(Ok(HarnessInternalEvent::TurnEnd {
422 stop_reason: "interrupt".into(),
423 usage: saw_any_usage.then(|| total_usage.clone()),
424 final_messages: if context_path.is_none() { messages.clone() } else { vec![] },
425 }))
426 .await;
427 return;
428 }
429 Err(StepFailure::Model { err, had_progress }) => {
430 if !had_progress
433 && err.retryable()
434 && stream_attempt + 1 < config.stream_max_attempts
435 {
436 let delay_ms =
437 (BASE_BACKOFF_MS * (1 << stream_attempt)).min(MAX_BACKOFF_MS);
438 tracing::warn!(
439 step,
440 stream_attempt,
441 delay_ms,
442 error = %err,
443 "model stream failed before any output — reconnecting"
444 );
445 if !backoff_sleep(delay_ms, cancel_token.as_ref()).await {
446 let _ = tx
447 .send(Err(NativeHarnessError::ModelOther(
448 "interrupted during stream reconnect backoff".into(),
449 )))
450 .await;
451 return;
452 }
453 stream_attempt += 1;
454 continue 'stream;
455 }
456 tracing::error!(
459 step,
460 stream_attempt,
461 error = %err,
462 had_progress,
463 retryable = err.retryable(),
464 "model stream failed — terminating turn"
465 );
466 let _ = tx.send(Err(model_error_to_native(err))).await;
467 return;
468 }
469 Err(StepFailure::ChannelClosed) => return,
470 Err(StepFailure::Fatal(e)) => {
471 let _ = tx.send(Err(e)).await;
472 return;
473 }
474 }
475 };
476
477 if let Some(u) = outcome.usage.as_ref() {
478 saw_any_usage = true;
479 total_usage.input_tokens += u.input_tokens;
480 total_usage.output_tokens += u.output_tokens;
481 total_usage.cache_read_input_tokens += u.cache_read_input_tokens;
482 total_usage.cache_creation_input_tokens += u.cache_creation_input_tokens;
483 }
484
485 match outcome.next {
486 StepNext::Message { text, stop_reason } => {
487 let assistant_text = (!text.is_empty()).then_some(text);
488 messages.push(ChatMessage::Assistant {
489 text: assistant_text,
490 tool_calls: vec![],
491 thinking: outcome.thinking.clone(),
492 });
493 if let Some(ref path) = context_path {
495 crate::context::jsonl::append_context(path, &messages[ctx_written..]).await;
496 }
497 let final_msgs = if context_path.is_none() { messages.clone() } else { vec![] };
500 let _ = tx
501 .send(Ok(HarnessInternalEvent::TurnEnd {
502 stop_reason,
503 usage: saw_any_usage.then(|| total_usage.clone()),
504 final_messages: final_msgs,
505 }))
506 .await;
507 return;
508 }
509 StepNext::ToolCalls {
510 preface,
511 mut invocations,
512 } => {
513 check_cancel!();
514 for inv in &mut invocations {
522 if let Some(repairs) = tools.repair_invocation(inv) {
527 tracing::warn!(
528 target: "harness::tool_repair",
529 tool = %inv.name,
530 id = %inv.id,
531 repairs = ?repairs,
532 "schema-guided tool input repair applied"
533 );
534 }
535 }
536 let preface_text = preface.filter(|s| !s.is_empty());
537 messages.push(ChatMessage::Assistant {
545 text: preface_text,
546 tool_calls: invocations.clone(),
547 thinking: outcome.thinking.clone(),
548 });
549
550 for inv in &invocations {
556 if tx
557 .send(Ok(HarnessInternalEvent::ToolCall {
558 id: inv.id.clone(),
559 name: inv.name.clone(),
560 input: inv.input.clone(),
561 }))
562 .await
563 .is_err()
564 {
565 return;
566 }
567 }
568
569 let handles = invocations.iter().cloned().map(|inv| {
575 let tools = tools.clone();
576 let cancel_for_task = cancel_token.clone();
577 let invocation_for_task = inv.clone();
578 let handle = tokio::spawn(async move {
579 tools
580 .invoke_cancellable(invocation_for_task, cancel_for_task.as_ref())
581 .await
582 });
583 (inv, handle)
584 });
585 let join = futures::future::join_all(handles.map(|(inv, handle)| async move {
586 let outcome = match handle.await {
587 Ok(outcome) => outcome,
588 Err(e) => Err(ToolRuntimeError::Runtime(format!("tool task failed: {e}"))),
589 };
590 (inv, outcome)
591 }));
592
593 let pairs_opt = if let Some(token) = cancel_token.as_ref() {
594 tokio::select! {
595 biased;
596 _ = token.cancelled() => None,
597 results = join => Some(results),
598 }
599 } else {
600 Some(join.await)
601 };
602 let pairs = match pairs_opt {
603 Some(o) => o,
604 None => {
605 let _ = tx
611 .send(Ok(HarnessInternalEvent::TurnEnd {
612 stop_reason: "interrupt".into(),
613 usage: saw_any_usage.then(|| total_usage.clone()),
614 final_messages: if context_path.is_none() { messages.clone() } else { vec![] },
615 }))
616 .await;
617 return;
618 }
619 };
620
621 let mut runtime_error: Option<String> = None;
626 for (inv, outcome) in pairs {
627 let id = inv.id.clone();
628 let outcome = match outcome {
629 Ok(o) => o,
630 Err(ToolRuntimeError::Timeout(message)) => ToolOutcome {
631 output: Err(ToolFailure::new(ToolFailureKind::Timeout, message)),
632 attachments: vec![],
633 },
634 Err(ToolRuntimeError::InvalidInput { tool, message }) => ToolOutcome {
635 output: Err(crate::tools::invalid_input_failure(
636 &tool,
637 message,
638 &inv.input,
639 tools_snapshot
640 .iter()
641 .find(|s| s.name == tool)
642 .map(|s| &s.input_schema),
643 )),
644 attachments: vec![],
645 },
646 Err(e) => {
647 runtime_error = Some(e.to_string());
653 break;
654 }
655 };
656 let tool_attachments = outcome.attachments;
657 let output = outcome.output.map_err(|failure| failure.to_string());
658
659 let (tool_content, is_error) = match &output {
665 Ok(value) => (value.to_string(), false),
666 Err(err) => (json!({ "error": err }).to_string(), true),
667 };
668 messages.push(ChatMessage::Tool {
669 tool_call_id: id.clone(),
670 content: tool_content,
671 is_error,
672 attachments: tool_attachments,
673 });
674
675 if tx
676 .send(Ok(HarnessInternalEvent::ToolResult { id, output }))
677 .await
678 .is_err()
679 {
680 return;
681 }
682 }
683 if let Some(err) = runtime_error {
684 let _ = tx.send(Err(NativeHarnessError::ToolRuntime(err))).await;
685 return;
686 }
687 if let Some(ref path) = context_path {
689 crate::context::jsonl::append_context(path, &messages[ctx_written..]).await;
690 ctx_written = messages.len();
691 }
692 }
695 }
696 }
697
698 if let Some(ref path) = context_path {
700 crate::context::jsonl::append_context(path, &messages[ctx_written..]).await;
701 }
702 let final_msgs = if context_path.is_none() { messages } else { vec![] };
703 let _ = tx
704 .send(Ok(HarnessInternalEvent::TurnEnd {
705 stop_reason: "max_turns".into(),
706 usage: saw_any_usage.then(|| total_usage.clone()),
707 final_messages: final_msgs,
708 }))
709 .await;
710}
711
712async fn backoff_sleep(delay_ms: u64, cancel_token: Option<&CancellationToken>) -> bool {
723 let sleep = tokio::time::sleep(Duration::from_millis(delay_ms));
724 tokio::pin!(sleep);
725 let cancelled = async {
726 if let Some(t) = cancel_token {
727 t.cancelled().await
728 } else {
729 std::future::pending().await
730 }
731 };
732 tokio::select! {
733 _ = &mut sleep => true,
734 _ = cancelled => false,
735 }
736}
737
738fn model_error_to_native(err: ModelClientError) -> NativeHarnessError {
739 match err {
740 ModelClientError::RateLimit(s) => NativeHarnessError::ModelRateLimit(s),
741 ModelClientError::Auth(s) => NativeHarnessError::ModelAuth(s),
742 ModelClientError::ContextOverflow(s) => NativeHarnessError::ModelContextOverflow(s),
743 ModelClientError::BadRequest(s) => NativeHarnessError::ModelBadRequest(s),
744 ModelClientError::ServerError(s) => NativeHarnessError::ModelServerError(s),
745 ModelClientError::Network(s) => NativeHarnessError::ModelNetwork(s),
746 ModelClientError::Other(s) => NativeHarnessError::ModelOther(s),
747 }
748}
749
750struct StepOutcome {
758 next: StepNext,
759 usage: Option<HarnessUsage>,
760 thinking: Option<AssistantThinking>,
761}
762
763enum StepDrain {
768 Complete(StepOutcome),
769 Cancelled,
770}
771
772enum StepFailure {
776 Model {
783 err: ModelClientError,
784 had_progress: bool,
785 },
786 ChannelClosed,
789 Fatal(NativeHarnessError),
792}
793
794enum StepNext {
795 Message {
796 text: String,
797 stop_reason: String,
798 },
799 ToolCalls {
805 preface: Option<String>,
806 invocations: Vec<ToolInvocation>,
807 },
808}
809
810fn stall_failure(idle_timeout: Duration, had_progress: bool) -> StepFailure {
821 StepFailure::Model {
822 err: ModelClientError::Network(format!(
823 "model stream stalled: no output for {}s (connection open but idle)",
824 idle_timeout.as_secs()
825 )),
826 had_progress,
827 }
828}
829
830async fn consume_step_stream(
831 mut stream: futures::stream::BoxStream<'static, Result<ModelChunk, ModelClientError>>,
832 tx: &mpsc::Sender<Result<HarnessInternalEvent, NativeHarnessError>>,
833 step: usize,
834 cancel_token: Option<&CancellationToken>,
835 idle_timeout: Duration,
836) -> Result<StepDrain, StepFailure> {
837 let emit_msg_id = format!("msg_native_{step}");
838 let emit_thinking_id = format!("thinking_native_{step}");
839 let mut text_buf = String::new();
840 let mut thinking_buf = String::new();
841 let mut thinking_signature: Option<String> = None;
842 let mut saw_thinking = false;
843 let mut tool_states: Vec<ToolBuf> = Vec::new();
844 let mut stop_reason = "end_turn".to_string();
845 let mut usage: Option<HarnessUsage> = None;
846 let mut had_progress = false;
849
850 loop {
851 let idle = tokio::time::sleep(idle_timeout);
858 tokio::pin!(idle);
859
860 let item = if let Some(token) = cancel_token {
866 tokio::select! {
867 biased;
868 _ = token.cancelled() => {
869 return Ok(StepDrain::Cancelled);
870 }
871 _ = &mut idle => return Err(stall_failure(idle_timeout, had_progress)),
872 next = stream.next() => next,
873 }
874 } else {
875 tokio::select! {
876 _ = &mut idle => return Err(stall_failure(idle_timeout, had_progress)),
877 next = stream.next() => next,
878 }
879 };
880 let Some(item) = item else { break };
881 let chunk = match item {
882 Ok(c) => c,
883 Err(e) => {
884 return Err(StepFailure::Model {
885 err: e,
886 had_progress,
887 })
888 }
889 };
890 match chunk {
891 ModelChunk::TextDelta { msg_id: _, delta } => {
892 if delta.is_empty() {
893 continue;
894 }
895 text_buf.push_str(&delta);
896 had_progress = true;
899 if tx
903 .send(Ok(HarnessInternalEvent::AssistantTextChunk {
904 msg_id: emit_msg_id.clone(),
905 delta,
906 }))
907 .await
908 .is_err()
909 {
910 return Err(StepFailure::ChannelClosed);
911 }
912 }
913 ModelChunk::ThinkingDelta {
914 thinking_id: _,
915 delta,
916 signature,
917 } => {
918 if let Some(sig) = signature {
925 if !sig.is_empty() {
926 thinking_signature = Some(sig);
927 }
928 }
929 if !delta.is_empty() {
930 saw_thinking = true;
931 had_progress = true;
932 thinking_buf.push_str(&delta);
933 if tx
934 .send(Ok(HarnessInternalEvent::AssistantThinkingChunk {
935 msg_id: emit_thinking_id.clone(),
936 delta,
937 }))
938 .await
939 .is_err()
940 {
941 return Err(StepFailure::ChannelClosed);
942 }
943 }
944 }
945 ModelChunk::ToolCallStart { id, name } => {
946 had_progress = true;
952 tool_states.push(ToolBuf {
953 id,
954 name,
955 args_buf: String::new(),
956 early_input: None,
957 });
958 }
959 ModelChunk::ToolCallInputDelta { id, delta } => {
960 if let Some(s) = tool_states.iter_mut().find(|s| s.id == id) {
961 s.args_buf.push_str(&delta);
962 }
963 }
964 ModelChunk::ToolCallEnd { id, input } => {
965 if let Some(s) = tool_states.iter_mut().find(|s| s.id == id) {
966 s.early_input = input;
967 }
968 }
969 ModelChunk::Done {
970 stop_reason: sr,
971 usage: u,
972 } => {
973 stop_reason = sr;
974 usage = u;
975 }
976 }
977 }
978
979 let thinking = if saw_thinking || thinking_signature.is_some() {
983 Some(AssistantThinking {
984 text: thinking_buf,
985 signature: thinking_signature,
986 })
987 } else {
988 None
989 };
990
991 if !tool_states.is_empty() {
997 let mut invocations = Vec::with_capacity(tool_states.len());
998 for state in tool_states {
999 let parsed_input = match state.early_input {
1000 Some(v) => v,
1001 None => {
1002 let trimmed = state.args_buf.trim();
1003 if trimmed.is_empty() {
1004 Value::Object(serde_json::Map::new())
1008 } else {
1009 match serde_json::from_str(trimmed) {
1010 Ok(v) => v,
1011 Err(e) => {
1012 let res = crate::tool_repair::repair_truncated_json(trimmed);
1018 match serde_json::from_str(&res.repaired) {
1019 Ok(v) if res.changed => {
1020 tracing::warn!(
1021 target: "harness::tool_repair",
1022 tool = %state.name,
1023 id = %state.id,
1024 notes = ?res.notes,
1025 "repaired malformed tool arguments"
1026 );
1027 v
1028 }
1029 _ => {
1030 return Err(StepFailure::Fatal(
1031 NativeHarnessError::ModelOther(format!(
1032 "decode tool arguments for {id}: {e}",
1033 id = state.id
1034 )),
1035 ))
1036 }
1037 }
1038 }
1039 }
1040 }
1041 }
1042 };
1043 invocations.push(ToolInvocation {
1044 id: state.id,
1045 name: state.name,
1046 input: parsed_input,
1047 });
1048 }
1049 return Ok(StepDrain::Complete(StepOutcome {
1050 next: StepNext::ToolCalls {
1051 preface: (!text_buf.is_empty()).then_some(text_buf),
1052 invocations,
1053 },
1054 usage,
1055 thinking,
1056 }));
1057 }
1058
1059 Ok(StepDrain::Complete(StepOutcome {
1060 next: StepNext::Message {
1061 text: text_buf,
1062 stop_reason,
1063 },
1064 usage,
1065 thinking,
1066 }))
1067}
1068
1069struct ToolBuf {
1070 id: String,
1071 name: String,
1072 args_buf: String,
1073 early_input: Option<Value>,
1074}
1075
1076#[cfg(test)]
1077mod tests {
1078 use super::*;
1079 use crate::compaction::{CompactionContext, CompactionError, CompactionStrategy};
1080 use crate::model::{ModelChunk, ModelClient, ModelClientError, ModelResponse};
1081 use crate::tools::{ToolInvocation, ToolOutcome};
1082 use crate::{HarnessInternalEvent, MockToolRuntime, ScriptedModelClient};
1083 use async_trait::async_trait;
1084 use futures::stream::{BoxStream, StreamExt};
1085 use std::sync::atomic::{AtomicUsize, Ordering};
1086 use std::sync::{Arc, Mutex};
1087
1088 #[derive(Clone)]
1092 struct QueueModelClient {
1093 queue: Arc<Mutex<Vec<ModelResponse>>>,
1094 }
1095
1096 impl QueueModelClient {
1097 fn new(responses: Vec<ModelResponse>) -> Self {
1098 Self {
1099 queue: Arc::new(Mutex::new(responses)),
1100 }
1101 }
1102 }
1103
1104 #[async_trait]
1105 impl ModelClient for QueueModelClient {
1106 async fn stream(
1107 &self,
1108 _input: ModelTurnInput,
1109 ) -> Result<BoxStream<'static, Result<ModelChunk, ModelClientError>>, ModelClientError>
1110 {
1111 let mut q = self.queue.lock().unwrap();
1112 if q.is_empty() {
1113 return Err(ModelClientError::Other("queue exhausted".into()));
1114 }
1115 let response = q.remove(0);
1116 let chunks = response_to_chunks(response);
1117 Ok(futures::stream::iter(chunks.into_iter().map(Ok)).boxed())
1118 }
1119 }
1120
1121 fn response_to_chunks(response: ModelResponse) -> Vec<ModelChunk> {
1125 match response {
1126 ModelResponse::Message {
1127 text,
1128 stop_reason,
1129 usage,
1130 } => {
1131 let mut out = Vec::new();
1132 if !text.is_empty() {
1133 out.push(ModelChunk::TextDelta {
1134 msg_id: "queue_msg".into(),
1135 delta: text,
1136 });
1137 }
1138 out.push(ModelChunk::Done { stop_reason, usage });
1139 out
1140 }
1141 ModelResponse::ToolCall {
1142 preface,
1143 invocation,
1144 usage,
1145 } => {
1146 let mut out = Vec::new();
1147 if let Some(p) = preface {
1148 if !p.is_empty() {
1149 out.push(ModelChunk::TextDelta {
1150 msg_id: "queue_msg".into(),
1151 delta: p,
1152 });
1153 }
1154 }
1155 out.push(ModelChunk::ToolCallStart {
1156 id: invocation.id.clone(),
1157 name: invocation.name.clone(),
1158 });
1159 out.push(ModelChunk::ToolCallEnd {
1160 id: invocation.id.clone(),
1161 input: Some(invocation.input.clone()),
1162 });
1163 out.push(ModelChunk::Done {
1164 stop_reason: "end_turn".into(),
1165 usage,
1166 });
1167 out
1168 }
1169 }
1170 }
1171
1172 fn usage(input: u64, output: u64, cache_read: u64) -> HarnessUsage {
1173 HarnessUsage {
1174 input_tokens: input,
1175 output_tokens: output,
1176 cache_read_input_tokens: cache_read,
1177 cache_creation_input_tokens: 0,
1178 compaction_input_tokens: 0,
1179 compaction_output_tokens: 0,
1180 }
1181 }
1182
1183 #[tokio::test]
1184 async fn agent_loop_accumulates_usage_across_steps() {
1185 let model = QueueModelClient::new(vec![
1187 ModelResponse::ToolCall {
1188 preface: None,
1189 invocation: ToolInvocation {
1190 id: "tc_1".into(),
1191 name: "bash".into(),
1192 input: serde_json::json!({"command": "pwd"}),
1193 },
1194 usage: Some(usage(10, 5, 0)),
1195 },
1196 ModelResponse::Message {
1197 text: "done".into(),
1198 stop_reason: "end_turn".into(),
1199 usage: Some(usage(20, 15, 4)),
1200 },
1201 ]);
1202 let harness = AgentLoopHarness::new(model, MockToolRuntime::new());
1203 let mut rx = harness
1204 .run_turn(NativeTurnInput {
1205 prompt_text: "pwd".into(),
1206 system_prompt: None,
1207 attachments: vec![],
1208 cancel_token: None,
1209 prior_messages: vec![],
1210 context_path: None,
1211 })
1212 .await
1213 .unwrap();
1214
1215 let mut final_usage = None;
1217 while let Some(item) = rx.recv().await {
1218 if let HarnessInternalEvent::TurnEnd { usage: u, .. } = item.unwrap() {
1219 final_usage = u;
1220 break;
1221 }
1222 }
1223 let u = final_usage.expect("TurnEnd carried usage");
1224 assert_eq!(u.input_tokens, 30);
1225 assert_eq!(u.output_tokens, 20);
1226 assert_eq!(u.cache_read_input_tokens, 4);
1227 }
1228
1229 #[tokio::test]
1230 async fn agent_loop_turn_end_usage_is_none_when_no_step_reported() {
1231 let model = QueueModelClient::new(vec![ModelResponse::Message {
1233 text: "ok".into(),
1234 stop_reason: "end_turn".into(),
1235 usage: None,
1236 }]);
1237 let harness = AgentLoopHarness::new(model, MockToolRuntime::new());
1238 let mut rx = harness
1239 .run_turn(NativeTurnInput {
1240 prompt_text: "noop".into(),
1241 system_prompt: None,
1242 attachments: vec![],
1243 cancel_token: None,
1244 prior_messages: vec![],
1245 context_path: None,
1246 })
1247 .await
1248 .unwrap();
1249 let mut saw_usage = None;
1250 while let Some(item) = rx.recv().await {
1251 if let HarnessInternalEvent::TurnEnd { usage, .. } = item.unwrap() {
1252 saw_usage = Some(usage);
1253 break;
1254 }
1255 }
1256 assert_eq!(saw_usage.unwrap(), None);
1257 }
1258
1259 #[derive(Clone)]
1264 struct StreamingFakeClient {
1265 chunks_per_call: Arc<Mutex<Vec<Vec<ModelChunk>>>>,
1266 }
1267
1268 impl StreamingFakeClient {
1269 fn new(per_call: Vec<Vec<ModelChunk>>) -> Self {
1270 Self {
1271 chunks_per_call: Arc::new(Mutex::new(per_call)),
1272 }
1273 }
1274 }
1275
1276 #[async_trait]
1277 impl ModelClient for StreamingFakeClient {
1278 async fn stream(
1279 &self,
1280 _input: ModelTurnInput,
1281 ) -> Result<BoxStream<'static, Result<ModelChunk, ModelClientError>>, ModelClientError>
1282 {
1283 let mut bucket = self.chunks_per_call.lock().unwrap();
1284 if bucket.is_empty() {
1285 return Err(ModelClientError::Other("queue exhausted".into()));
1286 }
1287 let chunks = bucket.remove(0);
1288 Ok(futures::stream::iter(chunks.into_iter().map(Ok)).boxed())
1289 }
1290 }
1291
1292 #[tokio::test]
1293 async fn agent_loop_forwards_token_chunks_to_harness_output() {
1294 let model = StreamingFakeClient::new(vec![vec![
1295 ModelChunk::TextDelta {
1296 msg_id: "remote_msg".into(),
1297 delta: "Hel".into(),
1298 },
1299 ModelChunk::TextDelta {
1300 msg_id: "remote_msg".into(),
1301 delta: "lo ".into(),
1302 },
1303 ModelChunk::TextDelta {
1304 msg_id: "remote_msg".into(),
1305 delta: "world".into(),
1306 },
1307 ModelChunk::Done {
1308 stop_reason: "end_turn".into(),
1309 usage: None,
1310 },
1311 ]]);
1312 let harness = AgentLoopHarness::new(model, MockToolRuntime::new());
1313 let mut rx = harness
1314 .run_turn(NativeTurnInput {
1315 prompt_text: "hi".into(),
1316 system_prompt: None,
1317 attachments: vec![],
1318 cancel_token: None,
1319 prior_messages: vec![],
1320 context_path: None,
1321 })
1322 .await
1323 .unwrap();
1324
1325 let mut deltas: Vec<String> = Vec::new();
1326 let mut saw_end = false;
1327 while let Some(item) = rx.recv().await {
1328 match item.unwrap() {
1329 HarnessInternalEvent::AssistantTextChunk { msg_id, delta } => {
1330 assert_eq!(msg_id, "msg_native_0");
1334 deltas.push(delta);
1335 }
1336 HarnessInternalEvent::TurnEnd { stop_reason, .. } => {
1337 assert_eq!(stop_reason, "end_turn");
1338 saw_end = true;
1339 break;
1340 }
1341 other => panic!("unexpected event: {other:?}"),
1342 }
1343 }
1344 assert_eq!(deltas, vec!["Hel", "lo ", "world"]);
1345 assert!(saw_end);
1346 }
1347
1348 #[tokio::test]
1349 async fn agent_loop_streaming_tool_call_then_summary() {
1350 let model = StreamingFakeClient::new(vec![
1357 vec![
1358 ModelChunk::TextDelta {
1359 msg_id: "r1".into(),
1360 delta: "running ".into(),
1361 },
1362 ModelChunk::ToolCallStart {
1363 id: "tc_1".into(),
1364 name: "bash".into(),
1365 },
1366 ModelChunk::ToolCallInputDelta {
1367 id: "tc_1".into(),
1368 delta: "{\"command\":".into(),
1369 },
1370 ModelChunk::ToolCallInputDelta {
1371 id: "tc_1".into(),
1372 delta: "\"pwd\"}".into(),
1373 },
1374 ModelChunk::ToolCallEnd {
1375 id: "tc_1".into(),
1376 input: None,
1377 },
1378 ModelChunk::Done {
1379 stop_reason: "end_turn".into(),
1380 usage: None,
1381 },
1382 ],
1383 vec![
1384 ModelChunk::TextDelta {
1385 msg_id: "r2".into(),
1386 delta: "done".into(),
1387 },
1388 ModelChunk::Done {
1389 stop_reason: "end_turn".into(),
1390 usage: None,
1391 },
1392 ],
1393 ]);
1394 let harness = AgentLoopHarness::new(model, MockToolRuntime::new());
1395 let mut rx = harness
1396 .run_turn(NativeTurnInput {
1397 prompt_text: "pwd".into(),
1398 system_prompt: None,
1399 attachments: vec![],
1400 cancel_token: None,
1401 prior_messages: vec![],
1402 context_path: None,
1403 })
1404 .await
1405 .unwrap();
1406
1407 let ev = rx.recv().await.unwrap().unwrap();
1414 assert!(matches!(
1415 ev,
1416 HarnessInternalEvent::AssistantTextChunk { ref delta, .. } if delta == "running "
1417 ));
1418 let ev = rx.recv().await.unwrap().unwrap();
1419 let HarnessInternalEvent::ToolCall { name, input, .. } = ev else {
1420 panic!("expected ToolCall");
1421 };
1422 assert_eq!(name, "bash");
1423 assert_eq!(input["command"], "pwd");
1424 let ev = rx.recv().await.unwrap().unwrap();
1425 assert!(matches!(ev, HarnessInternalEvent::ToolResult { .. }));
1426 let ev = rx.recv().await.unwrap().unwrap();
1427 assert!(matches!(
1428 ev,
1429 HarnessInternalEvent::AssistantTextChunk { ref delta, .. } if delta == "done"
1430 ));
1431 let ev = rx.recv().await.unwrap().unwrap();
1432 assert!(matches!(ev, HarnessInternalEvent::TurnEnd { .. }));
1433 }
1434
1435 #[tokio::test]
1436 async fn agent_loop_repairs_truncated_tool_arguments() {
1437 let model = StreamingFakeClient::new(vec![
1441 vec![
1442 ModelChunk::ToolCallStart {
1443 id: "tc_trunc".into(),
1444 name: "bash".into(),
1445 },
1446 ModelChunk::ToolCallInputDelta {
1447 id: "tc_trunc".into(),
1448 delta: r#"{"command":"pwd""#.into(), },
1450 ModelChunk::ToolCallEnd {
1451 id: "tc_trunc".into(),
1452 input: None,
1453 },
1454 ModelChunk::Done {
1455 stop_reason: "tool_use".into(),
1456 usage: None,
1457 },
1458 ],
1459 vec![
1460 ModelChunk::TextDelta {
1461 msg_id: "r2".into(),
1462 delta: "done".into(),
1463 },
1464 ModelChunk::Done {
1465 stop_reason: "end_turn".into(),
1466 usage: None,
1467 },
1468 ],
1469 ]);
1470 let harness = AgentLoopHarness::new(model, MockToolRuntime::new());
1471 let mut rx = harness
1472 .run_turn(NativeTurnInput {
1473 prompt_text: "pwd".into(),
1474 system_prompt: None,
1475 attachments: vec![],
1476 cancel_token: None,
1477 prior_messages: vec![],
1478 context_path: None,
1479 })
1480 .await
1481 .unwrap();
1482
1483 let mut saw_tool_call = false;
1484 let mut saw_turn_end = false;
1485 while let Some(item) = rx.recv().await {
1486 match item.expect("turn must not fail on truncated args") {
1487 HarnessInternalEvent::ToolCall { name, input, .. } => {
1488 assert_eq!(name, "bash");
1489 assert_eq!(input["command"], "pwd", "repaired args reach the wire");
1490 saw_tool_call = true;
1491 }
1492 HarnessInternalEvent::TurnEnd { .. } => {
1493 saw_turn_end = true;
1494 break;
1495 }
1496 _ => {}
1497 }
1498 }
1499 assert!(saw_tool_call, "expected ToolCall with repaired input");
1500 assert!(saw_turn_end);
1501 }
1502
1503 #[derive(Clone)]
1506 struct ProbeToolRuntime {
1507 seen_input: Arc<Mutex<Option<Value>>>,
1508 }
1509
1510 #[async_trait]
1511 impl ToolRuntime for ProbeToolRuntime {
1512 fn specs(&self) -> Vec<crate::tools::ToolSpec> {
1513 vec![crate::tools::ToolSpec {
1514 name: "probe".into(),
1515 description: "records its input".into(),
1516 input_schema: serde_json::json!({
1517 "type": "object",
1518 "properties": {
1519 "pattern": {"type": "string"},
1520 "literal": {"type": "boolean"},
1521 "limit": {"type": "integer"}
1522 },
1523 "required": ["pattern"]
1524 }),
1525 }]
1526 }
1527
1528 async fn invoke(
1529 &self,
1530 invocation: ToolInvocation,
1531 ) -> Result<ToolOutcome, ToolRuntimeError> {
1532 *self.seen_input.lock().unwrap() = Some(invocation.input);
1533 Ok(ToolOutcome {
1534 output: Ok(r#"{"ok":true}"#.into()),
1535 attachments: vec![],
1536 })
1537 }
1538 }
1539
1540 #[tokio::test]
1541 async fn agent_loop_applies_schema_repair_before_dispatch() {
1542 let model = StreamingFakeClient::new(vec![
1546 vec![
1547 ModelChunk::ToolCallStart {
1548 id: "tc_shape".into(),
1549 name: "probe".into(),
1550 },
1551 ModelChunk::ToolCallEnd {
1552 id: "tc_shape".into(),
1553 input: Some(json!({"pattern": "x", "literal": "true", "limit": "30"})),
1554 },
1555 ModelChunk::Done {
1556 stop_reason: "tool_use".into(),
1557 usage: None,
1558 },
1559 ],
1560 vec![
1561 ModelChunk::TextDelta {
1562 msg_id: "r2".into(),
1563 delta: "done".into(),
1564 },
1565 ModelChunk::Done {
1566 stop_reason: "end_turn".into(),
1567 usage: None,
1568 },
1569 ],
1570 ]);
1571 let seen_input = Arc::new(Mutex::new(None));
1572 let tools = ProbeToolRuntime {
1573 seen_input: seen_input.clone(),
1574 };
1575 let harness = AgentLoopHarness::new(model, tools);
1576 let mut rx = harness
1577 .run_turn(NativeTurnInput {
1578 prompt_text: "go".into(),
1579 system_prompt: None,
1580 attachments: vec![],
1581 cancel_token: None,
1582 prior_messages: vec![],
1583 context_path: None,
1584 })
1585 .await
1586 .unwrap();
1587
1588 let mut wire_input: Option<Value> = None;
1589 let mut history: Option<Vec<ChatMessage>> = None;
1590 while let Some(item) = rx.recv().await {
1591 match item.unwrap() {
1592 HarnessInternalEvent::ToolCall { input, .. } => wire_input = Some(input),
1593 HarnessInternalEvent::TurnEnd { final_messages, .. } => {
1594 history = Some(final_messages);
1595 break;
1596 }
1597 _ => {}
1598 }
1599 }
1600 let repaired = json!({"pattern": "x", "literal": true, "limit": 30});
1601 assert_eq!(seen_input.lock().unwrap().clone().unwrap(), repaired);
1603 assert_eq!(wire_input.unwrap(), repaired);
1604 let history = history.unwrap();
1605 let assistant_tool_calls = history
1606 .iter()
1607 .find_map(|m| match m {
1608 ChatMessage::Assistant { tool_calls, .. } if !tool_calls.is_empty() => {
1609 Some(tool_calls.clone())
1610 }
1611 _ => None,
1612 })
1613 .expect("assistant message with tool_calls in history");
1614 assert_eq!(assistant_tool_calls[0].input, repaired);
1615 }
1616
1617 #[derive(Clone)]
1618 struct TimeoutToolRuntime;
1619
1620 #[async_trait]
1621 impl ToolRuntime for TimeoutToolRuntime {
1622 fn specs(&self) -> Vec<crate::tools::ToolSpec> {
1623 vec![crate::tools::ToolSpec {
1624 name: "slow".into(),
1625 description: "always times out".into(),
1626 input_schema: serde_json::json!({"type": "object"}),
1627 }]
1628 }
1629
1630 async fn invoke(
1631 &self,
1632 _invocation: ToolInvocation,
1633 ) -> Result<ToolOutcome, ToolRuntimeError> {
1634 Err(ToolRuntimeError::Timeout("tool timed out after 1s".into()))
1635 }
1636 }
1637
1638 #[tokio::test]
1639 async fn agent_loop_tool_timeout_is_model_observable_result() {
1640 let model = StreamingFakeClient::new(vec![
1641 vec![
1642 ModelChunk::ToolCallStart {
1643 id: "tc_timeout".into(),
1644 name: "slow".into(),
1645 },
1646 ModelChunk::ToolCallEnd {
1647 id: "tc_timeout".into(),
1648 input: Some(json!({})),
1649 },
1650 ModelChunk::Done {
1651 stop_reason: "tool_use".into(),
1652 usage: None,
1653 },
1654 ],
1655 vec![
1656 ModelChunk::TextDelta {
1657 msg_id: "r2".into(),
1658 delta: "recovered".into(),
1659 },
1660 ModelChunk::Done {
1661 stop_reason: "end_turn".into(),
1662 usage: None,
1663 },
1664 ],
1665 ]);
1666 let harness = AgentLoopHarness::new(model, TimeoutToolRuntime);
1667 let mut rx = harness
1668 .run_turn(NativeTurnInput {
1669 prompt_text: "run slow".into(),
1670 system_prompt: None,
1671 attachments: vec![],
1672 cancel_token: None,
1673 prior_messages: vec![],
1674 context_path: None,
1675 })
1676 .await
1677 .unwrap();
1678
1679 assert!(matches!(
1680 rx.recv().await.unwrap().unwrap(),
1681 HarnessInternalEvent::ToolCall { .. }
1682 ));
1683 match rx.recv().await.unwrap().unwrap() {
1684 HarnessInternalEvent::ToolResult { output, .. } => {
1685 let err = output.unwrap_err();
1686 assert!(err.contains("Timeout"));
1687 assert!(err.contains("tool timed out"));
1688 }
1689 other => panic!("expected timeout ToolResult, got {other:?}"),
1690 }
1691 assert!(matches!(
1692 rx.recv().await.unwrap().unwrap(),
1693 HarnessInternalEvent::AssistantTextChunk { ref delta, .. } if delta == "recovered"
1694 ));
1695 assert!(matches!(
1696 rx.recv().await.unwrap().unwrap(),
1697 HarnessInternalEvent::TurnEnd { ref stop_reason, .. } if stop_reason == "end_turn"
1698 ));
1699 }
1700
1701 #[tokio::test]
1702 async fn agent_loop_invalid_tool_input_is_model_observable_and_bounded() {
1703 let huge_content = "x".repeat(20_000);
1704 let model = StreamingFakeClient::new(vec![
1705 vec![
1706 ModelChunk::ToolCallStart {
1707 id: "tc_bad_write".into(),
1708 name: "write".into(),
1709 },
1710 ModelChunk::ToolCallEnd {
1711 id: "tc_bad_write".into(),
1712 input: Some(json!({"content": huge_content})),
1713 },
1714 ModelChunk::Done {
1715 stop_reason: "tool_use".into(),
1716 usage: None,
1717 },
1718 ],
1719 vec![
1720 ModelChunk::TextDelta {
1721 msg_id: "r2".into(),
1722 delta: "recovered".into(),
1723 },
1724 ModelChunk::Done {
1725 stop_reason: "end_turn".into(),
1726 usage: None,
1727 },
1728 ],
1729 ]);
1730 let harness = AgentLoopHarness::new(model, MockToolRuntime::new());
1731 let mut rx = harness
1732 .run_turn(NativeTurnInput {
1733 prompt_text: "write file".into(),
1734 system_prompt: None,
1735 attachments: vec![],
1736 cancel_token: None,
1737 prior_messages: vec![],
1738 context_path: None,
1739 })
1740 .await
1741 .unwrap();
1742
1743 assert!(matches!(
1744 rx.recv().await.unwrap().unwrap(),
1745 HarnessInternalEvent::ToolCall { .. }
1746 ));
1747 match rx.recv().await.unwrap().unwrap() {
1748 HarnessInternalEvent::ToolResult { output, .. } => {
1749 let err = output.unwrap_err();
1750 assert!(err.contains("The write tool was called with invalid arguments"));
1753 assert!(err.contains("missing required field `path`"), "{err}");
1754 assert!(err.contains("Received fields: content"));
1755 assert!(err.contains("string(20000 chars"));
1756 assert!(err.contains("Expected shape"), "teaching example missing: {err}");
1757 assert!(!err.contains(&"x".repeat(2000)), "error should not echo full content");
1758 }
1759 other => panic!("expected invalid-input ToolResult, got {other:?}"),
1760 }
1761 assert!(matches!(
1762 rx.recv().await.unwrap().unwrap(),
1763 HarnessInternalEvent::AssistantTextChunk { ref delta, .. } if delta == "recovered"
1764 ));
1765 }
1766
1767 struct CountingCompactionStrategy {
1771 calls: Arc<AtomicUsize>,
1772 }
1773
1774 #[async_trait]
1775 impl CompactionStrategy for CountingCompactionStrategy {
1776 fn should_compact(&self, _messages: &[ChatMessage], _context_window_tokens: u64) -> bool {
1777 true
1778 }
1779
1780 async fn compact(
1781 &self,
1782 _messages: Vec<ChatMessage>,
1783 _ctx: &CompactionContext,
1784 ) -> Result<crate::compaction::CompactionOutcome, CompactionError> {
1785 self.calls.fetch_add(1, Ordering::SeqCst);
1786 Ok(crate::compaction::CompactionOutcome {
1789 messages: vec![ChatMessage::User {
1790 content: "<conversation-summary>FOLDED</conversation-summary>".into(),
1791 attachments: vec![],
1792 }],
1793 usage: None,
1794 })
1795 }
1796 }
1797
1798 struct UsageReportingCompactionStrategy {
1802 invoked: Arc<AtomicUsize>,
1803 per_call_usage: HarnessUsage,
1804 }
1805
1806 #[async_trait]
1807 impl CompactionStrategy for UsageReportingCompactionStrategy {
1808 fn should_compact(&self, _: &[ChatMessage], _: u64) -> bool {
1809 self.invoked.load(Ordering::SeqCst) == 0
1813 }
1814 async fn compact(
1815 &self,
1816 messages: Vec<ChatMessage>,
1817 _ctx: &CompactionContext,
1818 ) -> Result<crate::compaction::CompactionOutcome, CompactionError> {
1819 self.invoked.fetch_add(1, Ordering::SeqCst);
1820 Ok(crate::compaction::CompactionOutcome {
1821 messages,
1822 usage: Some(self.per_call_usage.clone()),
1823 })
1824 }
1825 }
1826
1827 #[tokio::test]
1828 async fn agent_loop_attributes_compaction_usage_to_subbucket_and_total() {
1829 let model = StreamingFakeClient::new(vec![vec![
1833 ModelChunk::TextDelta {
1834 msg_id: "m".into(),
1835 delta: "done".into(),
1836 },
1837 ModelChunk::Done {
1838 stop_reason: "end_turn".into(),
1839 usage: Some(usage(100, 30, 0)),
1840 },
1841 ]]);
1842 let invoked = Arc::new(AtomicUsize::new(0));
1843 let strategy = UsageReportingCompactionStrategy {
1844 invoked: invoked.clone(),
1845 per_call_usage: usage(50, 20, 0),
1846 };
1847 let policy = CompactionPolicy {
1848 strategy: Arc::new(strategy),
1849 model_client: Arc::new(ScriptedModelClient),
1850 context_window_tokens: 1, };
1852 let harness = AgentLoopHarness::new(model, MockToolRuntime::new()).with_compaction(policy);
1853 let mut rx = harness
1854 .run_turn(NativeTurnInput {
1855 prompt_text: "hi".into(),
1856 system_prompt: None,
1857 attachments: vec![],
1858 cancel_token: None,
1859 prior_messages: vec![],
1860 context_path: None,
1861 })
1862 .await
1863 .unwrap();
1864 let mut final_usage = None;
1865 while let Some(item) = rx.recv().await {
1866 if let HarnessInternalEvent::TurnEnd { usage, .. } = item.unwrap() {
1867 final_usage = usage;
1868 break;
1869 }
1870 }
1871 assert_eq!(invoked.load(Ordering::SeqCst), 1);
1872 let u = final_usage.expect("TurnEnd carried usage");
1873 assert_eq!(u.input_tokens, 150);
1875 assert_eq!(u.output_tokens, 50);
1876 assert_eq!(u.compaction_input_tokens, 50);
1878 assert_eq!(u.compaction_output_tokens, 20);
1879 }
1880
1881 #[derive(Clone)]
1885 struct RecordingFakeClient {
1886 last_messages: Arc<Mutex<Option<Vec<ChatMessage>>>>,
1887 chunks: Vec<ModelChunk>,
1888 }
1889
1890 #[async_trait]
1891 impl ModelClient for RecordingFakeClient {
1892 async fn stream(
1893 &self,
1894 input: ModelTurnInput,
1895 ) -> Result<BoxStream<'static, Result<ModelChunk, ModelClientError>>, ModelClientError>
1896 {
1897 *self.last_messages.lock().unwrap() = Some(input.messages);
1898 Ok(futures::stream::iter(self.chunks.clone().into_iter().map(Ok)).boxed())
1899 }
1900 }
1901
1902 #[tokio::test]
1903 async fn agent_loop_invokes_compaction_between_steps() {
1904 let calls = Arc::new(AtomicUsize::new(0));
1905 let last_messages = Arc::new(Mutex::new(None::<Vec<ChatMessage>>));
1906 let model = RecordingFakeClient {
1907 last_messages: last_messages.clone(),
1908 chunks: vec![
1909 ModelChunk::TextDelta {
1910 msg_id: "m".into(),
1911 delta: "done".into(),
1912 },
1913 ModelChunk::Done {
1914 stop_reason: "end_turn".into(),
1915 usage: None,
1916 },
1917 ],
1918 };
1919 let summary_client: Arc<dyn ModelClient> = Arc::new(ScriptedModelClient);
1923 let policy = CompactionPolicy {
1924 strategy: Arc::new(CountingCompactionStrategy {
1925 calls: calls.clone(),
1926 }),
1927 model_client: summary_client,
1928 context_window_tokens: 1, };
1930
1931 let harness = AgentLoopHarness::new(model, MockToolRuntime::new()).with_compaction(policy);
1932 let mut rx = harness
1933 .run_turn(NativeTurnInput {
1934 prompt_text: "hello".into(),
1935 system_prompt: None,
1936 attachments: vec![],
1937 cancel_token: None,
1938 prior_messages: vec![],
1939 context_path: None,
1940 })
1941 .await
1942 .unwrap();
1943 let mut compaction_event: Option<(usize, usize)> = None;
1944 while let Some(item) = rx.recv().await {
1945 match item.unwrap() {
1946 HarnessInternalEvent::CompactionApplied {
1947 original_message_count,
1948 compacted_message_count,
1949 ..
1950 } => {
1951 compaction_event = Some((original_message_count, compacted_message_count));
1952 }
1953 HarnessInternalEvent::TurnEnd { .. } => break,
1954 _ => {}
1955 }
1956 }
1957 assert_eq!(calls.load(Ordering::SeqCst), 1);
1959 let (orig, comp) = compaction_event.expect("CompactionApplied event emitted");
1961 assert_eq!(orig, 1, "started with 1 message ([User \"hello\"])");
1962 assert_eq!(comp, 1, "spy strategy folded to single User message");
1963 let observed = last_messages.lock().unwrap().clone().expect("model called");
1966 assert_eq!(observed.len(), 1);
1967 match &observed[0] {
1968 ChatMessage::User { content, .. } => {
1969 assert!(content.contains("FOLDED"), "got {content:?}");
1970 }
1971 other => panic!("expected User, got {other:?}"),
1972 }
1973 }
1974
1975 #[derive(Clone)]
1979 struct HangingModelClient {
1980 started: Arc<tokio::sync::Notify>,
1981 }
1982
1983 #[async_trait]
1984 impl ModelClient for HangingModelClient {
1985 async fn stream(
1986 &self,
1987 _input: ModelTurnInput,
1988 ) -> Result<BoxStream<'static, Result<ModelChunk, ModelClientError>>, ModelClientError>
1989 {
1990 let (tx, rx) = mpsc::channel::<Result<ModelChunk, ModelClientError>>(1);
1995 let started = self.started.clone();
1996 tokio::spawn(async move {
1997 started.notify_one();
2002 let _retain = tx; let () = std::future::pending().await;
2004 });
2005 Ok(tokio_stream::wrappers::ReceiverStream::new(rx).boxed())
2006 }
2007 }
2008
2009 #[tokio::test]
2010 async fn agent_loop_cancellation_interrupts_in_flight_stream() {
2011 let started = Arc::new(tokio::sync::Notify::new());
2016 let model = HangingModelClient {
2017 started: started.clone(),
2018 };
2019 let cancel = CancellationToken::new();
2020 let cancel_for_outside = cancel.clone();
2021
2022 let harness = AgentLoopHarness::new(model, MockToolRuntime::new());
2023 let mut rx = harness
2024 .run_turn(NativeTurnInput {
2025 prompt_text: "hi".into(),
2026 system_prompt: None,
2027 attachments: vec![],
2028 cancel_token: Some(cancel),
2029 prior_messages: vec![],
2030 context_path: None,
2031 })
2032 .await
2033 .unwrap();
2034
2035 started.notified().await;
2040 cancel_for_outside.cancel();
2041
2042 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(2);
2044 let mut saw_interrupt = false;
2045 while tokio::time::Instant::now() < deadline {
2046 tokio::select! {
2047 item = rx.recv() => {
2048 match item {
2049 Some(Ok(HarnessInternalEvent::TurnEnd { stop_reason, .. })) => {
2050 assert_eq!(stop_reason, "interrupt");
2051 saw_interrupt = true;
2052 break;
2053 }
2054 Some(_) => continue,
2055 None => break,
2056 }
2057 }
2058 _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {}
2059 }
2060 }
2061 assert!(saw_interrupt, "expected TurnEnd{{interrupt}} after cancel");
2062 }
2063
2064 enum StallBehavior {
2071 Complete(Vec<ModelChunk>),
2073 EmitThenHang(Vec<ModelChunk>),
2075 }
2076
2077 #[derive(Clone)]
2078 struct StallingModelClient {
2079 behaviors: Arc<Mutex<Vec<StallBehavior>>>,
2080 calls: Arc<AtomicUsize>,
2081 }
2082
2083 impl StallingModelClient {
2084 fn new(behaviors: Vec<StallBehavior>) -> Self {
2085 Self {
2086 behaviors: Arc::new(Mutex::new(behaviors)),
2087 calls: Arc::new(AtomicUsize::new(0)),
2088 }
2089 }
2090 }
2091
2092 #[async_trait]
2093 impl ModelClient for StallingModelClient {
2094 async fn stream(
2095 &self,
2096 _input: ModelTurnInput,
2097 ) -> Result<BoxStream<'static, Result<ModelChunk, ModelClientError>>, ModelClientError>
2098 {
2099 self.calls.fetch_add(1, Ordering::SeqCst);
2100 let behavior = {
2103 let mut b = self.behaviors.lock().unwrap();
2104 if b.is_empty() {
2105 StallBehavior::EmitThenHang(vec![])
2106 } else {
2107 b.remove(0)
2108 }
2109 };
2110 let (tx, rx) = mpsc::channel::<Result<ModelChunk, ModelClientError>>(8);
2111 tokio::spawn(async move {
2112 match behavior {
2113 StallBehavior::Complete(chunks) => {
2114 for c in chunks {
2115 if tx.send(Ok(c)).await.is_err() {
2116 return;
2117 }
2118 }
2119 }
2121 StallBehavior::EmitThenHang(chunks) => {
2122 for c in chunks {
2123 if tx.send(Ok(c)).await.is_err() {
2124 return;
2125 }
2126 }
2127 let _retain = tx; let () = std::future::pending().await;
2129 }
2130 }
2131 });
2132 Ok(tokio_stream::wrappers::ReceiverStream::new(rx).boxed())
2133 }
2134 }
2135
2136 #[tokio::test(start_paused = true)]
2137 async fn agent_loop_reconnects_after_stall_before_any_output() {
2138 let model = StallingModelClient::new(vec![
2143 StallBehavior::EmitThenHang(vec![]),
2144 StallBehavior::Complete(vec![
2145 ModelChunk::TextDelta {
2146 msg_id: "m".into(),
2147 delta: "ok".into(),
2148 },
2149 ModelChunk::Done {
2150 stop_reason: "end_turn".into(),
2151 usage: None,
2152 },
2153 ]),
2154 ]);
2155 let calls = model.calls.clone();
2156 let harness = AgentLoopHarness::new(model, MockToolRuntime::new())
2157 .with_stream_resilience(Duration::from_millis(50), 3);
2158 let mut rx = harness
2159 .run_turn(NativeTurnInput {
2160 prompt_text: "hi".into(),
2161 system_prompt: None,
2162 attachments: vec![],
2163 cancel_token: None,
2164 prior_messages: vec![],
2165 context_path: None,
2166 })
2167 .await
2168 .unwrap();
2169
2170 let mut text = String::new();
2171 let mut stop = None;
2172 while let Some(item) = rx.recv().await {
2173 match item.expect("no error expected") {
2174 HarnessInternalEvent::AssistantTextChunk { delta, .. } => text.push_str(&delta),
2175 HarnessInternalEvent::TurnEnd { stop_reason, .. } => {
2176 stop = Some(stop_reason);
2177 break;
2178 }
2179 _ => {}
2180 }
2181 }
2182 assert_eq!(stop.as_deref(), Some("end_turn"));
2183 assert_eq!(text, "ok", "text delivered exactly once, no duplication");
2184 assert_eq!(
2185 calls.load(Ordering::SeqCst),
2186 2,
2187 "stream established twice (one reconnect)"
2188 );
2189 }
2190
2191 #[tokio::test(start_paused = true)]
2192 async fn agent_loop_surfaces_error_when_reconnect_budget_exhausted() {
2193 let model = StallingModelClient::new(vec![]); let calls = model.calls.clone();
2197 let harness = AgentLoopHarness::new(model, MockToolRuntime::new())
2198 .with_stream_resilience(Duration::from_millis(50), 2);
2199 let mut rx = harness
2200 .run_turn(NativeTurnInput {
2201 prompt_text: "hi".into(),
2202 system_prompt: None,
2203 attachments: vec![],
2204 cancel_token: None,
2205 prior_messages: vec![],
2206 context_path: None,
2207 })
2208 .await
2209 .unwrap();
2210
2211 let mut saw_error = false;
2212 while let Some(item) = rx.recv().await {
2213 match item {
2214 Err(NativeHarnessError::ModelNetwork(msg)) => {
2215 assert!(msg.contains("stalled"), "got {msg:?}");
2216 saw_error = true;
2217 break;
2218 }
2219 Err(other) => panic!("unexpected error variant: {other:?}"),
2220 Ok(_) => {}
2221 }
2222 }
2223 assert!(
2224 saw_error,
2225 "expected ModelNetwork stall error after budget exhausted"
2226 );
2227 assert_eq!(
2228 calls.load(Ordering::SeqCst),
2229 2,
2230 "two establishments (initial + one reconnect)"
2231 );
2232 }
2233
2234 #[tokio::test(start_paused = true)]
2235 async fn agent_loop_does_not_reconnect_after_stall_with_partial_output() {
2236 let model = StallingModelClient::new(vec![StallBehavior::EmitThenHang(vec![
2242 ModelChunk::TextDelta {
2243 msg_id: "m".into(),
2244 delta: "partial".into(),
2245 },
2246 ])]);
2247 let calls = model.calls.clone();
2248 let harness = AgentLoopHarness::new(model, MockToolRuntime::new())
2249 .with_stream_resilience(Duration::from_millis(50), 5);
2250 let mut rx = harness
2251 .run_turn(NativeTurnInput {
2252 prompt_text: "hi".into(),
2253 system_prompt: None,
2254 attachments: vec![],
2255 cancel_token: None,
2256 prior_messages: vec![],
2257 context_path: None,
2258 })
2259 .await
2260 .unwrap();
2261
2262 let mut text = String::new();
2263 let mut saw_error = false;
2264 while let Some(item) = rx.recv().await {
2265 match item {
2266 Ok(HarnessInternalEvent::AssistantTextChunk { delta, .. }) => text.push_str(&delta),
2267 Err(NativeHarnessError::ModelNetwork(_)) => {
2268 saw_error = true;
2269 break;
2270 }
2271 Err(other) => panic!("unexpected error variant: {other:?}"),
2272 Ok(_) => {}
2273 }
2274 }
2275 assert!(saw_error, "expected terminal ModelNetwork error");
2276 assert_eq!(
2277 text, "partial",
2278 "partial output delivered once, not replayed"
2279 );
2280 assert_eq!(
2281 calls.load(Ordering::SeqCst),
2282 1,
2283 "no reconnect once output has reached the user"
2284 );
2285 }
2286
2287 #[tokio::test]
2288 async fn agent_loop_accumulates_thinking_chunks_and_signature() {
2289 let model = StreamingFakeClient::new(vec![vec![
2298 ModelChunk::ThinkingDelta {
2299 thinking_id: "th_1".into(),
2300 delta: "let me think...".into(),
2301 signature: None,
2302 },
2303 ModelChunk::ThinkingDelta {
2304 thinking_id: "th_1".into(),
2305 delta: "".into(),
2306 signature: Some("sig_abc".into()),
2307 },
2308 ModelChunk::TextDelta {
2309 msg_id: "m1".into(),
2310 delta: "ok".into(),
2311 },
2312 ModelChunk::Done {
2313 stop_reason: "end_turn".into(),
2314 usage: None,
2315 },
2316 ]]);
2317 let harness = AgentLoopHarness::new(model, MockToolRuntime::new());
2318 let mut rx = harness
2319 .run_turn(NativeTurnInput {
2320 prompt_text: "hi".into(),
2321 system_prompt: None,
2322 attachments: vec![],
2323 cancel_token: None,
2324 prior_messages: vec![],
2325 context_path: None,
2326 })
2327 .await
2328 .unwrap();
2329
2330 let mut thinking_chunks: Vec<String> = Vec::new();
2331 let mut text_chunks: Vec<String> = Vec::new();
2332 let mut saw_end = false;
2333 while let Some(item) = rx.recv().await {
2334 match item.unwrap() {
2335 HarnessInternalEvent::AssistantThinkingChunk { msg_id, delta } => {
2336 assert_eq!(msg_id, "thinking_native_0");
2337 thinking_chunks.push(delta);
2338 }
2339 HarnessInternalEvent::AssistantTextChunk { msg_id, delta } => {
2340 assert_eq!(msg_id, "msg_native_0");
2341 text_chunks.push(delta);
2342 }
2343 HarnessInternalEvent::TurnEnd { .. } => {
2344 saw_end = true;
2345 break;
2346 }
2347 other => panic!("unexpected event: {other:?}"),
2348 }
2349 }
2350 assert_eq!(thinking_chunks, vec!["let me think..."]);
2353 assert_eq!(text_chunks, vec!["ok"]);
2354 assert!(saw_end);
2355 }
2356
2357 #[tokio::test]
2358 async fn agent_loop_runs_tool_then_final_message() {
2359 let harness = AgentLoopHarness::new(
2360 ScriptedModelClient,
2361 MockToolRuntime::new().with_file("README.md", "hello"),
2362 );
2363 let mut rx = harness
2364 .run_turn(NativeTurnInput {
2365 prompt_text: "read README.md".into(),
2366 system_prompt: None,
2367 attachments: vec![],
2368 cancel_token: None,
2369 prior_messages: vec![],
2370 context_path: None,
2371 })
2372 .await
2373 .unwrap();
2374
2375 assert!(matches!(
2376 rx.recv().await.unwrap().unwrap(),
2377 HarnessInternalEvent::AssistantTextChunk { .. }
2378 ));
2379 assert!(matches!(
2380 rx.recv().await.unwrap().unwrap(),
2381 HarnessInternalEvent::ToolCall { ref name, .. } if name == "read"
2382 ));
2383 assert!(matches!(
2384 rx.recv().await.unwrap().unwrap(),
2385 HarnessInternalEvent::ToolResult { .. }
2386 ));
2387 assert!(matches!(
2388 rx.recv().await.unwrap().unwrap(),
2389 HarnessInternalEvent::AssistantTextChunk { .. }
2390 ));
2391 assert!(matches!(
2392 rx.recv().await.unwrap().unwrap(),
2393 HarnessInternalEvent::TurnEnd { .. }
2394 ));
2395 assert!(rx.recv().await.is_none());
2396 }
2397
2398 #[tokio::test]
2405 async fn agent_loop_turn_end_carries_full_message_history() {
2406 let model = QueueModelClient::new(vec![ModelResponse::Message {
2407 text: "second reply".into(),
2408 stop_reason: "end_turn".into(),
2409 usage: None,
2410 }]);
2411 let harness = AgentLoopHarness::new(model, MockToolRuntime::new());
2412 let prior = vec![
2414 ChatMessage::User {
2415 content: "first prompt".into(),
2416 attachments: vec![],
2417 },
2418 ChatMessage::Assistant {
2419 text: Some("first reply".into()),
2420 tool_calls: vec![],
2421 thinking: None,
2422 },
2423 ];
2424 let mut rx = harness
2425 .run_turn(NativeTurnInput {
2426 prompt_text: "second prompt".into(),
2427 system_prompt: None,
2428 attachments: vec![],
2429 cancel_token: None,
2430 prior_messages: prior,
2431 context_path: None,
2432 })
2433 .await
2434 .unwrap();
2435 let mut final_messages: Option<Vec<ChatMessage>> = None;
2436 while let Some(item) = rx.recv().await {
2437 if let HarnessInternalEvent::TurnEnd {
2438 final_messages: m, ..
2439 } = item.unwrap()
2440 {
2441 final_messages = Some(m);
2442 break;
2443 }
2444 }
2445 let msgs = final_messages.expect("TurnEnd carried final_messages");
2446 assert_eq!(msgs.len(), 4, "got {msgs:?}");
2448 match &msgs[0] {
2449 ChatMessage::User { content, .. } => assert_eq!(content, "first prompt"),
2450 other => panic!("msgs[0] not user-1: {other:?}"),
2451 }
2452 match &msgs[1] {
2453 ChatMessage::Assistant { text, .. } => {
2454 assert_eq!(text.as_deref(), Some("first reply"));
2455 }
2456 other => panic!("msgs[1] not assistant-1: {other:?}"),
2457 }
2458 match &msgs[2] {
2459 ChatMessage::User { content, .. } => assert_eq!(content, "second prompt"),
2460 other => panic!("msgs[2] not user-2: {other:?}"),
2461 }
2462 match &msgs[3] {
2463 ChatMessage::Assistant { text, .. } => {
2464 assert_eq!(text.as_deref(), Some("second reply"));
2465 }
2466 other => panic!("msgs[3] not assistant-2: {other:?}"),
2467 }
2468 }
2469
2470 #[derive(Clone)]
2475 struct ConcurrencyProbeRuntime {
2476 sleep_for: std::time::Duration,
2477 in_flight: Arc<AtomicUsize>,
2478 max_concurrency: Arc<AtomicUsize>,
2479 call_order: Arc<Mutex<Vec<String>>>,
2480 cancelled: Arc<AtomicUsize>,
2481 }
2482
2483 impl ConcurrencyProbeRuntime {
2484 fn new(sleep_for: std::time::Duration) -> Self {
2485 Self {
2486 sleep_for,
2487 in_flight: Arc::new(AtomicUsize::new(0)),
2488 max_concurrency: Arc::new(AtomicUsize::new(0)),
2489 call_order: Arc::new(Mutex::new(Vec::new())),
2490 cancelled: Arc::new(AtomicUsize::new(0)),
2491 }
2492 }
2493 }
2494
2495 #[async_trait]
2496 impl ToolRuntime for ConcurrencyProbeRuntime {
2497 fn specs(&self) -> Vec<crate::tools::ToolSpec> {
2498 vec![crate::tools::ToolSpec {
2499 name: "slow".into(),
2500 description: "sleeps".into(),
2501 input_schema: serde_json::json!({"type": "object"}),
2502 }]
2503 }
2504
2505 async fn invoke(
2506 &self,
2507 invocation: ToolInvocation,
2508 ) -> Result<ToolOutcome, ToolRuntimeError> {
2509 self.call_order.lock().unwrap().push(invocation.id.clone());
2510 let now = self.in_flight.fetch_add(1, Ordering::SeqCst) + 1;
2511 let mut prev = self.max_concurrency.load(Ordering::SeqCst);
2512 while now > prev {
2513 match self.max_concurrency.compare_exchange(
2514 prev,
2515 now,
2516 Ordering::SeqCst,
2517 Ordering::SeqCst,
2518 ) {
2519 Ok(_) => break,
2520 Err(actual) => prev = actual,
2521 }
2522 }
2523 tokio::time::sleep(self.sleep_for).await;
2524 self.in_flight.fetch_sub(1, Ordering::SeqCst);
2525 Ok(ToolOutcome {
2526 output: Ok(serde_json::json!({"slept": true, "id": invocation.id})),
2527 attachments: vec![],
2528 })
2529 }
2530
2531 async fn invoke_cancellable(
2532 &self,
2533 invocation: ToolInvocation,
2534 cancel: Option<&CancellationToken>,
2535 ) -> Result<ToolOutcome, ToolRuntimeError> {
2536 self.call_order.lock().unwrap().push(invocation.id.clone());
2537 let now = self.in_flight.fetch_add(1, Ordering::SeqCst) + 1;
2538 let mut prev = self.max_concurrency.load(Ordering::SeqCst);
2539 while now > prev {
2540 match self.max_concurrency.compare_exchange(
2541 prev,
2542 now,
2543 Ordering::SeqCst,
2544 Ordering::SeqCst,
2545 ) {
2546 Ok(_) => break,
2547 Err(actual) => prev = actual,
2548 }
2549 }
2550 if let Some(token) = cancel {
2551 tokio::select! {
2552 _ = token.cancelled() => {
2553 self.cancelled.fetch_add(1, Ordering::SeqCst);
2554 self.in_flight.fetch_sub(1, Ordering::SeqCst);
2555 Err(ToolRuntimeError::Runtime("cancelled".into()))
2556 }
2557 _ = tokio::time::sleep(self.sleep_for) => {
2558 self.in_flight.fetch_sub(1, Ordering::SeqCst);
2559 Ok(ToolOutcome {
2560 output: Ok(serde_json::json!({"slept": true, "id": invocation.id})),
2561 attachments: vec![],
2562 })
2563 }
2564 }
2565 } else {
2566 tokio::time::sleep(self.sleep_for).await;
2567 self.in_flight.fetch_sub(1, Ordering::SeqCst);
2568 Ok(ToolOutcome {
2569 output: Ok(serde_json::json!({"slept": true, "id": invocation.id})),
2570 attachments: vec![],
2571 })
2572 }
2573 }
2574 }
2575
2576 #[tokio::test]
2582 async fn agent_loop_runs_multi_tool_calls_concurrently() {
2583 let model = StreamingFakeClient::new(vec![
2586 vec![
2587 ModelChunk::ToolCallStart {
2588 id: "tc_a".into(),
2589 name: "slow".into(),
2590 },
2591 ModelChunk::ToolCallEnd {
2592 id: "tc_a".into(),
2593 input: Some(json!({})),
2594 },
2595 ModelChunk::ToolCallStart {
2596 id: "tc_b".into(),
2597 name: "slow".into(),
2598 },
2599 ModelChunk::ToolCallEnd {
2600 id: "tc_b".into(),
2601 input: Some(json!({})),
2602 },
2603 ModelChunk::ToolCallStart {
2604 id: "tc_c".into(),
2605 name: "slow".into(),
2606 },
2607 ModelChunk::ToolCallEnd {
2608 id: "tc_c".into(),
2609 input: Some(json!({})),
2610 },
2611 ModelChunk::Done {
2612 stop_reason: "tool_use".into(),
2613 usage: None,
2614 },
2615 ],
2616 vec![
2617 ModelChunk::TextDelta {
2618 msg_id: "remote".into(),
2619 delta: "done".into(),
2620 },
2621 ModelChunk::Done {
2622 stop_reason: "end_turn".into(),
2623 usage: None,
2624 },
2625 ],
2626 ]);
2627 let probe = ConcurrencyProbeRuntime::new(std::time::Duration::from_millis(80));
2628 let max_concurrency = probe.max_concurrency.clone();
2629 let harness = AgentLoopHarness::new(model, probe);
2630
2631 let start = std::time::Instant::now();
2632 let mut rx = harness
2633 .run_turn(NativeTurnInput {
2634 prompt_text: "go".into(),
2635 system_prompt: None,
2636 attachments: vec![],
2637 cancel_token: None,
2638 prior_messages: vec![],
2639 context_path: None,
2640 })
2641 .await
2642 .unwrap();
2643 let mut tool_results = 0;
2644 while let Some(item) = rx.recv().await {
2645 match item.unwrap() {
2646 HarnessInternalEvent::ToolResult { .. } => tool_results += 1,
2647 HarnessInternalEvent::TurnEnd { .. } => break,
2648 _ => {}
2649 }
2650 }
2651 let elapsed = start.elapsed();
2652 assert_eq!(
2654 tool_results, 3,
2655 "expected 3 tool results, got {tool_results}"
2656 );
2657 assert_eq!(
2659 max_concurrency.load(Ordering::SeqCst),
2660 3,
2661 "expected max concurrency 3 (parallel dispatch), got {}",
2662 max_concurrency.load(Ordering::SeqCst)
2663 );
2664 assert!(
2668 elapsed < std::time::Duration::from_millis(200),
2669 "elapsed {elapsed:?} suggests sequential execution"
2670 );
2671 }
2672
2673 #[tokio::test]
2677 async fn agent_loop_cancels_in_flight_tool_calls() {
2678 let model = StreamingFakeClient::new(vec![vec![
2679 ModelChunk::ToolCallStart {
2680 id: "tc_slow".into(),
2681 name: "slow".into(),
2682 },
2683 ModelChunk::ToolCallEnd {
2684 id: "tc_slow".into(),
2685 input: Some(json!({})),
2686 },
2687 ModelChunk::Done {
2688 stop_reason: "tool_use".into(),
2689 usage: None,
2690 },
2691 ]]);
2692 let probe = ConcurrencyProbeRuntime::new(std::time::Duration::from_secs(5));
2695 let cancelled_count = probe.cancelled.clone();
2696 let harness = AgentLoopHarness::new(model, probe);
2697
2698 let cancel = CancellationToken::new();
2699 let cancel_for_input = cancel.clone();
2700 let mut rx = harness
2701 .run_turn(NativeTurnInput {
2702 prompt_text: "go".into(),
2703 system_prompt: None,
2704 attachments: vec![],
2705 cancel_token: Some(cancel_for_input),
2706 prior_messages: vec![],
2707 context_path: None,
2708 })
2709 .await
2710 .unwrap();
2711
2712 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2714 cancel.cancel();
2715
2716 let start = std::time::Instant::now();
2717 let mut saw_interrupt = false;
2718 while let Some(item) = rx.recv().await {
2719 if let HarnessInternalEvent::TurnEnd { stop_reason, .. } = item.unwrap() {
2720 assert_eq!(stop_reason, "interrupt");
2721 saw_interrupt = true;
2722 break;
2723 }
2724 }
2725 let elapsed = start.elapsed();
2726 assert!(saw_interrupt, "must see interrupt TurnEnd");
2727 assert!(
2728 elapsed < std::time::Duration::from_millis(200),
2729 "cancel propagation took too long: {elapsed:?}"
2730 );
2731 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1);
2732 while cancelled_count.load(Ordering::SeqCst) == 0 && tokio::time::Instant::now() < deadline
2733 {
2734 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
2735 }
2736 assert_eq!(
2737 cancelled_count.load(Ordering::SeqCst),
2738 1,
2739 "tool runtime must observe the cancellation token"
2740 );
2741 }
2742
2743 #[tokio::test]
2748 async fn agent_loop_cancel_during_tools_yields_clean_history() {
2749 let model = StreamingFakeClient::new(vec![vec![
2750 ModelChunk::ToolCallStart {
2751 id: "tc_a".into(),
2752 name: "slow".into(),
2753 },
2754 ModelChunk::ToolCallEnd {
2755 id: "tc_a".into(),
2756 input: Some(json!({})),
2757 },
2758 ModelChunk::ToolCallStart {
2759 id: "tc_b".into(),
2760 name: "slow".into(),
2761 },
2762 ModelChunk::ToolCallEnd {
2763 id: "tc_b".into(),
2764 input: Some(json!({})),
2765 },
2766 ModelChunk::Done {
2767 stop_reason: "tool_use".into(),
2768 usage: None,
2769 },
2770 ]]);
2771 let probe = ConcurrencyProbeRuntime::new(std::time::Duration::from_secs(3));
2772 let harness = AgentLoopHarness::new(model, probe);
2773
2774 let cancel = CancellationToken::new();
2775 let cancel_for_input = cancel.clone();
2776 let mut rx = harness
2777 .run_turn(NativeTurnInput {
2778 prompt_text: "go".into(),
2779 system_prompt: None,
2780 attachments: vec![],
2781 cancel_token: Some(cancel_for_input),
2782 prior_messages: vec![],
2783 context_path: None,
2784 })
2785 .await
2786 .unwrap();
2787
2788 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
2789 cancel.cancel();
2790
2791 let mut final_msgs = None;
2792 while let Some(item) = rx.recv().await {
2793 if let HarnessInternalEvent::TurnEnd { final_messages, .. } = item.unwrap() {
2794 final_msgs = Some(final_messages);
2795 break;
2796 }
2797 }
2798 let msgs = final_msgs.expect("interrupt TurnEnd");
2799 assert_eq!(msgs.len(), 2, "expected 2 messages, got {msgs:?}");
2802 match &msgs[1] {
2803 ChatMessage::Assistant { tool_calls, .. } => {
2804 assert_eq!(tool_calls.len(), 2);
2805 assert_eq!(tool_calls[0].id, "tc_a");
2806 assert_eq!(tool_calls[1].id, "tc_b");
2807 }
2808 other => panic!("msgs[1] not assistant: {other:?}"),
2809 }
2810 }
2811}