1use std::collections::{BTreeSet, HashMap};
37
38use serde::{Deserialize, Serialize};
39
40use crate::{
41 OneOrMany,
42 agent::prompt_request::{TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER, tool_result_user_content},
43 completion::{CompletionError, GetTokenUsage, Message, Usage},
44 json_utils,
45 message::{AssistantContent, Reasoning, ToolCall, ToolFunction, ToolResult},
46 streaming::{StreamedAssistantContent, ToolCallDeltaContent},
47};
48
49pub(crate) fn merge_reasoning_blocks(
52 accumulated_reasoning: &mut Vec<Reasoning>,
53 incoming: &Reasoning,
54) {
55 let ids_match = |existing: &Reasoning| {
56 matches!(
57 (&existing.id, &incoming.id),
58 (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
59 )
60 };
61
62 if let Some(existing) = accumulated_reasoning
63 .iter_mut()
64 .rev()
65 .find(|existing| ids_match(existing))
66 {
67 existing.content.extend(incoming.content.clone());
68 } else {
69 accumulated_reasoning.push(incoming.clone());
70 }
71}
72
73pub(crate) fn ordered_streaming_assistant_content(
76 reasoning_items: impl IntoIterator<Item = Reasoning>,
77 text_items: impl IntoIterator<Item = AssistantContent>,
78 trailing_items: impl IntoIterator<Item = AssistantContent>,
79) -> Option<OneOrMany<AssistantContent>> {
80 let mut content_items = reasoning_items
81 .into_iter()
82 .map(AssistantContent::Reasoning)
83 .collect::<Vec<_>>();
84 content_items.extend(text_items);
85 content_items.extend(trailing_items);
86
87 OneOrMany::from_iter_optional(content_items)
88}
89
90pub(crate) fn assistant_text_items_from_choice(
91 choice: &OneOrMany<AssistantContent>,
92) -> Vec<AssistantContent> {
93 choice
94 .iter()
95 .filter_map(|content| match content {
96 AssistantContent::Text(text) => (!text.text.is_empty()
97 || text.additional_params.is_some())
98 .then(|| AssistantContent::Text(text.clone())),
99 _ => None,
100 })
101 .collect()
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
107#[non_exhaustive]
108pub struct StreamedInvalidToolCall {
109 pub tool_call: ToolCall,
112 pub internal_call_id: String,
114 pub args: Option<String>,
116 pub executable_tool_names: BTreeSet<String>,
118 pub allowed_tool_names: BTreeSet<String>,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
126#[non_exhaustive]
127pub struct PartialStreamedTurn {
128 pub message_id: Option<String>,
130 pub text: Option<String>,
132 pub reasoning: Vec<Reasoning>,
135 pub pending_tool_calls: Vec<ToolCall>,
137}
138
139impl PartialStreamedTurn {
140 pub(crate) fn assistant_message(&self, current_tool_call: Option<ToolCall>) -> Option<Message> {
144 let text_items = match &self.text {
145 Some(text) if !text.is_empty() => vec![AssistantContent::text(text.clone())],
146 _ => Vec::new(),
147 };
148 let mut tool_items = self
149 .pending_tool_calls
150 .iter()
151 .cloned()
152 .map(AssistantContent::ToolCall)
153 .collect::<Vec<_>>();
154 if let Some(tool_call) = current_tool_call {
155 tool_items.push(AssistantContent::ToolCall(tool_call));
156 }
157
158 let content = ordered_streaming_assistant_content(
159 self.reasoning.iter().cloned(),
160 text_items,
161 tool_items,
162 )?;
163 Some(Message::Assistant {
164 id: self.message_id.clone(),
165 content,
166 })
167 }
168
169 pub(crate) fn rollback_messages(
173 &self,
174 invalid_tool_call: ToolCall,
175 feedback: String,
176 ) -> Option<(Message, Message)> {
177 let assistant_message = self.assistant_message(Some(invalid_tool_call.clone()))?;
178
179 let mut retry_results = self
180 .pending_tool_calls
181 .iter()
182 .map(|tool_call| {
183 tool_result_user_content(
184 tool_call.id.clone(),
185 tool_call.call_id.clone(),
186 TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
187 )
188 })
189 .collect::<Vec<_>>();
190 retry_results.push(tool_result_user_content(
191 invalid_tool_call.id,
192 invalid_tool_call.call_id,
193 feedback,
194 ));
195
196 let user_message = Message::User {
197 content: OneOrMany::from_iter_optional(retry_results)?,
198 };
199
200 Some((assistant_message, user_message))
201 }
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
207#[non_exhaustive]
208pub struct StreamedTurn {
209 pub message_id: Option<String>,
211 pub choice: OneOrMany<AssistantContent>,
215 pub executable_tool_names: BTreeSet<String>,
217 pub allowed_tool_names: BTreeSet<String>,
219 #[serde(default)]
223 pub internal_call_ids: Vec<(String, String)>,
224}
225
226#[derive(Debug)]
231pub enum StreamedResolution {
232 Repaired {
236 tool_name: String,
238 },
239 TurnAbandoned {
244 skipped_tool_result: Option<ToolResult>,
247 },
248}
249
250#[derive(Debug, Clone)]
255pub enum StreamedTurnEvent {
256 EmitIngested,
259 EmitToolCallDelta {
262 id: String,
264 internal_call_id: String,
266 content: ToolCallDeltaContent,
268 },
269 InvalidToolCall(Box<StreamedInvalidToolCall>),
274 Completed {
279 usage: Usage,
282 emit_final: bool,
285 },
286}
287
288#[derive(Default)]
289struct ToolCallDeltaState {
290 name_validated: bool,
291 buffered_arguments: Vec<String>,
292}
293
294enum PendingInvalid {
295 FullCall {
297 tool_call: Box<ToolCall>,
298 internal_call_id: String,
299 },
300 NameDelta {
302 id: String,
303 internal_call_id: String,
304 },
305}
306
307pub struct StreamedTurnAssembler {
310 executable_tool_names: BTreeSet<String>,
311 allowed_tool_names: BTreeSet<String>,
312 text: String,
313 saw_text: bool,
314 accumulated_reasoning: Vec<Reasoning>,
315 pending_reasoning_delta_text: String,
316 pending_reasoning_delta_id: Option<String>,
317 pending_tool_calls: Vec<(ToolCall, String)>,
318 delta_states: HashMap<(String, String), ToolCallDeltaState>,
319 pending_invalid: Option<PendingInvalid>,
320}
321
322impl StreamedTurnAssembler {
323 pub fn new(
326 executable_tool_names: BTreeSet<String>,
327 allowed_tool_names: BTreeSet<String>,
328 ) -> Self {
329 Self {
330 executable_tool_names,
331 allowed_tool_names,
332 text: String::new(),
333 saw_text: false,
334 accumulated_reasoning: Vec::new(),
335 pending_reasoning_delta_text: String::new(),
336 pending_reasoning_delta_id: None,
337 pending_tool_calls: Vec::new(),
338 delta_states: HashMap::new(),
339 pending_invalid: None,
340 }
341 }
342
343 pub fn aggregated_text(&self) -> &str {
346 &self.text
347 }
348
349 pub fn ingest<R>(
356 &mut self,
357 item: &StreamedAssistantContent<R>,
358 ) -> Result<Vec<StreamedTurnEvent>, CompletionError>
359 where
360 R: Clone + Unpin + GetTokenUsage,
361 {
362 if self.pending_invalid.is_some() {
363 return Err(CompletionError::ResponseError(
364 "streamed turn ingested while an invalid tool call awaits resolution".to_string(),
365 ));
366 }
367
368 match item {
369 StreamedAssistantContent::Text(text) => {
370 if !self.saw_text {
371 self.text.clear();
372 self.saw_text = true;
373 }
374 self.text.push_str(&text.text);
375 Ok(vec![StreamedTurnEvent::EmitIngested])
376 }
377 StreamedAssistantContent::Reasoning(reasoning) => {
378 merge_reasoning_blocks(&mut self.accumulated_reasoning, reasoning);
379 Ok(vec![StreamedTurnEvent::EmitIngested])
380 }
381 StreamedAssistantContent::ReasoningDelta { reasoning, id } => {
382 self.pending_reasoning_delta_text.push_str(reasoning);
387 if self.pending_reasoning_delta_id.is_none() {
388 self.pending_reasoning_delta_id = id.clone();
389 }
390 Ok(vec![StreamedTurnEvent::EmitIngested])
391 }
392 StreamedAssistantContent::ToolCall {
393 tool_call,
394 internal_call_id,
395 } => {
396 if !self.allowed_tool_names.contains(&tool_call.function.name) {
397 let invalid = StreamedInvalidToolCall {
398 tool_call: tool_call.clone(),
399 internal_call_id: internal_call_id.clone(),
400 args: Some(json_utils::value_to_json_string(
401 &tool_call.function.arguments,
402 )),
403 executable_tool_names: self.executable_tool_names.clone(),
404 allowed_tool_names: self.allowed_tool_names.clone(),
405 };
406 self.pending_invalid = Some(PendingInvalid::FullCall {
407 tool_call: Box::new(tool_call.clone()),
408 internal_call_id: internal_call_id.clone(),
409 });
410 return Ok(vec![StreamedTurnEvent::InvalidToolCall(Box::new(invalid))]);
411 }
412
413 self.pending_tool_calls
414 .push((tool_call.clone(), internal_call_id.clone()));
415 Ok(Vec::new())
416 }
417 StreamedAssistantContent::ToolCallDelta {
418 id,
419 internal_call_id,
420 content,
421 } => {
422 let key = (id.clone(), internal_call_id.clone());
423 match content {
424 ToolCallDeltaContent::Name(name) => {
425 if !self.allowed_tool_names.contains(name) {
426 let buffered_args = self
427 .delta_states
428 .get(&key)
429 .map(|state| state.buffered_arguments.join(""))
430 .unwrap_or_default();
431 let invalid = StreamedInvalidToolCall {
432 tool_call: self.name_delta_diagnostic_tool_call(
433 id,
434 name,
435 &buffered_args,
436 ),
437 internal_call_id: internal_call_id.clone(),
438 args: Some(buffered_args),
439 executable_tool_names: self.executable_tool_names.clone(),
440 allowed_tool_names: self.allowed_tool_names.clone(),
441 };
442 self.pending_invalid = Some(PendingInvalid::NameDelta {
443 id: id.clone(),
444 internal_call_id: internal_call_id.clone(),
445 });
446 return Ok(vec![StreamedTurnEvent::InvalidToolCall(Box::new(invalid))]);
447 }
448
449 Ok(self.validate_delta_name(&key, name.clone()))
450 }
451 ToolCallDeltaContent::Delta(arguments) => {
452 let state = self.delta_states.entry(key.clone()).or_default();
453 if state.name_validated {
454 Ok(vec![StreamedTurnEvent::EmitToolCallDelta {
455 id: id.clone(),
456 internal_call_id: internal_call_id.clone(),
457 content: ToolCallDeltaContent::Delta(arguments.clone()),
458 }])
459 } else {
460 state.buffered_arguments.push(arguments.clone());
461 Ok(Vec::new())
462 }
463 }
464 }
465 }
466 StreamedAssistantContent::Final(final_response) => {
467 if let Some(err) = self.pending_delta_error() {
468 return Err(err);
469 }
470
471 let usage = final_response.token_usage();
472 let emit_final = self.saw_text;
473 self.saw_text = false;
474 Ok(vec![StreamedTurnEvent::Completed { usage, emit_final }])
475 }
476 }
477 }
478
479 pub fn resolve_pending_invalid(
484 &mut self,
485 resolution: &StreamedResolution,
486 ) -> Vec<StreamedTurnEvent> {
487 let Some(pending) = self.pending_invalid.take() else {
488 return Vec::new();
489 };
490
491 match (resolution, pending) {
492 (
493 StreamedResolution::Repaired { tool_name },
494 PendingInvalid::FullCall {
495 mut tool_call,
496 internal_call_id,
497 },
498 ) => {
499 tool_call.function.name = tool_name.clone();
500 self.pending_tool_calls.push((*tool_call, internal_call_id));
501 Vec::new()
502 }
503 (
504 StreamedResolution::Repaired { tool_name },
505 PendingInvalid::NameDelta {
506 id,
507 internal_call_id,
508 },
509 ) => {
510 let key = (id, internal_call_id);
511 self.validate_delta_name(&key, tool_name.clone())
512 }
513 (
514 StreamedResolution::TurnAbandoned { .. },
515 PendingInvalid::NameDelta {
516 id,
517 internal_call_id,
518 },
519 ) => {
520 self.delta_states.remove(&(id, internal_call_id));
523 Vec::new()
524 }
525 (StreamedResolution::TurnAbandoned { .. }, PendingInvalid::FullCall { .. }) => {
526 Vec::new()
527 }
528 }
529 }
530
531 pub fn pending_delta_error(&self) -> Option<CompletionError> {
534 self.delta_states
535 .iter()
536 .find(|(_, state)| !state.name_validated && !state.buffered_arguments.is_empty())
537 .map(|((id, internal_call_id), state)| {
538 CompletionError::ResponseError(format!(
539 "streamed tool call arguments received before a validated tool name for id `{id}` and internal_call_id `{internal_call_id}` ({} buffered argument delta(s))",
540 state.buffered_arguments.len()
541 ))
542 })
543 }
544
545 pub fn partial_turn(&self, message_id: Option<String>) -> PartialStreamedTurn {
547 let mut reasoning = self.accumulated_reasoning.clone();
548 if reasoning.is_empty() && !self.pending_reasoning_delta_text.is_empty() {
549 let mut assembled = Reasoning::new(&self.pending_reasoning_delta_text);
550 if let Some(id) = self.pending_reasoning_delta_id.clone() {
551 assembled = assembled.with_id(id);
552 }
553 reasoning.push(assembled);
554 }
555
556 PartialStreamedTurn {
557 message_id,
558 text: self.saw_text.then(|| self.text.clone()),
559 reasoning,
560 pending_tool_calls: self
561 .pending_tool_calls
562 .iter()
563 .map(|(tool_call, _)| tool_call.clone())
564 .collect(),
565 }
566 }
567
568 pub fn finish(
572 mut self,
573 message_id: Option<String>,
574 final_choice: &OneOrMany<AssistantContent>,
575 ) -> StreamedTurn {
576 let internal_call_ids: Vec<(String, String)> = self
577 .pending_tool_calls
578 .iter()
579 .map(|(tool_call, internal_call_id)| (tool_call.id.clone(), internal_call_id.clone()))
580 .collect();
581 if self.accumulated_reasoning.is_empty() && !self.pending_reasoning_delta_text.is_empty() {
585 let mut assembled = Reasoning::new(&self.pending_reasoning_delta_text);
586 if let Some(id) = self.pending_reasoning_delta_id.take() {
587 assembled = assembled.with_id(id);
588 }
589 self.accumulated_reasoning.push(assembled);
590 }
591
592 let choice =
595 if !self.pending_tool_calls.is_empty() || !self.accumulated_reasoning.is_empty() {
596 let text_items = assistant_text_items_from_choice(final_choice);
597 let tool_items = self
598 .pending_tool_calls
599 .iter()
600 .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone()))
601 .collect::<Vec<_>>();
602 ordered_streaming_assistant_content(
603 self.accumulated_reasoning.drain(..),
604 text_items,
605 tool_items,
606 )
607 .unwrap_or_else(|| final_choice.clone())
608 } else {
609 final_choice.clone()
610 };
611
612 StreamedTurn {
613 message_id,
614 choice,
615 executable_tool_names: self.executable_tool_names,
616 allowed_tool_names: self.allowed_tool_names,
617 internal_call_ids,
618 }
619 }
620
621 fn name_delta_diagnostic_tool_call(
622 &self,
623 id: &str,
624 name: &str,
625 buffered_args: &str,
626 ) -> ToolCall {
627 let diagnostic_args = if buffered_args.trim().is_empty() {
628 serde_json::Value::Null
629 } else {
630 serde_json::from_str(buffered_args).unwrap_or(serde_json::Value::Null)
631 };
632 ToolCall::new(
633 id.to_string(),
634 ToolFunction::new(name.to_string(), diagnostic_args),
635 )
636 }
637
638 fn validate_delta_name(
639 &mut self,
640 key: &(String, String),
641 name: String,
642 ) -> Vec<StreamedTurnEvent> {
643 let state = self.delta_states.entry(key.clone()).or_default();
644 state.name_validated = true;
645 let buffered_arguments = std::mem::take(&mut state.buffered_arguments);
646
647 let mut events = vec![StreamedTurnEvent::EmitToolCallDelta {
648 id: key.0.clone(),
649 internal_call_id: key.1.clone(),
650 content: ToolCallDeltaContent::Name(name),
651 }];
652 events.extend(buffered_arguments.into_iter().map(|arguments| {
653 StreamedTurnEvent::EmitToolCallDelta {
654 id: key.0.clone(),
655 internal_call_id: key.1.clone(),
656 content: ToolCallDeltaContent::Delta(arguments),
657 }
658 }));
659 events
660 }
661}
662
663#[cfg(test)]
664mod tests {
665 use super::*;
666 use crate::agent::prompt_request::hooks::InvalidToolCallHookAction;
667 use crate::agent::run::{AgentRun, AgentRunStep};
668 use crate::completion::PromptError;
669 use crate::message::{Text, ToolResultContent, UserContent};
670 use crate::test_utils::MockResponse;
671 use serde_json::json;
672
673 fn tool_names(names: &[&str]) -> BTreeSet<String> {
674 names.iter().map(|name| (*name).to_string()).collect()
675 }
676
677 fn assembler() -> StreamedTurnAssembler {
678 StreamedTurnAssembler::new(tool_names(&["add"]), tool_names(&["add"]))
679 }
680
681 fn text_item(text: &str) -> StreamedAssistantContent<MockResponse> {
682 StreamedAssistantContent::Text(Text::new(text.to_string()))
683 }
684
685 fn tool_call(id: &str, name: &str) -> ToolCall {
686 ToolCall::new(
687 id.to_string(),
688 ToolFunction::new(name.to_string(), json!({"x": 1})),
689 )
690 }
691
692 fn tool_call_item(id: &str, name: &str) -> StreamedAssistantContent<MockResponse> {
693 StreamedAssistantContent::ToolCall {
694 tool_call: tool_call(id, name),
695 internal_call_id: format!("internal_{id}"),
696 }
697 }
698
699 fn final_item() -> StreamedAssistantContent<MockResponse> {
700 StreamedAssistantContent::Final(MockResponse::with_usage(Usage::new()))
701 }
702
703 fn name_delta(id: &str, name: &str) -> StreamedAssistantContent<MockResponse> {
704 StreamedAssistantContent::ToolCallDelta {
705 id: id.to_string(),
706 internal_call_id: format!("internal_{id}"),
707 content: ToolCallDeltaContent::Name(name.to_string()),
708 }
709 }
710
711 fn args_delta(id: &str, arguments: &str) -> StreamedAssistantContent<MockResponse> {
712 StreamedAssistantContent::ToolCallDelta {
713 id: id.to_string(),
714 internal_call_id: format!("internal_{id}"),
715 content: ToolCallDeltaContent::Delta(arguments.to_string()),
716 }
717 }
718
719 fn expect_invalid(events: Vec<StreamedTurnEvent>) -> StreamedInvalidToolCall {
720 match events.into_iter().next() {
721 Some(StreamedTurnEvent::InvalidToolCall(invalid)) => *invalid,
722 other => panic!("expected InvalidToolCall, got {other:?}"),
723 }
724 }
725
726 #[test]
727 fn text_accumulates_and_emits() {
728 let mut asm = assembler();
729 let events = asm
730 .ingest(&text_item("hel"))
731 .expect("ingest should succeed");
732 assert!(matches!(
733 events.as_slice(),
734 [StreamedTurnEvent::EmitIngested]
735 ));
736 asm.ingest(&text_item("lo")).expect("ingest should succeed");
737 assert_eq!(asm.aggregated_text(), "hello");
738 }
739
740 #[test]
741 fn argument_deltas_buffer_until_name_validates() {
742 let mut asm = assembler();
743
744 let events = asm
745 .ingest(&args_delta("tc_1", "{\"x\""))
746 .expect("ingest should succeed");
747 assert!(events.is_empty(), "arguments must buffer before the name");
748
749 let events = asm
750 .ingest(&name_delta("tc_1", "add"))
751 .expect("ingest should succeed");
752 let contents: Vec<_> = events
753 .iter()
754 .map(|event| match event {
755 StreamedTurnEvent::EmitToolCallDelta { content, .. } => content.clone(),
756 other => panic!("expected EmitToolCallDelta, got {other:?}"),
757 })
758 .collect();
759 assert_eq!(
760 contents,
761 vec![
762 ToolCallDeltaContent::Name("add".to_string()),
763 ToolCallDeltaContent::Delta("{\"x\"".to_string()),
764 ]
765 );
766
767 let events = asm
769 .ingest(&args_delta("tc_1", ":1}"))
770 .expect("ingest should succeed");
771 assert_eq!(events.len(), 1);
772 }
773
774 #[test]
775 fn buffered_arguments_without_validated_name_error_at_final() {
776 let mut asm = assembler();
777 asm.ingest(&args_delta("tc_1", "{\"x\":1}"))
778 .expect("ingest should succeed");
779
780 assert!(asm.pending_delta_error().is_some());
781 assert!(asm.ingest(&final_item()).is_err());
782 }
783
784 #[test]
785 fn finish_orders_reasoning_text_then_tool_calls() {
786 let mut asm = assembler();
787 asm.ingest(&StreamedAssistantContent::<MockResponse>::ReasoningDelta {
788 id: Some("rs_1".to_string()),
789 reasoning: "think".to_string(),
790 })
791 .expect("ingest should succeed");
792 asm.ingest(&tool_call_item("tc_1", "add"))
793 .expect("ingest should succeed");
794
795 let final_choice = OneOrMany::many(vec![
797 AssistantContent::text("answer"),
798 AssistantContent::ToolCall(tool_call("tc_1", "add")),
799 ])
800 .expect("two items");
801
802 let turn = asm.finish(Some("msg_1".to_string()), &final_choice);
803 let kinds: Vec<&'static str> = turn
804 .choice
805 .iter()
806 .map(|item| match item {
807 AssistantContent::Reasoning(_) => "reasoning",
808 AssistantContent::Text(_) => "text",
809 AssistantContent::ToolCall(_) => "tool_call",
810 _ => "other",
811 })
812 .collect();
813 assert_eq!(kinds, vec!["reasoning", "text", "tool_call"]);
814 }
815
816 #[test]
817 fn finish_passes_raw_choice_through_for_plain_text_turns() {
818 let mut asm = assembler();
819 asm.ingest(&text_item("hi")).expect("ingest should succeed");
820
821 let final_choice = OneOrMany::one(AssistantContent::text("hi"));
822 let turn = asm.finish(None, &final_choice);
823 assert_eq!(
824 serde_json::to_value(&turn.choice).expect("serialize"),
825 serde_json::to_value(&final_choice).expect("serialize"),
826 );
827 }
828
829 #[test]
830 fn streamed_run_completes_a_tool_roundtrip() {
831 let mut run = AgentRun::new("add things").max_turns(2);
832
833 let AgentRunStep::CallModel { .. } = run.next_step().expect("next_step") else {
835 panic!("expected CallModel");
836 };
837 let mut asm = assembler();
838 assert!(
839 asm.ingest(&tool_call_item("tc_1", "add"))
840 .expect("ingest should succeed")
841 .is_empty()
842 );
843 let usage = Usage {
844 input_tokens: 5,
845 output_tokens: 7,
846 total_tokens: 12,
847 ..Usage::new()
848 };
849 run.record_streamed_completion_call(usage)
850 .expect("record should succeed");
851 let final_choice = OneOrMany::one(AssistantContent::ToolCall(tool_call("tc_1", "add")));
852 run.streamed_turn(asm.finish(Some("msg_1".to_string()), &final_choice))
853 .expect("streamed_turn should succeed");
854
855 let AgentRunStep::CallTools { calls } = run.next_step().expect("next_step") else {
856 panic!("expected CallTools");
857 };
858 assert_eq!(calls.len(), 1);
859 assert_eq!(calls[0].internal_call_id.as_deref(), Some("internal_tc_1"));
860 run.tool_results(vec![UserContent::tool_result(
861 "tc_1".to_string(),
862 ToolResultContent::from_tool_output("2".to_string()),
863 )])
864 .expect("tool_results should succeed");
865
866 let AgentRunStep::CallModel { .. } = run.next_step().expect("next_step") else {
868 panic!("expected CallModel");
869 };
870 let asm = assembler();
871 run.record_streamed_completion_call(Usage::new())
872 .expect("record should succeed");
873 let final_choice = OneOrMany::one(AssistantContent::text("done"));
874 run.streamed_turn(asm.finish(None, &final_choice))
875 .expect("streamed_turn should succeed");
876
877 let AgentRunStep::Done(response) = run.next_step().expect("next_step") else {
878 panic!("expected Done");
879 };
880 assert_eq!(response.output, "done");
881 assert_eq!(response.usage, usage);
882 assert_eq!(response.completion_calls.len(), 2);
883 assert_eq!(response.completion_calls[0].usage, usage);
884 assert_eq!(response.completion_calls[1].usage, Usage::new());
885 assert_eq!(
887 response
888 .messages
889 .expect("messages should be recorded")
890 .len(),
891 4
892 );
893 }
894
895 #[test]
896 fn streamed_invalid_tool_call_retry_rolls_back_with_partial_turn() {
897 let mut run = AgentRun::new("use the tool")
898 .max_turns(2)
899 .max_invalid_tool_call_retries(1);
900 run.next_step().expect("next_step");
901
902 let mut asm = assembler();
903 asm.ingest(&text_item("thinking ")).expect("ingest");
904 let invalid = expect_invalid(
905 asm.ingest(&tool_call_item("tc_1", "default_api"))
906 .expect("ingest should succeed"),
907 );
908 let partial = asm.partial_turn(Some("msg_1".to_string()));
909 assert_eq!(partial.text.as_deref(), Some("thinking "));
910
911 let context = run.streamed_invalid_tool_call_context(&partial, &invalid);
912 assert!(context.is_streaming);
913 assert_eq!(context.tool_name, "default_api");
914 assert_eq!(context.internal_call_id.as_deref(), Some("internal_tc_1"));
915
916 let resolution = run
917 .resolve_streamed_invalid_tool_call(
918 &partial,
919 &invalid,
920 InvalidToolCallHookAction::retry("use add instead"),
921 )
922 .expect("retry should be accepted");
923 assert!(matches!(
924 resolution,
925 StreamedResolution::TurnAbandoned {
926 skipped_tool_result: None
927 }
928 ));
929 asm.resolve_pending_invalid(&resolution);
930
931 run.record_streamed_completion_call(Usage::new())
933 .expect("record after rollback should succeed");
934
935 assert_eq!(run.messages().len(), 3);
937 let AgentRunStep::CallModel { turn, .. } = run.next_step().expect("next_step") else {
938 panic!("expected CallModel retry");
939 };
940 assert_eq!(turn, 2);
941 }
942
943 #[test]
944 fn streamed_invalid_tool_call_skip_returns_synthetic_result() {
945 let mut run = AgentRun::new("use the tool").max_turns(2);
946 run.next_step().expect("next_step");
947
948 let mut asm = assembler();
949 let invalid = expect_invalid(
950 asm.ingest(&tool_call_item("tc_1", "default_api"))
951 .expect("ingest should succeed"),
952 );
953 let partial = asm.partial_turn(None);
954
955 let resolution = run
956 .resolve_streamed_invalid_tool_call(
957 &partial,
958 &invalid,
959 InvalidToolCallHookAction::skip("not available"),
960 )
961 .expect("skip should be accepted");
962 let StreamedResolution::TurnAbandoned {
963 skipped_tool_result: Some(tool_result),
964 } = &resolution
965 else {
966 panic!("expected skipped tool result");
967 };
968 assert_eq!(tool_result.id, "tc_1");
969 }
970
971 #[test]
972 fn streamed_invalid_name_delta_repair_replays_buffered_arguments() {
973 let mut run = AgentRun::new("use the tool").max_turns(2);
974 run.next_step().expect("next_step");
975
976 let mut asm = assembler();
977 asm.ingest(&args_delta("tc_1", "{\"x\":1}"))
978 .expect("ingest should succeed");
979 let invalid = expect_invalid(
980 asm.ingest(&name_delta("tc_1", "default_api"))
981 .expect("ingest should succeed"),
982 );
983 assert_eq!(invalid.args.as_deref(), Some("{\"x\":1}"));
984
985 let partial = asm.partial_turn(None);
986 let resolution = run
987 .resolve_streamed_invalid_tool_call(
988 &partial,
989 &invalid,
990 InvalidToolCallHookAction::repair("add"),
991 )
992 .expect("repair should be accepted");
993 assert!(matches!(
994 resolution,
995 StreamedResolution::Repaired { ref tool_name } if tool_name == "add"
996 ));
997
998 let events = asm.resolve_pending_invalid(&resolution);
999 let contents: Vec<_> = events
1000 .iter()
1001 .map(|event| match event {
1002 StreamedTurnEvent::EmitToolCallDelta { content, .. } => content.clone(),
1003 other => panic!("expected EmitToolCallDelta, got {other:?}"),
1004 })
1005 .collect();
1006 assert_eq!(
1007 contents,
1008 vec![
1009 ToolCallDeltaContent::Name("add".to_string()),
1010 ToolCallDeltaContent::Delta("{\"x\":1}".to_string()),
1011 ]
1012 );
1013 }
1014
1015 #[test]
1016 fn streamed_turn_rejects_unknown_tool_calls_fail_fast() {
1017 let mut run = AgentRun::new("use the tool");
1018 run.next_step().expect("next_step");
1019
1020 let turn = StreamedTurn {
1021 message_id: None,
1022 choice: OneOrMany::one(AssistantContent::ToolCall(tool_call("tc_1", "unknown"))),
1023 executable_tool_names: tool_names(&["add"]),
1024 allowed_tool_names: tool_names(&["add"]),
1025 internal_call_ids: Vec::new(),
1026 };
1027 let err = run
1028 .streamed_turn(turn)
1029 .expect_err("unknown tool should fail fast");
1030 assert!(matches!(
1031 err,
1032 PromptError::UnknownToolCall { tool_name, .. } if tool_name == "unknown"
1033 ));
1034 }
1035
1036 #[test]
1037 fn streamed_completion_call_record_requires_a_model_call() {
1038 let mut run = AgentRun::new("hello");
1041 let err = run
1042 .record_streamed_completion_call(Usage::new())
1043 .expect_err("recording before any model call must be rejected");
1044 assert!(matches!(err, PromptError::PromptCancelled { .. }));
1045
1046 run.next_step().expect("next_step should still succeed");
1048 run.record_streamed_completion_call(Usage::new())
1049 .expect("recording during a pending model call succeeds");
1050 }
1051
1052 #[test]
1053 fn duplicate_tool_call_ids_keep_distinct_internal_ids_through_the_run() {
1054 let mut run = AgentRun::new("do both").max_turns(2);
1055 run.next_step().expect("next_step");
1056
1057 let mut asm = assembler();
1058 asm.ingest(&StreamedAssistantContent::<MockResponse>::ToolCall {
1059 tool_call: tool_call("tc_1", "add"),
1060 internal_call_id: "internal_a".to_string(),
1061 })
1062 .expect("ingest should succeed");
1063 asm.ingest(&StreamedAssistantContent::<MockResponse>::ToolCall {
1064 tool_call: tool_call("tc_1", "add"),
1065 internal_call_id: "internal_b".to_string(),
1066 })
1067 .expect("ingest should succeed");
1068 run.record_streamed_completion_call(Usage::new())
1069 .expect("record should succeed");
1070
1071 let final_choice = OneOrMany::many(vec![
1072 AssistantContent::ToolCall(tool_call("tc_1", "add")),
1073 AssistantContent::ToolCall(tool_call("tc_1", "add")),
1074 ])
1075 .expect("two items");
1076 run.streamed_turn(asm.finish(None, &final_choice))
1077 .expect("streamed_turn should succeed");
1078
1079 let serialized = serde_json::to_string(&run).expect("serialize");
1082 let mut restored: AgentRun = serde_json::from_str(&serialized).expect("deserialize");
1083 let AgentRunStep::CallTools { calls } = restored.next_step().expect("next_step") else {
1084 panic!("expected CallTools");
1085 };
1086 assert_eq!(calls.len(), 2);
1087 assert_eq!(calls[0].internal_call_id.as_deref(), Some("internal_a"));
1088 assert_eq!(calls[1].internal_call_id.as_deref(), Some("internal_b"));
1089 }
1090
1091 #[test]
1092 fn streamed_turn_records_the_completion_call_when_the_driver_did_not() {
1093 let mut run = AgentRun::new("hello");
1094 run.next_step().expect("next_step");
1095
1096 let asm = assembler();
1097 let final_choice = OneOrMany::one(AssistantContent::text("done"));
1098 run.streamed_turn(asm.finish(None, &final_choice))
1099 .expect("streamed_turn should succeed");
1100
1101 assert_eq!(run.completion_calls().len(), 1);
1104 assert_eq!(run.completion_calls()[0].usage, Usage::new());
1105 }
1106
1107 #[test]
1108 fn streamed_completion_call_is_recorded_once_per_turn() {
1109 let mut run = AgentRun::new("hello");
1110 run.next_step().expect("next_step");
1111
1112 run.record_streamed_completion_call(Usage::new())
1113 .expect("first record succeeds");
1114 let err = run
1115 .record_streamed_completion_call(Usage::new())
1116 .expect_err("second record for the same turn must be rejected");
1117 assert!(matches!(err, PromptError::PromptCancelled { .. }));
1118 assert_eq!(run.completion_calls().len(), 1);
1119 }
1120
1121 #[test]
1122 fn streamed_run_serde_round_trips_while_tools_pend() {
1123 let mut run = AgentRun::new("add things").max_turns(2);
1124 run.next_step().expect("next_step");
1125
1126 let mut asm = assembler();
1127 asm.ingest(&tool_call_item("tc_1", "add"))
1128 .expect("ingest should succeed");
1129 run.record_streamed_completion_call(Usage::new())
1130 .expect("record should succeed");
1131 let final_choice = OneOrMany::one(AssistantContent::ToolCall(tool_call("tc_1", "add")));
1132 run.streamed_turn(asm.finish(None, &final_choice))
1133 .expect("streamed_turn should succeed");
1134 run.next_step().expect("CallTools step");
1135
1136 let serialized = serde_json::to_string(&run).expect("serialize mid-run");
1137 let mut restored: AgentRun =
1138 serde_json::from_str(&serialized).expect("deserialize mid-run");
1139 restored
1140 .tool_results(vec![UserContent::tool_result(
1141 "tc_1".to_string(),
1142 ToolResultContent::from_tool_output("2".to_string()),
1143 )])
1144 .expect("tool_results should succeed");
1145 assert!(matches!(
1146 restored.next_step().expect("next turn"),
1147 AgentRunStep::CallModel { turn: 2, .. }
1148 ));
1149 }
1150}