1use crate::{
2 OneOrMany,
3 agent::completion::{DynamicContextStore, build_prepared_completion_request},
4 agent::prompt_request::{
5 HookAction, assistant_text_from_choice,
6 hooks::{InvalidToolCallHookAction, PromptHook},
7 is_empty_assistant_turn, tool_result_user_content,
8 },
9 agent::run::{
10 AgentRun, AgentRunStep,
11 streamed::{StreamedResolution, StreamedTurnAssembler, StreamedTurnEvent},
12 },
13 completion::{Document, GetTokenUsage},
14 json_utils,
15 memory::ConversationMemory,
16 message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent},
17 streaming::{StreamedAssistantContent, StreamedUserContent, ToolCallDeltaContent},
18 tool::server::ToolServerHandle,
19 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
20};
21use futures::{Stream, StreamExt};
22use serde::{Deserialize, Serialize};
23use std::{collections::VecDeque, pin::Pin, sync::Arc};
24use tracing::info_span;
25use tracing_futures::Instrument;
26
27use super::{CompletionCall, ToolCallHookAction};
28use crate::{
29 agent::Agent,
30 completion::{CompletionError, CompletionModel, PromptError},
31 message::{Message, Text},
32 tool::ToolSetError,
33};
34
35#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
36pub type StreamingResult<R> =
37 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
38
39#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
40pub type StreamingResult<R> =
41 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
42
43#[derive(Deserialize, Serialize, Debug, Clone)]
44#[serde(tag = "type", rename_all = "camelCase")]
45#[non_exhaustive]
46pub enum MultiTurnStreamItem<R> {
47 StreamAssistantItem(StreamedAssistantContent<R>),
49 StreamUserItem(StreamedUserContent),
51 CompletionCall(CompletionCall),
69 FinalResponse(FinalResponse),
71}
72
73#[derive(Deserialize, Serialize, Debug, Clone)]
74#[serde(rename_all = "camelCase")]
75pub struct FinalResponse {
76 content: OneOrMany<AssistantContent>,
78 response: String,
81 aggregated_usage: crate::completion::Usage,
82 #[serde(default, skip_serializing_if = "Vec::is_empty")]
84 completion_calls: Vec<CompletionCall>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 history: Option<Vec<Message>>,
87}
88
89impl FinalResponse {
90 pub fn empty() -> Self {
91 Self::new(
92 OneOrMany::one(AssistantContent::text("")),
93 crate::completion::Usage::new(),
94 None,
95 )
96 }
97
98 pub fn new(
99 content: OneOrMany<AssistantContent>,
100 aggregated_usage: crate::completion::Usage,
101 history: Option<Vec<Message>>,
102 ) -> Self {
103 let response = assistant_text_from_choice(&content);
104 Self {
105 content,
106 response,
107 aggregated_usage,
108 completion_calls: Vec::new(),
109 history,
110 }
111 }
112
113 pub fn response(&self) -> &str {
115 &self.response
116 }
117
118 pub fn content(&self) -> &OneOrMany<AssistantContent> {
120 &self.content
121 }
122
123 pub fn assistant_content(&self) -> &OneOrMany<AssistantContent> {
125 &self.content
126 }
127
128 pub fn usage(&self) -> crate::completion::Usage {
129 self.aggregated_usage
130 }
131
132 pub fn completion_calls(&self) -> &[CompletionCall] {
139 &self.completion_calls
140 }
141
142 pub fn requests(&self) -> usize {
144 self.completion_calls.len()
145 }
146
147 pub fn history(&self) -> Option<&[Message]> {
148 self.history.as_deref()
149 }
150}
151
152impl<R> MultiTurnStreamItem<R> {
153 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
154 Self::StreamAssistantItem(item)
155 }
156
157 pub fn final_response(
158 content: OneOrMany<AssistantContent>,
159 aggregated_usage: crate::completion::Usage,
160 ) -> Self {
161 Self::FinalResponse(FinalResponse::new(content, aggregated_usage, None))
162 }
163
164 pub fn final_response_with_history(
165 content: OneOrMany<AssistantContent>,
166 aggregated_usage: crate::completion::Usage,
167 history: Option<Vec<Message>>,
168 ) -> Self {
169 Self::FinalResponse(FinalResponse::new(content, aggregated_usage, history))
170 }
171
172 pub(crate) fn final_response_with_completion_calls(
173 content: OneOrMany<AssistantContent>,
174 aggregated_usage: crate::completion::Usage,
175 completion_calls: Vec<CompletionCall>,
176 history: Option<Vec<Message>>,
177 ) -> Self {
178 let mut response = FinalResponse::new(content, aggregated_usage, history);
179 response.completion_calls = completion_calls;
180 Self::FinalResponse(response)
181 }
182}
183
184async fn drain_stream_usage<R>(
187 stream: &mut crate::streaming::StreamingCompletionResponse<R>,
188) -> Result<crate::completion::Usage, StreamingError>
189where
190 R: Clone + Unpin + GetTokenUsage,
191{
192 while let Some(content) = stream.next().await {
193 match content {
194 Ok(StreamedAssistantContent::Final(final_resp)) => {
195 return Ok(final_resp.token_usage());
196 }
197 Ok(_) => {}
198 Err(err) => return Err(err.into()),
199 }
200 }
201
202 Ok(crate::completion::Usage::new())
203}
204
205fn record_usage_on_span(span: &tracing::Span, usage: crate::completion::Usage) {
206 span.record("gen_ai.usage.input_tokens", usage.input_tokens);
207 span.record("gen_ai.usage.output_tokens", usage.output_tokens);
208 span.record(
209 "gen_ai.usage.cache_read.input_tokens",
210 usage.cached_input_tokens,
211 );
212 span.record(
213 "gen_ai.usage.cache_creation.input_tokens",
214 usage.cache_creation_input_tokens,
215 );
216 span.record(
217 "gen_ai.usage.tool_use_prompt_tokens",
218 usage.tool_use_prompt_tokens,
219 );
220 span.record("gen_ai.usage.reasoning_tokens", usage.reasoning_tokens);
221}
222
223#[derive(Debug, thiserror::Error)]
224pub enum StreamingError {
225 #[error("CompletionError: {0}")]
226 Completion(#[from] CompletionError),
227 #[error("PromptError: {0}")]
228 Prompt(#[from] Box<PromptError>),
229 #[error("ToolSetError: {0}")]
230 Tool(#[from] ToolSetError),
231}
232
233impl From<crate::memory::MemoryError> for StreamingError {
237 fn from(err: crate::memory::MemoryError) -> Self {
238 Self::Completion(CompletionError::RequestError(Box::new(err)))
239 }
240}
241
242const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
243
244pub struct StreamingPromptRequest<M, P>
253where
254 M: CompletionModel,
255 P: PromptHook<M> + 'static,
256{
257 prompt: Message,
259 chat_history: Option<Vec<Message>>,
261 max_turns: usize,
263
264 model: Arc<M>,
267 agent_name: Option<String>,
269 preamble: Option<String>,
271 static_context: Vec<Document>,
273 temperature: Option<f64>,
275 max_tokens: Option<u64>,
277 additional_params: Option<serde_json::Value>,
279 tool_server_handle: ToolServerHandle,
281 dynamic_context: DynamicContextStore,
283 tool_choice: Option<ToolChoice>,
285 output_schema: Option<schemars::Schema>,
287 hook: Option<P>,
289 max_invalid_tool_call_retries: usize,
291 memory: Option<Arc<dyn ConversationMemory>>,
293 conversation_id: Option<String>,
295}
296
297impl<M, P> StreamingPromptRequest<M, P>
298where
299 M: CompletionModel + 'static,
300 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
301 P: PromptHook<M>,
302{
303 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
306 StreamingPromptRequest {
307 prompt: prompt.into(),
308 chat_history: None,
309 max_turns: agent.default_max_turns.unwrap_or_default(),
310 model: agent.model.clone(),
311 agent_name: agent.name.clone(),
312 preamble: agent.preamble.clone(),
313 static_context: agent.static_context.clone(),
314 temperature: agent.temperature,
315 max_tokens: agent.max_tokens,
316 additional_params: agent.additional_params.clone(),
317 tool_server_handle: agent.tool_server_handle.clone(),
318 dynamic_context: agent.dynamic_context.clone(),
319 tool_choice: agent.tool_choice.clone(),
320 output_schema: agent.output_schema.clone(),
321 hook: None,
322 max_invalid_tool_call_retries: 0,
323 memory: agent.memory.clone(),
324 conversation_id: agent.default_conversation_id.clone(),
325 }
326 }
327
328 pub fn from_agent<P2>(
330 agent: &Agent<M, P2>,
331 prompt: impl Into<Message>,
332 ) -> StreamingPromptRequest<M, P2>
333 where
334 P2: PromptHook<M>,
335 {
336 StreamingPromptRequest {
337 prompt: prompt.into(),
338 chat_history: None,
339 max_turns: agent.default_max_turns.unwrap_or_default(),
340 model: agent.model.clone(),
341 agent_name: agent.name.clone(),
342 preamble: agent.preamble.clone(),
343 static_context: agent.static_context.clone(),
344 temperature: agent.temperature,
345 max_tokens: agent.max_tokens,
346 additional_params: agent.additional_params.clone(),
347 tool_server_handle: agent.tool_server_handle.clone(),
348 dynamic_context: agent.dynamic_context.clone(),
349 tool_choice: agent.tool_choice.clone(),
350 output_schema: agent.output_schema.clone(),
351 hook: agent.hook.clone(),
352 max_invalid_tool_call_retries: 0,
353 memory: agent.memory.clone(),
354 conversation_id: agent.default_conversation_id.clone(),
355 }
356 }
357
358 fn agent_name(&self) -> &str {
359 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
360 }
361
362 pub fn multi_turn(mut self, turns: usize) -> Self {
365 self.max_turns = turns;
366 self
367 }
368
369 pub fn with_history<H, T>(mut self, history: H) -> Self
382 where
383 H: IntoIterator<Item = T>,
384 T: Into<Message>,
385 {
386 self.chat_history = Some(history.into_iter().map(Into::into).collect());
387 self
388 }
389
390 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
393 where
394 P2: PromptHook<M>,
395 {
396 StreamingPromptRequest {
397 prompt: self.prompt,
398 chat_history: self.chat_history,
399 max_turns: self.max_turns,
400 model: self.model,
401 agent_name: self.agent_name,
402 preamble: self.preamble,
403 static_context: self.static_context,
404 temperature: self.temperature,
405 max_tokens: self.max_tokens,
406 additional_params: self.additional_params,
407 tool_server_handle: self.tool_server_handle,
408 dynamic_context: self.dynamic_context,
409 tool_choice: self.tool_choice,
410 output_schema: self.output_schema,
411 hook: Some(hook),
412 max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
413 memory: self.memory,
414 conversation_id: self.conversation_id,
415 }
416 }
417
418 pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
422 self.max_invalid_tool_call_retries = retries;
423 self
424 }
425
426 pub fn conversation(mut self, id: impl Into<String>) -> Self {
431 self.conversation_id = Some(id.into());
432 self
433 }
434
435 pub fn without_memory(mut self) -> Self {
439 self.memory = None;
440 self.conversation_id = None;
441 self
442 }
443
444 async fn send(self) -> StreamingResult<M::StreamingResponse> {
445 let (agent_span, created_agent_span) = if tracing::Span::current().is_disabled() {
446 (
447 info_span!(
448 "invoke_agent",
449 gen_ai.operation.name = "invoke_agent",
450 gen_ai.agent.name = self.agent_name(),
451 gen_ai.system_instructions = self.preamble,
452 gen_ai.prompt = tracing::field::Empty,
453 gen_ai.completion = tracing::field::Empty,
454 gen_ai.usage.input_tokens = tracing::field::Empty,
455 gen_ai.usage.output_tokens = tracing::field::Empty,
456 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
457 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
458 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
459 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
460 ),
461 true,
462 )
463 } else {
464 (tracing::Span::current(), false)
465 };
466
467 let prompt = self.prompt;
468 if let Some(text) = prompt.rag_text() {
469 agent_span.record("gen_ai.prompt", text);
470 }
471
472 let model = self.model.clone();
474 let preamble = self.preamble.clone();
475 let static_context = self.static_context.clone();
476 let temperature = self.temperature;
477 let max_tokens = self.max_tokens;
478 let additional_params = self.additional_params.clone();
479 let tool_server_handle = self.tool_server_handle.clone();
480 let dynamic_context = self.dynamic_context.clone();
481 let tool_choice = self.tool_choice.clone();
482 let agent_name = self.agent_name.clone();
483 let output_schema = self.output_schema;
484 let (chat_history, memory_handle) = match self.chat_history {
489 Some(history) => (Some(history), None),
490 None => match (self.memory, self.conversation_id) {
491 (Some(memory), Some(id)) => match memory.load(&id).await {
492 Ok(loaded) => (Some(loaded), Some((memory, id))),
493 Err(err) => {
494 let stream = async_stream::stream! {
495 yield Err(StreamingError::from(err));
496 };
497 return Box::pin(stream);
498 }
499 },
500 _ => (None, None),
501 },
502 };
503 let has_history = chat_history.is_some();
504
505 let mut run = AgentRun::new(prompt.clone())
506 .max_turns(self.max_turns)
507 .max_invalid_tool_call_retries(self.max_invalid_tool_call_retries);
508 if let Some(history) = chat_history {
509 run = run.with_history(history);
510 }
511 if let Some(tool_choice) = tool_choice.clone() {
512 run = run.with_tool_choice(tool_choice);
513 }
514
515 let stream = async_stream::stream! {
522 let mut last_final_choice: OneOrMany<AssistantContent> =
526 OneOrMany::one(AssistantContent::text(""));
527 let mut last_message_id: Option<String> = None;
528
529 'outer: loop {
530 let step = match run.next_step() {
531 Ok(step) => step,
532 Err(err) => {
533 yield Err(Box::new(err).into());
534 break 'outer;
535 }
536 };
537
538 match step {
539 AgentRunStep::CallModel { prompt: current_prompt, history, turn } => {
540 if self.max_turns > 1 {
541 tracing::info!(
542 "Current conversation Turns: {}/{}",
543 turn,
544 self.max_turns
545 );
546 }
547
548 if let Some(ref hook) = self.hook
549 && let HookAction::Terminate { reason } =
550 hook.on_completion_call(¤t_prompt, &history).await
551 {
552 yield Err(StreamingError::Prompt(Box::new(run.cancel_error(reason))));
553 break 'outer;
554 }
555
556 let chat_stream_span = info_span!(
557 target: "rig::agent_chat",
558 parent: tracing::Span::current(),
559 "chat_streaming",
560 gen_ai.operation.name = "chat",
561 gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
562 gen_ai.system_instructions = preamble,
563 gen_ai.provider.name = tracing::field::Empty,
564 gen_ai.request.model = tracing::field::Empty,
565 gen_ai.response.id = tracing::field::Empty,
566 gen_ai.response.model = tracing::field::Empty,
567 gen_ai.usage.output_tokens = tracing::field::Empty,
568 gen_ai.usage.input_tokens = tracing::field::Empty,
569 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
570 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
571 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
572 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
573 gen_ai.input.messages = tracing::field::Empty,
574 gen_ai.output.messages = tracing::field::Empty,
575 );
576
577 let prepared_request = build_prepared_completion_request(
578 &model,
579 current_prompt.clone(),
580 &history,
581 preamble.as_deref(),
582 &static_context,
583 temperature,
584 max_tokens,
585 additional_params.as_ref(),
586 tool_choice.as_ref(),
587 &tool_server_handle,
588 &dynamic_context,
589 output_schema.as_ref(),
590 )
591 .await?;
592
593 let mut stream = prepared_request
594 .builder
595 .stream()
596 .instrument(chat_stream_span.clone())
597 .await?;
598
599 let mut assembler = StreamedTurnAssembler::new(
600 prepared_request.executable_tool_names.clone(),
601 prepared_request.allowed_tool_names.clone(),
602 );
603 let mut completion_call_emitted = false;
604 let mut turn_abandoned = false;
605
606 'turn: while let Some(item) = stream.next().await {
607 let item = match item {
608 Ok(item) => item,
609 Err(err) => {
610 yield Err(err.into());
611 break 'outer;
612 }
613 };
614 let mut events: VecDeque<StreamedTurnEvent> =
615 match assembler.ingest(&item) {
616 Ok(events) => events.into(),
617 Err(err) => {
618 yield Err(err.into());
619 break 'outer;
620 }
621 };
622 let mut item_slot = Some(item);
626 while let Some(event) = events.pop_front() {
627 match event {
628 StreamedTurnEvent::EmitIngested => {
629 if let Some(StreamedAssistantContent::Text(text)) =
630 item_slot.as_ref()
631 && let Some(ref hook) = self.hook
632 && let HookAction::Terminate { reason } = hook
633 .on_text_delta(
634 &text.text,
635 assembler.aggregated_text(),
636 )
637 .await
638 {
639 yield Err(StreamingError::Prompt(Box::new(
640 run.cancel_error(reason),
641 )));
642 break 'outer;
643 }
644 if let Some(item) = item_slot.take() {
645 yield Ok(MultiTurnStreamItem::stream_item(item));
646 }
647 }
648 StreamedTurnEvent::EmitToolCallDelta {
649 id,
650 internal_call_id,
651 content,
652 } => {
653 if let Some(ref hook) = self.hook {
654 let (name, delta) = match &content {
655 ToolCallDeltaContent::Name(name) => {
656 (Some(name.as_str()), "")
657 }
658 ToolCallDeltaContent::Delta(delta) => {
659 (None, delta.as_str())
660 }
661 };
662
663 if let HookAction::Terminate { reason } = hook
664 .on_tool_call_delta(
665 &id,
666 &internal_call_id,
667 name,
668 delta,
669 )
670 .await
671 {
672 yield Err(StreamingError::Prompt(Box::new(
673 run.cancel_error(reason),
674 )));
675 break 'outer;
676 }
677 }
678
679 yield Ok(MultiTurnStreamItem::StreamAssistantItem(
680 StreamedAssistantContent::ToolCallDelta {
681 id,
682 internal_call_id,
683 content,
684 },
685 ));
686 }
687 StreamedTurnEvent::Completed { usage, emit_final } => {
688 if !completion_call_emitted {
689 if usage.has_values() {
690 record_usage_on_span(&chat_stream_span, usage);
691 }
692 let completion_call =
693 match run.record_streamed_completion_call(usage) {
694 Ok(call) => call,
695 Err(err) => {
696 yield Err(Box::new(err).into());
697 break 'outer;
698 }
699 };
700 completion_call_emitted = true;
701 yield Ok(MultiTurnStreamItem::CompletionCall(
702 completion_call,
703 ));
704 }
705
706 if emit_final
707 && let Some(StreamedAssistantContent::Final(
708 final_resp,
709 )) = item_slot.as_ref()
710 {
711 if let Some(ref hook) = self.hook
712 && let HookAction::Terminate { reason } = hook
713 .on_stream_completion_response_finish(
714 ¤t_prompt,
715 final_resp,
716 )
717 .await
718 {
719 yield Err(StreamingError::Prompt(Box::new(
720 run.cancel_error(reason),
721 )));
722 break 'outer;
723 }
724 if let Some(item) = item_slot.take() {
725 yield Ok(MultiTurnStreamItem::stream_item(item));
726 }
727 }
728 }
729 StreamedTurnEvent::InvalidToolCall(invalid) => {
730 let partial =
731 assembler.partial_turn(stream.message_id.clone());
732 let action = match self.hook.as_ref() {
733 Some(hook) => {
734 let context = run
735 .streamed_invalid_tool_call_context(
736 &partial, &invalid,
737 );
738 hook.on_invalid_tool_call(&context).await
739 }
740 None => InvalidToolCallHookAction::fail(),
741 };
742
743 let resolution = match run
744 .resolve_streamed_invalid_tool_call(
745 &partial, &invalid, action,
746 ) {
747 Ok(resolution) => resolution,
748 Err(err) => {
749 yield Err(Box::new(err).into());
750 break 'outer;
751 }
752 };
753
754 match resolution {
755 StreamedResolution::Repaired { .. } => {
756 events.extend(
759 assembler.resolve_pending_invalid(&resolution),
760 );
761 }
762 StreamedResolution::TurnAbandoned {
763 ref skipped_tool_result,
764 } => {
765 let skipped_tool_result =
766 skipped_tool_result.clone();
767 assembler.resolve_pending_invalid(&resolution);
768
769 if let Some(err) = assembler.pending_delta_error() {
770 yield Err(err.into());
771 break 'outer;
772 }
773 let drained_usage =
774 match drain_stream_usage(&mut stream).await {
775 Ok(usage) => usage,
776 Err(err) => {
777 yield Err(err);
778 break 'outer;
779 }
780 };
781 if !completion_call_emitted {
782 if drained_usage.has_values() {
783 record_usage_on_span(
784 &chat_stream_span,
785 drained_usage,
786 );
787 }
788 let completion_call = match run
789 .record_streamed_completion_call(
790 drained_usage,
791 ) {
792 Ok(call) => call,
793 Err(err) => {
794 yield Err(Box::new(err).into());
795 break 'outer;
796 }
797 };
798 completion_call_emitted = true;
799 yield Ok(MultiTurnStreamItem::CompletionCall(
800 completion_call,
801 ));
802 }
803 if let Some(tool_result) = skipped_tool_result {
804 yield Ok(MultiTurnStreamItem::StreamUserItem(
805 StreamedUserContent::ToolResult {
806 tool_result,
807 internal_call_id: invalid
808 .internal_call_id
809 .clone(),
810 },
811 ));
812 }
813 turn_abandoned = true;
814 break 'turn;
815 }
816 }
817 }
818 }
819 }
820 }
821
822 if turn_abandoned {
823 continue 'outer;
824 }
825
826 if let Some(err) = assembler.pending_delta_error() {
827 yield Err(err.into());
828 break 'outer;
829 }
830
831 if !completion_call_emitted {
832 let completion_call =
833 match run
834 .record_streamed_completion_call(crate::completion::Usage::new())
835 {
836 Ok(call) => call,
837 Err(err) => {
838 yield Err(Box::new(err).into());
839 break 'outer;
840 }
841 };
842 yield Ok(MultiTurnStreamItem::CompletionCall(completion_call));
843 }
844
845 let final_turn_content = stream.choice.clone();
846 tracing::Span::current().record(
847 "gen_ai.completion",
848 assistant_text_from_choice(&final_turn_content),
849 );
850
851 last_message_id = stream.message_id.clone();
852 let streamed_turn =
853 assembler.finish(stream.message_id.clone(), &final_turn_content);
854 if let Err(err) = run.streamed_turn(streamed_turn) {
855 yield Err(Box::new(err).into());
856 break 'outer;
857 }
858 last_final_choice = final_turn_content;
859 }
860 AgentRunStep::CallTools { calls } => {
861 let full_history_for_errors = run.full_history();
862 let mut results: Vec<UserContent> = Vec::with_capacity(calls.len());
863
864 for pending in calls {
865 let tool_call = pending.tool_call;
866 if let Some(result) = pending.preresolved_result {
867 results.push(result);
873 continue;
874 }
875 let internal_call_id = pending
876 .internal_call_id
877 .unwrap_or_else(|| nanoid::nanoid!());
878
879 let tool_span = info_span!(
880 parent: tracing::Span::current(),
881 "execute_tool",
882 gen_ai.operation.name = "execute_tool",
883 gen_ai.tool.type = "function",
884 gen_ai.tool.name = tracing::field::Empty,
885 gen_ai.tool.call.id = tracing::field::Empty,
886 gen_ai.tool.call.arguments = tracing::field::Empty,
887 gen_ai.tool.call.result = tracing::field::Empty
888 );
889
890 yield Ok(MultiTurnStreamItem::stream_item(
891 StreamedAssistantContent::ToolCall {
892 tool_call: tool_call.clone(),
893 internal_call_id: internal_call_id.clone(),
894 },
895 ));
896
897 let tc_result = async {
898 let tool_span = tracing::Span::current();
899 let tool_args =
900 json_utils::value_to_json_string(&tool_call.function.arguments);
901 if let Some(ref hook) = self.hook {
902 let action = hook
903 .on_tool_call(
904 &tool_call.function.name,
905 tool_call.call_id.clone(),
906 &internal_call_id,
907 &tool_args,
908 )
909 .await;
910
911 if let ToolCallHookAction::Terminate { reason } = action {
912 return Err(StreamingError::Prompt(Box::new(
913 PromptError::prompt_cancelled(
914 full_history_for_errors.clone(),
915 reason,
916 ),
917 )));
918 }
919
920 if let ToolCallHookAction::Skip { reason } = action {
921 tracing::info!(
923 tool_name = tool_call.function.name.as_str(),
924 reason = reason,
925 "Tool call rejected"
926 );
927 return Ok(reason);
928 }
929 }
930
931 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
932 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
933
934 let tool_result = match tool_server_handle
935 .call_tool(&tool_call.function.name, &tool_args)
936 .await
937 {
938 Ok(result) => result,
939 Err(err) => {
940 tracing::warn!("Error while calling tool: {err}");
941 err.to_string()
942 }
943 };
944
945 tool_span.record("gen_ai.tool.call.result", &tool_result);
946
947 if let Some(ref hook) = self.hook
948 && let HookAction::Terminate { reason } = hook
949 .on_tool_result(
950 &tool_call.function.name,
951 tool_call.call_id.clone(),
952 &internal_call_id,
953 &tool_args,
954 &tool_result.to_string(),
955 )
956 .await
957 {
958 return Err(StreamingError::Prompt(Box::new(
959 PromptError::prompt_cancelled(
960 full_history_for_errors.clone(),
961 reason,
962 ),
963 )));
964 }
965
966 Ok(tool_result)
967 }
968 .instrument(tool_span)
969 .await;
970
971 match tc_result {
972 Ok(text) => {
973 results.push(tool_result_user_content(
974 tool_call.id.clone(),
975 tool_call.call_id.clone(),
976 text.clone(),
977 ));
978 let tool_result = ToolResult {
979 id: tool_call.id,
980 call_id: tool_call.call_id,
981 content: ToolResultContent::from_tool_output(text),
982 };
983 yield Ok(MultiTurnStreamItem::StreamUserItem(
984 StreamedUserContent::ToolResult {
985 tool_result,
986 internal_call_id,
987 },
988 ));
989 }
990 Err(err) => {
991 yield Err(err);
992 break 'outer;
993 }
994 }
995 }
996
997 if let Err(err) = run.tool_results(results) {
998 yield Err(Box::new(err).into());
999 break 'outer;
1000 }
1001 }
1002 AgentRunStep::Done(response) => {
1003 if is_empty_assistant_turn(&last_final_choice) {
1004 tracing::warn!(
1005 agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
1006 message_id = ?last_message_id,
1007 "Streaming turn completed without assistant text; final response will be empty"
1008 );
1009 }
1010
1011 if created_agent_span {
1012 let current_span = tracing::Span::current();
1013 record_usage_on_span(¤t_span, response.usage);
1014 }
1015 tracing::info!("Agent multi-turn stream finished");
1016 if let Some((memory, id)) = memory_handle.as_ref()
1017 && let Err(err) = memory
1018 .append(id, response.messages.clone().unwrap_or_default())
1019 .await
1020 {
1021 tracing::warn!(
1022 error = %err,
1023 conversation_id = %id,
1024 "conversation memory append failed; yielding final response anyway"
1025 );
1026 }
1027 let final_messages: Option<Vec<Message>> = if has_history {
1028 Some(response.messages.clone().unwrap_or_default())
1029 } else {
1030 None
1031 };
1032 yield Ok(MultiTurnStreamItem::final_response_with_completion_calls(
1033 last_final_choice.clone(),
1034 response.usage,
1035 response.completion_calls.clone(),
1036 final_messages,
1037 ));
1038 break 'outer;
1039 }
1040 }
1041 }
1042 };
1043
1044 Box::pin(stream.instrument(agent_span))
1045 }
1046}
1047
1048impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
1049where
1050 M: CompletionModel + 'static,
1051 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
1052 P: PromptHook<M> + 'static,
1053{
1054 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1056
1057 fn into_future(self) -> Self::IntoFuture {
1058 Box::pin(async move { self.send().await })
1060 }
1061}
1062
1063pub async fn stream_to_stdout<R>(
1070 stream: &mut StreamingResult<R>,
1071) -> Result<FinalResponse, std::io::Error> {
1072 let mut final_res = FinalResponse::empty();
1073 print!("Response: ");
1074 while let Some(content) = stream.next().await {
1075 match content {
1076 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1077 Text { text, .. },
1078 ))) => {
1079 print!("{text}");
1080 std::io::Write::flush(&mut std::io::stdout())?;
1081 }
1082 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
1083 reasoning,
1084 ))) => {
1085 let reasoning = reasoning.display_text();
1086 print!("{reasoning}");
1087 std::io::Write::flush(&mut std::io::stdout())?;
1088 }
1089 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1090 final_res = res;
1091 }
1092 Err(err) => {
1093 eprintln!("Error: {err}");
1094 }
1095 _ => {}
1096 }
1097 }
1098
1099 Ok(final_res)
1100}
1101
1102#[cfg(test)]
1103mod tests {
1104 use super::*;
1105 use crate::agent::AgentBuilder;
1106 use crate::agent::prompt_request::TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER;
1107 use crate::agent::prompt_request::hooks::{
1108 InvalidToolCallContext, InvalidToolCallHookAction, PromptHook, ToolCallHookAction,
1109 };
1110 use crate::agent::run::streamed::merge_reasoning_blocks;
1111 use crate::client::ProviderClient;
1112 use crate::client::completion::CompletionClient;
1113 use crate::completion::{CompletionRequest, PromptError, ToolDefinition, Usage};
1114 use crate::message::{
1115 AssistantContent, DocumentSourceKind, ImageMediaType, Message, ReasoningContent,
1116 ToolChoice, ToolResultContent, UserContent,
1117 };
1118 use crate::providers::anthropic;
1119 use crate::streaming::{StreamingPrompt, ToolCallDeltaContent};
1120 use crate::test_utils::{
1121 AppendFailingMemory, FailingMemory, MockAddTool, MockCompletionModel, MockResponse,
1122 MockStreamEvent, MockSubtractTool, MockToolError,
1123 };
1124 use crate::tool::Tool;
1125 use futures::{StreamExt, TryStreamExt};
1126 use serde::Deserialize;
1127 use std::collections::HashMap;
1128 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
1129 use std::sync::{Arc, Mutex};
1130 use std::time::Duration;
1131 use tracing::field::{Field, Visit};
1132 use tracing::{Id, Subscriber};
1133 use tracing_subscriber::layer::{Context, SubscriberExt};
1134 use tracing_subscriber::{Layer, Registry, registry::LookupSpan};
1135
1136 #[test]
1137 fn merge_reasoning_blocks_preserves_order_and_signatures() {
1138 let mut accumulated = Vec::new();
1139 let first = crate::message::Reasoning {
1140 id: Some("rs_1".to_string()),
1141 content: vec![ReasoningContent::Text {
1142 text: "step-1".to_string(),
1143 signature: Some("sig-1".to_string()),
1144 }],
1145 };
1146 let second = crate::message::Reasoning {
1147 id: Some("rs_1".to_string()),
1148 content: vec![
1149 ReasoningContent::Text {
1150 text: "step-2".to_string(),
1151 signature: Some("sig-2".to_string()),
1152 },
1153 ReasoningContent::Summary("summary".to_string()),
1154 ],
1155 };
1156
1157 merge_reasoning_blocks(&mut accumulated, &first);
1158 merge_reasoning_blocks(&mut accumulated, &second);
1159
1160 assert_eq!(accumulated.len(), 1);
1161 let merged = accumulated.first().expect("expected accumulated reasoning");
1162 assert_eq!(merged.id.as_deref(), Some("rs_1"));
1163 assert_eq!(merged.content.len(), 3);
1164 assert!(matches!(
1165 merged.content.first(),
1166 Some(ReasoningContent::Text { text, signature: Some(sig) })
1167 if text == "step-1" && sig == "sig-1"
1168 ));
1169 assert!(matches!(
1170 merged.content.get(1),
1171 Some(ReasoningContent::Text { text, signature: Some(sig) })
1172 if text == "step-2" && sig == "sig-2"
1173 ));
1174 }
1175
1176 #[test]
1177 fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
1178 let mut accumulated = vec![crate::message::Reasoning {
1179 id: Some("rs_a".to_string()),
1180 content: vec![ReasoningContent::Text {
1181 text: "step-1".to_string(),
1182 signature: None,
1183 }],
1184 }];
1185 let incoming = crate::message::Reasoning {
1186 id: Some("rs_b".to_string()),
1187 content: vec![ReasoningContent::Text {
1188 text: "step-2".to_string(),
1189 signature: None,
1190 }],
1191 };
1192
1193 merge_reasoning_blocks(&mut accumulated, &incoming);
1194 assert_eq!(accumulated.len(), 2);
1195 assert_eq!(
1196 accumulated.first().and_then(|r| r.id.as_deref()),
1197 Some("rs_a")
1198 );
1199 assert_eq!(
1200 accumulated.get(1).and_then(|r| r.id.as_deref()),
1201 Some("rs_b")
1202 );
1203 }
1204
1205 #[test]
1206 fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
1207 let mut accumulated = vec![crate::message::Reasoning {
1208 id: None,
1209 content: vec![ReasoningContent::Text {
1210 text: "first".to_string(),
1211 signature: None,
1212 }],
1213 }];
1214 let incoming = crate::message::Reasoning {
1215 id: None,
1216 content: vec![ReasoningContent::Text {
1217 text: "second".to_string(),
1218 signature: None,
1219 }],
1220 };
1221
1222 merge_reasoning_blocks(&mut accumulated, &incoming);
1223 assert_eq!(accumulated.len(), 2);
1224 assert!(matches!(
1225 accumulated.first(),
1226 Some(crate::message::Reasoning {
1227 id: None,
1228 content
1229 }) if matches!(
1230 content.first(),
1231 Some(ReasoningContent::Text { text, .. }) if text == "first"
1232 )
1233 ));
1234 assert!(matches!(
1235 accumulated.get(1),
1236 Some(crate::message::Reasoning {
1237 id: None,
1238 content
1239 }) if matches!(
1240 content.first(),
1241 Some(ReasoningContent::Text { text, .. }) if text == "second"
1242 )
1243 ));
1244 }
1245
1246 #[test]
1247 fn tool_result_user_content_preserves_multimodal_tool_output() {
1248 let user_content = tool_result_user_content(
1249 "tool_call_1".to_string(),
1250 Some("call_1".to_string()),
1251 serde_json::json!({
1252 "response": {
1253 "instruction": "Use the image part to answer."
1254 },
1255 "parts": [
1256 {
1257 "type": "image",
1258 "data": "base64data==",
1259 "mimeType": "image/png"
1260 }
1261 ]
1262 })
1263 .to_string(),
1264 );
1265
1266 let tool_result = match user_content {
1267 UserContent::ToolResult(tool_result) => tool_result,
1268 other => panic!("expected tool result content, got {other:?}"),
1269 };
1270
1271 assert_eq!(tool_result.id, "tool_call_1");
1272 assert_eq!(tool_result.call_id.as_deref(), Some("call_1"));
1273 assert_eq!(tool_result.content.len(), 2);
1274
1275 let mut items = tool_result.content.iter();
1276 match items.next() {
1277 Some(ToolResultContent::Text(text)) => {
1278 assert!(text.text.contains("Use the image part to answer."));
1279 }
1280 other => panic!("expected structured text payload first, got {other:?}"),
1281 }
1282
1283 match items.next() {
1284 Some(ToolResultContent::Image(image)) => {
1285 assert_eq!(image.media_type, Some(ImageMediaType::PNG));
1286 assert!(matches!(
1287 image.data,
1288 DocumentSourceKind::Base64(ref data) if data == "base64data=="
1289 ));
1290 }
1291 other => panic!("expected image payload second, got {other:?}"),
1292 }
1293 }
1294
1295 fn validate_follow_up_tool_history(request: &CompletionRequest) -> Result<(), String> {
1296 let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
1297 if history.len() != 3 {
1298 return Err(format!(
1299 "follow-up request should contain [original user prompt, assistant tool call, user tool result]: {history:?}"
1300 ));
1301 }
1302
1303 if !matches!(
1304 history.first(),
1305 Some(Message::User { content })
1306 if matches!(
1307 content.first(),
1308 UserContent::Text(text) if text.text == "do tool work"
1309 )
1310 ) {
1311 return Err(format!(
1312 "follow-up request should begin with the original user prompt: {history:?}"
1313 ));
1314 }
1315
1316 if !matches!(
1317 history.get(1),
1318 Some(Message::Assistant { content, .. })
1319 if matches!(
1320 content.first(),
1321 AssistantContent::ToolCall(tool_call)
1322 if tool_call.id == "tool_call_1"
1323 && tool_call.call_id.as_deref() == Some("call_1")
1324 )
1325 ) {
1326 return Err(format!(
1327 "follow-up request is missing the assistant tool call in position 2: {history:?}"
1328 ));
1329 }
1330
1331 if !matches!(
1332 history.get(2),
1333 Some(Message::User { content })
1334 if matches!(
1335 content.first(),
1336 UserContent::ToolResult(tool_result)
1337 if tool_result.id == "tool_call_1"
1338 && tool_result.call_id.as_deref() == Some("call_1")
1339 )
1340 ) {
1341 return Err(format!(
1342 "follow-up request should end with the user tool result: {history:?}"
1343 ));
1344 }
1345
1346 Ok(())
1347 }
1348
1349 fn history_contains_tool_call(history: &[Message], tool_name: &str) -> bool {
1350 history.iter().any(|message| {
1351 matches!(
1352 message,
1353 Message::Assistant { content, .. }
1354 if content.iter().any(|item| matches!(
1355 item,
1356 AssistantContent::ToolCall(tool_call)
1357 if tool_call.function.name == tool_name
1358 ))
1359 )
1360 })
1361 }
1362
1363 fn history_contains_text(history: &[Message], expected: &str) -> bool {
1364 history.iter().any(|message| {
1365 matches!(
1366 message,
1367 Message::Assistant { content, .. }
1368 if content.iter().any(|item| matches!(
1369 item,
1370 AssistantContent::Text(text) if text.text == expected
1371 ))
1372 )
1373 })
1374 }
1375
1376 fn assistant_reasoning_precedes_tool_call(
1377 history: &[Message],
1378 expected_reasoning: &str,
1379 tool_name: &str,
1380 ) -> bool {
1381 history.iter().any(|message| {
1382 let Message::Assistant { content, .. } = message else {
1383 return false;
1384 };
1385
1386 let reasoning_index = content.iter().position(|item| {
1387 matches!(
1388 item,
1389 AssistantContent::Reasoning(reasoning)
1390 if reasoning.content.iter().any(|content| matches!(
1391 content,
1392 ReasoningContent::Text { text, .. }
1393 if text == expected_reasoning
1394 ))
1395 )
1396 });
1397 let tool_index = content.iter().position(|item| {
1398 matches!(
1399 item,
1400 AssistantContent::ToolCall(tool_call)
1401 if tool_call.function.name == tool_name
1402 )
1403 });
1404
1405 matches!((reasoning_index, tool_index), (Some(reasoning), Some(tool)) if reasoning < tool)
1406 })
1407 }
1408
1409 fn assistant_reasoning_precedes_text_and_tool_call(
1410 history: &[Message],
1411 expected_reasoning: &str,
1412 expected_text: &str,
1413 tool_name: &str,
1414 ) -> bool {
1415 history.iter().any(|message| {
1416 let Message::Assistant { content, .. } = message else {
1417 return false;
1418 };
1419
1420 let reasoning_index = content.iter().position(|item| {
1421 matches!(
1422 item,
1423 AssistantContent::Reasoning(reasoning)
1424 if reasoning.content.iter().any(|content| matches!(
1425 content,
1426 ReasoningContent::Text { text, .. }
1427 if text == expected_reasoning
1428 ))
1429 )
1430 });
1431 let text_index = content.iter().position(|item| {
1432 matches!(
1433 item,
1434 AssistantContent::Text(text) if text.text == expected_text
1435 )
1436 });
1437 let tool_index = content.iter().position(|item| {
1438 matches!(
1439 item,
1440 AssistantContent::ToolCall(tool_call)
1441 if tool_call.function.name == tool_name
1442 )
1443 });
1444
1445 matches!(
1446 (reasoning_index, text_index, tool_index),
1447 (Some(reasoning), Some(text), Some(tool))
1448 if reasoning < text && text < tool
1449 )
1450 })
1451 }
1452
1453 #[derive(Clone)]
1454 struct PanicOnUnknownToolHook;
1455
1456 impl PromptHook<MockCompletionModel> for PanicOnUnknownToolHook {
1457 async fn on_tool_call_delta(
1458 &self,
1459 _tool_call_id: &str,
1460 _internal_call_id: &str,
1461 _tool_name: Option<&str>,
1462 _tool_call_delta: &str,
1463 ) -> HookAction {
1464 panic!("unknown tool call delta should fail before delta hooks run")
1465 }
1466
1467 async fn on_tool_call(
1468 &self,
1469 _tool_name: &str,
1470 _tool_call_id: Option<String>,
1471 _internal_call_id: &str,
1472 _args: &str,
1473 ) -> ToolCallHookAction {
1474 panic!("unknown tool call should fail before tool hooks run")
1475 }
1476
1477 async fn on_stream_completion_response_finish(
1478 &self,
1479 _prompt: &Message,
1480 _response: &MockResponse,
1481 ) -> HookAction {
1482 panic!("unknown tool call should fail before stream finish hooks run")
1483 }
1484 }
1485
1486 #[derive(Clone)]
1487 struct CountingAddTool {
1488 calls: Arc<AtomicU32>,
1489 }
1490
1491 #[derive(Clone)]
1492 struct CountingSubtractTool {
1493 calls: Arc<AtomicU32>,
1494 }
1495
1496 #[derive(Deserialize)]
1497 struct CountingOperationArgs {
1498 x: i32,
1499 y: i32,
1500 }
1501
1502 fn arithmetic_tool_definition(name: &str, description: &str) -> ToolDefinition {
1503 ToolDefinition {
1504 name: name.to_string(),
1505 description: description.to_string(),
1506 parameters: serde_json::json!({
1507 "type": "object",
1508 "properties": {
1509 "x": {
1510 "type": "number",
1511 "description": "The first operand"
1512 },
1513 "y": {
1514 "type": "number",
1515 "description": "The second operand"
1516 }
1517 },
1518 "required": ["x", "y"],
1519 }),
1520 }
1521 }
1522
1523 impl Tool for CountingAddTool {
1524 const NAME: &'static str = "add";
1525 type Error = MockToolError;
1526 type Args = CountingOperationArgs;
1527 type Output = i32;
1528
1529 async fn definition(&self, _prompt: String) -> ToolDefinition {
1530 arithmetic_tool_definition(Self::NAME, "Add x and y together")
1531 }
1532
1533 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
1534 self.calls.fetch_add(1, Ordering::SeqCst);
1535 Ok(args.x + args.y)
1536 }
1537 }
1538
1539 impl Tool for CountingSubtractTool {
1540 const NAME: &'static str = "subtract";
1541 type Error = MockToolError;
1542 type Args = CountingOperationArgs;
1543 type Output = i32;
1544
1545 async fn definition(&self, _prompt: String) -> ToolDefinition {
1546 arithmetic_tool_definition(Self::NAME, "Subtract y from x")
1547 }
1548
1549 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
1550 self.calls.fetch_add(1, Ordering::SeqCst);
1551 Ok(args.x - args.y)
1552 }
1553 }
1554
1555 fn streaming_tool_then_text_model() -> MockCompletionModel {
1556 MockCompletionModel::from_stream_turns([
1557 vec![
1558 MockStreamEvent::tool_call(
1559 "tool_call_1",
1560 "add",
1561 serde_json::json!({"x": 1, "y": 2}),
1562 )
1563 .with_call_id("call_1"),
1564 MockStreamEvent::final_response_with_total_tokens(4),
1565 ],
1566 vec![
1567 MockStreamEvent::text("done"),
1568 MockStreamEvent::final_response_with_total_tokens(6),
1569 ],
1570 ])
1571 }
1572
1573 fn usage(input_tokens: u64, output_tokens: u64) -> Usage {
1574 Usage {
1575 input_tokens,
1576 output_tokens,
1577 total_tokens: input_tokens + output_tokens,
1578 cached_input_tokens: 0,
1579 cache_creation_input_tokens: 0,
1580 tool_use_prompt_tokens: 0,
1581 reasoning_tokens: 0,
1582 }
1583 }
1584
1585 #[derive(Clone, Debug, Default)]
1586 struct CapturedSpan {
1587 id: u64,
1588 name: String,
1589 parent_id: Option<u64>,
1590 fields: HashMap<String, u64>,
1591 }
1592
1593 #[derive(Clone, Default)]
1594 struct CapturedSpans(Arc<Mutex<Vec<CapturedSpan>>>);
1595
1596 impl CapturedSpans {
1597 fn clear(&self) {
1598 if let Ok(mut spans) = self.0.lock() {
1599 spans.clear();
1600 }
1601 }
1602
1603 fn insert(&self, id: &Id, name: &str, parent_id: Option<u64>) {
1604 let id = id.into_u64();
1605 if let Ok(mut spans) = self.0.lock() {
1606 spans.push(CapturedSpan {
1607 id,
1608 name: name.to_string(),
1609 parent_id,
1610 fields: HashMap::new(),
1611 });
1612 }
1613 }
1614
1615 fn record(&self, id: &Id, fields: Vec<(String, u64)>) {
1616 if let Ok(mut spans) = self.0.lock()
1617 && let Some(span) = spans.iter_mut().rev().find(|span| span.id == id.into_u64())
1618 {
1619 span.fields.extend(fields);
1620 }
1621 }
1622
1623 fn snapshot(&self) -> Vec<CapturedSpan> {
1624 self.0.lock().map(|spans| spans.clone()).unwrap_or_default()
1625 }
1626 }
1627
1628 struct SpanCaptureLayer {
1629 spans: CapturedSpans,
1630 }
1631
1632 impl<S> Layer<S> for SpanCaptureLayer
1633 where
1634 S: Subscriber,
1635 S: for<'lookup> LookupSpan<'lookup>,
1636 {
1637 fn on_new_span(&self, attrs: &tracing::span::Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
1638 let parent_id = attrs
1639 .parent()
1640 .map(Id::into_u64)
1641 .or_else(|| ctx.current_span().id().map(Id::into_u64));
1642 self.spans.insert(id, attrs.metadata().name(), parent_id);
1643 }
1644
1645 fn on_record(&self, span: &Id, values: &tracing::span::Record<'_>, _ctx: Context<'_, S>) {
1646 let mut fields = Vec::new();
1647 values.record(&mut SpanFieldCaptureVisitor {
1648 fields: &mut fields,
1649 });
1650 self.spans.record(span, fields);
1651 }
1652 }
1653
1654 struct SpanFieldCaptureVisitor<'a> {
1655 fields: &'a mut Vec<(String, u64)>,
1656 }
1657
1658 impl Visit for SpanFieldCaptureVisitor<'_> {
1659 fn record_u64(&mut self, field: &Field, value: u64) {
1660 self.fields.push((field.name().to_string(), value));
1661 }
1662
1663 fn record_debug(&mut self, _field: &Field, _value: &dyn std::fmt::Debug) {}
1664 }
1665
1666 async fn assert_stream_usage_recorded_on_chat_spans(
1667 agent: crate::agent::Agent<MockCompletionModel>,
1668 prompt: &str,
1669 max_turns: usize,
1670 expected_usages: &[Usage],
1671 ) {
1672 let _isolation = crate::test_utils::scoped_tracing_subscriber_guard().await;
1675 let spans = CapturedSpans::default();
1676 let subscriber = Registry::default().with(SpanCaptureLayer {
1677 spans: spans.clone(),
1678 });
1679 let _default = tracing::subscriber::set_default(subscriber);
1680
1681 let warmup_model = MockCompletionModel::from_stream_turns([[
1692 MockStreamEvent::text("warmup"),
1693 MockStreamEvent::final_response(Usage::default()),
1694 ]]);
1695 let warmup_agent = crate::agent::AgentBuilder::new(warmup_model).build();
1696 let mut warmup_stream = warmup_agent.stream_prompt("warmup").multi_turn(1).await;
1697 while let Some(item) = warmup_stream
1698 .try_next()
1699 .await
1700 .expect("warmup stream should not error")
1701 {
1702 if matches!(item, MultiTurnStreamItem::FinalResponse(_)) {
1703 break;
1704 }
1705 }
1706 tracing::callsite::rebuild_interest_cache();
1707 spans.clear();
1708
1709 let empty_history: &[Message] = &[];
1710 let outer_span = tracing::info_span!("outer");
1711
1712 async {
1713 let mut stream = agent
1714 .stream_prompt(prompt)
1715 .with_history(empty_history)
1716 .multi_turn(max_turns)
1717 .await;
1718
1719 while let Some(item) = stream.try_next().await.expect("stream should not error") {
1720 if matches!(item, MultiTurnStreamItem::FinalResponse(_)) {
1721 break;
1722 }
1723 }
1724 }
1725 .instrument(outer_span)
1726 .await;
1727
1728 let span_snapshot = spans.snapshot();
1729 let outer_span_id = span_snapshot
1730 .iter()
1731 .find(|span| span.name == "outer")
1732 .map(|span| span.id)
1733 .expect("outer span should be captured");
1734 let chat_spans = span_snapshot
1735 .iter()
1736 .filter(|span| span.name == "chat_streaming")
1737 .collect::<Vec<_>>();
1738
1739 assert_eq!(chat_spans.len(), expected_usages.len());
1740 assert!(
1741 span_snapshot.iter().all(|span| span.name != "invoke_agent"),
1742 "outer span path should not create invoke_agent"
1743 );
1744
1745 for (chat_span, expected_usage) in chat_spans.into_iter().zip(expected_usages) {
1746 assert_eq!(chat_span.parent_id, Some(outer_span_id));
1747 assert_eq!(
1748 chat_span.fields.get("gen_ai.usage.input_tokens"),
1749 Some(&expected_usage.input_tokens)
1750 );
1751 assert_eq!(
1752 chat_span.fields.get("gen_ai.usage.output_tokens"),
1753 Some(&expected_usage.output_tokens)
1754 );
1755 assert_eq!(
1756 chat_span.fields.get("gen_ai.usage.cache_read.input_tokens"),
1757 Some(&expected_usage.cached_input_tokens)
1758 );
1759 assert_eq!(
1760 chat_span
1761 .fields
1762 .get("gen_ai.usage.cache_creation.input_tokens"),
1763 Some(&expected_usage.cache_creation_input_tokens)
1764 );
1765 assert_eq!(
1766 chat_span.fields.get("gen_ai.usage.tool_use_prompt_tokens"),
1767 Some(&expected_usage.tool_use_prompt_tokens)
1768 );
1769 assert_eq!(
1770 chat_span.fields.get("gen_ai.usage.reasoning_tokens"),
1771 Some(&expected_usage.reasoning_tokens)
1772 );
1773 }
1774
1775 let outer_span = span_snapshot
1776 .iter()
1777 .find(|span| span.id == outer_span_id)
1778 .expect("outer span should be present");
1779 assert!(
1780 outer_span
1781 .fields
1782 .keys()
1783 .all(|field| !field.starts_with("gen_ai.usage.")),
1784 "usage should not be recorded onto the caller's outer span"
1785 );
1786 }
1787
1788 #[test]
1789 fn completion_calls_stream_item_serializes_and_deserializes_expected_shape() {
1790 let item: MultiTurnStreamItem<MockResponse> =
1791 MultiTurnStreamItem::CompletionCall(CompletionCall::new(2, usage(3, 4)));
1792
1793 let value = serde_json::to_value(&item).expect("serialize completion call event");
1794
1795 assert_eq!(
1796 value,
1797 serde_json::json!({
1798 "type": "completionCall",
1799 "call_index": 2,
1800 "usage": {
1801 "input_tokens": 3,
1802 "output_tokens": 4,
1803 "total_tokens": 7,
1804 "cached_input_tokens": 0,
1805 "cache_creation_input_tokens": 0,
1806 "tool_use_prompt_tokens": 0,
1807 "reasoning_tokens": 0,
1808 }
1809 })
1810 );
1811
1812 let item: MultiTurnStreamItem<MockResponse> =
1813 serde_json::from_value(value).expect("deserialize completion call event");
1814 match item {
1815 MultiTurnStreamItem::CompletionCall(call_usage) => {
1816 assert_eq!(call_usage, CompletionCall::new(2, usage(3, 4)));
1817 }
1818 other => panic!("expected completion call event, got {other:?}"),
1819 }
1820
1821 let item: MultiTurnStreamItem<MockResponse> =
1822 MultiTurnStreamItem::CompletionCall(CompletionCall::new(3, Usage::new()));
1823 let value = serde_json::to_value(&item).expect("serialize missing usage event");
1824
1825 assert_eq!(
1828 value,
1829 serde_json::json!({
1830 "type": "completionCall",
1831 "call_index": 3,
1832 "usage": {
1833 "input_tokens": 0,
1834 "output_tokens": 0,
1835 "total_tokens": 0,
1836 "cached_input_tokens": 0,
1837 "cache_creation_input_tokens": 0,
1838 "tool_use_prompt_tokens": 0,
1839 "reasoning_tokens": 0,
1840 }
1841 })
1842 );
1843
1844 let legacy: MultiTurnStreamItem<MockResponse> = serde_json::from_value(serde_json::json!({
1847 "type": "completionCall",
1848 "call_index": 3,
1849 "usage": null
1850 }))
1851 .expect("legacy null-usage event should deserialize");
1852 match legacy {
1853 MultiTurnStreamItem::CompletionCall(call) => {
1854 assert_eq!(call, CompletionCall::new(3, Usage::new()));
1855 }
1856 other => panic!("expected completion call event, got {other:?}"),
1857 }
1858 }
1859
1860 #[test]
1861 fn final_response_serializes_completion_calls_with_missing_usage() {
1862 let item: MultiTurnStreamItem<MockResponse> =
1863 MultiTurnStreamItem::final_response_with_completion_calls(
1864 OneOrMany::one(AssistantContent::text("done")),
1865 usage(3, 4),
1866 vec![
1867 CompletionCall::new(0, Usage::new()),
1868 CompletionCall::new(1, usage(3, 4)),
1869 ],
1870 None,
1871 );
1872
1873 if let MultiTurnStreamItem::FinalResponse(response) = &item {
1874 assert_eq!(response.requests(), 2);
1875 }
1876
1877 let value = serde_json::to_value(&item).expect("serialize final response");
1878
1879 assert_eq!(
1880 value.get("completionCalls"),
1881 Some(&serde_json::json!([
1882 {
1883 "call_index": 0,
1884 "usage": {
1885 "input_tokens": 0,
1886 "output_tokens": 0,
1887 "total_tokens": 0,
1888 "cached_input_tokens": 0,
1889 "cache_creation_input_tokens": 0,
1890 "tool_use_prompt_tokens": 0,
1891 "reasoning_tokens": 0,
1892 }
1893 },
1894 {
1895 "call_index": 1,
1896 "usage": {
1897 "input_tokens": 3,
1898 "output_tokens": 4,
1899 "total_tokens": 7,
1900 "cached_input_tokens": 0,
1901 "cache_creation_input_tokens": 0,
1902 "tool_use_prompt_tokens": 0,
1903 "reasoning_tokens": 0,
1904 }
1905 }
1906 ]))
1907 );
1908 }
1909
1910 fn streaming_text_then_final_model() -> MockCompletionModel {
1911 MockCompletionModel::from_stream_turns([[
1912 MockStreamEvent::text("hello"),
1913 MockStreamEvent::text(" world"),
1914 MockStreamEvent::final_response_with_total_tokens(3),
1915 ]])
1916 }
1917
1918 fn citation_metadata() -> serde_json::Value {
1919 serde_json::json!({
1920 "citations": [{
1921 "type": "web_search_result_location",
1922 "cited_text": "Claude Shannon was born in 1916.",
1923 "url": "https://example.com/shannon",
1924 "title": "Claude Shannon",
1925 "encrypted_index": "encrypted-reference"
1926 }]
1927 })
1928 }
1929
1930 fn streaming_cited_text_then_final_model() -> MockCompletionModel {
1931 MockCompletionModel::from_stream_turns([[
1932 MockStreamEvent::text_start(Some(citation_metadata())),
1933 MockStreamEvent::text("cited "),
1934 MockStreamEvent::text_start(None),
1935 MockStreamEvent::text("answer"),
1936 MockStreamEvent::final_response_with_total_tokens(3),
1937 ]])
1938 }
1939
1940 fn streaming_cited_text_then_tool_model() -> MockCompletionModel {
1941 MockCompletionModel::from_stream_turns([
1942 vec![
1943 MockStreamEvent::text_start(Some(citation_metadata())),
1944 MockStreamEvent::text("I need a tool. "),
1945 MockStreamEvent::tool_call(
1946 "tool_call_1",
1947 "add",
1948 serde_json::json!({"x": 1, "y": 2}),
1949 )
1950 .with_call_id("call_1"),
1951 MockStreamEvent::final_response_with_total_tokens(4),
1952 ],
1953 vec![
1954 MockStreamEvent::text("done"),
1955 MockStreamEvent::final_response_with_total_tokens(6),
1956 ],
1957 ])
1958 }
1959
1960 fn streaming_final_only_model() -> MockCompletionModel {
1961 MockCompletionModel::from_stream_turns([[
1962 MockStreamEvent::final_response_with_total_tokens(1),
1963 ]])
1964 }
1965
1966 #[derive(Clone)]
1967 struct TerminateOnStreamFinish;
1968
1969 impl PromptHook<MockCompletionModel> for TerminateOnStreamFinish {
1970 async fn on_stream_completion_response_finish(
1971 &self,
1972 _prompt: &Message,
1973 _response: &<MockCompletionModel as CompletionModel>::StreamingResponse,
1974 ) -> HookAction {
1975 HookAction::terminate("stop after completion call")
1976 }
1977 }
1978
1979 type RecordedToolCallDelta = (String, String, Option<String>, String);
1980
1981 #[derive(Clone)]
1982 struct RepairDefaultApiHook;
1983
1984 impl PromptHook<MockCompletionModel> for RepairDefaultApiHook {
1985 fn on_invalid_tool_call(
1986 &self,
1987 context: &InvalidToolCallContext,
1988 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
1989 let tool_name = context.tool_name.clone();
1990 async move {
1991 assert_eq!(tool_name, "default_api");
1992 InvalidToolCallHookAction::repair("add")
1993 }
1994 }
1995 }
1996
1997 #[derive(Clone)]
1998 struct RetryDefaultApiHook;
1999
2000 impl PromptHook<MockCompletionModel> for RetryDefaultApiHook {
2001 fn on_invalid_tool_call(
2002 &self,
2003 context: &InvalidToolCallContext,
2004 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2005 let tool_name = context.tool_name.clone();
2006 let args = context.args.clone();
2007 async move {
2008 assert_eq!(tool_name, "default_api");
2009 if let Some(args) = args {
2010 assert!(!args.is_empty());
2011 }
2012 InvalidToolCallHookAction::retry("Use the add tool instead")
2013 }
2014 }
2015 }
2016
2017 #[derive(Clone)]
2018 struct SkipDefaultApiHook;
2019
2020 impl PromptHook<MockCompletionModel> for SkipDefaultApiHook {
2021 fn on_invalid_tool_call(
2022 &self,
2023 context: &InvalidToolCallContext,
2024 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2025 let tool_name = context.tool_name.clone();
2026 async move {
2027 assert_eq!(tool_name, "default_api");
2028 InvalidToolCallHookAction::skip("default_api was skipped")
2029 }
2030 }
2031 }
2032
2033 #[derive(Clone, Default)]
2034 struct RecordingInvalidToolCallHook {
2035 contexts: Arc<Mutex<Vec<InvalidToolCallContext>>>,
2036 }
2037
2038 impl RecordingInvalidToolCallHook {
2039 fn observed(&self) -> Vec<InvalidToolCallContext> {
2040 self.contexts
2041 .lock()
2042 .expect("invalid tool context records mutex was poisoned")
2043 .clone()
2044 }
2045 }
2046
2047 impl PromptHook<MockCompletionModel> for RecordingInvalidToolCallHook {
2048 fn on_invalid_tool_call(
2049 &self,
2050 context: &InvalidToolCallContext,
2051 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2052 let contexts = self.contexts.clone();
2053 let context = context.clone();
2054
2055 async move {
2056 contexts
2057 .lock()
2058 .expect("invalid tool context records mutex was poisoned")
2059 .push(context);
2060 InvalidToolCallHookAction::fail()
2061 }
2062 }
2063 }
2064
2065 #[derive(Clone, Default)]
2066 struct RecordingToolCallDeltaHook {
2067 deltas: Arc<Mutex<Vec<RecordedToolCallDelta>>>,
2068 }
2069
2070 impl RecordingToolCallDeltaHook {
2071 fn observed(&self) -> Vec<RecordedToolCallDelta> {
2072 self.deltas
2073 .lock()
2074 .expect("tool call delta hook records mutex was poisoned")
2075 .clone()
2076 }
2077 }
2078
2079 impl PromptHook<MockCompletionModel> for RecordingToolCallDeltaHook {
2080 fn on_tool_call_delta(
2081 &self,
2082 tool_call_id: &str,
2083 internal_call_id: &str,
2084 tool_name: Option<&str>,
2085 tool_call_delta: &str,
2086 ) -> impl Future<Output = HookAction> + Send {
2087 let deltas = self.deltas.clone();
2088 let event = (
2089 tool_call_id.to_string(),
2090 internal_call_id.to_string(),
2091 tool_name.map(str::to_string),
2092 tool_call_delta.to_string(),
2093 );
2094
2095 async move {
2096 deltas
2097 .lock()
2098 .expect("tool call delta hook records mutex was poisoned")
2099 .push(event);
2100 HookAction::cont()
2101 }
2102 }
2103 }
2104
2105 #[derive(Clone, Default)]
2106 struct RecordingTextDeltaHook {
2107 deltas: Arc<Mutex<Vec<(String, String)>>>,
2108 }
2109
2110 impl RecordingTextDeltaHook {
2111 fn observed(&self) -> Vec<(String, String)> {
2112 self.deltas
2113 .lock()
2114 .expect("text delta hook records mutex was poisoned")
2115 .clone()
2116 }
2117 }
2118
2119 impl PromptHook<MockCompletionModel> for RecordingTextDeltaHook {
2120 fn on_text_delta(
2121 &self,
2122 text_delta: &str,
2123 full_text: &str,
2124 ) -> impl Future<Output = HookAction> + Send {
2125 let deltas = self.deltas.clone();
2126 let event = (text_delta.to_string(), full_text.to_string());
2127
2128 async move {
2129 deltas
2130 .lock()
2131 .expect("text delta hook records mutex was poisoned")
2132 .push(event);
2133 HookAction::cont()
2134 }
2135 }
2136 }
2137
2138 #[derive(Clone)]
2139 struct RecordingTextAndSkipInvalidToolHook {
2140 text: RecordingTextDeltaHook,
2141 }
2142
2143 impl PromptHook<MockCompletionModel> for RecordingTextAndSkipInvalidToolHook {
2144 fn on_text_delta(
2145 &self,
2146 text_delta: &str,
2147 full_text: &str,
2148 ) -> impl Future<Output = HookAction> + Send {
2149 self.text.on_text_delta(text_delta, full_text)
2150 }
2151
2152 fn on_invalid_tool_call(
2153 &self,
2154 context: &InvalidToolCallContext,
2155 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2156 SkipDefaultApiHook.on_invalid_tool_call(context)
2157 }
2158 }
2159
2160 #[derive(Clone)]
2161 struct RecordingTextAndRetryInvalidToolHook {
2162 text: RecordingTextDeltaHook,
2163 }
2164
2165 impl PromptHook<MockCompletionModel> for RecordingTextAndRetryInvalidToolHook {
2166 fn on_text_delta(
2167 &self,
2168 text_delta: &str,
2169 full_text: &str,
2170 ) -> impl Future<Output = HookAction> + Send {
2171 self.text.on_text_delta(text_delta, full_text)
2172 }
2173
2174 fn on_invalid_tool_call(
2175 &self,
2176 context: &InvalidToolCallContext,
2177 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2178 RetryDefaultApiHook.on_invalid_tool_call(context)
2179 }
2180 }
2181
2182 #[derive(Clone)]
2183 struct RecordingDeltaAndRetryInvalidToolHook {
2184 delta: RecordingToolCallDeltaHook,
2185 }
2186
2187 impl PromptHook<MockCompletionModel> for RecordingDeltaAndRetryInvalidToolHook {
2188 fn on_tool_call_delta(
2189 &self,
2190 tool_call_id: &str,
2191 internal_call_id: &str,
2192 tool_name: Option<&str>,
2193 tool_call_delta: &str,
2194 ) -> impl Future<Output = HookAction> + Send {
2195 self.delta.on_tool_call_delta(
2196 tool_call_id,
2197 internal_call_id,
2198 tool_name,
2199 tool_call_delta,
2200 )
2201 }
2202
2203 fn on_invalid_tool_call(
2204 &self,
2205 context: &InvalidToolCallContext,
2206 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2207 RetryDefaultApiHook.on_invalid_tool_call(context)
2208 }
2209 }
2210
2211 #[derive(Clone)]
2212 struct RecordingDeltaAndSkipInvalidToolHook {
2213 delta: RecordingToolCallDeltaHook,
2214 }
2215
2216 impl PromptHook<MockCompletionModel> for RecordingDeltaAndSkipInvalidToolHook {
2217 fn on_tool_call_delta(
2218 &self,
2219 tool_call_id: &str,
2220 internal_call_id: &str,
2221 tool_name: Option<&str>,
2222 tool_call_delta: &str,
2223 ) -> impl Future<Output = HookAction> + Send {
2224 self.delta.on_tool_call_delta(
2225 tool_call_id,
2226 internal_call_id,
2227 tool_name,
2228 tool_call_delta,
2229 )
2230 }
2231
2232 fn on_invalid_tool_call(
2233 &self,
2234 context: &InvalidToolCallContext,
2235 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2236 SkipDefaultApiHook.on_invalid_tool_call(context)
2237 }
2238 }
2239
2240 #[derive(Clone, Default)]
2241 struct TerminatingToolCallDeltaHook {
2242 deltas: Arc<Mutex<Vec<RecordedToolCallDelta>>>,
2243 }
2244
2245 impl TerminatingToolCallDeltaHook {
2246 fn observed(&self) -> Vec<RecordedToolCallDelta> {
2247 self.deltas
2248 .lock()
2249 .expect("tool call delta hook records mutex was poisoned")
2250 .clone()
2251 }
2252 }
2253
2254 impl PromptHook<MockCompletionModel> for TerminatingToolCallDeltaHook {
2255 fn on_tool_call_delta(
2256 &self,
2257 tool_call_id: &str,
2258 internal_call_id: &str,
2259 tool_name: Option<&str>,
2260 tool_call_delta: &str,
2261 ) -> impl Future<Output = HookAction> + Send {
2262 let deltas = self.deltas.clone();
2263 let event = (
2264 tool_call_id.to_string(),
2265 internal_call_id.to_string(),
2266 tool_name.map(str::to_string),
2267 tool_call_delta.to_string(),
2268 );
2269
2270 async move {
2271 deltas
2272 .lock()
2273 .expect("tool call delta hook records mutex was poisoned")
2274 .push(event);
2275 HookAction::terminate("stop on tool call delta")
2276 }
2277 }
2278 }
2279
2280 fn text_metadata(content: &OneOrMany<AssistantContent>) -> Option<&serde_json::Value> {
2281 content.iter().find_map(|item| match item {
2282 AssistantContent::Text(text) => text.additional_params.as_ref(),
2283 _ => None,
2284 })
2285 }
2286
2287 #[tokio::test]
2288 async fn stream_prompt_continues_after_tool_call_turn() {
2289 let model = streaming_tool_then_text_model();
2290 let recorded = model.clone();
2291 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2292 let empty_history: &[Message] = &[];
2293
2294 let mut stream = agent
2295 .stream_prompt("do tool work")
2296 .with_history(empty_history)
2297 .multi_turn(3)
2298 .await;
2299 let mut saw_tool_call = false;
2300 let mut saw_tool_result = false;
2301 let mut saw_final_response = false;
2302 let mut final_text = String::new();
2303 let mut final_response_text = None;
2304 let mut final_history = None;
2305
2306 while let Some(item) = stream.next().await {
2307 match item {
2308 Ok(MultiTurnStreamItem::StreamAssistantItem(
2309 StreamedAssistantContent::ToolCall { .. },
2310 )) => {
2311 saw_tool_call = true;
2312 }
2313 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2314 ..
2315 })) => {
2316 saw_tool_result = true;
2317 }
2318 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
2319 text,
2320 ))) => {
2321 final_text.push_str(&text.text);
2322 }
2323 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
2324 saw_final_response = true;
2325 final_response_text = Some(res.response().to_owned());
2326 final_history = res.history().map(|history| history.to_vec());
2327 break;
2328 }
2329 Ok(_) => {}
2330 Err(err) => panic!("unexpected streaming error: {err:?}"),
2331 }
2332 }
2333
2334 assert!(saw_tool_call);
2335 assert!(saw_tool_result);
2336 assert!(saw_final_response);
2337 assert_eq!(final_text, "done");
2338 assert_eq!(final_response_text.as_deref(), Some("done"));
2339 let history = final_history.expect("expected final response history");
2340 assert!(history.iter().any(|message| matches!(
2341 message,
2342 Message::Assistant { content, .. }
2343 if content.iter().any(|item| matches!(
2344 item,
2345 AssistantContent::Text(text) if text.text == "done"
2346 ))
2347 )));
2348 let requests = recorded.requests();
2349 assert_eq!(requests.len(), 2);
2350 assert!(validate_follow_up_tool_history(&requests[1]).is_ok());
2351 }
2352
2353 #[tokio::test]
2354 async fn unknown_tool_call_fails_before_streaming_second_request() {
2355 let model = MockCompletionModel::from_stream_turns([
2356 vec![
2357 MockStreamEvent::tool_call(
2358 "tool_call_1",
2359 "default_api",
2360 serde_json::json!({"x": 1, "y": 2}),
2361 ),
2362 MockStreamEvent::final_response_with_total_tokens(4),
2363 ],
2364 vec![
2365 MockStreamEvent::text("should not be requested"),
2366 MockStreamEvent::final_response_with_total_tokens(6),
2367 ],
2368 ]);
2369 let recorded = model.clone();
2370 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2371
2372 let mut stream = agent
2373 .stream_prompt("use the tool")
2374 .with_hook(PanicOnUnknownToolHook)
2375 .multi_turn(3)
2376 .await;
2377 let mut saw_tool_call = false;
2378 let mut error = None;
2379
2380 while let Some(item) = stream.next().await {
2381 match item {
2382 Ok(MultiTurnStreamItem::StreamAssistantItem(
2383 StreamedAssistantContent::ToolCall { .. },
2384 )) => {
2385 saw_tool_call = true;
2386 }
2387 Ok(_) => {}
2388 Err(err) => {
2389 error = Some(err);
2390 break;
2391 }
2392 }
2393 }
2394
2395 assert!(!saw_tool_call);
2396 let error = error.expect("unknown model-emitted tool should fail");
2397 match error {
2398 StreamingError::Prompt(err) => match *err {
2399 PromptError::UnknownToolCall {
2400 tool_name,
2401 available_tools,
2402 allowed_tools,
2403 chat_history,
2404 } => {
2405 assert_eq!(tool_name, "default_api");
2406 assert_eq!(available_tools, vec!["add".to_string()]);
2407 assert_eq!(allowed_tools, vec!["add".to_string()]);
2408 assert!(history_contains_tool_call(&chat_history, "default_api"));
2409 }
2410 other => panic!("expected UnknownToolCall, got {other:?}"),
2411 },
2412 other => panic!("expected prompt streaming error, got {other:?}"),
2413 }
2414 assert_eq!(recorded.request_count(), 1);
2415 }
2416
2417 #[tokio::test]
2418 async fn invalid_tool_call_hook_can_repair_streaming_tool_name() {
2419 let model = MockCompletionModel::from_stream_turns([
2420 vec![
2421 MockStreamEvent::tool_call(
2422 "tool_call_1",
2423 "default_api",
2424 serde_json::json!({"x": 2, "y": 3}),
2425 ),
2426 MockStreamEvent::final_response_with_total_tokens(4),
2427 ],
2428 vec![
2429 MockStreamEvent::text("done"),
2430 MockStreamEvent::final_response_with_total_tokens(6),
2431 ],
2432 ]);
2433 let recorded = model.clone();
2434 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2435
2436 let mut stream = agent
2437 .stream_prompt("use the tool")
2438 .with_hook(RepairDefaultApiHook)
2439 .multi_turn(3)
2440 .with_history(Vec::<Message>::new())
2441 .await;
2442 let mut saw_repaired_tool_call = false;
2443 let mut saw_tool_result = false;
2444 let mut final_response_text = None;
2445
2446 while let Some(item) = stream.next().await {
2447 match item {
2448 Ok(MultiTurnStreamItem::StreamAssistantItem(
2449 StreamedAssistantContent::ToolCall { tool_call, .. },
2450 )) => {
2451 assert_eq!(tool_call.function.name, "add");
2452 saw_repaired_tool_call = true;
2453 }
2454 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2455 tool_result,
2456 ..
2457 })) => {
2458 assert!(tool_result.content.iter().any(|content| {
2459 matches!(
2460 content,
2461 ToolResultContent::Text(text) if text.text == "5"
2462 )
2463 }));
2464 saw_tool_result = true;
2465 }
2466 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2467 final_response_text = Some(response.response().to_string());
2468 break;
2469 }
2470 Ok(_) => {}
2471 Err(err) => panic!("unexpected streaming error: {err:?}"),
2472 }
2473 }
2474
2475 assert!(saw_repaired_tool_call);
2476 assert!(saw_tool_result);
2477 assert_eq!(final_response_text.as_deref(), Some("done"));
2478 assert_eq!(recorded.request_count(), 2);
2479 }
2480
2481 #[tokio::test]
2482 async fn invalid_tool_call_context_uses_completed_streaming_tool_call_provider_id() {
2483 let invalid_hook = RecordingInvalidToolCallHook::default();
2484 let model = MockCompletionModel::from_stream_turns([
2485 vec![
2486 MockStreamEvent::tool_call(
2487 "tool_call_1",
2488 "default_api",
2489 serde_json::json!({"x": 2, "y": 3}),
2490 )
2491 .with_call_id("provider_call_1"),
2492 MockStreamEvent::final_response_with_total_tokens(4),
2493 ],
2494 vec![
2495 MockStreamEvent::text("should not be requested"),
2496 MockStreamEvent::final_response_with_total_tokens(6),
2497 ],
2498 ]);
2499 let recorded = model.clone();
2500 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2501
2502 let mut stream = agent
2503 .stream_prompt("use the tool")
2504 .with_hook(invalid_hook.clone())
2505 .multi_turn(3)
2506 .await;
2507 let mut error = None;
2508
2509 while let Some(item) = stream.next().await {
2510 if let Err(err) = item {
2511 error = Some(err);
2512 break;
2513 }
2514 }
2515
2516 assert!(error.is_some(), "invalid tool should fail");
2517 assert_eq!(recorded.request_count(), 1);
2518 let contexts = invalid_hook.observed();
2519 assert_eq!(contexts.len(), 1);
2520 let context = &contexts[0];
2521 assert_eq!(context.tool_name, "default_api");
2522 assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
2523 assert!(context.internal_call_id.is_some());
2524 assert!(context.is_streaming);
2525 }
2526
2527 #[tokio::test]
2528 async fn invalid_tool_call_hook_skip_emits_streaming_tool_result() {
2529 let add_calls = Arc::new(AtomicU32::new(0));
2530 let model = MockCompletionModel::from_stream_turns([
2531 vec![
2532 MockStreamEvent::tool_call(
2533 "tool_call_1",
2534 "default_api",
2535 serde_json::json!({"x": 2, "y": 3}),
2536 )
2537 .with_call_id("call_1"),
2538 MockStreamEvent::final_response_with_total_tokens(4),
2539 ],
2540 vec![
2541 MockStreamEvent::text("continued"),
2542 MockStreamEvent::final_response_with_total_tokens(6),
2543 ],
2544 ]);
2545 let recorded = model.clone();
2546 let agent = AgentBuilder::new(model)
2547 .tool(CountingAddTool {
2548 calls: add_calls.clone(),
2549 })
2550 .build();
2551
2552 let mut stream = agent
2553 .stream_prompt("use the tool")
2554 .with_hook(SkipDefaultApiHook)
2555 .multi_turn(3)
2556 .with_history(Vec::<Message>::new())
2557 .await;
2558 let mut skipped_tool_result = None;
2559 let mut final_response_text = None;
2560
2561 while let Some(item) = stream.next().await {
2562 match item {
2563 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2564 tool_result,
2565 internal_call_id,
2566 })) => {
2567 assert!(!internal_call_id.is_empty());
2568 skipped_tool_result = Some(tool_result);
2569 }
2570 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2571 final_response_text = Some(response.response().to_string());
2572 break;
2573 }
2574 Ok(_) => {}
2575 Err(err) => panic!("unexpected streaming error: {err:?}"),
2576 }
2577 }
2578
2579 let skipped_tool_result =
2580 skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
2581 assert_eq!(skipped_tool_result.id, "tool_call_1");
2582 assert_eq!(skipped_tool_result.call_id.as_deref(), Some("call_1"));
2583 assert!(skipped_tool_result.content.iter().any(|content| matches!(
2584 content,
2585 ToolResultContent::Text(text) if text.text == "default_api was skipped"
2586 )));
2587 assert_eq!(final_response_text.as_deref(), Some("continued"));
2588 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
2589
2590 let requests = recorded.requests();
2591 assert_eq!(requests.len(), 2);
2592 let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
2593 assert!(matches!(
2594 follow_up_history.get(2),
2595 Some(Message::User { content })
2596 if content.iter().any(|item| matches!(
2597 item,
2598 UserContent::ToolResult(result)
2599 if result.id == "tool_call_1"
2600 && result.content.iter().any(|content| matches!(
2601 content,
2602 ToolResultContent::Text(text)
2603 if text.text == "default_api was skipped"
2604 ))
2605 ))
2606 ));
2607 }
2608
2609 #[tokio::test]
2610 async fn invalid_tool_call_hook_retries_mixed_streaming_turn_without_executing_valid_call() {
2611 let add_calls = Arc::new(AtomicU32::new(0));
2612 let model = MockCompletionModel::from_stream_turns([
2613 vec![
2614 MockStreamEvent::text("checking "),
2615 MockStreamEvent::tool_call(
2616 "tool_call_1",
2617 "add",
2618 serde_json::json!({"x": 2, "y": 3}),
2619 )
2620 .with_call_id("call_1"),
2621 MockStreamEvent::tool_call(
2622 "tool_call_2",
2623 "default_api",
2624 serde_json::json!({"x": 4, "y": 5}),
2625 )
2626 .with_call_id("call_2"),
2627 MockStreamEvent::final_response_with_total_tokens(4),
2628 ],
2629 vec![
2630 MockStreamEvent::text("retried"),
2631 MockStreamEvent::final_response_with_total_tokens(6),
2632 ],
2633 ]);
2634 let recorded = model.clone();
2635 let agent = AgentBuilder::new(model)
2636 .tool(CountingAddTool {
2637 calls: add_calls.clone(),
2638 })
2639 .build();
2640
2641 let mut stream = agent
2642 .stream_prompt("use the tool")
2643 .with_hook(RetryDefaultApiHook)
2644 .multi_turn(3)
2645 .with_history(Vec::<Message>::new())
2646 .max_invalid_tool_call_retries(1)
2647 .await;
2648 let mut completion_call_events = Vec::new();
2649 let mut final_response_text = None;
2650 let mut final_response_usage = Usage::new();
2651 let mut final_completion_calls = Vec::new();
2652
2653 while let Some(item) = stream.next().await {
2654 match item {
2655 Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
2656 completion_call_events.push(completion_call);
2657 }
2658 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2659 final_response_text = Some(response.response().to_string());
2660 final_response_usage = response.usage();
2661 final_completion_calls = response.completion_calls().to_vec();
2662 break;
2663 }
2664 Ok(_) => {}
2665 Err(err) => panic!("unexpected streaming error: {err:?}"),
2666 }
2667 }
2668
2669 assert_eq!(final_response_text.as_deref(), Some("retried"));
2670 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
2671 let mut first_usage = Usage::new();
2672 first_usage.total_tokens = 4;
2673 let mut second_usage = Usage::new();
2674 second_usage.total_tokens = 6;
2675 let expected_completion_calls = vec![
2676 CompletionCall::new(0, first_usage),
2677 CompletionCall::new(1, second_usage),
2678 ];
2679 assert_eq!(completion_call_events, expected_completion_calls);
2680 assert_eq!(final_completion_calls, expected_completion_calls);
2681 assert_eq!(final_response_usage.total_tokens, 10);
2682
2683 let requests = recorded.requests();
2684 assert_eq!(requests.len(), 2);
2685 let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
2686 assert_eq!(retry_history.len(), 3);
2687 assert!(matches!(
2688 retry_history.get(1),
2689 Some(Message::Assistant { content, .. })
2690 if content.iter().any(|item| matches!(
2691 item,
2692 AssistantContent::Text(text) if text.text == "checking "
2693 ))
2694 && content.iter().any(|item| matches!(
2695 item,
2696 AssistantContent::ToolCall(tool_call)
2697 if tool_call.id == "tool_call_1"
2698 && tool_call.function.name == "add"
2699 ))
2700 && content.iter().any(|item| matches!(
2701 item,
2702 AssistantContent::ToolCall(tool_call)
2703 if tool_call.id == "tool_call_2"
2704 && tool_call.function.name == "default_api"
2705 ))
2706 ));
2707 assert!(matches!(
2708 retry_history.get(2),
2709 Some(Message::User { content })
2710 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
2711 && content.iter().any(|item| matches!(
2712 item,
2713 UserContent::ToolResult(result)
2714 if result.id == "tool_call_1"
2715 && result.content.iter().any(|content| matches!(
2716 content,
2717 ToolResultContent::Text(text)
2718 if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
2719 ))
2720 ))
2721 && content.iter().any(|item| matches!(
2722 item,
2723 UserContent::ToolResult(result)
2724 if result.id == "tool_call_2"
2725 && result.content.iter().any(|content| matches!(
2726 content,
2727 ToolResultContent::Text(text)
2728 if text.text == "Use the add tool instead"
2729 ))
2730 ))
2731 ));
2732 }
2733
2734 #[tokio::test]
2735 async fn invalid_tool_call_hook_skips_mixed_streaming_turn_without_executing_valid_call() {
2736 let add_calls = Arc::new(AtomicU32::new(0));
2737 let model = MockCompletionModel::from_stream_turns([
2738 vec![
2739 MockStreamEvent::text("checking "),
2740 MockStreamEvent::tool_call(
2741 "tool_call_1",
2742 "add",
2743 serde_json::json!({"x": 2, "y": 3}),
2744 )
2745 .with_call_id("call_1"),
2746 MockStreamEvent::tool_call(
2747 "tool_call_2",
2748 "default_api",
2749 serde_json::json!({"x": 4, "y": 5}),
2750 )
2751 .with_call_id("call_2"),
2752 MockStreamEvent::final_response_with_total_tokens(4),
2753 ],
2754 vec![
2755 MockStreamEvent::text("continued"),
2756 MockStreamEvent::final_response_with_total_tokens(6),
2757 ],
2758 ]);
2759 let recorded = model.clone();
2760 let agent = AgentBuilder::new(model)
2761 .tool(CountingAddTool {
2762 calls: add_calls.clone(),
2763 })
2764 .build();
2765
2766 let mut stream = agent
2767 .stream_prompt("use the tool")
2768 .with_hook(SkipDefaultApiHook)
2769 .multi_turn(3)
2770 .with_history(Vec::<Message>::new())
2771 .await;
2772 let mut skipped_tool_result = None;
2773 let mut final_response_text = None;
2774
2775 while let Some(item) = stream.next().await {
2776 match item {
2777 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2778 tool_result,
2779 ..
2780 })) => {
2781 skipped_tool_result = Some(tool_result);
2782 }
2783 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2784 final_response_text = Some(response.response().to_string());
2785 break;
2786 }
2787 Ok(_) => {}
2788 Err(err) => panic!("unexpected streaming error: {err:?}"),
2789 }
2790 }
2791
2792 let skipped_tool_result =
2793 skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
2794 assert_eq!(skipped_tool_result.id, "tool_call_2");
2795 assert_eq!(skipped_tool_result.call_id.as_deref(), Some("call_2"));
2796 assert_eq!(final_response_text.as_deref(), Some("continued"));
2797 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
2798
2799 let requests = recorded.requests();
2800 assert_eq!(requests.len(), 2);
2801 let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
2802 assert_eq!(follow_up_history.len(), 3);
2803 assert!(matches!(
2804 follow_up_history.get(1),
2805 Some(Message::Assistant { content, .. })
2806 if content.iter().any(|item| matches!(
2807 item,
2808 AssistantContent::Text(text) if text.text == "checking "
2809 ))
2810 && content.iter().any(|item| matches!(
2811 item,
2812 AssistantContent::ToolCall(tool_call)
2813 if tool_call.id == "tool_call_1"
2814 && tool_call.function.name == "add"
2815 ))
2816 && content.iter().any(|item| matches!(
2817 item,
2818 AssistantContent::ToolCall(tool_call)
2819 if tool_call.id == "tool_call_2"
2820 && tool_call.function.name == "default_api"
2821 ))
2822 ));
2823 assert!(matches!(
2824 follow_up_history.get(2),
2825 Some(Message::User { content })
2826 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
2827 && content.iter().any(|item| matches!(
2828 item,
2829 UserContent::ToolResult(result)
2830 if result.id == "tool_call_1"
2831 && result.call_id.as_deref() == Some("call_1")
2832 && result.content.iter().any(|content| matches!(
2833 content,
2834 ToolResultContent::Text(text)
2835 if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
2836 ))
2837 ))
2838 && content.iter().any(|item| matches!(
2839 item,
2840 UserContent::ToolResult(result)
2841 if result.id == "tool_call_2"
2842 && result.call_id.as_deref() == Some("call_2")
2843 && result.content.iter().any(|content| matches!(
2844 content,
2845 ToolResultContent::Text(text)
2846 if text.text == "default_api was skipped"
2847 ))
2848 ))
2849 ));
2850 }
2851
2852 #[tokio::test]
2853 async fn invalid_completed_tool_call_skip_preserves_streaming_reasoning_history() {
2854 let model = MockCompletionModel::from_stream_turns([
2855 vec![
2856 MockStreamEvent::text("checking "),
2857 MockStreamEvent::reasoning("reasoned step").with_reasoning_id("rs_1"),
2858 MockStreamEvent::tool_call(
2859 "tool_call_1",
2860 "default_api",
2861 serde_json::json!({"x": 2, "y": 3}),
2862 ),
2863 MockStreamEvent::final_response_with_total_tokens(4),
2864 ],
2865 vec![
2866 MockStreamEvent::text("continued"),
2867 MockStreamEvent::final_response_with_total_tokens(6),
2868 ],
2869 ]);
2870 let recorded = model.clone();
2871 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2872
2873 let mut stream = agent
2874 .stream_prompt("use the tool")
2875 .with_hook(SkipDefaultApiHook)
2876 .multi_turn(3)
2877 .with_history(Vec::<Message>::new())
2878 .await;
2879
2880 while let Some(item) = stream.next().await {
2881 match item {
2882 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
2883 Ok(_) => {}
2884 Err(err) => panic!("unexpected streaming error: {err:?}"),
2885 }
2886 }
2887
2888 let requests = recorded.requests();
2889 assert_eq!(requests.len(), 2);
2890 let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
2891 assert!(history_contains_text(&follow_up_history, "checking "));
2892 assert!(assistant_reasoning_precedes_tool_call(
2893 &follow_up_history,
2894 "reasoned step",
2895 "default_api"
2896 ));
2897 assert!(
2898 assistant_reasoning_precedes_text_and_tool_call(
2899 &follow_up_history,
2900 "reasoned step",
2901 "checking ",
2902 "default_api"
2903 ),
2904 "{follow_up_history:?}"
2905 );
2906 }
2907
2908 #[tokio::test]
2909 async fn invalid_name_delta_retry_preserves_streaming_reasoning_history() {
2910 let model = MockCompletionModel::from_stream_turns([
2911 vec![
2912 MockStreamEvent::reasoning_delta(Some("rs_1"), "delta reason"),
2913 MockStreamEvent::tool_call_arguments_delta(
2914 "tool_call_1",
2915 "internal_1",
2916 r#"{"x":2,"y":3}"#,
2917 ),
2918 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
2919 MockStreamEvent::final_response_with_total_tokens(4),
2920 ],
2921 vec![
2922 MockStreamEvent::text("retried"),
2923 MockStreamEvent::final_response_with_total_tokens(6),
2924 ],
2925 ]);
2926 let recorded = model.clone();
2927 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2928
2929 let mut stream = agent
2930 .stream_prompt("use the tool")
2931 .with_hook(RetryDefaultApiHook)
2932 .multi_turn(3)
2933 .with_history(Vec::<Message>::new())
2934 .max_invalid_tool_call_retries(1)
2935 .await;
2936
2937 while let Some(item) = stream.next().await {
2938 match item {
2939 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
2940 Ok(_) => {}
2941 Err(err) => panic!("unexpected streaming error: {err:?}"),
2942 }
2943 }
2944
2945 let requests = recorded.requests();
2946 assert_eq!(requests.len(), 2);
2947 let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
2948 assert!(assistant_reasoning_precedes_tool_call(
2949 &retry_history,
2950 "delta reason",
2951 "default_api"
2952 ));
2953 }
2954
2955 #[tokio::test]
2956 async fn invalid_tool_call_hook_skip_resets_streaming_text_delta_state() {
2957 let text_hook = RecordingTextDeltaHook::default();
2958 let model = MockCompletionModel::from_stream_turns([
2959 vec![
2960 MockStreamEvent::text("stale "),
2961 MockStreamEvent::tool_call(
2962 "tool_call_1",
2963 "default_api",
2964 serde_json::json!({"x": 2, "y": 3}),
2965 ),
2966 MockStreamEvent::final_response_with_total_tokens(4),
2967 ],
2968 vec![
2969 MockStreamEvent::text("fresh"),
2970 MockStreamEvent::final_response_with_total_tokens(6),
2971 ],
2972 ]);
2973 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2974
2975 let mut stream = agent
2976 .stream_prompt("use the tool")
2977 .with_hook(RecordingTextAndSkipInvalidToolHook {
2978 text: text_hook.clone(),
2979 })
2980 .multi_turn(3)
2981 .with_history(Vec::<Message>::new())
2982 .await;
2983
2984 while let Some(item) = stream.next().await {
2985 match item {
2986 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
2987 Ok(_) => {}
2988 Err(err) => panic!("unexpected streaming error: {err:?}"),
2989 }
2990 }
2991
2992 assert_eq!(
2993 text_hook.observed(),
2994 vec![
2995 ("stale ".to_string(), "stale ".to_string()),
2996 ("fresh".to_string(), "fresh".to_string()),
2997 ]
2998 );
2999 }
3000
3001 #[tokio::test]
3002 async fn invalid_tool_call_delta_retry_uses_structured_tool_feedback() {
3003 let delta_hook = RecordingToolCallDeltaHook::default();
3004 let add_calls = Arc::new(AtomicU32::new(0));
3005 let model = MockCompletionModel::from_stream_turns([
3006 vec![
3007 MockStreamEvent::text("checking "),
3008 MockStreamEvent::reasoning_delta(Some("rs_1"), "diagnostic reason"),
3009 MockStreamEvent::tool_call(
3010 "tool_call_0",
3011 "add",
3012 serde_json::json!({"x": 1, "y": 2}),
3013 )
3014 .with_call_id("call_0"),
3015 MockStreamEvent::tool_call_arguments_delta(
3016 "tool_call_1",
3017 "internal_1",
3018 r#"{"x":2,"y":3}"#,
3019 ),
3020 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3021 MockStreamEvent::final_response_with_total_tokens(4),
3022 ],
3023 vec![
3024 MockStreamEvent::text("retried"),
3025 MockStreamEvent::final_response_with_total_tokens(6),
3026 ],
3027 ]);
3028 let recorded = model.clone();
3029 let agent = AgentBuilder::new(model)
3030 .tool(CountingAddTool {
3031 calls: add_calls.clone(),
3032 })
3033 .build();
3034
3035 let mut stream = agent
3036 .stream_prompt("use the tool")
3037 .with_hook(RecordingDeltaAndRetryInvalidToolHook {
3038 delta: delta_hook.clone(),
3039 })
3040 .multi_turn(3)
3041 .with_history(Vec::<Message>::new())
3042 .max_invalid_tool_call_retries(1)
3043 .await;
3044 let mut completion_call_events = Vec::new();
3045 let mut final_response_text = None;
3046 let mut final_response_usage = Usage::new();
3047 let mut final_completion_calls = Vec::new();
3048
3049 while let Some(item) = stream.next().await {
3050 match item {
3051 Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
3052 completion_call_events.push(completion_call);
3053 }
3054 Ok(MultiTurnStreamItem::StreamAssistantItem(
3055 StreamedAssistantContent::ToolCallDelta { .. },
3056 )) => panic!("invalid tool-call delta should not be emitted"),
3057 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3058 final_response_text = Some(response.response().to_string());
3059 final_response_usage = response.usage();
3060 final_completion_calls = response.completion_calls().to_vec();
3061 break;
3062 }
3063 Ok(_) => {}
3064 Err(err) => panic!("unexpected streaming error: {err:?}"),
3065 }
3066 }
3067
3068 assert_eq!(final_response_text.as_deref(), Some("retried"));
3069 assert!(delta_hook.observed().is_empty());
3070 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3071 let mut first_usage = Usage::new();
3072 first_usage.total_tokens = 4;
3073 let mut second_usage = Usage::new();
3074 second_usage.total_tokens = 6;
3075 let expected_completion_calls = vec![
3076 CompletionCall::new(0, first_usage),
3077 CompletionCall::new(1, second_usage),
3078 ];
3079 assert_eq!(completion_call_events, expected_completion_calls);
3080 assert_eq!(final_completion_calls, expected_completion_calls);
3081 assert_eq!(final_response_usage.total_tokens, 10);
3082
3083 let requests = recorded.requests();
3084 assert_eq!(requests.len(), 2);
3085 let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3086 assert!(matches!(
3087 retry_history.get(1),
3088 Some(Message::Assistant { content, .. })
3089 if content.iter().any(|item| matches!(
3090 item,
3091 AssistantContent::Text(text) if text.text == "checking "
3092 ))
3093 && content.iter().any(|item| matches!(
3094 item,
3095 AssistantContent::ToolCall(tool_call)
3096 if tool_call.id == "tool_call_0"
3097 && tool_call.function.name == "add"
3098 ))
3099 && content.iter().any(|item| matches!(
3100 item,
3101 AssistantContent::ToolCall(tool_call)
3102 if tool_call.id == "tool_call_1"
3103 && tool_call.function.name == "default_api"
3104 && tool_call.function.arguments == serde_json::json!({"x": 2, "y": 3})
3105 ))
3106 ));
3107 assert!(matches!(
3108 retry_history.get(2),
3109 Some(Message::User { content })
3110 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3111 && content.iter().any(|item| matches!(
3112 item,
3113 UserContent::ToolResult(result)
3114 if result.id == "tool_call_0"
3115 && result.call_id.as_deref() == Some("call_0")
3116 && result.content.iter().any(|content| matches!(
3117 content,
3118 ToolResultContent::Text(text)
3119 if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3120 ))
3121 ))
3122 && content.iter().any(|item| matches!(
3123 item,
3124 UserContent::ToolResult(result)
3125 if result.id == "tool_call_1"
3126 && result.content.iter().any(|content| matches!(
3127 content,
3128 ToolResultContent::Text(text)
3129 if text.text == "Use the add tool instead"
3130 ))
3131 ))
3132 ));
3133 }
3134
3135 #[tokio::test]
3136 async fn invalid_tool_call_delta_context_includes_same_turn_history_and_tool_call_id() {
3137 let invalid_hook = RecordingInvalidToolCallHook::default();
3138 let model = MockCompletionModel::from_stream_turns([
3139 vec![
3140 MockStreamEvent::text("checking "),
3141 MockStreamEvent::reasoning_delta(Some("rs_1"), "diagnostic reason"),
3142 MockStreamEvent::tool_call(
3143 "tool_call_0",
3144 "add",
3145 serde_json::json!({"x": 1, "y": 2}),
3146 )
3147 .with_call_id("call_0"),
3148 MockStreamEvent::tool_call_arguments_delta(
3149 "tool_call_1",
3150 "internal_1",
3151 r#"{"x":2,"y":3}"#,
3152 ),
3153 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3154 MockStreamEvent::final_response_with_total_tokens(4),
3155 ],
3156 vec![
3157 MockStreamEvent::text("should not be requested"),
3158 MockStreamEvent::final_response_with_total_tokens(6),
3159 ],
3160 ]);
3161 let recorded = model.clone();
3162 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3163
3164 let mut stream = agent
3165 .stream_prompt("use the tool")
3166 .with_hook(invalid_hook.clone())
3167 .multi_turn(3)
3168 .await;
3169 let mut error = None;
3170
3171 while let Some(item) = stream.next().await {
3172 if let Err(err) = item {
3173 error = Some(err);
3174 break;
3175 }
3176 }
3177
3178 assert!(error.is_some(), "invalid name delta should fail");
3179 assert_eq!(recorded.request_count(), 1);
3180 let contexts = invalid_hook.observed();
3181 assert_eq!(contexts.len(), 1);
3182 let context = &contexts[0];
3183 assert_eq!(context.tool_name, "default_api");
3184 assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
3185 assert_eq!(context.internal_call_id.as_deref(), Some("internal_1"));
3186 assert!(context.is_streaming);
3187 assert!(history_contains_text(&context.chat_history, "checking "));
3188 assert!(
3189 assistant_reasoning_precedes_tool_call(
3190 &context.chat_history,
3191 "diagnostic reason",
3192 "add"
3193 ),
3194 "{:?}",
3195 context.chat_history
3196 );
3197 assert!(history_contains_tool_call(&context.chat_history, "add"));
3198 assert!(history_contains_tool_call(
3199 &context.chat_history,
3200 "default_api"
3201 ));
3202 }
3203
3204 #[tokio::test]
3205 async fn invalid_tool_call_delta_retry_resets_streaming_text_delta_state() {
3206 let text_hook = RecordingTextDeltaHook::default();
3207 let model = MockCompletionModel::from_stream_turns([
3208 vec![
3209 MockStreamEvent::text("stale "),
3210 MockStreamEvent::tool_call_arguments_delta(
3211 "tool_call_1",
3212 "internal_1",
3213 r#"{"x":2,"y":3}"#,
3214 ),
3215 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3216 MockStreamEvent::final_response_with_total_tokens(4),
3217 ],
3218 vec![
3219 MockStreamEvent::text("fresh"),
3220 MockStreamEvent::final_response_with_total_tokens(6),
3221 ],
3222 ]);
3223 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3224
3225 let mut stream = agent
3226 .stream_prompt("use the tool")
3227 .with_hook(RecordingTextAndRetryInvalidToolHook {
3228 text: text_hook.clone(),
3229 })
3230 .multi_turn(3)
3231 .with_history(Vec::<Message>::new())
3232 .max_invalid_tool_call_retries(1)
3233 .await;
3234
3235 while let Some(item) = stream.next().await {
3236 match item {
3237 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3238 Ok(_) => {}
3239 Err(err) => panic!("unexpected streaming error: {err:?}"),
3240 }
3241 }
3242
3243 assert_eq!(
3244 text_hook.observed(),
3245 vec![
3246 ("stale ".to_string(), "stale ".to_string()),
3247 ("fresh".to_string(), "fresh".to_string()),
3248 ]
3249 );
3250 }
3251
3252 #[tokio::test]
3253 async fn invalid_tool_call_delta_skip_uses_structured_tool_feedback() {
3254 let delta_hook = RecordingToolCallDeltaHook::default();
3255 let add_calls = Arc::new(AtomicU32::new(0));
3256 let model = MockCompletionModel::from_stream_turns([
3257 vec![
3258 MockStreamEvent::text("checking "),
3259 MockStreamEvent::tool_call(
3260 "tool_call_0",
3261 "add",
3262 serde_json::json!({"x": 1, "y": 2}),
3263 )
3264 .with_call_id("call_0"),
3265 MockStreamEvent::tool_call_arguments_delta(
3266 "tool_call_1",
3267 "internal_1",
3268 r#"{"x":2,"y":3}"#,
3269 ),
3270 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3271 MockStreamEvent::final_response_with_total_tokens(4),
3272 ],
3273 vec![
3274 MockStreamEvent::text("continued"),
3275 MockStreamEvent::final_response_with_total_tokens(6),
3276 ],
3277 ]);
3278 let recorded = model.clone();
3279 let agent = AgentBuilder::new(model)
3280 .tool(CountingAddTool {
3281 calls: add_calls.clone(),
3282 })
3283 .build();
3284
3285 let mut stream = agent
3286 .stream_prompt("use the tool")
3287 .with_hook(RecordingDeltaAndSkipInvalidToolHook {
3288 delta: delta_hook.clone(),
3289 })
3290 .multi_turn(3)
3291 .with_history(Vec::<Message>::new())
3292 .await;
3293 let mut skipped_tool_result = None;
3294 let mut final_response_text = None;
3295
3296 while let Some(item) = stream.next().await {
3297 match item {
3298 Ok(MultiTurnStreamItem::StreamAssistantItem(
3299 StreamedAssistantContent::ToolCallDelta { .. },
3300 )) => panic!("invalid tool-call delta should not be emitted"),
3301 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3302 tool_result,
3303 internal_call_id,
3304 })) => {
3305 assert_eq!(internal_call_id, "internal_1");
3306 skipped_tool_result = Some(tool_result);
3307 }
3308 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3309 final_response_text = Some(response.response().to_string());
3310 break;
3311 }
3312 Ok(_) => {}
3313 Err(err) => panic!("unexpected streaming error: {err:?}"),
3314 }
3315 }
3316
3317 let skipped_tool_result =
3318 skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
3319 assert_eq!(skipped_tool_result.id, "tool_call_1");
3320 assert!(skipped_tool_result.call_id.is_none());
3321 assert!(skipped_tool_result.content.iter().any(|content| matches!(
3322 content,
3323 ToolResultContent::Text(text) if text.text == "default_api was skipped"
3324 )));
3325 assert_eq!(final_response_text.as_deref(), Some("continued"));
3326 assert!(delta_hook.observed().is_empty());
3327 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3328
3329 let requests = recorded.requests();
3330 assert_eq!(requests.len(), 2);
3331 let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3332 assert!(matches!(
3333 follow_up_history.get(1),
3334 Some(Message::Assistant { content, .. })
3335 if content.iter().any(|item| matches!(
3336 item,
3337 AssistantContent::Text(text) if text.text == "checking "
3338 ))
3339 && content.iter().any(|item| matches!(
3340 item,
3341 AssistantContent::ToolCall(tool_call)
3342 if tool_call.id == "tool_call_0"
3343 && tool_call.function.name == "add"
3344 ))
3345 && content.iter().any(|item| matches!(
3346 item,
3347 AssistantContent::ToolCall(tool_call)
3348 if tool_call.id == "tool_call_1"
3349 && tool_call.function.name == "default_api"
3350 && tool_call.function.arguments == serde_json::json!({"x": 2, "y": 3})
3351 ))
3352 ));
3353 assert!(matches!(
3354 follow_up_history.get(2),
3355 Some(Message::User { content })
3356 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3357 && content.iter().any(|item| matches!(
3358 item,
3359 UserContent::ToolResult(result)
3360 if result.id == "tool_call_0"
3361 && result.call_id.as_deref() == Some("call_0")
3362 && result.content.iter().any(|content| matches!(
3363 content,
3364 ToolResultContent::Text(text)
3365 if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3366 ))
3367 ))
3368 && content.iter().any(|item| matches!(
3369 item,
3370 UserContent::ToolResult(result)
3371 if result.id == "tool_call_1"
3372 && result.content.iter().any(|content| matches!(
3373 content,
3374 ToolResultContent::Text(text)
3375 if text.text == "default_api was skipped"
3376 ))
3377 ))
3378 ));
3379 }
3380
3381 #[tokio::test]
3382 async fn streaming_retry_budget_exhaustion_history_contains_invalid_tool_call() {
3383 let model = MockCompletionModel::from_stream_turns([
3384 vec![
3385 MockStreamEvent::tool_call(
3386 "tool_call_1",
3387 "default_api",
3388 serde_json::json!({"x": 1, "y": 2}),
3389 ),
3390 MockStreamEvent::final_response_with_total_tokens(4),
3391 ],
3392 vec![
3393 MockStreamEvent::text("should not be requested"),
3394 MockStreamEvent::final_response_with_total_tokens(6),
3395 ],
3396 ]);
3397 let recorded = model.clone();
3398 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3399
3400 let mut stream = agent
3401 .stream_prompt("use the tool")
3402 .with_hook(RetryDefaultApiHook)
3403 .multi_turn(3)
3404 .max_invalid_tool_call_retries(0)
3405 .await;
3406 let mut error = None;
3407
3408 while let Some(item) = stream.next().await {
3409 if let Err(err) = item {
3410 error = Some(err);
3411 break;
3412 }
3413 }
3414
3415 let error = error.expect("retry budget exhaustion should fail");
3416 match error {
3417 StreamingError::Prompt(err) => match *err {
3418 PromptError::UnknownToolCall {
3419 tool_name,
3420 chat_history,
3421 ..
3422 } => {
3423 assert_eq!(tool_name, "default_api");
3424 assert!(history_contains_tool_call(&chat_history, "default_api"));
3425 }
3426 other => panic!("expected UnknownToolCall, got {other:?}"),
3427 },
3428 other => panic!("expected prompt streaming error, got {other:?}"),
3429 }
3430 assert_eq!(recorded.request_count(), 1);
3431 }
3432
3433 #[tokio::test]
3434 async fn streaming_name_delta_retry_budget_exhaustion_history_includes_same_turn_context() {
3435 let model = MockCompletionModel::from_stream_turns([
3436 vec![
3437 MockStreamEvent::text("checking "),
3438 MockStreamEvent::tool_call(
3439 "tool_call_0",
3440 "add",
3441 serde_json::json!({"x": 1, "y": 2}),
3442 )
3443 .with_call_id("call_0"),
3444 MockStreamEvent::tool_call_arguments_delta(
3445 "tool_call_1",
3446 "internal_1",
3447 r#"{"x":2,"y":3}"#,
3448 ),
3449 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3450 MockStreamEvent::final_response_with_total_tokens(4),
3451 ],
3452 vec![
3453 MockStreamEvent::text("should not be requested"),
3454 MockStreamEvent::final_response_with_total_tokens(6),
3455 ],
3456 ]);
3457 let recorded = model.clone();
3458 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3459
3460 let mut stream = agent
3461 .stream_prompt("use the tool")
3462 .with_hook(RetryDefaultApiHook)
3463 .multi_turn(3)
3464 .max_invalid_tool_call_retries(0)
3465 .await;
3466 let mut error = None;
3467
3468 while let Some(item) = stream.next().await {
3469 if let Err(err) = item {
3470 error = Some(err);
3471 break;
3472 }
3473 }
3474
3475 let error = error.expect("retry budget exhaustion should fail");
3476 match error {
3477 StreamingError::Prompt(err) => match *err {
3478 PromptError::UnknownToolCall {
3479 tool_name,
3480 chat_history,
3481 ..
3482 } => {
3483 assert_eq!(tool_name, "default_api");
3484 assert!(history_contains_text(&chat_history, "checking "));
3485 assert!(history_contains_tool_call(&chat_history, "add"));
3486 assert!(history_contains_tool_call(&chat_history, "default_api"));
3487 }
3488 other => panic!("expected UnknownToolCall, got {other:?}"),
3489 },
3490 other => panic!("expected prompt streaming error, got {other:?}"),
3491 }
3492 assert_eq!(recorded.request_count(), 1);
3493 }
3494
3495 #[tokio::test]
3496 async fn completed_unknown_tool_call_after_text_fails_before_finish_hook_or_later_emit() {
3497 let add_calls = Arc::new(AtomicU32::new(0));
3498 let model = MockCompletionModel::from_stream_turns([
3499 vec![
3500 MockStreamEvent::text("thinking "),
3501 MockStreamEvent::tool_call(
3502 "tool_call_1",
3503 "default_api",
3504 serde_json::json!({"x": 1, "y": 2}),
3505 ),
3506 MockStreamEvent::final_response_with_total_tokens(4),
3507 ],
3508 vec![
3509 MockStreamEvent::text("should not be requested"),
3510 MockStreamEvent::final_response_with_total_tokens(6),
3511 ],
3512 ]);
3513 let recorded = model.clone();
3514 let agent = AgentBuilder::new(model)
3515 .tool(CountingAddTool {
3516 calls: add_calls.clone(),
3517 })
3518 .build();
3519
3520 let mut stream = agent
3521 .stream_prompt("use the tool")
3522 .with_hook(PanicOnUnknownToolHook)
3523 .multi_turn(3)
3524 .await;
3525 let mut saw_text = false;
3526 let mut saw_completion_call = false;
3527 let mut saw_final_response = false;
3528 let mut saw_tool_call = false;
3529 let mut saw_tool_result = false;
3530 let mut error = None;
3531
3532 while let Some(item) = stream.next().await {
3533 match item {
3534 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(_))) => {
3535 saw_text = true;
3536 }
3537 Ok(MultiTurnStreamItem::CompletionCall(_)) => {
3538 saw_completion_call = true;
3539 }
3540 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Final(
3541 _,
3542 )))
3543 | Ok(MultiTurnStreamItem::FinalResponse(_)) => {
3544 saw_final_response = true;
3545 }
3546 Ok(MultiTurnStreamItem::StreamAssistantItem(
3547 StreamedAssistantContent::ToolCall { .. },
3548 )) => {
3549 saw_tool_call = true;
3550 }
3551 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3552 ..
3553 })) => {
3554 saw_tool_result = true;
3555 }
3556 Ok(_) => {}
3557 Err(err) => {
3558 error = Some(err);
3559 break;
3560 }
3561 }
3562 }
3563
3564 assert!(saw_text);
3565 assert!(!saw_completion_call);
3566 assert!(!saw_final_response);
3567 assert!(!saw_tool_call);
3568 assert!(!saw_tool_result);
3569 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3570 let error = error.expect("completed unknown tool call should fail immediately");
3571 match error {
3572 StreamingError::Prompt(err) => match *err {
3573 PromptError::UnknownToolCall {
3574 tool_name,
3575 available_tools,
3576 allowed_tools,
3577 chat_history,
3578 } => {
3579 assert_eq!(tool_name, "default_api");
3580 assert_eq!(available_tools, vec!["add".to_string()]);
3581 assert_eq!(allowed_tools, vec!["add".to_string()]);
3582 assert!(history_contains_tool_call(&chat_history, "default_api"));
3583 }
3584 other => panic!("expected UnknownToolCall, got {other:?}"),
3585 },
3586 other => panic!("expected prompt streaming error, got {other:?}"),
3587 }
3588 assert_eq!(recorded.request_count(), 1);
3589 }
3590
3591 #[tokio::test]
3592 async fn mixed_streaming_tool_calls_fail_before_any_tool_execution() {
3593 let add_calls = Arc::new(AtomicU32::new(0));
3594 let model = MockCompletionModel::from_stream_turns([
3595 vec![
3596 MockStreamEvent::tool_call(
3597 "tool_call_1",
3598 "add",
3599 serde_json::json!({"x": 1, "y": 2}),
3600 )
3601 .with_call_id("call_1"),
3602 MockStreamEvent::tool_call(
3603 "tool_call_2",
3604 "default_api",
3605 serde_json::json!({"x": 3, "y": 4}),
3606 ),
3607 MockStreamEvent::final_response_with_total_tokens(4),
3608 ],
3609 vec![
3610 MockStreamEvent::text("should not be requested"),
3611 MockStreamEvent::final_response_with_total_tokens(6),
3612 ],
3613 ]);
3614 let recorded = model.clone();
3615 let agent = AgentBuilder::new(model)
3616 .tool(CountingAddTool {
3617 calls: add_calls.clone(),
3618 })
3619 .build();
3620
3621 let mut stream = agent
3622 .stream_prompt("use tools")
3623 .with_hook(PanicOnUnknownToolHook)
3624 .multi_turn(3)
3625 .await;
3626 let mut saw_completion_call = false;
3627 let mut saw_tool_call = false;
3628 let mut saw_tool_result = false;
3629 let mut error = None;
3630
3631 while let Some(item) = stream.next().await {
3632 match item {
3633 Ok(MultiTurnStreamItem::CompletionCall(_)) => {
3634 saw_completion_call = true;
3635 }
3636 Ok(MultiTurnStreamItem::StreamAssistantItem(
3637 StreamedAssistantContent::ToolCall { .. },
3638 )) => {
3639 saw_tool_call = true;
3640 }
3641 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3642 ..
3643 })) => {
3644 saw_tool_result = true;
3645 }
3646 Ok(_) => {}
3647 Err(err) => {
3648 error = Some(err);
3649 break;
3650 }
3651 }
3652 }
3653
3654 assert!(!saw_completion_call);
3655 assert!(!saw_tool_call);
3656 assert!(!saw_tool_result);
3657 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3658 let error = error.expect("mixed unknown streamed tool call should fail");
3659 match error {
3660 StreamingError::Prompt(err) => match *err {
3661 PromptError::UnknownToolCall {
3662 tool_name,
3663 available_tools,
3664 allowed_tools,
3665 chat_history,
3666 } => {
3667 assert_eq!(tool_name, "default_api");
3668 assert_eq!(available_tools, vec!["add".to_string()]);
3669 assert_eq!(allowed_tools, vec!["add".to_string()]);
3670 assert!(history_contains_tool_call(&chat_history, "default_api"));
3671 }
3672 other => panic!("expected UnknownToolCall, got {other:?}"),
3673 },
3674 other => panic!("expected prompt streaming error, got {other:?}"),
3675 }
3676 assert_eq!(recorded.request_count(), 1);
3677 }
3678
3679 #[tokio::test]
3680 async fn multiple_valid_streaming_tool_calls_execute_after_batch_validation() {
3681 let add_calls = Arc::new(AtomicU32::new(0));
3682 let subtract_calls = Arc::new(AtomicU32::new(0));
3683 let model = MockCompletionModel::from_stream_turns([
3684 vec![
3685 MockStreamEvent::tool_call(
3686 "tool_call_1",
3687 "add",
3688 serde_json::json!({"x": 1, "y": 2}),
3689 )
3690 .with_call_id("call_1"),
3691 MockStreamEvent::tool_call(
3692 "tool_call_2",
3693 "subtract",
3694 serde_json::json!({"x": 8, "y": 3}),
3695 )
3696 .with_call_id("call_2"),
3697 MockStreamEvent::final_response_with_total_tokens(4),
3698 ],
3699 vec![
3700 MockStreamEvent::text("done"),
3701 MockStreamEvent::final_response_with_total_tokens(6),
3702 ],
3703 ]);
3704 let recorded = model.clone();
3705 let agent = AgentBuilder::new(model)
3706 .tool(CountingAddTool {
3707 calls: add_calls.clone(),
3708 })
3709 .tool(CountingSubtractTool {
3710 calls: subtract_calls.clone(),
3711 })
3712 .build();
3713
3714 let mut stream = agent.stream_prompt("use tools").multi_turn(3).await;
3715 let mut tool_call_names = Vec::new();
3716 let mut tool_result_ids = Vec::new();
3717 let mut final_response_text = None;
3718
3719 while let Some(item) = stream.next().await {
3720 match item {
3721 Ok(MultiTurnStreamItem::StreamAssistantItem(
3722 StreamedAssistantContent::ToolCall { tool_call, .. },
3723 )) => {
3724 tool_call_names.push(tool_call.function.name);
3725 }
3726 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3727 tool_result,
3728 ..
3729 })) => {
3730 tool_result_ids.push(tool_result.id);
3731 }
3732 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3733 final_response_text = Some(response.response().to_owned());
3734 break;
3735 }
3736 Ok(_) => {}
3737 Err(err) => panic!("unexpected streaming error: {err:?}"),
3738 }
3739 }
3740
3741 assert_eq!(
3742 tool_call_names,
3743 vec!["add".to_string(), "subtract".to_string()]
3744 );
3745 assert_eq!(
3746 tool_result_ids,
3747 vec!["tool_call_1".to_string(), "tool_call_2".to_string()]
3748 );
3749 assert_eq!(add_calls.load(Ordering::SeqCst), 1);
3750 assert_eq!(subtract_calls.load(Ordering::SeqCst), 1);
3751 assert_eq!(final_response_text.as_deref(), Some("done"));
3752 assert_eq!(recorded.request_count(), 2);
3753 }
3754
3755 #[tokio::test]
3756 async fn disallowed_specific_tool_call_fails_before_streaming_second_request() {
3757 let model = MockCompletionModel::from_stream_turns([
3758 vec![
3759 MockStreamEvent::tool_call(
3760 "tool_call_1",
3761 "subtract",
3762 serde_json::json!({"x": 3, "y": 1}),
3763 ),
3764 MockStreamEvent::final_response_with_total_tokens(4),
3765 ],
3766 vec![
3767 MockStreamEvent::text("should not be requested"),
3768 MockStreamEvent::final_response_with_total_tokens(6),
3769 ],
3770 ]);
3771 let recorded = model.clone();
3772 let agent = AgentBuilder::new(model)
3773 .tool(MockAddTool)
3774 .tool(MockSubtractTool)
3775 .tool_choice(ToolChoice::Specific {
3776 function_names: vec!["add".to_string()],
3777 })
3778 .build();
3779
3780 let mut stream = agent
3781 .stream_prompt("use the allowed tool")
3782 .with_hook(PanicOnUnknownToolHook)
3783 .multi_turn(3)
3784 .await;
3785 let mut saw_tool_call = false;
3786 let mut error = None;
3787
3788 while let Some(item) = stream.next().await {
3789 match item {
3790 Ok(MultiTurnStreamItem::StreamAssistantItem(
3791 StreamedAssistantContent::ToolCall { .. },
3792 )) => {
3793 saw_tool_call = true;
3794 }
3795 Ok(_) => {}
3796 Err(err) => {
3797 error = Some(err);
3798 break;
3799 }
3800 }
3801 }
3802
3803 assert!(!saw_tool_call);
3804 let error = error.expect("disallowed model-emitted tool should fail");
3805 match error {
3806 StreamingError::Prompt(err) => match *err {
3807 PromptError::UnknownToolCall {
3808 tool_name,
3809 available_tools,
3810 allowed_tools,
3811 chat_history,
3812 } => {
3813 assert_eq!(tool_name, "subtract");
3814 assert_eq!(
3815 available_tools,
3816 vec!["add".to_string(), "subtract".to_string()]
3817 );
3818 assert_eq!(allowed_tools, vec!["add".to_string()]);
3819 assert!(history_contains_tool_call(&chat_history, "subtract"));
3820 }
3821 other => panic!("expected UnknownToolCall, got {other:?}"),
3822 },
3823 other => panic!("expected prompt streaming error, got {other:?}"),
3824 }
3825 assert_eq!(recorded.request_count(), 1);
3826 }
3827
3828 #[tokio::test]
3829 async fn mixed_specific_tool_calls_fail_before_any_tool_execution() {
3830 let add_calls = Arc::new(AtomicU32::new(0));
3831 let model = MockCompletionModel::from_stream_turns([
3832 vec![
3833 MockStreamEvent::tool_call(
3834 "tool_call_1",
3835 "add",
3836 serde_json::json!({"x": 1, "y": 2}),
3837 ),
3838 MockStreamEvent::tool_call(
3839 "tool_call_2",
3840 "subtract",
3841 serde_json::json!({"x": 3, "y": 1}),
3842 ),
3843 MockStreamEvent::final_response_with_total_tokens(4),
3844 ],
3845 vec![
3846 MockStreamEvent::text("should not be requested"),
3847 MockStreamEvent::final_response_with_total_tokens(6),
3848 ],
3849 ]);
3850 let recorded = model.clone();
3851 let agent = AgentBuilder::new(model)
3852 .tool(CountingAddTool {
3853 calls: add_calls.clone(),
3854 })
3855 .tool(MockSubtractTool)
3856 .tool_choice(ToolChoice::Specific {
3857 function_names: vec!["add".to_string()],
3858 })
3859 .build();
3860
3861 let mut stream = agent
3862 .stream_prompt("use the allowed tool")
3863 .with_hook(PanicOnUnknownToolHook)
3864 .multi_turn(3)
3865 .await;
3866 let mut saw_tool_call = false;
3867 let mut saw_tool_result = false;
3868 let mut error = None;
3869
3870 while let Some(item) = stream.next().await {
3871 match item {
3872 Ok(MultiTurnStreamItem::StreamAssistantItem(
3873 StreamedAssistantContent::ToolCall { .. },
3874 )) => {
3875 saw_tool_call = true;
3876 }
3877 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3878 ..
3879 })) => {
3880 saw_tool_result = true;
3881 }
3882 Ok(_) => {}
3883 Err(err) => {
3884 error = Some(err);
3885 break;
3886 }
3887 }
3888 }
3889
3890 assert!(!saw_tool_call);
3891 assert!(!saw_tool_result);
3892 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3893 let error = error.expect("mixed disallowed streamed tool call should fail");
3894 match error {
3895 StreamingError::Prompt(err) => match *err {
3896 PromptError::UnknownToolCall {
3897 tool_name,
3898 available_tools,
3899 allowed_tools,
3900 chat_history,
3901 } => {
3902 assert_eq!(tool_name, "subtract");
3903 assert_eq!(
3904 available_tools,
3905 vec!["add".to_string(), "subtract".to_string()]
3906 );
3907 assert_eq!(allowed_tools, vec!["add".to_string()]);
3908 assert!(history_contains_tool_call(&chat_history, "subtract"));
3909 }
3910 other => panic!("expected UnknownToolCall, got {other:?}"),
3911 },
3912 other => panic!("expected prompt streaming error, got {other:?}"),
3913 }
3914 assert_eq!(recorded.request_count(), 1);
3915 }
3916
3917 #[tokio::test]
3918 async fn tool_choice_none_rejects_streaming_tool_call() {
3919 let model = MockCompletionModel::from_stream_turns([
3920 vec![
3921 MockStreamEvent::tool_call(
3922 "tool_call_1",
3923 "add",
3924 serde_json::json!({"x": 1, "y": 2}),
3925 ),
3926 MockStreamEvent::final_response_with_total_tokens(4),
3927 ],
3928 vec![
3929 MockStreamEvent::text("should not be requested"),
3930 MockStreamEvent::final_response_with_total_tokens(6),
3931 ],
3932 ]);
3933 let recorded = model.clone();
3934 let agent = AgentBuilder::new(model)
3935 .tool(MockAddTool)
3936 .tool_choice(ToolChoice::None)
3937 .build();
3938
3939 let mut stream = agent
3940 .stream_prompt("do not use tools")
3941 .with_hook(PanicOnUnknownToolHook)
3942 .multi_turn(3)
3943 .await;
3944 let mut saw_tool_call = false;
3945 let mut error = None;
3946
3947 while let Some(item) = stream.next().await {
3948 match item {
3949 Ok(MultiTurnStreamItem::StreamAssistantItem(
3950 StreamedAssistantContent::ToolCall { .. },
3951 )) => {
3952 saw_tool_call = true;
3953 }
3954 Ok(_) => {}
3955 Err(err) => {
3956 error = Some(err);
3957 break;
3958 }
3959 }
3960 }
3961
3962 assert!(!saw_tool_call);
3963 let error = error.expect("ToolChoice::None should reject returned tool calls");
3964 match error {
3965 StreamingError::Prompt(err) => match *err {
3966 PromptError::UnknownToolCall {
3967 tool_name,
3968 available_tools,
3969 allowed_tools,
3970 chat_history,
3971 } => {
3972 assert_eq!(tool_name, "add");
3973 assert_eq!(available_tools, vec!["add".to_string()]);
3974 assert!(allowed_tools.is_empty());
3975 assert!(history_contains_tool_call(&chat_history, "add"));
3976 }
3977 other => panic!("expected UnknownToolCall, got {other:?}"),
3978 },
3979 other => panic!("expected prompt streaming error, got {other:?}"),
3980 }
3981 assert_eq!(recorded.request_count(), 1);
3982 }
3983
3984 #[tokio::test]
3985 async fn tool_choice_none_rejects_streaming_tool_call_name_delta_before_hook_or_emit() {
3986 let model = MockCompletionModel::from_stream_turns([
3987 vec![
3988 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
3989 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
3990 MockStreamEvent::final_response_with_total_tokens(4),
3991 ],
3992 vec![
3993 MockStreamEvent::text("should not be requested"),
3994 MockStreamEvent::final_response_with_total_tokens(6),
3995 ],
3996 ]);
3997 let recorded = model.clone();
3998 let agent = AgentBuilder::new(model)
3999 .tool(MockAddTool)
4000 .tool_choice(ToolChoice::None)
4001 .build();
4002
4003 let mut stream = agent
4004 .stream_prompt("do not use tools")
4005 .with_hook(PanicOnUnknownToolHook)
4006 .multi_turn(3)
4007 .await;
4008 let mut saw_delta = false;
4009 let mut error = None;
4010
4011 while let Some(item) = stream.next().await {
4012 match item {
4013 Ok(MultiTurnStreamItem::StreamAssistantItem(
4014 StreamedAssistantContent::ToolCallDelta { .. },
4015 )) => {
4016 saw_delta = true;
4017 }
4018 Ok(_) => {}
4019 Err(err) => {
4020 error = Some(err);
4021 break;
4022 }
4023 }
4024 }
4025
4026 assert!(!saw_delta);
4027 let error = error.expect("ToolChoice::None should reject returned tool-call deltas");
4028 match error {
4029 StreamingError::Prompt(err) => match *err {
4030 PromptError::UnknownToolCall {
4031 tool_name,
4032 available_tools,
4033 allowed_tools,
4034 chat_history,
4035 } => {
4036 assert_eq!(tool_name, "add");
4037 assert_eq!(available_tools, vec!["add".to_string()]);
4038 assert!(allowed_tools.is_empty());
4039 assert!(history_contains_tool_call(&chat_history, "add"));
4040 }
4041 other => panic!("expected UnknownToolCall, got {other:?}"),
4042 },
4043 other => panic!("expected prompt streaming error, got {other:?}"),
4044 }
4045 assert_eq!(recorded.request_count(), 1);
4046 }
4047
4048 #[tokio::test]
4049 async fn unknown_tool_call_name_delta_fails_before_streaming_delta_hook_or_emit() {
4050 let model = MockCompletionModel::from_stream_turns([
4051 vec![
4052 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"),
4053 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4054 MockStreamEvent::final_response_with_total_tokens(4),
4055 ],
4056 vec![
4057 MockStreamEvent::text("should not be requested"),
4058 MockStreamEvent::final_response_with_total_tokens(6),
4059 ],
4060 ]);
4061 let recorded = model.clone();
4062 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4063
4064 let mut stream = agent
4065 .stream_prompt("stream a bad tool call")
4066 .with_hook(PanicOnUnknownToolHook)
4067 .multi_turn(3)
4068 .await;
4069 let mut saw_delta = false;
4070 let mut error = None;
4071
4072 while let Some(item) = stream.next().await {
4073 match item {
4074 Ok(MultiTurnStreamItem::StreamAssistantItem(
4075 StreamedAssistantContent::ToolCallDelta { .. },
4076 )) => {
4077 saw_delta = true;
4078 }
4079 Ok(_) => {}
4080 Err(err) => {
4081 error = Some(err);
4082 break;
4083 }
4084 }
4085 }
4086
4087 assert!(!saw_delta);
4088 let error = error.expect("unknown tool-call name delta should fail");
4089 match error {
4090 StreamingError::Prompt(err) => match *err {
4091 PromptError::UnknownToolCall {
4092 tool_name,
4093 available_tools,
4094 allowed_tools,
4095 chat_history,
4096 } => {
4097 assert_eq!(tool_name, "default_api");
4098 assert_eq!(available_tools, vec!["add".to_string()]);
4099 assert_eq!(allowed_tools, vec!["add".to_string()]);
4100 assert!(history_contains_tool_call(&chat_history, "default_api"));
4101 }
4102 other => panic!("expected UnknownToolCall, got {other:?}"),
4103 },
4104 other => panic!("expected prompt streaming error, got {other:?}"),
4105 }
4106 assert_eq!(recorded.request_count(), 1);
4107 }
4108
4109 #[tokio::test]
4110 async fn tool_call_args_delta_before_unknown_name_fails_before_hook_or_emit() {
4111 let model = MockCompletionModel::from_stream_turns([
4112 vec![
4113 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4114 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"),
4115 MockStreamEvent::final_response_with_total_tokens(4),
4116 ],
4117 vec![
4118 MockStreamEvent::text("should not be requested"),
4119 MockStreamEvent::final_response_with_total_tokens(6),
4120 ],
4121 ]);
4122 let recorded = model.clone();
4123 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4124
4125 let mut stream = agent
4126 .stream_prompt("stream a bad tool call")
4127 .with_hook(PanicOnUnknownToolHook)
4128 .multi_turn(3)
4129 .await;
4130 let mut saw_delta = false;
4131 let mut error = None;
4132
4133 while let Some(item) = stream.next().await {
4134 match item {
4135 Ok(MultiTurnStreamItem::StreamAssistantItem(
4136 StreamedAssistantContent::ToolCallDelta { .. },
4137 )) => {
4138 saw_delta = true;
4139 }
4140 Ok(_) => {}
4141 Err(err) => {
4142 error = Some(err);
4143 break;
4144 }
4145 }
4146 }
4147
4148 assert!(!saw_delta);
4149 let error = error.expect("unknown tool-call name should reject buffered args");
4150 match error {
4151 StreamingError::Prompt(err) => match *err {
4152 PromptError::UnknownToolCall {
4153 tool_name,
4154 available_tools,
4155 allowed_tools,
4156 chat_history,
4157 } => {
4158 assert_eq!(tool_name, "default_api");
4159 assert_eq!(available_tools, vec!["add".to_string()]);
4160 assert_eq!(allowed_tools, vec!["add".to_string()]);
4161 assert!(history_contains_tool_call(&chat_history, "default_api"));
4162 }
4163 other => panic!("expected UnknownToolCall, got {other:?}"),
4164 },
4165 other => panic!("expected prompt streaming error, got {other:?}"),
4166 }
4167 assert_eq!(recorded.request_count(), 1);
4168 }
4169
4170 #[tokio::test]
4171 async fn tool_call_args_delta_before_valid_name_buffers_then_emits_in_safe_order() {
4172 let model = MockCompletionModel::from_stream_turns([[
4173 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4174 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4175 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4176 MockStreamEvent::final_response_with_total_tokens(3),
4177 ]]);
4178 let hook = RecordingToolCallDeltaHook::default();
4179 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4180
4181 let mut stream = agent
4182 .stream_prompt("stream a tool call")
4183 .with_hook(hook.clone())
4184 .await;
4185 let mut stream_deltas = Vec::new();
4186
4187 while let Some(item) = stream.next().await {
4188 match item {
4189 Ok(MultiTurnStreamItem::StreamAssistantItem(
4190 StreamedAssistantContent::ToolCallDelta {
4191 id,
4192 internal_call_id,
4193 content,
4194 },
4195 )) => {
4196 stream_deltas.push((id, internal_call_id, content));
4197 }
4198 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4199 Ok(_) => {}
4200 Err(err) => panic!("unexpected streaming error: {err:?}"),
4201 }
4202 }
4203
4204 assert_eq!(
4205 hook.observed(),
4206 vec![
4207 (
4208 "tool_1".to_string(),
4209 "internal_1".to_string(),
4210 Some("add".to_string()),
4211 String::new()
4212 ),
4213 (
4214 "tool_1".to_string(),
4215 "internal_1".to_string(),
4216 None,
4217 "{\"x\":".to_string()
4218 ),
4219 (
4220 "tool_1".to_string(),
4221 "internal_1".to_string(),
4222 None,
4223 "1}".to_string()
4224 ),
4225 ]
4226 );
4227 assert_eq!(
4228 stream_deltas,
4229 vec![
4230 (
4231 "tool_1".to_string(),
4232 "internal_1".to_string(),
4233 ToolCallDeltaContent::Name("add".to_string())
4234 ),
4235 (
4236 "tool_1".to_string(),
4237 "internal_1".to_string(),
4238 ToolCallDeltaContent::Delta("{\"x\":".to_string())
4239 ),
4240 (
4241 "tool_1".to_string(),
4242 "internal_1".to_string(),
4243 ToolCallDeltaContent::Delta("1}".to_string())
4244 ),
4245 ]
4246 );
4247 }
4248
4249 #[tokio::test]
4250 async fn tool_call_args_delta_without_name_errors_at_stream_end() {
4251 let model = MockCompletionModel::from_stream_turns([
4252 vec![
4253 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4254 MockStreamEvent::final_response_with_total_tokens(4),
4255 ],
4256 vec![
4257 MockStreamEvent::text("should not be requested"),
4258 MockStreamEvent::final_response_with_total_tokens(6),
4259 ],
4260 ]);
4261 let recorded = model.clone();
4262 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4263
4264 let mut stream = agent
4265 .stream_prompt("stream an incomplete tool call")
4266 .with_hook(PanicOnUnknownToolHook)
4267 .multi_turn(3)
4268 .await;
4269 let mut saw_delta = false;
4270 let mut saw_completion_call = false;
4271 let mut saw_final_response = false;
4272 let mut error = None;
4273
4274 while let Some(item) = stream.next().await {
4275 match item {
4276 Ok(MultiTurnStreamItem::StreamAssistantItem(
4277 StreamedAssistantContent::ToolCallDelta { .. },
4278 )) => {
4279 saw_delta = true;
4280 }
4281 Ok(MultiTurnStreamItem::CompletionCall(_)) => {
4282 saw_completion_call = true;
4283 }
4284 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
4285 saw_final_response = true;
4286 }
4287 Ok(_) => {}
4288 Err(err) => {
4289 error = Some(err);
4290 break;
4291 }
4292 }
4293 }
4294
4295 assert!(!saw_delta);
4296 assert!(!saw_completion_call);
4297 assert!(!saw_final_response);
4298 let error = error.expect("unterminated tool-call args delta should fail");
4299 match error {
4300 StreamingError::Completion(CompletionError::ResponseError(message)) => {
4301 assert!(
4302 message.contains("streamed tool call arguments"),
4303 "{message}"
4304 );
4305 assert!(message.contains("tool_1"), "{message}");
4306 assert!(message.contains("internal_1"), "{message}");
4307 }
4308 other => panic!("expected completion response error, got {other:?}"),
4309 }
4310 assert_eq!(recorded.request_count(), 1);
4311 }
4312
4313 #[tokio::test]
4314 async fn tool_choice_none_buffers_args_then_rejects_name_without_emit() {
4315 let model = MockCompletionModel::from_stream_turns([
4316 vec![
4317 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4318 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4319 MockStreamEvent::final_response_with_total_tokens(4),
4320 ],
4321 vec![
4322 MockStreamEvent::text("should not be requested"),
4323 MockStreamEvent::final_response_with_total_tokens(6),
4324 ],
4325 ]);
4326 let recorded = model.clone();
4327 let agent = AgentBuilder::new(model)
4328 .tool(MockAddTool)
4329 .tool_choice(ToolChoice::None)
4330 .build();
4331
4332 let mut stream = agent
4333 .stream_prompt("do not use tools")
4334 .with_hook(PanicOnUnknownToolHook)
4335 .multi_turn(3)
4336 .await;
4337 let mut saw_delta = false;
4338 let mut error = None;
4339
4340 while let Some(item) = stream.next().await {
4341 match item {
4342 Ok(MultiTurnStreamItem::StreamAssistantItem(
4343 StreamedAssistantContent::ToolCallDelta { .. },
4344 )) => {
4345 saw_delta = true;
4346 }
4347 Ok(_) => {}
4348 Err(err) => {
4349 error = Some(err);
4350 break;
4351 }
4352 }
4353 }
4354
4355 assert!(!saw_delta);
4356 let error = error.expect("ToolChoice::None should reject buffered tool-call deltas");
4357 match error {
4358 StreamingError::Prompt(err) => match *err {
4359 PromptError::UnknownToolCall {
4360 tool_name,
4361 available_tools,
4362 allowed_tools,
4363 chat_history,
4364 } => {
4365 assert_eq!(tool_name, "add");
4366 assert_eq!(available_tools, vec!["add".to_string()]);
4367 assert!(allowed_tools.is_empty());
4368 assert!(history_contains_tool_call(&chat_history, "add"));
4369 }
4370 other => panic!("expected UnknownToolCall, got {other:?}"),
4371 },
4372 other => panic!("expected prompt streaming error, got {other:?}"),
4373 }
4374 assert_eq!(recorded.request_count(), 1);
4375 }
4376
4377 #[tokio::test]
4378 async fn stream_prompt_emits_tool_call_deltas_without_hook() {
4379 let model = MockCompletionModel::from_stream_turns([[
4380 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4381 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4382 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4383 MockStreamEvent::final_response_with_total_tokens(3),
4384 ]]);
4385 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4386
4387 let mut stream = agent.stream_prompt("stream a tool call").await;
4388 let mut deltas = Vec::new();
4389
4390 while let Some(item) = stream.next().await {
4391 match item {
4392 Ok(MultiTurnStreamItem::StreamAssistantItem(
4393 StreamedAssistantContent::ToolCallDelta {
4394 id,
4395 internal_call_id,
4396 content,
4397 },
4398 )) => {
4399 deltas.push((id, internal_call_id, content));
4400 }
4401 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4402 Ok(_) => {}
4403 Err(err) => panic!("unexpected streaming error: {err:?}"),
4404 }
4405 }
4406
4407 assert_eq!(
4408 deltas,
4409 vec![
4410 (
4411 "tool_1".to_string(),
4412 "internal_1".to_string(),
4413 ToolCallDeltaContent::Name("add".to_string())
4414 ),
4415 (
4416 "tool_1".to_string(),
4417 "internal_1".to_string(),
4418 ToolCallDeltaContent::Delta("{\"x\":".to_string())
4419 ),
4420 (
4421 "tool_1".to_string(),
4422 "internal_1".to_string(),
4423 ToolCallDeltaContent::Delta("1}".to_string())
4424 ),
4425 ]
4426 );
4427 }
4428
4429 #[tokio::test]
4430 async fn stream_prompt_emits_tool_call_deltas_after_hook_continue() {
4431 let model = MockCompletionModel::from_stream_turns([[
4432 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4433 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4434 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4435 MockStreamEvent::final_response_with_total_tokens(3),
4436 ]]);
4437 let hook = RecordingToolCallDeltaHook::default();
4438 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4439
4440 let mut stream = agent
4441 .stream_prompt("stream a tool call")
4442 .with_hook(hook.clone())
4443 .await;
4444 let mut stream_deltas = Vec::new();
4445
4446 while let Some(item) = stream.next().await {
4447 match item {
4448 Ok(MultiTurnStreamItem::StreamAssistantItem(
4449 StreamedAssistantContent::ToolCallDelta {
4450 id,
4451 internal_call_id,
4452 content,
4453 },
4454 )) => {
4455 stream_deltas.push((id, internal_call_id, content));
4456 }
4457 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4458 Ok(_) => {}
4459 Err(err) => panic!("unexpected streaming error: {err:?}"),
4460 }
4461 }
4462
4463 assert_eq!(
4464 hook.observed(),
4465 vec![
4466 (
4467 "tool_1".to_string(),
4468 "internal_1".to_string(),
4469 Some("add".to_string()),
4470 String::new()
4471 ),
4472 (
4473 "tool_1".to_string(),
4474 "internal_1".to_string(),
4475 None,
4476 "{\"x\":".to_string()
4477 ),
4478 (
4479 "tool_1".to_string(),
4480 "internal_1".to_string(),
4481 None,
4482 "1}".to_string()
4483 ),
4484 ]
4485 );
4486 assert_eq!(
4487 stream_deltas,
4488 vec![
4489 (
4490 "tool_1".to_string(),
4491 "internal_1".to_string(),
4492 ToolCallDeltaContent::Name("add".to_string())
4493 ),
4494 (
4495 "tool_1".to_string(),
4496 "internal_1".to_string(),
4497 ToolCallDeltaContent::Delta("{\"x\":".to_string())
4498 ),
4499 (
4500 "tool_1".to_string(),
4501 "internal_1".to_string(),
4502 ToolCallDeltaContent::Delta("1}".to_string())
4503 ),
4504 ]
4505 );
4506 }
4507
4508 #[tokio::test]
4509 async fn stream_prompt_tool_call_deltas_hook_termination_prevents_delta_emit() {
4510 let model = MockCompletionModel::from_stream_turns([[
4511 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4512 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4513 MockStreamEvent::final_response_with_total_tokens(3),
4514 ]]);
4515 let hook = TerminatingToolCallDeltaHook::default();
4516 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4517
4518 let mut stream = agent
4519 .stream_prompt("stream a tool call")
4520 .with_hook(hook.clone())
4521 .await;
4522 let mut saw_delta = false;
4523 let mut saw_final_response = false;
4524 let mut error_message = None;
4525
4526 while let Some(item) = stream.next().await {
4527 match item {
4528 Ok(MultiTurnStreamItem::StreamAssistantItem(
4529 StreamedAssistantContent::ToolCallDelta { .. },
4530 )) => {
4531 saw_delta = true;
4532 }
4533 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
4534 saw_final_response = true;
4535 }
4536 Ok(_) => {}
4537 Err(err) => {
4538 error_message = Some(err.to_string());
4539 break;
4540 }
4541 }
4542 }
4543
4544 assert_eq!(
4545 hook.observed(),
4546 vec![(
4547 "tool_1".to_string(),
4548 "internal_1".to_string(),
4549 Some("add".to_string()),
4550 String::new()
4551 )]
4552 );
4553 assert!(!saw_delta);
4554 assert!(!saw_final_response);
4555 assert!(
4556 error_message
4557 .as_deref()
4558 .is_some_and(|message| message.contains("PromptCancelled: stop on tool call delta")),
4559 "expected hook termination error, got {error_message:?}"
4560 );
4561 }
4562
4563 #[tokio::test]
4564 async fn stream_prompt_exposes_completion_calls() {
4565 let first_call_usage = usage(10, 2);
4566 let second_call_usage = usage(25, 5);
4567 let model = MockCompletionModel::from_stream_turns([
4568 vec![
4569 MockStreamEvent::tool_call(
4570 "tool_call_1",
4571 "add",
4572 serde_json::json!({"x": 1, "y": 2}),
4573 )
4574 .with_call_id("call_1"),
4575 MockStreamEvent::final_response(first_call_usage),
4576 ],
4577 vec![
4578 MockStreamEvent::text("done"),
4579 MockStreamEvent::final_response(second_call_usage),
4580 ],
4581 ]);
4582 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4583 let empty_history: &[Message] = &[];
4584
4585 let mut stream = agent
4586 .stream_prompt("do tool work")
4587 .with_history(empty_history)
4588 .multi_turn(3)
4589 .await;
4590 let mut completion_calls_events = Vec::new();
4591 let mut final_response = None;
4592
4593 while let Some(item) = stream.next().await {
4594 match item {
4595 Ok(MultiTurnStreamItem::CompletionCall(call_usage)) => {
4596 completion_calls_events.push(call_usage);
4597 }
4598 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4599 final_response = Some(response);
4600 break;
4601 }
4602 Ok(_) => {}
4603 Err(err) => panic!("unexpected streaming error: {err:?}"),
4604 }
4605 }
4606
4607 assert_eq!(
4608 completion_calls_events,
4609 vec![
4610 CompletionCall::new(0, first_call_usage),
4611 CompletionCall::new(1, second_call_usage)
4612 ]
4613 );
4614
4615 let final_response = final_response.expect("expected final response");
4616 assert_eq!(
4617 final_response.usage(),
4618 Usage {
4619 input_tokens: 35,
4620 output_tokens: 7,
4621 total_tokens: 42,
4622 cached_input_tokens: 0,
4623 cache_creation_input_tokens: 0,
4624 tool_use_prompt_tokens: 0,
4625 reasoning_tokens: 0,
4626 }
4627 );
4628 assert_eq!(
4629 final_response.completion_calls(),
4630 &[
4631 CompletionCall::new(0, first_call_usage),
4632 CompletionCall::new(1, second_call_usage)
4633 ]
4634 );
4635 }
4636
4637 #[tokio::test(flavor = "current_thread")]
4638 async fn stream_prompt_records_single_call_usage_on_chat_span_under_outer_span() {
4639 let call_usage = usage(10, 2);
4640 let model = MockCompletionModel::from_stream_turns([[
4641 MockStreamEvent::text("done"),
4642 MockStreamEvent::final_response(call_usage),
4643 ]]);
4644 let agent = AgentBuilder::new(model).build();
4645
4646 assert_stream_usage_recorded_on_chat_spans(agent, "say done", 1, &[call_usage]).await;
4647 }
4648
4649 #[tokio::test(flavor = "current_thread")]
4650 async fn stream_prompt_records_multi_turn_usage_on_chat_spans_under_outer_span() {
4651 let first_call_usage = usage(10, 2);
4652 let second_call_usage = usage(25, 5);
4653 let model = MockCompletionModel::from_stream_turns([
4654 vec![
4655 MockStreamEvent::tool_call(
4656 "tool_call_1",
4657 "add",
4658 serde_json::json!({"x": 1, "y": 2}),
4659 )
4660 .with_call_id("call_1"),
4661 MockStreamEvent::final_response(first_call_usage),
4662 ],
4663 vec![
4664 MockStreamEvent::text("done"),
4665 MockStreamEvent::final_response(second_call_usage),
4666 ],
4667 ]);
4668 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4669
4670 assert_stream_usage_recorded_on_chat_spans(
4671 agent,
4672 "do tool work",
4673 3,
4674 &[first_call_usage, second_call_usage],
4675 )
4676 .await;
4677 }
4678
4679 #[tokio::test]
4680 async fn stream_prompt_emits_completion_call_before_finish_hook_termination() {
4681 let call_usage = usage(10, 2);
4682 let model = MockCompletionModel::from_stream_turns([[
4683 MockStreamEvent::text("done"),
4684 MockStreamEvent::final_response(call_usage),
4685 ]]);
4686 let agent = AgentBuilder::new(model).build();
4687
4688 let mut stream = agent
4689 .stream_prompt("say done")
4690 .with_hook(TerminateOnStreamFinish)
4691 .await;
4692 let mut completion_calls = Vec::new();
4693 let mut saw_error = false;
4694
4695 while let Some(item) = stream.next().await {
4696 match item {
4697 Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
4698 completion_calls.push(completion_call);
4699 }
4700 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4701 panic!("unexpected final response after hook termination: {response:?}");
4702 }
4703 Ok(_) => {}
4704 Err(_) => {
4705 saw_error = true;
4706 break;
4707 }
4708 }
4709 }
4710
4711 assert_eq!(completion_calls, vec![CompletionCall::new(0, call_usage)]);
4712 assert!(saw_error);
4713 }
4714
4715 #[tokio::test]
4716 async fn stream_prompt_completion_calls_records_unreported_usage() {
4717 let second_call_usage = usage(25, 5);
4718 let model = MockCompletionModel::from_stream_turns([
4719 vec![
4720 MockStreamEvent::tool_call(
4721 "tool_call_1",
4722 "add",
4723 serde_json::json!({"x": 1, "y": 2}),
4724 )
4725 .with_call_id("call_1"),
4726 ],
4727 vec![
4728 MockStreamEvent::text("done"),
4729 MockStreamEvent::final_response(second_call_usage),
4730 ],
4731 ]);
4732 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4733 let empty_history: &[Message] = &[];
4734
4735 let mut stream = agent
4736 .stream_prompt("do tool work")
4737 .with_history(empty_history)
4738 .multi_turn(3)
4739 .await;
4740 let mut completion_calls_events = Vec::new();
4741 let mut final_response = None;
4742
4743 while let Some(item) = stream.next().await {
4744 match item {
4745 Ok(MultiTurnStreamItem::CompletionCall(call_usage)) => {
4746 completion_calls_events.push(call_usage);
4747 }
4748 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4749 final_response = Some(response);
4750 break;
4751 }
4752 Ok(_) => {}
4753 Err(err) => panic!("unexpected streaming error: {err:?}"),
4754 }
4755 }
4756
4757 let expected_usage = vec![
4758 CompletionCall::new(0, Usage::new()),
4759 CompletionCall::new(1, second_call_usage),
4760 ];
4761 assert_eq!(completion_calls_events, expected_usage);
4762
4763 let final_response = final_response.expect("expected final response");
4764 assert_eq!(final_response.completion_calls(), expected_usage.as_slice());
4765 }
4766
4767 #[tokio::test]
4768 async fn final_response_matches_streamed_text_when_provider_final_is_textless() {
4769 let agent = AgentBuilder::new(streaming_text_then_final_model()).build();
4770
4771 let mut stream = agent.stream_prompt("say hello").await;
4772 let mut streamed_text = String::new();
4773 let mut final_response_text = None;
4774
4775 while let Some(item) = stream.next().await {
4776 match item {
4777 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
4778 text,
4779 ))) => streamed_text.push_str(&text.text),
4780 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
4781 final_response_text = Some(res.response().to_owned());
4782 break;
4783 }
4784 Ok(_) => {}
4785 Err(err) => panic!("unexpected streaming error: {err:?}"),
4786 }
4787 }
4788
4789 assert_eq!(streamed_text, "hello world");
4790 assert_eq!(final_response_text.as_deref(), Some("hello world"));
4791 }
4792
4793 #[tokio::test]
4794 async fn final_response_preserves_structured_text_metadata() {
4795 let agent = AgentBuilder::new(streaming_cited_text_then_final_model()).build();
4796
4797 let mut stream = agent.stream_prompt("answer with citations").await;
4798 let mut final_response = None;
4799
4800 while let Some(item) = stream.next().await {
4801 match item {
4802 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
4803 final_response = Some(res);
4804 break;
4805 }
4806 Ok(_) => {}
4807 Err(err) => panic!("unexpected streaming error: {err:?}"),
4808 }
4809 }
4810
4811 let final_response = final_response.expect("expected final response");
4812 assert_eq!(final_response.response(), "cited answer");
4813 let metadata = text_metadata(final_response.content())
4814 .expect("expected text metadata in final content");
4815 assert_eq!(
4816 metadata["citations"][0]["encrypted_index"],
4817 "encrypted-reference"
4818 );
4819 }
4820
4821 #[tokio::test]
4822 async fn final_response_history_preserves_structured_text_metadata() {
4823 let agent = AgentBuilder::new(streaming_cited_text_then_final_model()).build();
4824
4825 let empty_history: &[Message] = &[];
4826 let mut stream = agent
4827 .stream_prompt("answer with citations")
4828 .with_history(empty_history)
4829 .await;
4830 let mut final_response = None;
4831
4832 while let Some(item) = stream.next().await {
4833 match item {
4834 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
4835 final_response = Some(res);
4836 break;
4837 }
4838 Ok(_) => {}
4839 Err(err) => panic!("unexpected streaming error: {err:?}"),
4840 }
4841 }
4842
4843 let final_response = final_response.expect("expected final response");
4844 let history = final_response
4845 .history()
4846 .expect("with_history should include final history");
4847 let assistant_content = history
4848 .iter()
4849 .find_map(|message| match message {
4850 Message::Assistant { content, .. } => Some(content),
4851 _ => None,
4852 })
4853 .expect("expected assistant message in history");
4854 let metadata =
4855 text_metadata(assistant_content).expect("expected text metadata in assistant history");
4856 assert_eq!(
4857 metadata["citations"][0]["encrypted_index"],
4858 "encrypted-reference"
4859 );
4860 }
4861
4862 #[tokio::test]
4863 async fn tool_follow_up_history_preserves_structured_text_metadata() {
4864 let model = streaming_cited_text_then_tool_model();
4865 let recorded = model.clone();
4866 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4867 let empty_history: &[Message] = &[];
4868
4869 let mut stream = agent
4870 .stream_prompt("use a tool with citations")
4871 .with_history(empty_history)
4872 .multi_turn(3)
4873 .await;
4874
4875 while let Some(item) = stream.next().await {
4876 match item {
4877 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4878 Ok(_) => {}
4879 Err(err) => panic!("unexpected streaming error: {err:?}"),
4880 }
4881 }
4882
4883 let requests = recorded.requests();
4884 assert_eq!(requests.len(), 2);
4885 let follow_up_history = requests[1].chat_history.iter().collect::<Vec<_>>();
4886 let assistant_content = follow_up_history
4887 .iter()
4888 .find_map(|message| match message {
4889 Message::Assistant { content, .. } => Some(content),
4890 _ => None,
4891 })
4892 .expect("expected assistant message in follow-up history");
4893 let metadata = text_metadata(assistant_content)
4894 .expect("expected citation metadata in follow-up assistant history");
4895 assert_eq!(
4896 metadata["citations"][0]["encrypted_index"],
4897 "encrypted-reference"
4898 );
4899 }
4900
4901 #[tokio::test]
4902 async fn final_response_can_remain_empty_for_truly_textless_turns() {
4903 let agent = AgentBuilder::new(streaming_final_only_model()).build();
4904
4905 let mut stream = agent.stream_prompt("say nothing").await;
4906 let mut streamed_text = String::new();
4907 let mut final_response_text = None;
4908
4909 while let Some(item) = stream.next().await {
4910 match item {
4911 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
4912 text,
4913 ))) => streamed_text.push_str(&text.text),
4914 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
4915 final_response_text = Some(res.response().to_owned());
4916 break;
4917 }
4918 Ok(_) => {}
4919 Err(err) => panic!("unexpected streaming error: {err:?}"),
4920 }
4921 }
4922
4923 assert!(streamed_text.is_empty());
4924 assert_eq!(final_response_text.as_deref(), Some(""));
4925 }
4926
4927 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
4930 let mut interval = tokio::time::interval(Duration::from_millis(50));
4931 let mut count = 0u32;
4932
4933 while !stop.load(Ordering::Relaxed) {
4934 interval.tick().await;
4935 count += 1;
4936
4937 tracing::event!(
4938 target: "background_logger",
4939 tracing::Level::INFO,
4940 count = count,
4941 "Background tick"
4942 );
4943
4944 let current = tracing::Span::current();
4946 if !current.is_disabled() && !current.is_none() {
4947 leak_count.fetch_add(1, Ordering::Relaxed);
4948 }
4949 }
4950
4951 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
4952 }
4953
4954 #[tokio::test(flavor = "current_thread")]
4962 #[ignore = "This requires an API key"]
4963 async fn test_span_context_isolation() -> anyhow::Result<()> {
4964 let stop = Arc::new(AtomicBool::new(false));
4965 let leak_count = Arc::new(AtomicU32::new(0));
4966
4967 let bg_stop = stop.clone();
4969 let bg_leak = leak_count.clone();
4970 let bg_handle = tokio::spawn(async move {
4971 background_logger(bg_stop, bg_leak).await;
4972 });
4973
4974 tokio::time::sleep(Duration::from_millis(100)).await;
4976
4977 let client = anthropic::Client::from_env()?;
4980 let agent = client
4981 .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
4982 .preamble("You are a helpful assistant.")
4983 .temperature(0.1)
4984 .max_tokens(100)
4985 .build();
4986
4987 let mut stream = agent
4988 .stream_prompt("Say 'hello world' and nothing else.")
4989 .await;
4990
4991 let mut full_content = String::new();
4992 while let Some(item) = stream.next().await {
4993 match item {
4994 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
4995 text,
4996 ))) => {
4997 full_content.push_str(&text.text);
4998 }
4999 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
5000 break;
5001 }
5002 Err(e) => {
5003 tracing::warn!("Error: {:?}", e);
5004 break;
5005 }
5006 _ => {}
5007 }
5008 }
5009
5010 tracing::info!("Got response: {:?}", full_content);
5011
5012 stop.store(true, Ordering::Relaxed);
5014 bg_handle.await?;
5015
5016 let leaks = leak_count.load(Ordering::Relaxed);
5017 anyhow::ensure!(
5018 leaks == 0,
5019 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
5020 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
5021 );
5022
5023 Ok(())
5024 }
5025
5026 #[tokio::test]
5032 #[ignore = "This requires an API key"]
5033 async fn test_chat_history_in_final_response() -> anyhow::Result<()> {
5034 use crate::message::Message;
5035
5036 let client = anthropic::Client::from_env()?;
5037 let agent = client
5038 .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
5039 .preamble("You are a helpful assistant. Keep responses brief.")
5040 .temperature(0.1)
5041 .max_tokens(50)
5042 .build();
5043
5044 let empty_history: &[Message] = &[];
5046 let mut stream = agent
5047 .stream_prompt("Say 'hello' and nothing else.")
5048 .with_history(empty_history)
5049 .await;
5050
5051 let mut response_text = String::new();
5053 let mut final_history = None;
5054 while let Some(item) = stream.next().await {
5055 match item {
5056 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5057 text,
5058 ))) => {
5059 response_text.push_str(&text.text);
5060 }
5061 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5062 final_history = res.history().map(|h| h.to_vec());
5063 break;
5064 }
5065 Err(e) => {
5066 return Err(e.into());
5067 }
5068 _ => {}
5069 }
5070 }
5071
5072 let history = final_history
5073 .ok_or_else(|| anyhow::anyhow!("final response should include history"))?;
5074
5075 anyhow::ensure!(
5077 history.iter().any(|m| matches!(m, Message::User { .. })),
5078 "History should contain the user message"
5079 );
5080
5081 anyhow::ensure!(
5083 history
5084 .iter()
5085 .any(|m| matches!(m, Message::Assistant { .. })),
5086 "History should contain the assistant response"
5087 );
5088
5089 tracing::info!(
5090 "History after streaming: {} messages, response: {:?}",
5091 history.len(),
5092 response_text
5093 );
5094
5095 Ok(())
5096 }
5097
5098 #[tokio::test]
5099 async fn streaming_appends_to_memory_after_final_response() {
5100 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5101
5102 let memory = InMemoryConversationMemory::new();
5103 let agent = AgentBuilder::new(streaming_text_then_final_model())
5104 .memory(memory.clone())
5105 .build();
5106
5107 let mut stream = agent
5108 .stream_prompt("hi there")
5109 .conversation("stream-thread")
5110 .await;
5111
5112 let mut history_in_final = None;
5113 while let Some(item) = stream.next().await {
5114 match item {
5115 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5116 history_in_final = res.history().map(|h| h.to_vec());
5117 break;
5118 }
5119 Ok(_) => {}
5120 Err(err) => panic!("unexpected streaming error: {err:?}"),
5121 }
5122 }
5123
5124 let final_history = history_in_final
5125 .expect("FinalResponse.history should be populated when memory is configured");
5126 assert_eq!(
5127 final_history.len(),
5128 2,
5129 "user prompt + assistant response in final history: {final_history:?}"
5130 );
5131
5132 let stored = memory.load("stream-thread").await.unwrap();
5133 assert_eq!(stored.len(), 2, "memory should contain user + assistant");
5134 }
5135
5136 #[tokio::test]
5137 async fn streaming_reasoning_without_tools_does_not_duplicate_final_history() {
5138 let agent = AgentBuilder::new(MockCompletionModel::from_stream_turns([[
5139 MockStreamEvent::text("final answer"),
5140 MockStreamEvent::reasoning("reasoned step").with_reasoning_id("rs_1"),
5141 MockStreamEvent::final_response_with_total_tokens(3),
5142 ]]))
5143 .build();
5144
5145 let mut stream = agent
5146 .stream_prompt("think before answering")
5147 .with_history(Vec::<Message>::new())
5148 .await;
5149
5150 let mut history_in_final = None;
5151 while let Some(item) = stream.next().await {
5152 match item {
5153 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5154 history_in_final = res.history().map(|h| h.to_vec());
5155 break;
5156 }
5157 Ok(_) => {}
5158 Err(err) => panic!("unexpected streaming error: {err:?}"),
5159 }
5160 }
5161
5162 let final_history = history_in_final
5163 .expect("FinalResponse.history should be populated when with_history is used");
5164 assert_eq!(
5165 final_history.len(),
5166 2,
5167 "user prompt + one assistant response in final history: {final_history:?}"
5168 );
5169
5170 assert!(matches!(
5171 final_history.first(),
5172 Some(Message::User { content })
5173 if matches!(
5174 content.first(),
5175 UserContent::Text(text) if text.text == "think before answering"
5176 )
5177 ));
5178
5179 let assistant_messages = final_history
5180 .iter()
5181 .filter_map(|message| match message {
5182 Message::Assistant { content, .. } => Some(content),
5183 _ => None,
5184 })
5185 .collect::<Vec<_>>();
5186 assert_eq!(
5187 assistant_messages.len(),
5188 1,
5189 "reasoning turn should produce exactly one assistant history message: {final_history:?}"
5190 );
5191 let assistant_content = assistant_messages
5192 .first()
5193 .expect("expected assistant history message");
5194 assert!(assistant_content.iter().any(|item| matches!(
5195 item,
5196 AssistantContent::Text(text) if text.text == "final answer"
5197 )));
5198 assert!(assistant_content.iter().any(|item| matches!(
5199 item,
5200 AssistantContent::Reasoning(reasoning)
5201 if reasoning.id.as_deref() == Some("rs_1")
5202 && reasoning.content.iter().any(|content| matches!(
5203 content,
5204 ReasoningContent::Text { text, .. } if text == "reasoned step"
5205 ))
5206 )));
5207 let reasoning_index = assistant_content
5208 .iter()
5209 .position(|item| matches!(item, AssistantContent::Reasoning(_)))
5210 .expect("assistant history should contain reasoning");
5211 let text_index = assistant_content
5212 .iter()
5213 .position(|item| matches!(item, AssistantContent::Text(_)))
5214 .expect("assistant history should contain text");
5215 assert!(
5216 reasoning_index < text_index,
5217 "assistant reasoning must be stored before assistant text: {assistant_content:?}"
5218 );
5219 }
5220
5221 #[tokio::test]
5222 async fn streaming_with_history_overrides_memory() {
5223 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5224
5225 let memory = InMemoryConversationMemory::new();
5226 memory
5227 .append("t1", vec![Message::user("from-memory")])
5228 .await
5229 .unwrap();
5230
5231 let agent = AgentBuilder::new(streaming_text_then_final_model())
5232 .memory(memory.clone())
5233 .build();
5234
5235 let mut stream = agent
5236 .stream_prompt("hi")
5237 .conversation("t1")
5238 .with_history(vec![Message::user("from-caller")])
5239 .await;
5240
5241 while let Some(item) = stream.next().await {
5242 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5243 break;
5244 }
5245 }
5246
5247 let stored = memory.load("t1").await.unwrap();
5248 assert_eq!(
5249 stored.len(),
5250 1,
5251 "with_history bypasses memory; only the pre-seeded entry remains: {stored:?}"
5252 );
5253 }
5254
5255 #[tokio::test]
5256 async fn streaming_without_memory_disables_for_request() {
5257 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5258
5259 let memory = InMemoryConversationMemory::new();
5260 let agent = AgentBuilder::new(streaming_text_then_final_model())
5261 .memory(memory.clone())
5262 .conversation_id("default")
5263 .build();
5264
5265 let mut stream = agent.stream_prompt("hi").without_memory().await;
5266
5267 while let Some(item) = stream.next().await {
5268 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5269 break;
5270 }
5271 }
5272
5273 let stored = memory.load("default").await.unwrap();
5274 assert!(stored.is_empty(), "without_memory disables save");
5275 }
5276
5277 #[tokio::test]
5278 async fn streaming_load_error_yields_memory_error() {
5279 let agent = AgentBuilder::new(streaming_text_then_final_model())
5280 .memory(FailingMemory::default())
5281 .build();
5282
5283 let mut stream = agent.stream_prompt("hi").conversation("t1").await;
5284
5285 let first = stream.next().await.expect("at least one item");
5286 match first {
5287 Err(err) => {
5288 let msg = format!("{err:?}");
5289 assert!(
5290 msg.contains("Memory") || msg.contains("memory") || msg.contains("load boom"),
5291 "expected memory error, got: {msg}"
5292 );
5293 }
5294 Ok(other) => panic!("expected memory error, got {other:?}"),
5295 }
5296 }
5297
5298 #[tokio::test]
5299 async fn streaming_with_filter_shapes_loaded_history() {
5300 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5301
5302 let memory = InMemoryConversationMemory::new()
5303 .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
5304 memory
5305 .append(
5306 "t1",
5307 vec![
5308 Message::user("1"),
5309 Message::assistant("2"),
5310 Message::user("3"),
5311 Message::assistant("4"),
5312 ],
5313 )
5314 .await
5315 .unwrap();
5316
5317 let model = MockCompletionModel::from_stream_turns([[
5318 MockStreamEvent::text("ok"),
5319 MockStreamEvent::final_response_with_total_tokens(1),
5320 ]]);
5321 let recorded = model.clone();
5322 let agent = AgentBuilder::new(model).memory(memory).build();
5323
5324 let mut stream = agent.stream_prompt("ping").conversation("t1").await;
5325 while let Some(item) = stream.next().await {
5326 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5327 break;
5328 }
5329 }
5330
5331 let received = recorded.requests()[0]
5332 .chat_history
5333 .iter()
5334 .cloned()
5335 .collect::<Vec<_>>();
5336 assert_eq!(
5337 received.len(),
5338 3,
5339 "window-truncated history (2) + current prompt: {received:?}"
5340 );
5341 }
5342
5343 #[tokio::test]
5344 async fn streaming_append_error_does_not_suppress_final_response() {
5345 let agent = AgentBuilder::new(streaming_text_then_final_model())
5346 .memory(AppendFailingMemory::default())
5347 .build();
5348
5349 let mut stream = agent.stream_prompt("hi").conversation("t1").await;
5350
5351 let mut saw_final = false;
5352 while let Some(item) = stream.next().await {
5353 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5354 saw_final = true;
5355 break;
5356 }
5357 }
5358 assert!(
5359 saw_final,
5360 "FinalResponse must be yielded even when memory.append fails"
5361 );
5362 }
5363}