1pub mod streamed;
57
58use std::collections::{BTreeMap, BTreeSet};
59
60use serde::{Deserialize, Serialize};
61
62use crate::{
63 OneOrMany,
64 agent::prompt_request::{
65 CompletionCall, PromptResponse, TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER,
66 assistant_text_from_choice, build_full_history, build_history_for_request,
67 hooks::{InvalidToolCallContext, InvalidToolCallHookAction},
68 invalid_tool_retry_user_message, is_empty_assistant_turn, tool_result_user_content,
69 },
70 completion::{Message, PromptError, Usage},
71 json_utils,
72 message::{AssistantContent, ToolCall, ToolChoice, ToolResult, ToolResultContent, UserContent},
73};
74
75pub use streamed::{
76 PartialStreamedTurn, StreamedInvalidToolCall, StreamedResolution, StreamedTurn,
77 StreamedTurnAssembler, StreamedTurnEvent,
78};
79
80#[derive(Debug, Clone)]
85pub enum AgentRunStep {
86 CallModel {
89 prompt: Message,
91 history: Vec<Message>,
94 turn: usize,
96 },
97 CallTools {
100 calls: Vec<PendingToolCall>,
102 },
103 Done(PromptResponse),
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109#[non_exhaustive]
110pub struct PendingToolCall {
111 pub tool_call: ToolCall,
113 pub preresolved_result: Option<UserContent>,
117 #[serde(default)]
122 pub internal_call_id: Option<String>,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127#[non_exhaustive]
128pub struct ModelTurn {
129 pub message_id: Option<String>,
131 pub choice: OneOrMany<AssistantContent>,
133 pub usage: Usage,
135 pub executable_tool_names: BTreeSet<String>,
137 pub allowed_tool_names: BTreeSet<String>,
139}
140
141impl ModelTurn {
142 pub fn new(
145 message_id: Option<String>,
146 choice: OneOrMany<AssistantContent>,
147 usage: Usage,
148 executable_tool_names: BTreeSet<String>,
149 allowed_tool_names: BTreeSet<String>,
150 ) -> Self {
151 Self {
152 message_id,
153 choice,
154 usage,
155 executable_tool_names,
156 allowed_tool_names,
157 }
158 }
159}
160
161#[derive(Debug)]
167pub enum ModelTurnOutcome {
168 Continue {
176 response_hook_suppressed: bool,
178 },
179 NeedsResolution(InvalidToolCallContext),
184 TurnRetried,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191struct ResolvingState {
192 message_id: Option<String>,
193 original_choice: OneOrMany<AssistantContent>,
196 items: Vec<AssistantContent>,
198 next_index: usize,
200 executable_tool_names: BTreeSet<String>,
201 allowed_tool_names: BTreeSet<String>,
202 skipped: BTreeMap<String, UserContent>,
204 recovered: bool,
205 any_skipped: bool,
206 has_tool_calls: bool,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210struct TurnState {
211 message_id: Option<String>,
212 items: Vec<AssistantContent>,
213 has_tool_calls: bool,
214 skipped: BTreeMap<String, UserContent>,
215 #[serde(default)]
218 internal_call_ids: Vec<(String, String)>,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
222enum RunState {
223 PreparingRequest,
225 AwaitingModel,
227 ResolvingToolCalls(Box<ResolvingState>),
230 AwaitingAdvance(Box<TurnState>),
233 ExecutingTools(Vec<PendingToolCall>),
237 Done(Box<PromptResponse>),
239 Failed,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct AgentRun {
247 max_turns: usize,
248 max_invalid_tool_call_retries: usize,
249 tool_choice: Option<ToolChoice>,
250 chat_history: Option<Vec<Message>>,
251 new_messages: Vec<Message>,
252 current_turn: usize,
253 usage: Usage,
254 completion_calls: Vec<CompletionCall>,
255 completion_call_index: usize,
256 invalid_tool_call_retries: usize,
257 #[serde(default)]
260 rollback_pending: bool,
261 #[serde(default)]
265 streamed_completion_call_recorded: bool,
266 state: RunState,
267}
268
269impl AgentRun {
270 pub fn new(prompt: impl Into<Message>) -> Self {
273 Self {
274 max_turns: 0,
275 max_invalid_tool_call_retries: 0,
276 tool_choice: None,
277 chat_history: None,
278 new_messages: vec![prompt.into()],
279 current_turn: 0,
280 usage: Usage::new(),
281 completion_calls: Vec::new(),
282 completion_call_index: 0,
283 invalid_tool_call_retries: 0,
284 rollback_pending: false,
285 streamed_completion_call_recorded: false,
286 state: RunState::PreparingRequest,
287 }
288 }
289
290 pub fn with_history(mut self, history: Vec<Message>) -> Self {
292 self.chat_history = Some(history);
293 self
294 }
295
296 pub fn max_turns(mut self, max_turns: usize) -> Self {
299 self.max_turns = max_turns;
300 self
301 }
302
303 pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
306 self.max_invalid_tool_call_retries = retries;
307 self
308 }
309
310 pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
314 self.tool_choice = Some(tool_choice);
315 self
316 }
317
318 pub fn usage(&self) -> Usage {
320 self.usage
321 }
322
323 pub fn turn(&self) -> usize {
325 self.current_turn
326 }
327
328 pub fn completion_calls(&self) -> &[CompletionCall] {
330 &self.completion_calls
331 }
332
333 pub fn messages(&self) -> &[Message] {
336 &self.new_messages
337 }
338
339 pub fn full_history(&self) -> Vec<Message> {
341 build_full_history(self.chat_history.as_deref(), self.new_messages.clone())
342 }
343
344 pub fn is_done(&self) -> bool {
346 matches!(self.state, RunState::Done(_))
347 }
348
349 pub fn response(&self) -> Option<&PromptResponse> {
354 match &self.state {
355 RunState::Done(response) => Some(response),
356 _ => None,
357 }
358 }
359
360 pub fn cancel_error(&self, reason: impl Into<String>) -> PromptError {
363 PromptError::prompt_cancelled(self.full_history(), reason)
364 }
365
366 pub fn pending_invalid_tool_call(&self) -> Option<InvalidToolCallContext> {
370 let RunState::ResolvingToolCalls(resolving) = &self.state else {
371 return None;
372 };
373 let AssistantContent::ToolCall(tool_call) = resolving.items.get(resolving.next_index)?
374 else {
375 return None;
376 };
377 if resolving
378 .allowed_tool_names
379 .contains(&tool_call.function.name)
380 {
381 return None;
382 }
383
384 Some(InvalidToolCallContext {
385 tool_name: tool_call.function.name.clone(),
386 tool_call_id: Some(tool_call.id.clone()),
387 internal_call_id: None,
388 args: Some(json_utils::value_to_json_string(
389 &tool_call.function.arguments,
390 )),
391 available_tools: resolving.executable_tool_names.iter().cloned().collect(),
392 allowed_tools: resolving.allowed_tool_names.iter().cloned().collect(),
393 tool_choice: self.tool_choice.clone(),
394 chat_history: self.diagnostic_history(resolving),
395 is_streaming: false,
396 })
397 }
398
399 pub fn next_step(&mut self) -> Result<AgentRunStep, PromptError> {
407 match std::mem::replace(&mut self.state, RunState::Failed) {
408 RunState::PreparingRequest => {
409 let Some((prompt_ref, history_for_turn)) = self.new_messages.split_last() else {
410 return Err(PromptError::prompt_cancelled(
411 self.full_history(),
412 "prompt loop lost its pending prompt",
413 ));
414 };
415 let prompt = prompt_ref.clone();
416
417 if self.current_turn > self.max_turns + 1 {
418 return Err(PromptError::MaxTurnsError {
419 max_turns: self.max_turns,
420 chat_history: self.full_history().into(),
421 prompt: prompt.into(),
422 });
423 }
424
425 let history =
426 build_history_for_request(self.chat_history.as_deref(), history_for_turn);
427 self.current_turn += 1;
428 self.rollback_pending = false;
429 self.streamed_completion_call_recorded = false;
430 self.state = RunState::AwaitingModel;
431 Ok(AgentRunStep::CallModel {
432 prompt,
433 history,
434 turn: self.current_turn,
435 })
436 }
437 RunState::AwaitingAdvance(turn_state) => {
438 let TurnState {
439 message_id,
440 items,
441 has_tool_calls,
442 skipped,
443 mut internal_call_ids,
444 } = *turn_state;
445 let Some(choice) = OneOrMany::from_iter_optional(items.clone()) else {
446 return Err(PromptError::prompt_cancelled(
447 self.full_history(),
448 "model turn lost its assistant content",
449 ));
450 };
451
452 if !is_empty_assistant_turn(&choice) {
453 self.new_messages.push(Message::Assistant {
454 id: message_id,
455 content: choice.clone(),
456 });
457 }
458
459 if has_tool_calls {
460 let calls: Vec<PendingToolCall> = items
461 .iter()
462 .filter_map(|item| match item {
463 AssistantContent::ToolCall(tool_call) => {
464 let internal_call_id = internal_call_ids
468 .iter()
469 .position(|(id, _)| *id == tool_call.id)
470 .map(|index| internal_call_ids.remove(index).1);
471 Some(PendingToolCall {
472 tool_call: tool_call.clone(),
473 preresolved_result: skipped.get(&tool_call.id).cloned(),
474 internal_call_id,
475 })
476 }
477 _ => None,
478 })
479 .collect();
480 self.state = RunState::ExecutingTools(calls.clone());
481 Ok(AgentRunStep::CallTools { calls })
482 } else {
483 let response =
484 PromptResponse::new(assistant_text_from_choice(&choice), self.usage)
485 .with_messages(self.new_messages.clone())
486 .with_completion_calls(self.completion_calls.clone());
487 self.state = RunState::Done(Box::new(response.clone()));
488 Ok(AgentRunStep::Done(response))
489 }
490 }
491 RunState::ExecutingTools(calls) => {
492 let step = AgentRunStep::CallTools {
495 calls: calls.clone(),
496 };
497 self.state = RunState::ExecutingTools(calls);
498 Ok(step)
499 }
500 RunState::Done(response) => {
501 let step = AgentRunStep::Done((*response).clone());
502 self.state = RunState::Done(response);
503 Ok(step)
504 }
505 state @ (RunState::AwaitingModel | RunState::ResolvingToolCalls(_)) => {
506 let reason = match &state {
507 RunState::AwaitingModel => {
508 "next_step called while a model response is pending; feed it via model_response first"
509 }
510 _ => {
511 "next_step called while an invalid tool-call resolution is pending; answer it via resolve_invalid_tool_call first"
512 }
513 };
514 self.state = state;
515 Err(self.protocol_violation(reason))
516 }
517 RunState::Failed => Err(self.protocol_violation(
518 "next_step called after the run already failed or was misdriven",
519 )),
520 }
521 }
522
523 pub fn model_response(&mut self, turn: ModelTurn) -> Result<ModelTurnOutcome, PromptError> {
529 if !matches!(self.state, RunState::AwaitingModel) {
530 return Err(
531 self.protocol_violation("model_response called without a pending CallModel step")
532 );
533 }
534 if self.streamed_completion_call_recorded {
535 return Err(self.protocol_violation(
536 "model_response called after record_streamed_completion_call for the same turn; feed streamed turns via streamed_turn",
537 ));
538 }
539
540 self.completion_calls
541 .push(CompletionCall::new(self.completion_call_index, turn.usage));
542 self.completion_call_index += 1;
543 self.usage += turn.usage;
544
545 let items: Vec<AssistantContent> = turn.choice.iter().cloned().collect();
546 let has_tool_calls = items
547 .iter()
548 .any(|item| matches!(item, AssistantContent::ToolCall(_)));
549
550 self.state = RunState::ResolvingToolCalls(Box::new(ResolvingState {
551 message_id: turn.message_id,
552 original_choice: turn.choice,
553 items,
554 next_index: 0,
555 executable_tool_names: turn.executable_tool_names,
556 allowed_tool_names: turn.allowed_tool_names,
557 skipped: BTreeMap::new(),
558 recovered: false,
559 any_skipped: false,
560 has_tool_calls,
561 }));
562
563 self.advance_resolution()
564 }
565
566 pub fn resolve_invalid_tool_call(
579 &mut self,
580 action: InvalidToolCallHookAction,
581 ) -> Result<ModelTurnOutcome, PromptError> {
582 let mut resolving = match std::mem::replace(&mut self.state, RunState::Failed) {
585 RunState::ResolvingToolCalls(resolving) => resolving,
586 other => {
587 self.state = other;
588 return Err(self.protocol_violation(
589 "resolve_invalid_tool_call called without a pending invalid tool call",
590 ));
591 }
592 };
593 let tool_call = match resolving.items.get(resolving.next_index) {
594 Some(AssistantContent::ToolCall(tool_call))
595 if !resolving
596 .allowed_tool_names
597 .contains(&tool_call.function.name) =>
598 {
599 tool_call.clone()
600 }
601 _ => {
602 self.state = RunState::ResolvingToolCalls(resolving);
603 return Err(self.protocol_violation(
604 "resolve_invalid_tool_call called without a pending invalid tool call",
605 ));
606 }
607 };
608
609 let diagnostic_history = self.diagnostic_history(&resolving);
610 let executable_tool_names: Vec<String> =
611 resolving.executable_tool_names.iter().cloned().collect();
612 let allowed_tool_names: Vec<String> =
613 resolving.allowed_tool_names.iter().cloned().collect();
614
615 match action {
616 InvalidToolCallHookAction::Fail => Err(PromptError::UnknownToolCall {
617 tool_name: tool_call.function.name,
618 available_tools: executable_tool_names,
619 allowed_tools: allowed_tool_names,
620 chat_history: Box::new(diagnostic_history),
621 }),
622 InvalidToolCallHookAction::Retry { feedback } => {
623 if self.invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
624 return Err(PromptError::UnknownToolCall {
625 tool_name: tool_call.function.name,
626 available_tools: executable_tool_names,
627 allowed_tools: allowed_tool_names,
628 chat_history: Box::new(diagnostic_history),
629 });
630 }
631 self.invalid_tool_call_retries += 1;
632
633 self.new_messages.push(Message::Assistant {
634 id: resolving.message_id.clone(),
635 content: resolving.original_choice.clone(),
636 });
637 let Some(user_message) = invalid_tool_retry_user_message(
638 &resolving.original_choice,
639 &tool_call.id,
640 feedback,
641 ) else {
642 return Err(PromptError::prompt_cancelled(
643 diagnostic_history,
644 "invalid tool call retry produced no retry messages",
645 ));
646 };
647 self.new_messages.push(user_message);
648 self.state = RunState::PreparingRequest;
649 Ok(ModelTurnOutcome::TurnRetried)
650 }
651 InvalidToolCallHookAction::Repair { tool_name } => {
652 if !allowed_tool_names.contains(&tool_name) {
653 return Err(PromptError::UnknownToolCall {
654 tool_name,
655 available_tools: executable_tool_names,
656 allowed_tools: allowed_tool_names,
657 chat_history: Box::new(diagnostic_history),
658 });
659 }
660 if let Some(AssistantContent::ToolCall(tool_call)) =
661 resolving.items.get_mut(resolving.next_index)
662 {
663 tool_call.function.name = tool_name;
664 }
665 resolving.recovered = true;
666 self.state = RunState::ResolvingToolCalls(resolving);
667 self.advance_resolution()
668 }
669 InvalidToolCallHookAction::Skip { reason } => {
670 if matches!(self.tool_choice, Some(ToolChoice::None)) {
671 return Err(PromptError::UnknownToolCall {
672 tool_name: tool_call.function.name,
673 available_tools: executable_tool_names,
674 allowed_tools: allowed_tool_names,
675 chat_history: Box::new(diagnostic_history),
676 });
677 }
678 let user_content = if let Some(call_id) = tool_call.call_id.clone() {
679 UserContent::tool_result_with_call_id(
680 tool_call.id.clone(),
681 call_id,
682 OneOrMany::one(reason.into()),
683 )
684 } else {
685 UserContent::tool_result(tool_call.id.clone(), OneOrMany::one(reason.into()))
686 };
687 resolving.skipped.insert(tool_call.id.clone(), user_content);
688 resolving.recovered = true;
689 resolving.any_skipped = true;
690 resolving.next_index += 1;
691 self.state = RunState::ResolvingToolCalls(resolving);
692 self.advance_resolution()
693 }
694 }
695 }
696
697 pub fn tool_results(&mut self, results: Vec<UserContent>) -> Result<(), PromptError> {
705 let RunState::ExecutingTools(pending) = &self.state else {
706 return Err(
707 self.protocol_violation("tool_results called without a pending CallTools step")
708 );
709 };
710 let mut unanswered: Vec<String> = pending
713 .iter()
714 .map(|call| call.tool_call.id.clone())
715 .collect();
716
717 if results.is_empty() {
718 self.state = RunState::Failed;
719 return Err(PromptError::prompt_cancelled(
720 self.full_history(),
721 "tool execution produced no tool results",
722 ));
723 }
724 for result in &results {
725 let UserContent::ToolResult(tool_result) = result else {
726 return Err(self.protocol_violation(
727 "tool_results received content that is not a tool result",
728 ));
729 };
730 let Some(index) = unanswered.iter().position(|id| *id == tool_result.id) else {
731 return Err(self.protocol_violation(&format!(
732 "tool_results received a result for unknown or already-answered tool call id `{}`",
733 tool_result.id
734 )));
735 };
736 unanswered.swap_remove(index);
737 }
738 if !unanswered.is_empty() {
739 return Err(self.protocol_violation(&format!(
740 "tool_results left pending tool call id(s) unanswered: {unanswered:?}"
741 )));
742 }
743
744 let Some(content) = OneOrMany::from_iter_optional(results) else {
746 return Err(
747 self.protocol_violation("internal: tool results vanished during validation")
748 );
749 };
750
751 self.new_messages.push(Message::User { content });
752 self.state = RunState::PreparingRequest;
753 Ok(())
754 }
755
756 fn advance_resolution(&mut self) -> Result<ModelTurnOutcome, PromptError> {
759 let mut resolving = match std::mem::replace(&mut self.state, RunState::Failed) {
760 RunState::ResolvingToolCalls(resolving) => resolving,
761 other => {
762 self.state = other;
763 return Err(self.protocol_violation(
764 "internal: advance_resolution outside of tool-call resolution",
765 ));
766 }
767 };
768 while let Some(item) = resolving.items.get(resolving.next_index) {
769 match item {
770 AssistantContent::ToolCall(tool_call)
771 if !resolving
772 .allowed_tool_names
773 .contains(&tool_call.function.name) =>
774 {
775 break;
776 }
777 _ => resolving.next_index += 1,
778 }
779 }
780
781 if resolving.next_index < resolving.items.len() {
782 self.state = RunState::ResolvingToolCalls(resolving);
783 return match self.pending_invalid_tool_call() {
784 Some(context) => Ok(ModelTurnOutcome::NeedsResolution(context)),
785 None => Err(self.protocol_violation(
786 "internal: pending invalid tool call could not be derived",
787 )),
788 };
789 }
790
791 let ResolvingState {
792 message_id,
793 items,
794 mut skipped,
795 recovered,
796 any_skipped,
797 has_tool_calls,
798 ..
799 } = *resolving;
800
801 if any_skipped {
804 for item in &items {
805 if let AssistantContent::ToolCall(tool_call) = item {
806 skipped.entry(tool_call.id.clone()).or_insert_with(|| {
807 tool_result_user_content(
808 tool_call.id.clone(),
809 tool_call.call_id.clone(),
810 TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
811 )
812 });
813 }
814 }
815 }
816
817 self.state = RunState::AwaitingAdvance(Box::new(TurnState {
818 message_id,
819 items,
820 has_tool_calls,
821 skipped,
822 internal_call_ids: Vec::new(),
823 }));
824 Ok(ModelTurnOutcome::Continue {
825 response_hook_suppressed: recovered,
826 })
827 }
828
829 pub fn record_streamed_completion_call(
843 &mut self,
844 usage: Usage,
845 ) -> Result<CompletionCall, PromptError> {
846 let recordable = matches!(self.state, RunState::AwaitingModel)
847 || (matches!(self.state, RunState::PreparingRequest) && self.rollback_pending);
848 if !recordable {
849 return Err(self.protocol_violation(
850 "record_streamed_completion_call called without a pending or rolled-back CallModel step",
851 ));
852 }
853 if self.streamed_completion_call_recorded {
854 return Err(self.protocol_violation(
855 "record_streamed_completion_call called twice for the same model turn",
856 ));
857 }
858 self.streamed_completion_call_recorded = true;
859
860 let call = CompletionCall::new(self.completion_call_index, usage);
861 self.completion_call_index += 1;
862 self.completion_calls.push(call);
863 self.usage += usage;
864 Ok(call)
865 }
866
867 pub fn streamed_invalid_tool_call_context(
870 &self,
871 partial: &PartialStreamedTurn,
872 invalid: &StreamedInvalidToolCall,
873 ) -> InvalidToolCallContext {
874 InvalidToolCallContext {
875 tool_name: invalid.tool_call.function.name.clone(),
876 tool_call_id: Some(invalid.tool_call.id.clone()),
877 internal_call_id: Some(invalid.internal_call_id.clone()),
878 args: invalid.args.clone(),
879 available_tools: invalid.executable_tool_names.iter().cloned().collect(),
880 allowed_tools: invalid.allowed_tool_names.iter().cloned().collect(),
881 tool_choice: self.tool_choice.clone(),
882 chat_history: self
883 .streamed_diagnostic_history(partial, Some(invalid.tool_call.clone())),
884 is_streaming: true,
885 }
886 }
887
888 pub fn resolve_streamed_invalid_tool_call(
896 &mut self,
897 partial: &PartialStreamedTurn,
898 invalid: &StreamedInvalidToolCall,
899 action: InvalidToolCallHookAction,
900 ) -> Result<StreamedResolution, PromptError> {
901 if !matches!(self.state, RunState::AwaitingModel) {
902 return Err(self.protocol_violation(
903 "resolve_streamed_invalid_tool_call called without a pending CallModel step",
904 ));
905 }
906
907 let diagnostic_history =
908 self.streamed_diagnostic_history(partial, Some(invalid.tool_call.clone()));
909 let executable_tool_names: Vec<String> =
910 invalid.executable_tool_names.iter().cloned().collect();
911 let allowed_tool_names: Vec<String> = invalid.allowed_tool_names.iter().cloned().collect();
912
913 match action {
914 InvalidToolCallHookAction::Fail => {
915 self.state = RunState::Failed;
916 Err(PromptError::UnknownToolCall {
917 tool_name: invalid.tool_call.function.name.clone(),
918 available_tools: executable_tool_names,
919 allowed_tools: allowed_tool_names,
920 chat_history: Box::new(diagnostic_history),
921 })
922 }
923 InvalidToolCallHookAction::Retry { feedback } => {
924 if self.invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
925 self.state = RunState::Failed;
926 return Err(PromptError::UnknownToolCall {
927 tool_name: invalid.tool_call.function.name.clone(),
928 available_tools: executable_tool_names,
929 allowed_tools: allowed_tool_names,
930 chat_history: Box::new(diagnostic_history),
931 });
932 }
933 self.invalid_tool_call_retries += 1;
934
935 let Some((assistant_message, user_message)) =
936 partial.rollback_messages(invalid.tool_call.clone(), feedback)
937 else {
938 self.state = RunState::Failed;
939 return Err(PromptError::prompt_cancelled(
940 diagnostic_history,
941 "invalid tool call retry produced no retry messages",
942 ));
943 };
944 self.new_messages.push(assistant_message);
945 self.new_messages.push(user_message);
946 self.rollback_pending = true;
947 self.state = RunState::PreparingRequest;
948 Ok(StreamedResolution::TurnAbandoned {
949 skipped_tool_result: None,
950 })
951 }
952 InvalidToolCallHookAction::Repair { tool_name } => {
953 if !invalid.allowed_tool_names.contains(&tool_name) {
954 self.state = RunState::Failed;
955 return Err(PromptError::UnknownToolCall {
956 tool_name,
957 available_tools: executable_tool_names,
958 allowed_tools: allowed_tool_names,
959 chat_history: Box::new(diagnostic_history),
960 });
961 }
962 Ok(StreamedResolution::Repaired { tool_name })
963 }
964 InvalidToolCallHookAction::Skip { reason } => {
965 if matches!(self.tool_choice, Some(ToolChoice::None)) {
966 self.state = RunState::Failed;
967 return Err(PromptError::UnknownToolCall {
968 tool_name: invalid.tool_call.function.name.clone(),
969 available_tools: executable_tool_names,
970 allowed_tools: allowed_tool_names,
971 chat_history: Box::new(diagnostic_history),
972 });
973 }
974
975 let skipped_tool_result = ToolResult {
976 id: invalid.tool_call.id.clone(),
977 call_id: invalid.tool_call.call_id.clone(),
978 content: ToolResultContent::from_tool_output(reason.clone()),
979 };
980 let Some((assistant_message, user_message)) =
981 partial.rollback_messages(invalid.tool_call.clone(), reason)
982 else {
983 self.state = RunState::Failed;
984 return Err(PromptError::prompt_cancelled(
985 diagnostic_history,
986 "invalid tool call skip produced no recovery messages",
987 ));
988 };
989 self.new_messages.push(assistant_message);
990 self.new_messages.push(user_message);
991 self.rollback_pending = true;
992 self.state = RunState::PreparingRequest;
993 Ok(StreamedResolution::TurnAbandoned {
994 skipped_tool_result: Some(skipped_tool_result),
995 })
996 }
997 }
998 }
999
1000 pub fn streamed_turn(&mut self, turn: StreamedTurn) -> Result<(), PromptError> {
1007 if !matches!(self.state, RunState::AwaitingModel) {
1008 return Err(
1009 self.protocol_violation("streamed_turn called without a pending CallModel step")
1010 );
1011 }
1012
1013 if !self.streamed_completion_call_recorded {
1017 self.completion_calls.push(CompletionCall::new(
1018 self.completion_call_index,
1019 Usage::new(),
1020 ));
1021 self.completion_call_index += 1;
1022 self.streamed_completion_call_recorded = true;
1023 }
1024
1025 let items: Vec<AssistantContent> = turn.choice.iter().cloned().collect();
1026 let has_tool_calls = items
1027 .iter()
1028 .any(|item| matches!(item, AssistantContent::ToolCall(_)));
1029
1030 for item in &items {
1031 let AssistantContent::ToolCall(tool_call) = item else {
1032 continue;
1033 };
1034 if !turn.allowed_tool_names.contains(&tool_call.function.name) {
1035 let mut diagnostic_messages = self.new_messages.clone();
1036 if !is_empty_assistant_turn(&turn.choice) {
1037 diagnostic_messages.push(Message::Assistant {
1038 id: turn.message_id.clone(),
1039 content: turn.choice.clone(),
1040 });
1041 }
1042 let diagnostic_history =
1043 build_full_history(self.chat_history.as_deref(), diagnostic_messages);
1044 self.state = RunState::Failed;
1045 return Err(PromptError::UnknownToolCall {
1046 tool_name: tool_call.function.name.clone(),
1047 available_tools: turn.executable_tool_names.iter().cloned().collect(),
1048 allowed_tools: turn.allowed_tool_names.iter().cloned().collect(),
1049 chat_history: Box::new(diagnostic_history),
1050 });
1051 }
1052 }
1053
1054 self.state = RunState::AwaitingAdvance(Box::new(TurnState {
1055 message_id: turn.message_id,
1056 items,
1057 has_tool_calls,
1058 skipped: BTreeMap::new(),
1059 internal_call_ids: turn.internal_call_ids,
1060 }));
1061 Ok(())
1062 }
1063
1064 fn streamed_diagnostic_history(
1067 &self,
1068 partial: &PartialStreamedTurn,
1069 current_tool_call: Option<ToolCall>,
1070 ) -> Vec<Message> {
1071 let mut messages = self.new_messages.clone();
1072 if let Some(assistant) = partial.assistant_message(current_tool_call) {
1073 messages.push(assistant);
1074 }
1075 build_full_history(self.chat_history.as_deref(), messages)
1076 }
1077
1078 fn diagnostic_history(&self, resolving: &ResolvingState) -> Vec<Message> {
1081 let mut diagnostic_messages = self.new_messages.clone();
1082 diagnostic_messages.push(Message::Assistant {
1083 id: resolving.message_id.clone(),
1084 content: resolving.original_choice.clone(),
1085 });
1086 build_full_history(self.chat_history.as_deref(), diagnostic_messages)
1087 }
1088
1089 fn protocol_violation(&self, reason: &str) -> PromptError {
1090 PromptError::prompt_cancelled(
1091 self.full_history(),
1092 format!("agent run driver protocol violation: {reason}"),
1093 )
1094 }
1095}
1096
1097#[cfg(test)]
1098mod tests {
1099 use super::*;
1100 use crate::message::{ToolFunction, ToolResultContent};
1101 use serde_json::json;
1102
1103 fn tool_names(names: &[&str]) -> BTreeSet<String> {
1104 names.iter().map(|name| (*name).to_string()).collect()
1105 }
1106
1107 fn usage(input_tokens: u64, output_tokens: u64) -> Usage {
1108 Usage {
1109 input_tokens,
1110 output_tokens,
1111 total_tokens: input_tokens + output_tokens,
1112 ..Usage::new()
1113 }
1114 }
1115
1116 fn text_turn(text: &str) -> ModelTurn {
1117 ModelTurn::new(
1118 None,
1119 OneOrMany::one(AssistantContent::text(text)),
1120 Usage::new(),
1121 tool_names(&["add"]),
1122 tool_names(&["add"]),
1123 )
1124 }
1125
1126 fn tool_call(id: &str, name: &str) -> AssistantContent {
1127 AssistantContent::ToolCall(ToolCall::new(
1128 id.to_string(),
1129 ToolFunction::new(name.to_string(), json!({"x": 1})),
1130 ))
1131 }
1132
1133 fn tool_call_turn(id: &str, name: &str) -> ModelTurn {
1134 ModelTurn::new(
1135 None,
1136 OneOrMany::one(tool_call(id, name)),
1137 Usage::new(),
1138 tool_names(&["add"]),
1139 tool_names(&["add"]),
1140 )
1141 }
1142
1143 fn tool_result(id: &str, output: &str) -> UserContent {
1144 UserContent::tool_result(
1145 id.to_string(),
1146 ToolResultContent::from_tool_output(output.to_string()),
1147 )
1148 }
1149
1150 fn expect_call_model(run: &mut AgentRun) -> (Message, Vec<Message>, usize) {
1151 match run.next_step().expect("next_step should succeed") {
1152 AgentRunStep::CallModel {
1153 prompt,
1154 history,
1155 turn,
1156 } => (prompt, history, turn),
1157 step => panic!("expected CallModel, got {step:?}"),
1158 }
1159 }
1160
1161 fn expect_call_tools(run: &mut AgentRun) -> Vec<PendingToolCall> {
1162 match run.next_step().expect("next_step should succeed") {
1163 AgentRunStep::CallTools { calls } => calls,
1164 step => panic!("expected CallTools, got {step:?}"),
1165 }
1166 }
1167
1168 fn expect_done(run: &mut AgentRun) -> PromptResponse {
1169 match run.next_step().expect("next_step should succeed") {
1170 AgentRunStep::Done(response) => response,
1171 step => panic!("expected Done, got {step:?}"),
1172 }
1173 }
1174
1175 fn expect_continue(outcome: ModelTurnOutcome) -> bool {
1176 match outcome {
1177 ModelTurnOutcome::Continue {
1178 response_hook_suppressed,
1179 } => response_hook_suppressed,
1180 outcome => panic!("expected Continue, got {outcome:?}"),
1181 }
1182 }
1183
1184 fn expect_needs_resolution(outcome: ModelTurnOutcome) -> InvalidToolCallContext {
1185 match outcome {
1186 ModelTurnOutcome::NeedsResolution(context) => context,
1187 outcome => panic!("expected NeedsResolution, got {outcome:?}"),
1188 }
1189 }
1190
1191 #[test]
1192 fn text_only_run_completes_in_one_turn() {
1193 let mut run = AgentRun::new("hello");
1194
1195 let (prompt, history, turn) = expect_call_model(&mut run);
1196 assert_eq!(prompt, Message::user("hello"));
1197 assert!(history.is_empty());
1198 assert_eq!(turn, 1);
1199
1200 let suppressed = expect_continue(
1201 run.model_response(text_turn("hi there"))
1202 .expect("model_response should succeed"),
1203 );
1204 assert!(!suppressed);
1205
1206 let response = expect_done(&mut run);
1207 assert_eq!(response.output, "hi there");
1208 let messages = response.messages.expect("messages should be recorded");
1209 assert_eq!(messages.len(), 2);
1210 assert!(run.is_done());
1211 }
1212
1213 #[test]
1214 fn input_history_prefixes_request_history() {
1215 let mut run = AgentRun::new("question")
1216 .with_history(vec![Message::user("earlier"), Message::assistant("reply")]);
1217
1218 let (_, history, _) = expect_call_model(&mut run);
1219 assert_eq!(
1220 history,
1221 vec![Message::user("earlier"), Message::assistant("reply")]
1222 );
1223
1224 expect_continue(
1225 run.model_response(text_turn("answer"))
1226 .expect("model_response should succeed"),
1227 );
1228 let response = expect_done(&mut run);
1229 assert_eq!(
1231 response
1232 .messages
1233 .expect("messages should be recorded")
1234 .len(),
1235 2
1236 );
1237 }
1238
1239 #[test]
1240 fn tool_roundtrip_threads_history_and_usage() {
1241 let mut run = AgentRun::new("add things").max_turns(2);
1242
1243 expect_call_model(&mut run);
1244 expect_continue(
1245 run.model_response(tool_call_turn("call_1", "add").with_usage_for_test(usage(10, 5)))
1246 .expect("model_response should succeed"),
1247 );
1248
1249 let calls = expect_call_tools(&mut run);
1250 assert_eq!(calls.len(), 1);
1251 assert_eq!(calls[0].tool_call.function.name, "add");
1252 assert!(calls[0].preresolved_result.is_none());
1253
1254 run.tool_results(vec![tool_result("call_1", "2")])
1255 .expect("tool_results should succeed");
1256
1257 let (prompt, history, turn) = expect_call_model(&mut run);
1258 assert_eq!(turn, 2);
1259 assert!(matches!(prompt, Message::User { .. }));
1262 assert_eq!(history.len(), 2);
1263
1264 expect_continue(
1265 run.model_response(text_turn("the answer is 2").with_usage_for_test(usage(20, 7)))
1266 .expect("model_response should succeed"),
1267 );
1268
1269 let response = expect_done(&mut run);
1270 assert_eq!(response.output, "the answer is 2");
1271 assert_eq!(response.usage, usage(30, 12));
1272 assert_eq!(response.completion_calls.len(), 2);
1273 assert_eq!(response.completion_calls[0].call_index, 0);
1274 assert_eq!(response.completion_calls[0].usage, usage(10, 5));
1275 assert_eq!(response.completion_calls[1].usage, usage(20, 7));
1276 assert_eq!(
1278 response
1279 .messages
1280 .expect("messages should be recorded")
1281 .len(),
1282 4
1283 );
1284 }
1285
1286 #[test]
1287 fn parallel_tool_calls_surface_in_emission_order() {
1288 let mut run = AgentRun::new("do both").max_turns(2);
1289
1290 expect_call_model(&mut run);
1291 let turn = ModelTurn::new(
1292 None,
1293 OneOrMany::many(vec![tool_call("call_1", "add"), tool_call("call_2", "add")])
1294 .expect("two items"),
1295 Usage::new(),
1296 tool_names(&["add"]),
1297 tool_names(&["add"]),
1298 );
1299 expect_continue(
1300 run.model_response(turn)
1301 .expect("model_response should succeed"),
1302 );
1303
1304 let calls = expect_call_tools(&mut run);
1305 assert_eq!(calls.len(), 2);
1306 assert_eq!(calls[0].tool_call.id, "call_1");
1307 assert_eq!(calls[1].tool_call.id, "call_2");
1308
1309 run.tool_results(vec![tool_result("call_2", "b"), tool_result("call_1", "a")])
1311 .expect("tool_results should succeed");
1312 let messages = run.messages();
1313 assert!(matches!(
1314 messages.last(),
1315 Some(Message::User { content }) if content.len() == 2
1316 ));
1317 }
1318
1319 #[test]
1320 fn max_turns_exhaustion_returns_max_turns_error() {
1321 let mut run = AgentRun::new("loop forever");
1322
1323 for turn_id in ["call_1", "call_2"] {
1324 expect_call_model(&mut run);
1325 expect_continue(
1326 run.model_response(tool_call_turn(turn_id, "add"))
1327 .expect("model_response should succeed"),
1328 );
1329 expect_call_tools(&mut run);
1330 run.tool_results(vec![tool_result(turn_id, "0")])
1331 .expect("tool_results should succeed");
1332 }
1333
1334 let err = run.next_step().expect_err("depth should be exhausted");
1335 assert!(matches!(
1336 err,
1337 PromptError::MaxTurnsError { max_turns: 0, .. }
1338 ));
1339 }
1340
1341 #[test]
1342 fn invalid_tool_call_fail_returns_unknown_tool_call() {
1343 let mut run = AgentRun::new("call something");
1344
1345 expect_call_model(&mut run);
1346 let context = expect_needs_resolution(
1347 run.model_response(tool_call_turn("call_1", "unknown"))
1348 .expect("model_response should succeed"),
1349 );
1350 assert_eq!(context.tool_name, "unknown");
1351 assert_eq!(context.available_tools, vec!["add".to_string()]);
1352 assert!(!context.is_streaming);
1353 assert_eq!(context.chat_history.len(), 2);
1355
1356 let err = run
1357 .resolve_invalid_tool_call(InvalidToolCallHookAction::fail())
1358 .expect_err("fail action should error");
1359 assert!(matches!(
1360 err,
1361 PromptError::UnknownToolCall { tool_name, .. } if tool_name == "unknown"
1362 ));
1363 }
1364
1365 #[test]
1366 fn invalid_tool_call_retry_rolls_back_with_feedback() {
1367 let mut run = AgentRun::new("call something")
1368 .max_turns(2)
1369 .max_invalid_tool_call_retries(1);
1370
1371 expect_call_model(&mut run);
1372 expect_needs_resolution(
1373 run.model_response(tool_call_turn("call_1", "unknown"))
1374 .expect("model_response should succeed"),
1375 );
1376 let outcome = run
1377 .resolve_invalid_tool_call(InvalidToolCallHookAction::retry("use add instead"))
1378 .expect("retry should be accepted");
1379 assert!(matches!(outcome, ModelTurnOutcome::TurnRetried));
1380
1381 assert_eq!(run.messages().len(), 3);
1383 let (prompt, _, turn) = expect_call_model(&mut run);
1384 assert_eq!(turn, 2);
1385 assert!(matches!(
1386 prompt,
1387 Message::User { ref content }
1388 if matches!(content.first(), UserContent::ToolResult(_))
1389 ));
1390
1391 expect_needs_resolution(
1393 run.model_response(tool_call_turn("call_2", "unknown"))
1394 .expect("model_response should succeed"),
1395 );
1396 let err = run
1397 .resolve_invalid_tool_call(InvalidToolCallHookAction::retry("again"))
1398 .expect_err("budget exhausted");
1399 assert!(matches!(err, PromptError::UnknownToolCall { .. }));
1400 }
1401
1402 #[test]
1403 fn invalid_tool_call_repair_renames_and_suppresses_response_hook() {
1404 let mut run = AgentRun::new("call something").max_turns(2);
1405
1406 expect_call_model(&mut run);
1407 expect_needs_resolution(
1408 run.model_response(tool_call_turn("call_1", "default_api"))
1409 .expect("model_response should succeed"),
1410 );
1411 let suppressed = expect_continue(
1412 run.resolve_invalid_tool_call(InvalidToolCallHookAction::repair("add"))
1413 .expect("repair should be accepted"),
1414 );
1415 assert!(suppressed);
1416
1417 let calls = expect_call_tools(&mut run);
1418 assert_eq!(calls[0].tool_call.function.name, "add");
1419 assert!(calls[0].preresolved_result.is_none());
1420 }
1421
1422 #[test]
1423 fn invalid_tool_call_repair_to_disallowed_name_fails() {
1424 let mut run = AgentRun::new("call something");
1425
1426 expect_call_model(&mut run);
1427 expect_needs_resolution(
1428 run.model_response(tool_call_turn("call_1", "unknown"))
1429 .expect("model_response should succeed"),
1430 );
1431 let err = run
1432 .resolve_invalid_tool_call(InvalidToolCallHookAction::repair("also_unknown"))
1433 .expect_err("repair to disallowed name should fail");
1434 assert!(matches!(
1435 err,
1436 PromptError::UnknownToolCall { tool_name, .. } if tool_name == "also_unknown"
1437 ));
1438 }
1439
1440 #[test]
1441 fn invalid_tool_call_skip_suppresses_all_peer_executions() {
1442 let mut run = AgentRun::new("call things").max_turns(2);
1443
1444 expect_call_model(&mut run);
1445 let turn = ModelTurn::new(
1446 None,
1447 OneOrMany::many(vec![
1448 tool_call("call_1", "unknown"),
1449 tool_call("call_2", "add"),
1450 ])
1451 .expect("two items"),
1452 Usage::new(),
1453 tool_names(&["add"]),
1454 tool_names(&["add"]),
1455 );
1456 expect_needs_resolution(
1457 run.model_response(turn)
1458 .expect("model_response should succeed"),
1459 );
1460 let suppressed = expect_continue(
1461 run.resolve_invalid_tool_call(InvalidToolCallHookAction::skip("not available"))
1462 .expect("skip should be accepted"),
1463 );
1464 assert!(suppressed);
1465
1466 let calls = expect_call_tools(&mut run);
1467 assert_eq!(calls.len(), 2);
1468 assert!(calls.iter().all(|call| call.preresolved_result.is_some()));
1470 }
1471
1472 #[test]
1473 fn skip_under_tool_choice_none_fails() {
1474 let mut run = AgentRun::new("call something").with_tool_choice(ToolChoice::None);
1475
1476 expect_call_model(&mut run);
1477 expect_needs_resolution(
1478 run.model_response(ModelTurn::new(
1479 None,
1480 OneOrMany::one(tool_call("call_1", "add")),
1481 Usage::new(),
1482 tool_names(&["add"]),
1483 BTreeSet::new(),
1484 ))
1485 .expect("model_response should succeed"),
1486 );
1487 let err = run
1488 .resolve_invalid_tool_call(InvalidToolCallHookAction::skip("nope"))
1489 .expect_err("skip under ToolChoice::None should fail");
1490 assert!(matches!(err, PromptError::UnknownToolCall { .. }));
1491 }
1492
1493 #[test]
1494 fn empty_tool_results_cancel_the_run() {
1495 let mut run = AgentRun::new("call something").max_turns(2);
1496
1497 expect_call_model(&mut run);
1498 expect_continue(
1499 run.model_response(tool_call_turn("call_1", "add"))
1500 .expect("model_response should succeed"),
1501 );
1502 expect_call_tools(&mut run);
1503
1504 let err = run
1505 .tool_results(Vec::new())
1506 .expect_err("empty results should cancel");
1507 assert!(matches!(
1508 err,
1509 PromptError::PromptCancelled { reason, .. }
1510 if reason.contains("tool execution produced no tool results")
1511 ));
1512 }
1513
1514 #[test]
1515 fn out_of_protocol_calls_are_rejected_without_corrupting_state() {
1516 let mut run = AgentRun::new("hello");
1517
1518 let err = run
1519 .tool_results(vec![tool_result("call_1", "x")])
1520 .expect_err("no CallTools pending");
1521 assert!(matches!(err, PromptError::PromptCancelled { .. }));
1522
1523 expect_call_model(&mut run);
1525 let err = run
1526 .next_step()
1527 .expect_err("model response is pending, next_step must be rejected");
1528 assert!(matches!(err, PromptError::PromptCancelled { .. }));
1529 expect_continue(
1530 run.model_response(text_turn("hi"))
1531 .expect("model_response should still succeed"),
1532 );
1533 assert_eq!(expect_done(&mut run).output, "hi");
1534 }
1535
1536 #[test]
1537 fn model_response_rejected_after_streamed_completion_call_record() {
1538 let mut run = AgentRun::new("hello");
1539 expect_call_model(&mut run);
1540 run.record_streamed_completion_call(Usage::new())
1541 .expect("record should succeed");
1542
1543 let err = run
1544 .model_response(text_turn("hi"))
1545 .expect_err("mixed streamed/non-streamed ingestion must be rejected");
1546 assert!(matches!(err, PromptError::PromptCancelled { .. }));
1547 assert_eq!(run.completion_calls().len(), 1);
1549 }
1550
1551 #[test]
1552 fn done_step_is_idempotent() {
1553 let mut run = AgentRun::new("hello");
1554 expect_call_model(&mut run);
1555 expect_continue(
1556 run.model_response(text_turn("hi"))
1557 .expect("model_response should succeed"),
1558 );
1559 assert_eq!(expect_done(&mut run).output, "hi");
1560 assert_eq!(expect_done(&mut run).output, "hi");
1561 }
1562
1563 #[test]
1564 fn serialized_run_alone_carries_pending_tool_calls() {
1565 let mut run = AgentRun::new("add things").max_turns(2);
1566 expect_call_model(&mut run);
1567 expect_continue(
1568 run.model_response(tool_call_turn("call_1", "add"))
1569 .expect("model_response should succeed"),
1570 );
1571 expect_call_tools(&mut run);
1572
1573 let serialized = serde_json::to_string(&run).expect("mid-run state should serialize");
1576 drop(run);
1577 let mut resumed: AgentRun =
1578 serde_json::from_str(&serialized).expect("mid-run state should deserialize");
1579
1580 let calls = expect_call_tools(&mut resumed);
1581 assert_eq!(calls.len(), 1);
1582 assert_eq!(calls[0].tool_call.function.name, "add");
1583 let calls_again = expect_call_tools(&mut resumed);
1585 assert_eq!(calls_again[0].tool_call.id, calls[0].tool_call.id);
1586
1587 let results = calls
1589 .iter()
1590 .map(|call| tool_result(&call.tool_call.id, "2"))
1591 .collect::<Vec<_>>();
1592 resumed
1593 .tool_results(results)
1594 .expect("tool_results should succeed");
1595 expect_call_model(&mut resumed);
1596 expect_continue(
1597 resumed
1598 .model_response(text_turn("done"))
1599 .expect("model_response should succeed"),
1600 );
1601 assert_eq!(expect_done(&mut resumed).output, "done");
1602 }
1603
1604 #[test]
1605 fn tool_results_validates_against_pending_calls() {
1606 let drive_to_pending_tools = || {
1607 let mut run = AgentRun::new("add things").max_turns(2);
1608 expect_call_model(&mut run);
1609 expect_continue(
1610 run.model_response(tool_call_turn("call_1", "add"))
1611 .expect("model_response should succeed"),
1612 );
1613 expect_call_tools(&mut run);
1614 run
1615 };
1616
1617 let mut run = drive_to_pending_tools();
1619 let err = run
1620 .tool_results(vec![tool_result("call_unknown", "2")])
1621 .expect_err("unknown tool call id must be rejected");
1622 assert!(matches!(err, PromptError::PromptCancelled { .. }));
1623 run.tool_results(vec![tool_result("call_1", "2")])
1624 .expect("valid results should still be accepted after a rejection");
1625
1626 let mut run = drive_to_pending_tools();
1628 let err = run
1629 .tool_results(vec![tool_result("call_1", "2"), tool_result("call_1", "3")])
1630 .expect_err("answering one call twice must be rejected");
1631 assert!(matches!(err, PromptError::PromptCancelled { .. }));
1632
1633 let mut run = drive_to_pending_tools();
1635 let err = run
1636 .tool_results(vec![UserContent::text("not a tool result")])
1637 .expect_err("non-tool-result content must be rejected");
1638 assert!(matches!(err, PromptError::PromptCancelled { .. }));
1639 }
1640
1641 #[test]
1642 fn agent_run_deserializes_pre_monoid_suspended_state() {
1643 let fixture = r#"{"max_turns":2,"max_invalid_tool_call_retries":0,"tool_choice":null,"chat_history":null,"new_messages":[{"role":"user","content":[{"type":"text","text":"add things"}]},{"role":"assistant","id":null,"content":[{"id":"call_1","call_id":null,"function":{"name":"add","arguments":{"x":1}},"signature":null,"additional_params":null}]}],"current_turn":1,"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15,"cached_input_tokens":0,"cache_creation_input_tokens":0,"tool_use_prompt_tokens":0,"reasoning_tokens":0},"completion_calls":[{"call_index":0,"usage":null}],"completion_call_index":1,"invalid_tool_call_retries":0,"rollback_pending":false,"streamed_completion_call_recorded":false,"state":{"ExecutingTools":[{"tool_call":{"id":"call_1","call_id":null,"function":{"name":"add","arguments":{"x":1}},"signature":null,"additional_params":null},"preresolved_result":null,"internal_call_id":null}]}}"#;
1647
1648 let mut restored: AgentRun =
1649 serde_json::from_str(fixture).expect("old-format suspended run should deserialize");
1650 assert_eq!(restored.completion_calls()[0].usage, Usage::new());
1651
1652 let calls = expect_call_tools(&mut restored);
1653 assert_eq!(calls.len(), 1);
1654 restored
1655 .tool_results(vec![tool_result("call_1", "2")])
1656 .expect("tool_results should succeed");
1657 expect_call_model(&mut restored);
1658 }
1659
1660 #[test]
1661 fn serde_round_trip_mid_run_resumes_identically() {
1662 let drive_to_pending_tools = || {
1663 let mut run = AgentRun::new("add things").max_turns(2);
1664 expect_call_model(&mut run);
1665 expect_continue(
1666 run.model_response(
1667 tool_call_turn("call_1", "add").with_usage_for_test(usage(10, 5)),
1668 )
1669 .expect("model_response should succeed"),
1670 );
1671 expect_call_tools(&mut run);
1672 run
1673 };
1674
1675 let finish = |mut run: AgentRun| {
1676 run.tool_results(vec![tool_result("call_1", "2")])
1677 .expect("tool_results should succeed");
1678 expect_call_model(&mut run);
1679 expect_continue(
1680 run.model_response(text_turn("done").with_usage_for_test(usage(3, 4)))
1681 .expect("model_response should succeed"),
1682 );
1683 expect_done(&mut run)
1684 };
1685
1686 let uninterrupted = finish(drive_to_pending_tools());
1687
1688 let suspended = drive_to_pending_tools();
1689 let serialized = serde_json::to_string(&suspended).expect("mid-run state should serialize");
1690 let restored: AgentRun =
1691 serde_json::from_str(&serialized).expect("mid-run state should deserialize");
1692 let resumed = finish(restored);
1693
1694 assert_eq!(resumed.output, uninterrupted.output);
1695 assert_eq!(resumed.usage, uninterrupted.usage);
1696 assert_eq!(resumed.completion_calls, uninterrupted.completion_calls);
1697 assert_eq!(
1701 serde_json::to_value(&resumed.messages).expect("messages should serialize"),
1702 serde_json::to_value(&uninterrupted.messages).expect("messages should serialize"),
1703 );
1704 }
1705
1706 #[test]
1707 fn pending_invalid_tool_call_survives_serde_round_trip() {
1708 let mut run = AgentRun::new("call something");
1709 expect_call_model(&mut run);
1710 let context = expect_needs_resolution(
1711 run.model_response(tool_call_turn("call_1", "unknown"))
1712 .expect("model_response should succeed"),
1713 );
1714
1715 let serialized = serde_json::to_string(&run).expect("state should serialize");
1716 let restored: AgentRun =
1717 serde_json::from_str(&serialized).expect("state should deserialize");
1718 let restored_context = restored
1719 .pending_invalid_tool_call()
1720 .expect("pending resolution should survive serialization");
1721 assert_eq!(restored_context.tool_name, context.tool_name);
1722 assert_eq!(
1723 restored_context.chat_history.len(),
1724 context.chat_history.len()
1725 );
1726 }
1727
1728 impl ModelTurn {
1729 fn with_usage_for_test(mut self, usage: Usage) -> Self {
1730 self.usage = usage;
1731 self
1732 }
1733 }
1734}