1use crate::{
2 OneOrMany,
3 agent::completion::{DynamicContextStore, build_prepared_completion_request},
4 agent::prompt_request::{
5 HookAction, InvalidToolCallResolution, TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER,
6 hooks::PromptHook, resolve_invalid_tool_call, validate_tool_call_name,
7 },
8 completion::{Document, GetTokenUsage},
9 json_utils,
10 memory::ConversationMemory,
11 message::{
12 AssistantContent, ToolCall, ToolChoice, ToolFunction, ToolResult, ToolResultContent,
13 UserContent,
14 },
15 streaming::{StreamedAssistantContent, StreamedUserContent, ToolCallDeltaContent},
16 tool::server::ToolServerHandle,
17 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
18};
19use futures::{Stream, StreamExt};
20use serde::{Deserialize, Serialize};
21use std::{collections::HashMap, pin::Pin, sync::Arc};
22use tracing::info_span;
23use tracing_futures::Instrument;
24
25use super::{CompletionCall, ToolCallHookAction, reported_usage};
26use crate::{
27 agent::Agent,
28 completion::{CompletionError, CompletionModel, PromptError},
29 message::{Message, Text},
30 tool::ToolSetError,
31};
32
33#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
34pub type StreamingResult<R> =
35 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
36
37#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
38pub type StreamingResult<R> =
39 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
40
41#[derive(Deserialize, Serialize, Debug, Clone)]
42#[serde(tag = "type", rename_all = "camelCase")]
43#[non_exhaustive]
44pub enum MultiTurnStreamItem<R> {
45 StreamAssistantItem(StreamedAssistantContent<R>),
47 StreamUserItem(StreamedUserContent),
49 CompletionCall(CompletionCall),
64 FinalResponse(FinalResponse),
66}
67
68#[derive(Deserialize, Serialize, Debug, Clone)]
69#[serde(rename_all = "camelCase")]
70pub struct FinalResponse {
71 content: OneOrMany<AssistantContent>,
73 response: String,
76 aggregated_usage: crate::completion::Usage,
77 #[serde(default, skip_serializing_if = "Vec::is_empty")]
79 completion_calls: Vec<CompletionCall>,
80 #[serde(skip_serializing_if = "Option::is_none")]
81 history: Option<Vec<Message>>,
82}
83
84impl FinalResponse {
85 pub fn empty() -> Self {
86 Self::new(
87 OneOrMany::one(AssistantContent::text("")),
88 crate::completion::Usage::new(),
89 None,
90 )
91 }
92
93 pub fn new(
94 content: OneOrMany<AssistantContent>,
95 aggregated_usage: crate::completion::Usage,
96 history: Option<Vec<Message>>,
97 ) -> Self {
98 let response = assistant_text_from_choice(&content);
99 Self {
100 content,
101 response,
102 aggregated_usage,
103 completion_calls: Vec::new(),
104 history,
105 }
106 }
107
108 pub fn response(&self) -> &str {
110 &self.response
111 }
112
113 pub fn content(&self) -> &OneOrMany<AssistantContent> {
115 &self.content
116 }
117
118 pub fn assistant_content(&self) -> &OneOrMany<AssistantContent> {
120 &self.content
121 }
122
123 pub fn usage(&self) -> crate::completion::Usage {
124 self.aggregated_usage
125 }
126
127 pub fn completion_calls(&self) -> &[CompletionCall] {
134 &self.completion_calls
135 }
136
137 pub fn history(&self) -> Option<&[Message]> {
138 self.history.as_deref()
139 }
140}
141
142impl<R> MultiTurnStreamItem<R> {
143 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
144 Self::StreamAssistantItem(item)
145 }
146
147 pub fn final_response(
148 content: OneOrMany<AssistantContent>,
149 aggregated_usage: crate::completion::Usage,
150 ) -> Self {
151 Self::FinalResponse(FinalResponse::new(content, aggregated_usage, None))
152 }
153
154 pub fn final_response_with_history(
155 content: OneOrMany<AssistantContent>,
156 aggregated_usage: crate::completion::Usage,
157 history: Option<Vec<Message>>,
158 ) -> Self {
159 Self::FinalResponse(FinalResponse::new(content, aggregated_usage, history))
160 }
161
162 pub(crate) fn final_response_with_completion_calls(
163 content: OneOrMany<AssistantContent>,
164 aggregated_usage: crate::completion::Usage,
165 completion_calls: Vec<CompletionCall>,
166 history: Option<Vec<Message>>,
167 ) -> Self {
168 let mut response = FinalResponse::new(content, aggregated_usage, history);
169 response.completion_calls = completion_calls;
170 Self::FinalResponse(response)
171 }
172}
173
174fn merge_reasoning_blocks(
175 accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
176 incoming: &crate::message::Reasoning,
177) {
178 let ids_match = |existing: &crate::message::Reasoning| {
179 matches!(
180 (&existing.id, &incoming.id),
181 (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
182 )
183 };
184
185 if let Some(existing) = accumulated_reasoning
186 .iter_mut()
187 .rev()
188 .find(|existing| ids_match(existing))
189 {
190 existing.content.extend(incoming.content.clone());
191 } else {
192 accumulated_reasoning.push(incoming.clone());
193 }
194}
195
196fn flush_pending_reasoning_delta(
197 accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
198 pending_reasoning_delta_text: &mut String,
199 pending_reasoning_delta_id: &mut Option<String>,
200) {
201 if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
202 let mut assembled = crate::message::Reasoning::new(&*pending_reasoning_delta_text);
203 if let Some(id) = pending_reasoning_delta_id.take() {
204 assembled = assembled.with_id(id);
205 }
206 accumulated_reasoning.push(assembled);
207 pending_reasoning_delta_text.clear();
208 }
209}
210
211fn build_full_history(
213 chat_history: Option<&[Message]>,
214 new_messages: Vec<Message>,
215) -> Vec<Message> {
216 let input = chat_history.unwrap_or(&[]);
217 input.iter().cloned().chain(new_messages).collect()
218}
219
220struct ToolCallValidationHistory<'a> {
221 chat_history: Option<&'a [Message]>,
222 new_messages: &'a [Message],
223 assistant_message_id: &'a Option<String>,
224 final_turn_content: Option<&'a OneOrMany<AssistantContent>>,
225 text_delta_response: Option<&'a str>,
226 accumulated_reasoning: &'a [crate::message::Reasoning],
227 pending_reasoning_delta_text: &'a str,
228 pending_reasoning_delta_id: &'a Option<String>,
229 pending_tool_calls: &'a [(ToolCall, String)],
230 current_tool_call: Option<ToolCall>,
231}
232
233fn build_tool_call_validation_history(input: ToolCallValidationHistory<'_>) -> Vec<Message> {
234 let mut messages = input.new_messages.to_vec();
235
236 if let Some(final_turn_content) = input.final_turn_content
237 && !is_empty_assistant_choice(final_turn_content)
238 {
239 messages.push(Message::Assistant {
240 id: input.assistant_message_id.clone(),
241 content: final_turn_content.clone(),
242 });
243 return build_full_history(input.chat_history, messages);
244 }
245
246 let mut content_items = Vec::new();
247 if let Some(text) = input.text_delta_response
248 && !text.is_empty()
249 {
250 content_items.push(AssistantContent::text(text.to_string()));
251 }
252 content_items.extend(
253 input
254 .accumulated_reasoning
255 .iter()
256 .cloned()
257 .map(AssistantContent::Reasoning),
258 );
259 if input.accumulated_reasoning.is_empty() && !input.pending_reasoning_delta_text.is_empty() {
260 let mut reasoning = crate::message::Reasoning::new(input.pending_reasoning_delta_text);
261 if let Some(id) = input.pending_reasoning_delta_id.clone() {
262 reasoning = reasoning.with_id(id);
263 }
264 content_items.push(AssistantContent::Reasoning(reasoning));
265 }
266 content_items.extend(
267 input
268 .pending_tool_calls
269 .iter()
270 .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone())),
271 );
272 if let Some(tool_call) = input.current_tool_call {
273 content_items.push(AssistantContent::ToolCall(tool_call));
274 }
275
276 if let Some(content) = OneOrMany::from_iter_optional(content_items) {
277 messages.push(Message::Assistant {
278 id: input.assistant_message_id.clone(),
279 content,
280 });
281 }
282
283 build_full_history(input.chat_history, messages)
284}
285
286async fn drain_recovered_stream_usage<R>(
287 stream: &mut crate::streaming::StreamingCompletionResponse<R>,
288 tool_call_delta_states: &HashMap<(String, String), ToolCallDeltaState>,
289 current_call_usage: &mut Option<crate::completion::Usage>,
290 aggregated_usage: &mut crate::completion::Usage,
291) -> Result<(), StreamingError>
292where
293 R: Clone + Unpin + GetTokenUsage,
294{
295 if let Some(err) = pending_tool_call_delta_error(tool_call_delta_states) {
296 return Err(err.into());
297 }
298
299 while let Some(content) = stream.next().await {
300 match content {
301 Ok(StreamedAssistantContent::Final(final_resp)) => {
302 if let Some(usage) = final_resp.token_usage() {
303 *current_call_usage = reported_usage(usage);
304 }
305 if let Some(usage) = *current_call_usage {
306 *aggregated_usage += usage;
307 }
308 return Ok(());
309 }
310 Ok(_) => {}
311 Err(err) => return Err(err.into()),
312 }
313 }
314
315 Ok(())
316}
317
318fn record_completion_call_if_needed(
319 completion_calls: &mut Vec<CompletionCall>,
320 completion_call_emitted: &mut bool,
321 call_index: usize,
322 current_call_usage: Option<crate::completion::Usage>,
323) -> Option<CompletionCall> {
324 if *completion_call_emitted {
325 return None;
326 }
327
328 let completion_call = CompletionCall::new(call_index, current_call_usage);
329 completion_calls.push(completion_call);
330 *completion_call_emitted = true;
331 Some(completion_call)
332}
333
334fn build_history_for_request(
336 chat_history: Option<&[Message]>,
337 new_messages: &[Message],
338) -> Vec<Message> {
339 let input = chat_history.unwrap_or(&[]);
340 input.iter().chain(new_messages.iter()).cloned().collect()
341}
342
343async fn cancelled_prompt_error(
344 chat_history: Option<&[Message]>,
345 new_messages: Vec<Message>,
346 reason: String,
347) -> StreamingError {
348 StreamingError::Prompt(
349 PromptError::prompt_cancelled(build_full_history(chat_history, new_messages), reason)
350 .into(),
351 )
352}
353
354fn tool_result_to_user_message(
355 id: String,
356 call_id: Option<String>,
357 tool_result: String,
358) -> Message {
359 let content = ToolResultContent::from_tool_output(tool_result);
360 let user_content = match call_id {
361 Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
362 None => UserContent::tool_result(id, content),
363 };
364
365 Message::User {
366 content: OneOrMany::one(user_content),
367 }
368}
369
370fn tool_result_user_content(
371 id: String,
372 call_id: Option<String>,
373 tool_result: String,
374) -> UserContent {
375 let content = ToolResultContent::from_tool_output(tool_result);
376 match call_id {
377 Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
378 None => UserContent::tool_result(id, content),
379 }
380}
381
382fn invalid_streaming_tool_retry_messages(
383 assistant_message_id: &Option<String>,
384 text_delta_response: Option<&str>,
385 accumulated_reasoning: &[crate::message::Reasoning],
386 pending_tool_calls: &[(ToolCall, String)],
387 invalid_tool_call: ToolCall,
388 feedback: String,
389) -> Option<(Message, Message)> {
390 let mut assistant_content = Vec::new();
391 if let Some(text) = text_delta_response
392 && !text.is_empty()
393 {
394 assistant_content.push(AssistantContent::text(text.to_string()));
395 }
396 assistant_content.extend(
397 accumulated_reasoning
398 .iter()
399 .cloned()
400 .map(AssistantContent::Reasoning),
401 );
402 assistant_content.extend(
403 pending_tool_calls
404 .iter()
405 .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone())),
406 );
407 assistant_content.push(AssistantContent::ToolCall(invalid_tool_call.clone()));
408
409 let assistant_content = OneOrMany::from_iter_optional(assistant_content)?;
410 let assistant_message = Message::Assistant {
411 id: assistant_message_id.clone(),
412 content: assistant_content,
413 };
414
415 let mut retry_results = pending_tool_calls
416 .iter()
417 .map(|(tool_call, _)| {
418 tool_result_user_content(
419 tool_call.id.clone(),
420 tool_call.call_id.clone(),
421 TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
422 )
423 })
424 .collect::<Vec<_>>();
425 retry_results.push(tool_result_user_content(
426 invalid_tool_call.id,
427 invalid_tool_call.call_id,
428 feedback,
429 ));
430
431 let user_message = Message::User {
432 content: OneOrMany::from_iter_optional(retry_results)?,
433 };
434
435 Some((assistant_message, user_message))
436}
437
438fn invalid_streaming_name_delta_retry_messages(
439 assistant_message_id: &Option<String>,
440 text_delta_response: Option<&str>,
441 accumulated_reasoning: &[crate::message::Reasoning],
442 pending_tool_calls: &[(ToolCall, String)],
443 invalid_tool_call: ToolCall,
444 feedback: String,
445) -> Option<(Message, Message)> {
446 let mut assistant_content = Vec::new();
447 if let Some(text) = text_delta_response
448 && !text.is_empty()
449 {
450 assistant_content.push(AssistantContent::text(text.to_string()));
451 }
452 assistant_content.extend(
453 accumulated_reasoning
454 .iter()
455 .cloned()
456 .map(AssistantContent::Reasoning),
457 );
458 assistant_content.extend(
459 pending_tool_calls
460 .iter()
461 .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone())),
462 );
463 assistant_content.push(AssistantContent::ToolCall(invalid_tool_call.clone()));
464
465 let assistant_message = Message::Assistant {
466 id: assistant_message_id.clone(),
467 content: OneOrMany::from_iter_optional(assistant_content)?,
468 };
469 let mut retry_results = pending_tool_calls
470 .iter()
471 .map(|(tool_call, _)| {
472 tool_result_user_content(
473 tool_call.id.clone(),
474 tool_call.call_id.clone(),
475 TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
476 )
477 })
478 .collect::<Vec<_>>();
479 retry_results.push(tool_result_user_content(
480 invalid_tool_call.id,
481 invalid_tool_call.call_id,
482 feedback,
483 ));
484 let user_message = Message::User {
485 content: OneOrMany::from_iter_optional(retry_results)?,
486 };
487
488 Some((assistant_message, user_message))
489}
490
491fn assistant_text_from_choice(choice: &OneOrMany<AssistantContent>) -> String {
492 choice
493 .iter()
494 .filter_map(|content| match content {
495 AssistantContent::Text(text) => Some(text.text.as_str()),
496 _ => None,
497 })
498 .collect()
499}
500
501fn assistant_text_items_from_choice(choice: &OneOrMany<AssistantContent>) -> Vec<AssistantContent> {
502 choice
503 .iter()
504 .filter_map(|content| match content {
505 AssistantContent::Text(text) => (!text.text.is_empty()
506 || text.additional_params.is_some())
507 .then(|| AssistantContent::Text(text.clone())),
508 _ => None,
509 })
510 .collect()
511}
512
513fn is_empty_assistant_choice(choice: &OneOrMany<AssistantContent>) -> bool {
514 choice.len() == 1
515 && matches!(
516 choice.first(),
517 AssistantContent::Text(text)
518 if text.text.is_empty() && text.additional_params.is_none()
519 )
520}
521
522#[derive(Default)]
523struct ToolCallDeltaState {
524 name_validated: bool,
525 buffered_arguments: Vec<String>,
526}
527
528fn pending_tool_call_delta_error(
529 states: &HashMap<(String, String), ToolCallDeltaState>,
530) -> Option<CompletionError> {
531 states
532 .iter()
533 .find(|(_, state)| !state.name_validated && !state.buffered_arguments.is_empty())
534 .map(|((id, internal_call_id), state)| {
535 CompletionError::ResponseError(format!(
536 "streamed tool call arguments received before a validated tool name for id `{id}` and internal_call_id `{internal_call_id}` ({} buffered argument delta(s))",
537 state.buffered_arguments.len()
538 ))
539 })
540}
541
542#[derive(Debug, thiserror::Error)]
543pub enum StreamingError {
544 #[error("CompletionError: {0}")]
545 Completion(#[from] CompletionError),
546 #[error("PromptError: {0}")]
547 Prompt(#[from] Box<PromptError>),
548 #[error("ToolSetError: {0}")]
549 Tool(#[from] ToolSetError),
550}
551
552impl From<crate::memory::MemoryError> for StreamingError {
556 fn from(err: crate::memory::MemoryError) -> Self {
557 Self::Completion(CompletionError::RequestError(Box::new(err)))
558 }
559}
560
561const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
562
563pub struct StreamingPromptRequest<M, P>
572where
573 M: CompletionModel,
574 P: PromptHook<M> + 'static,
575{
576 prompt: Message,
578 chat_history: Option<Vec<Message>>,
580 max_turns: usize,
582
583 model: Arc<M>,
586 agent_name: Option<String>,
588 preamble: Option<String>,
590 static_context: Vec<Document>,
592 temperature: Option<f64>,
594 max_tokens: Option<u64>,
596 additional_params: Option<serde_json::Value>,
598 tool_server_handle: ToolServerHandle,
600 dynamic_context: DynamicContextStore,
602 tool_choice: Option<ToolChoice>,
604 output_schema: Option<schemars::Schema>,
606 hook: Option<P>,
608 max_invalid_tool_call_retries: usize,
610 memory: Option<Arc<dyn ConversationMemory>>,
612 conversation_id: Option<String>,
614}
615
616impl<M, P> StreamingPromptRequest<M, P>
617where
618 M: CompletionModel + 'static,
619 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
620 P: PromptHook<M>,
621{
622 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
625 StreamingPromptRequest {
626 prompt: prompt.into(),
627 chat_history: None,
628 max_turns: agent.default_max_turns.unwrap_or_default(),
629 model: agent.model.clone(),
630 agent_name: agent.name.clone(),
631 preamble: agent.preamble.clone(),
632 static_context: agent.static_context.clone(),
633 temperature: agent.temperature,
634 max_tokens: agent.max_tokens,
635 additional_params: agent.additional_params.clone(),
636 tool_server_handle: agent.tool_server_handle.clone(),
637 dynamic_context: agent.dynamic_context.clone(),
638 tool_choice: agent.tool_choice.clone(),
639 output_schema: agent.output_schema.clone(),
640 hook: None,
641 max_invalid_tool_call_retries: 0,
642 memory: agent.memory.clone(),
643 conversation_id: agent.default_conversation_id.clone(),
644 }
645 }
646
647 pub fn from_agent<P2>(
649 agent: &Agent<M, P2>,
650 prompt: impl Into<Message>,
651 ) -> StreamingPromptRequest<M, P2>
652 where
653 P2: PromptHook<M>,
654 {
655 StreamingPromptRequest {
656 prompt: prompt.into(),
657 chat_history: None,
658 max_turns: agent.default_max_turns.unwrap_or_default(),
659 model: agent.model.clone(),
660 agent_name: agent.name.clone(),
661 preamble: agent.preamble.clone(),
662 static_context: agent.static_context.clone(),
663 temperature: agent.temperature,
664 max_tokens: agent.max_tokens,
665 additional_params: agent.additional_params.clone(),
666 tool_server_handle: agent.tool_server_handle.clone(),
667 dynamic_context: agent.dynamic_context.clone(),
668 tool_choice: agent.tool_choice.clone(),
669 output_schema: agent.output_schema.clone(),
670 hook: agent.hook.clone(),
671 max_invalid_tool_call_retries: 0,
672 memory: agent.memory.clone(),
673 conversation_id: agent.default_conversation_id.clone(),
674 }
675 }
676
677 fn agent_name(&self) -> &str {
678 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
679 }
680
681 pub fn multi_turn(mut self, turns: usize) -> Self {
684 self.max_turns = turns;
685 self
686 }
687
688 pub fn with_history<H, T>(mut self, history: H) -> Self
701 where
702 H: IntoIterator<Item = T>,
703 T: Into<Message>,
704 {
705 self.chat_history = Some(history.into_iter().map(Into::into).collect());
706 self
707 }
708
709 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
712 where
713 P2: PromptHook<M>,
714 {
715 StreamingPromptRequest {
716 prompt: self.prompt,
717 chat_history: self.chat_history,
718 max_turns: self.max_turns,
719 model: self.model,
720 agent_name: self.agent_name,
721 preamble: self.preamble,
722 static_context: self.static_context,
723 temperature: self.temperature,
724 max_tokens: self.max_tokens,
725 additional_params: self.additional_params,
726 tool_server_handle: self.tool_server_handle,
727 dynamic_context: self.dynamic_context,
728 tool_choice: self.tool_choice,
729 output_schema: self.output_schema,
730 hook: Some(hook),
731 max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
732 memory: self.memory,
733 conversation_id: self.conversation_id,
734 }
735 }
736
737 pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
741 self.max_invalid_tool_call_retries = retries;
742 self
743 }
744
745 pub fn conversation(mut self, id: impl Into<String>) -> Self {
750 self.conversation_id = Some(id.into());
751 self
752 }
753
754 pub fn without_memory(mut self) -> Self {
758 self.memory = None;
759 self.conversation_id = None;
760 self
761 }
762
763 async fn send(self) -> StreamingResult<M::StreamingResponse> {
764 let agent_span = if tracing::Span::current().is_disabled() {
765 info_span!(
766 "invoke_agent",
767 gen_ai.operation.name = "invoke_agent",
768 gen_ai.agent.name = self.agent_name(),
769 gen_ai.system_instructions = self.preamble,
770 gen_ai.prompt = tracing::field::Empty,
771 gen_ai.completion = tracing::field::Empty,
772 gen_ai.usage.input_tokens = tracing::field::Empty,
773 gen_ai.usage.output_tokens = tracing::field::Empty,
774 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
775 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
776 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
777 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
778 )
779 } else {
780 tracing::Span::current()
781 };
782
783 let prompt = self.prompt;
784 if let Some(text) = prompt.rag_text() {
785 agent_span.record("gen_ai.prompt", text);
786 }
787
788 let model = self.model.clone();
790 let preamble = self.preamble.clone();
791 let static_context = self.static_context.clone();
792 let temperature = self.temperature;
793 let max_tokens = self.max_tokens;
794 let additional_params = self.additional_params.clone();
795 let tool_server_handle = self.tool_server_handle.clone();
796 let dynamic_context = self.dynamic_context.clone();
797 let tool_choice = self.tool_choice.clone();
798 let agent_name = self.agent_name.clone();
799 let (chat_history, memory_handle) = match self.chat_history {
804 Some(history) => (Some(history), None),
805 None => match (self.memory, self.conversation_id) {
806 (Some(memory), Some(id)) => match memory.load(&id).await {
807 Ok(loaded) => (Some(loaded), Some((memory, id))),
808 Err(err) => {
809 let stream = async_stream::stream! {
810 yield Err(StreamingError::from(err));
811 };
812 return Box::pin(stream);
813 }
814 },
815 _ => (None, None),
816 },
817 };
818 let has_history = chat_history.is_some();
819 let mut new_messages: Vec<Message> = vec![prompt.clone()];
820
821 let mut current_max_turns = 0;
822 let mut last_prompt_error = String::new();
823
824 let mut text_delta_response = String::new();
825 let mut saw_text_this_turn = false;
826 let mut max_turns_reached = false;
827 let output_schema = self.output_schema;
828
829 let mut aggregated_usage = crate::completion::Usage::new();
830 let mut completion_calls = Vec::new();
831 let mut completion_call_index = 0;
832 let mut invalid_tool_call_retries = 0;
833
834 let stream = async_stream::stream! {
841 'outer: loop {
842 let Some((current_prompt_ref, previous_messages)) = new_messages.split_last() else {
843 yield Err(cancelled_prompt_error(
844 chat_history.as_deref(),
845 new_messages.clone(),
846 "streaming loop lost its pending prompt".to_string(),
847 ).await);
848 break 'outer;
849 };
850 let current_prompt = current_prompt_ref.clone();
851
852 if current_max_turns > self.max_turns + 1 {
853 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
854 max_turns_reached = true;
855 break;
856 }
857
858 current_max_turns += 1;
859
860 if self.max_turns > 1 {
861 tracing::info!(
862 "Current conversation Turns: {}/{}",
863 current_max_turns,
864 self.max_turns
865 );
866 }
867
868 let history_snapshot: Vec<Message> = build_history_for_request(
869 chat_history.as_deref(),
870 previous_messages,
871 );
872
873 if let Some(ref hook) = self.hook
874 && let HookAction::Terminate { reason } =
875 hook.on_completion_call(¤t_prompt, &history_snapshot).await
876 {
877 yield Err(
878 cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason)
879 .await,
880 );
881 break 'outer;
882 }
883
884 let chat_stream_span = info_span!(
885 target: "rig::agent_chat",
886 parent: tracing::Span::current(),
887 "chat_streaming",
888 gen_ai.operation.name = "chat",
889 gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
890 gen_ai.system_instructions = preamble,
891 gen_ai.provider.name = tracing::field::Empty,
892 gen_ai.request.model = tracing::field::Empty,
893 gen_ai.response.id = tracing::field::Empty,
894 gen_ai.response.model = tracing::field::Empty,
895 gen_ai.usage.output_tokens = tracing::field::Empty,
896 gen_ai.usage.input_tokens = tracing::field::Empty,
897 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
898 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
899 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
900 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
901 gen_ai.input.messages = tracing::field::Empty,
902 gen_ai.output.messages = tracing::field::Empty,
903 );
904
905 let prepared_request = build_prepared_completion_request(
906 &model,
907 current_prompt.clone(),
908 &history_snapshot,
909 preamble.as_deref(),
910 &static_context,
911 temperature,
912 max_tokens,
913 additional_params.as_ref(),
914 tool_choice.as_ref(),
915 &tool_server_handle,
916 &dynamic_context,
917 output_schema.as_ref(),
918 )
919 .await?;
920 let executable_tool_names = prepared_request.executable_tool_names.clone();
921 let allowed_tool_names = prepared_request.allowed_tool_names.clone();
922
923 let mut stream = prepared_request
924 .builder
925 .stream()
926 .instrument(chat_stream_span)
927 .await?;
928
929 let call_index = completion_call_index;
930 completion_call_index += 1;
931 let mut current_call_usage = None;
932 let mut completion_call_emitted = false;
933 let mut pending_tool_calls: Vec<(ToolCall, String)> = vec![];
934 let mut tool_calls = vec![];
935 let mut tool_results = vec![];
936 let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
937 let mut pending_reasoning_delta_text = String::new();
940 let mut pending_reasoning_delta_id: Option<String> = None;
941 let mut tool_call_delta_states: HashMap<(String, String), ToolCallDeltaState> =
942 HashMap::new();
943 let mut saw_tool_call_this_turn = false;
944
945 while let Some(content) = stream.next().await {
946 match content {
947 Ok(StreamedAssistantContent::Text(text)) => {
948 if !saw_text_this_turn {
949 text_delta_response.clear();
950 saw_text_this_turn = true;
951 }
952 text_delta_response.push_str(&text.text);
953 if let Some(ref hook) = self.hook &&
954 let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &text_delta_response).await {
955 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
956 break 'outer;
957 }
958
959 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
960 },
961 Ok(StreamedAssistantContent::ToolCall { mut tool_call, internal_call_id }) => {
962 let diagnostic_history =
963 build_tool_call_validation_history(ToolCallValidationHistory {
964 chat_history: chat_history.as_deref(),
965 new_messages: &new_messages,
966 assistant_message_id: &stream.message_id,
967 final_turn_content: None,
968 text_delta_response: saw_text_this_turn
969 .then_some(text_delta_response.as_str()),
970 accumulated_reasoning: &accumulated_reasoning,
971 pending_reasoning_delta_text: &pending_reasoning_delta_text,
972 pending_reasoning_delta_id: &pending_reasoning_delta_id,
973 pending_tool_calls: &pending_tool_calls,
974 current_tool_call: Some(tool_call.clone()),
975 });
976
977 if !allowed_tool_names.contains(&tool_call.function.name) {
978 let args = json_utils::value_to_json_string(&tool_call.function.arguments);
979 let emitted_tool_name = tool_call.function.name.clone();
980 match resolve_invalid_tool_call::<M, P>(
981 self.hook.as_ref(),
982 &emitted_tool_name,
983 Some(tool_call.id.clone()),
984 Some(internal_call_id.clone()),
985 Some(args),
986 &executable_tool_names,
987 &allowed_tool_names,
988 self.tool_choice.as_ref(),
989 diagnostic_history.clone(),
990 true,
991 ).await {
992 InvalidToolCallResolution::Fail(err) => {
993 yield Err(Box::new(err).into());
994 break 'outer;
995 }
996 InvalidToolCallResolution::Retry(feedback) => {
997 if invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
998 yield Err(Box::new(PromptError::UnknownToolCall {
999 tool_name: emitted_tool_name,
1000 available_tools: executable_tool_names.iter().cloned().collect(),
1001 allowed_tools: allowed_tool_names.iter().cloned().collect(),
1002 chat_history: Box::new(diagnostic_history),
1003 }).into());
1004 break 'outer;
1005 }
1006
1007 invalid_tool_call_retries += 1;
1008 flush_pending_reasoning_delta(
1009 &mut accumulated_reasoning,
1010 &mut pending_reasoning_delta_text,
1011 &mut pending_reasoning_delta_id,
1012 );
1013 let Some((assistant_message, user_message)) =
1014 invalid_streaming_tool_retry_messages(
1015 &stream.message_id,
1016 saw_text_this_turn.then_some(text_delta_response.as_str()),
1017 &accumulated_reasoning,
1018 &pending_tool_calls,
1019 tool_call,
1020 feedback,
1021 )
1022 else {
1023 yield Err(cancelled_prompt_error(
1024 chat_history.as_deref(),
1025 new_messages.clone(),
1026 "invalid tool call retry produced no retry messages".to_string(),
1027 ).await);
1028 break 'outer;
1029 };
1030 new_messages.push(assistant_message);
1031 new_messages.push(user_message);
1032 if let Err(err) = drain_recovered_stream_usage(
1033 &mut stream,
1034 &tool_call_delta_states,
1035 &mut current_call_usage,
1036 &mut aggregated_usage,
1037 )
1038 .await
1039 {
1040 yield Err(err);
1041 break 'outer;
1042 }
1043 if let Some(completion_call) = record_completion_call_if_needed(
1044 &mut completion_calls,
1045 &mut completion_call_emitted,
1046 call_index,
1047 current_call_usage,
1048 ) {
1049 yield Ok(MultiTurnStreamItem::CompletionCall(
1050 completion_call,
1051 ));
1052 }
1053 text_delta_response.clear();
1054 saw_text_this_turn = false;
1055 continue 'outer;
1056 }
1057 InvalidToolCallResolution::Repair(repaired_name) => {
1058 tool_call.function.name = repaired_name;
1059 }
1060 InvalidToolCallResolution::Skip(reason) => {
1061 let skipped_tool_result = ToolResult {
1062 id: tool_call.id.clone(),
1063 call_id: tool_call.call_id.clone(),
1064 content: ToolResultContent::from_tool_output(
1065 reason.clone(),
1066 ),
1067 };
1068 flush_pending_reasoning_delta(
1069 &mut accumulated_reasoning,
1070 &mut pending_reasoning_delta_text,
1071 &mut pending_reasoning_delta_id,
1072 );
1073 let Some((assistant_message, user_message)) =
1074 invalid_streaming_tool_retry_messages(
1075 &stream.message_id,
1076 saw_text_this_turn
1077 .then_some(text_delta_response.as_str()),
1078 &accumulated_reasoning,
1079 &pending_tool_calls,
1080 tool_call,
1081 reason,
1082 )
1083 else {
1084 yield Err(cancelled_prompt_error(
1085 chat_history.as_deref(),
1086 new_messages.clone(),
1087 "invalid tool call skip produced no recovery messages".to_string(),
1088 ).await);
1089 break 'outer;
1090 };
1091 new_messages.push(assistant_message);
1092 new_messages.push(user_message);
1093 let tool_result = ToolResult {
1094 id: skipped_tool_result.id,
1095 call_id: skipped_tool_result.call_id,
1096 content: skipped_tool_result.content,
1097 };
1098 if let Err(err) = drain_recovered_stream_usage(
1099 &mut stream,
1100 &tool_call_delta_states,
1101 &mut current_call_usage,
1102 &mut aggregated_usage,
1103 )
1104 .await
1105 {
1106 yield Err(err);
1107 break 'outer;
1108 }
1109 if let Some(completion_call) = record_completion_call_if_needed(
1110 &mut completion_calls,
1111 &mut completion_call_emitted,
1112 call_index,
1113 current_call_usage,
1114 ) {
1115 yield Ok(MultiTurnStreamItem::CompletionCall(
1116 completion_call,
1117 ));
1118 }
1119 yield Ok(MultiTurnStreamItem::StreamUserItem(
1120 StreamedUserContent::ToolResult {
1121 tool_result,
1122 internal_call_id,
1123 },
1124 ));
1125 text_delta_response.clear();
1126 saw_text_this_turn = false;
1127 continue 'outer;
1128 }
1129 }
1130 }
1131
1132 pending_tool_calls.push((tool_call, internal_call_id));
1133 },
1134 Ok(StreamedAssistantContent::ToolCallDelta {
1135 id,
1136 internal_call_id,
1137 content,
1138 }) => {
1139 let key = (id.clone(), internal_call_id.clone());
1140 let mut deltas_to_emit = Vec::new();
1141
1142 match content {
1143 ToolCallDeltaContent::Name(mut name) => {
1144 let buffered_args = tool_call_delta_states
1145 .get(&key)
1146 .map(|state| state.buffered_arguments.join(""))
1147 .unwrap_or_default();
1148 let diagnostic_args = if buffered_args.trim().is_empty() {
1149 serde_json::Value::Null
1150 } else {
1151 serde_json::from_str(&buffered_args)
1152 .unwrap_or(serde_json::Value::Null)
1153 };
1154 let diagnostic_tool_call = ToolCall::new(
1155 id.clone(),
1156 ToolFunction::new(name.clone(), diagnostic_args),
1157 );
1158 let diagnostic_history =
1159 build_tool_call_validation_history(ToolCallValidationHistory {
1160 chat_history: chat_history.as_deref(),
1161 new_messages: &new_messages,
1162 assistant_message_id: &stream.message_id,
1163 final_turn_content: None,
1164 text_delta_response: saw_text_this_turn
1165 .then_some(text_delta_response.as_str()),
1166 accumulated_reasoning: &accumulated_reasoning,
1167 pending_reasoning_delta_text: &pending_reasoning_delta_text,
1168 pending_reasoning_delta_id: &pending_reasoning_delta_id,
1169 pending_tool_calls: &pending_tool_calls,
1170 current_tool_call: Some(diagnostic_tool_call.clone()),
1171 });
1172
1173 if !allowed_tool_names.contains(&name) {
1174 let emitted_tool_name = name.clone();
1175 match resolve_invalid_tool_call::<M, P>(
1176 self.hook.as_ref(),
1177 &emitted_tool_name,
1178 Some(id.clone()),
1179 Some(internal_call_id.clone()),
1180 Some(buffered_args.clone()),
1181 &executable_tool_names,
1182 &allowed_tool_names,
1183 self.tool_choice.as_ref(),
1184 diagnostic_history.clone(),
1185 true,
1186 ).await {
1187 InvalidToolCallResolution::Fail(err) => {
1188 yield Err(Box::new(err).into());
1189 break 'outer;
1190 }
1191 InvalidToolCallResolution::Skip(reason) => {
1192 tool_call_delta_states.remove(&key);
1193 flush_pending_reasoning_delta(
1194 &mut accumulated_reasoning,
1195 &mut pending_reasoning_delta_text,
1196 &mut pending_reasoning_delta_id,
1197 );
1198 let Some((assistant_message, user_message)) =
1199 invalid_streaming_name_delta_retry_messages(
1200 &stream.message_id,
1201 saw_text_this_turn
1202 .then_some(text_delta_response.as_str()),
1203 &accumulated_reasoning,
1204 &pending_tool_calls,
1205 diagnostic_tool_call.clone(),
1206 reason.clone(),
1207 )
1208 else {
1209 yield Err(cancelled_prompt_error(
1210 chat_history.as_deref(),
1211 new_messages.clone(),
1212 "invalid tool call skip produced no recovery messages".to_string(),
1213 ).await);
1214 break 'outer;
1215 };
1216 new_messages.push(assistant_message);
1217 new_messages.push(user_message);
1218 let tool_result = ToolResult {
1219 id,
1220 call_id: None,
1221 content: ToolResultContent::from_tool_output(
1222 reason,
1223 ),
1224 };
1225 if let Err(err) = drain_recovered_stream_usage(
1226 &mut stream,
1227 &tool_call_delta_states,
1228 &mut current_call_usage,
1229 &mut aggregated_usage,
1230 )
1231 .await
1232 {
1233 yield Err(err);
1234 break 'outer;
1235 }
1236 if let Some(completion_call) = record_completion_call_if_needed(
1237 &mut completion_calls,
1238 &mut completion_call_emitted,
1239 call_index,
1240 current_call_usage,
1241 ) {
1242 yield Ok(MultiTurnStreamItem::CompletionCall(
1243 completion_call,
1244 ));
1245 }
1246 yield Ok(MultiTurnStreamItem::StreamUserItem(
1247 StreamedUserContent::ToolResult {
1248 tool_result,
1249 internal_call_id,
1250 },
1251 ));
1252 text_delta_response.clear();
1253 saw_text_this_turn = false;
1254 continue 'outer;
1255 }
1256 InvalidToolCallResolution::Retry(feedback) => {
1257 tool_call_delta_states.remove(&key);
1258 if invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
1259 yield Err(Box::new(PromptError::UnknownToolCall {
1260 tool_name: emitted_tool_name,
1261 available_tools: executable_tool_names.iter().cloned().collect(),
1262 allowed_tools: allowed_tool_names.iter().cloned().collect(),
1263 chat_history: Box::new(diagnostic_history),
1264 }).into());
1265 break 'outer;
1266 }
1267
1268 invalid_tool_call_retries += 1;
1269 flush_pending_reasoning_delta(
1270 &mut accumulated_reasoning,
1271 &mut pending_reasoning_delta_text,
1272 &mut pending_reasoning_delta_id,
1273 );
1274 let Some((assistant_message, user_message)) =
1275 invalid_streaming_name_delta_retry_messages(
1276 &stream.message_id,
1277 saw_text_this_turn
1278 .then_some(text_delta_response.as_str()),
1279 &accumulated_reasoning,
1280 &pending_tool_calls,
1281 diagnostic_tool_call.clone(),
1282 feedback,
1283 )
1284 else {
1285 yield Err(cancelled_prompt_error(
1286 chat_history.as_deref(),
1287 new_messages.clone(),
1288 "invalid tool call retry produced no retry messages".to_string(),
1289 ).await);
1290 break 'outer;
1291 };
1292 new_messages.push(assistant_message);
1293 new_messages.push(user_message);
1294 if let Err(err) = drain_recovered_stream_usage(
1295 &mut stream,
1296 &tool_call_delta_states,
1297 &mut current_call_usage,
1298 &mut aggregated_usage,
1299 )
1300 .await
1301 {
1302 yield Err(err);
1303 break 'outer;
1304 }
1305 if let Some(completion_call) = record_completion_call_if_needed(
1306 &mut completion_calls,
1307 &mut completion_call_emitted,
1308 call_index,
1309 current_call_usage,
1310 ) {
1311 yield Ok(MultiTurnStreamItem::CompletionCall(
1312 completion_call,
1313 ));
1314 }
1315 text_delta_response.clear();
1316 saw_text_this_turn = false;
1317 continue 'outer;
1318 }
1319 InvalidToolCallResolution::Repair(repaired_name) => {
1320 name = repaired_name;
1321 }
1322 }
1323 }
1324
1325 let state =
1326 tool_call_delta_states.entry(key.clone()).or_default();
1327 state.name_validated = true;
1328 let buffered_arguments =
1329 std::mem::take(&mut state.buffered_arguments);
1330
1331 deltas_to_emit.push(ToolCallDeltaContent::Name(name));
1332 deltas_to_emit.extend(
1333 buffered_arguments
1334 .into_iter()
1335 .map(ToolCallDeltaContent::Delta),
1336 );
1337 }
1338 ToolCallDeltaContent::Delta(arguments) => {
1339 let state =
1340 tool_call_delta_states.entry(key.clone()).or_default();
1341 if state.name_validated {
1342 deltas_to_emit.push(ToolCallDeltaContent::Delta(arguments));
1343 } else {
1344 state.buffered_arguments.push(arguments);
1345 }
1346 }
1347 }
1348
1349 for content in deltas_to_emit {
1350 if let Some(ref hook) = self.hook {
1351 let (name, delta) = match &content {
1352 ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
1353 ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
1354 };
1355
1356 if let HookAction::Terminate { reason } = hook
1357 .on_tool_call_delta(
1358 &id,
1359 &internal_call_id,
1360 name,
1361 delta,
1362 )
1363 .await
1364 {
1365 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1366 break 'outer;
1367 }
1368 }
1369
1370 yield Ok(MultiTurnStreamItem::StreamAssistantItem(
1371 StreamedAssistantContent::ToolCallDelta {
1372 id: id.clone(),
1373 internal_call_id: internal_call_id.clone(),
1374 content,
1375 },
1376 ));
1377 }
1378 }
1379 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
1380 merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
1384 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
1385 },
1386 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
1387 pending_reasoning_delta_text.push_str(&reasoning);
1391 if pending_reasoning_delta_id.is_none() {
1392 pending_reasoning_delta_id = id.clone();
1393 }
1394 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
1395 },
1396 Ok(StreamedAssistantContent::Final(final_resp)) => {
1397 if let Some(err) =
1398 pending_tool_call_delta_error(&tool_call_delta_states)
1399 {
1400 yield Err(err.into());
1401 break 'outer;
1402 }
1403
1404 if let Some(usage) = final_resp.token_usage() {
1405 current_call_usage = reported_usage(usage);
1406 }
1407 if let Some(usage) = current_call_usage {
1408 aggregated_usage += usage;
1409 }
1410 let completion_call = CompletionCall::new(call_index, current_call_usage);
1411 completion_calls.push(completion_call);
1412 completion_call_emitted = true;
1413 yield Ok(MultiTurnStreamItem::CompletionCall(completion_call));
1414
1415 if saw_text_this_turn {
1416 if let Some(ref hook) = self.hook &&
1417 let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(¤t_prompt, &final_resp).await {
1418 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1419 break 'outer;
1420 }
1421
1422 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
1423 saw_text_this_turn = false;
1424 }
1425 }
1426 Err(e) => {
1427 yield Err(e.into());
1428 break 'outer;
1429 }
1430 }
1431 }
1432
1433 if let Some(err) = pending_tool_call_delta_error(&tool_call_delta_states) {
1434 yield Err(err.into());
1435 break 'outer;
1436 }
1437
1438 if !completion_call_emitted {
1439 let completion_call = CompletionCall::new(call_index, current_call_usage);
1440 completion_calls.push(completion_call);
1441 yield Ok(MultiTurnStreamItem::CompletionCall(completion_call));
1442 }
1443
1444 flush_pending_reasoning_delta(
1448 &mut accumulated_reasoning,
1449 &mut pending_reasoning_delta_text,
1450 &mut pending_reasoning_delta_id,
1451 );
1452
1453 let final_turn_content = stream.choice.clone();
1454 let turn_text_response = assistant_text_from_choice(&final_turn_content);
1455 tracing::Span::current().record("gen_ai.completion", &turn_text_response);
1456
1457 if !pending_tool_calls.is_empty() {
1458 let diagnostic_history =
1459 build_tool_call_validation_history(ToolCallValidationHistory {
1460 chat_history: chat_history.as_deref(),
1461 new_messages: &new_messages,
1462 assistant_message_id: &stream.message_id,
1463 final_turn_content: Some(&final_turn_content),
1464 text_delta_response: None,
1465 accumulated_reasoning: &accumulated_reasoning,
1466 pending_reasoning_delta_text: "",
1467 pending_reasoning_delta_id: &None,
1468 pending_tool_calls: &pending_tool_calls,
1469 current_tool_call: None,
1470 });
1471
1472 for (tool_call, _) in &pending_tool_calls {
1473 if let Err(err) = validate_tool_call_name(
1474 &tool_call.function.name,
1475 &executable_tool_names,
1476 &allowed_tool_names,
1477 diagnostic_history.clone(),
1478 ) {
1479 yield Err(Box::new(err).into());
1480 break 'outer;
1481 }
1482 }
1483
1484 for (tool_call, internal_call_id) in pending_tool_calls {
1485 let tool_span = info_span!(
1486 parent: tracing::Span::current(),
1487 "execute_tool",
1488 gen_ai.operation.name = "execute_tool",
1489 gen_ai.tool.type = "function",
1490 gen_ai.tool.name = tracing::field::Empty,
1491 gen_ai.tool.call.id = tracing::field::Empty,
1492 gen_ai.tool.call.arguments = tracing::field::Empty,
1493 gen_ai.tool.call.result = tracing::field::Empty
1494 );
1495
1496 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
1497
1498 let tc_result = async {
1499 let tool_span = tracing::Span::current();
1500 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
1501 if let Some(ref hook) = self.hook {
1502 let action = hook
1503 .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
1504 .await;
1505
1506 if let ToolCallHookAction::Terminate { reason } = action {
1507 return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1508 }
1509
1510 if let ToolCallHookAction::Skip { reason } = action {
1511 tracing::info!(
1513 tool_name = tool_call.function.name.as_str(),
1514 reason = reason,
1515 "Tool call rejected"
1516 );
1517 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
1518 tool_calls.push(tool_call_msg);
1519 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
1520 saw_tool_call_this_turn = true;
1521 return Ok(reason);
1522 }
1523 }
1524
1525 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
1526 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
1527
1528 let tool_result = match
1529 tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
1530 Ok(thing) => thing,
1531 Err(e) => {
1532 tracing::warn!("Error while calling tool: {e}");
1533 e.to_string()
1534 }
1535 };
1536
1537 tool_span.record("gen_ai.tool.call.result", &tool_result);
1538
1539 if let Some(ref hook) = self.hook &&
1540 let HookAction::Terminate { reason } =
1541 hook.on_tool_result(
1542 &tool_call.function.name,
1543 tool_call.call_id.clone(),
1544 &internal_call_id,
1545 &tool_args,
1546 &tool_result.to_string()
1547 )
1548 .await {
1549 return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1550 }
1551
1552 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
1553
1554 tool_calls.push(tool_call_msg);
1555 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
1556
1557 saw_tool_call_this_turn = true;
1558 Ok(tool_result)
1559 }.instrument(tool_span).await;
1560
1561 match tc_result {
1562 Ok(text) => {
1563 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
1564 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
1565 }
1566 Err(e) => {
1567 yield Err(e);
1568 break 'outer;
1569 }
1570 }
1571 }
1572 }
1573
1574 let mut assistant_turn_added_to_history = false;
1577 if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
1578 let mut content_items = assistant_text_items_from_choice(&final_turn_content);
1580
1581 for reasoning in accumulated_reasoning.drain(..) {
1583 content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
1584 }
1585
1586 content_items.extend(tool_calls.clone());
1587
1588 if let Some(content) = OneOrMany::from_iter_optional(content_items) {
1589 new_messages.push(Message::Assistant {
1590 id: stream.message_id.clone(),
1591 content,
1592 });
1593 assistant_turn_added_to_history = true;
1594 }
1595 }
1596
1597 for (id, call_id, tool_result) in tool_results {
1598 new_messages.push(tool_result_to_user_message(id, call_id, tool_result));
1599 }
1600
1601 if !saw_tool_call_this_turn {
1602 let should_add_final_assistant = !assistant_turn_added_to_history
1604 && !is_empty_assistant_choice(&final_turn_content);
1605 if should_add_final_assistant {
1606 new_messages.push(Message::Assistant {
1607 id: stream.message_id.clone(),
1608 content: final_turn_content.clone(),
1609 });
1610 } else if !assistant_turn_added_to_history {
1611 tracing::warn!(
1612 agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
1613 message_id = ?stream.message_id,
1614 "Streaming turn completed without assistant text; final response will be empty"
1615 );
1616 }
1617
1618 let current_span = tracing::Span::current();
1619 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
1620 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
1621 current_span.record("gen_ai.usage.cache_read.input_tokens", aggregated_usage.cached_input_tokens);
1622 current_span.record("gen_ai.usage.cache_creation.input_tokens", aggregated_usage.cache_creation_input_tokens);
1623 current_span.record("gen_ai.usage.tool_use_prompt_tokens", aggregated_usage.tool_use_prompt_tokens);
1624 current_span.record("gen_ai.usage.reasoning_tokens", aggregated_usage.reasoning_tokens);
1625 tracing::info!("Agent multi-turn stream finished");
1626 if let Some((memory, id)) = memory_handle.as_ref()
1627 && let Err(err) = memory.append(id, new_messages.clone()).await
1628 {
1629 tracing::warn!(
1630 error = %err,
1631 conversation_id = %id,
1632 "conversation memory append failed; yielding final response anyway"
1633 );
1634 }
1635 let final_messages: Option<Vec<Message>> = if has_history {
1636 Some(new_messages.clone())
1637 } else {
1638 None
1639 };
1640 yield Ok(MultiTurnStreamItem::final_response_with_completion_calls(
1641 final_turn_content,
1642 aggregated_usage,
1643 completion_calls,
1644 final_messages,
1645 ));
1646 break;
1647 }
1648 }
1649
1650 if max_turns_reached {
1651 yield Err(Box::new(PromptError::MaxTurnsError {
1652 max_turns: self.max_turns,
1653 chat_history: build_full_history(chat_history.as_deref(), new_messages.clone()).into(),
1654 prompt: Box::new(last_prompt_error.clone().into()),
1655 }).into());
1656 }
1657 };
1658
1659 Box::pin(stream.instrument(agent_span))
1660 }
1661}
1662
1663impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
1664where
1665 M: CompletionModel + 'static,
1666 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
1667 P: PromptHook<M> + 'static,
1668{
1669 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1671
1672 fn into_future(self) -> Self::IntoFuture {
1673 Box::pin(async move { self.send().await })
1675 }
1676}
1677
1678pub async fn stream_to_stdout<R>(
1685 stream: &mut StreamingResult<R>,
1686) -> Result<FinalResponse, std::io::Error> {
1687 let mut final_res = FinalResponse::empty();
1688 print!("Response: ");
1689 while let Some(content) = stream.next().await {
1690 match content {
1691 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1692 Text { text, .. },
1693 ))) => {
1694 print!("{text}");
1695 std::io::Write::flush(&mut std::io::stdout())?;
1696 }
1697 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
1698 reasoning,
1699 ))) => {
1700 let reasoning = reasoning.display_text();
1701 print!("{reasoning}");
1702 std::io::Write::flush(&mut std::io::stdout())?;
1703 }
1704 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1705 final_res = res;
1706 }
1707 Err(err) => {
1708 eprintln!("Error: {err}");
1709 }
1710 _ => {}
1711 }
1712 }
1713
1714 Ok(final_res)
1715}
1716
1717#[cfg(test)]
1718mod tests {
1719 use super::*;
1720 use crate::agent::AgentBuilder;
1721 use crate::agent::prompt_request::hooks::{
1722 InvalidToolCallContext, InvalidToolCallHookAction, PromptHook, ToolCallHookAction,
1723 };
1724 use crate::client::ProviderClient;
1725 use crate::client::completion::CompletionClient;
1726 use crate::completion::{CompletionRequest, PromptError, ToolDefinition, Usage};
1727 use crate::message::{
1728 AssistantContent, DocumentSourceKind, ImageMediaType, Message, ReasoningContent,
1729 ToolChoice, ToolResultContent, UserContent,
1730 };
1731 use crate::providers::anthropic;
1732 use crate::streaming::{StreamingPrompt, ToolCallDeltaContent};
1733 use crate::test_utils::{
1734 AppendFailingMemory, FailingMemory, MockAddTool, MockCompletionModel, MockResponse,
1735 MockStreamEvent, MockSubtractTool, MockToolError,
1736 };
1737 use crate::tool::Tool;
1738 use futures::StreamExt;
1739 use serde::Deserialize;
1740 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
1741 use std::sync::{Arc, Mutex};
1742 use std::time::Duration;
1743
1744 #[test]
1745 fn merge_reasoning_blocks_preserves_order_and_signatures() {
1746 let mut accumulated = Vec::new();
1747 let first = crate::message::Reasoning {
1748 id: Some("rs_1".to_string()),
1749 content: vec![ReasoningContent::Text {
1750 text: "step-1".to_string(),
1751 signature: Some("sig-1".to_string()),
1752 }],
1753 };
1754 let second = crate::message::Reasoning {
1755 id: Some("rs_1".to_string()),
1756 content: vec![
1757 ReasoningContent::Text {
1758 text: "step-2".to_string(),
1759 signature: Some("sig-2".to_string()),
1760 },
1761 ReasoningContent::Summary("summary".to_string()),
1762 ],
1763 };
1764
1765 merge_reasoning_blocks(&mut accumulated, &first);
1766 merge_reasoning_blocks(&mut accumulated, &second);
1767
1768 assert_eq!(accumulated.len(), 1);
1769 let merged = accumulated.first().expect("expected accumulated reasoning");
1770 assert_eq!(merged.id.as_deref(), Some("rs_1"));
1771 assert_eq!(merged.content.len(), 3);
1772 assert!(matches!(
1773 merged.content.first(),
1774 Some(ReasoningContent::Text { text, signature: Some(sig) })
1775 if text == "step-1" && sig == "sig-1"
1776 ));
1777 assert!(matches!(
1778 merged.content.get(1),
1779 Some(ReasoningContent::Text { text, signature: Some(sig) })
1780 if text == "step-2" && sig == "sig-2"
1781 ));
1782 }
1783
1784 #[test]
1785 fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
1786 let mut accumulated = vec![crate::message::Reasoning {
1787 id: Some("rs_a".to_string()),
1788 content: vec![ReasoningContent::Text {
1789 text: "step-1".to_string(),
1790 signature: None,
1791 }],
1792 }];
1793 let incoming = crate::message::Reasoning {
1794 id: Some("rs_b".to_string()),
1795 content: vec![ReasoningContent::Text {
1796 text: "step-2".to_string(),
1797 signature: None,
1798 }],
1799 };
1800
1801 merge_reasoning_blocks(&mut accumulated, &incoming);
1802 assert_eq!(accumulated.len(), 2);
1803 assert_eq!(
1804 accumulated.first().and_then(|r| r.id.as_deref()),
1805 Some("rs_a")
1806 );
1807 assert_eq!(
1808 accumulated.get(1).and_then(|r| r.id.as_deref()),
1809 Some("rs_b")
1810 );
1811 }
1812
1813 #[test]
1814 fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
1815 let mut accumulated = vec![crate::message::Reasoning {
1816 id: None,
1817 content: vec![ReasoningContent::Text {
1818 text: "first".to_string(),
1819 signature: None,
1820 }],
1821 }];
1822 let incoming = crate::message::Reasoning {
1823 id: None,
1824 content: vec![ReasoningContent::Text {
1825 text: "second".to_string(),
1826 signature: None,
1827 }],
1828 };
1829
1830 merge_reasoning_blocks(&mut accumulated, &incoming);
1831 assert_eq!(accumulated.len(), 2);
1832 assert!(matches!(
1833 accumulated.first(),
1834 Some(crate::message::Reasoning {
1835 id: None,
1836 content
1837 }) if matches!(
1838 content.first(),
1839 Some(ReasoningContent::Text { text, .. }) if text == "first"
1840 )
1841 ));
1842 assert!(matches!(
1843 accumulated.get(1),
1844 Some(crate::message::Reasoning {
1845 id: None,
1846 content
1847 }) if matches!(
1848 content.first(),
1849 Some(ReasoningContent::Text { text, .. }) if text == "second"
1850 )
1851 ));
1852 }
1853
1854 #[test]
1855 fn tool_result_to_user_message_preserves_multimodal_tool_output() {
1856 let message = tool_result_to_user_message(
1857 "tool_call_1".to_string(),
1858 Some("call_1".to_string()),
1859 serde_json::json!({
1860 "response": {
1861 "instruction": "Use the image part to answer."
1862 },
1863 "parts": [
1864 {
1865 "type": "image",
1866 "data": "base64data==",
1867 "mimeType": "image/png"
1868 }
1869 ]
1870 })
1871 .to_string(),
1872 );
1873
1874 let tool_result = match message {
1875 Message::User { content } => match content.first() {
1876 UserContent::ToolResult(tool_result) => tool_result,
1877 other => panic!("expected tool result content, got {other:?}"),
1878 },
1879 other => panic!("expected user message, got {other:?}"),
1880 };
1881
1882 assert_eq!(tool_result.id, "tool_call_1");
1883 assert_eq!(tool_result.call_id.as_deref(), Some("call_1"));
1884 assert_eq!(tool_result.content.len(), 2);
1885
1886 let mut items = tool_result.content.iter();
1887 match items.next() {
1888 Some(ToolResultContent::Text(text)) => {
1889 assert!(text.text.contains("Use the image part to answer."));
1890 }
1891 other => panic!("expected structured text payload first, got {other:?}"),
1892 }
1893
1894 match items.next() {
1895 Some(ToolResultContent::Image(image)) => {
1896 assert_eq!(image.media_type, Some(ImageMediaType::PNG));
1897 assert!(matches!(
1898 image.data,
1899 DocumentSourceKind::Base64(ref data) if data == "base64data=="
1900 ));
1901 }
1902 other => panic!("expected image payload second, got {other:?}"),
1903 }
1904 }
1905
1906 fn validate_follow_up_tool_history(request: &CompletionRequest) -> Result<(), String> {
1907 let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
1908 if history.len() != 3 {
1909 return Err(format!(
1910 "follow-up request should contain [original user prompt, assistant tool call, user tool result]: {history:?}"
1911 ));
1912 }
1913
1914 if !matches!(
1915 history.first(),
1916 Some(Message::User { content })
1917 if matches!(
1918 content.first(),
1919 UserContent::Text(text) if text.text == "do tool work"
1920 )
1921 ) {
1922 return Err(format!(
1923 "follow-up request should begin with the original user prompt: {history:?}"
1924 ));
1925 }
1926
1927 if !matches!(
1928 history.get(1),
1929 Some(Message::Assistant { content, .. })
1930 if matches!(
1931 content.first(),
1932 AssistantContent::ToolCall(tool_call)
1933 if tool_call.id == "tool_call_1"
1934 && tool_call.call_id.as_deref() == Some("call_1")
1935 )
1936 ) {
1937 return Err(format!(
1938 "follow-up request is missing the assistant tool call in position 2: {history:?}"
1939 ));
1940 }
1941
1942 if !matches!(
1943 history.get(2),
1944 Some(Message::User { content })
1945 if matches!(
1946 content.first(),
1947 UserContent::ToolResult(tool_result)
1948 if tool_result.id == "tool_call_1"
1949 && tool_result.call_id.as_deref() == Some("call_1")
1950 )
1951 ) {
1952 return Err(format!(
1953 "follow-up request should end with the user tool result: {history:?}"
1954 ));
1955 }
1956
1957 Ok(())
1958 }
1959
1960 fn history_contains_tool_call(history: &[Message], tool_name: &str) -> bool {
1961 history.iter().any(|message| {
1962 matches!(
1963 message,
1964 Message::Assistant { content, .. }
1965 if content.iter().any(|item| matches!(
1966 item,
1967 AssistantContent::ToolCall(tool_call)
1968 if tool_call.function.name == tool_name
1969 ))
1970 )
1971 })
1972 }
1973
1974 fn history_contains_text(history: &[Message], expected: &str) -> bool {
1975 history.iter().any(|message| {
1976 matches!(
1977 message,
1978 Message::Assistant { content, .. }
1979 if content.iter().any(|item| matches!(
1980 item,
1981 AssistantContent::Text(text) if text.text == expected
1982 ))
1983 )
1984 })
1985 }
1986
1987 fn assistant_reasoning_precedes_tool_call(
1988 history: &[Message],
1989 expected_reasoning: &str,
1990 tool_name: &str,
1991 ) -> bool {
1992 history.iter().any(|message| {
1993 let Message::Assistant { content, .. } = message else {
1994 return false;
1995 };
1996
1997 let reasoning_index = content.iter().position(|item| {
1998 matches!(
1999 item,
2000 AssistantContent::Reasoning(reasoning)
2001 if reasoning.content.iter().any(|content| matches!(
2002 content,
2003 ReasoningContent::Text { text, .. }
2004 if text == expected_reasoning
2005 ))
2006 )
2007 });
2008 let tool_index = content.iter().position(|item| {
2009 matches!(
2010 item,
2011 AssistantContent::ToolCall(tool_call)
2012 if tool_call.function.name == tool_name
2013 )
2014 });
2015
2016 matches!((reasoning_index, tool_index), (Some(reasoning), Some(tool)) if reasoning < tool)
2017 })
2018 }
2019
2020 #[derive(Clone)]
2021 struct PanicOnUnknownToolHook;
2022
2023 impl PromptHook<MockCompletionModel> for PanicOnUnknownToolHook {
2024 async fn on_tool_call_delta(
2025 &self,
2026 _tool_call_id: &str,
2027 _internal_call_id: &str,
2028 _tool_name: Option<&str>,
2029 _tool_call_delta: &str,
2030 ) -> HookAction {
2031 panic!("unknown tool call delta should fail before delta hooks run")
2032 }
2033
2034 async fn on_tool_call(
2035 &self,
2036 _tool_name: &str,
2037 _tool_call_id: Option<String>,
2038 _internal_call_id: &str,
2039 _args: &str,
2040 ) -> ToolCallHookAction {
2041 panic!("unknown tool call should fail before tool hooks run")
2042 }
2043
2044 async fn on_stream_completion_response_finish(
2045 &self,
2046 _prompt: &Message,
2047 _response: &MockResponse,
2048 ) -> HookAction {
2049 panic!("unknown tool call should fail before stream finish hooks run")
2050 }
2051 }
2052
2053 #[derive(Clone)]
2054 struct CountingAddTool {
2055 calls: Arc<AtomicU32>,
2056 }
2057
2058 #[derive(Clone)]
2059 struct CountingSubtractTool {
2060 calls: Arc<AtomicU32>,
2061 }
2062
2063 #[derive(Deserialize)]
2064 struct CountingOperationArgs {
2065 x: i32,
2066 y: i32,
2067 }
2068
2069 fn arithmetic_tool_definition(name: &str, description: &str) -> ToolDefinition {
2070 ToolDefinition {
2071 name: name.to_string(),
2072 description: description.to_string(),
2073 parameters: serde_json::json!({
2074 "type": "object",
2075 "properties": {
2076 "x": {
2077 "type": "number",
2078 "description": "The first operand"
2079 },
2080 "y": {
2081 "type": "number",
2082 "description": "The second operand"
2083 }
2084 },
2085 "required": ["x", "y"],
2086 }),
2087 }
2088 }
2089
2090 impl Tool for CountingAddTool {
2091 const NAME: &'static str = "add";
2092 type Error = MockToolError;
2093 type Args = CountingOperationArgs;
2094 type Output = i32;
2095
2096 async fn definition(&self, _prompt: String) -> ToolDefinition {
2097 arithmetic_tool_definition(Self::NAME, "Add x and y together")
2098 }
2099
2100 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
2101 self.calls.fetch_add(1, Ordering::SeqCst);
2102 Ok(args.x + args.y)
2103 }
2104 }
2105
2106 impl Tool for CountingSubtractTool {
2107 const NAME: &'static str = "subtract";
2108 type Error = MockToolError;
2109 type Args = CountingOperationArgs;
2110 type Output = i32;
2111
2112 async fn definition(&self, _prompt: String) -> ToolDefinition {
2113 arithmetic_tool_definition(Self::NAME, "Subtract y from x")
2114 }
2115
2116 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
2117 self.calls.fetch_add(1, Ordering::SeqCst);
2118 Ok(args.x - args.y)
2119 }
2120 }
2121
2122 fn streaming_tool_then_text_model() -> MockCompletionModel {
2123 MockCompletionModel::from_stream_turns([
2124 vec![
2125 MockStreamEvent::tool_call(
2126 "tool_call_1",
2127 "add",
2128 serde_json::json!({"x": 1, "y": 2}),
2129 )
2130 .with_call_id("call_1"),
2131 MockStreamEvent::final_response_with_total_tokens(4),
2132 ],
2133 vec![
2134 MockStreamEvent::text("done"),
2135 MockStreamEvent::final_response_with_total_tokens(6),
2136 ],
2137 ])
2138 }
2139
2140 fn usage(input_tokens: u64, output_tokens: u64) -> Usage {
2141 Usage {
2142 input_tokens,
2143 output_tokens,
2144 total_tokens: input_tokens + output_tokens,
2145 cached_input_tokens: 0,
2146 cache_creation_input_tokens: 0,
2147 tool_use_prompt_tokens: 0,
2148 reasoning_tokens: 0,
2149 }
2150 }
2151
2152 #[test]
2153 fn completion_calls_stream_item_serializes_and_deserializes_expected_shape() {
2154 let item: MultiTurnStreamItem<MockResponse> =
2155 MultiTurnStreamItem::CompletionCall(CompletionCall::new(2, Some(usage(3, 4))));
2156
2157 let value = serde_json::to_value(&item).expect("serialize completion call event");
2158
2159 assert_eq!(
2160 value,
2161 serde_json::json!({
2162 "type": "completionCall",
2163 "call_index": 2,
2164 "usage": {
2165 "input_tokens": 3,
2166 "output_tokens": 4,
2167 "total_tokens": 7,
2168 "cached_input_tokens": 0,
2169 "cache_creation_input_tokens": 0,
2170 "tool_use_prompt_tokens": 0,
2171 "reasoning_tokens": 0,
2172 }
2173 })
2174 );
2175
2176 let item: MultiTurnStreamItem<MockResponse> =
2177 serde_json::from_value(value).expect("deserialize completion call event");
2178 match item {
2179 MultiTurnStreamItem::CompletionCall(call_usage) => {
2180 assert_eq!(call_usage, CompletionCall::new(2, Some(usage(3, 4))));
2181 }
2182 other => panic!("expected completion call event, got {other:?}"),
2183 }
2184
2185 let item: MultiTurnStreamItem<MockResponse> =
2186 MultiTurnStreamItem::CompletionCall(CompletionCall::new(3, None));
2187 let value = serde_json::to_value(&item).expect("serialize missing usage event");
2188
2189 assert_eq!(
2190 value,
2191 serde_json::json!({
2192 "type": "completionCall",
2193 "call_index": 3,
2194 "usage": null
2195 })
2196 );
2197 }
2198
2199 #[test]
2200 fn final_response_serializes_completion_calls_with_missing_usage() {
2201 let item: MultiTurnStreamItem<MockResponse> =
2202 MultiTurnStreamItem::final_response_with_completion_calls(
2203 OneOrMany::one(AssistantContent::text("done")),
2204 usage(3, 4),
2205 vec![
2206 CompletionCall::new(0, None),
2207 CompletionCall::new(1, Some(usage(3, 4))),
2208 ],
2209 None,
2210 );
2211
2212 let value = serde_json::to_value(&item).expect("serialize final response");
2213
2214 assert_eq!(
2215 value.get("completionCalls"),
2216 Some(&serde_json::json!([
2217 {
2218 "call_index": 0,
2219 "usage": null,
2220 },
2221 {
2222 "call_index": 1,
2223 "usage": {
2224 "input_tokens": 3,
2225 "output_tokens": 4,
2226 "total_tokens": 7,
2227 "cached_input_tokens": 0,
2228 "cache_creation_input_tokens": 0,
2229 "tool_use_prompt_tokens": 0,
2230 "reasoning_tokens": 0,
2231 }
2232 }
2233 ]))
2234 );
2235 }
2236
2237 fn streaming_text_then_final_model() -> MockCompletionModel {
2238 MockCompletionModel::from_stream_turns([[
2239 MockStreamEvent::text("hello"),
2240 MockStreamEvent::text(" world"),
2241 MockStreamEvent::final_response_with_total_tokens(3),
2242 ]])
2243 }
2244
2245 fn citation_metadata() -> serde_json::Value {
2246 serde_json::json!({
2247 "citations": [{
2248 "type": "web_search_result_location",
2249 "cited_text": "Claude Shannon was born in 1916.",
2250 "url": "https://example.com/shannon",
2251 "title": "Claude Shannon",
2252 "encrypted_index": "encrypted-reference"
2253 }]
2254 })
2255 }
2256
2257 fn streaming_cited_text_then_final_model() -> MockCompletionModel {
2258 MockCompletionModel::from_stream_turns([[
2259 MockStreamEvent::text_start(Some(citation_metadata())),
2260 MockStreamEvent::text("cited "),
2261 MockStreamEvent::text_start(None),
2262 MockStreamEvent::text("answer"),
2263 MockStreamEvent::final_response_with_total_tokens(3),
2264 ]])
2265 }
2266
2267 fn streaming_cited_text_then_tool_model() -> MockCompletionModel {
2268 MockCompletionModel::from_stream_turns([
2269 vec![
2270 MockStreamEvent::text_start(Some(citation_metadata())),
2271 MockStreamEvent::text("I need a tool. "),
2272 MockStreamEvent::tool_call(
2273 "tool_call_1",
2274 "add",
2275 serde_json::json!({"x": 1, "y": 2}),
2276 )
2277 .with_call_id("call_1"),
2278 MockStreamEvent::final_response_with_total_tokens(4),
2279 ],
2280 vec![
2281 MockStreamEvent::text("done"),
2282 MockStreamEvent::final_response_with_total_tokens(6),
2283 ],
2284 ])
2285 }
2286
2287 fn streaming_final_only_model() -> MockCompletionModel {
2288 MockCompletionModel::from_stream_turns([[
2289 MockStreamEvent::final_response_with_total_tokens(1),
2290 ]])
2291 }
2292
2293 #[derive(Clone)]
2294 struct TerminateOnStreamFinish;
2295
2296 impl PromptHook<MockCompletionModel> for TerminateOnStreamFinish {
2297 async fn on_stream_completion_response_finish(
2298 &self,
2299 _prompt: &Message,
2300 _response: &<MockCompletionModel as CompletionModel>::StreamingResponse,
2301 ) -> HookAction {
2302 HookAction::terminate("stop after completion call")
2303 }
2304 }
2305
2306 type RecordedToolCallDelta = (String, String, Option<String>, String);
2307
2308 #[derive(Clone)]
2309 struct RepairDefaultApiHook;
2310
2311 impl PromptHook<MockCompletionModel> for RepairDefaultApiHook {
2312 fn on_invalid_tool_call(
2313 &self,
2314 context: &InvalidToolCallContext,
2315 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2316 let tool_name = context.tool_name.clone();
2317 async move {
2318 assert_eq!(tool_name, "default_api");
2319 InvalidToolCallHookAction::repair("add")
2320 }
2321 }
2322 }
2323
2324 #[derive(Clone)]
2325 struct RetryDefaultApiHook;
2326
2327 impl PromptHook<MockCompletionModel> for RetryDefaultApiHook {
2328 fn on_invalid_tool_call(
2329 &self,
2330 context: &InvalidToolCallContext,
2331 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2332 let tool_name = context.tool_name.clone();
2333 let args = context.args.clone();
2334 async move {
2335 assert_eq!(tool_name, "default_api");
2336 if let Some(args) = args {
2337 assert!(!args.is_empty());
2338 }
2339 InvalidToolCallHookAction::retry("Use the add tool instead")
2340 }
2341 }
2342 }
2343
2344 #[derive(Clone)]
2345 struct SkipDefaultApiHook;
2346
2347 impl PromptHook<MockCompletionModel> for SkipDefaultApiHook {
2348 fn on_invalid_tool_call(
2349 &self,
2350 context: &InvalidToolCallContext,
2351 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2352 let tool_name = context.tool_name.clone();
2353 async move {
2354 assert_eq!(tool_name, "default_api");
2355 InvalidToolCallHookAction::skip("default_api was skipped")
2356 }
2357 }
2358 }
2359
2360 #[derive(Clone, Default)]
2361 struct RecordingInvalidToolCallHook {
2362 contexts: Arc<Mutex<Vec<InvalidToolCallContext>>>,
2363 }
2364
2365 impl RecordingInvalidToolCallHook {
2366 fn observed(&self) -> Vec<InvalidToolCallContext> {
2367 self.contexts
2368 .lock()
2369 .expect("invalid tool context records mutex was poisoned")
2370 .clone()
2371 }
2372 }
2373
2374 impl PromptHook<MockCompletionModel> for RecordingInvalidToolCallHook {
2375 fn on_invalid_tool_call(
2376 &self,
2377 context: &InvalidToolCallContext,
2378 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2379 let contexts = self.contexts.clone();
2380 let context = context.clone();
2381
2382 async move {
2383 contexts
2384 .lock()
2385 .expect("invalid tool context records mutex was poisoned")
2386 .push(context);
2387 InvalidToolCallHookAction::fail()
2388 }
2389 }
2390 }
2391
2392 #[derive(Clone, Default)]
2393 struct RecordingToolCallDeltaHook {
2394 deltas: Arc<Mutex<Vec<RecordedToolCallDelta>>>,
2395 }
2396
2397 impl RecordingToolCallDeltaHook {
2398 fn observed(&self) -> Vec<RecordedToolCallDelta> {
2399 self.deltas
2400 .lock()
2401 .expect("tool call delta hook records mutex was poisoned")
2402 .clone()
2403 }
2404 }
2405
2406 impl PromptHook<MockCompletionModel> for RecordingToolCallDeltaHook {
2407 fn on_tool_call_delta(
2408 &self,
2409 tool_call_id: &str,
2410 internal_call_id: &str,
2411 tool_name: Option<&str>,
2412 tool_call_delta: &str,
2413 ) -> impl Future<Output = HookAction> + Send {
2414 let deltas = self.deltas.clone();
2415 let event = (
2416 tool_call_id.to_string(),
2417 internal_call_id.to_string(),
2418 tool_name.map(str::to_string),
2419 tool_call_delta.to_string(),
2420 );
2421
2422 async move {
2423 deltas
2424 .lock()
2425 .expect("tool call delta hook records mutex was poisoned")
2426 .push(event);
2427 HookAction::cont()
2428 }
2429 }
2430 }
2431
2432 #[derive(Clone, Default)]
2433 struct RecordingTextDeltaHook {
2434 deltas: Arc<Mutex<Vec<(String, String)>>>,
2435 }
2436
2437 impl RecordingTextDeltaHook {
2438 fn observed(&self) -> Vec<(String, String)> {
2439 self.deltas
2440 .lock()
2441 .expect("text delta hook records mutex was poisoned")
2442 .clone()
2443 }
2444 }
2445
2446 impl PromptHook<MockCompletionModel> for RecordingTextDeltaHook {
2447 fn on_text_delta(
2448 &self,
2449 text_delta: &str,
2450 full_text: &str,
2451 ) -> impl Future<Output = HookAction> + Send {
2452 let deltas = self.deltas.clone();
2453 let event = (text_delta.to_string(), full_text.to_string());
2454
2455 async move {
2456 deltas
2457 .lock()
2458 .expect("text delta hook records mutex was poisoned")
2459 .push(event);
2460 HookAction::cont()
2461 }
2462 }
2463 }
2464
2465 #[derive(Clone)]
2466 struct RecordingTextAndSkipInvalidToolHook {
2467 text: RecordingTextDeltaHook,
2468 }
2469
2470 impl PromptHook<MockCompletionModel> for RecordingTextAndSkipInvalidToolHook {
2471 fn on_text_delta(
2472 &self,
2473 text_delta: &str,
2474 full_text: &str,
2475 ) -> impl Future<Output = HookAction> + Send {
2476 self.text.on_text_delta(text_delta, full_text)
2477 }
2478
2479 fn on_invalid_tool_call(
2480 &self,
2481 context: &InvalidToolCallContext,
2482 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2483 SkipDefaultApiHook.on_invalid_tool_call(context)
2484 }
2485 }
2486
2487 #[derive(Clone)]
2488 struct RecordingTextAndRetryInvalidToolHook {
2489 text: RecordingTextDeltaHook,
2490 }
2491
2492 impl PromptHook<MockCompletionModel> for RecordingTextAndRetryInvalidToolHook {
2493 fn on_text_delta(
2494 &self,
2495 text_delta: &str,
2496 full_text: &str,
2497 ) -> impl Future<Output = HookAction> + Send {
2498 self.text.on_text_delta(text_delta, full_text)
2499 }
2500
2501 fn on_invalid_tool_call(
2502 &self,
2503 context: &InvalidToolCallContext,
2504 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2505 RetryDefaultApiHook.on_invalid_tool_call(context)
2506 }
2507 }
2508
2509 #[derive(Clone)]
2510 struct RecordingDeltaAndRetryInvalidToolHook {
2511 delta: RecordingToolCallDeltaHook,
2512 }
2513
2514 impl PromptHook<MockCompletionModel> for RecordingDeltaAndRetryInvalidToolHook {
2515 fn on_tool_call_delta(
2516 &self,
2517 tool_call_id: &str,
2518 internal_call_id: &str,
2519 tool_name: Option<&str>,
2520 tool_call_delta: &str,
2521 ) -> impl Future<Output = HookAction> + Send {
2522 self.delta.on_tool_call_delta(
2523 tool_call_id,
2524 internal_call_id,
2525 tool_name,
2526 tool_call_delta,
2527 )
2528 }
2529
2530 fn on_invalid_tool_call(
2531 &self,
2532 context: &InvalidToolCallContext,
2533 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2534 RetryDefaultApiHook.on_invalid_tool_call(context)
2535 }
2536 }
2537
2538 #[derive(Clone)]
2539 struct RecordingDeltaAndSkipInvalidToolHook {
2540 delta: RecordingToolCallDeltaHook,
2541 }
2542
2543 impl PromptHook<MockCompletionModel> for RecordingDeltaAndSkipInvalidToolHook {
2544 fn on_tool_call_delta(
2545 &self,
2546 tool_call_id: &str,
2547 internal_call_id: &str,
2548 tool_name: Option<&str>,
2549 tool_call_delta: &str,
2550 ) -> impl Future<Output = HookAction> + Send {
2551 self.delta.on_tool_call_delta(
2552 tool_call_id,
2553 internal_call_id,
2554 tool_name,
2555 tool_call_delta,
2556 )
2557 }
2558
2559 fn on_invalid_tool_call(
2560 &self,
2561 context: &InvalidToolCallContext,
2562 ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2563 SkipDefaultApiHook.on_invalid_tool_call(context)
2564 }
2565 }
2566
2567 #[derive(Clone, Default)]
2568 struct TerminatingToolCallDeltaHook {
2569 deltas: Arc<Mutex<Vec<RecordedToolCallDelta>>>,
2570 }
2571
2572 impl TerminatingToolCallDeltaHook {
2573 fn observed(&self) -> Vec<RecordedToolCallDelta> {
2574 self.deltas
2575 .lock()
2576 .expect("tool call delta hook records mutex was poisoned")
2577 .clone()
2578 }
2579 }
2580
2581 impl PromptHook<MockCompletionModel> for TerminatingToolCallDeltaHook {
2582 fn on_tool_call_delta(
2583 &self,
2584 tool_call_id: &str,
2585 internal_call_id: &str,
2586 tool_name: Option<&str>,
2587 tool_call_delta: &str,
2588 ) -> impl Future<Output = HookAction> + Send {
2589 let deltas = self.deltas.clone();
2590 let event = (
2591 tool_call_id.to_string(),
2592 internal_call_id.to_string(),
2593 tool_name.map(str::to_string),
2594 tool_call_delta.to_string(),
2595 );
2596
2597 async move {
2598 deltas
2599 .lock()
2600 .expect("tool call delta hook records mutex was poisoned")
2601 .push(event);
2602 HookAction::terminate("stop on tool call delta")
2603 }
2604 }
2605 }
2606
2607 fn text_metadata(content: &OneOrMany<AssistantContent>) -> Option<&serde_json::Value> {
2608 content.iter().find_map(|item| match item {
2609 AssistantContent::Text(text) => text.additional_params.as_ref(),
2610 _ => None,
2611 })
2612 }
2613
2614 #[tokio::test]
2615 async fn stream_prompt_continues_after_tool_call_turn() {
2616 let model = streaming_tool_then_text_model();
2617 let recorded = model.clone();
2618 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2619 let empty_history: &[Message] = &[];
2620
2621 let mut stream = agent
2622 .stream_prompt("do tool work")
2623 .with_history(empty_history)
2624 .multi_turn(3)
2625 .await;
2626 let mut saw_tool_call = false;
2627 let mut saw_tool_result = false;
2628 let mut saw_final_response = false;
2629 let mut final_text = String::new();
2630 let mut final_response_text = None;
2631 let mut final_history = None;
2632
2633 while let Some(item) = stream.next().await {
2634 match item {
2635 Ok(MultiTurnStreamItem::StreamAssistantItem(
2636 StreamedAssistantContent::ToolCall { .. },
2637 )) => {
2638 saw_tool_call = true;
2639 }
2640 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2641 ..
2642 })) => {
2643 saw_tool_result = true;
2644 }
2645 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
2646 text,
2647 ))) => {
2648 final_text.push_str(&text.text);
2649 }
2650 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
2651 saw_final_response = true;
2652 final_response_text = Some(res.response().to_owned());
2653 final_history = res.history().map(|history| history.to_vec());
2654 break;
2655 }
2656 Ok(_) => {}
2657 Err(err) => panic!("unexpected streaming error: {err:?}"),
2658 }
2659 }
2660
2661 assert!(saw_tool_call);
2662 assert!(saw_tool_result);
2663 assert!(saw_final_response);
2664 assert_eq!(final_text, "done");
2665 assert_eq!(final_response_text.as_deref(), Some("done"));
2666 let history = final_history.expect("expected final response history");
2667 assert!(history.iter().any(|message| matches!(
2668 message,
2669 Message::Assistant { content, .. }
2670 if content.iter().any(|item| matches!(
2671 item,
2672 AssistantContent::Text(text) if text.text == "done"
2673 ))
2674 )));
2675 let requests = recorded.requests();
2676 assert_eq!(requests.len(), 2);
2677 assert!(validate_follow_up_tool_history(&requests[1]).is_ok());
2678 }
2679
2680 #[tokio::test]
2681 async fn unknown_tool_call_fails_before_streaming_second_request() {
2682 let model = MockCompletionModel::from_stream_turns([
2683 vec![
2684 MockStreamEvent::tool_call(
2685 "tool_call_1",
2686 "default_api",
2687 serde_json::json!({"x": 1, "y": 2}),
2688 ),
2689 MockStreamEvent::final_response_with_total_tokens(4),
2690 ],
2691 vec![
2692 MockStreamEvent::text("should not be requested"),
2693 MockStreamEvent::final_response_with_total_tokens(6),
2694 ],
2695 ]);
2696 let recorded = model.clone();
2697 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2698
2699 let mut stream = agent
2700 .stream_prompt("use the tool")
2701 .with_hook(PanicOnUnknownToolHook)
2702 .multi_turn(3)
2703 .await;
2704 let mut saw_tool_call = false;
2705 let mut error = None;
2706
2707 while let Some(item) = stream.next().await {
2708 match item {
2709 Ok(MultiTurnStreamItem::StreamAssistantItem(
2710 StreamedAssistantContent::ToolCall { .. },
2711 )) => {
2712 saw_tool_call = true;
2713 }
2714 Ok(_) => {}
2715 Err(err) => {
2716 error = Some(err);
2717 break;
2718 }
2719 }
2720 }
2721
2722 assert!(!saw_tool_call);
2723 let error = error.expect("unknown model-emitted tool should fail");
2724 match error {
2725 StreamingError::Prompt(err) => match *err {
2726 PromptError::UnknownToolCall {
2727 tool_name,
2728 available_tools,
2729 allowed_tools,
2730 chat_history,
2731 } => {
2732 assert_eq!(tool_name, "default_api");
2733 assert_eq!(available_tools, vec!["add".to_string()]);
2734 assert_eq!(allowed_tools, vec!["add".to_string()]);
2735 assert!(history_contains_tool_call(&chat_history, "default_api"));
2736 }
2737 other => panic!("expected UnknownToolCall, got {other:?}"),
2738 },
2739 other => panic!("expected prompt streaming error, got {other:?}"),
2740 }
2741 assert_eq!(recorded.request_count(), 1);
2742 }
2743
2744 #[tokio::test]
2745 async fn invalid_tool_call_hook_can_repair_streaming_tool_name() {
2746 let model = MockCompletionModel::from_stream_turns([
2747 vec![
2748 MockStreamEvent::tool_call(
2749 "tool_call_1",
2750 "default_api",
2751 serde_json::json!({"x": 2, "y": 3}),
2752 ),
2753 MockStreamEvent::final_response_with_total_tokens(4),
2754 ],
2755 vec![
2756 MockStreamEvent::text("done"),
2757 MockStreamEvent::final_response_with_total_tokens(6),
2758 ],
2759 ]);
2760 let recorded = model.clone();
2761 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2762
2763 let mut stream = agent
2764 .stream_prompt("use the tool")
2765 .with_hook(RepairDefaultApiHook)
2766 .multi_turn(3)
2767 .with_history(Vec::<Message>::new())
2768 .await;
2769 let mut saw_repaired_tool_call = false;
2770 let mut saw_tool_result = false;
2771 let mut final_response_text = None;
2772
2773 while let Some(item) = stream.next().await {
2774 match item {
2775 Ok(MultiTurnStreamItem::StreamAssistantItem(
2776 StreamedAssistantContent::ToolCall { tool_call, .. },
2777 )) => {
2778 assert_eq!(tool_call.function.name, "add");
2779 saw_repaired_tool_call = true;
2780 }
2781 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2782 tool_result,
2783 ..
2784 })) => {
2785 assert!(tool_result.content.iter().any(|content| {
2786 matches!(
2787 content,
2788 ToolResultContent::Text(text) if text.text == "5"
2789 )
2790 }));
2791 saw_tool_result = true;
2792 }
2793 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2794 final_response_text = Some(response.response().to_string());
2795 break;
2796 }
2797 Ok(_) => {}
2798 Err(err) => panic!("unexpected streaming error: {err:?}"),
2799 }
2800 }
2801
2802 assert!(saw_repaired_tool_call);
2803 assert!(saw_tool_result);
2804 assert_eq!(final_response_text.as_deref(), Some("done"));
2805 assert_eq!(recorded.request_count(), 2);
2806 }
2807
2808 #[tokio::test]
2809 async fn invalid_tool_call_context_uses_completed_streaming_tool_call_provider_id() {
2810 let invalid_hook = RecordingInvalidToolCallHook::default();
2811 let model = MockCompletionModel::from_stream_turns([
2812 vec![
2813 MockStreamEvent::tool_call(
2814 "tool_call_1",
2815 "default_api",
2816 serde_json::json!({"x": 2, "y": 3}),
2817 )
2818 .with_call_id("provider_call_1"),
2819 MockStreamEvent::final_response_with_total_tokens(4),
2820 ],
2821 vec![
2822 MockStreamEvent::text("should not be requested"),
2823 MockStreamEvent::final_response_with_total_tokens(6),
2824 ],
2825 ]);
2826 let recorded = model.clone();
2827 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2828
2829 let mut stream = agent
2830 .stream_prompt("use the tool")
2831 .with_hook(invalid_hook.clone())
2832 .multi_turn(3)
2833 .await;
2834 let mut error = None;
2835
2836 while let Some(item) = stream.next().await {
2837 if let Err(err) = item {
2838 error = Some(err);
2839 break;
2840 }
2841 }
2842
2843 assert!(error.is_some(), "invalid tool should fail");
2844 assert_eq!(recorded.request_count(), 1);
2845 let contexts = invalid_hook.observed();
2846 assert_eq!(contexts.len(), 1);
2847 let context = &contexts[0];
2848 assert_eq!(context.tool_name, "default_api");
2849 assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
2850 assert!(context.internal_call_id.is_some());
2851 assert!(context.is_streaming);
2852 }
2853
2854 #[tokio::test]
2855 async fn invalid_tool_call_hook_skip_emits_streaming_tool_result() {
2856 let add_calls = Arc::new(AtomicU32::new(0));
2857 let model = MockCompletionModel::from_stream_turns([
2858 vec![
2859 MockStreamEvent::tool_call(
2860 "tool_call_1",
2861 "default_api",
2862 serde_json::json!({"x": 2, "y": 3}),
2863 )
2864 .with_call_id("call_1"),
2865 MockStreamEvent::final_response_with_total_tokens(4),
2866 ],
2867 vec![
2868 MockStreamEvent::text("continued"),
2869 MockStreamEvent::final_response_with_total_tokens(6),
2870 ],
2871 ]);
2872 let recorded = model.clone();
2873 let agent = AgentBuilder::new(model)
2874 .tool(CountingAddTool {
2875 calls: add_calls.clone(),
2876 })
2877 .build();
2878
2879 let mut stream = agent
2880 .stream_prompt("use the tool")
2881 .with_hook(SkipDefaultApiHook)
2882 .multi_turn(3)
2883 .with_history(Vec::<Message>::new())
2884 .await;
2885 let mut skipped_tool_result = None;
2886 let mut final_response_text = None;
2887
2888 while let Some(item) = stream.next().await {
2889 match item {
2890 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2891 tool_result,
2892 internal_call_id,
2893 })) => {
2894 assert!(!internal_call_id.is_empty());
2895 skipped_tool_result = Some(tool_result);
2896 }
2897 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2898 final_response_text = Some(response.response().to_string());
2899 break;
2900 }
2901 Ok(_) => {}
2902 Err(err) => panic!("unexpected streaming error: {err:?}"),
2903 }
2904 }
2905
2906 let skipped_tool_result =
2907 skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
2908 assert_eq!(skipped_tool_result.id, "tool_call_1");
2909 assert_eq!(skipped_tool_result.call_id.as_deref(), Some("call_1"));
2910 assert!(skipped_tool_result.content.iter().any(|content| matches!(
2911 content,
2912 ToolResultContent::Text(text) if text.text == "default_api was skipped"
2913 )));
2914 assert_eq!(final_response_text.as_deref(), Some("continued"));
2915 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
2916
2917 let requests = recorded.requests();
2918 assert_eq!(requests.len(), 2);
2919 let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
2920 assert!(matches!(
2921 follow_up_history.get(2),
2922 Some(Message::User { content })
2923 if content.iter().any(|item| matches!(
2924 item,
2925 UserContent::ToolResult(result)
2926 if result.id == "tool_call_1"
2927 && result.content.iter().any(|content| matches!(
2928 content,
2929 ToolResultContent::Text(text)
2930 if text.text == "default_api was skipped"
2931 ))
2932 ))
2933 ));
2934 }
2935
2936 #[tokio::test]
2937 async fn invalid_tool_call_hook_retries_mixed_streaming_turn_without_executing_valid_call() {
2938 let add_calls = Arc::new(AtomicU32::new(0));
2939 let model = MockCompletionModel::from_stream_turns([
2940 vec![
2941 MockStreamEvent::text("checking "),
2942 MockStreamEvent::tool_call(
2943 "tool_call_1",
2944 "add",
2945 serde_json::json!({"x": 2, "y": 3}),
2946 )
2947 .with_call_id("call_1"),
2948 MockStreamEvent::tool_call(
2949 "tool_call_2",
2950 "default_api",
2951 serde_json::json!({"x": 4, "y": 5}),
2952 )
2953 .with_call_id("call_2"),
2954 MockStreamEvent::final_response_with_total_tokens(4),
2955 ],
2956 vec![
2957 MockStreamEvent::text("retried"),
2958 MockStreamEvent::final_response_with_total_tokens(6),
2959 ],
2960 ]);
2961 let recorded = model.clone();
2962 let agent = AgentBuilder::new(model)
2963 .tool(CountingAddTool {
2964 calls: add_calls.clone(),
2965 })
2966 .build();
2967
2968 let mut stream = agent
2969 .stream_prompt("use the tool")
2970 .with_hook(RetryDefaultApiHook)
2971 .multi_turn(3)
2972 .with_history(Vec::<Message>::new())
2973 .max_invalid_tool_call_retries(1)
2974 .await;
2975 let mut completion_call_events = Vec::new();
2976 let mut final_response_text = None;
2977 let mut final_response_usage = Usage::new();
2978 let mut final_completion_calls = Vec::new();
2979
2980 while let Some(item) = stream.next().await {
2981 match item {
2982 Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
2983 completion_call_events.push(completion_call);
2984 }
2985 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2986 final_response_text = Some(response.response().to_string());
2987 final_response_usage = response.usage();
2988 final_completion_calls = response.completion_calls().to_vec();
2989 break;
2990 }
2991 Ok(_) => {}
2992 Err(err) => panic!("unexpected streaming error: {err:?}"),
2993 }
2994 }
2995
2996 assert_eq!(final_response_text.as_deref(), Some("retried"));
2997 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
2998 let mut first_usage = Usage::new();
2999 first_usage.total_tokens = 4;
3000 let mut second_usage = Usage::new();
3001 second_usage.total_tokens = 6;
3002 let expected_completion_calls = vec![
3003 CompletionCall::new(0, Some(first_usage)),
3004 CompletionCall::new(1, Some(second_usage)),
3005 ];
3006 assert_eq!(completion_call_events, expected_completion_calls);
3007 assert_eq!(final_completion_calls, expected_completion_calls);
3008 assert_eq!(final_response_usage.total_tokens, 10);
3009
3010 let requests = recorded.requests();
3011 assert_eq!(requests.len(), 2);
3012 let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3013 assert_eq!(retry_history.len(), 3);
3014 assert!(matches!(
3015 retry_history.get(1),
3016 Some(Message::Assistant { content, .. })
3017 if content.iter().any(|item| matches!(
3018 item,
3019 AssistantContent::Text(text) if text.text == "checking "
3020 ))
3021 && content.iter().any(|item| matches!(
3022 item,
3023 AssistantContent::ToolCall(tool_call)
3024 if tool_call.id == "tool_call_1"
3025 && tool_call.function.name == "add"
3026 ))
3027 && content.iter().any(|item| matches!(
3028 item,
3029 AssistantContent::ToolCall(tool_call)
3030 if tool_call.id == "tool_call_2"
3031 && tool_call.function.name == "default_api"
3032 ))
3033 ));
3034 assert!(matches!(
3035 retry_history.get(2),
3036 Some(Message::User { content })
3037 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3038 && content.iter().any(|item| matches!(
3039 item,
3040 UserContent::ToolResult(result)
3041 if result.id == "tool_call_1"
3042 && result.content.iter().any(|content| matches!(
3043 content,
3044 ToolResultContent::Text(text)
3045 if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3046 ))
3047 ))
3048 && content.iter().any(|item| matches!(
3049 item,
3050 UserContent::ToolResult(result)
3051 if result.id == "tool_call_2"
3052 && result.content.iter().any(|content| matches!(
3053 content,
3054 ToolResultContent::Text(text)
3055 if text.text == "Use the add tool instead"
3056 ))
3057 ))
3058 ));
3059 }
3060
3061 #[tokio::test]
3062 async fn invalid_tool_call_hook_skips_mixed_streaming_turn_without_executing_valid_call() {
3063 let add_calls = Arc::new(AtomicU32::new(0));
3064 let model = MockCompletionModel::from_stream_turns([
3065 vec![
3066 MockStreamEvent::text("checking "),
3067 MockStreamEvent::tool_call(
3068 "tool_call_1",
3069 "add",
3070 serde_json::json!({"x": 2, "y": 3}),
3071 )
3072 .with_call_id("call_1"),
3073 MockStreamEvent::tool_call(
3074 "tool_call_2",
3075 "default_api",
3076 serde_json::json!({"x": 4, "y": 5}),
3077 )
3078 .with_call_id("call_2"),
3079 MockStreamEvent::final_response_with_total_tokens(4),
3080 ],
3081 vec![
3082 MockStreamEvent::text("continued"),
3083 MockStreamEvent::final_response_with_total_tokens(6),
3084 ],
3085 ]);
3086 let recorded = model.clone();
3087 let agent = AgentBuilder::new(model)
3088 .tool(CountingAddTool {
3089 calls: add_calls.clone(),
3090 })
3091 .build();
3092
3093 let mut stream = agent
3094 .stream_prompt("use the tool")
3095 .with_hook(SkipDefaultApiHook)
3096 .multi_turn(3)
3097 .with_history(Vec::<Message>::new())
3098 .await;
3099 let mut skipped_tool_result = None;
3100 let mut final_response_text = None;
3101
3102 while let Some(item) = stream.next().await {
3103 match item {
3104 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3105 tool_result,
3106 ..
3107 })) => {
3108 skipped_tool_result = Some(tool_result);
3109 }
3110 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3111 final_response_text = Some(response.response().to_string());
3112 break;
3113 }
3114 Ok(_) => {}
3115 Err(err) => panic!("unexpected streaming error: {err:?}"),
3116 }
3117 }
3118
3119 let skipped_tool_result =
3120 skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
3121 assert_eq!(skipped_tool_result.id, "tool_call_2");
3122 assert_eq!(skipped_tool_result.call_id.as_deref(), Some("call_2"));
3123 assert_eq!(final_response_text.as_deref(), Some("continued"));
3124 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3125
3126 let requests = recorded.requests();
3127 assert_eq!(requests.len(), 2);
3128 let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3129 assert_eq!(follow_up_history.len(), 3);
3130 assert!(matches!(
3131 follow_up_history.get(1),
3132 Some(Message::Assistant { content, .. })
3133 if content.iter().any(|item| matches!(
3134 item,
3135 AssistantContent::Text(text) if text.text == "checking "
3136 ))
3137 && content.iter().any(|item| matches!(
3138 item,
3139 AssistantContent::ToolCall(tool_call)
3140 if tool_call.id == "tool_call_1"
3141 && tool_call.function.name == "add"
3142 ))
3143 && content.iter().any(|item| matches!(
3144 item,
3145 AssistantContent::ToolCall(tool_call)
3146 if tool_call.id == "tool_call_2"
3147 && tool_call.function.name == "default_api"
3148 ))
3149 ));
3150 assert!(matches!(
3151 follow_up_history.get(2),
3152 Some(Message::User { content })
3153 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3154 && content.iter().any(|item| matches!(
3155 item,
3156 UserContent::ToolResult(result)
3157 if result.id == "tool_call_1"
3158 && result.call_id.as_deref() == Some("call_1")
3159 && result.content.iter().any(|content| matches!(
3160 content,
3161 ToolResultContent::Text(text)
3162 if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3163 ))
3164 ))
3165 && content.iter().any(|item| matches!(
3166 item,
3167 UserContent::ToolResult(result)
3168 if result.id == "tool_call_2"
3169 && result.call_id.as_deref() == Some("call_2")
3170 && result.content.iter().any(|content| matches!(
3171 content,
3172 ToolResultContent::Text(text)
3173 if text.text == "default_api was skipped"
3174 ))
3175 ))
3176 ));
3177 }
3178
3179 #[tokio::test]
3180 async fn invalid_completed_tool_call_skip_preserves_streaming_reasoning_history() {
3181 let model = MockCompletionModel::from_stream_turns([
3182 vec![
3183 MockStreamEvent::text("checking "),
3184 MockStreamEvent::reasoning("reasoned step").with_reasoning_id("rs_1"),
3185 MockStreamEvent::tool_call(
3186 "tool_call_1",
3187 "default_api",
3188 serde_json::json!({"x": 2, "y": 3}),
3189 ),
3190 MockStreamEvent::final_response_with_total_tokens(4),
3191 ],
3192 vec![
3193 MockStreamEvent::text("continued"),
3194 MockStreamEvent::final_response_with_total_tokens(6),
3195 ],
3196 ]);
3197 let recorded = model.clone();
3198 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3199
3200 let mut stream = agent
3201 .stream_prompt("use the tool")
3202 .with_hook(SkipDefaultApiHook)
3203 .multi_turn(3)
3204 .with_history(Vec::<Message>::new())
3205 .await;
3206
3207 while let Some(item) = stream.next().await {
3208 match item {
3209 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3210 Ok(_) => {}
3211 Err(err) => panic!("unexpected streaming error: {err:?}"),
3212 }
3213 }
3214
3215 let requests = recorded.requests();
3216 assert_eq!(requests.len(), 2);
3217 let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3218 assert!(history_contains_text(&follow_up_history, "checking "));
3219 assert!(assistant_reasoning_precedes_tool_call(
3220 &follow_up_history,
3221 "reasoned step",
3222 "default_api"
3223 ));
3224 }
3225
3226 #[tokio::test]
3227 async fn invalid_name_delta_retry_preserves_streaming_reasoning_history() {
3228 let model = MockCompletionModel::from_stream_turns([
3229 vec![
3230 MockStreamEvent::reasoning_delta(Some("rs_1"), "delta reason"),
3231 MockStreamEvent::tool_call_arguments_delta(
3232 "tool_call_1",
3233 "internal_1",
3234 r#"{"x":2,"y":3}"#,
3235 ),
3236 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3237 MockStreamEvent::final_response_with_total_tokens(4),
3238 ],
3239 vec![
3240 MockStreamEvent::text("retried"),
3241 MockStreamEvent::final_response_with_total_tokens(6),
3242 ],
3243 ]);
3244 let recorded = model.clone();
3245 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3246
3247 let mut stream = agent
3248 .stream_prompt("use the tool")
3249 .with_hook(RetryDefaultApiHook)
3250 .multi_turn(3)
3251 .with_history(Vec::<Message>::new())
3252 .max_invalid_tool_call_retries(1)
3253 .await;
3254
3255 while let Some(item) = stream.next().await {
3256 match item {
3257 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3258 Ok(_) => {}
3259 Err(err) => panic!("unexpected streaming error: {err:?}"),
3260 }
3261 }
3262
3263 let requests = recorded.requests();
3264 assert_eq!(requests.len(), 2);
3265 let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3266 assert!(assistant_reasoning_precedes_tool_call(
3267 &retry_history,
3268 "delta reason",
3269 "default_api"
3270 ));
3271 }
3272
3273 #[tokio::test]
3274 async fn invalid_tool_call_hook_skip_resets_streaming_text_delta_state() {
3275 let text_hook = RecordingTextDeltaHook::default();
3276 let model = MockCompletionModel::from_stream_turns([
3277 vec![
3278 MockStreamEvent::text("stale "),
3279 MockStreamEvent::tool_call(
3280 "tool_call_1",
3281 "default_api",
3282 serde_json::json!({"x": 2, "y": 3}),
3283 ),
3284 MockStreamEvent::final_response_with_total_tokens(4),
3285 ],
3286 vec![
3287 MockStreamEvent::text("fresh"),
3288 MockStreamEvent::final_response_with_total_tokens(6),
3289 ],
3290 ]);
3291 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3292
3293 let mut stream = agent
3294 .stream_prompt("use the tool")
3295 .with_hook(RecordingTextAndSkipInvalidToolHook {
3296 text: text_hook.clone(),
3297 })
3298 .multi_turn(3)
3299 .with_history(Vec::<Message>::new())
3300 .await;
3301
3302 while let Some(item) = stream.next().await {
3303 match item {
3304 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3305 Ok(_) => {}
3306 Err(err) => panic!("unexpected streaming error: {err:?}"),
3307 }
3308 }
3309
3310 assert_eq!(
3311 text_hook.observed(),
3312 vec![
3313 ("stale ".to_string(), "stale ".to_string()),
3314 ("fresh".to_string(), "fresh".to_string()),
3315 ]
3316 );
3317 }
3318
3319 #[tokio::test]
3320 async fn invalid_tool_call_delta_retry_uses_structured_tool_feedback() {
3321 let delta_hook = RecordingToolCallDeltaHook::default();
3322 let add_calls = Arc::new(AtomicU32::new(0));
3323 let model = MockCompletionModel::from_stream_turns([
3324 vec![
3325 MockStreamEvent::text("checking "),
3326 MockStreamEvent::reasoning_delta(Some("rs_1"), "diagnostic reason"),
3327 MockStreamEvent::tool_call(
3328 "tool_call_0",
3329 "add",
3330 serde_json::json!({"x": 1, "y": 2}),
3331 )
3332 .with_call_id("call_0"),
3333 MockStreamEvent::tool_call_arguments_delta(
3334 "tool_call_1",
3335 "internal_1",
3336 r#"{"x":2,"y":3}"#,
3337 ),
3338 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3339 MockStreamEvent::final_response_with_total_tokens(4),
3340 ],
3341 vec![
3342 MockStreamEvent::text("retried"),
3343 MockStreamEvent::final_response_with_total_tokens(6),
3344 ],
3345 ]);
3346 let recorded = model.clone();
3347 let agent = AgentBuilder::new(model)
3348 .tool(CountingAddTool {
3349 calls: add_calls.clone(),
3350 })
3351 .build();
3352
3353 let mut stream = agent
3354 .stream_prompt("use the tool")
3355 .with_hook(RecordingDeltaAndRetryInvalidToolHook {
3356 delta: delta_hook.clone(),
3357 })
3358 .multi_turn(3)
3359 .with_history(Vec::<Message>::new())
3360 .max_invalid_tool_call_retries(1)
3361 .await;
3362 let mut completion_call_events = Vec::new();
3363 let mut final_response_text = None;
3364 let mut final_response_usage = Usage::new();
3365 let mut final_completion_calls = Vec::new();
3366
3367 while let Some(item) = stream.next().await {
3368 match item {
3369 Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
3370 completion_call_events.push(completion_call);
3371 }
3372 Ok(MultiTurnStreamItem::StreamAssistantItem(
3373 StreamedAssistantContent::ToolCallDelta { .. },
3374 )) => panic!("invalid tool-call delta should not be emitted"),
3375 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3376 final_response_text = Some(response.response().to_string());
3377 final_response_usage = response.usage();
3378 final_completion_calls = response.completion_calls().to_vec();
3379 break;
3380 }
3381 Ok(_) => {}
3382 Err(err) => panic!("unexpected streaming error: {err:?}"),
3383 }
3384 }
3385
3386 assert_eq!(final_response_text.as_deref(), Some("retried"));
3387 assert!(delta_hook.observed().is_empty());
3388 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3389 let mut first_usage = Usage::new();
3390 first_usage.total_tokens = 4;
3391 let mut second_usage = Usage::new();
3392 second_usage.total_tokens = 6;
3393 let expected_completion_calls = vec![
3394 CompletionCall::new(0, Some(first_usage)),
3395 CompletionCall::new(1, Some(second_usage)),
3396 ];
3397 assert_eq!(completion_call_events, expected_completion_calls);
3398 assert_eq!(final_completion_calls, expected_completion_calls);
3399 assert_eq!(final_response_usage.total_tokens, 10);
3400
3401 let requests = recorded.requests();
3402 assert_eq!(requests.len(), 2);
3403 let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3404 assert!(matches!(
3405 retry_history.get(1),
3406 Some(Message::Assistant { content, .. })
3407 if content.iter().any(|item| matches!(
3408 item,
3409 AssistantContent::Text(text) if text.text == "checking "
3410 ))
3411 && content.iter().any(|item| matches!(
3412 item,
3413 AssistantContent::ToolCall(tool_call)
3414 if tool_call.id == "tool_call_0"
3415 && tool_call.function.name == "add"
3416 ))
3417 && content.iter().any(|item| matches!(
3418 item,
3419 AssistantContent::ToolCall(tool_call)
3420 if tool_call.id == "tool_call_1"
3421 && tool_call.function.name == "default_api"
3422 && tool_call.function.arguments == serde_json::json!({"x": 2, "y": 3})
3423 ))
3424 ));
3425 assert!(matches!(
3426 retry_history.get(2),
3427 Some(Message::User { content })
3428 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3429 && content.iter().any(|item| matches!(
3430 item,
3431 UserContent::ToolResult(result)
3432 if result.id == "tool_call_0"
3433 && result.call_id.as_deref() == Some("call_0")
3434 && result.content.iter().any(|content| matches!(
3435 content,
3436 ToolResultContent::Text(text)
3437 if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3438 ))
3439 ))
3440 && content.iter().any(|item| matches!(
3441 item,
3442 UserContent::ToolResult(result)
3443 if result.id == "tool_call_1"
3444 && result.content.iter().any(|content| matches!(
3445 content,
3446 ToolResultContent::Text(text)
3447 if text.text == "Use the add tool instead"
3448 ))
3449 ))
3450 ));
3451 }
3452
3453 #[tokio::test]
3454 async fn invalid_tool_call_delta_context_includes_same_turn_history_and_tool_call_id() {
3455 let invalid_hook = RecordingInvalidToolCallHook::default();
3456 let model = MockCompletionModel::from_stream_turns([
3457 vec![
3458 MockStreamEvent::text("checking "),
3459 MockStreamEvent::reasoning_delta(Some("rs_1"), "diagnostic reason"),
3460 MockStreamEvent::tool_call(
3461 "tool_call_0",
3462 "add",
3463 serde_json::json!({"x": 1, "y": 2}),
3464 )
3465 .with_call_id("call_0"),
3466 MockStreamEvent::tool_call_arguments_delta(
3467 "tool_call_1",
3468 "internal_1",
3469 r#"{"x":2,"y":3}"#,
3470 ),
3471 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3472 MockStreamEvent::final_response_with_total_tokens(4),
3473 ],
3474 vec![
3475 MockStreamEvent::text("should not be requested"),
3476 MockStreamEvent::final_response_with_total_tokens(6),
3477 ],
3478 ]);
3479 let recorded = model.clone();
3480 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3481
3482 let mut stream = agent
3483 .stream_prompt("use the tool")
3484 .with_hook(invalid_hook.clone())
3485 .multi_turn(3)
3486 .await;
3487 let mut error = None;
3488
3489 while let Some(item) = stream.next().await {
3490 if let Err(err) = item {
3491 error = Some(err);
3492 break;
3493 }
3494 }
3495
3496 assert!(error.is_some(), "invalid name delta should fail");
3497 assert_eq!(recorded.request_count(), 1);
3498 let contexts = invalid_hook.observed();
3499 assert_eq!(contexts.len(), 1);
3500 let context = &contexts[0];
3501 assert_eq!(context.tool_name, "default_api");
3502 assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
3503 assert_eq!(context.internal_call_id.as_deref(), Some("internal_1"));
3504 assert!(context.is_streaming);
3505 assert!(history_contains_text(&context.chat_history, "checking "));
3506 assert!(
3507 assistant_reasoning_precedes_tool_call(
3508 &context.chat_history,
3509 "diagnostic reason",
3510 "add"
3511 ),
3512 "{:?}",
3513 context.chat_history
3514 );
3515 assert!(history_contains_tool_call(&context.chat_history, "add"));
3516 assert!(history_contains_tool_call(
3517 &context.chat_history,
3518 "default_api"
3519 ));
3520 }
3521
3522 #[tokio::test]
3523 async fn invalid_tool_call_delta_retry_resets_streaming_text_delta_state() {
3524 let text_hook = RecordingTextDeltaHook::default();
3525 let model = MockCompletionModel::from_stream_turns([
3526 vec![
3527 MockStreamEvent::text("stale "),
3528 MockStreamEvent::tool_call_arguments_delta(
3529 "tool_call_1",
3530 "internal_1",
3531 r#"{"x":2,"y":3}"#,
3532 ),
3533 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3534 MockStreamEvent::final_response_with_total_tokens(4),
3535 ],
3536 vec![
3537 MockStreamEvent::text("fresh"),
3538 MockStreamEvent::final_response_with_total_tokens(6),
3539 ],
3540 ]);
3541 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3542
3543 let mut stream = agent
3544 .stream_prompt("use the tool")
3545 .with_hook(RecordingTextAndRetryInvalidToolHook {
3546 text: text_hook.clone(),
3547 })
3548 .multi_turn(3)
3549 .with_history(Vec::<Message>::new())
3550 .max_invalid_tool_call_retries(1)
3551 .await;
3552
3553 while let Some(item) = stream.next().await {
3554 match item {
3555 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3556 Ok(_) => {}
3557 Err(err) => panic!("unexpected streaming error: {err:?}"),
3558 }
3559 }
3560
3561 assert_eq!(
3562 text_hook.observed(),
3563 vec![
3564 ("stale ".to_string(), "stale ".to_string()),
3565 ("fresh".to_string(), "fresh".to_string()),
3566 ]
3567 );
3568 }
3569
3570 #[tokio::test]
3571 async fn invalid_tool_call_delta_skip_uses_structured_tool_feedback() {
3572 let delta_hook = RecordingToolCallDeltaHook::default();
3573 let add_calls = Arc::new(AtomicU32::new(0));
3574 let model = MockCompletionModel::from_stream_turns([
3575 vec![
3576 MockStreamEvent::text("checking "),
3577 MockStreamEvent::tool_call(
3578 "tool_call_0",
3579 "add",
3580 serde_json::json!({"x": 1, "y": 2}),
3581 )
3582 .with_call_id("call_0"),
3583 MockStreamEvent::tool_call_arguments_delta(
3584 "tool_call_1",
3585 "internal_1",
3586 r#"{"x":2,"y":3}"#,
3587 ),
3588 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3589 MockStreamEvent::final_response_with_total_tokens(4),
3590 ],
3591 vec![
3592 MockStreamEvent::text("continued"),
3593 MockStreamEvent::final_response_with_total_tokens(6),
3594 ],
3595 ]);
3596 let recorded = model.clone();
3597 let agent = AgentBuilder::new(model)
3598 .tool(CountingAddTool {
3599 calls: add_calls.clone(),
3600 })
3601 .build();
3602
3603 let mut stream = agent
3604 .stream_prompt("use the tool")
3605 .with_hook(RecordingDeltaAndSkipInvalidToolHook {
3606 delta: delta_hook.clone(),
3607 })
3608 .multi_turn(3)
3609 .with_history(Vec::<Message>::new())
3610 .await;
3611 let mut skipped_tool_result = None;
3612 let mut final_response_text = None;
3613
3614 while let Some(item) = stream.next().await {
3615 match item {
3616 Ok(MultiTurnStreamItem::StreamAssistantItem(
3617 StreamedAssistantContent::ToolCallDelta { .. },
3618 )) => panic!("invalid tool-call delta should not be emitted"),
3619 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3620 tool_result,
3621 internal_call_id,
3622 })) => {
3623 assert_eq!(internal_call_id, "internal_1");
3624 skipped_tool_result = Some(tool_result);
3625 }
3626 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3627 final_response_text = Some(response.response().to_string());
3628 break;
3629 }
3630 Ok(_) => {}
3631 Err(err) => panic!("unexpected streaming error: {err:?}"),
3632 }
3633 }
3634
3635 let skipped_tool_result =
3636 skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
3637 assert_eq!(skipped_tool_result.id, "tool_call_1");
3638 assert!(skipped_tool_result.call_id.is_none());
3639 assert!(skipped_tool_result.content.iter().any(|content| matches!(
3640 content,
3641 ToolResultContent::Text(text) if text.text == "default_api was skipped"
3642 )));
3643 assert_eq!(final_response_text.as_deref(), Some("continued"));
3644 assert!(delta_hook.observed().is_empty());
3645 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3646
3647 let requests = recorded.requests();
3648 assert_eq!(requests.len(), 2);
3649 let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3650 assert!(matches!(
3651 follow_up_history.get(1),
3652 Some(Message::Assistant { content, .. })
3653 if content.iter().any(|item| matches!(
3654 item,
3655 AssistantContent::Text(text) if text.text == "checking "
3656 ))
3657 && content.iter().any(|item| matches!(
3658 item,
3659 AssistantContent::ToolCall(tool_call)
3660 if tool_call.id == "tool_call_0"
3661 && tool_call.function.name == "add"
3662 ))
3663 && content.iter().any(|item| matches!(
3664 item,
3665 AssistantContent::ToolCall(tool_call)
3666 if tool_call.id == "tool_call_1"
3667 && tool_call.function.name == "default_api"
3668 && tool_call.function.arguments == serde_json::json!({"x": 2, "y": 3})
3669 ))
3670 ));
3671 assert!(matches!(
3672 follow_up_history.get(2),
3673 Some(Message::User { content })
3674 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3675 && content.iter().any(|item| matches!(
3676 item,
3677 UserContent::ToolResult(result)
3678 if result.id == "tool_call_0"
3679 && result.call_id.as_deref() == Some("call_0")
3680 && result.content.iter().any(|content| matches!(
3681 content,
3682 ToolResultContent::Text(text)
3683 if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3684 ))
3685 ))
3686 && content.iter().any(|item| matches!(
3687 item,
3688 UserContent::ToolResult(result)
3689 if result.id == "tool_call_1"
3690 && result.content.iter().any(|content| matches!(
3691 content,
3692 ToolResultContent::Text(text)
3693 if text.text == "default_api was skipped"
3694 ))
3695 ))
3696 ));
3697 }
3698
3699 #[tokio::test]
3700 async fn streaming_retry_budget_exhaustion_history_contains_invalid_tool_call() {
3701 let model = MockCompletionModel::from_stream_turns([
3702 vec![
3703 MockStreamEvent::tool_call(
3704 "tool_call_1",
3705 "default_api",
3706 serde_json::json!({"x": 1, "y": 2}),
3707 ),
3708 MockStreamEvent::final_response_with_total_tokens(4),
3709 ],
3710 vec![
3711 MockStreamEvent::text("should not be requested"),
3712 MockStreamEvent::final_response_with_total_tokens(6),
3713 ],
3714 ]);
3715 let recorded = model.clone();
3716 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3717
3718 let mut stream = agent
3719 .stream_prompt("use the tool")
3720 .with_hook(RetryDefaultApiHook)
3721 .multi_turn(3)
3722 .max_invalid_tool_call_retries(0)
3723 .await;
3724 let mut error = None;
3725
3726 while let Some(item) = stream.next().await {
3727 if let Err(err) = item {
3728 error = Some(err);
3729 break;
3730 }
3731 }
3732
3733 let error = error.expect("retry budget exhaustion should fail");
3734 match error {
3735 StreamingError::Prompt(err) => match *err {
3736 PromptError::UnknownToolCall {
3737 tool_name,
3738 chat_history,
3739 ..
3740 } => {
3741 assert_eq!(tool_name, "default_api");
3742 assert!(history_contains_tool_call(&chat_history, "default_api"));
3743 }
3744 other => panic!("expected UnknownToolCall, got {other:?}"),
3745 },
3746 other => panic!("expected prompt streaming error, got {other:?}"),
3747 }
3748 assert_eq!(recorded.request_count(), 1);
3749 }
3750
3751 #[tokio::test]
3752 async fn streaming_name_delta_retry_budget_exhaustion_history_includes_same_turn_context() {
3753 let model = MockCompletionModel::from_stream_turns([
3754 vec![
3755 MockStreamEvent::text("checking "),
3756 MockStreamEvent::tool_call(
3757 "tool_call_0",
3758 "add",
3759 serde_json::json!({"x": 1, "y": 2}),
3760 )
3761 .with_call_id("call_0"),
3762 MockStreamEvent::tool_call_arguments_delta(
3763 "tool_call_1",
3764 "internal_1",
3765 r#"{"x":2,"y":3}"#,
3766 ),
3767 MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3768 MockStreamEvent::final_response_with_total_tokens(4),
3769 ],
3770 vec![
3771 MockStreamEvent::text("should not be requested"),
3772 MockStreamEvent::final_response_with_total_tokens(6),
3773 ],
3774 ]);
3775 let recorded = model.clone();
3776 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3777
3778 let mut stream = agent
3779 .stream_prompt("use the tool")
3780 .with_hook(RetryDefaultApiHook)
3781 .multi_turn(3)
3782 .max_invalid_tool_call_retries(0)
3783 .await;
3784 let mut error = None;
3785
3786 while let Some(item) = stream.next().await {
3787 if let Err(err) = item {
3788 error = Some(err);
3789 break;
3790 }
3791 }
3792
3793 let error = error.expect("retry budget exhaustion should fail");
3794 match error {
3795 StreamingError::Prompt(err) => match *err {
3796 PromptError::UnknownToolCall {
3797 tool_name,
3798 chat_history,
3799 ..
3800 } => {
3801 assert_eq!(tool_name, "default_api");
3802 assert!(history_contains_text(&chat_history, "checking "));
3803 assert!(history_contains_tool_call(&chat_history, "add"));
3804 assert!(history_contains_tool_call(&chat_history, "default_api"));
3805 }
3806 other => panic!("expected UnknownToolCall, got {other:?}"),
3807 },
3808 other => panic!("expected prompt streaming error, got {other:?}"),
3809 }
3810 assert_eq!(recorded.request_count(), 1);
3811 }
3812
3813 #[tokio::test]
3814 async fn completed_unknown_tool_call_after_text_fails_before_finish_hook_or_later_emit() {
3815 let add_calls = Arc::new(AtomicU32::new(0));
3816 let model = MockCompletionModel::from_stream_turns([
3817 vec![
3818 MockStreamEvent::text("thinking "),
3819 MockStreamEvent::tool_call(
3820 "tool_call_1",
3821 "default_api",
3822 serde_json::json!({"x": 1, "y": 2}),
3823 ),
3824 MockStreamEvent::final_response_with_total_tokens(4),
3825 ],
3826 vec![
3827 MockStreamEvent::text("should not be requested"),
3828 MockStreamEvent::final_response_with_total_tokens(6),
3829 ],
3830 ]);
3831 let recorded = model.clone();
3832 let agent = AgentBuilder::new(model)
3833 .tool(CountingAddTool {
3834 calls: add_calls.clone(),
3835 })
3836 .build();
3837
3838 let mut stream = agent
3839 .stream_prompt("use the tool")
3840 .with_hook(PanicOnUnknownToolHook)
3841 .multi_turn(3)
3842 .await;
3843 let mut saw_text = false;
3844 let mut saw_completion_call = false;
3845 let mut saw_final_response = false;
3846 let mut saw_tool_call = false;
3847 let mut saw_tool_result = false;
3848 let mut error = None;
3849
3850 while let Some(item) = stream.next().await {
3851 match item {
3852 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(_))) => {
3853 saw_text = true;
3854 }
3855 Ok(MultiTurnStreamItem::CompletionCall(_)) => {
3856 saw_completion_call = true;
3857 }
3858 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Final(
3859 _,
3860 )))
3861 | Ok(MultiTurnStreamItem::FinalResponse(_)) => {
3862 saw_final_response = true;
3863 }
3864 Ok(MultiTurnStreamItem::StreamAssistantItem(
3865 StreamedAssistantContent::ToolCall { .. },
3866 )) => {
3867 saw_tool_call = true;
3868 }
3869 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3870 ..
3871 })) => {
3872 saw_tool_result = true;
3873 }
3874 Ok(_) => {}
3875 Err(err) => {
3876 error = Some(err);
3877 break;
3878 }
3879 }
3880 }
3881
3882 assert!(saw_text);
3883 assert!(!saw_completion_call);
3884 assert!(!saw_final_response);
3885 assert!(!saw_tool_call);
3886 assert!(!saw_tool_result);
3887 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3888 let error = error.expect("completed unknown tool call should fail immediately");
3889 match error {
3890 StreamingError::Prompt(err) => match *err {
3891 PromptError::UnknownToolCall {
3892 tool_name,
3893 available_tools,
3894 allowed_tools,
3895 chat_history,
3896 } => {
3897 assert_eq!(tool_name, "default_api");
3898 assert_eq!(available_tools, vec!["add".to_string()]);
3899 assert_eq!(allowed_tools, vec!["add".to_string()]);
3900 assert!(history_contains_tool_call(&chat_history, "default_api"));
3901 }
3902 other => panic!("expected UnknownToolCall, got {other:?}"),
3903 },
3904 other => panic!("expected prompt streaming error, got {other:?}"),
3905 }
3906 assert_eq!(recorded.request_count(), 1);
3907 }
3908
3909 #[tokio::test]
3910 async fn mixed_streaming_tool_calls_fail_before_any_tool_execution() {
3911 let add_calls = Arc::new(AtomicU32::new(0));
3912 let model = MockCompletionModel::from_stream_turns([
3913 vec![
3914 MockStreamEvent::tool_call(
3915 "tool_call_1",
3916 "add",
3917 serde_json::json!({"x": 1, "y": 2}),
3918 )
3919 .with_call_id("call_1"),
3920 MockStreamEvent::tool_call(
3921 "tool_call_2",
3922 "default_api",
3923 serde_json::json!({"x": 3, "y": 4}),
3924 ),
3925 MockStreamEvent::final_response_with_total_tokens(4),
3926 ],
3927 vec![
3928 MockStreamEvent::text("should not be requested"),
3929 MockStreamEvent::final_response_with_total_tokens(6),
3930 ],
3931 ]);
3932 let recorded = model.clone();
3933 let agent = AgentBuilder::new(model)
3934 .tool(CountingAddTool {
3935 calls: add_calls.clone(),
3936 })
3937 .build();
3938
3939 let mut stream = agent
3940 .stream_prompt("use tools")
3941 .with_hook(PanicOnUnknownToolHook)
3942 .multi_turn(3)
3943 .await;
3944 let mut saw_completion_call = false;
3945 let mut saw_tool_call = false;
3946 let mut saw_tool_result = false;
3947 let mut error = None;
3948
3949 while let Some(item) = stream.next().await {
3950 match item {
3951 Ok(MultiTurnStreamItem::CompletionCall(_)) => {
3952 saw_completion_call = true;
3953 }
3954 Ok(MultiTurnStreamItem::StreamAssistantItem(
3955 StreamedAssistantContent::ToolCall { .. },
3956 )) => {
3957 saw_tool_call = true;
3958 }
3959 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3960 ..
3961 })) => {
3962 saw_tool_result = true;
3963 }
3964 Ok(_) => {}
3965 Err(err) => {
3966 error = Some(err);
3967 break;
3968 }
3969 }
3970 }
3971
3972 assert!(!saw_completion_call);
3973 assert!(!saw_tool_call);
3974 assert!(!saw_tool_result);
3975 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3976 let error = error.expect("mixed unknown streamed tool call should fail");
3977 match error {
3978 StreamingError::Prompt(err) => match *err {
3979 PromptError::UnknownToolCall {
3980 tool_name,
3981 available_tools,
3982 allowed_tools,
3983 chat_history,
3984 } => {
3985 assert_eq!(tool_name, "default_api");
3986 assert_eq!(available_tools, vec!["add".to_string()]);
3987 assert_eq!(allowed_tools, vec!["add".to_string()]);
3988 assert!(history_contains_tool_call(&chat_history, "default_api"));
3989 }
3990 other => panic!("expected UnknownToolCall, got {other:?}"),
3991 },
3992 other => panic!("expected prompt streaming error, got {other:?}"),
3993 }
3994 assert_eq!(recorded.request_count(), 1);
3995 }
3996
3997 #[tokio::test]
3998 async fn multiple_valid_streaming_tool_calls_execute_after_batch_validation() {
3999 let add_calls = Arc::new(AtomicU32::new(0));
4000 let subtract_calls = Arc::new(AtomicU32::new(0));
4001 let model = MockCompletionModel::from_stream_turns([
4002 vec![
4003 MockStreamEvent::tool_call(
4004 "tool_call_1",
4005 "add",
4006 serde_json::json!({"x": 1, "y": 2}),
4007 )
4008 .with_call_id("call_1"),
4009 MockStreamEvent::tool_call(
4010 "tool_call_2",
4011 "subtract",
4012 serde_json::json!({"x": 8, "y": 3}),
4013 )
4014 .with_call_id("call_2"),
4015 MockStreamEvent::final_response_with_total_tokens(4),
4016 ],
4017 vec![
4018 MockStreamEvent::text("done"),
4019 MockStreamEvent::final_response_with_total_tokens(6),
4020 ],
4021 ]);
4022 let recorded = model.clone();
4023 let agent = AgentBuilder::new(model)
4024 .tool(CountingAddTool {
4025 calls: add_calls.clone(),
4026 })
4027 .tool(CountingSubtractTool {
4028 calls: subtract_calls.clone(),
4029 })
4030 .build();
4031
4032 let mut stream = agent.stream_prompt("use tools").multi_turn(3).await;
4033 let mut tool_call_names = Vec::new();
4034 let mut tool_result_ids = Vec::new();
4035 let mut final_response_text = None;
4036
4037 while let Some(item) = stream.next().await {
4038 match item {
4039 Ok(MultiTurnStreamItem::StreamAssistantItem(
4040 StreamedAssistantContent::ToolCall { tool_call, .. },
4041 )) => {
4042 tool_call_names.push(tool_call.function.name);
4043 }
4044 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
4045 tool_result,
4046 ..
4047 })) => {
4048 tool_result_ids.push(tool_result.id);
4049 }
4050 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4051 final_response_text = Some(response.response().to_owned());
4052 break;
4053 }
4054 Ok(_) => {}
4055 Err(err) => panic!("unexpected streaming error: {err:?}"),
4056 }
4057 }
4058
4059 assert_eq!(
4060 tool_call_names,
4061 vec!["add".to_string(), "subtract".to_string()]
4062 );
4063 assert_eq!(
4064 tool_result_ids,
4065 vec!["tool_call_1".to_string(), "tool_call_2".to_string()]
4066 );
4067 assert_eq!(add_calls.load(Ordering::SeqCst), 1);
4068 assert_eq!(subtract_calls.load(Ordering::SeqCst), 1);
4069 assert_eq!(final_response_text.as_deref(), Some("done"));
4070 assert_eq!(recorded.request_count(), 2);
4071 }
4072
4073 #[tokio::test]
4074 async fn disallowed_specific_tool_call_fails_before_streaming_second_request() {
4075 let model = MockCompletionModel::from_stream_turns([
4076 vec![
4077 MockStreamEvent::tool_call(
4078 "tool_call_1",
4079 "subtract",
4080 serde_json::json!({"x": 3, "y": 1}),
4081 ),
4082 MockStreamEvent::final_response_with_total_tokens(4),
4083 ],
4084 vec![
4085 MockStreamEvent::text("should not be requested"),
4086 MockStreamEvent::final_response_with_total_tokens(6),
4087 ],
4088 ]);
4089 let recorded = model.clone();
4090 let agent = AgentBuilder::new(model)
4091 .tool(MockAddTool)
4092 .tool(MockSubtractTool)
4093 .tool_choice(ToolChoice::Specific {
4094 function_names: vec!["add".to_string()],
4095 })
4096 .build();
4097
4098 let mut stream = agent
4099 .stream_prompt("use the allowed tool")
4100 .with_hook(PanicOnUnknownToolHook)
4101 .multi_turn(3)
4102 .await;
4103 let mut saw_tool_call = false;
4104 let mut error = None;
4105
4106 while let Some(item) = stream.next().await {
4107 match item {
4108 Ok(MultiTurnStreamItem::StreamAssistantItem(
4109 StreamedAssistantContent::ToolCall { .. },
4110 )) => {
4111 saw_tool_call = true;
4112 }
4113 Ok(_) => {}
4114 Err(err) => {
4115 error = Some(err);
4116 break;
4117 }
4118 }
4119 }
4120
4121 assert!(!saw_tool_call);
4122 let error = error.expect("disallowed model-emitted tool should fail");
4123 match error {
4124 StreamingError::Prompt(err) => match *err {
4125 PromptError::UnknownToolCall {
4126 tool_name,
4127 available_tools,
4128 allowed_tools,
4129 chat_history,
4130 } => {
4131 assert_eq!(tool_name, "subtract");
4132 assert_eq!(
4133 available_tools,
4134 vec!["add".to_string(), "subtract".to_string()]
4135 );
4136 assert_eq!(allowed_tools, vec!["add".to_string()]);
4137 assert!(history_contains_tool_call(&chat_history, "subtract"));
4138 }
4139 other => panic!("expected UnknownToolCall, got {other:?}"),
4140 },
4141 other => panic!("expected prompt streaming error, got {other:?}"),
4142 }
4143 assert_eq!(recorded.request_count(), 1);
4144 }
4145
4146 #[tokio::test]
4147 async fn mixed_specific_tool_calls_fail_before_any_tool_execution() {
4148 let add_calls = Arc::new(AtomicU32::new(0));
4149 let model = MockCompletionModel::from_stream_turns([
4150 vec![
4151 MockStreamEvent::tool_call(
4152 "tool_call_1",
4153 "add",
4154 serde_json::json!({"x": 1, "y": 2}),
4155 ),
4156 MockStreamEvent::tool_call(
4157 "tool_call_2",
4158 "subtract",
4159 serde_json::json!({"x": 3, "y": 1}),
4160 ),
4161 MockStreamEvent::final_response_with_total_tokens(4),
4162 ],
4163 vec![
4164 MockStreamEvent::text("should not be requested"),
4165 MockStreamEvent::final_response_with_total_tokens(6),
4166 ],
4167 ]);
4168 let recorded = model.clone();
4169 let agent = AgentBuilder::new(model)
4170 .tool(CountingAddTool {
4171 calls: add_calls.clone(),
4172 })
4173 .tool(MockSubtractTool)
4174 .tool_choice(ToolChoice::Specific {
4175 function_names: vec!["add".to_string()],
4176 })
4177 .build();
4178
4179 let mut stream = agent
4180 .stream_prompt("use the allowed tool")
4181 .with_hook(PanicOnUnknownToolHook)
4182 .multi_turn(3)
4183 .await;
4184 let mut saw_tool_call = false;
4185 let mut saw_tool_result = false;
4186 let mut error = None;
4187
4188 while let Some(item) = stream.next().await {
4189 match item {
4190 Ok(MultiTurnStreamItem::StreamAssistantItem(
4191 StreamedAssistantContent::ToolCall { .. },
4192 )) => {
4193 saw_tool_call = true;
4194 }
4195 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
4196 ..
4197 })) => {
4198 saw_tool_result = true;
4199 }
4200 Ok(_) => {}
4201 Err(err) => {
4202 error = Some(err);
4203 break;
4204 }
4205 }
4206 }
4207
4208 assert!(!saw_tool_call);
4209 assert!(!saw_tool_result);
4210 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
4211 let error = error.expect("mixed disallowed streamed tool call should fail");
4212 match error {
4213 StreamingError::Prompt(err) => match *err {
4214 PromptError::UnknownToolCall {
4215 tool_name,
4216 available_tools,
4217 allowed_tools,
4218 chat_history,
4219 } => {
4220 assert_eq!(tool_name, "subtract");
4221 assert_eq!(
4222 available_tools,
4223 vec!["add".to_string(), "subtract".to_string()]
4224 );
4225 assert_eq!(allowed_tools, vec!["add".to_string()]);
4226 assert!(history_contains_tool_call(&chat_history, "subtract"));
4227 }
4228 other => panic!("expected UnknownToolCall, got {other:?}"),
4229 },
4230 other => panic!("expected prompt streaming error, got {other:?}"),
4231 }
4232 assert_eq!(recorded.request_count(), 1);
4233 }
4234
4235 #[tokio::test]
4236 async fn tool_choice_none_rejects_streaming_tool_call() {
4237 let model = MockCompletionModel::from_stream_turns([
4238 vec![
4239 MockStreamEvent::tool_call(
4240 "tool_call_1",
4241 "add",
4242 serde_json::json!({"x": 1, "y": 2}),
4243 ),
4244 MockStreamEvent::final_response_with_total_tokens(4),
4245 ],
4246 vec![
4247 MockStreamEvent::text("should not be requested"),
4248 MockStreamEvent::final_response_with_total_tokens(6),
4249 ],
4250 ]);
4251 let recorded = model.clone();
4252 let agent = AgentBuilder::new(model)
4253 .tool(MockAddTool)
4254 .tool_choice(ToolChoice::None)
4255 .build();
4256
4257 let mut stream = agent
4258 .stream_prompt("do not use tools")
4259 .with_hook(PanicOnUnknownToolHook)
4260 .multi_turn(3)
4261 .await;
4262 let mut saw_tool_call = false;
4263 let mut error = None;
4264
4265 while let Some(item) = stream.next().await {
4266 match item {
4267 Ok(MultiTurnStreamItem::StreamAssistantItem(
4268 StreamedAssistantContent::ToolCall { .. },
4269 )) => {
4270 saw_tool_call = true;
4271 }
4272 Ok(_) => {}
4273 Err(err) => {
4274 error = Some(err);
4275 break;
4276 }
4277 }
4278 }
4279
4280 assert!(!saw_tool_call);
4281 let error = error.expect("ToolChoice::None should reject returned tool calls");
4282 match error {
4283 StreamingError::Prompt(err) => match *err {
4284 PromptError::UnknownToolCall {
4285 tool_name,
4286 available_tools,
4287 allowed_tools,
4288 chat_history,
4289 } => {
4290 assert_eq!(tool_name, "add");
4291 assert_eq!(available_tools, vec!["add".to_string()]);
4292 assert!(allowed_tools.is_empty());
4293 assert!(history_contains_tool_call(&chat_history, "add"));
4294 }
4295 other => panic!("expected UnknownToolCall, got {other:?}"),
4296 },
4297 other => panic!("expected prompt streaming error, got {other:?}"),
4298 }
4299 assert_eq!(recorded.request_count(), 1);
4300 }
4301
4302 #[tokio::test]
4303 async fn tool_choice_none_rejects_streaming_tool_call_name_delta_before_hook_or_emit() {
4304 let model = MockCompletionModel::from_stream_turns([
4305 vec![
4306 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4307 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4308 MockStreamEvent::final_response_with_total_tokens(4),
4309 ],
4310 vec![
4311 MockStreamEvent::text("should not be requested"),
4312 MockStreamEvent::final_response_with_total_tokens(6),
4313 ],
4314 ]);
4315 let recorded = model.clone();
4316 let agent = AgentBuilder::new(model)
4317 .tool(MockAddTool)
4318 .tool_choice(ToolChoice::None)
4319 .build();
4320
4321 let mut stream = agent
4322 .stream_prompt("do not use tools")
4323 .with_hook(PanicOnUnknownToolHook)
4324 .multi_turn(3)
4325 .await;
4326 let mut saw_delta = false;
4327 let mut error = None;
4328
4329 while let Some(item) = stream.next().await {
4330 match item {
4331 Ok(MultiTurnStreamItem::StreamAssistantItem(
4332 StreamedAssistantContent::ToolCallDelta { .. },
4333 )) => {
4334 saw_delta = true;
4335 }
4336 Ok(_) => {}
4337 Err(err) => {
4338 error = Some(err);
4339 break;
4340 }
4341 }
4342 }
4343
4344 assert!(!saw_delta);
4345 let error = error.expect("ToolChoice::None should reject returned tool-call deltas");
4346 match error {
4347 StreamingError::Prompt(err) => match *err {
4348 PromptError::UnknownToolCall {
4349 tool_name,
4350 available_tools,
4351 allowed_tools,
4352 chat_history,
4353 } => {
4354 assert_eq!(tool_name, "add");
4355 assert_eq!(available_tools, vec!["add".to_string()]);
4356 assert!(allowed_tools.is_empty());
4357 assert!(history_contains_tool_call(&chat_history, "add"));
4358 }
4359 other => panic!("expected UnknownToolCall, got {other:?}"),
4360 },
4361 other => panic!("expected prompt streaming error, got {other:?}"),
4362 }
4363 assert_eq!(recorded.request_count(), 1);
4364 }
4365
4366 #[tokio::test]
4367 async fn unknown_tool_call_name_delta_fails_before_streaming_delta_hook_or_emit() {
4368 let model = MockCompletionModel::from_stream_turns([
4369 vec![
4370 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"),
4371 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4372 MockStreamEvent::final_response_with_total_tokens(4),
4373 ],
4374 vec![
4375 MockStreamEvent::text("should not be requested"),
4376 MockStreamEvent::final_response_with_total_tokens(6),
4377 ],
4378 ]);
4379 let recorded = model.clone();
4380 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4381
4382 let mut stream = agent
4383 .stream_prompt("stream a bad tool call")
4384 .with_hook(PanicOnUnknownToolHook)
4385 .multi_turn(3)
4386 .await;
4387 let mut saw_delta = false;
4388 let mut error = None;
4389
4390 while let Some(item) = stream.next().await {
4391 match item {
4392 Ok(MultiTurnStreamItem::StreamAssistantItem(
4393 StreamedAssistantContent::ToolCallDelta { .. },
4394 )) => {
4395 saw_delta = true;
4396 }
4397 Ok(_) => {}
4398 Err(err) => {
4399 error = Some(err);
4400 break;
4401 }
4402 }
4403 }
4404
4405 assert!(!saw_delta);
4406 let error = error.expect("unknown tool-call name delta should fail");
4407 match error {
4408 StreamingError::Prompt(err) => match *err {
4409 PromptError::UnknownToolCall {
4410 tool_name,
4411 available_tools,
4412 allowed_tools,
4413 chat_history,
4414 } => {
4415 assert_eq!(tool_name, "default_api");
4416 assert_eq!(available_tools, vec!["add".to_string()]);
4417 assert_eq!(allowed_tools, vec!["add".to_string()]);
4418 assert!(history_contains_tool_call(&chat_history, "default_api"));
4419 }
4420 other => panic!("expected UnknownToolCall, got {other:?}"),
4421 },
4422 other => panic!("expected prompt streaming error, got {other:?}"),
4423 }
4424 assert_eq!(recorded.request_count(), 1);
4425 }
4426
4427 #[tokio::test]
4428 async fn tool_call_args_delta_before_unknown_name_fails_before_hook_or_emit() {
4429 let model = MockCompletionModel::from_stream_turns([
4430 vec![
4431 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4432 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"),
4433 MockStreamEvent::final_response_with_total_tokens(4),
4434 ],
4435 vec![
4436 MockStreamEvent::text("should not be requested"),
4437 MockStreamEvent::final_response_with_total_tokens(6),
4438 ],
4439 ]);
4440 let recorded = model.clone();
4441 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4442
4443 let mut stream = agent
4444 .stream_prompt("stream a bad tool call")
4445 .with_hook(PanicOnUnknownToolHook)
4446 .multi_turn(3)
4447 .await;
4448 let mut saw_delta = false;
4449 let mut error = None;
4450
4451 while let Some(item) = stream.next().await {
4452 match item {
4453 Ok(MultiTurnStreamItem::StreamAssistantItem(
4454 StreamedAssistantContent::ToolCallDelta { .. },
4455 )) => {
4456 saw_delta = true;
4457 }
4458 Ok(_) => {}
4459 Err(err) => {
4460 error = Some(err);
4461 break;
4462 }
4463 }
4464 }
4465
4466 assert!(!saw_delta);
4467 let error = error.expect("unknown tool-call name should reject buffered args");
4468 match error {
4469 StreamingError::Prompt(err) => match *err {
4470 PromptError::UnknownToolCall {
4471 tool_name,
4472 available_tools,
4473 allowed_tools,
4474 chat_history,
4475 } => {
4476 assert_eq!(tool_name, "default_api");
4477 assert_eq!(available_tools, vec!["add".to_string()]);
4478 assert_eq!(allowed_tools, vec!["add".to_string()]);
4479 assert!(history_contains_tool_call(&chat_history, "default_api"));
4480 }
4481 other => panic!("expected UnknownToolCall, got {other:?}"),
4482 },
4483 other => panic!("expected prompt streaming error, got {other:?}"),
4484 }
4485 assert_eq!(recorded.request_count(), 1);
4486 }
4487
4488 #[tokio::test]
4489 async fn tool_call_args_delta_before_valid_name_buffers_then_emits_in_safe_order() {
4490 let model = MockCompletionModel::from_stream_turns([[
4491 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4492 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4493 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4494 MockStreamEvent::final_response_with_total_tokens(3),
4495 ]]);
4496 let hook = RecordingToolCallDeltaHook::default();
4497 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4498
4499 let mut stream = agent
4500 .stream_prompt("stream a tool call")
4501 .with_hook(hook.clone())
4502 .await;
4503 let mut stream_deltas = Vec::new();
4504
4505 while let Some(item) = stream.next().await {
4506 match item {
4507 Ok(MultiTurnStreamItem::StreamAssistantItem(
4508 StreamedAssistantContent::ToolCallDelta {
4509 id,
4510 internal_call_id,
4511 content,
4512 },
4513 )) => {
4514 stream_deltas.push((id, internal_call_id, content));
4515 }
4516 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4517 Ok(_) => {}
4518 Err(err) => panic!("unexpected streaming error: {err:?}"),
4519 }
4520 }
4521
4522 assert_eq!(
4523 hook.observed(),
4524 vec![
4525 (
4526 "tool_1".to_string(),
4527 "internal_1".to_string(),
4528 Some("add".to_string()),
4529 String::new()
4530 ),
4531 (
4532 "tool_1".to_string(),
4533 "internal_1".to_string(),
4534 None,
4535 "{\"x\":".to_string()
4536 ),
4537 (
4538 "tool_1".to_string(),
4539 "internal_1".to_string(),
4540 None,
4541 "1}".to_string()
4542 ),
4543 ]
4544 );
4545 assert_eq!(
4546 stream_deltas,
4547 vec![
4548 (
4549 "tool_1".to_string(),
4550 "internal_1".to_string(),
4551 ToolCallDeltaContent::Name("add".to_string())
4552 ),
4553 (
4554 "tool_1".to_string(),
4555 "internal_1".to_string(),
4556 ToolCallDeltaContent::Delta("{\"x\":".to_string())
4557 ),
4558 (
4559 "tool_1".to_string(),
4560 "internal_1".to_string(),
4561 ToolCallDeltaContent::Delta("1}".to_string())
4562 ),
4563 ]
4564 );
4565 }
4566
4567 #[tokio::test]
4568 async fn tool_call_args_delta_without_name_errors_at_stream_end() {
4569 let model = MockCompletionModel::from_stream_turns([
4570 vec![
4571 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4572 MockStreamEvent::final_response_with_total_tokens(4),
4573 ],
4574 vec![
4575 MockStreamEvent::text("should not be requested"),
4576 MockStreamEvent::final_response_with_total_tokens(6),
4577 ],
4578 ]);
4579 let recorded = model.clone();
4580 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4581
4582 let mut stream = agent
4583 .stream_prompt("stream an incomplete tool call")
4584 .with_hook(PanicOnUnknownToolHook)
4585 .multi_turn(3)
4586 .await;
4587 let mut saw_delta = false;
4588 let mut saw_completion_call = false;
4589 let mut saw_final_response = false;
4590 let mut error = None;
4591
4592 while let Some(item) = stream.next().await {
4593 match item {
4594 Ok(MultiTurnStreamItem::StreamAssistantItem(
4595 StreamedAssistantContent::ToolCallDelta { .. },
4596 )) => {
4597 saw_delta = true;
4598 }
4599 Ok(MultiTurnStreamItem::CompletionCall(_)) => {
4600 saw_completion_call = true;
4601 }
4602 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
4603 saw_final_response = true;
4604 }
4605 Ok(_) => {}
4606 Err(err) => {
4607 error = Some(err);
4608 break;
4609 }
4610 }
4611 }
4612
4613 assert!(!saw_delta);
4614 assert!(!saw_completion_call);
4615 assert!(!saw_final_response);
4616 let error = error.expect("unterminated tool-call args delta should fail");
4617 match error {
4618 StreamingError::Completion(CompletionError::ResponseError(message)) => {
4619 assert!(
4620 message.contains("streamed tool call arguments"),
4621 "{message}"
4622 );
4623 assert!(message.contains("tool_1"), "{message}");
4624 assert!(message.contains("internal_1"), "{message}");
4625 }
4626 other => panic!("expected completion response error, got {other:?}"),
4627 }
4628 assert_eq!(recorded.request_count(), 1);
4629 }
4630
4631 #[tokio::test]
4632 async fn tool_choice_none_buffers_args_then_rejects_name_without_emit() {
4633 let model = MockCompletionModel::from_stream_turns([
4634 vec![
4635 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4636 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4637 MockStreamEvent::final_response_with_total_tokens(4),
4638 ],
4639 vec![
4640 MockStreamEvent::text("should not be requested"),
4641 MockStreamEvent::final_response_with_total_tokens(6),
4642 ],
4643 ]);
4644 let recorded = model.clone();
4645 let agent = AgentBuilder::new(model)
4646 .tool(MockAddTool)
4647 .tool_choice(ToolChoice::None)
4648 .build();
4649
4650 let mut stream = agent
4651 .stream_prompt("do not use tools")
4652 .with_hook(PanicOnUnknownToolHook)
4653 .multi_turn(3)
4654 .await;
4655 let mut saw_delta = false;
4656 let mut error = None;
4657
4658 while let Some(item) = stream.next().await {
4659 match item {
4660 Ok(MultiTurnStreamItem::StreamAssistantItem(
4661 StreamedAssistantContent::ToolCallDelta { .. },
4662 )) => {
4663 saw_delta = true;
4664 }
4665 Ok(_) => {}
4666 Err(err) => {
4667 error = Some(err);
4668 break;
4669 }
4670 }
4671 }
4672
4673 assert!(!saw_delta);
4674 let error = error.expect("ToolChoice::None should reject buffered tool-call deltas");
4675 match error {
4676 StreamingError::Prompt(err) => match *err {
4677 PromptError::UnknownToolCall {
4678 tool_name,
4679 available_tools,
4680 allowed_tools,
4681 chat_history,
4682 } => {
4683 assert_eq!(tool_name, "add");
4684 assert_eq!(available_tools, vec!["add".to_string()]);
4685 assert!(allowed_tools.is_empty());
4686 assert!(history_contains_tool_call(&chat_history, "add"));
4687 }
4688 other => panic!("expected UnknownToolCall, got {other:?}"),
4689 },
4690 other => panic!("expected prompt streaming error, got {other:?}"),
4691 }
4692 assert_eq!(recorded.request_count(), 1);
4693 }
4694
4695 #[tokio::test]
4696 async fn stream_prompt_emits_tool_call_deltas_without_hook() {
4697 let model = MockCompletionModel::from_stream_turns([[
4698 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4699 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4700 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4701 MockStreamEvent::final_response_with_total_tokens(3),
4702 ]]);
4703 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4704
4705 let mut stream = agent.stream_prompt("stream a tool call").await;
4706 let mut deltas = Vec::new();
4707
4708 while let Some(item) = stream.next().await {
4709 match item {
4710 Ok(MultiTurnStreamItem::StreamAssistantItem(
4711 StreamedAssistantContent::ToolCallDelta {
4712 id,
4713 internal_call_id,
4714 content,
4715 },
4716 )) => {
4717 deltas.push((id, internal_call_id, content));
4718 }
4719 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4720 Ok(_) => {}
4721 Err(err) => panic!("unexpected streaming error: {err:?}"),
4722 }
4723 }
4724
4725 assert_eq!(
4726 deltas,
4727 vec![
4728 (
4729 "tool_1".to_string(),
4730 "internal_1".to_string(),
4731 ToolCallDeltaContent::Name("add".to_string())
4732 ),
4733 (
4734 "tool_1".to_string(),
4735 "internal_1".to_string(),
4736 ToolCallDeltaContent::Delta("{\"x\":".to_string())
4737 ),
4738 (
4739 "tool_1".to_string(),
4740 "internal_1".to_string(),
4741 ToolCallDeltaContent::Delta("1}".to_string())
4742 ),
4743 ]
4744 );
4745 }
4746
4747 #[tokio::test]
4748 async fn stream_prompt_emits_tool_call_deltas_after_hook_continue() {
4749 let model = MockCompletionModel::from_stream_turns([[
4750 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4751 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4752 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4753 MockStreamEvent::final_response_with_total_tokens(3),
4754 ]]);
4755 let hook = RecordingToolCallDeltaHook::default();
4756 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4757
4758 let mut stream = agent
4759 .stream_prompt("stream a tool call")
4760 .with_hook(hook.clone())
4761 .await;
4762 let mut stream_deltas = Vec::new();
4763
4764 while let Some(item) = stream.next().await {
4765 match item {
4766 Ok(MultiTurnStreamItem::StreamAssistantItem(
4767 StreamedAssistantContent::ToolCallDelta {
4768 id,
4769 internal_call_id,
4770 content,
4771 },
4772 )) => {
4773 stream_deltas.push((id, internal_call_id, content));
4774 }
4775 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4776 Ok(_) => {}
4777 Err(err) => panic!("unexpected streaming error: {err:?}"),
4778 }
4779 }
4780
4781 assert_eq!(
4782 hook.observed(),
4783 vec![
4784 (
4785 "tool_1".to_string(),
4786 "internal_1".to_string(),
4787 Some("add".to_string()),
4788 String::new()
4789 ),
4790 (
4791 "tool_1".to_string(),
4792 "internal_1".to_string(),
4793 None,
4794 "{\"x\":".to_string()
4795 ),
4796 (
4797 "tool_1".to_string(),
4798 "internal_1".to_string(),
4799 None,
4800 "1}".to_string()
4801 ),
4802 ]
4803 );
4804 assert_eq!(
4805 stream_deltas,
4806 vec![
4807 (
4808 "tool_1".to_string(),
4809 "internal_1".to_string(),
4810 ToolCallDeltaContent::Name("add".to_string())
4811 ),
4812 (
4813 "tool_1".to_string(),
4814 "internal_1".to_string(),
4815 ToolCallDeltaContent::Delta("{\"x\":".to_string())
4816 ),
4817 (
4818 "tool_1".to_string(),
4819 "internal_1".to_string(),
4820 ToolCallDeltaContent::Delta("1}".to_string())
4821 ),
4822 ]
4823 );
4824 }
4825
4826 #[tokio::test]
4827 async fn stream_prompt_tool_call_deltas_hook_termination_prevents_delta_emit() {
4828 let model = MockCompletionModel::from_stream_turns([[
4829 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4830 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4831 MockStreamEvent::final_response_with_total_tokens(3),
4832 ]]);
4833 let hook = TerminatingToolCallDeltaHook::default();
4834 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4835
4836 let mut stream = agent
4837 .stream_prompt("stream a tool call")
4838 .with_hook(hook.clone())
4839 .await;
4840 let mut saw_delta = false;
4841 let mut saw_final_response = false;
4842 let mut error_message = None;
4843
4844 while let Some(item) = stream.next().await {
4845 match item {
4846 Ok(MultiTurnStreamItem::StreamAssistantItem(
4847 StreamedAssistantContent::ToolCallDelta { .. },
4848 )) => {
4849 saw_delta = true;
4850 }
4851 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
4852 saw_final_response = true;
4853 }
4854 Ok(_) => {}
4855 Err(err) => {
4856 error_message = Some(err.to_string());
4857 break;
4858 }
4859 }
4860 }
4861
4862 assert_eq!(
4863 hook.observed(),
4864 vec![(
4865 "tool_1".to_string(),
4866 "internal_1".to_string(),
4867 Some("add".to_string()),
4868 String::new()
4869 )]
4870 );
4871 assert!(!saw_delta);
4872 assert!(!saw_final_response);
4873 assert!(
4874 error_message
4875 .as_deref()
4876 .is_some_and(|message| message.contains("PromptCancelled: stop on tool call delta")),
4877 "expected hook termination error, got {error_message:?}"
4878 );
4879 }
4880
4881 #[tokio::test]
4882 async fn stream_prompt_exposes_completion_calls() {
4883 let first_call_usage = usage(10, 2);
4884 let second_call_usage = usage(25, 5);
4885 let model = MockCompletionModel::from_stream_turns([
4886 vec![
4887 MockStreamEvent::tool_call(
4888 "tool_call_1",
4889 "add",
4890 serde_json::json!({"x": 1, "y": 2}),
4891 )
4892 .with_call_id("call_1"),
4893 MockStreamEvent::final_response(first_call_usage),
4894 ],
4895 vec![
4896 MockStreamEvent::text("done"),
4897 MockStreamEvent::final_response(second_call_usage),
4898 ],
4899 ]);
4900 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4901 let empty_history: &[Message] = &[];
4902
4903 let mut stream = agent
4904 .stream_prompt("do tool work")
4905 .with_history(empty_history)
4906 .multi_turn(3)
4907 .await;
4908 let mut completion_calls_events = Vec::new();
4909 let mut final_response = None;
4910
4911 while let Some(item) = stream.next().await {
4912 match item {
4913 Ok(MultiTurnStreamItem::CompletionCall(call_usage)) => {
4914 completion_calls_events.push(call_usage);
4915 }
4916 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4917 final_response = Some(response);
4918 break;
4919 }
4920 Ok(_) => {}
4921 Err(err) => panic!("unexpected streaming error: {err:?}"),
4922 }
4923 }
4924
4925 assert_eq!(
4926 completion_calls_events,
4927 vec![
4928 CompletionCall::new(0, Some(first_call_usage)),
4929 CompletionCall::new(1, Some(second_call_usage))
4930 ]
4931 );
4932
4933 let final_response = final_response.expect("expected final response");
4934 assert_eq!(
4935 final_response.usage(),
4936 Usage {
4937 input_tokens: 35,
4938 output_tokens: 7,
4939 total_tokens: 42,
4940 cached_input_tokens: 0,
4941 cache_creation_input_tokens: 0,
4942 tool_use_prompt_tokens: 0,
4943 reasoning_tokens: 0,
4944 }
4945 );
4946 assert_eq!(
4947 final_response.completion_calls(),
4948 &[
4949 CompletionCall::new(0, Some(first_call_usage)),
4950 CompletionCall::new(1, Some(second_call_usage))
4951 ]
4952 );
4953 }
4954
4955 #[tokio::test]
4956 async fn stream_prompt_emits_completion_call_before_finish_hook_termination() {
4957 let call_usage = usage(10, 2);
4958 let model = MockCompletionModel::from_stream_turns([[
4959 MockStreamEvent::text("done"),
4960 MockStreamEvent::final_response(call_usage),
4961 ]]);
4962 let agent = AgentBuilder::new(model).build();
4963
4964 let mut stream = agent
4965 .stream_prompt("say done")
4966 .with_hook(TerminateOnStreamFinish)
4967 .await;
4968 let mut completion_calls = Vec::new();
4969 let mut saw_error = false;
4970
4971 while let Some(item) = stream.next().await {
4972 match item {
4973 Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
4974 completion_calls.push(completion_call);
4975 }
4976 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4977 panic!("unexpected final response after hook termination: {response:?}");
4978 }
4979 Ok(_) => {}
4980 Err(_) => {
4981 saw_error = true;
4982 break;
4983 }
4984 }
4985 }
4986
4987 assert_eq!(
4988 completion_calls,
4989 vec![CompletionCall::new(0, Some(call_usage))]
4990 );
4991 assert!(saw_error);
4992 }
4993
4994 #[tokio::test]
4995 async fn stream_prompt_completion_calls_records_unreported_usage() {
4996 let second_call_usage = usage(25, 5);
4997 let model = MockCompletionModel::from_stream_turns([
4998 vec![
4999 MockStreamEvent::tool_call(
5000 "tool_call_1",
5001 "add",
5002 serde_json::json!({"x": 1, "y": 2}),
5003 )
5004 .with_call_id("call_1"),
5005 ],
5006 vec![
5007 MockStreamEvent::text("done"),
5008 MockStreamEvent::final_response(second_call_usage),
5009 ],
5010 ]);
5011 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5012 let empty_history: &[Message] = &[];
5013
5014 let mut stream = agent
5015 .stream_prompt("do tool work")
5016 .with_history(empty_history)
5017 .multi_turn(3)
5018 .await;
5019 let mut completion_calls_events = Vec::new();
5020 let mut final_response = None;
5021
5022 while let Some(item) = stream.next().await {
5023 match item {
5024 Ok(MultiTurnStreamItem::CompletionCall(call_usage)) => {
5025 completion_calls_events.push(call_usage);
5026 }
5027 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
5028 final_response = Some(response);
5029 break;
5030 }
5031 Ok(_) => {}
5032 Err(err) => panic!("unexpected streaming error: {err:?}"),
5033 }
5034 }
5035
5036 let expected_usage = vec![
5037 CompletionCall::new(0, None),
5038 CompletionCall::new(1, Some(second_call_usage)),
5039 ];
5040 assert_eq!(completion_calls_events, expected_usage);
5041
5042 let final_response = final_response.expect("expected final response");
5043 assert_eq!(final_response.completion_calls(), expected_usage.as_slice());
5044 }
5045
5046 #[tokio::test]
5047 async fn final_response_matches_streamed_text_when_provider_final_is_textless() {
5048 let agent = AgentBuilder::new(streaming_text_then_final_model()).build();
5049
5050 let mut stream = agent.stream_prompt("say hello").await;
5051 let mut streamed_text = String::new();
5052 let mut final_response_text = None;
5053
5054 while let Some(item) = stream.next().await {
5055 match item {
5056 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5057 text,
5058 ))) => streamed_text.push_str(&text.text),
5059 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5060 final_response_text = Some(res.response().to_owned());
5061 break;
5062 }
5063 Ok(_) => {}
5064 Err(err) => panic!("unexpected streaming error: {err:?}"),
5065 }
5066 }
5067
5068 assert_eq!(streamed_text, "hello world");
5069 assert_eq!(final_response_text.as_deref(), Some("hello world"));
5070 }
5071
5072 #[tokio::test]
5073 async fn final_response_preserves_structured_text_metadata() {
5074 let agent = AgentBuilder::new(streaming_cited_text_then_final_model()).build();
5075
5076 let mut stream = agent.stream_prompt("answer with citations").await;
5077 let mut final_response = None;
5078
5079 while let Some(item) = stream.next().await {
5080 match item {
5081 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5082 final_response = Some(res);
5083 break;
5084 }
5085 Ok(_) => {}
5086 Err(err) => panic!("unexpected streaming error: {err:?}"),
5087 }
5088 }
5089
5090 let final_response = final_response.expect("expected final response");
5091 assert_eq!(final_response.response(), "cited answer");
5092 let metadata = text_metadata(final_response.content())
5093 .expect("expected text metadata in final content");
5094 assert_eq!(
5095 metadata["citations"][0]["encrypted_index"],
5096 "encrypted-reference"
5097 );
5098 }
5099
5100 #[tokio::test]
5101 async fn final_response_history_preserves_structured_text_metadata() {
5102 let agent = AgentBuilder::new(streaming_cited_text_then_final_model()).build();
5103
5104 let empty_history: &[Message] = &[];
5105 let mut stream = agent
5106 .stream_prompt("answer with citations")
5107 .with_history(empty_history)
5108 .await;
5109 let mut final_response = None;
5110
5111 while let Some(item) = stream.next().await {
5112 match item {
5113 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5114 final_response = Some(res);
5115 break;
5116 }
5117 Ok(_) => {}
5118 Err(err) => panic!("unexpected streaming error: {err:?}"),
5119 }
5120 }
5121
5122 let final_response = final_response.expect("expected final response");
5123 let history = final_response
5124 .history()
5125 .expect("with_history should include final history");
5126 let assistant_content = history
5127 .iter()
5128 .find_map(|message| match message {
5129 Message::Assistant { content, .. } => Some(content),
5130 _ => None,
5131 })
5132 .expect("expected assistant message in history");
5133 let metadata =
5134 text_metadata(assistant_content).expect("expected text metadata in assistant history");
5135 assert_eq!(
5136 metadata["citations"][0]["encrypted_index"],
5137 "encrypted-reference"
5138 );
5139 }
5140
5141 #[tokio::test]
5142 async fn tool_follow_up_history_preserves_structured_text_metadata() {
5143 let model = streaming_cited_text_then_tool_model();
5144 let recorded = model.clone();
5145 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5146 let empty_history: &[Message] = &[];
5147
5148 let mut stream = agent
5149 .stream_prompt("use a tool with citations")
5150 .with_history(empty_history)
5151 .multi_turn(3)
5152 .await;
5153
5154 while let Some(item) = stream.next().await {
5155 match item {
5156 Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
5157 Ok(_) => {}
5158 Err(err) => panic!("unexpected streaming error: {err:?}"),
5159 }
5160 }
5161
5162 let requests = recorded.requests();
5163 assert_eq!(requests.len(), 2);
5164 let follow_up_history = requests[1].chat_history.iter().collect::<Vec<_>>();
5165 let assistant_content = follow_up_history
5166 .iter()
5167 .find_map(|message| match message {
5168 Message::Assistant { content, .. } => Some(content),
5169 _ => None,
5170 })
5171 .expect("expected assistant message in follow-up history");
5172 let metadata = text_metadata(assistant_content)
5173 .expect("expected citation metadata in follow-up assistant history");
5174 assert_eq!(
5175 metadata["citations"][0]["encrypted_index"],
5176 "encrypted-reference"
5177 );
5178 }
5179
5180 #[tokio::test]
5181 async fn final_response_can_remain_empty_for_truly_textless_turns() {
5182 let agent = AgentBuilder::new(streaming_final_only_model()).build();
5183
5184 let mut stream = agent.stream_prompt("say nothing").await;
5185 let mut streamed_text = String::new();
5186 let mut final_response_text = None;
5187
5188 while let Some(item) = stream.next().await {
5189 match item {
5190 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5191 text,
5192 ))) => streamed_text.push_str(&text.text),
5193 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5194 final_response_text = Some(res.response().to_owned());
5195 break;
5196 }
5197 Ok(_) => {}
5198 Err(err) => panic!("unexpected streaming error: {err:?}"),
5199 }
5200 }
5201
5202 assert!(streamed_text.is_empty());
5203 assert_eq!(final_response_text.as_deref(), Some(""));
5204 }
5205
5206 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
5209 let mut interval = tokio::time::interval(Duration::from_millis(50));
5210 let mut count = 0u32;
5211
5212 while !stop.load(Ordering::Relaxed) {
5213 interval.tick().await;
5214 count += 1;
5215
5216 tracing::event!(
5217 target: "background_logger",
5218 tracing::Level::INFO,
5219 count = count,
5220 "Background tick"
5221 );
5222
5223 let current = tracing::Span::current();
5225 if !current.is_disabled() && !current.is_none() {
5226 leak_count.fetch_add(1, Ordering::Relaxed);
5227 }
5228 }
5229
5230 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
5231 }
5232
5233 #[tokio::test(flavor = "current_thread")]
5241 #[ignore = "This requires an API key"]
5242 async fn test_span_context_isolation() -> anyhow::Result<()> {
5243 let stop = Arc::new(AtomicBool::new(false));
5244 let leak_count = Arc::new(AtomicU32::new(0));
5245
5246 let bg_stop = stop.clone();
5248 let bg_leak = leak_count.clone();
5249 let bg_handle = tokio::spawn(async move {
5250 background_logger(bg_stop, bg_leak).await;
5251 });
5252
5253 tokio::time::sleep(Duration::from_millis(100)).await;
5255
5256 let client = anthropic::Client::from_env()?;
5259 let agent = client
5260 .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
5261 .preamble("You are a helpful assistant.")
5262 .temperature(0.1)
5263 .max_tokens(100)
5264 .build();
5265
5266 let mut stream = agent
5267 .stream_prompt("Say 'hello world' and nothing else.")
5268 .await;
5269
5270 let mut full_content = String::new();
5271 while let Some(item) = stream.next().await {
5272 match item {
5273 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5274 text,
5275 ))) => {
5276 full_content.push_str(&text.text);
5277 }
5278 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
5279 break;
5280 }
5281 Err(e) => {
5282 tracing::warn!("Error: {:?}", e);
5283 break;
5284 }
5285 _ => {}
5286 }
5287 }
5288
5289 tracing::info!("Got response: {:?}", full_content);
5290
5291 stop.store(true, Ordering::Relaxed);
5293 bg_handle.await?;
5294
5295 let leaks = leak_count.load(Ordering::Relaxed);
5296 anyhow::ensure!(
5297 leaks == 0,
5298 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
5299 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
5300 );
5301
5302 Ok(())
5303 }
5304
5305 #[tokio::test]
5311 #[ignore = "This requires an API key"]
5312 async fn test_chat_history_in_final_response() -> anyhow::Result<()> {
5313 use crate::message::Message;
5314
5315 let client = anthropic::Client::from_env()?;
5316 let agent = client
5317 .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
5318 .preamble("You are a helpful assistant. Keep responses brief.")
5319 .temperature(0.1)
5320 .max_tokens(50)
5321 .build();
5322
5323 let empty_history: &[Message] = &[];
5325 let mut stream = agent
5326 .stream_prompt("Say 'hello' and nothing else.")
5327 .with_history(empty_history)
5328 .await;
5329
5330 let mut response_text = String::new();
5332 let mut final_history = None;
5333 while let Some(item) = stream.next().await {
5334 match item {
5335 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5336 text,
5337 ))) => {
5338 response_text.push_str(&text.text);
5339 }
5340 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5341 final_history = res.history().map(|h| h.to_vec());
5342 break;
5343 }
5344 Err(e) => {
5345 return Err(e.into());
5346 }
5347 _ => {}
5348 }
5349 }
5350
5351 let history = final_history
5352 .ok_or_else(|| anyhow::anyhow!("final response should include history"))?;
5353
5354 anyhow::ensure!(
5356 history.iter().any(|m| matches!(m, Message::User { .. })),
5357 "History should contain the user message"
5358 );
5359
5360 anyhow::ensure!(
5362 history
5363 .iter()
5364 .any(|m| matches!(m, Message::Assistant { .. })),
5365 "History should contain the assistant response"
5366 );
5367
5368 tracing::info!(
5369 "History after streaming: {} messages, response: {:?}",
5370 history.len(),
5371 response_text
5372 );
5373
5374 Ok(())
5375 }
5376
5377 #[tokio::test]
5378 async fn streaming_appends_to_memory_after_final_response() {
5379 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5380
5381 let memory = InMemoryConversationMemory::new();
5382 let agent = AgentBuilder::new(streaming_text_then_final_model())
5383 .memory(memory.clone())
5384 .build();
5385
5386 let mut stream = agent
5387 .stream_prompt("hi there")
5388 .conversation("stream-thread")
5389 .await;
5390
5391 let mut history_in_final = None;
5392 while let Some(item) = stream.next().await {
5393 match item {
5394 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5395 history_in_final = res.history().map(|h| h.to_vec());
5396 break;
5397 }
5398 Ok(_) => {}
5399 Err(err) => panic!("unexpected streaming error: {err:?}"),
5400 }
5401 }
5402
5403 let final_history = history_in_final
5404 .expect("FinalResponse.history should be populated when memory is configured");
5405 assert_eq!(
5406 final_history.len(),
5407 2,
5408 "user prompt + assistant response in final history: {final_history:?}"
5409 );
5410
5411 let stored = memory.load("stream-thread").await.unwrap();
5412 assert_eq!(stored.len(), 2, "memory should contain user + assistant");
5413 }
5414
5415 #[tokio::test]
5416 async fn streaming_reasoning_without_tools_does_not_duplicate_final_history() {
5417 let agent = AgentBuilder::new(MockCompletionModel::from_stream_turns([[
5418 MockStreamEvent::text("final answer"),
5419 MockStreamEvent::reasoning("reasoned step").with_reasoning_id("rs_1"),
5420 MockStreamEvent::final_response_with_total_tokens(3),
5421 ]]))
5422 .build();
5423
5424 let mut stream = agent
5425 .stream_prompt("think before answering")
5426 .with_history(Vec::<Message>::new())
5427 .await;
5428
5429 let mut history_in_final = None;
5430 while let Some(item) = stream.next().await {
5431 match item {
5432 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5433 history_in_final = res.history().map(|h| h.to_vec());
5434 break;
5435 }
5436 Ok(_) => {}
5437 Err(err) => panic!("unexpected streaming error: {err:?}"),
5438 }
5439 }
5440
5441 let final_history = history_in_final
5442 .expect("FinalResponse.history should be populated when with_history is used");
5443 assert_eq!(
5444 final_history.len(),
5445 2,
5446 "user prompt + one assistant response in final history: {final_history:?}"
5447 );
5448
5449 assert!(matches!(
5450 final_history.first(),
5451 Some(Message::User { content })
5452 if matches!(
5453 content.first(),
5454 UserContent::Text(text) if text.text == "think before answering"
5455 )
5456 ));
5457
5458 let assistant_messages = final_history
5459 .iter()
5460 .filter_map(|message| match message {
5461 Message::Assistant { content, .. } => Some(content),
5462 _ => None,
5463 })
5464 .collect::<Vec<_>>();
5465 assert_eq!(
5466 assistant_messages.len(),
5467 1,
5468 "reasoning turn should produce exactly one assistant history message: {final_history:?}"
5469 );
5470 let assistant_content = assistant_messages
5471 .first()
5472 .expect("expected assistant history message");
5473 assert!(assistant_content.iter().any(|item| matches!(
5474 item,
5475 AssistantContent::Text(text) if text.text == "final answer"
5476 )));
5477 assert!(assistant_content.iter().any(|item| matches!(
5478 item,
5479 AssistantContent::Reasoning(reasoning)
5480 if reasoning.id.as_deref() == Some("rs_1")
5481 && reasoning.content.iter().any(|content| matches!(
5482 content,
5483 ReasoningContent::Text { text, .. } if text == "reasoned step"
5484 ))
5485 )));
5486 }
5487
5488 #[tokio::test]
5489 async fn streaming_with_history_overrides_memory() {
5490 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5491
5492 let memory = InMemoryConversationMemory::new();
5493 memory
5494 .append("t1", vec![Message::user("from-memory")])
5495 .await
5496 .unwrap();
5497
5498 let agent = AgentBuilder::new(streaming_text_then_final_model())
5499 .memory(memory.clone())
5500 .build();
5501
5502 let mut stream = agent
5503 .stream_prompt("hi")
5504 .conversation("t1")
5505 .with_history(vec![Message::user("from-caller")])
5506 .await;
5507
5508 while let Some(item) = stream.next().await {
5509 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5510 break;
5511 }
5512 }
5513
5514 let stored = memory.load("t1").await.unwrap();
5515 assert_eq!(
5516 stored.len(),
5517 1,
5518 "with_history bypasses memory; only the pre-seeded entry remains: {stored:?}"
5519 );
5520 }
5521
5522 #[tokio::test]
5523 async fn streaming_without_memory_disables_for_request() {
5524 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5525
5526 let memory = InMemoryConversationMemory::new();
5527 let agent = AgentBuilder::new(streaming_text_then_final_model())
5528 .memory(memory.clone())
5529 .conversation_id("default")
5530 .build();
5531
5532 let mut stream = agent.stream_prompt("hi").without_memory().await;
5533
5534 while let Some(item) = stream.next().await {
5535 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5536 break;
5537 }
5538 }
5539
5540 let stored = memory.load("default").await.unwrap();
5541 assert!(stored.is_empty(), "without_memory disables save");
5542 }
5543
5544 #[tokio::test]
5545 async fn streaming_load_error_yields_memory_error() {
5546 let agent = AgentBuilder::new(streaming_text_then_final_model())
5547 .memory(FailingMemory::default())
5548 .build();
5549
5550 let mut stream = agent.stream_prompt("hi").conversation("t1").await;
5551
5552 let first = stream.next().await.expect("at least one item");
5553 match first {
5554 Err(err) => {
5555 let msg = format!("{err:?}");
5556 assert!(
5557 msg.contains("Memory") || msg.contains("memory") || msg.contains("load boom"),
5558 "expected memory error, got: {msg}"
5559 );
5560 }
5561 Ok(other) => panic!("expected memory error, got {other:?}"),
5562 }
5563 }
5564
5565 #[tokio::test]
5566 async fn streaming_with_filter_shapes_loaded_history() {
5567 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5568
5569 let memory = InMemoryConversationMemory::new()
5570 .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
5571 memory
5572 .append(
5573 "t1",
5574 vec![
5575 Message::user("1"),
5576 Message::assistant("2"),
5577 Message::user("3"),
5578 Message::assistant("4"),
5579 ],
5580 )
5581 .await
5582 .unwrap();
5583
5584 let model = MockCompletionModel::from_stream_turns([[
5585 MockStreamEvent::text("ok"),
5586 MockStreamEvent::final_response_with_total_tokens(1),
5587 ]]);
5588 let recorded = model.clone();
5589 let agent = AgentBuilder::new(model).memory(memory).build();
5590
5591 let mut stream = agent.stream_prompt("ping").conversation("t1").await;
5592 while let Some(item) = stream.next().await {
5593 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5594 break;
5595 }
5596 }
5597
5598 let received = recorded.requests()[0]
5599 .chat_history
5600 .iter()
5601 .cloned()
5602 .collect::<Vec<_>>();
5603 assert_eq!(
5604 received.len(),
5605 3,
5606 "window-truncated history (2) + current prompt: {received:?}"
5607 );
5608 }
5609
5610 #[tokio::test]
5611 async fn streaming_append_error_does_not_suppress_final_response() {
5612 let agent = AgentBuilder::new(streaming_text_then_final_model())
5613 .memory(AppendFailingMemory::default())
5614 .build();
5615
5616 let mut stream = agent.stream_prompt("hi").conversation("t1").await;
5617
5618 let mut saw_final = false;
5619 while let Some(item) = stream.next().await {
5620 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5621 saw_final = true;
5622 break;
5623 }
5624 }
5625 assert!(
5626 saw_final,
5627 "FinalResponse must be yielded even when memory.append fails"
5628 );
5629 }
5630}