1use crate::{
2 OneOrMany,
3 agent::completion::{DynamicContextStore, build_prepared_completion_request},
4 agent::prompt_request::{
5 HookAction, InvalidToolCallResolution, TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER,
6 hooks::PromptHook, resolve_invalid_tool_call, validate_tool_call_name,
7 },
8 completion::{Document, GetTokenUsage},
9 json_utils,
10 memory::ConversationMemory,
11 message::{
12 AssistantContent, ToolCall, ToolChoice, ToolFunction, ToolResult, ToolResultContent,
13 UserContent,
14 },
15 streaming::{StreamedAssistantContent, StreamedUserContent, ToolCallDeltaContent},
16 tool::server::ToolServerHandle,
17 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
18};
19use futures::{Stream, StreamExt};
20use serde::{Deserialize, Serialize};
21use std::{collections::HashMap, pin::Pin, sync::Arc};
22use tracing::info_span;
23use tracing_futures::Instrument;
24
25use super::{CompletionCall, ToolCallHookAction, reported_usage};
26use crate::{
27 agent::Agent,
28 completion::{CompletionError, CompletionModel, PromptError},
29 message::{Message, Text},
30 tool::ToolSetError,
31};
32
33#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
34pub type StreamingResult<R> =
35 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
36
37#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
38pub type StreamingResult<R> =
39 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
40
41#[derive(Deserialize, Serialize, Debug, Clone)]
42#[serde(tag = "type", rename_all = "camelCase")]
43#[non_exhaustive]
44pub enum MultiTurnStreamItem<R> {
45 StreamAssistantItem(StreamedAssistantContent<R>),
47 StreamUserItem(StreamedUserContent),
49 CompletionCall(CompletionCall),
64 FinalResponse(FinalResponse),
66}
67
68#[derive(Deserialize, Serialize, Debug, Clone)]
69#[serde(rename_all = "camelCase")]
70pub struct FinalResponse {
71 content: OneOrMany<AssistantContent>,
73 response: String,
76 aggregated_usage: crate::completion::Usage,
77 #[serde(default, skip_serializing_if = "Vec::is_empty")]
79 completion_calls: Vec<CompletionCall>,
80 #[serde(skip_serializing_if = "Option::is_none")]
81 history: Option<Vec<Message>>,
82}
83
84impl FinalResponse {
85 pub fn empty() -> Self {
86 Self::new(
87 OneOrMany::one(AssistantContent::text("")),
88 crate::completion::Usage::new(),
89 None,
90 )
91 }
92
93 pub fn new(
94 content: OneOrMany<AssistantContent>,
95 aggregated_usage: crate::completion::Usage,
96 history: Option<Vec<Message>>,
97 ) -> Self {
98 let response = assistant_text_from_choice(&content);
99 Self {
100 content,
101 response,
102 aggregated_usage,
103 completion_calls: Vec::new(),
104 history,
105 }
106 }
107
108 pub fn response(&self) -> &str {
110 &self.response
111 }
112
113 pub fn content(&self) -> &OneOrMany<AssistantContent> {
115 &self.content
116 }
117
118 pub fn assistant_content(&self) -> &OneOrMany<AssistantContent> {
120 &self.content
121 }
122
123 pub fn usage(&self) -> crate::completion::Usage {
124 self.aggregated_usage
125 }
126
127 pub fn completion_calls(&self) -> &[CompletionCall] {
134 &self.completion_calls
135 }
136
137 pub fn history(&self) -> Option<&[Message]> {
138 self.history.as_deref()
139 }
140}
141
142impl<R> MultiTurnStreamItem<R> {
143 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
144 Self::StreamAssistantItem(item)
145 }
146
147 pub fn final_response(
148 content: OneOrMany<AssistantContent>,
149 aggregated_usage: crate::completion::Usage,
150 ) -> Self {
151 Self::FinalResponse(FinalResponse::new(content, aggregated_usage, None))
152 }
153
154 pub fn final_response_with_history(
155 content: OneOrMany<AssistantContent>,
156 aggregated_usage: crate::completion::Usage,
157 history: Option<Vec<Message>>,
158 ) -> Self {
159 Self::FinalResponse(FinalResponse::new(content, aggregated_usage, history))
160 }
161
162 pub(crate) fn final_response_with_completion_calls(
163 content: OneOrMany<AssistantContent>,
164 aggregated_usage: crate::completion::Usage,
165 completion_calls: Vec<CompletionCall>,
166 history: Option<Vec<Message>>,
167 ) -> Self {
168 let mut response = FinalResponse::new(content, aggregated_usage, history);
169 response.completion_calls = completion_calls;
170 Self::FinalResponse(response)
171 }
172}
173
174fn merge_reasoning_blocks(
175 accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
176 incoming: &crate::message::Reasoning,
177) {
178 let ids_match = |existing: &crate::message::Reasoning| {
179 matches!(
180 (&existing.id, &incoming.id),
181 (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
182 )
183 };
184
185 if let Some(existing) = accumulated_reasoning
186 .iter_mut()
187 .rev()
188 .find(|existing| ids_match(existing))
189 {
190 existing.content.extend(incoming.content.clone());
191 } else {
192 accumulated_reasoning.push(incoming.clone());
193 }
194}
195
196fn flush_pending_reasoning_delta(
197 accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
198 pending_reasoning_delta_text: &mut String,
199 pending_reasoning_delta_id: &mut Option<String>,
200) {
201 if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
202 let mut assembled = crate::message::Reasoning::new(&*pending_reasoning_delta_text);
203 if let Some(id) = pending_reasoning_delta_id.take() {
204 assembled = assembled.with_id(id);
205 }
206 accumulated_reasoning.push(assembled);
207 pending_reasoning_delta_text.clear();
208 }
209}
210
211fn build_full_history(
213 chat_history: Option<&[Message]>,
214 new_messages: Vec<Message>,
215) -> Vec<Message> {
216 let input = chat_history.unwrap_or(&[]);
217 input.iter().cloned().chain(new_messages).collect()
218}
219
220struct ToolCallValidationHistory<'a> {
221 chat_history: Option<&'a [Message]>,
222 new_messages: &'a [Message],
223 assistant_message_id: &'a Option<String>,
224 final_turn_content: Option<&'a OneOrMany<AssistantContent>>,
225 text_delta_response: Option<&'a str>,
226 accumulated_reasoning: &'a [crate::message::Reasoning],
227 pending_reasoning_delta_text: &'a str,
228 pending_reasoning_delta_id: &'a Option<String>,
229 pending_tool_calls: &'a [(ToolCall, String)],
230 current_tool_call: Option<ToolCall>,
231}
232
233fn build_tool_call_validation_history(input: ToolCallValidationHistory<'_>) -> Vec<Message> {
234 let mut messages = input.new_messages.to_vec();
235
236 if let Some(final_turn_content) = input.final_turn_content
237 && !is_empty_assistant_choice(final_turn_content)
238 {
239 messages.push(Message::Assistant {
240 id: input.assistant_message_id.clone(),
241 content: final_turn_content.clone(),
242 });
243 return build_full_history(input.chat_history, messages);
244 }
245
246 let mut content_items = Vec::new();
247 if let Some(text) = input.text_delta_response
248 && !text.is_empty()
249 {
250 content_items.push(AssistantContent::text(text.to_string()));
251 }
252 content_items.extend(
253 input
254 .accumulated_reasoning
255 .iter()
256 .cloned()
257 .map(AssistantContent::Reasoning),
258 );
259 if input.accumulated_reasoning.is_empty() && !input.pending_reasoning_delta_text.is_empty() {
260 let mut reasoning = crate::message::Reasoning::new(input.pending_reasoning_delta_text);
261 if let Some(id) = input.pending_reasoning_delta_id.clone() {
262 reasoning = reasoning.with_id(id);
263 }
264 content_items.push(AssistantContent::Reasoning(reasoning));
265 }
266 content_items.extend(
267 input
268 .pending_tool_calls
269 .iter()
270 .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone())),
271 );
272 if let Some(tool_call) = input.current_tool_call {
273 content_items.push(AssistantContent::ToolCall(tool_call));
274 }
275
276 if let Some(content) = OneOrMany::from_iter_optional(content_items) {
277 messages.push(Message::Assistant {
278 id: input.assistant_message_id.clone(),
279 content,
280 });
281 }
282
283 build_full_history(input.chat_history, messages)
284}
285
286async fn drain_recovered_stream_usage<R>(
287 stream: &mut crate::streaming::StreamingCompletionResponse<R>,
288 tool_call_delta_states: &HashMap<(String, String), ToolCallDeltaState>,
289 current_call_usage: &mut Option<crate::completion::Usage>,
290 aggregated_usage: &mut crate::completion::Usage,
291) -> Result<(), StreamingError>
292where
293 R: Clone + Unpin + GetTokenUsage,
294{
295 if let Some(err) = pending_tool_call_delta_error(tool_call_delta_states) {
296 return Err(err.into());
297 }
298
299 while let Some(content) = stream.next().await {
300 match content {
301 Ok(StreamedAssistantContent::Final(final_resp)) => {
302 if let Some(usage) = final_resp.token_usage() {
303 *current_call_usage = reported_usage(usage);
304 }
305 if let Some(usage) = *current_call_usage {
306 *aggregated_usage += usage;
307 }
308 return Ok(());
309 }
310 Ok(_) => {}
311 Err(err) => return Err(err.into()),
312 }
313 }
314
315 Ok(())
316}
317
318fn record_completion_call_if_needed(
319 completion_calls: &mut Vec<CompletionCall>,
320 completion_call_emitted: &mut bool,
321 call_index: usize,
322 current_call_usage: Option<crate::completion::Usage>,
323 chat_stream_span: &tracing::Span,
324) -> Option<CompletionCall> {
325 if *completion_call_emitted {
326 return None;
327 }
328
329 if let Some(usage) = current_call_usage {
330 record_usage_on_span(chat_stream_span, usage);
331 }
332
333 let completion_call = CompletionCall::new(call_index, current_call_usage);
334 completion_calls.push(completion_call);
335 *completion_call_emitted = true;
336 Some(completion_call)
337}
338
339fn record_usage_on_span(span: &tracing::Span, usage: crate::completion::Usage) {
340 span.record("gen_ai.usage.input_tokens", usage.input_tokens);
341 span.record("gen_ai.usage.output_tokens", usage.output_tokens);
342 span.record(
343 "gen_ai.usage.cache_read.input_tokens",
344 usage.cached_input_tokens,
345 );
346 span.record(
347 "gen_ai.usage.cache_creation.input_tokens",
348 usage.cache_creation_input_tokens,
349 );
350 span.record(
351 "gen_ai.usage.tool_use_prompt_tokens",
352 usage.tool_use_prompt_tokens,
353 );
354 span.record("gen_ai.usage.reasoning_tokens", usage.reasoning_tokens);
355}
356
357fn build_history_for_request(
359 chat_history: Option<&[Message]>,
360 new_messages: &[Message],
361) -> Vec<Message> {
362 let input = chat_history.unwrap_or(&[]);
363 input.iter().chain(new_messages.iter()).cloned().collect()
364}
365
366async fn cancelled_prompt_error(
367 chat_history: Option<&[Message]>,
368 new_messages: Vec<Message>,
369 reason: String,
370) -> StreamingError {
371 StreamingError::Prompt(
372 PromptError::prompt_cancelled(build_full_history(chat_history, new_messages), reason)
373 .into(),
374 )
375}
376
377fn tool_result_user_content(
378 id: String,
379 call_id: Option<String>,
380 tool_result: String,
381) -> UserContent {
382 let content = ToolResultContent::from_tool_output(tool_result);
383 match call_id {
384 Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
385 None => UserContent::tool_result(id, content),
386 }
387}
388
389fn invalid_streaming_tool_retry_messages(
390 assistant_message_id: &Option<String>,
391 text_delta_response: Option<&str>,
392 accumulated_reasoning: &[crate::message::Reasoning],
393 pending_tool_calls: &[(ToolCall, String)],
394 invalid_tool_call: ToolCall,
395 feedback: String,
396) -> Option<(Message, Message)> {
397 let mut assistant_content = Vec::new();
398 if let Some(text) = text_delta_response
399 && !text.is_empty()
400 {
401 assistant_content.push(AssistantContent::text(text.to_string()));
402 }
403 assistant_content.extend(
404 accumulated_reasoning
405 .iter()
406 .cloned()
407 .map(AssistantContent::Reasoning),
408 );
409 assistant_content.extend(
410 pending_tool_calls
411 .iter()
412 .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone())),
413 );
414 assistant_content.push(AssistantContent::ToolCall(invalid_tool_call.clone()));
415
416 let assistant_content = OneOrMany::from_iter_optional(assistant_content)?;
417 let assistant_message = Message::Assistant {
418 id: assistant_message_id.clone(),
419 content: assistant_content,
420 };
421
422 let mut retry_results = pending_tool_calls
423 .iter()
424 .map(|(tool_call, _)| {
425 tool_result_user_content(
426 tool_call.id.clone(),
427 tool_call.call_id.clone(),
428 TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
429 )
430 })
431 .collect::<Vec<_>>();
432 retry_results.push(tool_result_user_content(
433 invalid_tool_call.id,
434 invalid_tool_call.call_id,
435 feedback,
436 ));
437
438 let user_message = Message::User {
439 content: OneOrMany::from_iter_optional(retry_results)?,
440 };
441
442 Some((assistant_message, user_message))
443}
444
445fn invalid_streaming_name_delta_retry_messages(
446 assistant_message_id: &Option<String>,
447 text_delta_response: Option<&str>,
448 accumulated_reasoning: &[crate::message::Reasoning],
449 pending_tool_calls: &[(ToolCall, String)],
450 invalid_tool_call: ToolCall,
451 feedback: String,
452) -> Option<(Message, Message)> {
453 let mut assistant_content = Vec::new();
454 if let Some(text) = text_delta_response
455 && !text.is_empty()
456 {
457 assistant_content.push(AssistantContent::text(text.to_string()));
458 }
459 assistant_content.extend(
460 accumulated_reasoning
461 .iter()
462 .cloned()
463 .map(AssistantContent::Reasoning),
464 );
465 assistant_content.extend(
466 pending_tool_calls
467 .iter()
468 .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone())),
469 );
470 assistant_content.push(AssistantContent::ToolCall(invalid_tool_call.clone()));
471
472 let assistant_message = Message::Assistant {
473 id: assistant_message_id.clone(),
474 content: OneOrMany::from_iter_optional(assistant_content)?,
475 };
476 let mut retry_results = pending_tool_calls
477 .iter()
478 .map(|(tool_call, _)| {
479 tool_result_user_content(
480 tool_call.id.clone(),
481 tool_call.call_id.clone(),
482 TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
483 )
484 })
485 .collect::<Vec<_>>();
486 retry_results.push(tool_result_user_content(
487 invalid_tool_call.id,
488 invalid_tool_call.call_id,
489 feedback,
490 ));
491 let user_message = Message::User {
492 content: OneOrMany::from_iter_optional(retry_results)?,
493 };
494
495 Some((assistant_message, user_message))
496}
497
498fn assistant_text_from_choice(choice: &OneOrMany<AssistantContent>) -> String {
499 choice
500 .iter()
501 .filter_map(|content| match content {
502 AssistantContent::Text(text) => Some(text.text.as_str()),
503 _ => None,
504 })
505 .collect()
506}
507
508fn assistant_text_items_from_choice(choice: &OneOrMany<AssistantContent>) -> Vec<AssistantContent> {
509 choice
510 .iter()
511 .filter_map(|content| match content {
512 AssistantContent::Text(text) => (!text.text.is_empty()
513 || text.additional_params.is_some())
514 .then(|| AssistantContent::Text(text.clone())),
515 _ => None,
516 })
517 .collect()
518}
519
520fn is_empty_assistant_choice(choice: &OneOrMany<AssistantContent>) -> bool {
521 choice.len() == 1
522 && matches!(
523 choice.first(),
524 AssistantContent::Text(text)
525 if text.text.is_empty() && text.additional_params.is_none()
526 )
527}
528
529#[derive(Default)]
530struct ToolCallDeltaState {
531 name_validated: bool,
532 buffered_arguments: Vec<String>,
533}
534
535fn pending_tool_call_delta_error(
536 states: &HashMap<(String, String), ToolCallDeltaState>,
537) -> Option<CompletionError> {
538 states
539 .iter()
540 .find(|(_, state)| !state.name_validated && !state.buffered_arguments.is_empty())
541 .map(|((id, internal_call_id), state)| {
542 CompletionError::ResponseError(format!(
543 "streamed tool call arguments received before a validated tool name for id `{id}` and internal_call_id `{internal_call_id}` ({} buffered argument delta(s))",
544 state.buffered_arguments.len()
545 ))
546 })
547}
548
549#[derive(Debug, thiserror::Error)]
550pub enum StreamingError {
551 #[error("CompletionError: {0}")]
552 Completion(#[from] CompletionError),
553 #[error("PromptError: {0}")]
554 Prompt(#[from] Box<PromptError>),
555 #[error("ToolSetError: {0}")]
556 Tool(#[from] ToolSetError),
557}
558
559impl From<crate::memory::MemoryError> for StreamingError {
563 fn from(err: crate::memory::MemoryError) -> Self {
564 Self::Completion(CompletionError::RequestError(Box::new(err)))
565 }
566}
567
568const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
569
570pub struct StreamingPromptRequest<M, P>
579where
580 M: CompletionModel,
581 P: PromptHook<M> + 'static,
582{
583 prompt: Message,
585 chat_history: Option<Vec<Message>>,
587 max_turns: usize,
589
590 model: Arc<M>,
593 agent_name: Option<String>,
595 preamble: Option<String>,
597 static_context: Vec<Document>,
599 temperature: Option<f64>,
601 max_tokens: Option<u64>,
603 additional_params: Option<serde_json::Value>,
605 tool_server_handle: ToolServerHandle,
607 dynamic_context: DynamicContextStore,
609 tool_choice: Option<ToolChoice>,
611 output_schema: Option<schemars::Schema>,
613 hook: Option<P>,
615 max_invalid_tool_call_retries: usize,
617 memory: Option<Arc<dyn ConversationMemory>>,
619 conversation_id: Option<String>,
621}
622
623impl<M, P> StreamingPromptRequest<M, P>
624where
625 M: CompletionModel + 'static,
626 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
627 P: PromptHook<M>,
628{
629 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
632 StreamingPromptRequest {
633 prompt: prompt.into(),
634 chat_history: None,
635 max_turns: agent.default_max_turns.unwrap_or_default(),
636 model: agent.model.clone(),
637 agent_name: agent.name.clone(),
638 preamble: agent.preamble.clone(),
639 static_context: agent.static_context.clone(),
640 temperature: agent.temperature,
641 max_tokens: agent.max_tokens,
642 additional_params: agent.additional_params.clone(),
643 tool_server_handle: agent.tool_server_handle.clone(),
644 dynamic_context: agent.dynamic_context.clone(),
645 tool_choice: agent.tool_choice.clone(),
646 output_schema: agent.output_schema.clone(),
647 hook: None,
648 max_invalid_tool_call_retries: 0,
649 memory: agent.memory.clone(),
650 conversation_id: agent.default_conversation_id.clone(),
651 }
652 }
653
654 pub fn from_agent<P2>(
656 agent: &Agent<M, P2>,
657 prompt: impl Into<Message>,
658 ) -> StreamingPromptRequest<M, P2>
659 where
660 P2: PromptHook<M>,
661 {
662 StreamingPromptRequest {
663 prompt: prompt.into(),
664 chat_history: None,
665 max_turns: agent.default_max_turns.unwrap_or_default(),
666 model: agent.model.clone(),
667 agent_name: agent.name.clone(),
668 preamble: agent.preamble.clone(),
669 static_context: agent.static_context.clone(),
670 temperature: agent.temperature,
671 max_tokens: agent.max_tokens,
672 additional_params: agent.additional_params.clone(),
673 tool_server_handle: agent.tool_server_handle.clone(),
674 dynamic_context: agent.dynamic_context.clone(),
675 tool_choice: agent.tool_choice.clone(),
676 output_schema: agent.output_schema.clone(),
677 hook: agent.hook.clone(),
678 max_invalid_tool_call_retries: 0,
679 memory: agent.memory.clone(),
680 conversation_id: agent.default_conversation_id.clone(),
681 }
682 }
683
684 fn agent_name(&self) -> &str {
685 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
686 }
687
688 pub fn multi_turn(mut self, turns: usize) -> Self {
691 self.max_turns = turns;
692 self
693 }
694
695 pub fn with_history<H, T>(mut self, history: H) -> Self
708 where
709 H: IntoIterator<Item = T>,
710 T: Into<Message>,
711 {
712 self.chat_history = Some(history.into_iter().map(Into::into).collect());
713 self
714 }
715
716 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
719 where
720 P2: PromptHook<M>,
721 {
722 StreamingPromptRequest {
723 prompt: self.prompt,
724 chat_history: self.chat_history,
725 max_turns: self.max_turns,
726 model: self.model,
727 agent_name: self.agent_name,
728 preamble: self.preamble,
729 static_context: self.static_context,
730 temperature: self.temperature,
731 max_tokens: self.max_tokens,
732 additional_params: self.additional_params,
733 tool_server_handle: self.tool_server_handle,
734 dynamic_context: self.dynamic_context,
735 tool_choice: self.tool_choice,
736 output_schema: self.output_schema,
737 hook: Some(hook),
738 max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
739 memory: self.memory,
740 conversation_id: self.conversation_id,
741 }
742 }
743
744 pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
748 self.max_invalid_tool_call_retries = retries;
749 self
750 }
751
752 pub fn conversation(mut self, id: impl Into<String>) -> Self {
757 self.conversation_id = Some(id.into());
758 self
759 }
760
761 pub fn without_memory(mut self) -> Self {
765 self.memory = None;
766 self.conversation_id = None;
767 self
768 }
769
770 async fn send(self) -> StreamingResult<M::StreamingResponse> {
771 let (agent_span, created_agent_span) = if tracing::Span::current().is_disabled() {
772 (
773 info_span!(
774 "invoke_agent",
775 gen_ai.operation.name = "invoke_agent",
776 gen_ai.agent.name = self.agent_name(),
777 gen_ai.system_instructions = self.preamble,
778 gen_ai.prompt = tracing::field::Empty,
779 gen_ai.completion = tracing::field::Empty,
780 gen_ai.usage.input_tokens = tracing::field::Empty,
781 gen_ai.usage.output_tokens = tracing::field::Empty,
782 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
783 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
784 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
785 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
786 ),
787 true,
788 )
789 } else {
790 (tracing::Span::current(), false)
791 };
792
793 let prompt = self.prompt;
794 if let Some(text) = prompt.rag_text() {
795 agent_span.record("gen_ai.prompt", text);
796 }
797
798 let model = self.model.clone();
800 let preamble = self.preamble.clone();
801 let static_context = self.static_context.clone();
802 let temperature = self.temperature;
803 let max_tokens = self.max_tokens;
804 let additional_params = self.additional_params.clone();
805 let tool_server_handle = self.tool_server_handle.clone();
806 let dynamic_context = self.dynamic_context.clone();
807 let tool_choice = self.tool_choice.clone();
808 let agent_name = self.agent_name.clone();
809 let (chat_history, memory_handle) = match self.chat_history {
814 Some(history) => (Some(history), None),
815 None => match (self.memory, self.conversation_id) {
816 (Some(memory), Some(id)) => match memory.load(&id).await {
817 Ok(loaded) => (Some(loaded), Some((memory, id))),
818 Err(err) => {
819 let stream = async_stream::stream! {
820 yield Err(StreamingError::from(err));
821 };
822 return Box::pin(stream);
823 }
824 },
825 _ => (None, None),
826 },
827 };
828 let has_history = chat_history.is_some();
829 let mut new_messages: Vec<Message> = vec![prompt.clone()];
830
831 let mut current_max_turns = 0;
832 let mut last_prompt_error = String::new();
833
834 let mut text_delta_response = String::new();
835 let mut saw_text_this_turn = false;
836 let mut max_turns_reached = false;
837 let output_schema = self.output_schema;
838
839 let mut aggregated_usage = crate::completion::Usage::new();
840 let mut completion_calls = Vec::new();
841 let mut completion_call_index = 0;
842 let mut invalid_tool_call_retries = 0;
843
844 let stream = async_stream::stream! {
851 'outer: loop {
852 let Some((current_prompt_ref, previous_messages)) = new_messages.split_last() else {
853 yield Err(cancelled_prompt_error(
854 chat_history.as_deref(),
855 new_messages.clone(),
856 "streaming loop lost its pending prompt".to_string(),
857 ).await);
858 break 'outer;
859 };
860 let current_prompt = current_prompt_ref.clone();
861
862 if current_max_turns > self.max_turns + 1 {
863 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
864 max_turns_reached = true;
865 break;
866 }
867
868 current_max_turns += 1;
869
870 if self.max_turns > 1 {
871 tracing::info!(
872 "Current conversation Turns: {}/{}",
873 current_max_turns,
874 self.max_turns
875 );
876 }
877
878 let history_snapshot: Vec<Message> = build_history_for_request(
879 chat_history.as_deref(),
880 previous_messages,
881 );
882
883 if let Some(ref hook) = self.hook
884 && let HookAction::Terminate { reason } =
885 hook.on_completion_call(¤t_prompt, &history_snapshot).await
886 {
887 yield Err(
888 cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason)
889 .await,
890 );
891 break 'outer;
892 }
893
894 let chat_stream_span = info_span!(
895 target: "rig::agent_chat",
896 parent: tracing::Span::current(),
897 "chat_streaming",
898 gen_ai.operation.name = "chat",
899 gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
900 gen_ai.system_instructions = preamble,
901 gen_ai.provider.name = tracing::field::Empty,
902 gen_ai.request.model = tracing::field::Empty,
903 gen_ai.response.id = tracing::field::Empty,
904 gen_ai.response.model = tracing::field::Empty,
905 gen_ai.usage.output_tokens = tracing::field::Empty,
906 gen_ai.usage.input_tokens = tracing::field::Empty,
907 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
908 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
909 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
910 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
911 gen_ai.input.messages = tracing::field::Empty,
912 gen_ai.output.messages = tracing::field::Empty,
913 );
914
915 let prepared_request = build_prepared_completion_request(
916 &model,
917 current_prompt.clone(),
918 &history_snapshot,
919 preamble.as_deref(),
920 &static_context,
921 temperature,
922 max_tokens,
923 additional_params.as_ref(),
924 tool_choice.as_ref(),
925 &tool_server_handle,
926 &dynamic_context,
927 output_schema.as_ref(),
928 )
929 .await?;
930 let executable_tool_names = prepared_request.executable_tool_names.clone();
931 let allowed_tool_names = prepared_request.allowed_tool_names.clone();
932
933 let mut stream = prepared_request
934 .builder
935 .stream()
936 .instrument(chat_stream_span.clone())
937 .await?;
938
939 let call_index = completion_call_index;
940 completion_call_index += 1;
941 let mut current_call_usage = None;
942 let mut completion_call_emitted = false;
943 let mut pending_tool_calls: Vec<(ToolCall, String)> = vec![];
944 let mut tool_calls = vec![];
945 let mut tool_results = vec![];
946 let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
947 let mut pending_reasoning_delta_text = String::new();
950 let mut pending_reasoning_delta_id: Option<String> = None;
951 let mut tool_call_delta_states: HashMap<(String, String), ToolCallDeltaState> =
952 HashMap::new();
953 let mut saw_tool_call_this_turn = false;
954
955 while let Some(content) = stream.next().await {
956 match content {
957 Ok(StreamedAssistantContent::Text(text)) => {
958 if !saw_text_this_turn {
959 text_delta_response.clear();
960 saw_text_this_turn = true;
961 }
962 text_delta_response.push_str(&text.text);
963 if let Some(ref hook) = self.hook &&
964 let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &text_delta_response).await {
965 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
966 break 'outer;
967 }
968
969 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
970 },
971 Ok(StreamedAssistantContent::ToolCall { mut tool_call, internal_call_id }) => {
972 let diagnostic_history =
973 build_tool_call_validation_history(ToolCallValidationHistory {
974 chat_history: chat_history.as_deref(),
975 new_messages: &new_messages,
976 assistant_message_id: &stream.message_id,
977 final_turn_content: None,
978 text_delta_response: saw_text_this_turn
979 .then_some(text_delta_response.as_str()),
980 accumulated_reasoning: &accumulated_reasoning,
981 pending_reasoning_delta_text: &pending_reasoning_delta_text,
982 pending_reasoning_delta_id: &pending_reasoning_delta_id,
983 pending_tool_calls: &pending_tool_calls,
984 current_tool_call: Some(tool_call.clone()),
985 });
986
987 if !allowed_tool_names.contains(&tool_call.function.name) {
988 let args = json_utils::value_to_json_string(&tool_call.function.arguments);
989 let emitted_tool_name = tool_call.function.name.clone();
990 match resolve_invalid_tool_call::<M, P>(
991 self.hook.as_ref(),
992 &emitted_tool_name,
993 Some(tool_call.id.clone()),
994 Some(internal_call_id.clone()),
995 Some(args),
996 &executable_tool_names,
997 &allowed_tool_names,
998 self.tool_choice.as_ref(),
999 diagnostic_history.clone(),
1000 true,
1001 ).await {
1002 InvalidToolCallResolution::Fail(err) => {
1003 yield Err(Box::new(err).into());
1004 break 'outer;
1005 }
1006 InvalidToolCallResolution::Retry(feedback) => {
1007 if invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
1008 yield Err(Box::new(PromptError::UnknownToolCall {
1009 tool_name: emitted_tool_name,
1010 available_tools: executable_tool_names.iter().cloned().collect(),
1011 allowed_tools: allowed_tool_names.iter().cloned().collect(),
1012 chat_history: Box::new(diagnostic_history),
1013 }).into());
1014 break 'outer;
1015 }
1016
1017 invalid_tool_call_retries += 1;
1018 flush_pending_reasoning_delta(
1019 &mut accumulated_reasoning,
1020 &mut pending_reasoning_delta_text,
1021 &mut pending_reasoning_delta_id,
1022 );
1023 let Some((assistant_message, user_message)) =
1024 invalid_streaming_tool_retry_messages(
1025 &stream.message_id,
1026 saw_text_this_turn.then_some(text_delta_response.as_str()),
1027 &accumulated_reasoning,
1028 &pending_tool_calls,
1029 tool_call,
1030 feedback,
1031 )
1032 else {
1033 yield Err(cancelled_prompt_error(
1034 chat_history.as_deref(),
1035 new_messages.clone(),
1036 "invalid tool call retry produced no retry messages".to_string(),
1037 ).await);
1038 break 'outer;
1039 };
1040 new_messages.push(assistant_message);
1041 new_messages.push(user_message);
1042 if let Err(err) = drain_recovered_stream_usage(
1043 &mut stream,
1044 &tool_call_delta_states,
1045 &mut current_call_usage,
1046 &mut aggregated_usage,
1047 )
1048 .await
1049 {
1050 yield Err(err);
1051 break 'outer;
1052 }
1053 if let Some(completion_call) = record_completion_call_if_needed(
1054 &mut completion_calls,
1055 &mut completion_call_emitted,
1056 call_index,
1057 current_call_usage,
1058 &chat_stream_span,
1059 ) {
1060 yield Ok(MultiTurnStreamItem::CompletionCall(
1061 completion_call,
1062 ));
1063 }
1064 text_delta_response.clear();
1065 saw_text_this_turn = false;
1066 continue 'outer;
1067 }
1068 InvalidToolCallResolution::Repair(repaired_name) => {
1069 tool_call.function.name = repaired_name;
1070 }
1071 InvalidToolCallResolution::Skip(reason) => {
1072 let skipped_tool_result = ToolResult {
1073 id: tool_call.id.clone(),
1074 call_id: tool_call.call_id.clone(),
1075 content: ToolResultContent::from_tool_output(
1076 reason.clone(),
1077 ),
1078 };
1079 flush_pending_reasoning_delta(
1080 &mut accumulated_reasoning,
1081 &mut pending_reasoning_delta_text,
1082 &mut pending_reasoning_delta_id,
1083 );
1084 let Some((assistant_message, user_message)) =
1085 invalid_streaming_tool_retry_messages(
1086 &stream.message_id,
1087 saw_text_this_turn
1088 .then_some(text_delta_response.as_str()),
1089 &accumulated_reasoning,
1090 &pending_tool_calls,
1091 tool_call,
1092 reason,
1093 )
1094 else {
1095 yield Err(cancelled_prompt_error(
1096 chat_history.as_deref(),
1097 new_messages.clone(),
1098 "invalid tool call skip produced no recovery messages".to_string(),
1099 ).await);
1100 break 'outer;
1101 };
1102 new_messages.push(assistant_message);
1103 new_messages.push(user_message);
1104 let tool_result = ToolResult {
1105 id: skipped_tool_result.id,
1106 call_id: skipped_tool_result.call_id,
1107 content: skipped_tool_result.content,
1108 };
1109 if let Err(err) = drain_recovered_stream_usage(
1110 &mut stream,
1111 &tool_call_delta_states,
1112 &mut current_call_usage,
1113 &mut aggregated_usage,
1114 )
1115 .await
1116 {
1117 yield Err(err);
1118 break 'outer;
1119 }
1120 if let Some(completion_call) = record_completion_call_if_needed(
1121 &mut completion_calls,
1122 &mut completion_call_emitted,
1123 call_index,
1124 current_call_usage,
1125 &chat_stream_span,
1126 ) {
1127 yield Ok(MultiTurnStreamItem::CompletionCall(
1128 completion_call,
1129 ));
1130 }
1131 yield Ok(MultiTurnStreamItem::StreamUserItem(
1132 StreamedUserContent::ToolResult {
1133 tool_result,
1134 internal_call_id,
1135 },
1136 ));
1137 text_delta_response.clear();
1138 saw_text_this_turn = false;
1139 continue 'outer;
1140 }
1141 }
1142 }
1143
1144 pending_tool_calls.push((tool_call, internal_call_id));
1145 },
1146 Ok(StreamedAssistantContent::ToolCallDelta {
1147 id,
1148 internal_call_id,
1149 content,
1150 }) => {
1151 let key = (id.clone(), internal_call_id.clone());
1152 let mut deltas_to_emit = Vec::new();
1153
1154 match content {
1155 ToolCallDeltaContent::Name(mut name) => {
1156 let buffered_args = tool_call_delta_states
1157 .get(&key)
1158 .map(|state| state.buffered_arguments.join(""))
1159 .unwrap_or_default();
1160 let diagnostic_args = if buffered_args.trim().is_empty() {
1161 serde_json::Value::Null
1162 } else {
1163 serde_json::from_str(&buffered_args)
1164 .unwrap_or(serde_json::Value::Null)
1165 };
1166 let diagnostic_tool_call = ToolCall::new(
1167 id.clone(),
1168 ToolFunction::new(name.clone(), diagnostic_args),
1169 );
1170 let diagnostic_history =
1171 build_tool_call_validation_history(ToolCallValidationHistory {
1172 chat_history: chat_history.as_deref(),
1173 new_messages: &new_messages,
1174 assistant_message_id: &stream.message_id,
1175 final_turn_content: None,
1176 text_delta_response: saw_text_this_turn
1177 .then_some(text_delta_response.as_str()),
1178 accumulated_reasoning: &accumulated_reasoning,
1179 pending_reasoning_delta_text: &pending_reasoning_delta_text,
1180 pending_reasoning_delta_id: &pending_reasoning_delta_id,
1181 pending_tool_calls: &pending_tool_calls,
1182 current_tool_call: Some(diagnostic_tool_call.clone()),
1183 });
1184
1185 if !allowed_tool_names.contains(&name) {
1186 let emitted_tool_name = name.clone();
1187 match resolve_invalid_tool_call::<M, P>(
1188 self.hook.as_ref(),
1189 &emitted_tool_name,
1190 Some(id.clone()),
1191 Some(internal_call_id.clone()),
1192 Some(buffered_args.clone()),
1193 &executable_tool_names,
1194 &allowed_tool_names,
1195 self.tool_choice.as_ref(),
1196 diagnostic_history.clone(),
1197 true,
1198 ).await {
1199 InvalidToolCallResolution::Fail(err) => {
1200 yield Err(Box::new(err).into());
1201 break 'outer;
1202 }
1203 InvalidToolCallResolution::Skip(reason) => {
1204 tool_call_delta_states.remove(&key);
1205 flush_pending_reasoning_delta(
1206 &mut accumulated_reasoning,
1207 &mut pending_reasoning_delta_text,
1208 &mut pending_reasoning_delta_id,
1209 );
1210 let Some((assistant_message, user_message)) =
1211 invalid_streaming_name_delta_retry_messages(
1212 &stream.message_id,
1213 saw_text_this_turn
1214 .then_some(text_delta_response.as_str()),
1215 &accumulated_reasoning,
1216 &pending_tool_calls,
1217 diagnostic_tool_call.clone(),
1218 reason.clone(),
1219 )
1220 else {
1221 yield Err(cancelled_prompt_error(
1222 chat_history.as_deref(),
1223 new_messages.clone(),
1224 "invalid tool call skip produced no recovery messages".to_string(),
1225 ).await);
1226 break 'outer;
1227 };
1228 new_messages.push(assistant_message);
1229 new_messages.push(user_message);
1230 let tool_result = ToolResult {
1231 id,
1232 call_id: None,
1233 content: ToolResultContent::from_tool_output(
1234 reason,
1235 ),
1236 };
1237 if let Err(err) = drain_recovered_stream_usage(
1238 &mut stream,
1239 &tool_call_delta_states,
1240 &mut current_call_usage,
1241 &mut aggregated_usage,
1242 )
1243 .await
1244 {
1245 yield Err(err);
1246 break 'outer;
1247 }
1248 if let Some(completion_call) = record_completion_call_if_needed(
1249 &mut completion_calls,
1250 &mut completion_call_emitted,
1251 call_index,
1252 current_call_usage,
1253 &chat_stream_span,
1254 ) {
1255 yield Ok(MultiTurnStreamItem::CompletionCall(
1256 completion_call,
1257 ));
1258 }
1259 yield Ok(MultiTurnStreamItem::StreamUserItem(
1260 StreamedUserContent::ToolResult {
1261 tool_result,
1262 internal_call_id,
1263 },
1264 ));
1265 text_delta_response.clear();
1266 saw_text_this_turn = false;
1267 continue 'outer;
1268 }
1269 InvalidToolCallResolution::Retry(feedback) => {
1270 tool_call_delta_states.remove(&key);
1271 if invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
1272 yield Err(Box::new(PromptError::UnknownToolCall {
1273 tool_name: emitted_tool_name,
1274 available_tools: executable_tool_names.iter().cloned().collect(),
1275 allowed_tools: allowed_tool_names.iter().cloned().collect(),
1276 chat_history: Box::new(diagnostic_history),
1277 }).into());
1278 break 'outer;
1279 }
1280
1281 invalid_tool_call_retries += 1;
1282 flush_pending_reasoning_delta(
1283 &mut accumulated_reasoning,
1284 &mut pending_reasoning_delta_text,
1285 &mut pending_reasoning_delta_id,
1286 );
1287 let Some((assistant_message, user_message)) =
1288 invalid_streaming_name_delta_retry_messages(
1289 &stream.message_id,
1290 saw_text_this_turn
1291 .then_some(text_delta_response.as_str()),
1292 &accumulated_reasoning,
1293 &pending_tool_calls,
1294 diagnostic_tool_call.clone(),
1295 feedback,
1296 )
1297 else {
1298 yield Err(cancelled_prompt_error(
1299 chat_history.as_deref(),
1300 new_messages.clone(),
1301 "invalid tool call retry produced no retry messages".to_string(),
1302 ).await);
1303 break 'outer;
1304 };
1305 new_messages.push(assistant_message);
1306 new_messages.push(user_message);
1307 if let Err(err) = drain_recovered_stream_usage(
1308 &mut stream,
1309 &tool_call_delta_states,
1310 &mut current_call_usage,
1311 &mut aggregated_usage,
1312 )
1313 .await
1314 {
1315 yield Err(err);
1316 break 'outer;
1317 }
1318 if let Some(completion_call) = record_completion_call_if_needed(
1319 &mut completion_calls,
1320 &mut completion_call_emitted,
1321 call_index,
1322 current_call_usage,
1323 &chat_stream_span,
1324 ) {
1325 yield Ok(MultiTurnStreamItem::CompletionCall(
1326 completion_call,
1327 ));
1328 }
1329 text_delta_response.clear();
1330 saw_text_this_turn = false;
1331 continue 'outer;
1332 }
1333 InvalidToolCallResolution::Repair(repaired_name) => {
1334 name = repaired_name;
1335 }
1336 }
1337 }
1338
1339 let state =
1340 tool_call_delta_states.entry(key.clone()).or_default();
1341 state.name_validated = true;
1342 let buffered_arguments =
1343 std::mem::take(&mut state.buffered_arguments);
1344
1345 deltas_to_emit.push(ToolCallDeltaContent::Name(name));
1346 deltas_to_emit.extend(
1347 buffered_arguments
1348 .into_iter()
1349 .map(ToolCallDeltaContent::Delta),
1350 );
1351 }
1352 ToolCallDeltaContent::Delta(arguments) => {
1353 let state =
1354 tool_call_delta_states.entry(key.clone()).or_default();
1355 if state.name_validated {
1356 deltas_to_emit.push(ToolCallDeltaContent::Delta(arguments));
1357 } else {
1358 state.buffered_arguments.push(arguments);
1359 }
1360 }
1361 }
1362
1363 for content in deltas_to_emit {
1364 if let Some(ref hook) = self.hook {
1365 let (name, delta) = match &content {
1366 ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
1367 ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
1368 };
1369
1370 if let HookAction::Terminate { reason } = hook
1371 .on_tool_call_delta(
1372 &id,
1373 &internal_call_id,
1374 name,
1375 delta,
1376 )
1377 .await
1378 {
1379 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1380 break 'outer;
1381 }
1382 }
1383
1384 yield Ok(MultiTurnStreamItem::StreamAssistantItem(
1385 StreamedAssistantContent::ToolCallDelta {
1386 id: id.clone(),
1387 internal_call_id: internal_call_id.clone(),
1388 content,
1389 },
1390 ));
1391 }
1392 }
1393 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
1394 merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
1398 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
1399 },
1400 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
1401 pending_reasoning_delta_text.push_str(&reasoning);
1405 if pending_reasoning_delta_id.is_none() {
1406 pending_reasoning_delta_id = id.clone();
1407 }
1408 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
1409 },
1410 Ok(StreamedAssistantContent::Final(final_resp)) => {
1411 if let Some(err) =
1412 pending_tool_call_delta_error(&tool_call_delta_states)
1413 {
1414 yield Err(err.into());
1415 break 'outer;
1416 }
1417
1418 if let Some(usage) = final_resp.token_usage() {
1419 current_call_usage = reported_usage(usage);
1420 }
1421 if let Some(usage) = current_call_usage {
1422 aggregated_usage += usage;
1423 }
1424 if let Some(completion_call) = record_completion_call_if_needed(
1425 &mut completion_calls,
1426 &mut completion_call_emitted,
1427 call_index,
1428 current_call_usage,
1429 &chat_stream_span,
1430 ) {
1431 yield Ok(MultiTurnStreamItem::CompletionCall(completion_call));
1432 }
1433
1434 if saw_text_this_turn {
1435 if let Some(ref hook) = self.hook &&
1436 let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(¤t_prompt, &final_resp).await {
1437 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1438 break 'outer;
1439 }
1440
1441 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
1442 saw_text_this_turn = false;
1443 }
1444 }
1445 Err(e) => {
1446 yield Err(e.into());
1447 break 'outer;
1448 }
1449 }
1450 }
1451
1452 if let Some(err) = pending_tool_call_delta_error(&tool_call_delta_states) {
1453 yield Err(err.into());
1454 break 'outer;
1455 }
1456
1457 if let Some(completion_call) = record_completion_call_if_needed(
1458 &mut completion_calls,
1459 &mut completion_call_emitted,
1460 call_index,
1461 current_call_usage,
1462 &chat_stream_span,
1463 ) {
1464 yield Ok(MultiTurnStreamItem::CompletionCall(completion_call));
1465 }
1466
1467 flush_pending_reasoning_delta(
1471 &mut accumulated_reasoning,
1472 &mut pending_reasoning_delta_text,
1473 &mut pending_reasoning_delta_id,
1474 );
1475
1476 let final_turn_content = stream.choice.clone();
1477 let turn_text_response = assistant_text_from_choice(&final_turn_content);
1478 tracing::Span::current().record("gen_ai.completion", &turn_text_response);
1479
1480 if !pending_tool_calls.is_empty() {
1481 let diagnostic_history =
1482 build_tool_call_validation_history(ToolCallValidationHistory {
1483 chat_history: chat_history.as_deref(),
1484 new_messages: &new_messages,
1485 assistant_message_id: &stream.message_id,
1486 final_turn_content: Some(&final_turn_content),
1487 text_delta_response: None,
1488 accumulated_reasoning: &accumulated_reasoning,
1489 pending_reasoning_delta_text: "",
1490 pending_reasoning_delta_id: &None,
1491 pending_tool_calls: &pending_tool_calls,
1492 current_tool_call: None,
1493 });
1494
1495 for (tool_call, _) in &pending_tool_calls {
1496 if let Err(err) = validate_tool_call_name(
1497 &tool_call.function.name,
1498 &executable_tool_names,
1499 &allowed_tool_names,
1500 diagnostic_history.clone(),
1501 ) {
1502 yield Err(Box::new(err).into());
1503 break 'outer;
1504 }
1505 }
1506
1507 for (tool_call, internal_call_id) in pending_tool_calls {
1508 let tool_span = info_span!(
1509 parent: tracing::Span::current(),
1510 "execute_tool",
1511 gen_ai.operation.name = "execute_tool",
1512 gen_ai.tool.type = "function",
1513 gen_ai.tool.name = tracing::field::Empty,
1514 gen_ai.tool.call.id = tracing::field::Empty,
1515 gen_ai.tool.call.arguments = tracing::field::Empty,
1516 gen_ai.tool.call.result = tracing::field::Empty
1517 );
1518
1519 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
1520
1521 let tc_result = async {
1522 let tool_span = tracing::Span::current();
1523 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
1524 if let Some(ref hook) = self.hook {
1525 let action = hook
1526 .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
1527 .await;
1528
1529 if let ToolCallHookAction::Terminate { reason } = action {
1530 return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1531 }
1532
1533 if let ToolCallHookAction::Skip { reason } = action {
1534 tracing::info!(
1536 tool_name = tool_call.function.name.as_str(),
1537 reason = reason,
1538 "Tool call rejected"
1539 );
1540 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
1541 tool_calls.push(tool_call_msg);
1542 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
1543 saw_tool_call_this_turn = true;
1544 return Ok(reason);
1545 }
1546 }
1547
1548 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
1549 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
1550
1551 let tool_result = match
1552 tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
1553 Ok(thing) => thing,
1554 Err(e) => {
1555 tracing::warn!("Error while calling tool: {e}");
1556 e.to_string()
1557 }
1558 };
1559
1560 tool_span.record("gen_ai.tool.call.result", &tool_result);
1561
1562 if let Some(ref hook) = self.hook &&
1563 let HookAction::Terminate { reason } =
1564 hook.on_tool_result(
1565 &tool_call.function.name,
1566 tool_call.call_id.clone(),
1567 &internal_call_id,
1568 &tool_args,
1569 &tool_result.to_string()
1570 )
1571 .await {
1572 return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1573 }
1574
1575 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
1576
1577 tool_calls.push(tool_call_msg);
1578 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
1579
1580 saw_tool_call_this_turn = true;
1581 Ok(tool_result)
1582 }.instrument(tool_span).await;
1583
1584 match tc_result {
1585 Ok(text) => {
1586 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
1587 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
1588 }
1589 Err(e) => {
1590 yield Err(e);
1591 break 'outer;
1592 }
1593 }
1594 }
1595 }
1596
1597 let mut assistant_turn_added_to_history = false;
1600 if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
1601 let mut content_items = assistant_text_items_from_choice(&final_turn_content);
1603
1604 for reasoning in accumulated_reasoning.drain(..) {
1606 content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
1607 }
1608
1609 content_items.extend(tool_calls.clone());
1610
1611 if let Some(content) = OneOrMany::from_iter_optional(content_items) {
1612 new_messages.push(Message::Assistant {
1613 id: stream.message_id.clone(),
1614 content,
1615 });
1616 assistant_turn_added_to_history = true;
1617 }
1618 }
1619
1620 let tool_result_contents: Vec<UserContent> = tool_results
1622 .into_iter()
1623 .map(|(id, call_id, tool_result)| {
1624 let content = ToolResultContent::from_tool_output(tool_result);
1625 match call_id {
1626 Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
1627 None => UserContent::tool_result(id, content),
1628 }
1629 })
1630 .collect();
1631
1632 if let Some(content) = OneOrMany::from_iter_optional(tool_result_contents) {
1633 new_messages.push(Message::User { content });
1634 }
1635
1636 if !saw_tool_call_this_turn {
1637 let should_add_final_assistant = !assistant_turn_added_to_history
1639 && !is_empty_assistant_choice(&final_turn_content);
1640 if should_add_final_assistant {
1641 new_messages.push(Message::Assistant {
1642 id: stream.message_id.clone(),
1643 content: final_turn_content.clone(),
1644 });
1645 } else if !assistant_turn_added_to_history {
1646 tracing::warn!(
1647 agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
1648 message_id = ?stream.message_id,
1649 "Streaming turn completed without assistant text; final response will be empty"
1650 );
1651 }
1652
1653 if created_agent_span {
1654 let current_span = tracing::Span::current();
1655 record_usage_on_span(¤t_span, aggregated_usage);
1656 }
1657 tracing::info!("Agent multi-turn stream finished");
1658 if let Some((memory, id)) = memory_handle.as_ref()
1659 && let Err(err) = memory.append(id, new_messages.clone()).await
1660 {
1661 tracing::warn!(
1662 error = %err,
1663 conversation_id = %id,
1664 "conversation memory append failed; yielding final response anyway"
1665 );
1666 }
1667 let final_messages: Option<Vec<Message>> = if has_history {
1668 Some(new_messages.clone())
1669 } else {
1670 None
1671 };
1672 yield Ok(MultiTurnStreamItem::final_response_with_completion_calls(
1673 final_turn_content,
1674 aggregated_usage,
1675 completion_calls,
1676 final_messages,
1677 ));
1678 break;
1679 }
1680 }
1681
1682 if max_turns_reached {
1683 yield Err(Box::new(PromptError::MaxTurnsError {
1684 max_turns: self.max_turns,
1685 chat_history: build_full_history(chat_history.as_deref(), new_messages.clone()).into(),
1686 prompt: Box::new(last_prompt_error.clone().into()),
1687 }).into());
1688 }
1689 };
1690
1691 Box::pin(stream.instrument(agent_span))
1692 }
1693}
1694
1695impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
1696where
1697 M: CompletionModel + 'static,
1698 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
1699 P: PromptHook<M> + 'static,
1700{
1701 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1703
1704 fn into_future(self) -> Self::IntoFuture {
1705 Box::pin(async move { self.send().await })
1707 }
1708}
1709
1710pub async fn stream_to_stdout<R>(
1717 stream: &mut StreamingResult<R>,
1718) -> Result<FinalResponse, std::io::Error> {
1719 let mut final_res = FinalResponse::empty();
1720 print!("Response: ");
1721 while let Some(content) = stream.next().await {
1722 match content {
1723 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1724 Text { text, .. },
1725 ))) => {
1726 print!("{text}");
1727 std::io::Write::flush(&mut std::io::stdout())?;
1728 }
1729 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
1730 reasoning,
1731 ))) => {
1732 let reasoning = reasoning.display_text();
1733 print!("{reasoning}");
1734 std::io::Write::flush(&mut std::io::stdout())?;
1735 }
1736 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1737 final_res = res;
1738 }
1739 Err(err) => {
1740 eprintln!("Error: {err}");
1741 }
1742 _ => {}
1743 }
1744 }
1745
1746 Ok(final_res)
1747}
1748
1749#[cfg(test)]
1750mod tests {
1751 use super::*;
1752 use crate::agent::AgentBuilder;
1753 use crate::agent::prompt_request::hooks::{
1754 InvalidToolCallContext, InvalidToolCallHookAction, PromptHook, ToolCallHookAction,
1755 };
1756 use crate::client::ProviderClient;
1757 use crate::client::completion::CompletionClient;
1758 use crate::completion::{CompletionRequest, PromptError, ToolDefinition, Usage};
1759 use crate::message::{
1760 AssistantContent, DocumentSourceKind, ImageMediaType, Message, ReasoningContent,
1761 ToolChoice, ToolResultContent, UserContent,
1762 };
1763 use crate::providers::anthropic;
1764 use crate::streaming::{StreamingPrompt, ToolCallDeltaContent};
1765 use crate::test_utils::{
1766 AppendFailingMemory, FailingMemory, MockAddTool, MockCompletionModel, MockResponse,
1767 MockStreamEvent, MockSubtractTool, MockToolError,
1768 };
1769 use crate::tool::Tool;
1770 use futures::{StreamExt, TryStreamExt};
1771 use serde::Deserialize;
1772 use std::collections::HashMap;
1773 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
1774 use std::sync::{Arc, Mutex};
1775 use std::time::Duration;
1776 use tracing::field::{Field, Visit};
1777 use tracing::{Id, Subscriber};
1778 use tracing_subscriber::layer::{Context, SubscriberExt};
1779 use tracing_subscriber::{Layer, Registry, registry::LookupSpan};
1780
1781 #[test]
1782 fn merge_reasoning_blocks_preserves_order_and_signatures() {
1783 let mut accumulated = Vec::new();
1784 let first = crate::message::Reasoning {
1785 id: Some("rs_1".to_string()),
1786 content: vec![ReasoningContent::Text {
1787 text: "step-1".to_string(),
1788 signature: Some("sig-1".to_string()),
1789 }],
1790 };
1791 let second = crate::message::Reasoning {
1792 id: Some("rs_1".to_string()),
1793 content: vec![
1794 ReasoningContent::Text {
1795 text: "step-2".to_string(),
1796 signature: Some("sig-2".to_string()),
1797 },
1798 ReasoningContent::Summary("summary".to_string()),
1799 ],
1800 };
1801
1802 merge_reasoning_blocks(&mut accumulated, &first);
1803 merge_reasoning_blocks(&mut accumulated, &second);
1804
1805 assert_eq!(accumulated.len(), 1);
1806 let merged = accumulated.first().expect("expected accumulated reasoning");
1807 assert_eq!(merged.id.as_deref(), Some("rs_1"));
1808 assert_eq!(merged.content.len(), 3);
1809 assert!(matches!(
1810 merged.content.first(),
1811 Some(ReasoningContent::Text { text, signature: Some(sig) })
1812 if text == "step-1" && sig == "sig-1"
1813 ));
1814 assert!(matches!(
1815 merged.content.get(1),
1816 Some(ReasoningContent::Text { text, signature: Some(sig) })
1817 if text == "step-2" && sig == "sig-2"
1818 ));
1819 }
1820
1821 #[test]
1822 fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
1823 let mut accumulated = vec![crate::message::Reasoning {
1824 id: Some("rs_a".to_string()),
1825 content: vec![ReasoningContent::Text {
1826 text: "step-1".to_string(),
1827 signature: None,
1828 }],
1829 }];
1830 let incoming = crate::message::Reasoning {
1831 id: Some("rs_b".to_string()),
1832 content: vec![ReasoningContent::Text {
1833 text: "step-2".to_string(),
1834 signature: None,
1835 }],
1836 };
1837
1838 merge_reasoning_blocks(&mut accumulated, &incoming);
1839 assert_eq!(accumulated.len(), 2);
1840 assert_eq!(
1841 accumulated.first().and_then(|r| r.id.as_deref()),
1842 Some("rs_a")
1843 );
1844 assert_eq!(
1845 accumulated.get(1).and_then(|r| r.id.as_deref()),
1846 Some("rs_b")
1847 );
1848 }
1849
1850 #[test]
1851 fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
1852 let mut accumulated = vec![crate::message::Reasoning {
1853 id: None,
1854 content: vec![ReasoningContent::Text {
1855 text: "first".to_string(),
1856 signature: None,
1857 }],
1858 }];
1859 let incoming = crate::message::Reasoning {
1860 id: None,
1861 content: vec![ReasoningContent::Text {
1862 text: "second".to_string(),
1863 signature: None,
1864 }],
1865 };
1866
1867 merge_reasoning_blocks(&mut accumulated, &incoming);
1868 assert_eq!(accumulated.len(), 2);
1869 assert!(matches!(
1870 accumulated.first(),
1871 Some(crate::message::Reasoning {
1872 id: None,
1873 content
1874 }) if matches!(
1875 content.first(),
1876 Some(ReasoningContent::Text { text, .. }) if text == "first"
1877 )
1878 ));
1879 assert!(matches!(
1880 accumulated.get(1),
1881 Some(crate::message::Reasoning {
1882 id: None,
1883 content
1884 }) if matches!(
1885 content.first(),
1886 Some(ReasoningContent::Text { text, .. }) if text == "second"
1887 )
1888 ));
1889 }
1890
1891 #[test]
1892 fn tool_result_user_content_preserves_multimodal_tool_output() {
1893 let user_content = tool_result_user_content(
1894 "tool_call_1".to_string(),
1895 Some("call_1".to_string()),
1896 serde_json::json!({
1897 "response": {
1898 "instruction": "Use the image part to answer."
1899 },
1900 "parts": [
1901 {
1902 "type": "image",
1903 "data": "base64data==",
1904 "mimeType": "image/png"
1905 }
1906 ]
1907 })
1908 .to_string(),
1909 );
1910
1911 let tool_result = match user_content {
1912 UserContent::ToolResult(tool_result) => tool_result,
1913 other => panic!("expected tool result content, got {other:?}"),
1914 };
1915
1916 assert_eq!(tool_result.id, "tool_call_1");
1917 assert_eq!(tool_result.call_id.as_deref(), Some("call_1"));
1918 assert_eq!(tool_result.content.len(), 2);
1919
1920 let mut items = tool_result.content.iter();
1921 match items.next() {
1922 Some(ToolResultContent::Text(text)) => {
1923 assert!(text.text.contains("Use the image part to answer."));
1924 }
1925 other => panic!("expected structured text payload first, got {other:?}"),
1926 }
1927
1928 match items.next() {
1929 Some(ToolResultContent::Image(image)) => {
1930 assert_eq!(image.media_type, Some(ImageMediaType::PNG));
1931 assert!(matches!(
1932 image.data,
1933 DocumentSourceKind::Base64(ref data) if data == "base64data=="
1934 ));
1935 }
1936 other => panic!("expected image payload second, got {other:?}"),
1937 }
1938 }
1939
1940 fn validate_follow_up_tool_history(request: &CompletionRequest) -> Result<(), String> {
1941 let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
1942 if history.len() != 3 {
1943 return Err(format!(
1944 "follow-up request should contain [original user prompt, assistant tool call, user tool result]: {history:?}"
1945 ));
1946 }
1947
1948 if !matches!(
1949 history.first(),
1950 Some(Message::User { content })
1951 if matches!(
1952 content.first(),
1953 UserContent::Text(text) if text.text == "do tool work"
1954 )
1955 ) {
1956 return Err(format!(
1957 "follow-up request should begin with the original user prompt: {history:?}"
1958 ));
1959 }
1960
1961 if !matches!(
1962 history.get(1),
1963 Some(Message::Assistant { content, .. })
1964 if matches!(
1965 content.first(),
1966 AssistantContent::ToolCall(tool_call)
1967 if tool_call.id == "tool_call_1"
1968 && tool_call.call_id.as_deref() == Some("call_1")
1969 )
1970 ) {
1971 return Err(format!(
1972 "follow-up request is missing the assistant tool call in position 2: {history:?}"
1973 ));
1974 }
1975
1976 if !matches!(
1977 history.get(2),
1978 Some(Message::User { content })
1979 if matches!(
1980 content.first(),
1981 UserContent::ToolResult(tool_result)
1982 if tool_result.id == "tool_call_1"
1983 && tool_result.call_id.as_deref() == Some("call_1")
1984 )
1985 ) {
1986 return Err(format!(
1987 "follow-up request should end with the user tool result: {history:?}"
1988 ));
1989 }
1990
1991 Ok(())
1992 }
1993
1994 fn history_contains_tool_call(history: &[Message], tool_name: &str) -> bool {
1995 history.iter().any(|message| {
1996 matches!(
1997 message,
1998 Message::Assistant { content, .. }
1999 if content.iter().any(|item| matches!(
2000 item,
2001 AssistantContent::ToolCall(tool_call)
2002 if tool_call.function.name == tool_name
2003 ))
2004 )
2005 })
2006 }
2007
2008 fn history_contains_text(history: &[Message], expected: &str) -> bool {
2009 history.iter().any(|message| {
2010 matches!(
2011 message,
2012 Message::Assistant { content, .. }
2013 if content.iter().any(|item| matches!(
2014 item,
2015 AssistantContent::Text(text) if text.text == expected
2016 ))
2017 )
2018 })
2019 }
2020
2021 fn assistant_reasoning_precedes_tool_call(
2022 history: &[Message],
2023 expected_reasoning: &str,
2024 tool_name: &str,
2025 ) -> bool {
2026 history.iter().any(|message| {
2027 let Message::Assistant { content, .. } = message else {
2028 return false;
2029 };
2030
2031 let reasoning_index = content.iter().position(|item| {
2032 matches!(
2033 item,
2034 AssistantContent::Reasoning(reasoning)
2035 if reasoning.content.iter().any(|content| matches!(
2036 content,
2037 ReasoningContent::Text { text, .. }
2038 if text == expected_reasoning
2039 ))
2040 )
2041 });
2042 let tool_index = content.iter().position(|item| {
2043 matches!(
2044 item,
2045 AssistantContent::ToolCall(tool_call)
2046 if tool_call.function.name == tool_name
2047 )
2048 });
2049
2050 matches!((reasoning_index, tool_index), (Some(reasoning), Some(tool)) if reasoning < tool)
2051 })
2052 }
2053
2054 #[derive(Clone)]
2055 struct PanicOnUnknownToolHook;
2056
2057 impl PromptHook<MockCompletionModel> for PanicOnUnknownToolHook {
2058 async fn on_tool_call_delta(
2059 &self,
2060 _tool_call_id: &str,
2061 _internal_call_id: &str,
2062 _tool_name: Option<&str>,
2063 _tool_call_delta: &str,
2064 ) -> HookAction {
2065 panic!("unknown tool call delta should fail before delta hooks run")
2066 }
2067
2068 async fn on_tool_call(
2069 &self,
2070 _tool_name: &str,
2071 _tool_call_id: Option<String>,
2072 _internal_call_id: &str,
2073 _args: &str,
2074 ) -> ToolCallHookAction {
2075 panic!("unknown tool call should fail before tool hooks run")
2076 }
2077
2078 async fn on_stream_completion_response_finish(
2079 &self,
2080 _prompt: &Message,
2081 _response: &MockResponse,
2082 ) -> HookAction {
2083 panic!("unknown tool call should fail before stream finish hooks run")
2084 }
2085 }
2086
2087 #[derive(Clone)]
2088 struct CountingAddTool {
2089 calls: Arc<AtomicU32>,
2090 }
2091
2092 #[derive(Clone)]
2093 struct CountingSubtractTool {
2094 calls: Arc<AtomicU32>,
2095 }
2096
2097 #[derive(Deserialize)]
2098 struct CountingOperationArgs {
2099 x: i32,
2100 y: i32,
2101 }
2102
2103 fn arithmetic_tool_definition(name: &str, description: &str) -> ToolDefinition {
2104 ToolDefinition {
2105 name: name.to_string(),
2106 description: description.to_string(),
2107 parameters: serde_json::json!({
2108 "type": "object",
2109 "properties": {
2110 "x": {
2111 "type": "number",
2112 "description": "The first operand"
2113 },
2114 "y": {
2115 "type": "number",
2116 "description": "The second operand"
2117 }
2118 },
2119 "required": ["x", "y"],
2120 }),
2121 }
2122 }
2123
2124 impl Tool for CountingAddTool {
2125 const NAME: &'static str = "add";
2126 type Error = MockToolError;
2127 type Args = CountingOperationArgs;
2128 type Output = i32;
2129
2130 async fn definition(&self, _prompt: String) -> ToolDefinition {
2131 arithmetic_tool_definition(Self::NAME, "Add x and y together")
2132 }
2133
2134 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
2135 self.calls.fetch_add(1, Ordering::SeqCst);
2136 Ok(args.x + args.y)
2137 }
2138 }
2139
2140 impl Tool for CountingSubtractTool {
2141 const NAME: &'static str = "subtract";
2142 type Error = MockToolError;
2143 type Args = CountingOperationArgs;
2144 type Output = i32;
2145
2146 async fn definition(&self, _prompt: String) -> ToolDefinition {
2147 arithmetic_tool_definition(Self::NAME, "Subtract y from x")
2148 }
2149
2150 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
2151 self.calls.fetch_add(1, Ordering::SeqCst);
2152 Ok(args.x - args.y)
2153 }
2154 }
2155
2156 fn streaming_tool_then_text_model() -> MockCompletionModel {
2157 MockCompletionModel::from_stream_turns([
2158 vec![
2159 MockStreamEvent::tool_call(
2160 "tool_call_1",
2161 "add",
2162 serde_json::json!({"x": 1, "y": 2}),
2163 )
2164 .with_call_id("call_1"),
2165 MockStreamEvent::final_response_with_total_tokens(4),
2166 ],
2167 vec![
2168 MockStreamEvent::text("done"),
2169 MockStreamEvent::final_response_with_total_tokens(6),
2170 ],
2171 ])
2172 }
2173
2174 fn usage(input_tokens: u64, output_tokens: u64) -> Usage {
2175 Usage {
2176 input_tokens,
2177 output_tokens,
2178 total_tokens: input_tokens + output_tokens,
2179 cached_input_tokens: 0,
2180 cache_creation_input_tokens: 0,
2181 tool_use_prompt_tokens: 0,
2182 reasoning_tokens: 0,
2183 }
2184 }
2185
2186 #[derive(Clone, Debug, Default)]
2187 struct CapturedSpan {
2188 id: u64,
2189 name: String,
2190 parent_id: Option<u64>,
2191 fields: HashMap<String, u64>,
2192 }
2193
2194 #[derive(Clone, Default)]
2195 struct CapturedSpans(Arc<Mutex<Vec<CapturedSpan>>>);
2196
2197 impl CapturedSpans {
2198 fn insert(&self, id: &Id, name: &str, parent_id: Option<u64>) {
2199 let id = id.into_u64();
2200 if let Ok(mut spans) = self.0.lock() {
2201 spans.push(CapturedSpan {
2202 id,
2203 name: name.to_string(),
2204 parent_id,
2205 fields: HashMap::new(),
2206 });
2207 }
2208 }
2209
2210 fn record(&self, id: &Id, fields: Vec<(String, u64)>) {
2211 if let Ok(mut spans) = self.0.lock()
2212 && let Some(span) = spans.iter_mut().rev().find(|span| span.id == id.into_u64())
2213 {
2214 span.fields.extend(fields);
2215 }
2216 }
2217
2218 fn snapshot(&self) -> Vec<CapturedSpan> {
2219 self.0.lock().map(|spans| spans.clone()).unwrap_or_default()
2220 }
2221 }
2222
2223 struct SpanCaptureLayer {
2224 spans: CapturedSpans,
2225 }
2226
2227 impl<S> Layer<S> for SpanCaptureLayer
2228 where
2229 S: Subscriber,
2230 S: for<'lookup> LookupSpan<'lookup>,
2231 {
2232 fn on_new_span(&self, attrs: &tracing::span::Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
2233 let parent_id = attrs
2234 .parent()
2235 .map(Id::into_u64)
2236 .or_else(|| ctx.current_span().id().map(Id::into_u64));
2237 self.spans.insert(id, attrs.metadata().name(), parent_id);
2238 }
2239
2240 fn on_record(&self, span: &Id, values: &tracing::span::Record<'_>, _ctx: Context<'_, S>) {
2241 let mut fields = Vec::new();
2242 values.record(&mut SpanFieldCaptureVisitor {
2243 fields: &mut fields,
2244 });
2245 self.spans.record(span, fields);
2246 }
2247 }
2248
2249 struct SpanFieldCaptureVisitor<'a> {
2250 fields: &'a mut Vec<(String, u64)>,
2251 }
2252
2253 impl Visit for SpanFieldCaptureVisitor<'_> {
2254 fn record_u64(&mut self, field: &Field, value: u64) {
2255 self.fields.push((field.name().to_string(), value));
2256 }
2257
2258 fn record_debug(&mut self, _field: &Field, _value: &dyn std::fmt::Debug) {}
2259 }
2260
2261 async fn assert_stream_usage_recorded_on_chat_spans(
2262 agent: crate::agent::Agent<MockCompletionModel>,
2263 prompt: &str,
2264 max_turns: usize,
2265 expected_usages: &[Usage],
2266 ) {
2267 let spans = CapturedSpans::default();
2268 let subscriber = Registry::default().with(SpanCaptureLayer {
2269 spans: spans.clone(),
2270 });
2271 let _default = tracing::subscriber::set_default(subscriber);
2272 let empty_history: &[Message] = &[];
2273 let outer_span = tracing::info_span!("outer");
2274
2275 async {
2276 let mut stream = agent
2277 .stream_prompt(prompt)
2278 .with_history(empty_history)
2279 .multi_turn(max_turns)
2280 .await;
2281
2282 while let Some(item) = stream.try_next().await.expect("stream should not error") {
2283 if matches!(item, MultiTurnStreamItem::FinalResponse(_)) {
2284 break;
2285 }
2286 }
2287 }
2288 .instrument(outer_span)
2289 .await;
2290
2291 let span_snapshot = spans.snapshot();
2292 let outer_span_id = span_snapshot
2293 .iter()
2294 .find(|span| span.name == "outer")
2295 .map(|span| span.id)
2296 .expect("outer span should be captured");
2297 let chat_spans = span_snapshot
2298 .iter()
2299 .filter(|span| span.name == "chat_streaming")
2300 .collect::<Vec<_>>();
2301
2302 assert_eq!(chat_spans.len(), expected_usages.len());
2303 assert!(
2304 span_snapshot.iter().all(|span| span.name != "invoke_agent"),
2305 "outer span path should not create invoke_agent"
2306 );
2307
2308 for (chat_span, expected_usage) in chat_spans.into_iter().zip(expected_usages) {
2309 assert_eq!(chat_span.parent_id, Some(outer_span_id));
2310 assert_eq!(
2311 chat_span.fields.get("gen_ai.usage.input_tokens"),
2312 Some(&expected_usage.input_tokens)
2313 );
2314 assert_eq!(
2315 chat_span.fields.get("gen_ai.usage.output_tokens"),
2316 Some(&expected_usage.output_tokens)
2317 );
2318 assert_eq!(
2319 chat_span.fields.get("gen_ai.usage.cache_read.input_tokens"),
2320 Some(&expected_usage.cached_input_tokens)
2321 );
2322 assert_eq!(
2323 chat_span
2324 .fields
2325 .get("gen_ai.usage.cache_creation.input_tokens"),
2326 Some(&expected_usage.cache_creation_input_tokens)
2327 );
2328 assert_eq!(
2329 chat_span.fields.get("gen_ai.usage.tool_use_prompt_tokens"),
2330 Some(&expected_usage.tool_use_prompt_tokens)
2331 );
2332 assert_eq!(
2333 chat_span.fields.get("gen_ai.usage.reasoning_tokens"),
2334 Some(&expected_usage.reasoning_tokens)
2335 );
2336 }
2337
2338 let outer_span = span_snapshot
2339 .iter()
2340 .find(|span| span.id == outer_span_id)
2341 .expect("outer span should be present");
2342 assert!(
2343 outer_span
2344 .fields
2345 .keys()
2346 .all(|field| !field.starts_with("gen_ai.usage.")),
2347 "usage should not be recorded onto the caller's outer span"
2348 );
2349 }
2350
2351 #[test]
2352 fn completion_calls_stream_item_serializes_and_deserializes_expected_shape() {
2353 let item: MultiTurnStreamItem<MockResponse> =
2354 MultiTurnStreamItem::CompletionCall(CompletionCall::new(2, Some(usage(3, 4))));
2355
2356 let value = serde_json::to_value(&item).expect("serialize completion call event");
2357
2358 assert_eq!(
2359 value,
2360 serde_json::json!({
2361 "type": "completionCall",
2362 "call_index": 2,
2363 "usage": {
2364 "input_tokens": 3,
2365 "output_tokens": 4,
2366 "total_tokens": 7,
2367 "cached_input_tokens": 0,
2368 "cache_creation_input_tokens": 0,
2369 "tool_use_prompt_tokens": 0,
2370 "reasoning_tokens": 0,
2371 }
2372 })
2373 );
2374
2375 let item: MultiTurnStreamItem<MockResponse> =
2376 serde_json::from_value(value).expect("deserialize completion call event");
2377 match item {
2378 MultiTurnStreamItem::CompletionCall(call_usage) => {
2379 assert_eq!(call_usage, CompletionCall::new(2, Some(usage(3, 4))));
2380 }
2381 other => panic!("expected completion call event, got {other:?}"),
2382 }
2383
2384 let item: MultiTurnStreamItem<MockResponse> =
2385 MultiTurnStreamItem::CompletionCall(CompletionCall::new(3, None));
2386 let value = serde_json::to_value(&item).expect("serialize missing usage event");
2387
2388 assert_eq!(
2389 value,
2390 serde_json::json!({
2391 "type": "completionCall",
2392 "call_index": 3,
2393 "usage": null
2394 })
2395 );
2396 }
2397
2398 #[test]
2399 fn final_response_serializes_completion_calls_with_missing_usage() {
2400 let item: MultiTurnStreamItem<MockResponse> =
2401 MultiTurnStreamItem::final_response_with_completion_calls(
2402 OneOrMany::one(AssistantContent::text("done")),
2403 usage(3, 4),
2404 vec![
2405 CompletionCall::new(0, None),
2406 CompletionCall::new(1, Some(usage(3, 4))),
2407 ],
2408 None,
2409 );
2410
2411 let value = serde_json::to_value(&item).expect("serialize final response");
2412
2413 assert_eq!(
2414 value.get("completionCalls"),
2415 Some(&serde_json::json!([
2416 {
2417 "call_index": 0,
2418 "usage": null,
2419 },
2420 {
2421 "call_index": 1,
2422 "usage": {
2423 "input_tokens": 3,
2424 "output_tokens": 4,
2425 "total_tokens": 7,
2426 "cached_input_tokens": 0,
2427 "cache_creation_input_tokens": 0,
2428 "tool_use_prompt_tokens": 0,
2429 "reasoning_tokens": 0,
2430 }
2431 }
2432 ]))
2433 );
2434 }
2435
2436 fn streaming_text_then_final_model() -> MockCompletionModel {
2437 MockCompletionModel::from_stream_turns([[
2438 MockStreamEvent::text("hello"),
2439 MockStreamEvent::text(" world"),
2440 MockStreamEvent::final_response_with_total_tokens(3),
2441 ]])
2442 }
2443
2444 fn citation_metadata() -> serde_json::Value {
2445 serde_json::json!({
2446 "citations": [{
2447 "type": "web_search_result_location",
2448 "cited_text": "Claude Shannon was born in 1916.",
2449 "url": "https://example.com/shannon",
2450 "title": "Claude Shannon",
2451 "encrypted_index": "encrypted-reference"
2452 }]
2453 })
2454 }
2455
2456 fn streaming_cited_text_then_final_model() -> MockCompletionModel {
2457 MockCompletionModel::from_stream_turns([[
2458 MockStreamEvent::text_start(Some(citation_metadata())),
2459 MockStreamEvent::text("cited "),
2460 MockStreamEvent::text_start(None),
2461 MockStreamEvent::text("answer"),
2462 MockStreamEvent::final_response_with_total_tokens(3),
2463 ]])
2464 }
2465
2466 fn streaming_cited_text_then_tool_model() -> MockCompletionModel {
2467 MockCompletionModel::from_stream_turns([
2468 vec![
2469 MockStreamEvent::text_start(Some(citation_metadata())),
2470 MockStreamEvent::text("I need a tool. "),
2471 MockStreamEvent::tool_call(
2472 "tool_call_1",
2473 "add",
2474 serde_json::json!({"x": 1, "y": 2}),
2475 )
2476 .with_call_id("call_1"),
2477 MockStreamEvent::final_response_with_total_tokens(4),
2478 ],
2479 vec![
2480 MockStreamEvent::text("done"),
2481 MockStreamEvent::final_response_with_total_tokens(6),
2482 ],
2483 ])
2484 }
2485
2486 fn streaming_final_only_model() -> MockCompletionModel {
2487 MockCompletionModel::from_stream_turns([[
2488 MockStreamEvent::final_response_with_total_tokens(1),
2489 ]])
2490 }
2491
2492 #[derive(Clone)]
2493 struct TerminateOnStreamFinish;
2494
2495 impl PromptHook<MockCompletionModel> for TerminateOnStreamFinish {
2496 async fn on_stream_completion_response_finish(
2497 &self,
2498 _prompt: &Message,
2499 _response: &<MockCompletionModel as CompletionModel>::StreamingResponse,
2500 ) -> HookAction {
2501 HookAction::terminate("stop after completion call")
2502 }
2503 }
2504
2505 type RecordedToolCallDelta = (String, String, Option<String>, String);
2506
2507 #[derive(Clone)]
2508 struct RepairDefaultApiHook;
2509
2510 impl PromptHook<MockCompletionModel> for RepairDefaultApiHook {
2511 fn on_invalid_tool_call(
2512 &self,
2513 context: &InvalidToolCallContext,
2514 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2515 let tool_name = context.tool_name.clone();
2516 async move {
2517 assert_eq!(tool_name, "default_api");
2518 InvalidToolCallHookAction::repair("add")
2519 }
2520 }
2521 }
2522
2523 #[derive(Clone)]
2524 struct RetryDefaultApiHook;
2525
2526 impl PromptHook<MockCompletionModel> for RetryDefaultApiHook {
2527 fn on_invalid_tool_call(
2528 &self,
2529 context: &InvalidToolCallContext,
2530 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2531 let tool_name = context.tool_name.clone();
2532 let args = context.args.clone();
2533 async move {
2534 assert_eq!(tool_name, "default_api");
2535 if let Some(args) = args {
2536 assert!(!args.is_empty());
2537 }
2538 InvalidToolCallHookAction::retry("Use the add tool instead")
2539 }
2540 }
2541 }
2542
2543 #[derive(Clone)]
2544 struct SkipDefaultApiHook;
2545
2546 impl PromptHook<MockCompletionModel> for SkipDefaultApiHook {
2547 fn on_invalid_tool_call(
2548 &self,
2549 context: &InvalidToolCallContext,
2550 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2551 let tool_name = context.tool_name.clone();
2552 async move {
2553 assert_eq!(tool_name, "default_api");
2554 InvalidToolCallHookAction::skip("default_api was skipped")
2555 }
2556 }
2557 }
2558
2559 #[derive(Clone, Default)]
2560 struct RecordingInvalidToolCallHook {
2561 contexts: Arc<Mutex<Vec<InvalidToolCallContext>>>,
2562 }
2563
2564 impl RecordingInvalidToolCallHook {
2565 fn observed(&self) -> Vec<InvalidToolCallContext> {
2566 self.contexts
2567 .lock()
2568 .expect("invalid tool context records mutex was poisoned")
2569 .clone()
2570 }
2571 }
2572
2573 impl PromptHook<MockCompletionModel> for RecordingInvalidToolCallHook {
2574 fn on_invalid_tool_call(
2575 &self,
2576 context: &InvalidToolCallContext,
2577 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2578 let contexts = self.contexts.clone();
2579 let context = context.clone();
2580
2581 async move {
2582 contexts
2583 .lock()
2584 .expect("invalid tool context records mutex was poisoned")
2585 .push(context);
2586 InvalidToolCallHookAction::fail()
2587 }
2588 }
2589 }
2590
2591 #[derive(Clone, Default)]
2592 struct RecordingToolCallDeltaHook {
2593 deltas: Arc<Mutex<Vec<RecordedToolCallDelta>>>,
2594 }
2595
2596 impl RecordingToolCallDeltaHook {
2597 fn observed(&self) -> Vec<RecordedToolCallDelta> {
2598 self.deltas
2599 .lock()
2600 .expect("tool call delta hook records mutex was poisoned")
2601 .clone()
2602 }
2603 }
2604
2605 impl PromptHook<MockCompletionModel> for RecordingToolCallDeltaHook {
2606 fn on_tool_call_delta(
2607 &self,
2608 tool_call_id: &str,
2609 internal_call_id: &str,
2610 tool_name: Option<&str>,
2611 tool_call_delta: &str,
2612 ) -> impl Future<Output = HookAction> + Send {
2613 let deltas = self.deltas.clone();
2614 let event = (
2615 tool_call_id.to_string(),
2616 internal_call_id.to_string(),
2617 tool_name.map(str::to_string),
2618 tool_call_delta.to_string(),
2619 );
2620
2621 async move {
2622 deltas
2623 .lock()
2624 .expect("tool call delta hook records mutex was poisoned")
2625 .push(event);
2626 HookAction::cont()
2627 }
2628 }
2629 }
2630
2631 #[derive(Clone, Default)]
2632 struct RecordingTextDeltaHook {
2633 deltas: Arc<Mutex<Vec<(String, String)>>>,
2634 }
2635
2636 impl RecordingTextDeltaHook {
2637 fn observed(&self) -> Vec<(String, String)> {
2638 self.deltas
2639 .lock()
2640 .expect("text delta hook records mutex was poisoned")
2641 .clone()
2642 }
2643 }
2644
2645 impl PromptHook<MockCompletionModel> for RecordingTextDeltaHook {
2646 fn on_text_delta(
2647 &self,
2648 text_delta: &str,
2649 full_text: &str,
2650 ) -> impl Future<Output = HookAction> + Send {
2651 let deltas = self.deltas.clone();
2652 let event = (text_delta.to_string(), full_text.to_string());
2653
2654 async move {
2655 deltas
2656 .lock()
2657 .expect("text delta hook records mutex was poisoned")
2658 .push(event);
2659 HookAction::cont()
2660 }
2661 }
2662 }
2663
2664 #[derive(Clone)]
2665 struct RecordingTextAndSkipInvalidToolHook {
2666 text: RecordingTextDeltaHook,
2667 }
2668
2669 impl PromptHook<MockCompletionModel> for RecordingTextAndSkipInvalidToolHook {
2670 fn on_text_delta(
2671 &self,
2672 text_delta: &str,
2673 full_text: &str,
2674 ) -> impl Future<Output = HookAction> + Send {
2675 self.text.on_text_delta(text_delta, full_text)
2676 }
2677
2678 fn on_invalid_tool_call(
2679 &self,
2680 context: &InvalidToolCallContext,
2681 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2682 SkipDefaultApiHook.on_invalid_tool_call(context)
2683 }
2684 }
2685
2686 #[derive(Clone)]
2687 struct RecordingTextAndRetryInvalidToolHook {
2688 text: RecordingTextDeltaHook,
2689 }
2690
2691 impl PromptHook<MockCompletionModel> for RecordingTextAndRetryInvalidToolHook {
2692 fn on_text_delta(
2693 &self,
2694 text_delta: &str,
2695 full_text: &str,
2696 ) -> impl Future<Output = HookAction> + Send {
2697 self.text.on_text_delta(text_delta, full_text)
2698 }
2699
2700 fn on_invalid_tool_call(
2701 &self,
2702 context: &InvalidToolCallContext,
2703 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2704 RetryDefaultApiHook.on_invalid_tool_call(context)
2705 }
2706 }
2707
2708 #[derive(Clone)]
2709 struct RecordingDeltaAndRetryInvalidToolHook {
2710 delta: RecordingToolCallDeltaHook,
2711 }
2712
2713 impl PromptHook<MockCompletionModel> for RecordingDeltaAndRetryInvalidToolHook {
2714 fn on_tool_call_delta(
2715 &self,
2716 tool_call_id: &str,
2717 internal_call_id: &str,
2718 tool_name: Option<&str>,
2719 tool_call_delta: &str,
2720 ) -> impl Future<Output = HookAction> + Send {
2721 self.delta.on_tool_call_delta(
2722 tool_call_id,
2723 internal_call_id,
2724 tool_name,
2725 tool_call_delta,
2726 )
2727 }
2728
2729 fn on_invalid_tool_call(
2730 &self,
2731 context: &InvalidToolCallContext,
2732 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2733 RetryDefaultApiHook.on_invalid_tool_call(context)
2734 }
2735 }
2736
2737 #[derive(Clone)]
2738 struct RecordingDeltaAndSkipInvalidToolHook {
2739 delta: RecordingToolCallDeltaHook,
2740 }
2741
2742 impl PromptHook<MockCompletionModel> for RecordingDeltaAndSkipInvalidToolHook {
2743 fn on_tool_call_delta(
2744 &self,
2745 tool_call_id: &str,
2746 internal_call_id: &str,
2747 tool_name: Option<&str>,
2748 tool_call_delta: &str,
2749 ) -> impl Future<Output = HookAction> + Send {
2750 self.delta.on_tool_call_delta(
2751 tool_call_id,
2752 internal_call_id,
2753 tool_name,
2754 tool_call_delta,
2755 )
2756 }
2757
2758 fn on_invalid_tool_call(
2759 &self,
2760 context: &InvalidToolCallContext,
2761 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2762 SkipDefaultApiHook.on_invalid_tool_call(context)
2763 }
2764 }
2765
2766 #[derive(Clone, Default)]
2767 struct TerminatingToolCallDeltaHook {
2768 deltas: Arc<Mutex<Vec<RecordedToolCallDelta>>>,
2769 }
2770
2771 impl TerminatingToolCallDeltaHook {
2772 fn observed(&self) -> Vec<RecordedToolCallDelta> {
2773 self.deltas
2774 .lock()
2775 .expect("tool call delta hook records mutex was poisoned")
2776 .clone()
2777 }
2778 }
2779
2780 impl PromptHook<MockCompletionModel> for TerminatingToolCallDeltaHook {
2781 fn on_tool_call_delta(
2782 &self,
2783 tool_call_id: &str,
2784 internal_call_id: &str,
2785 tool_name: Option<&str>,
2786 tool_call_delta: &str,
2787 ) -> impl Future<Output = HookAction> + Send {
2788 let deltas = self.deltas.clone();
2789 let event = (
2790 tool_call_id.to_string(),
2791 internal_call_id.to_string(),
2792 tool_name.map(str::to_string),
2793 tool_call_delta.to_string(),
2794 );
2795
2796 async move {
2797 deltas
2798 .lock()
2799 .expect("tool call delta hook records mutex was poisoned")
2800 .push(event);
2801 HookAction::terminate("stop on tool call delta")
2802 }
2803 }
2804 }
2805
2806 fn text_metadata(content: &OneOrMany<AssistantContent>) -> Option<&serde_json::Value> {
2807 content.iter().find_map(|item| match item {
2808 AssistantContent::Text(text) => text.additional_params.as_ref(),
2809 _ => None,
2810 })
2811 }
2812
2813 #[tokio::test]
2814 async fn stream_prompt_continues_after_tool_call_turn() {
2815 let model = streaming_tool_then_text_model();
2816 let recorded = model.clone();
2817 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2818 let empty_history: &[Message] = &[];
2819
2820 let mut stream = agent
2821 .stream_prompt("do tool work")
2822 .with_history(empty_history)
2823 .multi_turn(3)
2824 .await;
2825 let mut saw_tool_call = false;
2826 let mut saw_tool_result = false;
2827 let mut saw_final_response = false;
2828 let mut final_text = String::new();
2829 let mut final_response_text = None;
2830 let mut final_history = None;
2831
2832 while let Some(item) = stream.next().await {
2833 match item {
2834 Ok(MultiTurnStreamItem::StreamAssistantItem(
2835 StreamedAssistantContent::ToolCall { .. },
2836 )) => {
2837 saw_tool_call = true;
2838 }
2839 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2840 ..
2841 })) => {
2842 saw_tool_result = true;
2843 }
2844 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
2845 text,
2846 ))) => {
2847 final_text.push_str(&text.text);
2848 }
2849 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
2850 saw_final_response = true;
2851 final_response_text = Some(res.response().to_owned());
2852 final_history = res.history().map(|history| history.to_vec());
2853 break;
2854 }
2855 Ok(_) => {}
2856 Err(err) => panic!("unexpected streaming error: {err:?}"),
2857 }
2858 }
2859
2860 assert!(saw_tool_call);
2861 assert!(saw_tool_result);
2862 assert!(saw_final_response);
2863 assert_eq!(final_text, "done");
2864 assert_eq!(final_response_text.as_deref(), Some("done"));
2865 let history = final_history.expect("expected final response history");
2866 assert!(history.iter().any(|message| matches!(
2867 message,
2868 Message::Assistant { content, .. }
2869 if content.iter().any(|item| matches!(
2870 item,
2871 AssistantContent::Text(text) if text.text == "done"
2872 ))
2873 )));
2874 let requests = recorded.requests();
2875 assert_eq!(requests.len(), 2);
2876 assert!(validate_follow_up_tool_history(&requests[1]).is_ok());
2877 }
2878
2879 #[tokio::test]
2880 async fn unknown_tool_call_fails_before_streaming_second_request() {
2881 let model = MockCompletionModel::from_stream_turns([
2882 vec![
2883 MockStreamEvent::tool_call(
2884 "tool_call_1",
2885 "default_api",
2886 serde_json::json!({"x": 1, "y": 2}),
2887 ),
2888 MockStreamEvent::final_response_with_total_tokens(4),
2889 ],
2890 vec![
2891 MockStreamEvent::text("should not be requested"),
2892 MockStreamEvent::final_response_with_total_tokens(6),
2893 ],
2894 ]);
2895 let recorded = model.clone();
2896 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2897
2898 let mut stream = agent
2899 .stream_prompt("use the tool")
2900 .with_hook(PanicOnUnknownToolHook)
2901 .multi_turn(3)
2902 .await;
2903 let mut saw_tool_call = false;
2904 let mut error = None;
2905
2906 while let Some(item) = stream.next().await {
2907 match item {
2908 Ok(MultiTurnStreamItem::StreamAssistantItem(
2909 StreamedAssistantContent::ToolCall { .. },
2910 )) => {
2911 saw_tool_call = true;
2912 }
2913 Ok(_) => {}
2914 Err(err) => {
2915 error = Some(err);
2916 break;
2917 }
2918 }
2919 }
2920
2921 assert!(!saw_tool_call);
2922 let error = error.expect("unknown model-emitted tool should fail");
2923 match error {
2924 StreamingError::Prompt(err) => match *err {
2925 PromptError::UnknownToolCall {
2926 tool_name,
2927 available_tools,
2928 allowed_tools,
2929 chat_history,
2930 } => {
2931 assert_eq!(tool_name, "default_api");
2932 assert_eq!(available_tools, vec!["add".to_string()]);
2933 assert_eq!(allowed_tools, vec!["add".to_string()]);
2934 assert!(history_contains_tool_call(&chat_history, "default_api"));
2935 }
2936 other => panic!("expected UnknownToolCall, got {other:?}"),
2937 },
2938 other => panic!("expected prompt streaming error, got {other:?}"),
2939 }
2940 assert_eq!(recorded.request_count(), 1);
2941 }
2942
2943 #[tokio::test]
2944 async fn invalid_tool_call_hook_can_repair_streaming_tool_name() {
2945 let model = MockCompletionModel::from_stream_turns([
2946 vec![
2947 MockStreamEvent::tool_call(
2948 "tool_call_1",
2949 "default_api",
2950 serde_json::json!({"x": 2, "y": 3}),
2951 ),
2952 MockStreamEvent::final_response_with_total_tokens(4),
2953 ],
2954 vec![
2955 MockStreamEvent::text("done"),
2956 MockStreamEvent::final_response_with_total_tokens(6),
2957 ],
2958 ]);
2959 let recorded = model.clone();
2960 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2961
2962 let mut stream = agent
2963 .stream_prompt("use the tool")
2964 .with_hook(RepairDefaultApiHook)
2965 .multi_turn(3)
2966 .with_history(Vec::<Message>::new())
2967 .await;
2968 let mut saw_repaired_tool_call = false;
2969 let mut saw_tool_result = false;
2970 let mut final_response_text = None;
2971
2972 while let Some(item) = stream.next().await {
2973 match item {
2974 Ok(MultiTurnStreamItem::StreamAssistantItem(
2975 StreamedAssistantContent::ToolCall { tool_call, .. },
2976 )) => {
2977 assert_eq!(tool_call.function.name, "add");
2978 saw_repaired_tool_call = true;
2979 }
2980 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2981 tool_result,
2982 ..
2983 })) => {
2984 assert!(tool_result.content.iter().any(|content| {
2985 matches!(
2986 content,
2987 ToolResultContent::Text(text) if text.text == "5"
2988 )
2989 }));
2990 saw_tool_result = true;
2991 }
2992 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2993 final_response_text = Some(response.response().to_string());
2994 break;
2995 }
2996 Ok(_) => {}
2997 Err(err) => panic!("unexpected streaming error: {err:?}"),
2998 }
2999 }
3000
3001 assert!(saw_repaired_tool_call);
3002 assert!(saw_tool_result);
3003 assert_eq!(final_response_text.as_deref(), Some("done"));
3004 assert_eq!(recorded.request_count(), 2);
3005 }
3006
3007 #[tokio::test]
3008 async fn invalid_tool_call_context_uses_completed_streaming_tool_call_provider_id() {
3009 let invalid_hook = RecordingInvalidToolCallHook::default();
3010 let model = MockCompletionModel::from_stream_turns([
3011 vec![
3012 MockStreamEvent::tool_call(
3013 "tool_call_1",
3014 "default_api",
3015 serde_json::json!({"x": 2, "y": 3}),
3016 )
3017 .with_call_id("provider_call_1"),
3018 MockStreamEvent::final_response_with_total_tokens(4),
3019 ],
3020 vec![
3021 MockStreamEvent::text("should not be requested"),
3022 MockStreamEvent::final_response_with_total_tokens(6),
3023 ],
3024 ]);
3025 let recorded = model.clone();
3026 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3027
3028 let mut stream = agent
3029 .stream_prompt("use the tool")
3030 .with_hook(invalid_hook.clone())
3031 .multi_turn(3)
3032 .await;
3033 let mut error = None;
3034
3035 while let Some(item) = stream.next().await {
3036 if let Err(err) = item {
3037 error = Some(err);
3038 break;
3039 }
3040 }
3041
3042 assert!(error.is_some(), "invalid tool should fail");
3043 assert_eq!(recorded.request_count(), 1);
3044 let contexts = invalid_hook.observed();
3045 assert_eq!(contexts.len(), 1);
3046 let context = &contexts[0];
3047 assert_eq!(context.tool_name, "default_api");
3048 assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
3049 assert!(context.internal_call_id.is_some());
3050 assert!(context.is_streaming);
3051 }
3052
3053 #[tokio::test]
3054 async fn invalid_tool_call_hook_skip_emits_streaming_tool_result() {
3055 let add_calls = Arc::new(AtomicU32::new(0));
3056 let model = MockCompletionModel::from_stream_turns([
3057 vec![
3058 MockStreamEvent::tool_call(
3059 "tool_call_1",
3060 "default_api",
3061 serde_json::json!({"x": 2, "y": 3}),
3062 )
3063 .with_call_id("call_1"),
3064 MockStreamEvent::final_response_with_total_tokens(4),
3065 ],
3066 vec![
3067 MockStreamEvent::text("continued"),
3068 MockStreamEvent::final_response_with_total_tokens(6),
3069 ],
3070 ]);
3071 let recorded = model.clone();
3072 let agent = AgentBuilder::new(model)
3073 .tool(CountingAddTool {
3074 calls: add_calls.clone(),
3075 })
3076 .build();
3077
3078 let mut stream = agent
3079 .stream_prompt("use the tool")
3080 .with_hook(SkipDefaultApiHook)
3081 .multi_turn(3)
3082 .with_history(Vec::<Message>::new())
3083 .await;
3084 let mut skipped_tool_result = None;
3085 let mut final_response_text = None;
3086
3087 while let Some(item) = stream.next().await {
3088 match item {
3089 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3090 tool_result,
3091 internal_call_id,
3092 })) => {
3093 assert!(!internal_call_id.is_empty());
3094 skipped_tool_result = Some(tool_result);
3095 }
3096 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3097 final_response_text = Some(response.response().to_string());
3098 break;
3099 }
3100 Ok(_) => {}
3101 Err(err) => panic!("unexpected streaming error: {err:?}"),
3102 }
3103 }
3104
3105 let skipped_tool_result =
3106 skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
3107 assert_eq!(skipped_tool_result.id, "tool_call_1");
3108 assert_eq!(skipped_tool_result.call_id.as_deref(), Some("call_1"));
3109 assert!(skipped_tool_result.content.iter().any(|content| matches!(
3110 content,
3111 ToolResultContent::Text(text) if text.text == "default_api was skipped"
3112 )));
3113 assert_eq!(final_response_text.as_deref(), Some("continued"));
3114 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3115
3116 let requests = recorded.requests();
3117 assert_eq!(requests.len(), 2);
3118 let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3119 assert!(matches!(
3120 follow_up_history.get(2),
3121 Some(Message::User { content })
3122 if content.iter().any(|item| matches!(
3123 item,
3124 UserContent::ToolResult(result)
3125 if result.id == "tool_call_1"
3126 && result.content.iter().any(|content| matches!(
3127 content,
3128 ToolResultContent::Text(text)
3129 if text.text == "default_api was skipped"
3130 ))
3131 ))
3132 ));
3133 }
3134
3135 #[tokio::test]
3136 async fn invalid_tool_call_hook_retries_mixed_streaming_turn_without_executing_valid_call() {
3137 let add_calls = Arc::new(AtomicU32::new(0));
3138 let model = MockCompletionModel::from_stream_turns([
3139 vec![
3140 MockStreamEvent::text("checking "),
3141 MockStreamEvent::tool_call(
3142 "tool_call_1",
3143 "add",
3144 serde_json::json!({"x": 2, "y": 3}),
3145 )
3146 .with_call_id("call_1"),
3147 MockStreamEvent::tool_call(
3148 "tool_call_2",
3149 "default_api",
3150 serde_json::json!({"x": 4, "y": 5}),
3151 )
3152 .with_call_id("call_2"),
3153 MockStreamEvent::final_response_with_total_tokens(4),
3154 ],
3155 vec![
3156 MockStreamEvent::text("retried"),
3157 MockStreamEvent::final_response_with_total_tokens(6),
3158 ],
3159 ]);
3160 let recorded = model.clone();
3161 let agent = AgentBuilder::new(model)
3162 .tool(CountingAddTool {
3163 calls: add_calls.clone(),
3164 })
3165 .build();
3166
3167 let mut stream = agent
3168 .stream_prompt("use the tool")
3169 .with_hook(RetryDefaultApiHook)
3170 .multi_turn(3)
3171 .with_history(Vec::<Message>::new())
3172 .max_invalid_tool_call_retries(1)
3173 .await;
3174 let mut completion_call_events = Vec::new();
3175 let mut final_response_text = None;
3176 let mut final_response_usage = Usage::new();
3177 let mut final_completion_calls = Vec::new();
3178
3179 while let Some(item) = stream.next().await {
3180 match item {
3181 Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
3182 completion_call_events.push(completion_call);
3183 }
3184 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3185 final_response_text = Some(response.response().to_string());
3186 final_response_usage = response.usage();
3187 final_completion_calls = response.completion_calls().to_vec();
3188 break;
3189 }
3190 Ok(_) => {}
3191 Err(err) => panic!("unexpected streaming error: {err:?}"),
3192 }
3193 }
3194
3195 assert_eq!(final_response_text.as_deref(), Some("retried"));
3196 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3197 let mut first_usage = Usage::new();
3198 first_usage.total_tokens = 4;
3199 let mut second_usage = Usage::new();
3200 second_usage.total_tokens = 6;
3201 let expected_completion_calls = vec![
3202 CompletionCall::new(0, Some(first_usage)),
3203 CompletionCall::new(1, Some(second_usage)),
3204 ];
3205 assert_eq!(completion_call_events, expected_completion_calls);
3206 assert_eq!(final_completion_calls, expected_completion_calls);
3207 assert_eq!(final_response_usage.total_tokens, 10);
3208
3209 let requests = recorded.requests();
3210 assert_eq!(requests.len(), 2);
3211 let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3212 assert_eq!(retry_history.len(), 3);
3213 assert!(matches!(
3214 retry_history.get(1),
3215 Some(Message::Assistant { content, .. })
3216 if content.iter().any(|item| matches!(
3217 item,
3218 AssistantContent::Text(text) if text.text == "checking "
3219 ))
3220 && content.iter().any(|item| matches!(
3221 item,
3222 AssistantContent::ToolCall(tool_call)
3223 if tool_call.id == "tool_call_1"
3224 && tool_call.function.name == "add"
3225 ))
3226 && content.iter().any(|item| matches!(
3227 item,
3228 AssistantContent::ToolCall(tool_call)
3229 if tool_call.id == "tool_call_2"
3230 && tool_call.function.name == "default_api"
3231 ))
3232 ));
3233 assert!(matches!(
3234 retry_history.get(2),
3235 Some(Message::User { content })
3236 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3237 && content.iter().any(|item| matches!(
3238 item,
3239 UserContent::ToolResult(result)
3240 if result.id == "tool_call_1"
3241 && result.content.iter().any(|content| matches!(
3242 content,
3243 ToolResultContent::Text(text)
3244 if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3245 ))
3246 ))
3247 && content.iter().any(|item| matches!(
3248 item,
3249 UserContent::ToolResult(result)
3250 if result.id == "tool_call_2"
3251 && result.content.iter().any(|content| matches!(
3252 content,
3253 ToolResultContent::Text(text)
3254 if text.text == "Use the add tool instead"
3255 ))
3256 ))
3257 ));
3258 }
3259
3260 #[tokio::test]
3261 async fn invalid_tool_call_hook_skips_mixed_streaming_turn_without_executing_valid_call() {
3262 let add_calls = Arc::new(AtomicU32::new(0));
3263 let model = MockCompletionModel::from_stream_turns([
3264 vec![
3265 MockStreamEvent::text("checking "),
3266 MockStreamEvent::tool_call(
3267 "tool_call_1",
3268 "add",
3269 serde_json::json!({"x": 2, "y": 3}),
3270 )
3271 .with_call_id("call_1"),
3272 MockStreamEvent::tool_call(
3273 "tool_call_2",
3274 "default_api",
3275 serde_json::json!({"x": 4, "y": 5}),
3276 )
3277 .with_call_id("call_2"),
3278 MockStreamEvent::final_response_with_total_tokens(4),
3279 ],
3280 vec![
3281 MockStreamEvent::text("continued"),
3282 MockStreamEvent::final_response_with_total_tokens(6),
3283 ],
3284 ]);
3285 let recorded = model.clone();
3286 let agent = AgentBuilder::new(model)
3287 .tool(CountingAddTool {
3288 calls: add_calls.clone(),
3289 })
3290 .build();
3291
3292 let mut stream = agent
3293 .stream_prompt("use the tool")
3294 .with_hook(SkipDefaultApiHook)
3295 .multi_turn(3)
3296 .with_history(Vec::<Message>::new())
3297 .await;
3298 let mut skipped_tool_result = None;
3299 let mut final_response_text = None;
3300
3301 while let Some(item) = stream.next().await {
3302 match item {
3303 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3304 tool_result,
3305 ..
3306 })) => {
3307 skipped_tool_result = Some(tool_result);
3308 }
3309 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3310 final_response_text = Some(response.response().to_string());
3311 break;
3312 }
3313 Ok(_) => {}
3314 Err(err) => panic!("unexpected streaming error: {err:?}"),
3315 }
3316 }
3317
3318 let skipped_tool_result =
3319 skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
3320 assert_eq!(skipped_tool_result.id, "tool_call_2");
3321 assert_eq!(skipped_tool_result.call_id.as_deref(), Some("call_2"));
3322 assert_eq!(final_response_text.as_deref(), Some("continued"));
3323 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3324
3325 let requests = recorded.requests();
3326 assert_eq!(requests.len(), 2);
3327 let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3328 assert_eq!(follow_up_history.len(), 3);
3329 assert!(matches!(
3330 follow_up_history.get(1),
3331 Some(Message::Assistant { content, .. })
3332 if content.iter().any(|item| matches!(
3333 item,
3334 AssistantContent::Text(text) if text.text == "checking "
3335 ))
3336 && content.iter().any(|item| matches!(
3337 item,
3338 AssistantContent::ToolCall(tool_call)
3339 if tool_call.id == "tool_call_1"
3340 && tool_call.function.name == "add"
3341 ))
3342 && content.iter().any(|item| matches!(
3343 item,
3344 AssistantContent::ToolCall(tool_call)
3345 if tool_call.id == "tool_call_2"
3346 && tool_call.function.name == "default_api"
3347 ))
3348 ));
3349 assert!(matches!(
3350 follow_up_history.get(2),
3351 Some(Message::User { content })
3352 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3353 && content.iter().any(|item| matches!(
3354 item,
3355 UserContent::ToolResult(result)
3356 if result.id == "tool_call_1"
3357 && result.call_id.as_deref() == Some("call_1")
3358 && result.content.iter().any(|content| matches!(
3359 content,
3360 ToolResultContent::Text(text)
3361 if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3362 ))
3363 ))
3364 && content.iter().any(|item| matches!(
3365 item,
3366 UserContent::ToolResult(result)
3367 if result.id == "tool_call_2"
3368 && result.call_id.as_deref() == Some("call_2")
3369 && result.content.iter().any(|content| matches!(
3370 content,
3371 ToolResultContent::Text(text)
3372 if text.text == "default_api was skipped"
3373 ))
3374 ))
3375 ));
3376 }
3377
3378 #[tokio::test]
3379 async fn invalid_completed_tool_call_skip_preserves_streaming_reasoning_history() {
3380 let model = MockCompletionModel::from_stream_turns([
3381 vec![
3382 MockStreamEvent::text("checking "),
3383 MockStreamEvent::reasoning("reasoned step").with_reasoning_id("rs_1"),
3384 MockStreamEvent::tool_call(
3385 "tool_call_1",
3386 "default_api",
3387 serde_json::json!({"x": 2, "y": 3}),
3388 ),
3389 MockStreamEvent::final_response_with_total_tokens(4),
3390 ],
3391 vec![
3392 MockStreamEvent::text("continued"),
3393 MockStreamEvent::final_response_with_total_tokens(6),
3394 ],
3395 ]);
3396 let recorded = model.clone();
3397 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3398
3399 let mut stream = agent
3400 .stream_prompt("use the tool")
3401 .with_hook(SkipDefaultApiHook)
3402 .multi_turn(3)
3403 .with_history(Vec::<Message>::new())
3404 .await;
3405
3406 while let Some(item) = stream.next().await {
3407 match item {
3408 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3409 Ok(_) => {}
3410 Err(err) => panic!("unexpected streaming error: {err:?}"),
3411 }
3412 }
3413
3414 let requests = recorded.requests();
3415 assert_eq!(requests.len(), 2);
3416 let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3417 assert!(history_contains_text(&follow_up_history, "checking "));
3418 assert!(assistant_reasoning_precedes_tool_call(
3419 &follow_up_history,
3420 "reasoned step",
3421 "default_api"
3422 ));
3423 }
3424
3425 #[tokio::test]
3426 async fn invalid_name_delta_retry_preserves_streaming_reasoning_history() {
3427 let model = MockCompletionModel::from_stream_turns([
3428 vec![
3429 MockStreamEvent::reasoning_delta(Some("rs_1"), "delta reason"),
3430 MockStreamEvent::tool_call_arguments_delta(
3431 "tool_call_1",
3432 "internal_1",
3433 r#"{"x":2,"y":3}"#,
3434 ),
3435 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3436 MockStreamEvent::final_response_with_total_tokens(4),
3437 ],
3438 vec![
3439 MockStreamEvent::text("retried"),
3440 MockStreamEvent::final_response_with_total_tokens(6),
3441 ],
3442 ]);
3443 let recorded = model.clone();
3444 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3445
3446 let mut stream = agent
3447 .stream_prompt("use the tool")
3448 .with_hook(RetryDefaultApiHook)
3449 .multi_turn(3)
3450 .with_history(Vec::<Message>::new())
3451 .max_invalid_tool_call_retries(1)
3452 .await;
3453
3454 while let Some(item) = stream.next().await {
3455 match item {
3456 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3457 Ok(_) => {}
3458 Err(err) => panic!("unexpected streaming error: {err:?}"),
3459 }
3460 }
3461
3462 let requests = recorded.requests();
3463 assert_eq!(requests.len(), 2);
3464 let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3465 assert!(assistant_reasoning_precedes_tool_call(
3466 &retry_history,
3467 "delta reason",
3468 "default_api"
3469 ));
3470 }
3471
3472 #[tokio::test]
3473 async fn invalid_tool_call_hook_skip_resets_streaming_text_delta_state() {
3474 let text_hook = RecordingTextDeltaHook::default();
3475 let model = MockCompletionModel::from_stream_turns([
3476 vec![
3477 MockStreamEvent::text("stale "),
3478 MockStreamEvent::tool_call(
3479 "tool_call_1",
3480 "default_api",
3481 serde_json::json!({"x": 2, "y": 3}),
3482 ),
3483 MockStreamEvent::final_response_with_total_tokens(4),
3484 ],
3485 vec![
3486 MockStreamEvent::text("fresh"),
3487 MockStreamEvent::final_response_with_total_tokens(6),
3488 ],
3489 ]);
3490 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3491
3492 let mut stream = agent
3493 .stream_prompt("use the tool")
3494 .with_hook(RecordingTextAndSkipInvalidToolHook {
3495 text: text_hook.clone(),
3496 })
3497 .multi_turn(3)
3498 .with_history(Vec::<Message>::new())
3499 .await;
3500
3501 while let Some(item) = stream.next().await {
3502 match item {
3503 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3504 Ok(_) => {}
3505 Err(err) => panic!("unexpected streaming error: {err:?}"),
3506 }
3507 }
3508
3509 assert_eq!(
3510 text_hook.observed(),
3511 vec![
3512 ("stale ".to_string(), "stale ".to_string()),
3513 ("fresh".to_string(), "fresh".to_string()),
3514 ]
3515 );
3516 }
3517
3518 #[tokio::test]
3519 async fn invalid_tool_call_delta_retry_uses_structured_tool_feedback() {
3520 let delta_hook = RecordingToolCallDeltaHook::default();
3521 let add_calls = Arc::new(AtomicU32::new(0));
3522 let model = MockCompletionModel::from_stream_turns([
3523 vec![
3524 MockStreamEvent::text("checking "),
3525 MockStreamEvent::reasoning_delta(Some("rs_1"), "diagnostic reason"),
3526 MockStreamEvent::tool_call(
3527 "tool_call_0",
3528 "add",
3529 serde_json::json!({"x": 1, "y": 2}),
3530 )
3531 .with_call_id("call_0"),
3532 MockStreamEvent::tool_call_arguments_delta(
3533 "tool_call_1",
3534 "internal_1",
3535 r#"{"x":2,"y":3}"#,
3536 ),
3537 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3538 MockStreamEvent::final_response_with_total_tokens(4),
3539 ],
3540 vec![
3541 MockStreamEvent::text("retried"),
3542 MockStreamEvent::final_response_with_total_tokens(6),
3543 ],
3544 ]);
3545 let recorded = model.clone();
3546 let agent = AgentBuilder::new(model)
3547 .tool(CountingAddTool {
3548 calls: add_calls.clone(),
3549 })
3550 .build();
3551
3552 let mut stream = agent
3553 .stream_prompt("use the tool")
3554 .with_hook(RecordingDeltaAndRetryInvalidToolHook {
3555 delta: delta_hook.clone(),
3556 })
3557 .multi_turn(3)
3558 .with_history(Vec::<Message>::new())
3559 .max_invalid_tool_call_retries(1)
3560 .await;
3561 let mut completion_call_events = Vec::new();
3562 let mut final_response_text = None;
3563 let mut final_response_usage = Usage::new();
3564 let mut final_completion_calls = Vec::new();
3565
3566 while let Some(item) = stream.next().await {
3567 match item {
3568 Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
3569 completion_call_events.push(completion_call);
3570 }
3571 Ok(MultiTurnStreamItem::StreamAssistantItem(
3572 StreamedAssistantContent::ToolCallDelta { .. },
3573 )) => panic!("invalid tool-call delta should not be emitted"),
3574 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3575 final_response_text = Some(response.response().to_string());
3576 final_response_usage = response.usage();
3577 final_completion_calls = response.completion_calls().to_vec();
3578 break;
3579 }
3580 Ok(_) => {}
3581 Err(err) => panic!("unexpected streaming error: {err:?}"),
3582 }
3583 }
3584
3585 assert_eq!(final_response_text.as_deref(), Some("retried"));
3586 assert!(delta_hook.observed().is_empty());
3587 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3588 let mut first_usage = Usage::new();
3589 first_usage.total_tokens = 4;
3590 let mut second_usage = Usage::new();
3591 second_usage.total_tokens = 6;
3592 let expected_completion_calls = vec![
3593 CompletionCall::new(0, Some(first_usage)),
3594 CompletionCall::new(1, Some(second_usage)),
3595 ];
3596 assert_eq!(completion_call_events, expected_completion_calls);
3597 assert_eq!(final_completion_calls, expected_completion_calls);
3598 assert_eq!(final_response_usage.total_tokens, 10);
3599
3600 let requests = recorded.requests();
3601 assert_eq!(requests.len(), 2);
3602 let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3603 assert!(matches!(
3604 retry_history.get(1),
3605 Some(Message::Assistant { content, .. })
3606 if content.iter().any(|item| matches!(
3607 item,
3608 AssistantContent::Text(text) if text.text == "checking "
3609 ))
3610 && content.iter().any(|item| matches!(
3611 item,
3612 AssistantContent::ToolCall(tool_call)
3613 if tool_call.id == "tool_call_0"
3614 && tool_call.function.name == "add"
3615 ))
3616 && content.iter().any(|item| matches!(
3617 item,
3618 AssistantContent::ToolCall(tool_call)
3619 if tool_call.id == "tool_call_1"
3620 && tool_call.function.name == "default_api"
3621 && tool_call.function.arguments == serde_json::json!({"x": 2, "y": 3})
3622 ))
3623 ));
3624 assert!(matches!(
3625 retry_history.get(2),
3626 Some(Message::User { content })
3627 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3628 && content.iter().any(|item| matches!(
3629 item,
3630 UserContent::ToolResult(result)
3631 if result.id == "tool_call_0"
3632 && result.call_id.as_deref() == Some("call_0")
3633 && result.content.iter().any(|content| matches!(
3634 content,
3635 ToolResultContent::Text(text)
3636 if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3637 ))
3638 ))
3639 && content.iter().any(|item| matches!(
3640 item,
3641 UserContent::ToolResult(result)
3642 if result.id == "tool_call_1"
3643 && result.content.iter().any(|content| matches!(
3644 content,
3645 ToolResultContent::Text(text)
3646 if text.text == "Use the add tool instead"
3647 ))
3648 ))
3649 ));
3650 }
3651
3652 #[tokio::test]
3653 async fn invalid_tool_call_delta_context_includes_same_turn_history_and_tool_call_id() {
3654 let invalid_hook = RecordingInvalidToolCallHook::default();
3655 let model = MockCompletionModel::from_stream_turns([
3656 vec![
3657 MockStreamEvent::text("checking "),
3658 MockStreamEvent::reasoning_delta(Some("rs_1"), "diagnostic reason"),
3659 MockStreamEvent::tool_call(
3660 "tool_call_0",
3661 "add",
3662 serde_json::json!({"x": 1, "y": 2}),
3663 )
3664 .with_call_id("call_0"),
3665 MockStreamEvent::tool_call_arguments_delta(
3666 "tool_call_1",
3667 "internal_1",
3668 r#"{"x":2,"y":3}"#,
3669 ),
3670 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3671 MockStreamEvent::final_response_with_total_tokens(4),
3672 ],
3673 vec![
3674 MockStreamEvent::text("should not be requested"),
3675 MockStreamEvent::final_response_with_total_tokens(6),
3676 ],
3677 ]);
3678 let recorded = model.clone();
3679 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3680
3681 let mut stream = agent
3682 .stream_prompt("use the tool")
3683 .with_hook(invalid_hook.clone())
3684 .multi_turn(3)
3685 .await;
3686 let mut error = None;
3687
3688 while let Some(item) = stream.next().await {
3689 if let Err(err) = item {
3690 error = Some(err);
3691 break;
3692 }
3693 }
3694
3695 assert!(error.is_some(), "invalid name delta should fail");
3696 assert_eq!(recorded.request_count(), 1);
3697 let contexts = invalid_hook.observed();
3698 assert_eq!(contexts.len(), 1);
3699 let context = &contexts[0];
3700 assert_eq!(context.tool_name, "default_api");
3701 assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
3702 assert_eq!(context.internal_call_id.as_deref(), Some("internal_1"));
3703 assert!(context.is_streaming);
3704 assert!(history_contains_text(&context.chat_history, "checking "));
3705 assert!(
3706 assistant_reasoning_precedes_tool_call(
3707 &context.chat_history,
3708 "diagnostic reason",
3709 "add"
3710 ),
3711 "{:?}",
3712 context.chat_history
3713 );
3714 assert!(history_contains_tool_call(&context.chat_history, "add"));
3715 assert!(history_contains_tool_call(
3716 &context.chat_history,
3717 "default_api"
3718 ));
3719 }
3720
3721 #[tokio::test]
3722 async fn invalid_tool_call_delta_retry_resets_streaming_text_delta_state() {
3723 let text_hook = RecordingTextDeltaHook::default();
3724 let model = MockCompletionModel::from_stream_turns([
3725 vec![
3726 MockStreamEvent::text("stale "),
3727 MockStreamEvent::tool_call_arguments_delta(
3728 "tool_call_1",
3729 "internal_1",
3730 r#"{"x":2,"y":3}"#,
3731 ),
3732 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3733 MockStreamEvent::final_response_with_total_tokens(4),
3734 ],
3735 vec![
3736 MockStreamEvent::text("fresh"),
3737 MockStreamEvent::final_response_with_total_tokens(6),
3738 ],
3739 ]);
3740 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3741
3742 let mut stream = agent
3743 .stream_prompt("use the tool")
3744 .with_hook(RecordingTextAndRetryInvalidToolHook {
3745 text: text_hook.clone(),
3746 })
3747 .multi_turn(3)
3748 .with_history(Vec::<Message>::new())
3749 .max_invalid_tool_call_retries(1)
3750 .await;
3751
3752 while let Some(item) = stream.next().await {
3753 match item {
3754 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3755 Ok(_) => {}
3756 Err(err) => panic!("unexpected streaming error: {err:?}"),
3757 }
3758 }
3759
3760 assert_eq!(
3761 text_hook.observed(),
3762 vec![
3763 ("stale ".to_string(), "stale ".to_string()),
3764 ("fresh".to_string(), "fresh".to_string()),
3765 ]
3766 );
3767 }
3768
3769 #[tokio::test]
3770 async fn invalid_tool_call_delta_skip_uses_structured_tool_feedback() {
3771 let delta_hook = RecordingToolCallDeltaHook::default();
3772 let add_calls = Arc::new(AtomicU32::new(0));
3773 let model = MockCompletionModel::from_stream_turns([
3774 vec![
3775 MockStreamEvent::text("checking "),
3776 MockStreamEvent::tool_call(
3777 "tool_call_0",
3778 "add",
3779 serde_json::json!({"x": 1, "y": 2}),
3780 )
3781 .with_call_id("call_0"),
3782 MockStreamEvent::tool_call_arguments_delta(
3783 "tool_call_1",
3784 "internal_1",
3785 r#"{"x":2,"y":3}"#,
3786 ),
3787 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3788 MockStreamEvent::final_response_with_total_tokens(4),
3789 ],
3790 vec![
3791 MockStreamEvent::text("continued"),
3792 MockStreamEvent::final_response_with_total_tokens(6),
3793 ],
3794 ]);
3795 let recorded = model.clone();
3796 let agent = AgentBuilder::new(model)
3797 .tool(CountingAddTool {
3798 calls: add_calls.clone(),
3799 })
3800 .build();
3801
3802 let mut stream = agent
3803 .stream_prompt("use the tool")
3804 .with_hook(RecordingDeltaAndSkipInvalidToolHook {
3805 delta: delta_hook.clone(),
3806 })
3807 .multi_turn(3)
3808 .with_history(Vec::<Message>::new())
3809 .await;
3810 let mut skipped_tool_result = None;
3811 let mut final_response_text = None;
3812
3813 while let Some(item) = stream.next().await {
3814 match item {
3815 Ok(MultiTurnStreamItem::StreamAssistantItem(
3816 StreamedAssistantContent::ToolCallDelta { .. },
3817 )) => panic!("invalid tool-call delta should not be emitted"),
3818 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3819 tool_result,
3820 internal_call_id,
3821 })) => {
3822 assert_eq!(internal_call_id, "internal_1");
3823 skipped_tool_result = Some(tool_result);
3824 }
3825 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3826 final_response_text = Some(response.response().to_string());
3827 break;
3828 }
3829 Ok(_) => {}
3830 Err(err) => panic!("unexpected streaming error: {err:?}"),
3831 }
3832 }
3833
3834 let skipped_tool_result =
3835 skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
3836 assert_eq!(skipped_tool_result.id, "tool_call_1");
3837 assert!(skipped_tool_result.call_id.is_none());
3838 assert!(skipped_tool_result.content.iter().any(|content| matches!(
3839 content,
3840 ToolResultContent::Text(text) if text.text == "default_api was skipped"
3841 )));
3842 assert_eq!(final_response_text.as_deref(), Some("continued"));
3843 assert!(delta_hook.observed().is_empty());
3844 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3845
3846 let requests = recorded.requests();
3847 assert_eq!(requests.len(), 2);
3848 let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3849 assert!(matches!(
3850 follow_up_history.get(1),
3851 Some(Message::Assistant { content, .. })
3852 if content.iter().any(|item| matches!(
3853 item,
3854 AssistantContent::Text(text) if text.text == "checking "
3855 ))
3856 && content.iter().any(|item| matches!(
3857 item,
3858 AssistantContent::ToolCall(tool_call)
3859 if tool_call.id == "tool_call_0"
3860 && tool_call.function.name == "add"
3861 ))
3862 && content.iter().any(|item| matches!(
3863 item,
3864 AssistantContent::ToolCall(tool_call)
3865 if tool_call.id == "tool_call_1"
3866 && tool_call.function.name == "default_api"
3867 && tool_call.function.arguments == serde_json::json!({"x": 2, "y": 3})
3868 ))
3869 ));
3870 assert!(matches!(
3871 follow_up_history.get(2),
3872 Some(Message::User { content })
3873 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3874 && content.iter().any(|item| matches!(
3875 item,
3876 UserContent::ToolResult(result)
3877 if result.id == "tool_call_0"
3878 && result.call_id.as_deref() == Some("call_0")
3879 && result.content.iter().any(|content| matches!(
3880 content,
3881 ToolResultContent::Text(text)
3882 if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3883 ))
3884 ))
3885 && content.iter().any(|item| matches!(
3886 item,
3887 UserContent::ToolResult(result)
3888 if result.id == "tool_call_1"
3889 && result.content.iter().any(|content| matches!(
3890 content,
3891 ToolResultContent::Text(text)
3892 if text.text == "default_api was skipped"
3893 ))
3894 ))
3895 ));
3896 }
3897
3898 #[tokio::test]
3899 async fn streaming_retry_budget_exhaustion_history_contains_invalid_tool_call() {
3900 let model = MockCompletionModel::from_stream_turns([
3901 vec![
3902 MockStreamEvent::tool_call(
3903 "tool_call_1",
3904 "default_api",
3905 serde_json::json!({"x": 1, "y": 2}),
3906 ),
3907 MockStreamEvent::final_response_with_total_tokens(4),
3908 ],
3909 vec![
3910 MockStreamEvent::text("should not be requested"),
3911 MockStreamEvent::final_response_with_total_tokens(6),
3912 ],
3913 ]);
3914 let recorded = model.clone();
3915 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3916
3917 let mut stream = agent
3918 .stream_prompt("use the tool")
3919 .with_hook(RetryDefaultApiHook)
3920 .multi_turn(3)
3921 .max_invalid_tool_call_retries(0)
3922 .await;
3923 let mut error = None;
3924
3925 while let Some(item) = stream.next().await {
3926 if let Err(err) = item {
3927 error = Some(err);
3928 break;
3929 }
3930 }
3931
3932 let error = error.expect("retry budget exhaustion should fail");
3933 match error {
3934 StreamingError::Prompt(err) => match *err {
3935 PromptError::UnknownToolCall {
3936 tool_name,
3937 chat_history,
3938 ..
3939 } => {
3940 assert_eq!(tool_name, "default_api");
3941 assert!(history_contains_tool_call(&chat_history, "default_api"));
3942 }
3943 other => panic!("expected UnknownToolCall, got {other:?}"),
3944 },
3945 other => panic!("expected prompt streaming error, got {other:?}"),
3946 }
3947 assert_eq!(recorded.request_count(), 1);
3948 }
3949
3950 #[tokio::test]
3951 async fn streaming_name_delta_retry_budget_exhaustion_history_includes_same_turn_context() {
3952 let model = MockCompletionModel::from_stream_turns([
3953 vec![
3954 MockStreamEvent::text("checking "),
3955 MockStreamEvent::tool_call(
3956 "tool_call_0",
3957 "add",
3958 serde_json::json!({"x": 1, "y": 2}),
3959 )
3960 .with_call_id("call_0"),
3961 MockStreamEvent::tool_call_arguments_delta(
3962 "tool_call_1",
3963 "internal_1",
3964 r#"{"x":2,"y":3}"#,
3965 ),
3966 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3967 MockStreamEvent::final_response_with_total_tokens(4),
3968 ],
3969 vec![
3970 MockStreamEvent::text("should not be requested"),
3971 MockStreamEvent::final_response_with_total_tokens(6),
3972 ],
3973 ]);
3974 let recorded = model.clone();
3975 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3976
3977 let mut stream = agent
3978 .stream_prompt("use the tool")
3979 .with_hook(RetryDefaultApiHook)
3980 .multi_turn(3)
3981 .max_invalid_tool_call_retries(0)
3982 .await;
3983 let mut error = None;
3984
3985 while let Some(item) = stream.next().await {
3986 if let Err(err) = item {
3987 error = Some(err);
3988 break;
3989 }
3990 }
3991
3992 let error = error.expect("retry budget exhaustion should fail");
3993 match error {
3994 StreamingError::Prompt(err) => match *err {
3995 PromptError::UnknownToolCall {
3996 tool_name,
3997 chat_history,
3998 ..
3999 } => {
4000 assert_eq!(tool_name, "default_api");
4001 assert!(history_contains_text(&chat_history, "checking "));
4002 assert!(history_contains_tool_call(&chat_history, "add"));
4003 assert!(history_contains_tool_call(&chat_history, "default_api"));
4004 }
4005 other => panic!("expected UnknownToolCall, got {other:?}"),
4006 },
4007 other => panic!("expected prompt streaming error, got {other:?}"),
4008 }
4009 assert_eq!(recorded.request_count(), 1);
4010 }
4011
4012 #[tokio::test]
4013 async fn completed_unknown_tool_call_after_text_fails_before_finish_hook_or_later_emit() {
4014 let add_calls = Arc::new(AtomicU32::new(0));
4015 let model = MockCompletionModel::from_stream_turns([
4016 vec![
4017 MockStreamEvent::text("thinking "),
4018 MockStreamEvent::tool_call(
4019 "tool_call_1",
4020 "default_api",
4021 serde_json::json!({"x": 1, "y": 2}),
4022 ),
4023 MockStreamEvent::final_response_with_total_tokens(4),
4024 ],
4025 vec![
4026 MockStreamEvent::text("should not be requested"),
4027 MockStreamEvent::final_response_with_total_tokens(6),
4028 ],
4029 ]);
4030 let recorded = model.clone();
4031 let agent = AgentBuilder::new(model)
4032 .tool(CountingAddTool {
4033 calls: add_calls.clone(),
4034 })
4035 .build();
4036
4037 let mut stream = agent
4038 .stream_prompt("use the tool")
4039 .with_hook(PanicOnUnknownToolHook)
4040 .multi_turn(3)
4041 .await;
4042 let mut saw_text = false;
4043 let mut saw_completion_call = false;
4044 let mut saw_final_response = false;
4045 let mut saw_tool_call = false;
4046 let mut saw_tool_result = false;
4047 let mut error = None;
4048
4049 while let Some(item) = stream.next().await {
4050 match item {
4051 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(_))) => {
4052 saw_text = true;
4053 }
4054 Ok(MultiTurnStreamItem::CompletionCall(_)) => {
4055 saw_completion_call = true;
4056 }
4057 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Final(
4058 _,
4059 )))
4060 | Ok(MultiTurnStreamItem::FinalResponse(_)) => {
4061 saw_final_response = true;
4062 }
4063 Ok(MultiTurnStreamItem::StreamAssistantItem(
4064 StreamedAssistantContent::ToolCall { .. },
4065 )) => {
4066 saw_tool_call = true;
4067 }
4068 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
4069 ..
4070 })) => {
4071 saw_tool_result = true;
4072 }
4073 Ok(_) => {}
4074 Err(err) => {
4075 error = Some(err);
4076 break;
4077 }
4078 }
4079 }
4080
4081 assert!(saw_text);
4082 assert!(!saw_completion_call);
4083 assert!(!saw_final_response);
4084 assert!(!saw_tool_call);
4085 assert!(!saw_tool_result);
4086 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
4087 let error = error.expect("completed unknown tool call should fail immediately");
4088 match error {
4089 StreamingError::Prompt(err) => match *err {
4090 PromptError::UnknownToolCall {
4091 tool_name,
4092 available_tools,
4093 allowed_tools,
4094 chat_history,
4095 } => {
4096 assert_eq!(tool_name, "default_api");
4097 assert_eq!(available_tools, vec!["add".to_string()]);
4098 assert_eq!(allowed_tools, vec!["add".to_string()]);
4099 assert!(history_contains_tool_call(&chat_history, "default_api"));
4100 }
4101 other => panic!("expected UnknownToolCall, got {other:?}"),
4102 },
4103 other => panic!("expected prompt streaming error, got {other:?}"),
4104 }
4105 assert_eq!(recorded.request_count(), 1);
4106 }
4107
4108 #[tokio::test]
4109 async fn mixed_streaming_tool_calls_fail_before_any_tool_execution() {
4110 let add_calls = Arc::new(AtomicU32::new(0));
4111 let model = MockCompletionModel::from_stream_turns([
4112 vec![
4113 MockStreamEvent::tool_call(
4114 "tool_call_1",
4115 "add",
4116 serde_json::json!({"x": 1, "y": 2}),
4117 )
4118 .with_call_id("call_1"),
4119 MockStreamEvent::tool_call(
4120 "tool_call_2",
4121 "default_api",
4122 serde_json::json!({"x": 3, "y": 4}),
4123 ),
4124 MockStreamEvent::final_response_with_total_tokens(4),
4125 ],
4126 vec![
4127 MockStreamEvent::text("should not be requested"),
4128 MockStreamEvent::final_response_with_total_tokens(6),
4129 ],
4130 ]);
4131 let recorded = model.clone();
4132 let agent = AgentBuilder::new(model)
4133 .tool(CountingAddTool {
4134 calls: add_calls.clone(),
4135 })
4136 .build();
4137
4138 let mut stream = agent
4139 .stream_prompt("use tools")
4140 .with_hook(PanicOnUnknownToolHook)
4141 .multi_turn(3)
4142 .await;
4143 let mut saw_completion_call = false;
4144 let mut saw_tool_call = false;
4145 let mut saw_tool_result = false;
4146 let mut error = None;
4147
4148 while let Some(item) = stream.next().await {
4149 match item {
4150 Ok(MultiTurnStreamItem::CompletionCall(_)) => {
4151 saw_completion_call = true;
4152 }
4153 Ok(MultiTurnStreamItem::StreamAssistantItem(
4154 StreamedAssistantContent::ToolCall { .. },
4155 )) => {
4156 saw_tool_call = true;
4157 }
4158 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
4159 ..
4160 })) => {
4161 saw_tool_result = true;
4162 }
4163 Ok(_) => {}
4164 Err(err) => {
4165 error = Some(err);
4166 break;
4167 }
4168 }
4169 }
4170
4171 assert!(!saw_completion_call);
4172 assert!(!saw_tool_call);
4173 assert!(!saw_tool_result);
4174 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
4175 let error = error.expect("mixed unknown streamed tool call should fail");
4176 match error {
4177 StreamingError::Prompt(err) => match *err {
4178 PromptError::UnknownToolCall {
4179 tool_name,
4180 available_tools,
4181 allowed_tools,
4182 chat_history,
4183 } => {
4184 assert_eq!(tool_name, "default_api");
4185 assert_eq!(available_tools, vec!["add".to_string()]);
4186 assert_eq!(allowed_tools, vec!["add".to_string()]);
4187 assert!(history_contains_tool_call(&chat_history, "default_api"));
4188 }
4189 other => panic!("expected UnknownToolCall, got {other:?}"),
4190 },
4191 other => panic!("expected prompt streaming error, got {other:?}"),
4192 }
4193 assert_eq!(recorded.request_count(), 1);
4194 }
4195
4196 #[tokio::test]
4197 async fn multiple_valid_streaming_tool_calls_execute_after_batch_validation() {
4198 let add_calls = Arc::new(AtomicU32::new(0));
4199 let subtract_calls = Arc::new(AtomicU32::new(0));
4200 let model = MockCompletionModel::from_stream_turns([
4201 vec![
4202 MockStreamEvent::tool_call(
4203 "tool_call_1",
4204 "add",
4205 serde_json::json!({"x": 1, "y": 2}),
4206 )
4207 .with_call_id("call_1"),
4208 MockStreamEvent::tool_call(
4209 "tool_call_2",
4210 "subtract",
4211 serde_json::json!({"x": 8, "y": 3}),
4212 )
4213 .with_call_id("call_2"),
4214 MockStreamEvent::final_response_with_total_tokens(4),
4215 ],
4216 vec![
4217 MockStreamEvent::text("done"),
4218 MockStreamEvent::final_response_with_total_tokens(6),
4219 ],
4220 ]);
4221 let recorded = model.clone();
4222 let agent = AgentBuilder::new(model)
4223 .tool(CountingAddTool {
4224 calls: add_calls.clone(),
4225 })
4226 .tool(CountingSubtractTool {
4227 calls: subtract_calls.clone(),
4228 })
4229 .build();
4230
4231 let mut stream = agent.stream_prompt("use tools").multi_turn(3).await;
4232 let mut tool_call_names = Vec::new();
4233 let mut tool_result_ids = Vec::new();
4234 let mut final_response_text = None;
4235
4236 while let Some(item) = stream.next().await {
4237 match item {
4238 Ok(MultiTurnStreamItem::StreamAssistantItem(
4239 StreamedAssistantContent::ToolCall { tool_call, .. },
4240 )) => {
4241 tool_call_names.push(tool_call.function.name);
4242 }
4243 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
4244 tool_result,
4245 ..
4246 })) => {
4247 tool_result_ids.push(tool_result.id);
4248 }
4249 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4250 final_response_text = Some(response.response().to_owned());
4251 break;
4252 }
4253 Ok(_) => {}
4254 Err(err) => panic!("unexpected streaming error: {err:?}"),
4255 }
4256 }
4257
4258 assert_eq!(
4259 tool_call_names,
4260 vec!["add".to_string(), "subtract".to_string()]
4261 );
4262 assert_eq!(
4263 tool_result_ids,
4264 vec!["tool_call_1".to_string(), "tool_call_2".to_string()]
4265 );
4266 assert_eq!(add_calls.load(Ordering::SeqCst), 1);
4267 assert_eq!(subtract_calls.load(Ordering::SeqCst), 1);
4268 assert_eq!(final_response_text.as_deref(), Some("done"));
4269 assert_eq!(recorded.request_count(), 2);
4270 }
4271
4272 #[tokio::test]
4273 async fn disallowed_specific_tool_call_fails_before_streaming_second_request() {
4274 let model = MockCompletionModel::from_stream_turns([
4275 vec![
4276 MockStreamEvent::tool_call(
4277 "tool_call_1",
4278 "subtract",
4279 serde_json::json!({"x": 3, "y": 1}),
4280 ),
4281 MockStreamEvent::final_response_with_total_tokens(4),
4282 ],
4283 vec![
4284 MockStreamEvent::text("should not be requested"),
4285 MockStreamEvent::final_response_with_total_tokens(6),
4286 ],
4287 ]);
4288 let recorded = model.clone();
4289 let agent = AgentBuilder::new(model)
4290 .tool(MockAddTool)
4291 .tool(MockSubtractTool)
4292 .tool_choice(ToolChoice::Specific {
4293 function_names: vec!["add".to_string()],
4294 })
4295 .build();
4296
4297 let mut stream = agent
4298 .stream_prompt("use the allowed tool")
4299 .with_hook(PanicOnUnknownToolHook)
4300 .multi_turn(3)
4301 .await;
4302 let mut saw_tool_call = false;
4303 let mut error = None;
4304
4305 while let Some(item) = stream.next().await {
4306 match item {
4307 Ok(MultiTurnStreamItem::StreamAssistantItem(
4308 StreamedAssistantContent::ToolCall { .. },
4309 )) => {
4310 saw_tool_call = true;
4311 }
4312 Ok(_) => {}
4313 Err(err) => {
4314 error = Some(err);
4315 break;
4316 }
4317 }
4318 }
4319
4320 assert!(!saw_tool_call);
4321 let error = error.expect("disallowed model-emitted tool should fail");
4322 match error {
4323 StreamingError::Prompt(err) => match *err {
4324 PromptError::UnknownToolCall {
4325 tool_name,
4326 available_tools,
4327 allowed_tools,
4328 chat_history,
4329 } => {
4330 assert_eq!(tool_name, "subtract");
4331 assert_eq!(
4332 available_tools,
4333 vec!["add".to_string(), "subtract".to_string()]
4334 );
4335 assert_eq!(allowed_tools, vec!["add".to_string()]);
4336 assert!(history_contains_tool_call(&chat_history, "subtract"));
4337 }
4338 other => panic!("expected UnknownToolCall, got {other:?}"),
4339 },
4340 other => panic!("expected prompt streaming error, got {other:?}"),
4341 }
4342 assert_eq!(recorded.request_count(), 1);
4343 }
4344
4345 #[tokio::test]
4346 async fn mixed_specific_tool_calls_fail_before_any_tool_execution() {
4347 let add_calls = Arc::new(AtomicU32::new(0));
4348 let model = MockCompletionModel::from_stream_turns([
4349 vec![
4350 MockStreamEvent::tool_call(
4351 "tool_call_1",
4352 "add",
4353 serde_json::json!({"x": 1, "y": 2}),
4354 ),
4355 MockStreamEvent::tool_call(
4356 "tool_call_2",
4357 "subtract",
4358 serde_json::json!({"x": 3, "y": 1}),
4359 ),
4360 MockStreamEvent::final_response_with_total_tokens(4),
4361 ],
4362 vec![
4363 MockStreamEvent::text("should not be requested"),
4364 MockStreamEvent::final_response_with_total_tokens(6),
4365 ],
4366 ]);
4367 let recorded = model.clone();
4368 let agent = AgentBuilder::new(model)
4369 .tool(CountingAddTool {
4370 calls: add_calls.clone(),
4371 })
4372 .tool(MockSubtractTool)
4373 .tool_choice(ToolChoice::Specific {
4374 function_names: vec!["add".to_string()],
4375 })
4376 .build();
4377
4378 let mut stream = agent
4379 .stream_prompt("use the allowed tool")
4380 .with_hook(PanicOnUnknownToolHook)
4381 .multi_turn(3)
4382 .await;
4383 let mut saw_tool_call = false;
4384 let mut saw_tool_result = false;
4385 let mut error = None;
4386
4387 while let Some(item) = stream.next().await {
4388 match item {
4389 Ok(MultiTurnStreamItem::StreamAssistantItem(
4390 StreamedAssistantContent::ToolCall { .. },
4391 )) => {
4392 saw_tool_call = true;
4393 }
4394 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
4395 ..
4396 })) => {
4397 saw_tool_result = true;
4398 }
4399 Ok(_) => {}
4400 Err(err) => {
4401 error = Some(err);
4402 break;
4403 }
4404 }
4405 }
4406
4407 assert!(!saw_tool_call);
4408 assert!(!saw_tool_result);
4409 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
4410 let error = error.expect("mixed disallowed streamed tool call should fail");
4411 match error {
4412 StreamingError::Prompt(err) => match *err {
4413 PromptError::UnknownToolCall {
4414 tool_name,
4415 available_tools,
4416 allowed_tools,
4417 chat_history,
4418 } => {
4419 assert_eq!(tool_name, "subtract");
4420 assert_eq!(
4421 available_tools,
4422 vec!["add".to_string(), "subtract".to_string()]
4423 );
4424 assert_eq!(allowed_tools, vec!["add".to_string()]);
4425 assert!(history_contains_tool_call(&chat_history, "subtract"));
4426 }
4427 other => panic!("expected UnknownToolCall, got {other:?}"),
4428 },
4429 other => panic!("expected prompt streaming error, got {other:?}"),
4430 }
4431 assert_eq!(recorded.request_count(), 1);
4432 }
4433
4434 #[tokio::test]
4435 async fn tool_choice_none_rejects_streaming_tool_call() {
4436 let model = MockCompletionModel::from_stream_turns([
4437 vec![
4438 MockStreamEvent::tool_call(
4439 "tool_call_1",
4440 "add",
4441 serde_json::json!({"x": 1, "y": 2}),
4442 ),
4443 MockStreamEvent::final_response_with_total_tokens(4),
4444 ],
4445 vec![
4446 MockStreamEvent::text("should not be requested"),
4447 MockStreamEvent::final_response_with_total_tokens(6),
4448 ],
4449 ]);
4450 let recorded = model.clone();
4451 let agent = AgentBuilder::new(model)
4452 .tool(MockAddTool)
4453 .tool_choice(ToolChoice::None)
4454 .build();
4455
4456 let mut stream = agent
4457 .stream_prompt("do not use tools")
4458 .with_hook(PanicOnUnknownToolHook)
4459 .multi_turn(3)
4460 .await;
4461 let mut saw_tool_call = false;
4462 let mut error = None;
4463
4464 while let Some(item) = stream.next().await {
4465 match item {
4466 Ok(MultiTurnStreamItem::StreamAssistantItem(
4467 StreamedAssistantContent::ToolCall { .. },
4468 )) => {
4469 saw_tool_call = true;
4470 }
4471 Ok(_) => {}
4472 Err(err) => {
4473 error = Some(err);
4474 break;
4475 }
4476 }
4477 }
4478
4479 assert!(!saw_tool_call);
4480 let error = error.expect("ToolChoice::None should reject returned tool calls");
4481 match error {
4482 StreamingError::Prompt(err) => match *err {
4483 PromptError::UnknownToolCall {
4484 tool_name,
4485 available_tools,
4486 allowed_tools,
4487 chat_history,
4488 } => {
4489 assert_eq!(tool_name, "add");
4490 assert_eq!(available_tools, vec!["add".to_string()]);
4491 assert!(allowed_tools.is_empty());
4492 assert!(history_contains_tool_call(&chat_history, "add"));
4493 }
4494 other => panic!("expected UnknownToolCall, got {other:?}"),
4495 },
4496 other => panic!("expected prompt streaming error, got {other:?}"),
4497 }
4498 assert_eq!(recorded.request_count(), 1);
4499 }
4500
4501 #[tokio::test]
4502 async fn tool_choice_none_rejects_streaming_tool_call_name_delta_before_hook_or_emit() {
4503 let model = MockCompletionModel::from_stream_turns([
4504 vec![
4505 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4506 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4507 MockStreamEvent::final_response_with_total_tokens(4),
4508 ],
4509 vec![
4510 MockStreamEvent::text("should not be requested"),
4511 MockStreamEvent::final_response_with_total_tokens(6),
4512 ],
4513 ]);
4514 let recorded = model.clone();
4515 let agent = AgentBuilder::new(model)
4516 .tool(MockAddTool)
4517 .tool_choice(ToolChoice::None)
4518 .build();
4519
4520 let mut stream = agent
4521 .stream_prompt("do not use tools")
4522 .with_hook(PanicOnUnknownToolHook)
4523 .multi_turn(3)
4524 .await;
4525 let mut saw_delta = false;
4526 let mut error = None;
4527
4528 while let Some(item) = stream.next().await {
4529 match item {
4530 Ok(MultiTurnStreamItem::StreamAssistantItem(
4531 StreamedAssistantContent::ToolCallDelta { .. },
4532 )) => {
4533 saw_delta = true;
4534 }
4535 Ok(_) => {}
4536 Err(err) => {
4537 error = Some(err);
4538 break;
4539 }
4540 }
4541 }
4542
4543 assert!(!saw_delta);
4544 let error = error.expect("ToolChoice::None should reject returned tool-call deltas");
4545 match error {
4546 StreamingError::Prompt(err) => match *err {
4547 PromptError::UnknownToolCall {
4548 tool_name,
4549 available_tools,
4550 allowed_tools,
4551 chat_history,
4552 } => {
4553 assert_eq!(tool_name, "add");
4554 assert_eq!(available_tools, vec!["add".to_string()]);
4555 assert!(allowed_tools.is_empty());
4556 assert!(history_contains_tool_call(&chat_history, "add"));
4557 }
4558 other => panic!("expected UnknownToolCall, got {other:?}"),
4559 },
4560 other => panic!("expected prompt streaming error, got {other:?}"),
4561 }
4562 assert_eq!(recorded.request_count(), 1);
4563 }
4564
4565 #[tokio::test]
4566 async fn unknown_tool_call_name_delta_fails_before_streaming_delta_hook_or_emit() {
4567 let model = MockCompletionModel::from_stream_turns([
4568 vec![
4569 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"),
4570 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4571 MockStreamEvent::final_response_with_total_tokens(4),
4572 ],
4573 vec![
4574 MockStreamEvent::text("should not be requested"),
4575 MockStreamEvent::final_response_with_total_tokens(6),
4576 ],
4577 ]);
4578 let recorded = model.clone();
4579 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4580
4581 let mut stream = agent
4582 .stream_prompt("stream a bad tool call")
4583 .with_hook(PanicOnUnknownToolHook)
4584 .multi_turn(3)
4585 .await;
4586 let mut saw_delta = false;
4587 let mut error = None;
4588
4589 while let Some(item) = stream.next().await {
4590 match item {
4591 Ok(MultiTurnStreamItem::StreamAssistantItem(
4592 StreamedAssistantContent::ToolCallDelta { .. },
4593 )) => {
4594 saw_delta = true;
4595 }
4596 Ok(_) => {}
4597 Err(err) => {
4598 error = Some(err);
4599 break;
4600 }
4601 }
4602 }
4603
4604 assert!(!saw_delta);
4605 let error = error.expect("unknown tool-call name delta should fail");
4606 match error {
4607 StreamingError::Prompt(err) => match *err {
4608 PromptError::UnknownToolCall {
4609 tool_name,
4610 available_tools,
4611 allowed_tools,
4612 chat_history,
4613 } => {
4614 assert_eq!(tool_name, "default_api");
4615 assert_eq!(available_tools, vec!["add".to_string()]);
4616 assert_eq!(allowed_tools, vec!["add".to_string()]);
4617 assert!(history_contains_tool_call(&chat_history, "default_api"));
4618 }
4619 other => panic!("expected UnknownToolCall, got {other:?}"),
4620 },
4621 other => panic!("expected prompt streaming error, got {other:?}"),
4622 }
4623 assert_eq!(recorded.request_count(), 1);
4624 }
4625
4626 #[tokio::test]
4627 async fn tool_call_args_delta_before_unknown_name_fails_before_hook_or_emit() {
4628 let model = MockCompletionModel::from_stream_turns([
4629 vec![
4630 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4631 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"),
4632 MockStreamEvent::final_response_with_total_tokens(4),
4633 ],
4634 vec![
4635 MockStreamEvent::text("should not be requested"),
4636 MockStreamEvent::final_response_with_total_tokens(6),
4637 ],
4638 ]);
4639 let recorded = model.clone();
4640 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4641
4642 let mut stream = agent
4643 .stream_prompt("stream a bad tool call")
4644 .with_hook(PanicOnUnknownToolHook)
4645 .multi_turn(3)
4646 .await;
4647 let mut saw_delta = false;
4648 let mut error = None;
4649
4650 while let Some(item) = stream.next().await {
4651 match item {
4652 Ok(MultiTurnStreamItem::StreamAssistantItem(
4653 StreamedAssistantContent::ToolCallDelta { .. },
4654 )) => {
4655 saw_delta = true;
4656 }
4657 Ok(_) => {}
4658 Err(err) => {
4659 error = Some(err);
4660 break;
4661 }
4662 }
4663 }
4664
4665 assert!(!saw_delta);
4666 let error = error.expect("unknown tool-call name should reject buffered args");
4667 match error {
4668 StreamingError::Prompt(err) => match *err {
4669 PromptError::UnknownToolCall {
4670 tool_name,
4671 available_tools,
4672 allowed_tools,
4673 chat_history,
4674 } => {
4675 assert_eq!(tool_name, "default_api");
4676 assert_eq!(available_tools, vec!["add".to_string()]);
4677 assert_eq!(allowed_tools, vec!["add".to_string()]);
4678 assert!(history_contains_tool_call(&chat_history, "default_api"));
4679 }
4680 other => panic!("expected UnknownToolCall, got {other:?}"),
4681 },
4682 other => panic!("expected prompt streaming error, got {other:?}"),
4683 }
4684 assert_eq!(recorded.request_count(), 1);
4685 }
4686
4687 #[tokio::test]
4688 async fn tool_call_args_delta_before_valid_name_buffers_then_emits_in_safe_order() {
4689 let model = MockCompletionModel::from_stream_turns([[
4690 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4691 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4692 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4693 MockStreamEvent::final_response_with_total_tokens(3),
4694 ]]);
4695 let hook = RecordingToolCallDeltaHook::default();
4696 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4697
4698 let mut stream = agent
4699 .stream_prompt("stream a tool call")
4700 .with_hook(hook.clone())
4701 .await;
4702 let mut stream_deltas = Vec::new();
4703
4704 while let Some(item) = stream.next().await {
4705 match item {
4706 Ok(MultiTurnStreamItem::StreamAssistantItem(
4707 StreamedAssistantContent::ToolCallDelta {
4708 id,
4709 internal_call_id,
4710 content,
4711 },
4712 )) => {
4713 stream_deltas.push((id, internal_call_id, content));
4714 }
4715 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4716 Ok(_) => {}
4717 Err(err) => panic!("unexpected streaming error: {err:?}"),
4718 }
4719 }
4720
4721 assert_eq!(
4722 hook.observed(),
4723 vec![
4724 (
4725 "tool_1".to_string(),
4726 "internal_1".to_string(),
4727 Some("add".to_string()),
4728 String::new()
4729 ),
4730 (
4731 "tool_1".to_string(),
4732 "internal_1".to_string(),
4733 None,
4734 "{\"x\":".to_string()
4735 ),
4736 (
4737 "tool_1".to_string(),
4738 "internal_1".to_string(),
4739 None,
4740 "1}".to_string()
4741 ),
4742 ]
4743 );
4744 assert_eq!(
4745 stream_deltas,
4746 vec![
4747 (
4748 "tool_1".to_string(),
4749 "internal_1".to_string(),
4750 ToolCallDeltaContent::Name("add".to_string())
4751 ),
4752 (
4753 "tool_1".to_string(),
4754 "internal_1".to_string(),
4755 ToolCallDeltaContent::Delta("{\"x\":".to_string())
4756 ),
4757 (
4758 "tool_1".to_string(),
4759 "internal_1".to_string(),
4760 ToolCallDeltaContent::Delta("1}".to_string())
4761 ),
4762 ]
4763 );
4764 }
4765
4766 #[tokio::test]
4767 async fn tool_call_args_delta_without_name_errors_at_stream_end() {
4768 let model = MockCompletionModel::from_stream_turns([
4769 vec![
4770 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4771 MockStreamEvent::final_response_with_total_tokens(4),
4772 ],
4773 vec![
4774 MockStreamEvent::text("should not be requested"),
4775 MockStreamEvent::final_response_with_total_tokens(6),
4776 ],
4777 ]);
4778 let recorded = model.clone();
4779 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4780
4781 let mut stream = agent
4782 .stream_prompt("stream an incomplete tool call")
4783 .with_hook(PanicOnUnknownToolHook)
4784 .multi_turn(3)
4785 .await;
4786 let mut saw_delta = false;
4787 let mut saw_completion_call = false;
4788 let mut saw_final_response = false;
4789 let mut error = None;
4790
4791 while let Some(item) = stream.next().await {
4792 match item {
4793 Ok(MultiTurnStreamItem::StreamAssistantItem(
4794 StreamedAssistantContent::ToolCallDelta { .. },
4795 )) => {
4796 saw_delta = true;
4797 }
4798 Ok(MultiTurnStreamItem::CompletionCall(_)) => {
4799 saw_completion_call = true;
4800 }
4801 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
4802 saw_final_response = true;
4803 }
4804 Ok(_) => {}
4805 Err(err) => {
4806 error = Some(err);
4807 break;
4808 }
4809 }
4810 }
4811
4812 assert!(!saw_delta);
4813 assert!(!saw_completion_call);
4814 assert!(!saw_final_response);
4815 let error = error.expect("unterminated tool-call args delta should fail");
4816 match error {
4817 StreamingError::Completion(CompletionError::ResponseError(message)) => {
4818 assert!(
4819 message.contains("streamed tool call arguments"),
4820 "{message}"
4821 );
4822 assert!(message.contains("tool_1"), "{message}");
4823 assert!(message.contains("internal_1"), "{message}");
4824 }
4825 other => panic!("expected completion response error, got {other:?}"),
4826 }
4827 assert_eq!(recorded.request_count(), 1);
4828 }
4829
4830 #[tokio::test]
4831 async fn tool_choice_none_buffers_args_then_rejects_name_without_emit() {
4832 let model = MockCompletionModel::from_stream_turns([
4833 vec![
4834 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4835 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4836 MockStreamEvent::final_response_with_total_tokens(4),
4837 ],
4838 vec![
4839 MockStreamEvent::text("should not be requested"),
4840 MockStreamEvent::final_response_with_total_tokens(6),
4841 ],
4842 ]);
4843 let recorded = model.clone();
4844 let agent = AgentBuilder::new(model)
4845 .tool(MockAddTool)
4846 .tool_choice(ToolChoice::None)
4847 .build();
4848
4849 let mut stream = agent
4850 .stream_prompt("do not use tools")
4851 .with_hook(PanicOnUnknownToolHook)
4852 .multi_turn(3)
4853 .await;
4854 let mut saw_delta = false;
4855 let mut error = None;
4856
4857 while let Some(item) = stream.next().await {
4858 match item {
4859 Ok(MultiTurnStreamItem::StreamAssistantItem(
4860 StreamedAssistantContent::ToolCallDelta { .. },
4861 )) => {
4862 saw_delta = true;
4863 }
4864 Ok(_) => {}
4865 Err(err) => {
4866 error = Some(err);
4867 break;
4868 }
4869 }
4870 }
4871
4872 assert!(!saw_delta);
4873 let error = error.expect("ToolChoice::None should reject buffered tool-call deltas");
4874 match error {
4875 StreamingError::Prompt(err) => match *err {
4876 PromptError::UnknownToolCall {
4877 tool_name,
4878 available_tools,
4879 allowed_tools,
4880 chat_history,
4881 } => {
4882 assert_eq!(tool_name, "add");
4883 assert_eq!(available_tools, vec!["add".to_string()]);
4884 assert!(allowed_tools.is_empty());
4885 assert!(history_contains_tool_call(&chat_history, "add"));
4886 }
4887 other => panic!("expected UnknownToolCall, got {other:?}"),
4888 },
4889 other => panic!("expected prompt streaming error, got {other:?}"),
4890 }
4891 assert_eq!(recorded.request_count(), 1);
4892 }
4893
4894 #[tokio::test]
4895 async fn stream_prompt_emits_tool_call_deltas_without_hook() {
4896 let model = MockCompletionModel::from_stream_turns([[
4897 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4898 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4899 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4900 MockStreamEvent::final_response_with_total_tokens(3),
4901 ]]);
4902 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4903
4904 let mut stream = agent.stream_prompt("stream a tool call").await;
4905 let mut deltas = Vec::new();
4906
4907 while let Some(item) = stream.next().await {
4908 match item {
4909 Ok(MultiTurnStreamItem::StreamAssistantItem(
4910 StreamedAssistantContent::ToolCallDelta {
4911 id,
4912 internal_call_id,
4913 content,
4914 },
4915 )) => {
4916 deltas.push((id, internal_call_id, content));
4917 }
4918 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4919 Ok(_) => {}
4920 Err(err) => panic!("unexpected streaming error: {err:?}"),
4921 }
4922 }
4923
4924 assert_eq!(
4925 deltas,
4926 vec![
4927 (
4928 "tool_1".to_string(),
4929 "internal_1".to_string(),
4930 ToolCallDeltaContent::Name("add".to_string())
4931 ),
4932 (
4933 "tool_1".to_string(),
4934 "internal_1".to_string(),
4935 ToolCallDeltaContent::Delta("{\"x\":".to_string())
4936 ),
4937 (
4938 "tool_1".to_string(),
4939 "internal_1".to_string(),
4940 ToolCallDeltaContent::Delta("1}".to_string())
4941 ),
4942 ]
4943 );
4944 }
4945
4946 #[tokio::test]
4947 async fn stream_prompt_emits_tool_call_deltas_after_hook_continue() {
4948 let model = MockCompletionModel::from_stream_turns([[
4949 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4950 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4951 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4952 MockStreamEvent::final_response_with_total_tokens(3),
4953 ]]);
4954 let hook = RecordingToolCallDeltaHook::default();
4955 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4956
4957 let mut stream = agent
4958 .stream_prompt("stream a tool call")
4959 .with_hook(hook.clone())
4960 .await;
4961 let mut stream_deltas = Vec::new();
4962
4963 while let Some(item) = stream.next().await {
4964 match item {
4965 Ok(MultiTurnStreamItem::StreamAssistantItem(
4966 StreamedAssistantContent::ToolCallDelta {
4967 id,
4968 internal_call_id,
4969 content,
4970 },
4971 )) => {
4972 stream_deltas.push((id, internal_call_id, content));
4973 }
4974 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4975 Ok(_) => {}
4976 Err(err) => panic!("unexpected streaming error: {err:?}"),
4977 }
4978 }
4979
4980 assert_eq!(
4981 hook.observed(),
4982 vec![
4983 (
4984 "tool_1".to_string(),
4985 "internal_1".to_string(),
4986 Some("add".to_string()),
4987 String::new()
4988 ),
4989 (
4990 "tool_1".to_string(),
4991 "internal_1".to_string(),
4992 None,
4993 "{\"x\":".to_string()
4994 ),
4995 (
4996 "tool_1".to_string(),
4997 "internal_1".to_string(),
4998 None,
4999 "1}".to_string()
5000 ),
5001 ]
5002 );
5003 assert_eq!(
5004 stream_deltas,
5005 vec![
5006 (
5007 "tool_1".to_string(),
5008 "internal_1".to_string(),
5009 ToolCallDeltaContent::Name("add".to_string())
5010 ),
5011 (
5012 "tool_1".to_string(),
5013 "internal_1".to_string(),
5014 ToolCallDeltaContent::Delta("{\"x\":".to_string())
5015 ),
5016 (
5017 "tool_1".to_string(),
5018 "internal_1".to_string(),
5019 ToolCallDeltaContent::Delta("1}".to_string())
5020 ),
5021 ]
5022 );
5023 }
5024
5025 #[tokio::test]
5026 async fn stream_prompt_tool_call_deltas_hook_termination_prevents_delta_emit() {
5027 let model = MockCompletionModel::from_stream_turns([[
5028 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
5029 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
5030 MockStreamEvent::final_response_with_total_tokens(3),
5031 ]]);
5032 let hook = TerminatingToolCallDeltaHook::default();
5033 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5034
5035 let mut stream = agent
5036 .stream_prompt("stream a tool call")
5037 .with_hook(hook.clone())
5038 .await;
5039 let mut saw_delta = false;
5040 let mut saw_final_response = false;
5041 let mut error_message = None;
5042
5043 while let Some(item) = stream.next().await {
5044 match item {
5045 Ok(MultiTurnStreamItem::StreamAssistantItem(
5046 StreamedAssistantContent::ToolCallDelta { .. },
5047 )) => {
5048 saw_delta = true;
5049 }
5050 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
5051 saw_final_response = true;
5052 }
5053 Ok(_) => {}
5054 Err(err) => {
5055 error_message = Some(err.to_string());
5056 break;
5057 }
5058 }
5059 }
5060
5061 assert_eq!(
5062 hook.observed(),
5063 vec![(
5064 "tool_1".to_string(),
5065 "internal_1".to_string(),
5066 Some("add".to_string()),
5067 String::new()
5068 )]
5069 );
5070 assert!(!saw_delta);
5071 assert!(!saw_final_response);
5072 assert!(
5073 error_message
5074 .as_deref()
5075 .is_some_and(|message| message.contains("PromptCancelled: stop on tool call delta")),
5076 "expected hook termination error, got {error_message:?}"
5077 );
5078 }
5079
5080 #[tokio::test]
5081 async fn stream_prompt_exposes_completion_calls() {
5082 let first_call_usage = usage(10, 2);
5083 let second_call_usage = usage(25, 5);
5084 let model = MockCompletionModel::from_stream_turns([
5085 vec![
5086 MockStreamEvent::tool_call(
5087 "tool_call_1",
5088 "add",
5089 serde_json::json!({"x": 1, "y": 2}),
5090 )
5091 .with_call_id("call_1"),
5092 MockStreamEvent::final_response(first_call_usage),
5093 ],
5094 vec![
5095 MockStreamEvent::text("done"),
5096 MockStreamEvent::final_response(second_call_usage),
5097 ],
5098 ]);
5099 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5100 let empty_history: &[Message] = &[];
5101
5102 let mut stream = agent
5103 .stream_prompt("do tool work")
5104 .with_history(empty_history)
5105 .multi_turn(3)
5106 .await;
5107 let mut completion_calls_events = Vec::new();
5108 let mut final_response = None;
5109
5110 while let Some(item) = stream.next().await {
5111 match item {
5112 Ok(MultiTurnStreamItem::CompletionCall(call_usage)) => {
5113 completion_calls_events.push(call_usage);
5114 }
5115 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
5116 final_response = Some(response);
5117 break;
5118 }
5119 Ok(_) => {}
5120 Err(err) => panic!("unexpected streaming error: {err:?}"),
5121 }
5122 }
5123
5124 assert_eq!(
5125 completion_calls_events,
5126 vec![
5127 CompletionCall::new(0, Some(first_call_usage)),
5128 CompletionCall::new(1, Some(second_call_usage))
5129 ]
5130 );
5131
5132 let final_response = final_response.expect("expected final response");
5133 assert_eq!(
5134 final_response.usage(),
5135 Usage {
5136 input_tokens: 35,
5137 output_tokens: 7,
5138 total_tokens: 42,
5139 cached_input_tokens: 0,
5140 cache_creation_input_tokens: 0,
5141 tool_use_prompt_tokens: 0,
5142 reasoning_tokens: 0,
5143 }
5144 );
5145 assert_eq!(
5146 final_response.completion_calls(),
5147 &[
5148 CompletionCall::new(0, Some(first_call_usage)),
5149 CompletionCall::new(1, Some(second_call_usage))
5150 ]
5151 );
5152 }
5153
5154 #[tokio::test(flavor = "current_thread")]
5155 async fn stream_prompt_records_single_call_usage_on_chat_span_under_outer_span() {
5156 let call_usage = usage(10, 2);
5157 let model = MockCompletionModel::from_stream_turns([[
5158 MockStreamEvent::text("done"),
5159 MockStreamEvent::final_response(call_usage),
5160 ]]);
5161 let agent = AgentBuilder::new(model).build();
5162
5163 assert_stream_usage_recorded_on_chat_spans(agent, "say done", 1, &[call_usage]).await;
5164 }
5165
5166 #[tokio::test(flavor = "current_thread")]
5167 async fn stream_prompt_records_multi_turn_usage_on_chat_spans_under_outer_span() {
5168 let first_call_usage = usage(10, 2);
5169 let second_call_usage = usage(25, 5);
5170 let model = MockCompletionModel::from_stream_turns([
5171 vec![
5172 MockStreamEvent::tool_call(
5173 "tool_call_1",
5174 "add",
5175 serde_json::json!({"x": 1, "y": 2}),
5176 )
5177 .with_call_id("call_1"),
5178 MockStreamEvent::final_response(first_call_usage),
5179 ],
5180 vec![
5181 MockStreamEvent::text("done"),
5182 MockStreamEvent::final_response(second_call_usage),
5183 ],
5184 ]);
5185 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5186
5187 assert_stream_usage_recorded_on_chat_spans(
5188 agent,
5189 "do tool work",
5190 3,
5191 &[first_call_usage, second_call_usage],
5192 )
5193 .await;
5194 }
5195
5196 #[tokio::test]
5197 async fn stream_prompt_emits_completion_call_before_finish_hook_termination() {
5198 let call_usage = usage(10, 2);
5199 let model = MockCompletionModel::from_stream_turns([[
5200 MockStreamEvent::text("done"),
5201 MockStreamEvent::final_response(call_usage),
5202 ]]);
5203 let agent = AgentBuilder::new(model).build();
5204
5205 let mut stream = agent
5206 .stream_prompt("say done")
5207 .with_hook(TerminateOnStreamFinish)
5208 .await;
5209 let mut completion_calls = Vec::new();
5210 let mut saw_error = false;
5211
5212 while let Some(item) = stream.next().await {
5213 match item {
5214 Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
5215 completion_calls.push(completion_call);
5216 }
5217 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
5218 panic!("unexpected final response after hook termination: {response:?}");
5219 }
5220 Ok(_) => {}
5221 Err(_) => {
5222 saw_error = true;
5223 break;
5224 }
5225 }
5226 }
5227
5228 assert_eq!(
5229 completion_calls,
5230 vec![CompletionCall::new(0, Some(call_usage))]
5231 );
5232 assert!(saw_error);
5233 }
5234
5235 #[tokio::test]
5236 async fn stream_prompt_completion_calls_records_unreported_usage() {
5237 let second_call_usage = usage(25, 5);
5238 let model = MockCompletionModel::from_stream_turns([
5239 vec![
5240 MockStreamEvent::tool_call(
5241 "tool_call_1",
5242 "add",
5243 serde_json::json!({"x": 1, "y": 2}),
5244 )
5245 .with_call_id("call_1"),
5246 ],
5247 vec![
5248 MockStreamEvent::text("done"),
5249 MockStreamEvent::final_response(second_call_usage),
5250 ],
5251 ]);
5252 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5253 let empty_history: &[Message] = &[];
5254
5255 let mut stream = agent
5256 .stream_prompt("do tool work")
5257 .with_history(empty_history)
5258 .multi_turn(3)
5259 .await;
5260 let mut completion_calls_events = Vec::new();
5261 let mut final_response = None;
5262
5263 while let Some(item) = stream.next().await {
5264 match item {
5265 Ok(MultiTurnStreamItem::CompletionCall(call_usage)) => {
5266 completion_calls_events.push(call_usage);
5267 }
5268 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
5269 final_response = Some(response);
5270 break;
5271 }
5272 Ok(_) => {}
5273 Err(err) => panic!("unexpected streaming error: {err:?}"),
5274 }
5275 }
5276
5277 let expected_usage = vec![
5278 CompletionCall::new(0, None),
5279 CompletionCall::new(1, Some(second_call_usage)),
5280 ];
5281 assert_eq!(completion_calls_events, expected_usage);
5282
5283 let final_response = final_response.expect("expected final response");
5284 assert_eq!(final_response.completion_calls(), expected_usage.as_slice());
5285 }
5286
5287 #[tokio::test]
5288 async fn final_response_matches_streamed_text_when_provider_final_is_textless() {
5289 let agent = AgentBuilder::new(streaming_text_then_final_model()).build();
5290
5291 let mut stream = agent.stream_prompt("say hello").await;
5292 let mut streamed_text = String::new();
5293 let mut final_response_text = None;
5294
5295 while let Some(item) = stream.next().await {
5296 match item {
5297 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5298 text,
5299 ))) => streamed_text.push_str(&text.text),
5300 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5301 final_response_text = Some(res.response().to_owned());
5302 break;
5303 }
5304 Ok(_) => {}
5305 Err(err) => panic!("unexpected streaming error: {err:?}"),
5306 }
5307 }
5308
5309 assert_eq!(streamed_text, "hello world");
5310 assert_eq!(final_response_text.as_deref(), Some("hello world"));
5311 }
5312
5313 #[tokio::test]
5314 async fn final_response_preserves_structured_text_metadata() {
5315 let agent = AgentBuilder::new(streaming_cited_text_then_final_model()).build();
5316
5317 let mut stream = agent.stream_prompt("answer with citations").await;
5318 let mut final_response = None;
5319
5320 while let Some(item) = stream.next().await {
5321 match item {
5322 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5323 final_response = Some(res);
5324 break;
5325 }
5326 Ok(_) => {}
5327 Err(err) => panic!("unexpected streaming error: {err:?}"),
5328 }
5329 }
5330
5331 let final_response = final_response.expect("expected final response");
5332 assert_eq!(final_response.response(), "cited answer");
5333 let metadata = text_metadata(final_response.content())
5334 .expect("expected text metadata in final content");
5335 assert_eq!(
5336 metadata["citations"][0]["encrypted_index"],
5337 "encrypted-reference"
5338 );
5339 }
5340
5341 #[tokio::test]
5342 async fn final_response_history_preserves_structured_text_metadata() {
5343 let agent = AgentBuilder::new(streaming_cited_text_then_final_model()).build();
5344
5345 let empty_history: &[Message] = &[];
5346 let mut stream = agent
5347 .stream_prompt("answer with citations")
5348 .with_history(empty_history)
5349 .await;
5350 let mut final_response = None;
5351
5352 while let Some(item) = stream.next().await {
5353 match item {
5354 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5355 final_response = Some(res);
5356 break;
5357 }
5358 Ok(_) => {}
5359 Err(err) => panic!("unexpected streaming error: {err:?}"),
5360 }
5361 }
5362
5363 let final_response = final_response.expect("expected final response");
5364 let history = final_response
5365 .history()
5366 .expect("with_history should include final history");
5367 let assistant_content = history
5368 .iter()
5369 .find_map(|message| match message {
5370 Message::Assistant { content, .. } => Some(content),
5371 _ => None,
5372 })
5373 .expect("expected assistant message in history");
5374 let metadata =
5375 text_metadata(assistant_content).expect("expected text metadata in assistant history");
5376 assert_eq!(
5377 metadata["citations"][0]["encrypted_index"],
5378 "encrypted-reference"
5379 );
5380 }
5381
5382 #[tokio::test]
5383 async fn tool_follow_up_history_preserves_structured_text_metadata() {
5384 let model = streaming_cited_text_then_tool_model();
5385 let recorded = model.clone();
5386 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5387 let empty_history: &[Message] = &[];
5388
5389 let mut stream = agent
5390 .stream_prompt("use a tool with citations")
5391 .with_history(empty_history)
5392 .multi_turn(3)
5393 .await;
5394
5395 while let Some(item) = stream.next().await {
5396 match item {
5397 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
5398 Ok(_) => {}
5399 Err(err) => panic!("unexpected streaming error: {err:?}"),
5400 }
5401 }
5402
5403 let requests = recorded.requests();
5404 assert_eq!(requests.len(), 2);
5405 let follow_up_history = requests[1].chat_history.iter().collect::<Vec<_>>();
5406 let assistant_content = follow_up_history
5407 .iter()
5408 .find_map(|message| match message {
5409 Message::Assistant { content, .. } => Some(content),
5410 _ => None,
5411 })
5412 .expect("expected assistant message in follow-up history");
5413 let metadata = text_metadata(assistant_content)
5414 .expect("expected citation metadata in follow-up assistant history");
5415 assert_eq!(
5416 metadata["citations"][0]["encrypted_index"],
5417 "encrypted-reference"
5418 );
5419 }
5420
5421 #[tokio::test]
5422 async fn final_response_can_remain_empty_for_truly_textless_turns() {
5423 let agent = AgentBuilder::new(streaming_final_only_model()).build();
5424
5425 let mut stream = agent.stream_prompt("say nothing").await;
5426 let mut streamed_text = String::new();
5427 let mut final_response_text = None;
5428
5429 while let Some(item) = stream.next().await {
5430 match item {
5431 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5432 text,
5433 ))) => streamed_text.push_str(&text.text),
5434 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5435 final_response_text = Some(res.response().to_owned());
5436 break;
5437 }
5438 Ok(_) => {}
5439 Err(err) => panic!("unexpected streaming error: {err:?}"),
5440 }
5441 }
5442
5443 assert!(streamed_text.is_empty());
5444 assert_eq!(final_response_text.as_deref(), Some(""));
5445 }
5446
5447 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
5450 let mut interval = tokio::time::interval(Duration::from_millis(50));
5451 let mut count = 0u32;
5452
5453 while !stop.load(Ordering::Relaxed) {
5454 interval.tick().await;
5455 count += 1;
5456
5457 tracing::event!(
5458 target: "background_logger",
5459 tracing::Level::INFO,
5460 count = count,
5461 "Background tick"
5462 );
5463
5464 let current = tracing::Span::current();
5466 if !current.is_disabled() && !current.is_none() {
5467 leak_count.fetch_add(1, Ordering::Relaxed);
5468 }
5469 }
5470
5471 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
5472 }
5473
5474 #[tokio::test(flavor = "current_thread")]
5482 #[ignore = "This requires an API key"]
5483 async fn test_span_context_isolation() -> anyhow::Result<()> {
5484 let stop = Arc::new(AtomicBool::new(false));
5485 let leak_count = Arc::new(AtomicU32::new(0));
5486
5487 let bg_stop = stop.clone();
5489 let bg_leak = leak_count.clone();
5490 let bg_handle = tokio::spawn(async move {
5491 background_logger(bg_stop, bg_leak).await;
5492 });
5493
5494 tokio::time::sleep(Duration::from_millis(100)).await;
5496
5497 let client = anthropic::Client::from_env()?;
5500 let agent = client
5501 .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
5502 .preamble("You are a helpful assistant.")
5503 .temperature(0.1)
5504 .max_tokens(100)
5505 .build();
5506
5507 let mut stream = agent
5508 .stream_prompt("Say 'hello world' and nothing else.")
5509 .await;
5510
5511 let mut full_content = String::new();
5512 while let Some(item) = stream.next().await {
5513 match item {
5514 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5515 text,
5516 ))) => {
5517 full_content.push_str(&text.text);
5518 }
5519 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
5520 break;
5521 }
5522 Err(e) => {
5523 tracing::warn!("Error: {:?}", e);
5524 break;
5525 }
5526 _ => {}
5527 }
5528 }
5529
5530 tracing::info!("Got response: {:?}", full_content);
5531
5532 stop.store(true, Ordering::Relaxed);
5534 bg_handle.await?;
5535
5536 let leaks = leak_count.load(Ordering::Relaxed);
5537 anyhow::ensure!(
5538 leaks == 0,
5539 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
5540 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
5541 );
5542
5543 Ok(())
5544 }
5545
5546 #[tokio::test]
5552 #[ignore = "This requires an API key"]
5553 async fn test_chat_history_in_final_response() -> anyhow::Result<()> {
5554 use crate::message::Message;
5555
5556 let client = anthropic::Client::from_env()?;
5557 let agent = client
5558 .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
5559 .preamble("You are a helpful assistant. Keep responses brief.")
5560 .temperature(0.1)
5561 .max_tokens(50)
5562 .build();
5563
5564 let empty_history: &[Message] = &[];
5566 let mut stream = agent
5567 .stream_prompt("Say 'hello' and nothing else.")
5568 .with_history(empty_history)
5569 .await;
5570
5571 let mut response_text = String::new();
5573 let mut final_history = None;
5574 while let Some(item) = stream.next().await {
5575 match item {
5576 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5577 text,
5578 ))) => {
5579 response_text.push_str(&text.text);
5580 }
5581 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5582 final_history = res.history().map(|h| h.to_vec());
5583 break;
5584 }
5585 Err(e) => {
5586 return Err(e.into());
5587 }
5588 _ => {}
5589 }
5590 }
5591
5592 let history = final_history
5593 .ok_or_else(|| anyhow::anyhow!("final response should include history"))?;
5594
5595 anyhow::ensure!(
5597 history.iter().any(|m| matches!(m, Message::User { .. })),
5598 "History should contain the user message"
5599 );
5600
5601 anyhow::ensure!(
5603 history
5604 .iter()
5605 .any(|m| matches!(m, Message::Assistant { .. })),
5606 "History should contain the assistant response"
5607 );
5608
5609 tracing::info!(
5610 "History after streaming: {} messages, response: {:?}",
5611 history.len(),
5612 response_text
5613 );
5614
5615 Ok(())
5616 }
5617
5618 #[tokio::test]
5619 async fn streaming_appends_to_memory_after_final_response() {
5620 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5621
5622 let memory = InMemoryConversationMemory::new();
5623 let agent = AgentBuilder::new(streaming_text_then_final_model())
5624 .memory(memory.clone())
5625 .build();
5626
5627 let mut stream = agent
5628 .stream_prompt("hi there")
5629 .conversation("stream-thread")
5630 .await;
5631
5632 let mut history_in_final = None;
5633 while let Some(item) = stream.next().await {
5634 match item {
5635 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5636 history_in_final = res.history().map(|h| h.to_vec());
5637 break;
5638 }
5639 Ok(_) => {}
5640 Err(err) => panic!("unexpected streaming error: {err:?}"),
5641 }
5642 }
5643
5644 let final_history = history_in_final
5645 .expect("FinalResponse.history should be populated when memory is configured");
5646 assert_eq!(
5647 final_history.len(),
5648 2,
5649 "user prompt + assistant response in final history: {final_history:?}"
5650 );
5651
5652 let stored = memory.load("stream-thread").await.unwrap();
5653 assert_eq!(stored.len(), 2, "memory should contain user + assistant");
5654 }
5655
5656 #[tokio::test]
5657 async fn streaming_reasoning_without_tools_does_not_duplicate_final_history() {
5658 let agent = AgentBuilder::new(MockCompletionModel::from_stream_turns([[
5659 MockStreamEvent::text("final answer"),
5660 MockStreamEvent::reasoning("reasoned step").with_reasoning_id("rs_1"),
5661 MockStreamEvent::final_response_with_total_tokens(3),
5662 ]]))
5663 .build();
5664
5665 let mut stream = agent
5666 .stream_prompt("think before answering")
5667 .with_history(Vec::<Message>::new())
5668 .await;
5669
5670 let mut history_in_final = None;
5671 while let Some(item) = stream.next().await {
5672 match item {
5673 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5674 history_in_final = res.history().map(|h| h.to_vec());
5675 break;
5676 }
5677 Ok(_) => {}
5678 Err(err) => panic!("unexpected streaming error: {err:?}"),
5679 }
5680 }
5681
5682 let final_history = history_in_final
5683 .expect("FinalResponse.history should be populated when with_history is used");
5684 assert_eq!(
5685 final_history.len(),
5686 2,
5687 "user prompt + one assistant response in final history: {final_history:?}"
5688 );
5689
5690 assert!(matches!(
5691 final_history.first(),
5692 Some(Message::User { content })
5693 if matches!(
5694 content.first(),
5695 UserContent::Text(text) if text.text == "think before answering"
5696 )
5697 ));
5698
5699 let assistant_messages = final_history
5700 .iter()
5701 .filter_map(|message| match message {
5702 Message::Assistant { content, .. } => Some(content),
5703 _ => None,
5704 })
5705 .collect::<Vec<_>>();
5706 assert_eq!(
5707 assistant_messages.len(),
5708 1,
5709 "reasoning turn should produce exactly one assistant history message: {final_history:?}"
5710 );
5711 let assistant_content = assistant_messages
5712 .first()
5713 .expect("expected assistant history message");
5714 assert!(assistant_content.iter().any(|item| matches!(
5715 item,
5716 AssistantContent::Text(text) if text.text == "final answer"
5717 )));
5718 assert!(assistant_content.iter().any(|item| matches!(
5719 item,
5720 AssistantContent::Reasoning(reasoning)
5721 if reasoning.id.as_deref() == Some("rs_1")
5722 && reasoning.content.iter().any(|content| matches!(
5723 content,
5724 ReasoningContent::Text { text, .. } if text == "reasoned step"
5725 ))
5726 )));
5727 }
5728
5729 #[tokio::test]
5730 async fn streaming_with_history_overrides_memory() {
5731 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5732
5733 let memory = InMemoryConversationMemory::new();
5734 memory
5735 .append("t1", vec![Message::user("from-memory")])
5736 .await
5737 .unwrap();
5738
5739 let agent = AgentBuilder::new(streaming_text_then_final_model())
5740 .memory(memory.clone())
5741 .build();
5742
5743 let mut stream = agent
5744 .stream_prompt("hi")
5745 .conversation("t1")
5746 .with_history(vec![Message::user("from-caller")])
5747 .await;
5748
5749 while let Some(item) = stream.next().await {
5750 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5751 break;
5752 }
5753 }
5754
5755 let stored = memory.load("t1").await.unwrap();
5756 assert_eq!(
5757 stored.len(),
5758 1,
5759 "with_history bypasses memory; only the pre-seeded entry remains: {stored:?}"
5760 );
5761 }
5762
5763 #[tokio::test]
5764 async fn streaming_without_memory_disables_for_request() {
5765 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5766
5767 let memory = InMemoryConversationMemory::new();
5768 let agent = AgentBuilder::new(streaming_text_then_final_model())
5769 .memory(memory.clone())
5770 .conversation_id("default")
5771 .build();
5772
5773 let mut stream = agent.stream_prompt("hi").without_memory().await;
5774
5775 while let Some(item) = stream.next().await {
5776 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5777 break;
5778 }
5779 }
5780
5781 let stored = memory.load("default").await.unwrap();
5782 assert!(stored.is_empty(), "without_memory disables save");
5783 }
5784
5785 #[tokio::test]
5786 async fn streaming_load_error_yields_memory_error() {
5787 let agent = AgentBuilder::new(streaming_text_then_final_model())
5788 .memory(FailingMemory::default())
5789 .build();
5790
5791 let mut stream = agent.stream_prompt("hi").conversation("t1").await;
5792
5793 let first = stream.next().await.expect("at least one item");
5794 match first {
5795 Err(err) => {
5796 let msg = format!("{err:?}");
5797 assert!(
5798 msg.contains("Memory") || msg.contains("memory") || msg.contains("load boom"),
5799 "expected memory error, got: {msg}"
5800 );
5801 }
5802 Ok(other) => panic!("expected memory error, got {other:?}"),
5803 }
5804 }
5805
5806 #[tokio::test]
5807 async fn streaming_with_filter_shapes_loaded_history() {
5808 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5809
5810 let memory = InMemoryConversationMemory::new()
5811 .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
5812 memory
5813 .append(
5814 "t1",
5815 vec![
5816 Message::user("1"),
5817 Message::assistant("2"),
5818 Message::user("3"),
5819 Message::assistant("4"),
5820 ],
5821 )
5822 .await
5823 .unwrap();
5824
5825 let model = MockCompletionModel::from_stream_turns([[
5826 MockStreamEvent::text("ok"),
5827 MockStreamEvent::final_response_with_total_tokens(1),
5828 ]]);
5829 let recorded = model.clone();
5830 let agent = AgentBuilder::new(model).memory(memory).build();
5831
5832 let mut stream = agent.stream_prompt("ping").conversation("t1").await;
5833 while let Some(item) = stream.next().await {
5834 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5835 break;
5836 }
5837 }
5838
5839 let received = recorded.requests()[0]
5840 .chat_history
5841 .iter()
5842 .cloned()
5843 .collect::<Vec<_>>();
5844 assert_eq!(
5845 received.len(),
5846 3,
5847 "window-truncated history (2) + current prompt: {received:?}"
5848 );
5849 }
5850
5851 #[tokio::test]
5852 async fn streaming_append_error_does_not_suppress_final_response() {
5853 let agent = AgentBuilder::new(streaming_text_then_final_model())
5854 .memory(AppendFailingMemory::default())
5855 .build();
5856
5857 let mut stream = agent.stream_prompt("hi").conversation("t1").await;
5858
5859 let mut saw_final = false;
5860 while let Some(item) = stream.next().await {
5861 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5862 saw_final = true;
5863 break;
5864 }
5865 }
5866 assert!(
5867 saw_final,
5868 "FinalResponse must be yielded even when memory.append fails"
5869 );
5870 }
5871}