Skip to main content

rig_core/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::completion::{DynamicContextStore, build_prepared_completion_request},
4    agent::prompt_request::{
5        HookAction, InvalidToolCallResolution, TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER,
6        hooks::PromptHook, resolve_invalid_tool_call, validate_tool_call_name,
7    },
8    completion::{Document, GetTokenUsage},
9    json_utils,
10    memory::ConversationMemory,
11    message::{
12        AssistantContent, ToolCall, ToolChoice, ToolFunction, ToolResult, ToolResultContent,
13        UserContent,
14    },
15    streaming::{StreamedAssistantContent, StreamedUserContent, ToolCallDeltaContent},
16    tool::server::ToolServerHandle,
17    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
18};
19use futures::{Stream, StreamExt};
20use serde::{Deserialize, Serialize};
21use std::{collections::HashMap, pin::Pin, sync::Arc};
22use tracing::info_span;
23use tracing_futures::Instrument;
24
25use super::{CompletionCall, ToolCallHookAction, reported_usage};
26use crate::{
27    agent::Agent,
28    completion::{CompletionError, CompletionModel, PromptError},
29    message::{Message, Text},
30    tool::ToolSetError,
31};
32
33#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
34pub type StreamingResult<R> =
35    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
36
37#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
38pub type StreamingResult<R> =
39    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
40
41#[derive(Deserialize, Serialize, Debug, Clone)]
42#[serde(tag = "type", rename_all = "camelCase")]
43#[non_exhaustive]
44pub enum MultiTurnStreamItem<R> {
45    /// A streamed assistant content item.
46    StreamAssistantItem(StreamedAssistantContent<R>),
47    /// A streamed user content item (mostly for tool results).
48    StreamUserItem(StreamedUserContent),
49    /// Details for one successfully completed completion request made by this agent stream.
50    ///
51    /// This is emitted when a provider call finishes. Usage is the provider's
52    /// final usage for that completion request when available; it is not
53    /// incremental per streamed token.
54    ///
55    /// ```rust,ignore
56    /// match item {
57    ///     MultiTurnStreamItem::CompletionCall(completion_call) => {
58    ///         let context_tokens = completion_call.usage.map(|usage| usage.input_tokens);
59    ///     }
60    ///     _ => {}
61    /// }
62    /// ```
63    CompletionCall(CompletionCall),
64    /// The final result from the stream.
65    FinalResponse(FinalResponse),
66}
67
68#[derive(Deserialize, Serialize, Debug, Clone)]
69#[serde(rename_all = "camelCase")]
70pub struct FinalResponse {
71    /// Structured assistant content for the final turn.
72    content: OneOrMany<AssistantContent>,
73    /// Concatenated assistant text for the final turn.
74    /// This is empty only when the turn completed without emitting any text.
75    response: String,
76    aggregated_usage: crate::completion::Usage,
77    /// Successfully completed completion requests made by this agent stream.
78    #[serde(default, skip_serializing_if = "Vec::is_empty")]
79    completion_calls: Vec<CompletionCall>,
80    #[serde(skip_serializing_if = "Option::is_none")]
81    history: Option<Vec<Message>>,
82}
83
84impl FinalResponse {
85    pub fn empty() -> Self {
86        Self::new(
87            OneOrMany::one(AssistantContent::text("")),
88            crate::completion::Usage::new(),
89            None,
90        )
91    }
92
93    pub fn new(
94        content: OneOrMany<AssistantContent>,
95        aggregated_usage: crate::completion::Usage,
96        history: Option<Vec<Message>>,
97    ) -> Self {
98        let response = assistant_text_from_choice(&content);
99        Self {
100            content,
101            response,
102            aggregated_usage,
103            completion_calls: Vec::new(),
104            history,
105        }
106    }
107
108    /// Returns the concatenated assistant text for the final turn.
109    pub fn response(&self) -> &str {
110        &self.response
111    }
112
113    /// Returns the structured assistant content for the final turn.
114    pub fn content(&self) -> &OneOrMany<AssistantContent> {
115        &self.content
116    }
117
118    /// Returns the structured assistant content for the final turn.
119    pub fn assistant_content(&self) -> &OneOrMany<AssistantContent> {
120        &self.content
121    }
122
123    pub fn usage(&self) -> crate::completion::Usage {
124        self.aggregated_usage
125    }
126
127    /// Returns successfully completed completion requests made by this agent stream, with usage when available.
128    ///
129    /// Each entry represents one provider completion request. When present,
130    /// usage is a whole-request provider snapshot, not incremental usage per
131    /// streamed token. Streaming providers may omit usage for some calls; those
132    /// calls have an entry with `None` usage.
133    pub fn completion_calls(&self) -> &[CompletionCall] {
134        &self.completion_calls
135    }
136
137    pub fn history(&self) -> Option<&[Message]> {
138        self.history.as_deref()
139    }
140}
141
142impl<R> MultiTurnStreamItem<R> {
143    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
144        Self::StreamAssistantItem(item)
145    }
146
147    pub fn final_response(
148        content: OneOrMany<AssistantContent>,
149        aggregated_usage: crate::completion::Usage,
150    ) -> Self {
151        Self::FinalResponse(FinalResponse::new(content, aggregated_usage, None))
152    }
153
154    pub fn final_response_with_history(
155        content: OneOrMany<AssistantContent>,
156        aggregated_usage: crate::completion::Usage,
157        history: Option<Vec<Message>>,
158    ) -> Self {
159        Self::FinalResponse(FinalResponse::new(content, aggregated_usage, history))
160    }
161
162    pub(crate) fn final_response_with_completion_calls(
163        content: OneOrMany<AssistantContent>,
164        aggregated_usage: crate::completion::Usage,
165        completion_calls: Vec<CompletionCall>,
166        history: Option<Vec<Message>>,
167    ) -> Self {
168        let mut response = FinalResponse::new(content, aggregated_usage, history);
169        response.completion_calls = completion_calls;
170        Self::FinalResponse(response)
171    }
172}
173
174fn merge_reasoning_blocks(
175    accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
176    incoming: &crate::message::Reasoning,
177) {
178    let ids_match = |existing: &crate::message::Reasoning| {
179        matches!(
180            (&existing.id, &incoming.id),
181            (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
182        )
183    };
184
185    if let Some(existing) = accumulated_reasoning
186        .iter_mut()
187        .rev()
188        .find(|existing| ids_match(existing))
189    {
190        existing.content.extend(incoming.content.clone());
191    } else {
192        accumulated_reasoning.push(incoming.clone());
193    }
194}
195
196fn flush_pending_reasoning_delta(
197    accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
198    pending_reasoning_delta_text: &mut String,
199    pending_reasoning_delta_id: &mut Option<String>,
200) {
201    if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
202        let mut assembled = crate::message::Reasoning::new(&*pending_reasoning_delta_text);
203        if let Some(id) = pending_reasoning_delta_id.take() {
204            assembled = assembled.with_id(id);
205        }
206        accumulated_reasoning.push(assembled);
207        pending_reasoning_delta_text.clear();
208    }
209}
210
211/// Build full history for error reporting (input + new messages).
212fn build_full_history(
213    chat_history: Option<&[Message]>,
214    new_messages: Vec<Message>,
215) -> Vec<Message> {
216    let input = chat_history.unwrap_or(&[]);
217    input.iter().cloned().chain(new_messages).collect()
218}
219
220struct ToolCallValidationHistory<'a> {
221    chat_history: Option<&'a [Message]>,
222    new_messages: &'a [Message],
223    assistant_message_id: &'a Option<String>,
224    final_turn_content: Option<&'a OneOrMany<AssistantContent>>,
225    text_delta_response: Option<&'a str>,
226    accumulated_reasoning: &'a [crate::message::Reasoning],
227    pending_reasoning_delta_text: &'a str,
228    pending_reasoning_delta_id: &'a Option<String>,
229    pending_tool_calls: &'a [(ToolCall, String)],
230    current_tool_call: Option<ToolCall>,
231}
232
233fn build_tool_call_validation_history(input: ToolCallValidationHistory<'_>) -> Vec<Message> {
234    let mut messages = input.new_messages.to_vec();
235
236    if let Some(final_turn_content) = input.final_turn_content
237        && !is_empty_assistant_choice(final_turn_content)
238    {
239        messages.push(Message::Assistant {
240            id: input.assistant_message_id.clone(),
241            content: final_turn_content.clone(),
242        });
243        return build_full_history(input.chat_history, messages);
244    }
245
246    let mut content_items = Vec::new();
247    if let Some(text) = input.text_delta_response
248        && !text.is_empty()
249    {
250        content_items.push(AssistantContent::text(text.to_string()));
251    }
252    content_items.extend(
253        input
254            .accumulated_reasoning
255            .iter()
256            .cloned()
257            .map(AssistantContent::Reasoning),
258    );
259    if input.accumulated_reasoning.is_empty() && !input.pending_reasoning_delta_text.is_empty() {
260        let mut reasoning = crate::message::Reasoning::new(input.pending_reasoning_delta_text);
261        if let Some(id) = input.pending_reasoning_delta_id.clone() {
262            reasoning = reasoning.with_id(id);
263        }
264        content_items.push(AssistantContent::Reasoning(reasoning));
265    }
266    content_items.extend(
267        input
268            .pending_tool_calls
269            .iter()
270            .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone())),
271    );
272    if let Some(tool_call) = input.current_tool_call {
273        content_items.push(AssistantContent::ToolCall(tool_call));
274    }
275
276    if let Some(content) = OneOrMany::from_iter_optional(content_items) {
277        messages.push(Message::Assistant {
278            id: input.assistant_message_id.clone(),
279            content,
280        });
281    }
282
283    build_full_history(input.chat_history, messages)
284}
285
286async fn drain_recovered_stream_usage<R>(
287    stream: &mut crate::streaming::StreamingCompletionResponse<R>,
288    tool_call_delta_states: &HashMap<(String, String), ToolCallDeltaState>,
289    current_call_usage: &mut Option<crate::completion::Usage>,
290    aggregated_usage: &mut crate::completion::Usage,
291) -> Result<(), StreamingError>
292where
293    R: Clone + Unpin + GetTokenUsage,
294{
295    if let Some(err) = pending_tool_call_delta_error(tool_call_delta_states) {
296        return Err(err.into());
297    }
298
299    while let Some(content) = stream.next().await {
300        match content {
301            Ok(StreamedAssistantContent::Final(final_resp)) => {
302                if let Some(usage) = final_resp.token_usage() {
303                    *current_call_usage = reported_usage(usage);
304                }
305                if let Some(usage) = *current_call_usage {
306                    *aggregated_usage += usage;
307                }
308                return Ok(());
309            }
310            Ok(_) => {}
311            Err(err) => return Err(err.into()),
312        }
313    }
314
315    Ok(())
316}
317
318fn record_completion_call_if_needed(
319    completion_calls: &mut Vec<CompletionCall>,
320    completion_call_emitted: &mut bool,
321    call_index: usize,
322    current_call_usage: Option<crate::completion::Usage>,
323    chat_stream_span: &tracing::Span,
324) -> Option<CompletionCall> {
325    if *completion_call_emitted {
326        return None;
327    }
328
329    if let Some(usage) = current_call_usage {
330        record_usage_on_span(chat_stream_span, usage);
331    }
332
333    let completion_call = CompletionCall::new(call_index, current_call_usage);
334    completion_calls.push(completion_call);
335    *completion_call_emitted = true;
336    Some(completion_call)
337}
338
339fn record_usage_on_span(span: &tracing::Span, usage: crate::completion::Usage) {
340    span.record("gen_ai.usage.input_tokens", usage.input_tokens);
341    span.record("gen_ai.usage.output_tokens", usage.output_tokens);
342    span.record(
343        "gen_ai.usage.cache_read.input_tokens",
344        usage.cached_input_tokens,
345    );
346    span.record(
347        "gen_ai.usage.cache_creation.input_tokens",
348        usage.cache_creation_input_tokens,
349    );
350    span.record(
351        "gen_ai.usage.tool_use_prompt_tokens",
352        usage.tool_use_prompt_tokens,
353    );
354    span.record("gen_ai.usage.reasoning_tokens", usage.reasoning_tokens);
355}
356
357/// Combine input history with new messages for building completion requests.
358fn build_history_for_request(
359    chat_history: Option<&[Message]>,
360    new_messages: &[Message],
361) -> Vec<Message> {
362    let input = chat_history.unwrap_or(&[]);
363    input.iter().chain(new_messages.iter()).cloned().collect()
364}
365
366async fn cancelled_prompt_error(
367    chat_history: Option<&[Message]>,
368    new_messages: Vec<Message>,
369    reason: String,
370) -> StreamingError {
371    StreamingError::Prompt(
372        PromptError::prompt_cancelled(build_full_history(chat_history, new_messages), reason)
373            .into(),
374    )
375}
376
377fn tool_result_user_content(
378    id: String,
379    call_id: Option<String>,
380    tool_result: String,
381) -> UserContent {
382    let content = ToolResultContent::from_tool_output(tool_result);
383    match call_id {
384        Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
385        None => UserContent::tool_result(id, content),
386    }
387}
388
389fn invalid_streaming_tool_retry_messages(
390    assistant_message_id: &Option<String>,
391    text_delta_response: Option<&str>,
392    accumulated_reasoning: &[crate::message::Reasoning],
393    pending_tool_calls: &[(ToolCall, String)],
394    invalid_tool_call: ToolCall,
395    feedback: String,
396) -> Option<(Message, Message)> {
397    let mut assistant_content = Vec::new();
398    if let Some(text) = text_delta_response
399        && !text.is_empty()
400    {
401        assistant_content.push(AssistantContent::text(text.to_string()));
402    }
403    assistant_content.extend(
404        accumulated_reasoning
405            .iter()
406            .cloned()
407            .map(AssistantContent::Reasoning),
408    );
409    assistant_content.extend(
410        pending_tool_calls
411            .iter()
412            .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone())),
413    );
414    assistant_content.push(AssistantContent::ToolCall(invalid_tool_call.clone()));
415
416    let assistant_content = OneOrMany::from_iter_optional(assistant_content)?;
417    let assistant_message = Message::Assistant {
418        id: assistant_message_id.clone(),
419        content: assistant_content,
420    };
421
422    let mut retry_results = pending_tool_calls
423        .iter()
424        .map(|(tool_call, _)| {
425            tool_result_user_content(
426                tool_call.id.clone(),
427                tool_call.call_id.clone(),
428                TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
429            )
430        })
431        .collect::<Vec<_>>();
432    retry_results.push(tool_result_user_content(
433        invalid_tool_call.id,
434        invalid_tool_call.call_id,
435        feedback,
436    ));
437
438    let user_message = Message::User {
439        content: OneOrMany::from_iter_optional(retry_results)?,
440    };
441
442    Some((assistant_message, user_message))
443}
444
445fn invalid_streaming_name_delta_retry_messages(
446    assistant_message_id: &Option<String>,
447    text_delta_response: Option<&str>,
448    accumulated_reasoning: &[crate::message::Reasoning],
449    pending_tool_calls: &[(ToolCall, String)],
450    invalid_tool_call: ToolCall,
451    feedback: String,
452) -> Option<(Message, Message)> {
453    let mut assistant_content = Vec::new();
454    if let Some(text) = text_delta_response
455        && !text.is_empty()
456    {
457        assistant_content.push(AssistantContent::text(text.to_string()));
458    }
459    assistant_content.extend(
460        accumulated_reasoning
461            .iter()
462            .cloned()
463            .map(AssistantContent::Reasoning),
464    );
465    assistant_content.extend(
466        pending_tool_calls
467            .iter()
468            .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone())),
469    );
470    assistant_content.push(AssistantContent::ToolCall(invalid_tool_call.clone()));
471
472    let assistant_message = Message::Assistant {
473        id: assistant_message_id.clone(),
474        content: OneOrMany::from_iter_optional(assistant_content)?,
475    };
476    let mut retry_results = pending_tool_calls
477        .iter()
478        .map(|(tool_call, _)| {
479            tool_result_user_content(
480                tool_call.id.clone(),
481                tool_call.call_id.clone(),
482                TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
483            )
484        })
485        .collect::<Vec<_>>();
486    retry_results.push(tool_result_user_content(
487        invalid_tool_call.id,
488        invalid_tool_call.call_id,
489        feedback,
490    ));
491    let user_message = Message::User {
492        content: OneOrMany::from_iter_optional(retry_results)?,
493    };
494
495    Some((assistant_message, user_message))
496}
497
498fn assistant_text_from_choice(choice: &OneOrMany<AssistantContent>) -> String {
499    choice
500        .iter()
501        .filter_map(|content| match content {
502            AssistantContent::Text(text) => Some(text.text.as_str()),
503            _ => None,
504        })
505        .collect()
506}
507
508fn assistant_text_items_from_choice(choice: &OneOrMany<AssistantContent>) -> Vec<AssistantContent> {
509    choice
510        .iter()
511        .filter_map(|content| match content {
512            AssistantContent::Text(text) => (!text.text.is_empty()
513                || text.additional_params.is_some())
514            .then(|| AssistantContent::Text(text.clone())),
515            _ => None,
516        })
517        .collect()
518}
519
520fn is_empty_assistant_choice(choice: &OneOrMany<AssistantContent>) -> bool {
521    choice.len() == 1
522        && matches!(
523            choice.first(),
524            AssistantContent::Text(text)
525                if text.text.is_empty() && text.additional_params.is_none()
526        )
527}
528
529#[derive(Default)]
530struct ToolCallDeltaState {
531    name_validated: bool,
532    buffered_arguments: Vec<String>,
533}
534
535fn pending_tool_call_delta_error(
536    states: &HashMap<(String, String), ToolCallDeltaState>,
537) -> Option<CompletionError> {
538    states
539        .iter()
540        .find(|(_, state)| !state.name_validated && !state.buffered_arguments.is_empty())
541        .map(|((id, internal_call_id), state)| {
542            CompletionError::ResponseError(format!(
543                "streamed tool call arguments received before a validated tool name for id `{id}` and internal_call_id `{internal_call_id}` ({} buffered argument delta(s))",
544                state.buffered_arguments.len()
545            ))
546        })
547}
548
549#[derive(Debug, thiserror::Error)]
550pub enum StreamingError {
551    #[error("CompletionError: {0}")]
552    Completion(#[from] CompletionError),
553    #[error("PromptError: {0}")]
554    Prompt(#[from] Box<PromptError>),
555    #[error("ToolSetError: {0}")]
556    Tool(#[from] ToolSetError),
557}
558
559/// Surface [`crate::memory::ConversationMemory`] failures through the existing
560/// [`CompletionError::RequestError`] variant so adding memory support does not
561/// require a new top-level [`StreamingError`] arm.
562impl From<crate::memory::MemoryError> for StreamingError {
563    fn from(err: crate::memory::MemoryError) -> Self {
564        Self::Completion(CompletionError::RequestError(Box::new(err)))
565    }
566}
567
568const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
569
570/// A builder for creating prompt requests with customizable options.
571/// Uses generics to track which options have been set during the build process.
572///
573/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
574/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
575/// attempting to await (which will send the prompt request) can potentially return
576/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
577/// back to back.
578pub struct StreamingPromptRequest<M, P>
579where
580    M: CompletionModel,
581    P: PromptHook<M> + 'static,
582{
583    /// The prompt message to send to the model
584    prompt: Message,
585    /// Optional chat history provided by the caller.
586    chat_history: Option<Vec<Message>>,
587    /// Maximum Turns for multi-turn conversations (0 means no multi-turn)
588    max_turns: usize,
589
590    // Agent data (cloned from agent to allow hook type transitions):
591    /// The completion model
592    model: Arc<M>,
593    /// Agent name for logging
594    agent_name: Option<String>,
595    /// System prompt
596    preamble: Option<String>,
597    /// Static context documents
598    static_context: Vec<Document>,
599    /// Temperature setting
600    temperature: Option<f64>,
601    /// Max tokens setting
602    max_tokens: Option<u64>,
603    /// Additional model parameters
604    additional_params: Option<serde_json::Value>,
605    /// Tool server handle for tool execution
606    tool_server_handle: ToolServerHandle,
607    /// Dynamic context store
608    dynamic_context: DynamicContextStore,
609    /// Tool choice setting
610    tool_choice: Option<ToolChoice>,
611    /// Optional JSON Schema for structured output
612    output_schema: Option<schemars::Schema>,
613    /// Optional per-request hook for events
614    hook: Option<P>,
615    /// Maximum number of invalid tool-call retries for this request.
616    max_invalid_tool_call_retries: usize,
617    /// Optional conversation memory backend cloned from the agent.
618    memory: Option<Arc<dyn ConversationMemory>>,
619    /// Optional conversation id used for loading and saving memory.
620    conversation_id: Option<String>,
621}
622
623impl<M, P> StreamingPromptRequest<M, P>
624where
625    M: CompletionModel + 'static,
626    <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
627    P: PromptHook<M>,
628{
629    /// Create a new StreamingPromptRequest with the given prompt and model.
630    /// Note: This creates a request without an agent hook. Use `from_agent` to include the agent's hook.
631    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
632        StreamingPromptRequest {
633            prompt: prompt.into(),
634            chat_history: None,
635            max_turns: agent.default_max_turns.unwrap_or_default(),
636            model: agent.model.clone(),
637            agent_name: agent.name.clone(),
638            preamble: agent.preamble.clone(),
639            static_context: agent.static_context.clone(),
640            temperature: agent.temperature,
641            max_tokens: agent.max_tokens,
642            additional_params: agent.additional_params.clone(),
643            tool_server_handle: agent.tool_server_handle.clone(),
644            dynamic_context: agent.dynamic_context.clone(),
645            tool_choice: agent.tool_choice.clone(),
646            output_schema: agent.output_schema.clone(),
647            hook: None,
648            max_invalid_tool_call_retries: 0,
649            memory: agent.memory.clone(),
650            conversation_id: agent.default_conversation_id.clone(),
651        }
652    }
653
654    /// Create a new StreamingPromptRequest from an agent, cloning the agent's data and default hook.
655    pub fn from_agent<P2>(
656        agent: &Agent<M, P2>,
657        prompt: impl Into<Message>,
658    ) -> StreamingPromptRequest<M, P2>
659    where
660        P2: PromptHook<M>,
661    {
662        StreamingPromptRequest {
663            prompt: prompt.into(),
664            chat_history: None,
665            max_turns: agent.default_max_turns.unwrap_or_default(),
666            model: agent.model.clone(),
667            agent_name: agent.name.clone(),
668            preamble: agent.preamble.clone(),
669            static_context: agent.static_context.clone(),
670            temperature: agent.temperature,
671            max_tokens: agent.max_tokens,
672            additional_params: agent.additional_params.clone(),
673            tool_server_handle: agent.tool_server_handle.clone(),
674            dynamic_context: agent.dynamic_context.clone(),
675            tool_choice: agent.tool_choice.clone(),
676            output_schema: agent.output_schema.clone(),
677            hook: agent.hook.clone(),
678            max_invalid_tool_call_retries: 0,
679            memory: agent.memory.clone(),
680            conversation_id: agent.default_conversation_id.clone(),
681        }
682    }
683
684    fn agent_name(&self) -> &str {
685        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
686    }
687
688    /// Set the maximum Turns for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
689    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
690    pub fn multi_turn(mut self, turns: usize) -> Self {
691        self.max_turns = turns;
692        self
693    }
694
695    /// Add chat history to the prompt request.
696    ///
697    /// When history is provided, the final [`FinalResponse`] will include the
698    /// updated chat history (original messages + new user prompt + assistant response).
699    /// ```ignore
700    /// let mut stream = agent
701    ///     .stream_prompt("Hello")
702    ///     .with_history(vec![])
703    ///     .await;
704    /// // ... consume stream ...
705    /// // Access updated history from FinalResponse::history()
706    /// ```
707    pub fn with_history<H, T>(mut self, history: H) -> Self
708    where
709        H: IntoIterator<Item = T>,
710        T: Into<Message>,
711    {
712        self.chat_history = Some(history.into_iter().map(Into::into).collect());
713        self
714    }
715
716    /// Attach a per-request hook for tool call events.
717    /// This overrides any default hook set on the agent.
718    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
719    where
720        P2: PromptHook<M>,
721    {
722        StreamingPromptRequest {
723            prompt: self.prompt,
724            chat_history: self.chat_history,
725            max_turns: self.max_turns,
726            model: self.model,
727            agent_name: self.agent_name,
728            preamble: self.preamble,
729            static_context: self.static_context,
730            temperature: self.temperature,
731            max_tokens: self.max_tokens,
732            additional_params: self.additional_params,
733            tool_server_handle: self.tool_server_handle,
734            dynamic_context: self.dynamic_context,
735            tool_choice: self.tool_choice,
736            output_schema: self.output_schema,
737            hook: Some(hook),
738            max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
739            memory: self.memory,
740            conversation_id: self.conversation_id,
741        }
742    }
743
744    /// Set the retry budget for [`crate::agent::prompt_request::hooks::InvalidToolCallHookAction::Retry`].
745    ///
746    /// Invalid tool-call retries also consume normal multi-turn depth.
747    pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
748        self.max_invalid_tool_call_retries = retries;
749        self
750    }
751
752    /// Set the conversation id used to load and persist memory for this request.
753    ///
754    /// Overrides any default conversation id set on the agent. If memory is not
755    /// configured on the agent, this has no effect.
756    pub fn conversation(mut self, id: impl Into<String>) -> Self {
757        self.conversation_id = Some(id.into());
758        self
759    }
760
761    /// Disable conversation memory for this request.
762    ///
763    /// History will neither be loaded from nor saved to the agent's memory backend.
764    pub fn without_memory(mut self) -> Self {
765        self.memory = None;
766        self.conversation_id = None;
767        self
768    }
769
770    async fn send(self) -> StreamingResult<M::StreamingResponse> {
771        let (agent_span, created_agent_span) = if tracing::Span::current().is_disabled() {
772            (
773                info_span!(
774                    "invoke_agent",
775                    gen_ai.operation.name = "invoke_agent",
776                    gen_ai.agent.name = self.agent_name(),
777                    gen_ai.system_instructions = self.preamble,
778                    gen_ai.prompt = tracing::field::Empty,
779                    gen_ai.completion = tracing::field::Empty,
780                    gen_ai.usage.input_tokens = tracing::field::Empty,
781                    gen_ai.usage.output_tokens = tracing::field::Empty,
782                    gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
783                    gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
784                    gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
785                    gen_ai.usage.reasoning_tokens = tracing::field::Empty,
786                ),
787                true,
788            )
789        } else {
790            (tracing::Span::current(), false)
791        };
792
793        let prompt = self.prompt;
794        if let Some(text) = prompt.rag_text() {
795            agent_span.record("gen_ai.prompt", text);
796        }
797
798        // Clone fields needed inside the stream
799        let model = self.model.clone();
800        let preamble = self.preamble.clone();
801        let static_context = self.static_context.clone();
802        let temperature = self.temperature;
803        let max_tokens = self.max_tokens;
804        let additional_params = self.additional_params.clone();
805        let tool_server_handle = self.tool_server_handle.clone();
806        let dynamic_context = self.dynamic_context.clone();
807        let tool_choice = self.tool_choice.clone();
808        let agent_name = self.agent_name.clone();
809        // When the caller passes explicit history, memory is fully bypassed for
810        // this request (no load AND no save). Otherwise, if a memory backend and
811        // conversation id are both configured, load prior history; if either is
812        // missing, behave as if no memory is configured.
813        let (chat_history, memory_handle) = match self.chat_history {
814            Some(history) => (Some(history), None),
815            None => match (self.memory, self.conversation_id) {
816                (Some(memory), Some(id)) => match memory.load(&id).await {
817                    Ok(loaded) => (Some(loaded), Some((memory, id))),
818                    Err(err) => {
819                        let stream = async_stream::stream! {
820                            yield Err(StreamingError::from(err));
821                        };
822                        return Box::pin(stream);
823                    }
824                },
825                _ => (None, None),
826            },
827        };
828        let has_history = chat_history.is_some();
829        let mut new_messages: Vec<Message> = vec![prompt.clone()];
830
831        let mut current_max_turns = 0;
832        let mut last_prompt_error = String::new();
833
834        let mut text_delta_response = String::new();
835        let mut saw_text_this_turn = false;
836        let mut max_turns_reached = false;
837        let output_schema = self.output_schema;
838
839        let mut aggregated_usage = crate::completion::Usage::new();
840        let mut completion_calls = Vec::new();
841        let mut completion_call_index = 0;
842        let mut invalid_tool_call_retries = 0;
843
844        // NOTE: We use .instrument(agent_span) instead of span.enter() to avoid
845        // span context leaking to other concurrent tasks. Using span.enter() inside
846        // async_stream::stream! holds the guard across yield points, which causes
847        // thread-local span context to leak when other tasks run on the same thread.
848        // See: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#in-asynchronous-code
849        // See also: https://github.com/rust-lang/rust-clippy/issues/8722
850        let stream = async_stream::stream! {
851            'outer: loop {
852                let Some((current_prompt_ref, previous_messages)) = new_messages.split_last() else {
853                    yield Err(cancelled_prompt_error(
854                        chat_history.as_deref(),
855                        new_messages.clone(),
856                        "streaming loop lost its pending prompt".to_string(),
857                    ).await);
858                    break 'outer;
859                };
860                let current_prompt = current_prompt_ref.clone();
861
862                if current_max_turns > self.max_turns + 1 {
863                    last_prompt_error = current_prompt.rag_text().unwrap_or_default();
864                    max_turns_reached = true;
865                    break;
866                }
867
868                current_max_turns += 1;
869
870                if self.max_turns > 1 {
871                    tracing::info!(
872                        "Current conversation Turns: {}/{}",
873                        current_max_turns,
874                        self.max_turns
875                    );
876                }
877
878                let history_snapshot: Vec<Message> = build_history_for_request(
879                    chat_history.as_deref(),
880                    previous_messages,
881                );
882
883                if let Some(ref hook) = self.hook
884                    && let HookAction::Terminate { reason } =
885                        hook.on_completion_call(&current_prompt, &history_snapshot).await
886                {
887                    yield Err(
888                        cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason)
889                            .await,
890                    );
891                    break 'outer;
892                }
893
894                let chat_stream_span = info_span!(
895                    target: "rig::agent_chat",
896                    parent: tracing::Span::current(),
897                    "chat_streaming",
898                    gen_ai.operation.name = "chat",
899                    gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
900                    gen_ai.system_instructions = preamble,
901                    gen_ai.provider.name = tracing::field::Empty,
902                    gen_ai.request.model = tracing::field::Empty,
903                    gen_ai.response.id = tracing::field::Empty,
904                    gen_ai.response.model = tracing::field::Empty,
905                    gen_ai.usage.output_tokens = tracing::field::Empty,
906                    gen_ai.usage.input_tokens = tracing::field::Empty,
907                    gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
908                    gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
909                    gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
910                    gen_ai.usage.reasoning_tokens = tracing::field::Empty,
911                    gen_ai.input.messages = tracing::field::Empty,
912                    gen_ai.output.messages = tracing::field::Empty,
913                );
914
915                let prepared_request = build_prepared_completion_request(
916                    &model,
917                    current_prompt.clone(),
918                    &history_snapshot,
919                    preamble.as_deref(),
920                    &static_context,
921                    temperature,
922                    max_tokens,
923                    additional_params.as_ref(),
924                    tool_choice.as_ref(),
925                    &tool_server_handle,
926                    &dynamic_context,
927                    output_schema.as_ref(),
928                )
929                .await?;
930                let executable_tool_names = prepared_request.executable_tool_names.clone();
931                let allowed_tool_names = prepared_request.allowed_tool_names.clone();
932
933                let mut stream = prepared_request
934                    .builder
935                    .stream()
936                    .instrument(chat_stream_span.clone())
937                    .await?;
938
939                let call_index = completion_call_index;
940                completion_call_index += 1;
941                let mut current_call_usage = None;
942                let mut completion_call_emitted = false;
943                let mut pending_tool_calls: Vec<(ToolCall, String)> = vec![];
944                let mut tool_calls = vec![];
945                let mut tool_results = vec![];
946                let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
947                // Kept separate from accumulated_reasoning so providers requiring
948                // signatures (e.g. Anthropic) never see unsigned blocks.
949                let mut pending_reasoning_delta_text = String::new();
950                let mut pending_reasoning_delta_id: Option<String> = None;
951                let mut tool_call_delta_states: HashMap<(String, String), ToolCallDeltaState> =
952                    HashMap::new();
953                let mut saw_tool_call_this_turn = false;
954
955                while let Some(content) = stream.next().await {
956                    match content {
957                        Ok(StreamedAssistantContent::Text(text)) => {
958                            if !saw_text_this_turn {
959                                text_delta_response.clear();
960                                saw_text_this_turn = true;
961                            }
962                            text_delta_response.push_str(&text.text);
963                            if let Some(ref hook) = self.hook &&
964                                let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &text_delta_response).await {
965                                    yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
966                                    break 'outer;
967                            }
968
969                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
970                        },
971                        Ok(StreamedAssistantContent::ToolCall { mut tool_call, internal_call_id }) => {
972                            let diagnostic_history =
973                                build_tool_call_validation_history(ToolCallValidationHistory {
974                                    chat_history: chat_history.as_deref(),
975                                    new_messages: &new_messages,
976                                    assistant_message_id: &stream.message_id,
977                                    final_turn_content: None,
978                                    text_delta_response: saw_text_this_turn
979                                        .then_some(text_delta_response.as_str()),
980                                    accumulated_reasoning: &accumulated_reasoning,
981                                    pending_reasoning_delta_text: &pending_reasoning_delta_text,
982                                    pending_reasoning_delta_id: &pending_reasoning_delta_id,
983                                    pending_tool_calls: &pending_tool_calls,
984                                    current_tool_call: Some(tool_call.clone()),
985                                });
986
987                            if !allowed_tool_names.contains(&tool_call.function.name) {
988                                let args = json_utils::value_to_json_string(&tool_call.function.arguments);
989                                let emitted_tool_name = tool_call.function.name.clone();
990                                match resolve_invalid_tool_call::<M, P>(
991                                    self.hook.as_ref(),
992                                    &emitted_tool_name,
993                                    Some(tool_call.id.clone()),
994                                    Some(internal_call_id.clone()),
995                                    Some(args),
996                                    &executable_tool_names,
997                                    &allowed_tool_names,
998                                    self.tool_choice.as_ref(),
999                                    diagnostic_history.clone(),
1000                                    true,
1001                                ).await {
1002                                    InvalidToolCallResolution::Fail(err) => {
1003                                        yield Err(Box::new(err).into());
1004                                        break 'outer;
1005                                    }
1006                                    InvalidToolCallResolution::Retry(feedback) => {
1007                                        if invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
1008                                            yield Err(Box::new(PromptError::UnknownToolCall {
1009                                                tool_name: emitted_tool_name,
1010                                                available_tools: executable_tool_names.iter().cloned().collect(),
1011                                                allowed_tools: allowed_tool_names.iter().cloned().collect(),
1012                                                chat_history: Box::new(diagnostic_history),
1013                                            }).into());
1014                                            break 'outer;
1015                                        }
1016
1017                                        invalid_tool_call_retries += 1;
1018                                        flush_pending_reasoning_delta(
1019                                            &mut accumulated_reasoning,
1020                                            &mut pending_reasoning_delta_text,
1021                                            &mut pending_reasoning_delta_id,
1022                                        );
1023                                        let Some((assistant_message, user_message)) =
1024                                            invalid_streaming_tool_retry_messages(
1025                                                &stream.message_id,
1026                                                saw_text_this_turn.then_some(text_delta_response.as_str()),
1027                                                &accumulated_reasoning,
1028                                                &pending_tool_calls,
1029                                                tool_call,
1030                                                feedback,
1031                                            )
1032                                        else {
1033                                            yield Err(cancelled_prompt_error(
1034                                                chat_history.as_deref(),
1035                                                new_messages.clone(),
1036                                                "invalid tool call retry produced no retry messages".to_string(),
1037                                            ).await);
1038                                            break 'outer;
1039                                        };
1040                                        new_messages.push(assistant_message);
1041                                        new_messages.push(user_message);
1042                                        if let Err(err) = drain_recovered_stream_usage(
1043                                            &mut stream,
1044                                            &tool_call_delta_states,
1045                                            &mut current_call_usage,
1046                                            &mut aggregated_usage,
1047                                        )
1048                                        .await
1049                                        {
1050                                            yield Err(err);
1051                                            break 'outer;
1052                                        }
1053                                        if let Some(completion_call) = record_completion_call_if_needed(
1054                                            &mut completion_calls,
1055                                            &mut completion_call_emitted,
1056                                            call_index,
1057                                            current_call_usage,
1058                                            &chat_stream_span,
1059                                        ) {
1060                                            yield Ok(MultiTurnStreamItem::CompletionCall(
1061                                                completion_call,
1062                                            ));
1063                                        }
1064                                        text_delta_response.clear();
1065                                        saw_text_this_turn = false;
1066                                        continue 'outer;
1067                                    }
1068                                    InvalidToolCallResolution::Repair(repaired_name) => {
1069                                        tool_call.function.name = repaired_name;
1070                                    }
1071                                    InvalidToolCallResolution::Skip(reason) => {
1072                                        let skipped_tool_result = ToolResult {
1073                                            id: tool_call.id.clone(),
1074                                            call_id: tool_call.call_id.clone(),
1075                                            content: ToolResultContent::from_tool_output(
1076                                                reason.clone(),
1077                                            ),
1078                                        };
1079                                        flush_pending_reasoning_delta(
1080                                            &mut accumulated_reasoning,
1081                                            &mut pending_reasoning_delta_text,
1082                                            &mut pending_reasoning_delta_id,
1083                                        );
1084                                        let Some((assistant_message, user_message)) =
1085                                            invalid_streaming_tool_retry_messages(
1086                                                &stream.message_id,
1087                                                saw_text_this_turn
1088                                                    .then_some(text_delta_response.as_str()),
1089                                                &accumulated_reasoning,
1090                                                &pending_tool_calls,
1091                                                tool_call,
1092                                                reason,
1093                                            )
1094                                        else {
1095                                            yield Err(cancelled_prompt_error(
1096                                                chat_history.as_deref(),
1097                                                new_messages.clone(),
1098                                                "invalid tool call skip produced no recovery messages".to_string(),
1099                                            ).await);
1100                                            break 'outer;
1101                                        };
1102                                        new_messages.push(assistant_message);
1103                                        new_messages.push(user_message);
1104                                        let tool_result = ToolResult {
1105                                            id: skipped_tool_result.id,
1106                                            call_id: skipped_tool_result.call_id,
1107                                            content: skipped_tool_result.content,
1108                                        };
1109                                        if let Err(err) = drain_recovered_stream_usage(
1110                                            &mut stream,
1111                                            &tool_call_delta_states,
1112                                            &mut current_call_usage,
1113                                            &mut aggregated_usage,
1114                                        )
1115                                        .await
1116                                        {
1117                                            yield Err(err);
1118                                            break 'outer;
1119                                        }
1120                                        if let Some(completion_call) = record_completion_call_if_needed(
1121                                            &mut completion_calls,
1122                                            &mut completion_call_emitted,
1123                                            call_index,
1124                                            current_call_usage,
1125                                            &chat_stream_span,
1126                                        ) {
1127                                            yield Ok(MultiTurnStreamItem::CompletionCall(
1128                                                completion_call,
1129                                            ));
1130                                        }
1131                                        yield Ok(MultiTurnStreamItem::StreamUserItem(
1132                                            StreamedUserContent::ToolResult {
1133                                                tool_result,
1134                                                internal_call_id,
1135                                            },
1136                                        ));
1137                                        text_delta_response.clear();
1138                                        saw_text_this_turn = false;
1139                                        continue 'outer;
1140                                    }
1141                                }
1142                            }
1143
1144                            pending_tool_calls.push((tool_call, internal_call_id));
1145                        },
1146                        Ok(StreamedAssistantContent::ToolCallDelta {
1147                            id,
1148                            internal_call_id,
1149                            content,
1150                        }) => {
1151                            let key = (id.clone(), internal_call_id.clone());
1152                            let mut deltas_to_emit = Vec::new();
1153
1154                            match content {
1155                                ToolCallDeltaContent::Name(mut name) => {
1156                                    let buffered_args = tool_call_delta_states
1157                                        .get(&key)
1158                                        .map(|state| state.buffered_arguments.join(""))
1159                                        .unwrap_or_default();
1160                                    let diagnostic_args = if buffered_args.trim().is_empty() {
1161                                        serde_json::Value::Null
1162                                    } else {
1163                                        serde_json::from_str(&buffered_args)
1164                                            .unwrap_or(serde_json::Value::Null)
1165                                    };
1166                                    let diagnostic_tool_call = ToolCall::new(
1167                                        id.clone(),
1168                                        ToolFunction::new(name.clone(), diagnostic_args),
1169                                    );
1170                                    let diagnostic_history =
1171                                        build_tool_call_validation_history(ToolCallValidationHistory {
1172                                            chat_history: chat_history.as_deref(),
1173                                            new_messages: &new_messages,
1174                                            assistant_message_id: &stream.message_id,
1175                                            final_turn_content: None,
1176                                            text_delta_response: saw_text_this_turn
1177                                                .then_some(text_delta_response.as_str()),
1178                                            accumulated_reasoning: &accumulated_reasoning,
1179                                            pending_reasoning_delta_text: &pending_reasoning_delta_text,
1180                                            pending_reasoning_delta_id: &pending_reasoning_delta_id,
1181                                            pending_tool_calls: &pending_tool_calls,
1182                                            current_tool_call: Some(diagnostic_tool_call.clone()),
1183                                        });
1184
1185                                    if !allowed_tool_names.contains(&name) {
1186                                        let emitted_tool_name = name.clone();
1187                                        match resolve_invalid_tool_call::<M, P>(
1188                                            self.hook.as_ref(),
1189                                            &emitted_tool_name,
1190                                            Some(id.clone()),
1191                                            Some(internal_call_id.clone()),
1192                                            Some(buffered_args.clone()),
1193                                            &executable_tool_names,
1194                                            &allowed_tool_names,
1195                                            self.tool_choice.as_ref(),
1196                                            diagnostic_history.clone(),
1197                                            true,
1198                                        ).await {
1199                                            InvalidToolCallResolution::Fail(err) => {
1200                                                yield Err(Box::new(err).into());
1201                                                break 'outer;
1202                                            }
1203                                            InvalidToolCallResolution::Skip(reason) => {
1204                                                tool_call_delta_states.remove(&key);
1205                                                flush_pending_reasoning_delta(
1206                                                    &mut accumulated_reasoning,
1207                                                    &mut pending_reasoning_delta_text,
1208                                                    &mut pending_reasoning_delta_id,
1209                                                );
1210                                                let Some((assistant_message, user_message)) =
1211                                                    invalid_streaming_name_delta_retry_messages(
1212                                                        &stream.message_id,
1213                                                        saw_text_this_turn
1214                                                            .then_some(text_delta_response.as_str()),
1215                                                        &accumulated_reasoning,
1216                                                        &pending_tool_calls,
1217                                                        diagnostic_tool_call.clone(),
1218                                                        reason.clone(),
1219                                                    )
1220                                                else {
1221                                                    yield Err(cancelled_prompt_error(
1222                                                        chat_history.as_deref(),
1223                                                        new_messages.clone(),
1224                                                        "invalid tool call skip produced no recovery messages".to_string(),
1225                                                    ).await);
1226                                                    break 'outer;
1227                                                };
1228                                                new_messages.push(assistant_message);
1229                                                new_messages.push(user_message);
1230                                                let tool_result = ToolResult {
1231                                                    id,
1232                                                    call_id: None,
1233                                                    content: ToolResultContent::from_tool_output(
1234                                                        reason,
1235                                                    ),
1236                                                };
1237                                                if let Err(err) = drain_recovered_stream_usage(
1238                                                    &mut stream,
1239                                                    &tool_call_delta_states,
1240                                                    &mut current_call_usage,
1241                                                    &mut aggregated_usage,
1242                                                )
1243                                                .await
1244                                                {
1245                                                    yield Err(err);
1246                                                    break 'outer;
1247                                                }
1248                                                if let Some(completion_call) = record_completion_call_if_needed(
1249                                                    &mut completion_calls,
1250                                                    &mut completion_call_emitted,
1251                                                    call_index,
1252                                                    current_call_usage,
1253                                                    &chat_stream_span,
1254                                                ) {
1255                                                    yield Ok(MultiTurnStreamItem::CompletionCall(
1256                                                        completion_call,
1257                                                    ));
1258                                                }
1259                                                yield Ok(MultiTurnStreamItem::StreamUserItem(
1260                                                    StreamedUserContent::ToolResult {
1261                                                        tool_result,
1262                                                        internal_call_id,
1263                                                    },
1264                                                ));
1265                                                text_delta_response.clear();
1266                                                saw_text_this_turn = false;
1267                                                continue 'outer;
1268                                            }
1269                                            InvalidToolCallResolution::Retry(feedback) => {
1270                                                tool_call_delta_states.remove(&key);
1271                                                if invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
1272                                                    yield Err(Box::new(PromptError::UnknownToolCall {
1273                                                        tool_name: emitted_tool_name,
1274                                                        available_tools: executable_tool_names.iter().cloned().collect(),
1275                                                        allowed_tools: allowed_tool_names.iter().cloned().collect(),
1276                                                        chat_history: Box::new(diagnostic_history),
1277                                                    }).into());
1278                                                    break 'outer;
1279                                                }
1280
1281                                                invalid_tool_call_retries += 1;
1282                                                flush_pending_reasoning_delta(
1283                                                    &mut accumulated_reasoning,
1284                                                    &mut pending_reasoning_delta_text,
1285                                                    &mut pending_reasoning_delta_id,
1286                                                );
1287                                                let Some((assistant_message, user_message)) =
1288                                                    invalid_streaming_name_delta_retry_messages(
1289                                                        &stream.message_id,
1290                                                        saw_text_this_turn
1291                                                            .then_some(text_delta_response.as_str()),
1292                                                        &accumulated_reasoning,
1293                                                        &pending_tool_calls,
1294                                                        diagnostic_tool_call.clone(),
1295                                                        feedback,
1296                                                    )
1297                                                else {
1298                                                    yield Err(cancelled_prompt_error(
1299                                                        chat_history.as_deref(),
1300                                                        new_messages.clone(),
1301                                                        "invalid tool call retry produced no retry messages".to_string(),
1302                                                    ).await);
1303                                                    break 'outer;
1304                                                };
1305                                                new_messages.push(assistant_message);
1306                                                new_messages.push(user_message);
1307                                                if let Err(err) = drain_recovered_stream_usage(
1308                                                    &mut stream,
1309                                                    &tool_call_delta_states,
1310                                                    &mut current_call_usage,
1311                                                    &mut aggregated_usage,
1312                                                )
1313                                                .await
1314                                                {
1315                                                    yield Err(err);
1316                                                    break 'outer;
1317                                                }
1318                                                if let Some(completion_call) = record_completion_call_if_needed(
1319                                                    &mut completion_calls,
1320                                                    &mut completion_call_emitted,
1321                                                    call_index,
1322                                                    current_call_usage,
1323                                                    &chat_stream_span,
1324                                                ) {
1325                                                    yield Ok(MultiTurnStreamItem::CompletionCall(
1326                                                        completion_call,
1327                                                    ));
1328                                                }
1329                                                text_delta_response.clear();
1330                                                saw_text_this_turn = false;
1331                                                continue 'outer;
1332                                            }
1333                                            InvalidToolCallResolution::Repair(repaired_name) => {
1334                                                name = repaired_name;
1335                                            }
1336                                        }
1337                                    }
1338
1339                                    let state =
1340                                        tool_call_delta_states.entry(key.clone()).or_default();
1341                                    state.name_validated = true;
1342                                    let buffered_arguments =
1343                                        std::mem::take(&mut state.buffered_arguments);
1344
1345                                    deltas_to_emit.push(ToolCallDeltaContent::Name(name));
1346                                    deltas_to_emit.extend(
1347                                        buffered_arguments
1348                                            .into_iter()
1349                                            .map(ToolCallDeltaContent::Delta),
1350                                    );
1351                                }
1352                                ToolCallDeltaContent::Delta(arguments) => {
1353                                    let state =
1354                                        tool_call_delta_states.entry(key.clone()).or_default();
1355                                    if state.name_validated {
1356                                        deltas_to_emit.push(ToolCallDeltaContent::Delta(arguments));
1357                                    } else {
1358                                        state.buffered_arguments.push(arguments);
1359                                    }
1360                                }
1361                            }
1362
1363                            for content in deltas_to_emit {
1364                                if let Some(ref hook) = self.hook {
1365                                    let (name, delta) = match &content {
1366                                        ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
1367                                        ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
1368                                    };
1369
1370                                    if let HookAction::Terminate { reason } = hook
1371                                        .on_tool_call_delta(
1372                                            &id,
1373                                            &internal_call_id,
1374                                            name,
1375                                            delta,
1376                                        )
1377                                        .await
1378                                    {
1379                                        yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1380                                        break 'outer;
1381                                    }
1382                                }
1383
1384                                yield Ok(MultiTurnStreamItem::StreamAssistantItem(
1385                                    StreamedAssistantContent::ToolCallDelta {
1386                                        id: id.clone(),
1387                                        internal_call_id: internal_call_id.clone(),
1388                                        content,
1389                                    },
1390                                ));
1391                            }
1392                        }
1393                        Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
1394                            // Accumulate reasoning for inclusion in chat history with tool calls.
1395                            // OpenAI Responses API requires reasoning items to be sent back
1396                            // alongside function_call items in multi-turn conversations.
1397                            merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
1398                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
1399                        },
1400                        Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
1401                            // Deltas lack signatures/encrypted content that full
1402                            // blocks carry; mixing them into accumulated_reasoning
1403                            // causes Anthropic to reject with "signature required".
1404                            pending_reasoning_delta_text.push_str(&reasoning);
1405                            if pending_reasoning_delta_id.is_none() {
1406                                pending_reasoning_delta_id = id.clone();
1407                            }
1408                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
1409                        },
1410                        Ok(StreamedAssistantContent::Final(final_resp)) => {
1411                            if let Some(err) =
1412                                pending_tool_call_delta_error(&tool_call_delta_states)
1413                            {
1414                                yield Err(err.into());
1415                                break 'outer;
1416                            }
1417
1418                            if let Some(usage) = final_resp.token_usage() {
1419                                current_call_usage = reported_usage(usage);
1420                            }
1421                            if let Some(usage) = current_call_usage {
1422                                aggregated_usage += usage;
1423                            }
1424                            if let Some(completion_call) = record_completion_call_if_needed(
1425                                &mut completion_calls,
1426                                &mut completion_call_emitted,
1427                                call_index,
1428                                current_call_usage,
1429                                &chat_stream_span,
1430                            ) {
1431                                yield Ok(MultiTurnStreamItem::CompletionCall(completion_call));
1432                            }
1433
1434                            if saw_text_this_turn {
1435                                if let Some(ref hook) = self.hook &&
1436                                     let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(&current_prompt, &final_resp).await {
1437                                        yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1438                                        break 'outer;
1439                                    }
1440
1441                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
1442                                saw_text_this_turn = false;
1443                            }
1444                        }
1445                        Err(e) => {
1446                            yield Err(e.into());
1447                            break 'outer;
1448                        }
1449                    }
1450                }
1451
1452                if let Some(err) = pending_tool_call_delta_error(&tool_call_delta_states) {
1453                    yield Err(err.into());
1454                    break 'outer;
1455                }
1456
1457                if let Some(completion_call) = record_completion_call_if_needed(
1458                    &mut completion_calls,
1459                    &mut completion_call_emitted,
1460                    call_index,
1461                    current_call_usage,
1462                    &chat_stream_span,
1463                ) {
1464                    yield Ok(MultiTurnStreamItem::CompletionCall(completion_call));
1465                }
1466
1467                // Providers like Gemini emit thinking as incremental deltas
1468                // without signatures; assemble into a single block so
1469                // reasoning survives into the next turn's chat history.
1470                flush_pending_reasoning_delta(
1471                    &mut accumulated_reasoning,
1472                    &mut pending_reasoning_delta_text,
1473                    &mut pending_reasoning_delta_id,
1474                );
1475
1476                let final_turn_content = stream.choice.clone();
1477                let turn_text_response = assistant_text_from_choice(&final_turn_content);
1478                tracing::Span::current().record("gen_ai.completion", &turn_text_response);
1479
1480                if !pending_tool_calls.is_empty() {
1481                    let diagnostic_history =
1482                        build_tool_call_validation_history(ToolCallValidationHistory {
1483                            chat_history: chat_history.as_deref(),
1484                            new_messages: &new_messages,
1485                            assistant_message_id: &stream.message_id,
1486                            final_turn_content: Some(&final_turn_content),
1487                            text_delta_response: None,
1488                            accumulated_reasoning: &accumulated_reasoning,
1489                            pending_reasoning_delta_text: "",
1490                            pending_reasoning_delta_id: &None,
1491                            pending_tool_calls: &pending_tool_calls,
1492                            current_tool_call: None,
1493                        });
1494
1495                    for (tool_call, _) in &pending_tool_calls {
1496                        if let Err(err) = validate_tool_call_name(
1497                            &tool_call.function.name,
1498                            &executable_tool_names,
1499                            &allowed_tool_names,
1500                            diagnostic_history.clone(),
1501                        ) {
1502                            yield Err(Box::new(err).into());
1503                            break 'outer;
1504                        }
1505                    }
1506
1507                    for (tool_call, internal_call_id) in pending_tool_calls {
1508                        let tool_span = info_span!(
1509                            parent: tracing::Span::current(),
1510                            "execute_tool",
1511                            gen_ai.operation.name = "execute_tool",
1512                            gen_ai.tool.type = "function",
1513                            gen_ai.tool.name = tracing::field::Empty,
1514                            gen_ai.tool.call.id = tracing::field::Empty,
1515                            gen_ai.tool.call.arguments = tracing::field::Empty,
1516                            gen_ai.tool.call.result = tracing::field::Empty
1517                        );
1518
1519                        yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
1520
1521                        let tc_result = async {
1522                            let tool_span = tracing::Span::current();
1523                            let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
1524                            if let Some(ref hook) = self.hook {
1525                                let action = hook
1526                                    .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
1527                                    .await;
1528
1529                                if let ToolCallHookAction::Terminate { reason } = action {
1530                                    return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1531                                }
1532
1533                                if let ToolCallHookAction::Skip { reason } = action {
1534                                    // Tool execution rejected, return rejection message as tool result
1535                                    tracing::info!(
1536                                        tool_name = tool_call.function.name.as_str(),
1537                                        reason = reason,
1538                                        "Tool call rejected"
1539                                    );
1540                                    let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
1541                                    tool_calls.push(tool_call_msg);
1542                                    tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
1543                                    saw_tool_call_this_turn = true;
1544                                    return Ok(reason);
1545                                }
1546                            }
1547
1548                            tool_span.record("gen_ai.tool.name", &tool_call.function.name);
1549                            tool_span.record("gen_ai.tool.call.arguments", &tool_args);
1550
1551                            let tool_result = match
1552                            tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
1553                                Ok(thing) => thing,
1554                                Err(e) => {
1555                                    tracing::warn!("Error while calling tool: {e}");
1556                                    e.to_string()
1557                                }
1558                            };
1559
1560                            tool_span.record("gen_ai.tool.call.result", &tool_result);
1561
1562                            if let Some(ref hook) = self.hook &&
1563                                let HookAction::Terminate { reason } =
1564                                hook.on_tool_result(
1565                                    &tool_call.function.name,
1566                                    tool_call.call_id.clone(),
1567                                    &internal_call_id,
1568                                    &tool_args,
1569                                    &tool_result.to_string()
1570                                )
1571                                .await {
1572                                    return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1573                                }
1574
1575                            let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
1576
1577                            tool_calls.push(tool_call_msg);
1578                            tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
1579
1580                            saw_tool_call_this_turn = true;
1581                            Ok(tool_result)
1582                        }.instrument(tool_span).await;
1583
1584                        match tc_result {
1585                            Ok(text) => {
1586                                let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
1587                                yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
1588                            }
1589                            Err(e) => {
1590                                yield Err(e);
1591                                break 'outer;
1592                            }
1593                        }
1594                    }
1595                }
1596
1597                // Add text, reasoning, and tool calls to chat history.
1598                // OpenAI Responses API requires reasoning items to precede function_call items.
1599                let mut assistant_turn_added_to_history = false;
1600                if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
1601                    // Text before tool calls so the model sees its own prior output.
1602                    let mut content_items = assistant_text_items_from_choice(&final_turn_content);
1603
1604                    // Reasoning must come before tool calls (OpenAI requirement)
1605                    for reasoning in accumulated_reasoning.drain(..) {
1606                        content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
1607                    }
1608
1609                    content_items.extend(tool_calls.clone());
1610
1611                    if let Some(content) = OneOrMany::from_iter_optional(content_items) {
1612                        new_messages.push(Message::Assistant {
1613                            id: stream.message_id.clone(),
1614                            content,
1615                        });
1616                        assistant_turn_added_to_history = true;
1617                    }
1618                }
1619
1620                // Combine all tool results into a single User message (required by Anthropic)
1621                let tool_result_contents: Vec<UserContent> = tool_results
1622                    .into_iter()
1623                    .map(|(id, call_id, tool_result)| {
1624                        let content = ToolResultContent::from_tool_output(tool_result);
1625                        match call_id {
1626                            Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
1627                            None => UserContent::tool_result(id, content),
1628                        }
1629                    })
1630                    .collect();
1631
1632                if let Some(content) = OneOrMany::from_iter_optional(tool_result_contents) {
1633                    new_messages.push(Message::User { content });
1634                }
1635
1636                if !saw_tool_call_this_turn {
1637                    // Add user message and assistant response to history before finishing
1638                    let should_add_final_assistant = !assistant_turn_added_to_history
1639                        && !is_empty_assistant_choice(&final_turn_content);
1640                    if should_add_final_assistant {
1641                        new_messages.push(Message::Assistant {
1642                            id: stream.message_id.clone(),
1643                            content: final_turn_content.clone(),
1644                        });
1645                    } else if !assistant_turn_added_to_history {
1646                        tracing::warn!(
1647                            agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
1648                            message_id = ?stream.message_id,
1649                            "Streaming turn completed without assistant text; final response will be empty"
1650                        );
1651                    }
1652
1653                    if created_agent_span {
1654                        let current_span = tracing::Span::current();
1655                        record_usage_on_span(&current_span, aggregated_usage);
1656                    }
1657                    tracing::info!("Agent multi-turn stream finished");
1658                    if let Some((memory, id)) = memory_handle.as_ref()
1659                        && let Err(err) = memory.append(id, new_messages.clone()).await
1660                    {
1661                        tracing::warn!(
1662                            error = %err,
1663                            conversation_id = %id,
1664                            "conversation memory append failed; yielding final response anyway"
1665                        );
1666                    }
1667                    let final_messages: Option<Vec<Message>> = if has_history {
1668                        Some(new_messages.clone())
1669                    } else {
1670                        None
1671                    };
1672                    yield Ok(MultiTurnStreamItem::final_response_with_completion_calls(
1673                        final_turn_content,
1674                        aggregated_usage,
1675                        completion_calls,
1676                        final_messages,
1677                    ));
1678                    break;
1679                }
1680            }
1681
1682            if max_turns_reached {
1683                yield Err(Box::new(PromptError::MaxTurnsError {
1684                    max_turns: self.max_turns,
1685                    chat_history: build_full_history(chat_history.as_deref(), new_messages.clone()).into(),
1686                    prompt: Box::new(last_prompt_error.clone().into()),
1687                }).into());
1688            }
1689        };
1690
1691        Box::pin(stream.instrument(agent_span))
1692    }
1693}
1694
1695impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
1696where
1697    M: CompletionModel + 'static,
1698    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
1699    P: PromptHook<M> + 'static,
1700{
1701    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
1702    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1703
1704    fn into_future(self) -> Self::IntoFuture {
1705        // Wrap send() in a future, because send() returns a stream immediately
1706        Box::pin(async move { self.send().await })
1707    }
1708}
1709
1710/// Helper function to stream assistant-visible completion output to stdout.
1711///
1712/// This helper prints streamed assistant text and reasoning only. Streaming
1713/// metadata events, such as `MultiTurnStreamItem::CompletionCall`, are not
1714/// printed; metadata is returned on the `FinalResponse` via accessors such as
1715/// `FinalResponse::completion_calls`.
1716pub async fn stream_to_stdout<R>(
1717    stream: &mut StreamingResult<R>,
1718) -> Result<FinalResponse, std::io::Error> {
1719    let mut final_res = FinalResponse::empty();
1720    print!("Response: ");
1721    while let Some(content) = stream.next().await {
1722        match content {
1723            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1724                Text { text, .. },
1725            ))) => {
1726                print!("{text}");
1727                std::io::Write::flush(&mut std::io::stdout())?;
1728            }
1729            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
1730                reasoning,
1731            ))) => {
1732                let reasoning = reasoning.display_text();
1733                print!("{reasoning}");
1734                std::io::Write::flush(&mut std::io::stdout())?;
1735            }
1736            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1737                final_res = res;
1738            }
1739            Err(err) => {
1740                eprintln!("Error: {err}");
1741            }
1742            _ => {}
1743        }
1744    }
1745
1746    Ok(final_res)
1747}
1748
1749#[cfg(test)]
1750mod tests {
1751    use super::*;
1752    use crate::agent::AgentBuilder;
1753    use crate::agent::prompt_request::hooks::{
1754        InvalidToolCallContext, InvalidToolCallHookAction, PromptHook, ToolCallHookAction,
1755    };
1756    use crate::client::ProviderClient;
1757    use crate::client::completion::CompletionClient;
1758    use crate::completion::{CompletionRequest, PromptError, ToolDefinition, Usage};
1759    use crate::message::{
1760        AssistantContent, DocumentSourceKind, ImageMediaType, Message, ReasoningContent,
1761        ToolChoice, ToolResultContent, UserContent,
1762    };
1763    use crate::providers::anthropic;
1764    use crate::streaming::{StreamingPrompt, ToolCallDeltaContent};
1765    use crate::test_utils::{
1766        AppendFailingMemory, FailingMemory, MockAddTool, MockCompletionModel, MockResponse,
1767        MockStreamEvent, MockSubtractTool, MockToolError,
1768    };
1769    use crate::tool::Tool;
1770    use futures::{StreamExt, TryStreamExt};
1771    use serde::Deserialize;
1772    use std::collections::HashMap;
1773    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
1774    use std::sync::{Arc, Mutex};
1775    use std::time::Duration;
1776    use tracing::field::{Field, Visit};
1777    use tracing::{Id, Subscriber};
1778    use tracing_subscriber::layer::{Context, SubscriberExt};
1779    use tracing_subscriber::{Layer, Registry, registry::LookupSpan};
1780
1781    #[test]
1782    fn merge_reasoning_blocks_preserves_order_and_signatures() {
1783        let mut accumulated = Vec::new();
1784        let first = crate::message::Reasoning {
1785            id: Some("rs_1".to_string()),
1786            content: vec![ReasoningContent::Text {
1787                text: "step-1".to_string(),
1788                signature: Some("sig-1".to_string()),
1789            }],
1790        };
1791        let second = crate::message::Reasoning {
1792            id: Some("rs_1".to_string()),
1793            content: vec![
1794                ReasoningContent::Text {
1795                    text: "step-2".to_string(),
1796                    signature: Some("sig-2".to_string()),
1797                },
1798                ReasoningContent::Summary("summary".to_string()),
1799            ],
1800        };
1801
1802        merge_reasoning_blocks(&mut accumulated, &first);
1803        merge_reasoning_blocks(&mut accumulated, &second);
1804
1805        assert_eq!(accumulated.len(), 1);
1806        let merged = accumulated.first().expect("expected accumulated reasoning");
1807        assert_eq!(merged.id.as_deref(), Some("rs_1"));
1808        assert_eq!(merged.content.len(), 3);
1809        assert!(matches!(
1810            merged.content.first(),
1811            Some(ReasoningContent::Text { text, signature: Some(sig) })
1812                if text == "step-1" && sig == "sig-1"
1813        ));
1814        assert!(matches!(
1815            merged.content.get(1),
1816            Some(ReasoningContent::Text { text, signature: Some(sig) })
1817                if text == "step-2" && sig == "sig-2"
1818        ));
1819    }
1820
1821    #[test]
1822    fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
1823        let mut accumulated = vec![crate::message::Reasoning {
1824            id: Some("rs_a".to_string()),
1825            content: vec![ReasoningContent::Text {
1826                text: "step-1".to_string(),
1827                signature: None,
1828            }],
1829        }];
1830        let incoming = crate::message::Reasoning {
1831            id: Some("rs_b".to_string()),
1832            content: vec![ReasoningContent::Text {
1833                text: "step-2".to_string(),
1834                signature: None,
1835            }],
1836        };
1837
1838        merge_reasoning_blocks(&mut accumulated, &incoming);
1839        assert_eq!(accumulated.len(), 2);
1840        assert_eq!(
1841            accumulated.first().and_then(|r| r.id.as_deref()),
1842            Some("rs_a")
1843        );
1844        assert_eq!(
1845            accumulated.get(1).and_then(|r| r.id.as_deref()),
1846            Some("rs_b")
1847        );
1848    }
1849
1850    #[test]
1851    fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
1852        let mut accumulated = vec![crate::message::Reasoning {
1853            id: None,
1854            content: vec![ReasoningContent::Text {
1855                text: "first".to_string(),
1856                signature: None,
1857            }],
1858        }];
1859        let incoming = crate::message::Reasoning {
1860            id: None,
1861            content: vec![ReasoningContent::Text {
1862                text: "second".to_string(),
1863                signature: None,
1864            }],
1865        };
1866
1867        merge_reasoning_blocks(&mut accumulated, &incoming);
1868        assert_eq!(accumulated.len(), 2);
1869        assert!(matches!(
1870            accumulated.first(),
1871            Some(crate::message::Reasoning {
1872                id: None,
1873                content
1874            }) if matches!(
1875                content.first(),
1876                Some(ReasoningContent::Text { text, .. }) if text == "first"
1877            )
1878        ));
1879        assert!(matches!(
1880            accumulated.get(1),
1881            Some(crate::message::Reasoning {
1882                id: None,
1883                content
1884            }) if matches!(
1885                content.first(),
1886                Some(ReasoningContent::Text { text, .. }) if text == "second"
1887            )
1888        ));
1889    }
1890
1891    #[test]
1892    fn tool_result_user_content_preserves_multimodal_tool_output() {
1893        let user_content = tool_result_user_content(
1894            "tool_call_1".to_string(),
1895            Some("call_1".to_string()),
1896            serde_json::json!({
1897                "response": {
1898                    "instruction": "Use the image part to answer."
1899                },
1900                "parts": [
1901                    {
1902                        "type": "image",
1903                        "data": "base64data==",
1904                        "mimeType": "image/png"
1905                    }
1906                ]
1907            })
1908            .to_string(),
1909        );
1910
1911        let tool_result = match user_content {
1912            UserContent::ToolResult(tool_result) => tool_result,
1913            other => panic!("expected tool result content, got {other:?}"),
1914        };
1915
1916        assert_eq!(tool_result.id, "tool_call_1");
1917        assert_eq!(tool_result.call_id.as_deref(), Some("call_1"));
1918        assert_eq!(tool_result.content.len(), 2);
1919
1920        let mut items = tool_result.content.iter();
1921        match items.next() {
1922            Some(ToolResultContent::Text(text)) => {
1923                assert!(text.text.contains("Use the image part to answer."));
1924            }
1925            other => panic!("expected structured text payload first, got {other:?}"),
1926        }
1927
1928        match items.next() {
1929            Some(ToolResultContent::Image(image)) => {
1930                assert_eq!(image.media_type, Some(ImageMediaType::PNG));
1931                assert!(matches!(
1932                    image.data,
1933                    DocumentSourceKind::Base64(ref data) if data == "base64data=="
1934                ));
1935            }
1936            other => panic!("expected image payload second, got {other:?}"),
1937        }
1938    }
1939
1940    fn validate_follow_up_tool_history(request: &CompletionRequest) -> Result<(), String> {
1941        let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
1942        if history.len() != 3 {
1943            return Err(format!(
1944                "follow-up request should contain [original user prompt, assistant tool call, user tool result]: {history:?}"
1945            ));
1946        }
1947
1948        if !matches!(
1949            history.first(),
1950            Some(Message::User { content })
1951                if matches!(
1952                    content.first(),
1953                    UserContent::Text(text) if text.text == "do tool work"
1954                )
1955        ) {
1956            return Err(format!(
1957                "follow-up request should begin with the original user prompt: {history:?}"
1958            ));
1959        }
1960
1961        if !matches!(
1962            history.get(1),
1963            Some(Message::Assistant { content, .. })
1964                if matches!(
1965                    content.first(),
1966                    AssistantContent::ToolCall(tool_call)
1967                        if tool_call.id == "tool_call_1"
1968                            && tool_call.call_id.as_deref() == Some("call_1")
1969                )
1970        ) {
1971            return Err(format!(
1972                "follow-up request is missing the assistant tool call in position 2: {history:?}"
1973            ));
1974        }
1975
1976        if !matches!(
1977            history.get(2),
1978            Some(Message::User { content })
1979                if matches!(
1980                    content.first(),
1981                    UserContent::ToolResult(tool_result)
1982                        if tool_result.id == "tool_call_1"
1983                            && tool_result.call_id.as_deref() == Some("call_1")
1984                )
1985        ) {
1986            return Err(format!(
1987                "follow-up request should end with the user tool result: {history:?}"
1988            ));
1989        }
1990
1991        Ok(())
1992    }
1993
1994    fn history_contains_tool_call(history: &[Message], tool_name: &str) -> bool {
1995        history.iter().any(|message| {
1996            matches!(
1997                message,
1998                Message::Assistant { content, .. }
1999                    if content.iter().any(|item| matches!(
2000                        item,
2001                        AssistantContent::ToolCall(tool_call)
2002                            if tool_call.function.name == tool_name
2003                    ))
2004            )
2005        })
2006    }
2007
2008    fn history_contains_text(history: &[Message], expected: &str) -> bool {
2009        history.iter().any(|message| {
2010            matches!(
2011                message,
2012                Message::Assistant { content, .. }
2013                    if content.iter().any(|item| matches!(
2014                        item,
2015                        AssistantContent::Text(text) if text.text == expected
2016                    ))
2017            )
2018        })
2019    }
2020
2021    fn assistant_reasoning_precedes_tool_call(
2022        history: &[Message],
2023        expected_reasoning: &str,
2024        tool_name: &str,
2025    ) -> bool {
2026        history.iter().any(|message| {
2027            let Message::Assistant { content, .. } = message else {
2028                return false;
2029            };
2030
2031            let reasoning_index = content.iter().position(|item| {
2032                matches!(
2033                    item,
2034                    AssistantContent::Reasoning(reasoning)
2035                        if reasoning.content.iter().any(|content| matches!(
2036                            content,
2037                            ReasoningContent::Text { text, .. }
2038                                if text == expected_reasoning
2039                        ))
2040                )
2041            });
2042            let tool_index = content.iter().position(|item| {
2043                matches!(
2044                    item,
2045                    AssistantContent::ToolCall(tool_call)
2046                        if tool_call.function.name == tool_name
2047                )
2048            });
2049
2050            matches!((reasoning_index, tool_index), (Some(reasoning), Some(tool)) if reasoning < tool)
2051        })
2052    }
2053
2054    #[derive(Clone)]
2055    struct PanicOnUnknownToolHook;
2056
2057    impl PromptHook<MockCompletionModel> for PanicOnUnknownToolHook {
2058        async fn on_tool_call_delta(
2059            &self,
2060            _tool_call_id: &str,
2061            _internal_call_id: &str,
2062            _tool_name: Option<&str>,
2063            _tool_call_delta: &str,
2064        ) -> HookAction {
2065            panic!("unknown tool call delta should fail before delta hooks run")
2066        }
2067
2068        async fn on_tool_call(
2069            &self,
2070            _tool_name: &str,
2071            _tool_call_id: Option<String>,
2072            _internal_call_id: &str,
2073            _args: &str,
2074        ) -> ToolCallHookAction {
2075            panic!("unknown tool call should fail before tool hooks run")
2076        }
2077
2078        async fn on_stream_completion_response_finish(
2079            &self,
2080            _prompt: &Message,
2081            _response: &MockResponse,
2082        ) -> HookAction {
2083            panic!("unknown tool call should fail before stream finish hooks run")
2084        }
2085    }
2086
2087    #[derive(Clone)]
2088    struct CountingAddTool {
2089        calls: Arc<AtomicU32>,
2090    }
2091
2092    #[derive(Clone)]
2093    struct CountingSubtractTool {
2094        calls: Arc<AtomicU32>,
2095    }
2096
2097    #[derive(Deserialize)]
2098    struct CountingOperationArgs {
2099        x: i32,
2100        y: i32,
2101    }
2102
2103    fn arithmetic_tool_definition(name: &str, description: &str) -> ToolDefinition {
2104        ToolDefinition {
2105            name: name.to_string(),
2106            description: description.to_string(),
2107            parameters: serde_json::json!({
2108                "type": "object",
2109                "properties": {
2110                    "x": {
2111                        "type": "number",
2112                        "description": "The first operand"
2113                    },
2114                    "y": {
2115                        "type": "number",
2116                        "description": "The second operand"
2117                    }
2118                },
2119                "required": ["x", "y"],
2120            }),
2121        }
2122    }
2123
2124    impl Tool for CountingAddTool {
2125        const NAME: &'static str = "add";
2126        type Error = MockToolError;
2127        type Args = CountingOperationArgs;
2128        type Output = i32;
2129
2130        async fn definition(&self, _prompt: String) -> ToolDefinition {
2131            arithmetic_tool_definition(Self::NAME, "Add x and y together")
2132        }
2133
2134        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
2135            self.calls.fetch_add(1, Ordering::SeqCst);
2136            Ok(args.x + args.y)
2137        }
2138    }
2139
2140    impl Tool for CountingSubtractTool {
2141        const NAME: &'static str = "subtract";
2142        type Error = MockToolError;
2143        type Args = CountingOperationArgs;
2144        type Output = i32;
2145
2146        async fn definition(&self, _prompt: String) -> ToolDefinition {
2147            arithmetic_tool_definition(Self::NAME, "Subtract y from x")
2148        }
2149
2150        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
2151            self.calls.fetch_add(1, Ordering::SeqCst);
2152            Ok(args.x - args.y)
2153        }
2154    }
2155
2156    fn streaming_tool_then_text_model() -> MockCompletionModel {
2157        MockCompletionModel::from_stream_turns([
2158            vec![
2159                MockStreamEvent::tool_call(
2160                    "tool_call_1",
2161                    "add",
2162                    serde_json::json!({"x": 1, "y": 2}),
2163                )
2164                .with_call_id("call_1"),
2165                MockStreamEvent::final_response_with_total_tokens(4),
2166            ],
2167            vec![
2168                MockStreamEvent::text("done"),
2169                MockStreamEvent::final_response_with_total_tokens(6),
2170            ],
2171        ])
2172    }
2173
2174    fn usage(input_tokens: u64, output_tokens: u64) -> Usage {
2175        Usage {
2176            input_tokens,
2177            output_tokens,
2178            total_tokens: input_tokens + output_tokens,
2179            cached_input_tokens: 0,
2180            cache_creation_input_tokens: 0,
2181            tool_use_prompt_tokens: 0,
2182            reasoning_tokens: 0,
2183        }
2184    }
2185
2186    #[derive(Clone, Debug, Default)]
2187    struct CapturedSpan {
2188        id: u64,
2189        name: String,
2190        parent_id: Option<u64>,
2191        fields: HashMap<String, u64>,
2192    }
2193
2194    #[derive(Clone, Default)]
2195    struct CapturedSpans(Arc<Mutex<Vec<CapturedSpan>>>);
2196
2197    impl CapturedSpans {
2198        fn insert(&self, id: &Id, name: &str, parent_id: Option<u64>) {
2199            let id = id.into_u64();
2200            if let Ok(mut spans) = self.0.lock() {
2201                spans.push(CapturedSpan {
2202                    id,
2203                    name: name.to_string(),
2204                    parent_id,
2205                    fields: HashMap::new(),
2206                });
2207            }
2208        }
2209
2210        fn record(&self, id: &Id, fields: Vec<(String, u64)>) {
2211            if let Ok(mut spans) = self.0.lock()
2212                && let Some(span) = spans.iter_mut().rev().find(|span| span.id == id.into_u64())
2213            {
2214                span.fields.extend(fields);
2215            }
2216        }
2217
2218        fn snapshot(&self) -> Vec<CapturedSpan> {
2219            self.0.lock().map(|spans| spans.clone()).unwrap_or_default()
2220        }
2221    }
2222
2223    struct SpanCaptureLayer {
2224        spans: CapturedSpans,
2225    }
2226
2227    impl<S> Layer<S> for SpanCaptureLayer
2228    where
2229        S: Subscriber,
2230        S: for<'lookup> LookupSpan<'lookup>,
2231    {
2232        fn on_new_span(&self, attrs: &tracing::span::Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
2233            let parent_id = attrs
2234                .parent()
2235                .map(Id::into_u64)
2236                .or_else(|| ctx.current_span().id().map(Id::into_u64));
2237            self.spans.insert(id, attrs.metadata().name(), parent_id);
2238        }
2239
2240        fn on_record(&self, span: &Id, values: &tracing::span::Record<'_>, _ctx: Context<'_, S>) {
2241            let mut fields = Vec::new();
2242            values.record(&mut SpanFieldCaptureVisitor {
2243                fields: &mut fields,
2244            });
2245            self.spans.record(span, fields);
2246        }
2247    }
2248
2249    struct SpanFieldCaptureVisitor<'a> {
2250        fields: &'a mut Vec<(String, u64)>,
2251    }
2252
2253    impl Visit for SpanFieldCaptureVisitor<'_> {
2254        fn record_u64(&mut self, field: &Field, value: u64) {
2255            self.fields.push((field.name().to_string(), value));
2256        }
2257
2258        fn record_debug(&mut self, _field: &Field, _value: &dyn std::fmt::Debug) {}
2259    }
2260
2261    async fn assert_stream_usage_recorded_on_chat_spans(
2262        agent: crate::agent::Agent<MockCompletionModel>,
2263        prompt: &str,
2264        max_turns: usize,
2265        expected_usages: &[Usage],
2266    ) {
2267        let spans = CapturedSpans::default();
2268        let subscriber = Registry::default().with(SpanCaptureLayer {
2269            spans: spans.clone(),
2270        });
2271        let _default = tracing::subscriber::set_default(subscriber);
2272        let empty_history: &[Message] = &[];
2273        let outer_span = tracing::info_span!("outer");
2274
2275        async {
2276            let mut stream = agent
2277                .stream_prompt(prompt)
2278                .with_history(empty_history)
2279                .multi_turn(max_turns)
2280                .await;
2281
2282            while let Some(item) = stream.try_next().await.expect("stream should not error") {
2283                if matches!(item, MultiTurnStreamItem::FinalResponse(_)) {
2284                    break;
2285                }
2286            }
2287        }
2288        .instrument(outer_span)
2289        .await;
2290
2291        let span_snapshot = spans.snapshot();
2292        let outer_span_id = span_snapshot
2293            .iter()
2294            .find(|span| span.name == "outer")
2295            .map(|span| span.id)
2296            .expect("outer span should be captured");
2297        let chat_spans = span_snapshot
2298            .iter()
2299            .filter(|span| span.name == "chat_streaming")
2300            .collect::<Vec<_>>();
2301
2302        assert_eq!(chat_spans.len(), expected_usages.len());
2303        assert!(
2304            span_snapshot.iter().all(|span| span.name != "invoke_agent"),
2305            "outer span path should not create invoke_agent"
2306        );
2307
2308        for (chat_span, expected_usage) in chat_spans.into_iter().zip(expected_usages) {
2309            assert_eq!(chat_span.parent_id, Some(outer_span_id));
2310            assert_eq!(
2311                chat_span.fields.get("gen_ai.usage.input_tokens"),
2312                Some(&expected_usage.input_tokens)
2313            );
2314            assert_eq!(
2315                chat_span.fields.get("gen_ai.usage.output_tokens"),
2316                Some(&expected_usage.output_tokens)
2317            );
2318            assert_eq!(
2319                chat_span.fields.get("gen_ai.usage.cache_read.input_tokens"),
2320                Some(&expected_usage.cached_input_tokens)
2321            );
2322            assert_eq!(
2323                chat_span
2324                    .fields
2325                    .get("gen_ai.usage.cache_creation.input_tokens"),
2326                Some(&expected_usage.cache_creation_input_tokens)
2327            );
2328            assert_eq!(
2329                chat_span.fields.get("gen_ai.usage.tool_use_prompt_tokens"),
2330                Some(&expected_usage.tool_use_prompt_tokens)
2331            );
2332            assert_eq!(
2333                chat_span.fields.get("gen_ai.usage.reasoning_tokens"),
2334                Some(&expected_usage.reasoning_tokens)
2335            );
2336        }
2337
2338        let outer_span = span_snapshot
2339            .iter()
2340            .find(|span| span.id == outer_span_id)
2341            .expect("outer span should be present");
2342        assert!(
2343            outer_span
2344                .fields
2345                .keys()
2346                .all(|field| !field.starts_with("gen_ai.usage.")),
2347            "usage should not be recorded onto the caller's outer span"
2348        );
2349    }
2350
2351    #[test]
2352    fn completion_calls_stream_item_serializes_and_deserializes_expected_shape() {
2353        let item: MultiTurnStreamItem<MockResponse> =
2354            MultiTurnStreamItem::CompletionCall(CompletionCall::new(2, Some(usage(3, 4))));
2355
2356        let value = serde_json::to_value(&item).expect("serialize completion call event");
2357
2358        assert_eq!(
2359            value,
2360            serde_json::json!({
2361                "type": "completionCall",
2362                "call_index": 2,
2363                "usage": {
2364                    "input_tokens": 3,
2365                    "output_tokens": 4,
2366                    "total_tokens": 7,
2367                    "cached_input_tokens": 0,
2368                    "cache_creation_input_tokens": 0,
2369                    "tool_use_prompt_tokens": 0,
2370                    "reasoning_tokens": 0,
2371                }
2372            })
2373        );
2374
2375        let item: MultiTurnStreamItem<MockResponse> =
2376            serde_json::from_value(value).expect("deserialize completion call event");
2377        match item {
2378            MultiTurnStreamItem::CompletionCall(call_usage) => {
2379                assert_eq!(call_usage, CompletionCall::new(2, Some(usage(3, 4))));
2380            }
2381            other => panic!("expected completion call event, got {other:?}"),
2382        }
2383
2384        let item: MultiTurnStreamItem<MockResponse> =
2385            MultiTurnStreamItem::CompletionCall(CompletionCall::new(3, None));
2386        let value = serde_json::to_value(&item).expect("serialize missing usage event");
2387
2388        assert_eq!(
2389            value,
2390            serde_json::json!({
2391                "type": "completionCall",
2392                "call_index": 3,
2393                "usage": null
2394            })
2395        );
2396    }
2397
2398    #[test]
2399    fn final_response_serializes_completion_calls_with_missing_usage() {
2400        let item: MultiTurnStreamItem<MockResponse> =
2401            MultiTurnStreamItem::final_response_with_completion_calls(
2402                OneOrMany::one(AssistantContent::text("done")),
2403                usage(3, 4),
2404                vec![
2405                    CompletionCall::new(0, None),
2406                    CompletionCall::new(1, Some(usage(3, 4))),
2407                ],
2408                None,
2409            );
2410
2411        let value = serde_json::to_value(&item).expect("serialize final response");
2412
2413        assert_eq!(
2414            value.get("completionCalls"),
2415            Some(&serde_json::json!([
2416                {
2417                    "call_index": 0,
2418                    "usage": null,
2419                },
2420                {
2421                    "call_index": 1,
2422                    "usage": {
2423                        "input_tokens": 3,
2424                        "output_tokens": 4,
2425                        "total_tokens": 7,
2426                        "cached_input_tokens": 0,
2427                        "cache_creation_input_tokens": 0,
2428                        "tool_use_prompt_tokens": 0,
2429                        "reasoning_tokens": 0,
2430                    }
2431                }
2432            ]))
2433        );
2434    }
2435
2436    fn streaming_text_then_final_model() -> MockCompletionModel {
2437        MockCompletionModel::from_stream_turns([[
2438            MockStreamEvent::text("hello"),
2439            MockStreamEvent::text(" world"),
2440            MockStreamEvent::final_response_with_total_tokens(3),
2441        ]])
2442    }
2443
2444    fn citation_metadata() -> serde_json::Value {
2445        serde_json::json!({
2446            "citations": [{
2447                "type": "web_search_result_location",
2448                "cited_text": "Claude Shannon was born in 1916.",
2449                "url": "https://example.com/shannon",
2450                "title": "Claude Shannon",
2451                "encrypted_index": "encrypted-reference"
2452            }]
2453        })
2454    }
2455
2456    fn streaming_cited_text_then_final_model() -> MockCompletionModel {
2457        MockCompletionModel::from_stream_turns([[
2458            MockStreamEvent::text_start(Some(citation_metadata())),
2459            MockStreamEvent::text("cited "),
2460            MockStreamEvent::text_start(None),
2461            MockStreamEvent::text("answer"),
2462            MockStreamEvent::final_response_with_total_tokens(3),
2463        ]])
2464    }
2465
2466    fn streaming_cited_text_then_tool_model() -> MockCompletionModel {
2467        MockCompletionModel::from_stream_turns([
2468            vec![
2469                MockStreamEvent::text_start(Some(citation_metadata())),
2470                MockStreamEvent::text("I need a tool. "),
2471                MockStreamEvent::tool_call(
2472                    "tool_call_1",
2473                    "add",
2474                    serde_json::json!({"x": 1, "y": 2}),
2475                )
2476                .with_call_id("call_1"),
2477                MockStreamEvent::final_response_with_total_tokens(4),
2478            ],
2479            vec![
2480                MockStreamEvent::text("done"),
2481                MockStreamEvent::final_response_with_total_tokens(6),
2482            ],
2483        ])
2484    }
2485
2486    fn streaming_final_only_model() -> MockCompletionModel {
2487        MockCompletionModel::from_stream_turns([[
2488            MockStreamEvent::final_response_with_total_tokens(1),
2489        ]])
2490    }
2491
2492    #[derive(Clone)]
2493    struct TerminateOnStreamFinish;
2494
2495    impl PromptHook<MockCompletionModel> for TerminateOnStreamFinish {
2496        async fn on_stream_completion_response_finish(
2497            &self,
2498            _prompt: &Message,
2499            _response: &<MockCompletionModel as CompletionModel>::StreamingResponse,
2500        ) -> HookAction {
2501            HookAction::terminate("stop after completion call")
2502        }
2503    }
2504
2505    type RecordedToolCallDelta = (String, String, Option<String>, String);
2506
2507    #[derive(Clone)]
2508    struct RepairDefaultApiHook;
2509
2510    impl PromptHook<MockCompletionModel> for RepairDefaultApiHook {
2511        fn on_invalid_tool_call(
2512            &self,
2513            context: &InvalidToolCallContext,
2514        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2515            let tool_name = context.tool_name.clone();
2516            async move {
2517                assert_eq!(tool_name, "default_api");
2518                InvalidToolCallHookAction::repair("add")
2519            }
2520        }
2521    }
2522
2523    #[derive(Clone)]
2524    struct RetryDefaultApiHook;
2525
2526    impl PromptHook<MockCompletionModel> for RetryDefaultApiHook {
2527        fn on_invalid_tool_call(
2528            &self,
2529            context: &InvalidToolCallContext,
2530        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2531            let tool_name = context.tool_name.clone();
2532            let args = context.args.clone();
2533            async move {
2534                assert_eq!(tool_name, "default_api");
2535                if let Some(args) = args {
2536                    assert!(!args.is_empty());
2537                }
2538                InvalidToolCallHookAction::retry("Use the add tool instead")
2539            }
2540        }
2541    }
2542
2543    #[derive(Clone)]
2544    struct SkipDefaultApiHook;
2545
2546    impl PromptHook<MockCompletionModel> for SkipDefaultApiHook {
2547        fn on_invalid_tool_call(
2548            &self,
2549            context: &InvalidToolCallContext,
2550        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2551            let tool_name = context.tool_name.clone();
2552            async move {
2553                assert_eq!(tool_name, "default_api");
2554                InvalidToolCallHookAction::skip("default_api was skipped")
2555            }
2556        }
2557    }
2558
2559    #[derive(Clone, Default)]
2560    struct RecordingInvalidToolCallHook {
2561        contexts: Arc<Mutex<Vec<InvalidToolCallContext>>>,
2562    }
2563
2564    impl RecordingInvalidToolCallHook {
2565        fn observed(&self) -> Vec<InvalidToolCallContext> {
2566            self.contexts
2567                .lock()
2568                .expect("invalid tool context records mutex was poisoned")
2569                .clone()
2570        }
2571    }
2572
2573    impl PromptHook<MockCompletionModel> for RecordingInvalidToolCallHook {
2574        fn on_invalid_tool_call(
2575            &self,
2576            context: &InvalidToolCallContext,
2577        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2578            let contexts = self.contexts.clone();
2579            let context = context.clone();
2580
2581            async move {
2582                contexts
2583                    .lock()
2584                    .expect("invalid tool context records mutex was poisoned")
2585                    .push(context);
2586                InvalidToolCallHookAction::fail()
2587            }
2588        }
2589    }
2590
2591    #[derive(Clone, Default)]
2592    struct RecordingToolCallDeltaHook {
2593        deltas: Arc<Mutex<Vec<RecordedToolCallDelta>>>,
2594    }
2595
2596    impl RecordingToolCallDeltaHook {
2597        fn observed(&self) -> Vec<RecordedToolCallDelta> {
2598            self.deltas
2599                .lock()
2600                .expect("tool call delta hook records mutex was poisoned")
2601                .clone()
2602        }
2603    }
2604
2605    impl PromptHook<MockCompletionModel> for RecordingToolCallDeltaHook {
2606        fn on_tool_call_delta(
2607            &self,
2608            tool_call_id: &str,
2609            internal_call_id: &str,
2610            tool_name: Option<&str>,
2611            tool_call_delta: &str,
2612        ) -> impl Future<Output = HookAction> + Send {
2613            let deltas = self.deltas.clone();
2614            let event = (
2615                tool_call_id.to_string(),
2616                internal_call_id.to_string(),
2617                tool_name.map(str::to_string),
2618                tool_call_delta.to_string(),
2619            );
2620
2621            async move {
2622                deltas
2623                    .lock()
2624                    .expect("tool call delta hook records mutex was poisoned")
2625                    .push(event);
2626                HookAction::cont()
2627            }
2628        }
2629    }
2630
2631    #[derive(Clone, Default)]
2632    struct RecordingTextDeltaHook {
2633        deltas: Arc<Mutex<Vec<(String, String)>>>,
2634    }
2635
2636    impl RecordingTextDeltaHook {
2637        fn observed(&self) -> Vec<(String, String)> {
2638            self.deltas
2639                .lock()
2640                .expect("text delta hook records mutex was poisoned")
2641                .clone()
2642        }
2643    }
2644
2645    impl PromptHook<MockCompletionModel> for RecordingTextDeltaHook {
2646        fn on_text_delta(
2647            &self,
2648            text_delta: &str,
2649            full_text: &str,
2650        ) -> impl Future<Output = HookAction> + Send {
2651            let deltas = self.deltas.clone();
2652            let event = (text_delta.to_string(), full_text.to_string());
2653
2654            async move {
2655                deltas
2656                    .lock()
2657                    .expect("text delta hook records mutex was poisoned")
2658                    .push(event);
2659                HookAction::cont()
2660            }
2661        }
2662    }
2663
2664    #[derive(Clone)]
2665    struct RecordingTextAndSkipInvalidToolHook {
2666        text: RecordingTextDeltaHook,
2667    }
2668
2669    impl PromptHook<MockCompletionModel> for RecordingTextAndSkipInvalidToolHook {
2670        fn on_text_delta(
2671            &self,
2672            text_delta: &str,
2673            full_text: &str,
2674        ) -> impl Future<Output = HookAction> + Send {
2675            self.text.on_text_delta(text_delta, full_text)
2676        }
2677
2678        fn on_invalid_tool_call(
2679            &self,
2680            context: &InvalidToolCallContext,
2681        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2682            SkipDefaultApiHook.on_invalid_tool_call(context)
2683        }
2684    }
2685
2686    #[derive(Clone)]
2687    struct RecordingTextAndRetryInvalidToolHook {
2688        text: RecordingTextDeltaHook,
2689    }
2690
2691    impl PromptHook<MockCompletionModel> for RecordingTextAndRetryInvalidToolHook {
2692        fn on_text_delta(
2693            &self,
2694            text_delta: &str,
2695            full_text: &str,
2696        ) -> impl Future<Output = HookAction> + Send {
2697            self.text.on_text_delta(text_delta, full_text)
2698        }
2699
2700        fn on_invalid_tool_call(
2701            &self,
2702            context: &InvalidToolCallContext,
2703        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2704            RetryDefaultApiHook.on_invalid_tool_call(context)
2705        }
2706    }
2707
2708    #[derive(Clone)]
2709    struct RecordingDeltaAndRetryInvalidToolHook {
2710        delta: RecordingToolCallDeltaHook,
2711    }
2712
2713    impl PromptHook<MockCompletionModel> for RecordingDeltaAndRetryInvalidToolHook {
2714        fn on_tool_call_delta(
2715            &self,
2716            tool_call_id: &str,
2717            internal_call_id: &str,
2718            tool_name: Option<&str>,
2719            tool_call_delta: &str,
2720        ) -> impl Future<Output = HookAction> + Send {
2721            self.delta.on_tool_call_delta(
2722                tool_call_id,
2723                internal_call_id,
2724                tool_name,
2725                tool_call_delta,
2726            )
2727        }
2728
2729        fn on_invalid_tool_call(
2730            &self,
2731            context: &InvalidToolCallContext,
2732        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2733            RetryDefaultApiHook.on_invalid_tool_call(context)
2734        }
2735    }
2736
2737    #[derive(Clone)]
2738    struct RecordingDeltaAndSkipInvalidToolHook {
2739        delta: RecordingToolCallDeltaHook,
2740    }
2741
2742    impl PromptHook<MockCompletionModel> for RecordingDeltaAndSkipInvalidToolHook {
2743        fn on_tool_call_delta(
2744            &self,
2745            tool_call_id: &str,
2746            internal_call_id: &str,
2747            tool_name: Option<&str>,
2748            tool_call_delta: &str,
2749        ) -> impl Future<Output = HookAction> + Send {
2750            self.delta.on_tool_call_delta(
2751                tool_call_id,
2752                internal_call_id,
2753                tool_name,
2754                tool_call_delta,
2755            )
2756        }
2757
2758        fn on_invalid_tool_call(
2759            &self,
2760            context: &InvalidToolCallContext,
2761        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2762            SkipDefaultApiHook.on_invalid_tool_call(context)
2763        }
2764    }
2765
2766    #[derive(Clone, Default)]
2767    struct TerminatingToolCallDeltaHook {
2768        deltas: Arc<Mutex<Vec<RecordedToolCallDelta>>>,
2769    }
2770
2771    impl TerminatingToolCallDeltaHook {
2772        fn observed(&self) -> Vec<RecordedToolCallDelta> {
2773            self.deltas
2774                .lock()
2775                .expect("tool call delta hook records mutex was poisoned")
2776                .clone()
2777        }
2778    }
2779
2780    impl PromptHook<MockCompletionModel> for TerminatingToolCallDeltaHook {
2781        fn on_tool_call_delta(
2782            &self,
2783            tool_call_id: &str,
2784            internal_call_id: &str,
2785            tool_name: Option<&str>,
2786            tool_call_delta: &str,
2787        ) -> impl Future<Output = HookAction> + Send {
2788            let deltas = self.deltas.clone();
2789            let event = (
2790                tool_call_id.to_string(),
2791                internal_call_id.to_string(),
2792                tool_name.map(str::to_string),
2793                tool_call_delta.to_string(),
2794            );
2795
2796            async move {
2797                deltas
2798                    .lock()
2799                    .expect("tool call delta hook records mutex was poisoned")
2800                    .push(event);
2801                HookAction::terminate("stop on tool call delta")
2802            }
2803        }
2804    }
2805
2806    fn text_metadata(content: &OneOrMany<AssistantContent>) -> Option<&serde_json::Value> {
2807        content.iter().find_map(|item| match item {
2808            AssistantContent::Text(text) => text.additional_params.as_ref(),
2809            _ => None,
2810        })
2811    }
2812
2813    #[tokio::test]
2814    async fn stream_prompt_continues_after_tool_call_turn() {
2815        let model = streaming_tool_then_text_model();
2816        let recorded = model.clone();
2817        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2818        let empty_history: &[Message] = &[];
2819
2820        let mut stream = agent
2821            .stream_prompt("do tool work")
2822            .with_history(empty_history)
2823            .multi_turn(3)
2824            .await;
2825        let mut saw_tool_call = false;
2826        let mut saw_tool_result = false;
2827        let mut saw_final_response = false;
2828        let mut final_text = String::new();
2829        let mut final_response_text = None;
2830        let mut final_history = None;
2831
2832        while let Some(item) = stream.next().await {
2833            match item {
2834                Ok(MultiTurnStreamItem::StreamAssistantItem(
2835                    StreamedAssistantContent::ToolCall { .. },
2836                )) => {
2837                    saw_tool_call = true;
2838                }
2839                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2840                    ..
2841                })) => {
2842                    saw_tool_result = true;
2843                }
2844                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
2845                    text,
2846                ))) => {
2847                    final_text.push_str(&text.text);
2848                }
2849                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
2850                    saw_final_response = true;
2851                    final_response_text = Some(res.response().to_owned());
2852                    final_history = res.history().map(|history| history.to_vec());
2853                    break;
2854                }
2855                Ok(_) => {}
2856                Err(err) => panic!("unexpected streaming error: {err:?}"),
2857            }
2858        }
2859
2860        assert!(saw_tool_call);
2861        assert!(saw_tool_result);
2862        assert!(saw_final_response);
2863        assert_eq!(final_text, "done");
2864        assert_eq!(final_response_text.as_deref(), Some("done"));
2865        let history = final_history.expect("expected final response history");
2866        assert!(history.iter().any(|message| matches!(
2867            message,
2868            Message::Assistant { content, .. }
2869                if content.iter().any(|item| matches!(
2870                    item,
2871                    AssistantContent::Text(text) if text.text == "done"
2872                ))
2873        )));
2874        let requests = recorded.requests();
2875        assert_eq!(requests.len(), 2);
2876        assert!(validate_follow_up_tool_history(&requests[1]).is_ok());
2877    }
2878
2879    #[tokio::test]
2880    async fn unknown_tool_call_fails_before_streaming_second_request() {
2881        let model = MockCompletionModel::from_stream_turns([
2882            vec![
2883                MockStreamEvent::tool_call(
2884                    "tool_call_1",
2885                    "default_api",
2886                    serde_json::json!({"x": 1, "y": 2}),
2887                ),
2888                MockStreamEvent::final_response_with_total_tokens(4),
2889            ],
2890            vec![
2891                MockStreamEvent::text("should not be requested"),
2892                MockStreamEvent::final_response_with_total_tokens(6),
2893            ],
2894        ]);
2895        let recorded = model.clone();
2896        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2897
2898        let mut stream = agent
2899            .stream_prompt("use the tool")
2900            .with_hook(PanicOnUnknownToolHook)
2901            .multi_turn(3)
2902            .await;
2903        let mut saw_tool_call = false;
2904        let mut error = None;
2905
2906        while let Some(item) = stream.next().await {
2907            match item {
2908                Ok(MultiTurnStreamItem::StreamAssistantItem(
2909                    StreamedAssistantContent::ToolCall { .. },
2910                )) => {
2911                    saw_tool_call = true;
2912                }
2913                Ok(_) => {}
2914                Err(err) => {
2915                    error = Some(err);
2916                    break;
2917                }
2918            }
2919        }
2920
2921        assert!(!saw_tool_call);
2922        let error = error.expect("unknown model-emitted tool should fail");
2923        match error {
2924            StreamingError::Prompt(err) => match *err {
2925                PromptError::UnknownToolCall {
2926                    tool_name,
2927                    available_tools,
2928                    allowed_tools,
2929                    chat_history,
2930                } => {
2931                    assert_eq!(tool_name, "default_api");
2932                    assert_eq!(available_tools, vec!["add".to_string()]);
2933                    assert_eq!(allowed_tools, vec!["add".to_string()]);
2934                    assert!(history_contains_tool_call(&chat_history, "default_api"));
2935                }
2936                other => panic!("expected UnknownToolCall, got {other:?}"),
2937            },
2938            other => panic!("expected prompt streaming error, got {other:?}"),
2939        }
2940        assert_eq!(recorded.request_count(), 1);
2941    }
2942
2943    #[tokio::test]
2944    async fn invalid_tool_call_hook_can_repair_streaming_tool_name() {
2945        let model = MockCompletionModel::from_stream_turns([
2946            vec![
2947                MockStreamEvent::tool_call(
2948                    "tool_call_1",
2949                    "default_api",
2950                    serde_json::json!({"x": 2, "y": 3}),
2951                ),
2952                MockStreamEvent::final_response_with_total_tokens(4),
2953            ],
2954            vec![
2955                MockStreamEvent::text("done"),
2956                MockStreamEvent::final_response_with_total_tokens(6),
2957            ],
2958        ]);
2959        let recorded = model.clone();
2960        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2961
2962        let mut stream = agent
2963            .stream_prompt("use the tool")
2964            .with_hook(RepairDefaultApiHook)
2965            .multi_turn(3)
2966            .with_history(Vec::<Message>::new())
2967            .await;
2968        let mut saw_repaired_tool_call = false;
2969        let mut saw_tool_result = false;
2970        let mut final_response_text = None;
2971
2972        while let Some(item) = stream.next().await {
2973            match item {
2974                Ok(MultiTurnStreamItem::StreamAssistantItem(
2975                    StreamedAssistantContent::ToolCall { tool_call, .. },
2976                )) => {
2977                    assert_eq!(tool_call.function.name, "add");
2978                    saw_repaired_tool_call = true;
2979                }
2980                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2981                    tool_result,
2982                    ..
2983                })) => {
2984                    assert!(tool_result.content.iter().any(|content| {
2985                        matches!(
2986                            content,
2987                            ToolResultContent::Text(text) if text.text == "5"
2988                        )
2989                    }));
2990                    saw_tool_result = true;
2991                }
2992                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2993                    final_response_text = Some(response.response().to_string());
2994                    break;
2995                }
2996                Ok(_) => {}
2997                Err(err) => panic!("unexpected streaming error: {err:?}"),
2998            }
2999        }
3000
3001        assert!(saw_repaired_tool_call);
3002        assert!(saw_tool_result);
3003        assert_eq!(final_response_text.as_deref(), Some("done"));
3004        assert_eq!(recorded.request_count(), 2);
3005    }
3006
3007    #[tokio::test]
3008    async fn invalid_tool_call_context_uses_completed_streaming_tool_call_provider_id() {
3009        let invalid_hook = RecordingInvalidToolCallHook::default();
3010        let model = MockCompletionModel::from_stream_turns([
3011            vec![
3012                MockStreamEvent::tool_call(
3013                    "tool_call_1",
3014                    "default_api",
3015                    serde_json::json!({"x": 2, "y": 3}),
3016                )
3017                .with_call_id("provider_call_1"),
3018                MockStreamEvent::final_response_with_total_tokens(4),
3019            ],
3020            vec![
3021                MockStreamEvent::text("should not be requested"),
3022                MockStreamEvent::final_response_with_total_tokens(6),
3023            ],
3024        ]);
3025        let recorded = model.clone();
3026        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3027
3028        let mut stream = agent
3029            .stream_prompt("use the tool")
3030            .with_hook(invalid_hook.clone())
3031            .multi_turn(3)
3032            .await;
3033        let mut error = None;
3034
3035        while let Some(item) = stream.next().await {
3036            if let Err(err) = item {
3037                error = Some(err);
3038                break;
3039            }
3040        }
3041
3042        assert!(error.is_some(), "invalid tool should fail");
3043        assert_eq!(recorded.request_count(), 1);
3044        let contexts = invalid_hook.observed();
3045        assert_eq!(contexts.len(), 1);
3046        let context = &contexts[0];
3047        assert_eq!(context.tool_name, "default_api");
3048        assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
3049        assert!(context.internal_call_id.is_some());
3050        assert!(context.is_streaming);
3051    }
3052
3053    #[tokio::test]
3054    async fn invalid_tool_call_hook_skip_emits_streaming_tool_result() {
3055        let add_calls = Arc::new(AtomicU32::new(0));
3056        let model = MockCompletionModel::from_stream_turns([
3057            vec![
3058                MockStreamEvent::tool_call(
3059                    "tool_call_1",
3060                    "default_api",
3061                    serde_json::json!({"x": 2, "y": 3}),
3062                )
3063                .with_call_id("call_1"),
3064                MockStreamEvent::final_response_with_total_tokens(4),
3065            ],
3066            vec![
3067                MockStreamEvent::text("continued"),
3068                MockStreamEvent::final_response_with_total_tokens(6),
3069            ],
3070        ]);
3071        let recorded = model.clone();
3072        let agent = AgentBuilder::new(model)
3073            .tool(CountingAddTool {
3074                calls: add_calls.clone(),
3075            })
3076            .build();
3077
3078        let mut stream = agent
3079            .stream_prompt("use the tool")
3080            .with_hook(SkipDefaultApiHook)
3081            .multi_turn(3)
3082            .with_history(Vec::<Message>::new())
3083            .await;
3084        let mut skipped_tool_result = None;
3085        let mut final_response_text = None;
3086
3087        while let Some(item) = stream.next().await {
3088            match item {
3089                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3090                    tool_result,
3091                    internal_call_id,
3092                })) => {
3093                    assert!(!internal_call_id.is_empty());
3094                    skipped_tool_result = Some(tool_result);
3095                }
3096                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3097                    final_response_text = Some(response.response().to_string());
3098                    break;
3099                }
3100                Ok(_) => {}
3101                Err(err) => panic!("unexpected streaming error: {err:?}"),
3102            }
3103        }
3104
3105        let skipped_tool_result =
3106            skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
3107        assert_eq!(skipped_tool_result.id, "tool_call_1");
3108        assert_eq!(skipped_tool_result.call_id.as_deref(), Some("call_1"));
3109        assert!(skipped_tool_result.content.iter().any(|content| matches!(
3110            content,
3111            ToolResultContent::Text(text) if text.text == "default_api was skipped"
3112        )));
3113        assert_eq!(final_response_text.as_deref(), Some("continued"));
3114        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3115
3116        let requests = recorded.requests();
3117        assert_eq!(requests.len(), 2);
3118        let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3119        assert!(matches!(
3120            follow_up_history.get(2),
3121            Some(Message::User { content })
3122                if content.iter().any(|item| matches!(
3123                    item,
3124                    UserContent::ToolResult(result)
3125                        if result.id == "tool_call_1"
3126                            && result.content.iter().any(|content| matches!(
3127                                content,
3128                                ToolResultContent::Text(text)
3129                                    if text.text == "default_api was skipped"
3130                            ))
3131                ))
3132        ));
3133    }
3134
3135    #[tokio::test]
3136    async fn invalid_tool_call_hook_retries_mixed_streaming_turn_without_executing_valid_call() {
3137        let add_calls = Arc::new(AtomicU32::new(0));
3138        let model = MockCompletionModel::from_stream_turns([
3139            vec![
3140                MockStreamEvent::text("checking "),
3141                MockStreamEvent::tool_call(
3142                    "tool_call_1",
3143                    "add",
3144                    serde_json::json!({"x": 2, "y": 3}),
3145                )
3146                .with_call_id("call_1"),
3147                MockStreamEvent::tool_call(
3148                    "tool_call_2",
3149                    "default_api",
3150                    serde_json::json!({"x": 4, "y": 5}),
3151                )
3152                .with_call_id("call_2"),
3153                MockStreamEvent::final_response_with_total_tokens(4),
3154            ],
3155            vec![
3156                MockStreamEvent::text("retried"),
3157                MockStreamEvent::final_response_with_total_tokens(6),
3158            ],
3159        ]);
3160        let recorded = model.clone();
3161        let agent = AgentBuilder::new(model)
3162            .tool(CountingAddTool {
3163                calls: add_calls.clone(),
3164            })
3165            .build();
3166
3167        let mut stream = agent
3168            .stream_prompt("use the tool")
3169            .with_hook(RetryDefaultApiHook)
3170            .multi_turn(3)
3171            .with_history(Vec::<Message>::new())
3172            .max_invalid_tool_call_retries(1)
3173            .await;
3174        let mut completion_call_events = Vec::new();
3175        let mut final_response_text = None;
3176        let mut final_response_usage = Usage::new();
3177        let mut final_completion_calls = Vec::new();
3178
3179        while let Some(item) = stream.next().await {
3180            match item {
3181                Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
3182                    completion_call_events.push(completion_call);
3183                }
3184                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3185                    final_response_text = Some(response.response().to_string());
3186                    final_response_usage = response.usage();
3187                    final_completion_calls = response.completion_calls().to_vec();
3188                    break;
3189                }
3190                Ok(_) => {}
3191                Err(err) => panic!("unexpected streaming error: {err:?}"),
3192            }
3193        }
3194
3195        assert_eq!(final_response_text.as_deref(), Some("retried"));
3196        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3197        let mut first_usage = Usage::new();
3198        first_usage.total_tokens = 4;
3199        let mut second_usage = Usage::new();
3200        second_usage.total_tokens = 6;
3201        let expected_completion_calls = vec![
3202            CompletionCall::new(0, Some(first_usage)),
3203            CompletionCall::new(1, Some(second_usage)),
3204        ];
3205        assert_eq!(completion_call_events, expected_completion_calls);
3206        assert_eq!(final_completion_calls, expected_completion_calls);
3207        assert_eq!(final_response_usage.total_tokens, 10);
3208
3209        let requests = recorded.requests();
3210        assert_eq!(requests.len(), 2);
3211        let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3212        assert_eq!(retry_history.len(), 3);
3213        assert!(matches!(
3214            retry_history.get(1),
3215            Some(Message::Assistant { content, .. })
3216                if content.iter().any(|item| matches!(
3217                    item,
3218                    AssistantContent::Text(text) if text.text == "checking "
3219                ))
3220                    && content.iter().any(|item| matches!(
3221                        item,
3222                        AssistantContent::ToolCall(tool_call)
3223                            if tool_call.id == "tool_call_1"
3224                                && tool_call.function.name == "add"
3225                    ))
3226                    && content.iter().any(|item| matches!(
3227                        item,
3228                        AssistantContent::ToolCall(tool_call)
3229                            if tool_call.id == "tool_call_2"
3230                                && tool_call.function.name == "default_api"
3231                    ))
3232        ));
3233        assert!(matches!(
3234            retry_history.get(2),
3235            Some(Message::User { content })
3236                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3237                    && content.iter().any(|item| matches!(
3238                        item,
3239                        UserContent::ToolResult(result)
3240                            if result.id == "tool_call_1"
3241                                && result.content.iter().any(|content| matches!(
3242                                    content,
3243                                    ToolResultContent::Text(text)
3244                                        if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3245                                ))
3246                    ))
3247                    && content.iter().any(|item| matches!(
3248                        item,
3249                        UserContent::ToolResult(result)
3250                            if result.id == "tool_call_2"
3251                                && result.content.iter().any(|content| matches!(
3252                                    content,
3253                                    ToolResultContent::Text(text)
3254                                        if text.text == "Use the add tool instead"
3255                                ))
3256                    ))
3257        ));
3258    }
3259
3260    #[tokio::test]
3261    async fn invalid_tool_call_hook_skips_mixed_streaming_turn_without_executing_valid_call() {
3262        let add_calls = Arc::new(AtomicU32::new(0));
3263        let model = MockCompletionModel::from_stream_turns([
3264            vec![
3265                MockStreamEvent::text("checking "),
3266                MockStreamEvent::tool_call(
3267                    "tool_call_1",
3268                    "add",
3269                    serde_json::json!({"x": 2, "y": 3}),
3270                )
3271                .with_call_id("call_1"),
3272                MockStreamEvent::tool_call(
3273                    "tool_call_2",
3274                    "default_api",
3275                    serde_json::json!({"x": 4, "y": 5}),
3276                )
3277                .with_call_id("call_2"),
3278                MockStreamEvent::final_response_with_total_tokens(4),
3279            ],
3280            vec![
3281                MockStreamEvent::text("continued"),
3282                MockStreamEvent::final_response_with_total_tokens(6),
3283            ],
3284        ]);
3285        let recorded = model.clone();
3286        let agent = AgentBuilder::new(model)
3287            .tool(CountingAddTool {
3288                calls: add_calls.clone(),
3289            })
3290            .build();
3291
3292        let mut stream = agent
3293            .stream_prompt("use the tool")
3294            .with_hook(SkipDefaultApiHook)
3295            .multi_turn(3)
3296            .with_history(Vec::<Message>::new())
3297            .await;
3298        let mut skipped_tool_result = None;
3299        let mut final_response_text = None;
3300
3301        while let Some(item) = stream.next().await {
3302            match item {
3303                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3304                    tool_result,
3305                    ..
3306                })) => {
3307                    skipped_tool_result = Some(tool_result);
3308                }
3309                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3310                    final_response_text = Some(response.response().to_string());
3311                    break;
3312                }
3313                Ok(_) => {}
3314                Err(err) => panic!("unexpected streaming error: {err:?}"),
3315            }
3316        }
3317
3318        let skipped_tool_result =
3319            skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
3320        assert_eq!(skipped_tool_result.id, "tool_call_2");
3321        assert_eq!(skipped_tool_result.call_id.as_deref(), Some("call_2"));
3322        assert_eq!(final_response_text.as_deref(), Some("continued"));
3323        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3324
3325        let requests = recorded.requests();
3326        assert_eq!(requests.len(), 2);
3327        let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3328        assert_eq!(follow_up_history.len(), 3);
3329        assert!(matches!(
3330            follow_up_history.get(1),
3331            Some(Message::Assistant { content, .. })
3332                if content.iter().any(|item| matches!(
3333                    item,
3334                    AssistantContent::Text(text) if text.text == "checking "
3335                ))
3336                    && content.iter().any(|item| matches!(
3337                        item,
3338                        AssistantContent::ToolCall(tool_call)
3339                            if tool_call.id == "tool_call_1"
3340                                && tool_call.function.name == "add"
3341                    ))
3342                    && content.iter().any(|item| matches!(
3343                        item,
3344                        AssistantContent::ToolCall(tool_call)
3345                            if tool_call.id == "tool_call_2"
3346                                && tool_call.function.name == "default_api"
3347                    ))
3348        ));
3349        assert!(matches!(
3350            follow_up_history.get(2),
3351            Some(Message::User { content })
3352                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3353                    && content.iter().any(|item| matches!(
3354                        item,
3355                        UserContent::ToolResult(result)
3356                            if result.id == "tool_call_1"
3357                                && result.call_id.as_deref() == Some("call_1")
3358                                && result.content.iter().any(|content| matches!(
3359                                    content,
3360                                    ToolResultContent::Text(text)
3361                                        if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3362                                ))
3363                    ))
3364                    && content.iter().any(|item| matches!(
3365                        item,
3366                        UserContent::ToolResult(result)
3367                            if result.id == "tool_call_2"
3368                                && result.call_id.as_deref() == Some("call_2")
3369                                && result.content.iter().any(|content| matches!(
3370                                    content,
3371                                    ToolResultContent::Text(text)
3372                                        if text.text == "default_api was skipped"
3373                                ))
3374            ))
3375        ));
3376    }
3377
3378    #[tokio::test]
3379    async fn invalid_completed_tool_call_skip_preserves_streaming_reasoning_history() {
3380        let model = MockCompletionModel::from_stream_turns([
3381            vec![
3382                MockStreamEvent::text("checking "),
3383                MockStreamEvent::reasoning("reasoned step").with_reasoning_id("rs_1"),
3384                MockStreamEvent::tool_call(
3385                    "tool_call_1",
3386                    "default_api",
3387                    serde_json::json!({"x": 2, "y": 3}),
3388                ),
3389                MockStreamEvent::final_response_with_total_tokens(4),
3390            ],
3391            vec![
3392                MockStreamEvent::text("continued"),
3393                MockStreamEvent::final_response_with_total_tokens(6),
3394            ],
3395        ]);
3396        let recorded = model.clone();
3397        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3398
3399        let mut stream = agent
3400            .stream_prompt("use the tool")
3401            .with_hook(SkipDefaultApiHook)
3402            .multi_turn(3)
3403            .with_history(Vec::<Message>::new())
3404            .await;
3405
3406        while let Some(item) = stream.next().await {
3407            match item {
3408                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3409                Ok(_) => {}
3410                Err(err) => panic!("unexpected streaming error: {err:?}"),
3411            }
3412        }
3413
3414        let requests = recorded.requests();
3415        assert_eq!(requests.len(), 2);
3416        let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3417        assert!(history_contains_text(&follow_up_history, "checking "));
3418        assert!(assistant_reasoning_precedes_tool_call(
3419            &follow_up_history,
3420            "reasoned step",
3421            "default_api"
3422        ));
3423    }
3424
3425    #[tokio::test]
3426    async fn invalid_name_delta_retry_preserves_streaming_reasoning_history() {
3427        let model = MockCompletionModel::from_stream_turns([
3428            vec![
3429                MockStreamEvent::reasoning_delta(Some("rs_1"), "delta reason"),
3430                MockStreamEvent::tool_call_arguments_delta(
3431                    "tool_call_1",
3432                    "internal_1",
3433                    r#"{"x":2,"y":3}"#,
3434                ),
3435                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3436                MockStreamEvent::final_response_with_total_tokens(4),
3437            ],
3438            vec![
3439                MockStreamEvent::text("retried"),
3440                MockStreamEvent::final_response_with_total_tokens(6),
3441            ],
3442        ]);
3443        let recorded = model.clone();
3444        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3445
3446        let mut stream = agent
3447            .stream_prompt("use the tool")
3448            .with_hook(RetryDefaultApiHook)
3449            .multi_turn(3)
3450            .with_history(Vec::<Message>::new())
3451            .max_invalid_tool_call_retries(1)
3452            .await;
3453
3454        while let Some(item) = stream.next().await {
3455            match item {
3456                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3457                Ok(_) => {}
3458                Err(err) => panic!("unexpected streaming error: {err:?}"),
3459            }
3460        }
3461
3462        let requests = recorded.requests();
3463        assert_eq!(requests.len(), 2);
3464        let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3465        assert!(assistant_reasoning_precedes_tool_call(
3466            &retry_history,
3467            "delta reason",
3468            "default_api"
3469        ));
3470    }
3471
3472    #[tokio::test]
3473    async fn invalid_tool_call_hook_skip_resets_streaming_text_delta_state() {
3474        let text_hook = RecordingTextDeltaHook::default();
3475        let model = MockCompletionModel::from_stream_turns([
3476            vec![
3477                MockStreamEvent::text("stale "),
3478                MockStreamEvent::tool_call(
3479                    "tool_call_1",
3480                    "default_api",
3481                    serde_json::json!({"x": 2, "y": 3}),
3482                ),
3483                MockStreamEvent::final_response_with_total_tokens(4),
3484            ],
3485            vec![
3486                MockStreamEvent::text("fresh"),
3487                MockStreamEvent::final_response_with_total_tokens(6),
3488            ],
3489        ]);
3490        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3491
3492        let mut stream = agent
3493            .stream_prompt("use the tool")
3494            .with_hook(RecordingTextAndSkipInvalidToolHook {
3495                text: text_hook.clone(),
3496            })
3497            .multi_turn(3)
3498            .with_history(Vec::<Message>::new())
3499            .await;
3500
3501        while let Some(item) = stream.next().await {
3502            match item {
3503                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3504                Ok(_) => {}
3505                Err(err) => panic!("unexpected streaming error: {err:?}"),
3506            }
3507        }
3508
3509        assert_eq!(
3510            text_hook.observed(),
3511            vec![
3512                ("stale ".to_string(), "stale ".to_string()),
3513                ("fresh".to_string(), "fresh".to_string()),
3514            ]
3515        );
3516    }
3517
3518    #[tokio::test]
3519    async fn invalid_tool_call_delta_retry_uses_structured_tool_feedback() {
3520        let delta_hook = RecordingToolCallDeltaHook::default();
3521        let add_calls = Arc::new(AtomicU32::new(0));
3522        let model = MockCompletionModel::from_stream_turns([
3523            vec![
3524                MockStreamEvent::text("checking "),
3525                MockStreamEvent::reasoning_delta(Some("rs_1"), "diagnostic reason"),
3526                MockStreamEvent::tool_call(
3527                    "tool_call_0",
3528                    "add",
3529                    serde_json::json!({"x": 1, "y": 2}),
3530                )
3531                .with_call_id("call_0"),
3532                MockStreamEvent::tool_call_arguments_delta(
3533                    "tool_call_1",
3534                    "internal_1",
3535                    r#"{"x":2,"y":3}"#,
3536                ),
3537                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3538                MockStreamEvent::final_response_with_total_tokens(4),
3539            ],
3540            vec![
3541                MockStreamEvent::text("retried"),
3542                MockStreamEvent::final_response_with_total_tokens(6),
3543            ],
3544        ]);
3545        let recorded = model.clone();
3546        let agent = AgentBuilder::new(model)
3547            .tool(CountingAddTool {
3548                calls: add_calls.clone(),
3549            })
3550            .build();
3551
3552        let mut stream = agent
3553            .stream_prompt("use the tool")
3554            .with_hook(RecordingDeltaAndRetryInvalidToolHook {
3555                delta: delta_hook.clone(),
3556            })
3557            .multi_turn(3)
3558            .with_history(Vec::<Message>::new())
3559            .max_invalid_tool_call_retries(1)
3560            .await;
3561        let mut completion_call_events = Vec::new();
3562        let mut final_response_text = None;
3563        let mut final_response_usage = Usage::new();
3564        let mut final_completion_calls = Vec::new();
3565
3566        while let Some(item) = stream.next().await {
3567            match item {
3568                Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
3569                    completion_call_events.push(completion_call);
3570                }
3571                Ok(MultiTurnStreamItem::StreamAssistantItem(
3572                    StreamedAssistantContent::ToolCallDelta { .. },
3573                )) => panic!("invalid tool-call delta should not be emitted"),
3574                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3575                    final_response_text = Some(response.response().to_string());
3576                    final_response_usage = response.usage();
3577                    final_completion_calls = response.completion_calls().to_vec();
3578                    break;
3579                }
3580                Ok(_) => {}
3581                Err(err) => panic!("unexpected streaming error: {err:?}"),
3582            }
3583        }
3584
3585        assert_eq!(final_response_text.as_deref(), Some("retried"));
3586        assert!(delta_hook.observed().is_empty());
3587        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3588        let mut first_usage = Usage::new();
3589        first_usage.total_tokens = 4;
3590        let mut second_usage = Usage::new();
3591        second_usage.total_tokens = 6;
3592        let expected_completion_calls = vec![
3593            CompletionCall::new(0, Some(first_usage)),
3594            CompletionCall::new(1, Some(second_usage)),
3595        ];
3596        assert_eq!(completion_call_events, expected_completion_calls);
3597        assert_eq!(final_completion_calls, expected_completion_calls);
3598        assert_eq!(final_response_usage.total_tokens, 10);
3599
3600        let requests = recorded.requests();
3601        assert_eq!(requests.len(), 2);
3602        let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3603        assert!(matches!(
3604            retry_history.get(1),
3605            Some(Message::Assistant { content, .. })
3606                if content.iter().any(|item| matches!(
3607                    item,
3608                    AssistantContent::Text(text) if text.text == "checking "
3609                ))
3610                    && content.iter().any(|item| matches!(
3611                        item,
3612                        AssistantContent::ToolCall(tool_call)
3613                            if tool_call.id == "tool_call_0"
3614                                && tool_call.function.name == "add"
3615                    ))
3616                    && content.iter().any(|item| matches!(
3617                    item,
3618                    AssistantContent::ToolCall(tool_call)
3619                        if tool_call.id == "tool_call_1"
3620                            && tool_call.function.name == "default_api"
3621                            && tool_call.function.arguments == serde_json::json!({"x": 2, "y": 3})
3622                ))
3623        ));
3624        assert!(matches!(
3625            retry_history.get(2),
3626            Some(Message::User { content })
3627                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3628                    && content.iter().any(|item| matches!(
3629                        item,
3630                        UserContent::ToolResult(result)
3631                            if result.id == "tool_call_0"
3632                                && result.call_id.as_deref() == Some("call_0")
3633                                && result.content.iter().any(|content| matches!(
3634                                    content,
3635                                    ToolResultContent::Text(text)
3636                                        if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3637                                ))
3638                    ))
3639                    && content.iter().any(|item| matches!(
3640                    item,
3641                    UserContent::ToolResult(result)
3642                        if result.id == "tool_call_1"
3643                            && result.content.iter().any(|content| matches!(
3644                                content,
3645                                ToolResultContent::Text(text)
3646                                    if text.text == "Use the add tool instead"
3647                            ))
3648                ))
3649        ));
3650    }
3651
3652    #[tokio::test]
3653    async fn invalid_tool_call_delta_context_includes_same_turn_history_and_tool_call_id() {
3654        let invalid_hook = RecordingInvalidToolCallHook::default();
3655        let model = MockCompletionModel::from_stream_turns([
3656            vec![
3657                MockStreamEvent::text("checking "),
3658                MockStreamEvent::reasoning_delta(Some("rs_1"), "diagnostic reason"),
3659                MockStreamEvent::tool_call(
3660                    "tool_call_0",
3661                    "add",
3662                    serde_json::json!({"x": 1, "y": 2}),
3663                )
3664                .with_call_id("call_0"),
3665                MockStreamEvent::tool_call_arguments_delta(
3666                    "tool_call_1",
3667                    "internal_1",
3668                    r#"{"x":2,"y":3}"#,
3669                ),
3670                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3671                MockStreamEvent::final_response_with_total_tokens(4),
3672            ],
3673            vec![
3674                MockStreamEvent::text("should not be requested"),
3675                MockStreamEvent::final_response_with_total_tokens(6),
3676            ],
3677        ]);
3678        let recorded = model.clone();
3679        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3680
3681        let mut stream = agent
3682            .stream_prompt("use the tool")
3683            .with_hook(invalid_hook.clone())
3684            .multi_turn(3)
3685            .await;
3686        let mut error = None;
3687
3688        while let Some(item) = stream.next().await {
3689            if let Err(err) = item {
3690                error = Some(err);
3691                break;
3692            }
3693        }
3694
3695        assert!(error.is_some(), "invalid name delta should fail");
3696        assert_eq!(recorded.request_count(), 1);
3697        let contexts = invalid_hook.observed();
3698        assert_eq!(contexts.len(), 1);
3699        let context = &contexts[0];
3700        assert_eq!(context.tool_name, "default_api");
3701        assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
3702        assert_eq!(context.internal_call_id.as_deref(), Some("internal_1"));
3703        assert!(context.is_streaming);
3704        assert!(history_contains_text(&context.chat_history, "checking "));
3705        assert!(
3706            assistant_reasoning_precedes_tool_call(
3707                &context.chat_history,
3708                "diagnostic reason",
3709                "add"
3710            ),
3711            "{:?}",
3712            context.chat_history
3713        );
3714        assert!(history_contains_tool_call(&context.chat_history, "add"));
3715        assert!(history_contains_tool_call(
3716            &context.chat_history,
3717            "default_api"
3718        ));
3719    }
3720
3721    #[tokio::test]
3722    async fn invalid_tool_call_delta_retry_resets_streaming_text_delta_state() {
3723        let text_hook = RecordingTextDeltaHook::default();
3724        let model = MockCompletionModel::from_stream_turns([
3725            vec![
3726                MockStreamEvent::text("stale "),
3727                MockStreamEvent::tool_call_arguments_delta(
3728                    "tool_call_1",
3729                    "internal_1",
3730                    r#"{"x":2,"y":3}"#,
3731                ),
3732                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3733                MockStreamEvent::final_response_with_total_tokens(4),
3734            ],
3735            vec![
3736                MockStreamEvent::text("fresh"),
3737                MockStreamEvent::final_response_with_total_tokens(6),
3738            ],
3739        ]);
3740        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3741
3742        let mut stream = agent
3743            .stream_prompt("use the tool")
3744            .with_hook(RecordingTextAndRetryInvalidToolHook {
3745                text: text_hook.clone(),
3746            })
3747            .multi_turn(3)
3748            .with_history(Vec::<Message>::new())
3749            .max_invalid_tool_call_retries(1)
3750            .await;
3751
3752        while let Some(item) = stream.next().await {
3753            match item {
3754                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3755                Ok(_) => {}
3756                Err(err) => panic!("unexpected streaming error: {err:?}"),
3757            }
3758        }
3759
3760        assert_eq!(
3761            text_hook.observed(),
3762            vec![
3763                ("stale ".to_string(), "stale ".to_string()),
3764                ("fresh".to_string(), "fresh".to_string()),
3765            ]
3766        );
3767    }
3768
3769    #[tokio::test]
3770    async fn invalid_tool_call_delta_skip_uses_structured_tool_feedback() {
3771        let delta_hook = RecordingToolCallDeltaHook::default();
3772        let add_calls = Arc::new(AtomicU32::new(0));
3773        let model = MockCompletionModel::from_stream_turns([
3774            vec![
3775                MockStreamEvent::text("checking "),
3776                MockStreamEvent::tool_call(
3777                    "tool_call_0",
3778                    "add",
3779                    serde_json::json!({"x": 1, "y": 2}),
3780                )
3781                .with_call_id("call_0"),
3782                MockStreamEvent::tool_call_arguments_delta(
3783                    "tool_call_1",
3784                    "internal_1",
3785                    r#"{"x":2,"y":3}"#,
3786                ),
3787                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3788                MockStreamEvent::final_response_with_total_tokens(4),
3789            ],
3790            vec![
3791                MockStreamEvent::text("continued"),
3792                MockStreamEvent::final_response_with_total_tokens(6),
3793            ],
3794        ]);
3795        let recorded = model.clone();
3796        let agent = AgentBuilder::new(model)
3797            .tool(CountingAddTool {
3798                calls: add_calls.clone(),
3799            })
3800            .build();
3801
3802        let mut stream = agent
3803            .stream_prompt("use the tool")
3804            .with_hook(RecordingDeltaAndSkipInvalidToolHook {
3805                delta: delta_hook.clone(),
3806            })
3807            .multi_turn(3)
3808            .with_history(Vec::<Message>::new())
3809            .await;
3810        let mut skipped_tool_result = None;
3811        let mut final_response_text = None;
3812
3813        while let Some(item) = stream.next().await {
3814            match item {
3815                Ok(MultiTurnStreamItem::StreamAssistantItem(
3816                    StreamedAssistantContent::ToolCallDelta { .. },
3817                )) => panic!("invalid tool-call delta should not be emitted"),
3818                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3819                    tool_result,
3820                    internal_call_id,
3821                })) => {
3822                    assert_eq!(internal_call_id, "internal_1");
3823                    skipped_tool_result = Some(tool_result);
3824                }
3825                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3826                    final_response_text = Some(response.response().to_string());
3827                    break;
3828                }
3829                Ok(_) => {}
3830                Err(err) => panic!("unexpected streaming error: {err:?}"),
3831            }
3832        }
3833
3834        let skipped_tool_result =
3835            skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
3836        assert_eq!(skipped_tool_result.id, "tool_call_1");
3837        assert!(skipped_tool_result.call_id.is_none());
3838        assert!(skipped_tool_result.content.iter().any(|content| matches!(
3839            content,
3840            ToolResultContent::Text(text) if text.text == "default_api was skipped"
3841        )));
3842        assert_eq!(final_response_text.as_deref(), Some("continued"));
3843        assert!(delta_hook.observed().is_empty());
3844        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3845
3846        let requests = recorded.requests();
3847        assert_eq!(requests.len(), 2);
3848        let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3849        assert!(matches!(
3850            follow_up_history.get(1),
3851            Some(Message::Assistant { content, .. })
3852                if content.iter().any(|item| matches!(
3853                    item,
3854                    AssistantContent::Text(text) if text.text == "checking "
3855                ))
3856                    && content.iter().any(|item| matches!(
3857                        item,
3858                        AssistantContent::ToolCall(tool_call)
3859                            if tool_call.id == "tool_call_0"
3860                                && tool_call.function.name == "add"
3861                    ))
3862                    && content.iter().any(|item| matches!(
3863                    item,
3864                    AssistantContent::ToolCall(tool_call)
3865                        if tool_call.id == "tool_call_1"
3866                            && tool_call.function.name == "default_api"
3867                            && tool_call.function.arguments == serde_json::json!({"x": 2, "y": 3})
3868                ))
3869        ));
3870        assert!(matches!(
3871            follow_up_history.get(2),
3872            Some(Message::User { content })
3873                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3874                    && content.iter().any(|item| matches!(
3875                        item,
3876                        UserContent::ToolResult(result)
3877                            if result.id == "tool_call_0"
3878                                && result.call_id.as_deref() == Some("call_0")
3879                                && result.content.iter().any(|content| matches!(
3880                                    content,
3881                                    ToolResultContent::Text(text)
3882                                        if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3883                                ))
3884                    ))
3885                    && content.iter().any(|item| matches!(
3886                    item,
3887                    UserContent::ToolResult(result)
3888                        if result.id == "tool_call_1"
3889                            && result.content.iter().any(|content| matches!(
3890                                content,
3891                                ToolResultContent::Text(text)
3892                                    if text.text == "default_api was skipped"
3893                            ))
3894                ))
3895        ));
3896    }
3897
3898    #[tokio::test]
3899    async fn streaming_retry_budget_exhaustion_history_contains_invalid_tool_call() {
3900        let model = MockCompletionModel::from_stream_turns([
3901            vec![
3902                MockStreamEvent::tool_call(
3903                    "tool_call_1",
3904                    "default_api",
3905                    serde_json::json!({"x": 1, "y": 2}),
3906                ),
3907                MockStreamEvent::final_response_with_total_tokens(4),
3908            ],
3909            vec![
3910                MockStreamEvent::text("should not be requested"),
3911                MockStreamEvent::final_response_with_total_tokens(6),
3912            ],
3913        ]);
3914        let recorded = model.clone();
3915        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3916
3917        let mut stream = agent
3918            .stream_prompt("use the tool")
3919            .with_hook(RetryDefaultApiHook)
3920            .multi_turn(3)
3921            .max_invalid_tool_call_retries(0)
3922            .await;
3923        let mut error = None;
3924
3925        while let Some(item) = stream.next().await {
3926            if let Err(err) = item {
3927                error = Some(err);
3928                break;
3929            }
3930        }
3931
3932        let error = error.expect("retry budget exhaustion should fail");
3933        match error {
3934            StreamingError::Prompt(err) => match *err {
3935                PromptError::UnknownToolCall {
3936                    tool_name,
3937                    chat_history,
3938                    ..
3939                } => {
3940                    assert_eq!(tool_name, "default_api");
3941                    assert!(history_contains_tool_call(&chat_history, "default_api"));
3942                }
3943                other => panic!("expected UnknownToolCall, got {other:?}"),
3944            },
3945            other => panic!("expected prompt streaming error, got {other:?}"),
3946        }
3947        assert_eq!(recorded.request_count(), 1);
3948    }
3949
3950    #[tokio::test]
3951    async fn streaming_name_delta_retry_budget_exhaustion_history_includes_same_turn_context() {
3952        let model = MockCompletionModel::from_stream_turns([
3953            vec![
3954                MockStreamEvent::text("checking "),
3955                MockStreamEvent::tool_call(
3956                    "tool_call_0",
3957                    "add",
3958                    serde_json::json!({"x": 1, "y": 2}),
3959                )
3960                .with_call_id("call_0"),
3961                MockStreamEvent::tool_call_arguments_delta(
3962                    "tool_call_1",
3963                    "internal_1",
3964                    r#"{"x":2,"y":3}"#,
3965                ),
3966                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3967                MockStreamEvent::final_response_with_total_tokens(4),
3968            ],
3969            vec![
3970                MockStreamEvent::text("should not be requested"),
3971                MockStreamEvent::final_response_with_total_tokens(6),
3972            ],
3973        ]);
3974        let recorded = model.clone();
3975        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3976
3977        let mut stream = agent
3978            .stream_prompt("use the tool")
3979            .with_hook(RetryDefaultApiHook)
3980            .multi_turn(3)
3981            .max_invalid_tool_call_retries(0)
3982            .await;
3983        let mut error = None;
3984
3985        while let Some(item) = stream.next().await {
3986            if let Err(err) = item {
3987                error = Some(err);
3988                break;
3989            }
3990        }
3991
3992        let error = error.expect("retry budget exhaustion should fail");
3993        match error {
3994            StreamingError::Prompt(err) => match *err {
3995                PromptError::UnknownToolCall {
3996                    tool_name,
3997                    chat_history,
3998                    ..
3999                } => {
4000                    assert_eq!(tool_name, "default_api");
4001                    assert!(history_contains_text(&chat_history, "checking "));
4002                    assert!(history_contains_tool_call(&chat_history, "add"));
4003                    assert!(history_contains_tool_call(&chat_history, "default_api"));
4004                }
4005                other => panic!("expected UnknownToolCall, got {other:?}"),
4006            },
4007            other => panic!("expected prompt streaming error, got {other:?}"),
4008        }
4009        assert_eq!(recorded.request_count(), 1);
4010    }
4011
4012    #[tokio::test]
4013    async fn completed_unknown_tool_call_after_text_fails_before_finish_hook_or_later_emit() {
4014        let add_calls = Arc::new(AtomicU32::new(0));
4015        let model = MockCompletionModel::from_stream_turns([
4016            vec![
4017                MockStreamEvent::text("thinking "),
4018                MockStreamEvent::tool_call(
4019                    "tool_call_1",
4020                    "default_api",
4021                    serde_json::json!({"x": 1, "y": 2}),
4022                ),
4023                MockStreamEvent::final_response_with_total_tokens(4),
4024            ],
4025            vec![
4026                MockStreamEvent::text("should not be requested"),
4027                MockStreamEvent::final_response_with_total_tokens(6),
4028            ],
4029        ]);
4030        let recorded = model.clone();
4031        let agent = AgentBuilder::new(model)
4032            .tool(CountingAddTool {
4033                calls: add_calls.clone(),
4034            })
4035            .build();
4036
4037        let mut stream = agent
4038            .stream_prompt("use the tool")
4039            .with_hook(PanicOnUnknownToolHook)
4040            .multi_turn(3)
4041            .await;
4042        let mut saw_text = false;
4043        let mut saw_completion_call = false;
4044        let mut saw_final_response = false;
4045        let mut saw_tool_call = false;
4046        let mut saw_tool_result = false;
4047        let mut error = None;
4048
4049        while let Some(item) = stream.next().await {
4050            match item {
4051                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(_))) => {
4052                    saw_text = true;
4053                }
4054                Ok(MultiTurnStreamItem::CompletionCall(_)) => {
4055                    saw_completion_call = true;
4056                }
4057                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Final(
4058                    _,
4059                )))
4060                | Ok(MultiTurnStreamItem::FinalResponse(_)) => {
4061                    saw_final_response = true;
4062                }
4063                Ok(MultiTurnStreamItem::StreamAssistantItem(
4064                    StreamedAssistantContent::ToolCall { .. },
4065                )) => {
4066                    saw_tool_call = true;
4067                }
4068                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
4069                    ..
4070                })) => {
4071                    saw_tool_result = true;
4072                }
4073                Ok(_) => {}
4074                Err(err) => {
4075                    error = Some(err);
4076                    break;
4077                }
4078            }
4079        }
4080
4081        assert!(saw_text);
4082        assert!(!saw_completion_call);
4083        assert!(!saw_final_response);
4084        assert!(!saw_tool_call);
4085        assert!(!saw_tool_result);
4086        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
4087        let error = error.expect("completed unknown tool call should fail immediately");
4088        match error {
4089            StreamingError::Prompt(err) => match *err {
4090                PromptError::UnknownToolCall {
4091                    tool_name,
4092                    available_tools,
4093                    allowed_tools,
4094                    chat_history,
4095                } => {
4096                    assert_eq!(tool_name, "default_api");
4097                    assert_eq!(available_tools, vec!["add".to_string()]);
4098                    assert_eq!(allowed_tools, vec!["add".to_string()]);
4099                    assert!(history_contains_tool_call(&chat_history, "default_api"));
4100                }
4101                other => panic!("expected UnknownToolCall, got {other:?}"),
4102            },
4103            other => panic!("expected prompt streaming error, got {other:?}"),
4104        }
4105        assert_eq!(recorded.request_count(), 1);
4106    }
4107
4108    #[tokio::test]
4109    async fn mixed_streaming_tool_calls_fail_before_any_tool_execution() {
4110        let add_calls = Arc::new(AtomicU32::new(0));
4111        let model = MockCompletionModel::from_stream_turns([
4112            vec![
4113                MockStreamEvent::tool_call(
4114                    "tool_call_1",
4115                    "add",
4116                    serde_json::json!({"x": 1, "y": 2}),
4117                )
4118                .with_call_id("call_1"),
4119                MockStreamEvent::tool_call(
4120                    "tool_call_2",
4121                    "default_api",
4122                    serde_json::json!({"x": 3, "y": 4}),
4123                ),
4124                MockStreamEvent::final_response_with_total_tokens(4),
4125            ],
4126            vec![
4127                MockStreamEvent::text("should not be requested"),
4128                MockStreamEvent::final_response_with_total_tokens(6),
4129            ],
4130        ]);
4131        let recorded = model.clone();
4132        let agent = AgentBuilder::new(model)
4133            .tool(CountingAddTool {
4134                calls: add_calls.clone(),
4135            })
4136            .build();
4137
4138        let mut stream = agent
4139            .stream_prompt("use tools")
4140            .with_hook(PanicOnUnknownToolHook)
4141            .multi_turn(3)
4142            .await;
4143        let mut saw_completion_call = false;
4144        let mut saw_tool_call = false;
4145        let mut saw_tool_result = false;
4146        let mut error = None;
4147
4148        while let Some(item) = stream.next().await {
4149            match item {
4150                Ok(MultiTurnStreamItem::CompletionCall(_)) => {
4151                    saw_completion_call = true;
4152                }
4153                Ok(MultiTurnStreamItem::StreamAssistantItem(
4154                    StreamedAssistantContent::ToolCall { .. },
4155                )) => {
4156                    saw_tool_call = true;
4157                }
4158                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
4159                    ..
4160                })) => {
4161                    saw_tool_result = true;
4162                }
4163                Ok(_) => {}
4164                Err(err) => {
4165                    error = Some(err);
4166                    break;
4167                }
4168            }
4169        }
4170
4171        assert!(!saw_completion_call);
4172        assert!(!saw_tool_call);
4173        assert!(!saw_tool_result);
4174        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
4175        let error = error.expect("mixed unknown streamed tool call should fail");
4176        match error {
4177            StreamingError::Prompt(err) => match *err {
4178                PromptError::UnknownToolCall {
4179                    tool_name,
4180                    available_tools,
4181                    allowed_tools,
4182                    chat_history,
4183                } => {
4184                    assert_eq!(tool_name, "default_api");
4185                    assert_eq!(available_tools, vec!["add".to_string()]);
4186                    assert_eq!(allowed_tools, vec!["add".to_string()]);
4187                    assert!(history_contains_tool_call(&chat_history, "default_api"));
4188                }
4189                other => panic!("expected UnknownToolCall, got {other:?}"),
4190            },
4191            other => panic!("expected prompt streaming error, got {other:?}"),
4192        }
4193        assert_eq!(recorded.request_count(), 1);
4194    }
4195
4196    #[tokio::test]
4197    async fn multiple_valid_streaming_tool_calls_execute_after_batch_validation() {
4198        let add_calls = Arc::new(AtomicU32::new(0));
4199        let subtract_calls = Arc::new(AtomicU32::new(0));
4200        let model = MockCompletionModel::from_stream_turns([
4201            vec![
4202                MockStreamEvent::tool_call(
4203                    "tool_call_1",
4204                    "add",
4205                    serde_json::json!({"x": 1, "y": 2}),
4206                )
4207                .with_call_id("call_1"),
4208                MockStreamEvent::tool_call(
4209                    "tool_call_2",
4210                    "subtract",
4211                    serde_json::json!({"x": 8, "y": 3}),
4212                )
4213                .with_call_id("call_2"),
4214                MockStreamEvent::final_response_with_total_tokens(4),
4215            ],
4216            vec![
4217                MockStreamEvent::text("done"),
4218                MockStreamEvent::final_response_with_total_tokens(6),
4219            ],
4220        ]);
4221        let recorded = model.clone();
4222        let agent = AgentBuilder::new(model)
4223            .tool(CountingAddTool {
4224                calls: add_calls.clone(),
4225            })
4226            .tool(CountingSubtractTool {
4227                calls: subtract_calls.clone(),
4228            })
4229            .build();
4230
4231        let mut stream = agent.stream_prompt("use tools").multi_turn(3).await;
4232        let mut tool_call_names = Vec::new();
4233        let mut tool_result_ids = Vec::new();
4234        let mut final_response_text = None;
4235
4236        while let Some(item) = stream.next().await {
4237            match item {
4238                Ok(MultiTurnStreamItem::StreamAssistantItem(
4239                    StreamedAssistantContent::ToolCall { tool_call, .. },
4240                )) => {
4241                    tool_call_names.push(tool_call.function.name);
4242                }
4243                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
4244                    tool_result,
4245                    ..
4246                })) => {
4247                    tool_result_ids.push(tool_result.id);
4248                }
4249                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4250                    final_response_text = Some(response.response().to_owned());
4251                    break;
4252                }
4253                Ok(_) => {}
4254                Err(err) => panic!("unexpected streaming error: {err:?}"),
4255            }
4256        }
4257
4258        assert_eq!(
4259            tool_call_names,
4260            vec!["add".to_string(), "subtract".to_string()]
4261        );
4262        assert_eq!(
4263            tool_result_ids,
4264            vec!["tool_call_1".to_string(), "tool_call_2".to_string()]
4265        );
4266        assert_eq!(add_calls.load(Ordering::SeqCst), 1);
4267        assert_eq!(subtract_calls.load(Ordering::SeqCst), 1);
4268        assert_eq!(final_response_text.as_deref(), Some("done"));
4269        assert_eq!(recorded.request_count(), 2);
4270    }
4271
4272    #[tokio::test]
4273    async fn disallowed_specific_tool_call_fails_before_streaming_second_request() {
4274        let model = MockCompletionModel::from_stream_turns([
4275            vec![
4276                MockStreamEvent::tool_call(
4277                    "tool_call_1",
4278                    "subtract",
4279                    serde_json::json!({"x": 3, "y": 1}),
4280                ),
4281                MockStreamEvent::final_response_with_total_tokens(4),
4282            ],
4283            vec![
4284                MockStreamEvent::text("should not be requested"),
4285                MockStreamEvent::final_response_with_total_tokens(6),
4286            ],
4287        ]);
4288        let recorded = model.clone();
4289        let agent = AgentBuilder::new(model)
4290            .tool(MockAddTool)
4291            .tool(MockSubtractTool)
4292            .tool_choice(ToolChoice::Specific {
4293                function_names: vec!["add".to_string()],
4294            })
4295            .build();
4296
4297        let mut stream = agent
4298            .stream_prompt("use the allowed tool")
4299            .with_hook(PanicOnUnknownToolHook)
4300            .multi_turn(3)
4301            .await;
4302        let mut saw_tool_call = false;
4303        let mut error = None;
4304
4305        while let Some(item) = stream.next().await {
4306            match item {
4307                Ok(MultiTurnStreamItem::StreamAssistantItem(
4308                    StreamedAssistantContent::ToolCall { .. },
4309                )) => {
4310                    saw_tool_call = true;
4311                }
4312                Ok(_) => {}
4313                Err(err) => {
4314                    error = Some(err);
4315                    break;
4316                }
4317            }
4318        }
4319
4320        assert!(!saw_tool_call);
4321        let error = error.expect("disallowed model-emitted tool should fail");
4322        match error {
4323            StreamingError::Prompt(err) => match *err {
4324                PromptError::UnknownToolCall {
4325                    tool_name,
4326                    available_tools,
4327                    allowed_tools,
4328                    chat_history,
4329                } => {
4330                    assert_eq!(tool_name, "subtract");
4331                    assert_eq!(
4332                        available_tools,
4333                        vec!["add".to_string(), "subtract".to_string()]
4334                    );
4335                    assert_eq!(allowed_tools, vec!["add".to_string()]);
4336                    assert!(history_contains_tool_call(&chat_history, "subtract"));
4337                }
4338                other => panic!("expected UnknownToolCall, got {other:?}"),
4339            },
4340            other => panic!("expected prompt streaming error, got {other:?}"),
4341        }
4342        assert_eq!(recorded.request_count(), 1);
4343    }
4344
4345    #[tokio::test]
4346    async fn mixed_specific_tool_calls_fail_before_any_tool_execution() {
4347        let add_calls = Arc::new(AtomicU32::new(0));
4348        let model = MockCompletionModel::from_stream_turns([
4349            vec![
4350                MockStreamEvent::tool_call(
4351                    "tool_call_1",
4352                    "add",
4353                    serde_json::json!({"x": 1, "y": 2}),
4354                ),
4355                MockStreamEvent::tool_call(
4356                    "tool_call_2",
4357                    "subtract",
4358                    serde_json::json!({"x": 3, "y": 1}),
4359                ),
4360                MockStreamEvent::final_response_with_total_tokens(4),
4361            ],
4362            vec![
4363                MockStreamEvent::text("should not be requested"),
4364                MockStreamEvent::final_response_with_total_tokens(6),
4365            ],
4366        ]);
4367        let recorded = model.clone();
4368        let agent = AgentBuilder::new(model)
4369            .tool(CountingAddTool {
4370                calls: add_calls.clone(),
4371            })
4372            .tool(MockSubtractTool)
4373            .tool_choice(ToolChoice::Specific {
4374                function_names: vec!["add".to_string()],
4375            })
4376            .build();
4377
4378        let mut stream = agent
4379            .stream_prompt("use the allowed tool")
4380            .with_hook(PanicOnUnknownToolHook)
4381            .multi_turn(3)
4382            .await;
4383        let mut saw_tool_call = false;
4384        let mut saw_tool_result = false;
4385        let mut error = None;
4386
4387        while let Some(item) = stream.next().await {
4388            match item {
4389                Ok(MultiTurnStreamItem::StreamAssistantItem(
4390                    StreamedAssistantContent::ToolCall { .. },
4391                )) => {
4392                    saw_tool_call = true;
4393                }
4394                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
4395                    ..
4396                })) => {
4397                    saw_tool_result = true;
4398                }
4399                Ok(_) => {}
4400                Err(err) => {
4401                    error = Some(err);
4402                    break;
4403                }
4404            }
4405        }
4406
4407        assert!(!saw_tool_call);
4408        assert!(!saw_tool_result);
4409        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
4410        let error = error.expect("mixed disallowed streamed tool call should fail");
4411        match error {
4412            StreamingError::Prompt(err) => match *err {
4413                PromptError::UnknownToolCall {
4414                    tool_name,
4415                    available_tools,
4416                    allowed_tools,
4417                    chat_history,
4418                } => {
4419                    assert_eq!(tool_name, "subtract");
4420                    assert_eq!(
4421                        available_tools,
4422                        vec!["add".to_string(), "subtract".to_string()]
4423                    );
4424                    assert_eq!(allowed_tools, vec!["add".to_string()]);
4425                    assert!(history_contains_tool_call(&chat_history, "subtract"));
4426                }
4427                other => panic!("expected UnknownToolCall, got {other:?}"),
4428            },
4429            other => panic!("expected prompt streaming error, got {other:?}"),
4430        }
4431        assert_eq!(recorded.request_count(), 1);
4432    }
4433
4434    #[tokio::test]
4435    async fn tool_choice_none_rejects_streaming_tool_call() {
4436        let model = MockCompletionModel::from_stream_turns([
4437            vec![
4438                MockStreamEvent::tool_call(
4439                    "tool_call_1",
4440                    "add",
4441                    serde_json::json!({"x": 1, "y": 2}),
4442                ),
4443                MockStreamEvent::final_response_with_total_tokens(4),
4444            ],
4445            vec![
4446                MockStreamEvent::text("should not be requested"),
4447                MockStreamEvent::final_response_with_total_tokens(6),
4448            ],
4449        ]);
4450        let recorded = model.clone();
4451        let agent = AgentBuilder::new(model)
4452            .tool(MockAddTool)
4453            .tool_choice(ToolChoice::None)
4454            .build();
4455
4456        let mut stream = agent
4457            .stream_prompt("do not use tools")
4458            .with_hook(PanicOnUnknownToolHook)
4459            .multi_turn(3)
4460            .await;
4461        let mut saw_tool_call = false;
4462        let mut error = None;
4463
4464        while let Some(item) = stream.next().await {
4465            match item {
4466                Ok(MultiTurnStreamItem::StreamAssistantItem(
4467                    StreamedAssistantContent::ToolCall { .. },
4468                )) => {
4469                    saw_tool_call = true;
4470                }
4471                Ok(_) => {}
4472                Err(err) => {
4473                    error = Some(err);
4474                    break;
4475                }
4476            }
4477        }
4478
4479        assert!(!saw_tool_call);
4480        let error = error.expect("ToolChoice::None should reject returned tool calls");
4481        match error {
4482            StreamingError::Prompt(err) => match *err {
4483                PromptError::UnknownToolCall {
4484                    tool_name,
4485                    available_tools,
4486                    allowed_tools,
4487                    chat_history,
4488                } => {
4489                    assert_eq!(tool_name, "add");
4490                    assert_eq!(available_tools, vec!["add".to_string()]);
4491                    assert!(allowed_tools.is_empty());
4492                    assert!(history_contains_tool_call(&chat_history, "add"));
4493                }
4494                other => panic!("expected UnknownToolCall, got {other:?}"),
4495            },
4496            other => panic!("expected prompt streaming error, got {other:?}"),
4497        }
4498        assert_eq!(recorded.request_count(), 1);
4499    }
4500
4501    #[tokio::test]
4502    async fn tool_choice_none_rejects_streaming_tool_call_name_delta_before_hook_or_emit() {
4503        let model = MockCompletionModel::from_stream_turns([
4504            vec![
4505                MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4506                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4507                MockStreamEvent::final_response_with_total_tokens(4),
4508            ],
4509            vec![
4510                MockStreamEvent::text("should not be requested"),
4511                MockStreamEvent::final_response_with_total_tokens(6),
4512            ],
4513        ]);
4514        let recorded = model.clone();
4515        let agent = AgentBuilder::new(model)
4516            .tool(MockAddTool)
4517            .tool_choice(ToolChoice::None)
4518            .build();
4519
4520        let mut stream = agent
4521            .stream_prompt("do not use tools")
4522            .with_hook(PanicOnUnknownToolHook)
4523            .multi_turn(3)
4524            .await;
4525        let mut saw_delta = false;
4526        let mut error = None;
4527
4528        while let Some(item) = stream.next().await {
4529            match item {
4530                Ok(MultiTurnStreamItem::StreamAssistantItem(
4531                    StreamedAssistantContent::ToolCallDelta { .. },
4532                )) => {
4533                    saw_delta = true;
4534                }
4535                Ok(_) => {}
4536                Err(err) => {
4537                    error = Some(err);
4538                    break;
4539                }
4540            }
4541        }
4542
4543        assert!(!saw_delta);
4544        let error = error.expect("ToolChoice::None should reject returned tool-call deltas");
4545        match error {
4546            StreamingError::Prompt(err) => match *err {
4547                PromptError::UnknownToolCall {
4548                    tool_name,
4549                    available_tools,
4550                    allowed_tools,
4551                    chat_history,
4552                } => {
4553                    assert_eq!(tool_name, "add");
4554                    assert_eq!(available_tools, vec!["add".to_string()]);
4555                    assert!(allowed_tools.is_empty());
4556                    assert!(history_contains_tool_call(&chat_history, "add"));
4557                }
4558                other => panic!("expected UnknownToolCall, got {other:?}"),
4559            },
4560            other => panic!("expected prompt streaming error, got {other:?}"),
4561        }
4562        assert_eq!(recorded.request_count(), 1);
4563    }
4564
4565    #[tokio::test]
4566    async fn unknown_tool_call_name_delta_fails_before_streaming_delta_hook_or_emit() {
4567        let model = MockCompletionModel::from_stream_turns([
4568            vec![
4569                MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"),
4570                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4571                MockStreamEvent::final_response_with_total_tokens(4),
4572            ],
4573            vec![
4574                MockStreamEvent::text("should not be requested"),
4575                MockStreamEvent::final_response_with_total_tokens(6),
4576            ],
4577        ]);
4578        let recorded = model.clone();
4579        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4580
4581        let mut stream = agent
4582            .stream_prompt("stream a bad tool call")
4583            .with_hook(PanicOnUnknownToolHook)
4584            .multi_turn(3)
4585            .await;
4586        let mut saw_delta = false;
4587        let mut error = None;
4588
4589        while let Some(item) = stream.next().await {
4590            match item {
4591                Ok(MultiTurnStreamItem::StreamAssistantItem(
4592                    StreamedAssistantContent::ToolCallDelta { .. },
4593                )) => {
4594                    saw_delta = true;
4595                }
4596                Ok(_) => {}
4597                Err(err) => {
4598                    error = Some(err);
4599                    break;
4600                }
4601            }
4602        }
4603
4604        assert!(!saw_delta);
4605        let error = error.expect("unknown tool-call name delta should fail");
4606        match error {
4607            StreamingError::Prompt(err) => match *err {
4608                PromptError::UnknownToolCall {
4609                    tool_name,
4610                    available_tools,
4611                    allowed_tools,
4612                    chat_history,
4613                } => {
4614                    assert_eq!(tool_name, "default_api");
4615                    assert_eq!(available_tools, vec!["add".to_string()]);
4616                    assert_eq!(allowed_tools, vec!["add".to_string()]);
4617                    assert!(history_contains_tool_call(&chat_history, "default_api"));
4618                }
4619                other => panic!("expected UnknownToolCall, got {other:?}"),
4620            },
4621            other => panic!("expected prompt streaming error, got {other:?}"),
4622        }
4623        assert_eq!(recorded.request_count(), 1);
4624    }
4625
4626    #[tokio::test]
4627    async fn tool_call_args_delta_before_unknown_name_fails_before_hook_or_emit() {
4628        let model = MockCompletionModel::from_stream_turns([
4629            vec![
4630                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4631                MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"),
4632                MockStreamEvent::final_response_with_total_tokens(4),
4633            ],
4634            vec![
4635                MockStreamEvent::text("should not be requested"),
4636                MockStreamEvent::final_response_with_total_tokens(6),
4637            ],
4638        ]);
4639        let recorded = model.clone();
4640        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4641
4642        let mut stream = agent
4643            .stream_prompt("stream a bad tool call")
4644            .with_hook(PanicOnUnknownToolHook)
4645            .multi_turn(3)
4646            .await;
4647        let mut saw_delta = false;
4648        let mut error = None;
4649
4650        while let Some(item) = stream.next().await {
4651            match item {
4652                Ok(MultiTurnStreamItem::StreamAssistantItem(
4653                    StreamedAssistantContent::ToolCallDelta { .. },
4654                )) => {
4655                    saw_delta = true;
4656                }
4657                Ok(_) => {}
4658                Err(err) => {
4659                    error = Some(err);
4660                    break;
4661                }
4662            }
4663        }
4664
4665        assert!(!saw_delta);
4666        let error = error.expect("unknown tool-call name should reject buffered args");
4667        match error {
4668            StreamingError::Prompt(err) => match *err {
4669                PromptError::UnknownToolCall {
4670                    tool_name,
4671                    available_tools,
4672                    allowed_tools,
4673                    chat_history,
4674                } => {
4675                    assert_eq!(tool_name, "default_api");
4676                    assert_eq!(available_tools, vec!["add".to_string()]);
4677                    assert_eq!(allowed_tools, vec!["add".to_string()]);
4678                    assert!(history_contains_tool_call(&chat_history, "default_api"));
4679                }
4680                other => panic!("expected UnknownToolCall, got {other:?}"),
4681            },
4682            other => panic!("expected prompt streaming error, got {other:?}"),
4683        }
4684        assert_eq!(recorded.request_count(), 1);
4685    }
4686
4687    #[tokio::test]
4688    async fn tool_call_args_delta_before_valid_name_buffers_then_emits_in_safe_order() {
4689        let model = MockCompletionModel::from_stream_turns([[
4690            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4691            MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4692            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4693            MockStreamEvent::final_response_with_total_tokens(3),
4694        ]]);
4695        let hook = RecordingToolCallDeltaHook::default();
4696        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4697
4698        let mut stream = agent
4699            .stream_prompt("stream a tool call")
4700            .with_hook(hook.clone())
4701            .await;
4702        let mut stream_deltas = Vec::new();
4703
4704        while let Some(item) = stream.next().await {
4705            match item {
4706                Ok(MultiTurnStreamItem::StreamAssistantItem(
4707                    StreamedAssistantContent::ToolCallDelta {
4708                        id,
4709                        internal_call_id,
4710                        content,
4711                    },
4712                )) => {
4713                    stream_deltas.push((id, internal_call_id, content));
4714                }
4715                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4716                Ok(_) => {}
4717                Err(err) => panic!("unexpected streaming error: {err:?}"),
4718            }
4719        }
4720
4721        assert_eq!(
4722            hook.observed(),
4723            vec![
4724                (
4725                    "tool_1".to_string(),
4726                    "internal_1".to_string(),
4727                    Some("add".to_string()),
4728                    String::new()
4729                ),
4730                (
4731                    "tool_1".to_string(),
4732                    "internal_1".to_string(),
4733                    None,
4734                    "{\"x\":".to_string()
4735                ),
4736                (
4737                    "tool_1".to_string(),
4738                    "internal_1".to_string(),
4739                    None,
4740                    "1}".to_string()
4741                ),
4742            ]
4743        );
4744        assert_eq!(
4745            stream_deltas,
4746            vec![
4747                (
4748                    "tool_1".to_string(),
4749                    "internal_1".to_string(),
4750                    ToolCallDeltaContent::Name("add".to_string())
4751                ),
4752                (
4753                    "tool_1".to_string(),
4754                    "internal_1".to_string(),
4755                    ToolCallDeltaContent::Delta("{\"x\":".to_string())
4756                ),
4757                (
4758                    "tool_1".to_string(),
4759                    "internal_1".to_string(),
4760                    ToolCallDeltaContent::Delta("1}".to_string())
4761                ),
4762            ]
4763        );
4764    }
4765
4766    #[tokio::test]
4767    async fn tool_call_args_delta_without_name_errors_at_stream_end() {
4768        let model = MockCompletionModel::from_stream_turns([
4769            vec![
4770                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4771                MockStreamEvent::final_response_with_total_tokens(4),
4772            ],
4773            vec![
4774                MockStreamEvent::text("should not be requested"),
4775                MockStreamEvent::final_response_with_total_tokens(6),
4776            ],
4777        ]);
4778        let recorded = model.clone();
4779        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4780
4781        let mut stream = agent
4782            .stream_prompt("stream an incomplete tool call")
4783            .with_hook(PanicOnUnknownToolHook)
4784            .multi_turn(3)
4785            .await;
4786        let mut saw_delta = false;
4787        let mut saw_completion_call = false;
4788        let mut saw_final_response = false;
4789        let mut error = None;
4790
4791        while let Some(item) = stream.next().await {
4792            match item {
4793                Ok(MultiTurnStreamItem::StreamAssistantItem(
4794                    StreamedAssistantContent::ToolCallDelta { .. },
4795                )) => {
4796                    saw_delta = true;
4797                }
4798                Ok(MultiTurnStreamItem::CompletionCall(_)) => {
4799                    saw_completion_call = true;
4800                }
4801                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
4802                    saw_final_response = true;
4803                }
4804                Ok(_) => {}
4805                Err(err) => {
4806                    error = Some(err);
4807                    break;
4808                }
4809            }
4810        }
4811
4812        assert!(!saw_delta);
4813        assert!(!saw_completion_call);
4814        assert!(!saw_final_response);
4815        let error = error.expect("unterminated tool-call args delta should fail");
4816        match error {
4817            StreamingError::Completion(CompletionError::ResponseError(message)) => {
4818                assert!(
4819                    message.contains("streamed tool call arguments"),
4820                    "{message}"
4821                );
4822                assert!(message.contains("tool_1"), "{message}");
4823                assert!(message.contains("internal_1"), "{message}");
4824            }
4825            other => panic!("expected completion response error, got {other:?}"),
4826        }
4827        assert_eq!(recorded.request_count(), 1);
4828    }
4829
4830    #[tokio::test]
4831    async fn tool_choice_none_buffers_args_then_rejects_name_without_emit() {
4832        let model = MockCompletionModel::from_stream_turns([
4833            vec![
4834                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4835                MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4836                MockStreamEvent::final_response_with_total_tokens(4),
4837            ],
4838            vec![
4839                MockStreamEvent::text("should not be requested"),
4840                MockStreamEvent::final_response_with_total_tokens(6),
4841            ],
4842        ]);
4843        let recorded = model.clone();
4844        let agent = AgentBuilder::new(model)
4845            .tool(MockAddTool)
4846            .tool_choice(ToolChoice::None)
4847            .build();
4848
4849        let mut stream = agent
4850            .stream_prompt("do not use tools")
4851            .with_hook(PanicOnUnknownToolHook)
4852            .multi_turn(3)
4853            .await;
4854        let mut saw_delta = false;
4855        let mut error = None;
4856
4857        while let Some(item) = stream.next().await {
4858            match item {
4859                Ok(MultiTurnStreamItem::StreamAssistantItem(
4860                    StreamedAssistantContent::ToolCallDelta { .. },
4861                )) => {
4862                    saw_delta = true;
4863                }
4864                Ok(_) => {}
4865                Err(err) => {
4866                    error = Some(err);
4867                    break;
4868                }
4869            }
4870        }
4871
4872        assert!(!saw_delta);
4873        let error = error.expect("ToolChoice::None should reject buffered tool-call deltas");
4874        match error {
4875            StreamingError::Prompt(err) => match *err {
4876                PromptError::UnknownToolCall {
4877                    tool_name,
4878                    available_tools,
4879                    allowed_tools,
4880                    chat_history,
4881                } => {
4882                    assert_eq!(tool_name, "add");
4883                    assert_eq!(available_tools, vec!["add".to_string()]);
4884                    assert!(allowed_tools.is_empty());
4885                    assert!(history_contains_tool_call(&chat_history, "add"));
4886                }
4887                other => panic!("expected UnknownToolCall, got {other:?}"),
4888            },
4889            other => panic!("expected prompt streaming error, got {other:?}"),
4890        }
4891        assert_eq!(recorded.request_count(), 1);
4892    }
4893
4894    #[tokio::test]
4895    async fn stream_prompt_emits_tool_call_deltas_without_hook() {
4896        let model = MockCompletionModel::from_stream_turns([[
4897            MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4898            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4899            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4900            MockStreamEvent::final_response_with_total_tokens(3),
4901        ]]);
4902        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4903
4904        let mut stream = agent.stream_prompt("stream a tool call").await;
4905        let mut deltas = Vec::new();
4906
4907        while let Some(item) = stream.next().await {
4908            match item {
4909                Ok(MultiTurnStreamItem::StreamAssistantItem(
4910                    StreamedAssistantContent::ToolCallDelta {
4911                        id,
4912                        internal_call_id,
4913                        content,
4914                    },
4915                )) => {
4916                    deltas.push((id, internal_call_id, content));
4917                }
4918                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4919                Ok(_) => {}
4920                Err(err) => panic!("unexpected streaming error: {err:?}"),
4921            }
4922        }
4923
4924        assert_eq!(
4925            deltas,
4926            vec![
4927                (
4928                    "tool_1".to_string(),
4929                    "internal_1".to_string(),
4930                    ToolCallDeltaContent::Name("add".to_string())
4931                ),
4932                (
4933                    "tool_1".to_string(),
4934                    "internal_1".to_string(),
4935                    ToolCallDeltaContent::Delta("{\"x\":".to_string())
4936                ),
4937                (
4938                    "tool_1".to_string(),
4939                    "internal_1".to_string(),
4940                    ToolCallDeltaContent::Delta("1}".to_string())
4941                ),
4942            ]
4943        );
4944    }
4945
4946    #[tokio::test]
4947    async fn stream_prompt_emits_tool_call_deltas_after_hook_continue() {
4948        let model = MockCompletionModel::from_stream_turns([[
4949            MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4950            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4951            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4952            MockStreamEvent::final_response_with_total_tokens(3),
4953        ]]);
4954        let hook = RecordingToolCallDeltaHook::default();
4955        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4956
4957        let mut stream = agent
4958            .stream_prompt("stream a tool call")
4959            .with_hook(hook.clone())
4960            .await;
4961        let mut stream_deltas = Vec::new();
4962
4963        while let Some(item) = stream.next().await {
4964            match item {
4965                Ok(MultiTurnStreamItem::StreamAssistantItem(
4966                    StreamedAssistantContent::ToolCallDelta {
4967                        id,
4968                        internal_call_id,
4969                        content,
4970                    },
4971                )) => {
4972                    stream_deltas.push((id, internal_call_id, content));
4973                }
4974                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4975                Ok(_) => {}
4976                Err(err) => panic!("unexpected streaming error: {err:?}"),
4977            }
4978        }
4979
4980        assert_eq!(
4981            hook.observed(),
4982            vec![
4983                (
4984                    "tool_1".to_string(),
4985                    "internal_1".to_string(),
4986                    Some("add".to_string()),
4987                    String::new()
4988                ),
4989                (
4990                    "tool_1".to_string(),
4991                    "internal_1".to_string(),
4992                    None,
4993                    "{\"x\":".to_string()
4994                ),
4995                (
4996                    "tool_1".to_string(),
4997                    "internal_1".to_string(),
4998                    None,
4999                    "1}".to_string()
5000                ),
5001            ]
5002        );
5003        assert_eq!(
5004            stream_deltas,
5005            vec![
5006                (
5007                    "tool_1".to_string(),
5008                    "internal_1".to_string(),
5009                    ToolCallDeltaContent::Name("add".to_string())
5010                ),
5011                (
5012                    "tool_1".to_string(),
5013                    "internal_1".to_string(),
5014                    ToolCallDeltaContent::Delta("{\"x\":".to_string())
5015                ),
5016                (
5017                    "tool_1".to_string(),
5018                    "internal_1".to_string(),
5019                    ToolCallDeltaContent::Delta("1}".to_string())
5020                ),
5021            ]
5022        );
5023    }
5024
5025    #[tokio::test]
5026    async fn stream_prompt_tool_call_deltas_hook_termination_prevents_delta_emit() {
5027        let model = MockCompletionModel::from_stream_turns([[
5028            MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
5029            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
5030            MockStreamEvent::final_response_with_total_tokens(3),
5031        ]]);
5032        let hook = TerminatingToolCallDeltaHook::default();
5033        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5034
5035        let mut stream = agent
5036            .stream_prompt("stream a tool call")
5037            .with_hook(hook.clone())
5038            .await;
5039        let mut saw_delta = false;
5040        let mut saw_final_response = false;
5041        let mut error_message = None;
5042
5043        while let Some(item) = stream.next().await {
5044            match item {
5045                Ok(MultiTurnStreamItem::StreamAssistantItem(
5046                    StreamedAssistantContent::ToolCallDelta { .. },
5047                )) => {
5048                    saw_delta = true;
5049                }
5050                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
5051                    saw_final_response = true;
5052                }
5053                Ok(_) => {}
5054                Err(err) => {
5055                    error_message = Some(err.to_string());
5056                    break;
5057                }
5058            }
5059        }
5060
5061        assert_eq!(
5062            hook.observed(),
5063            vec![(
5064                "tool_1".to_string(),
5065                "internal_1".to_string(),
5066                Some("add".to_string()),
5067                String::new()
5068            )]
5069        );
5070        assert!(!saw_delta);
5071        assert!(!saw_final_response);
5072        assert!(
5073            error_message
5074                .as_deref()
5075                .is_some_and(|message| message.contains("PromptCancelled: stop on tool call delta")),
5076            "expected hook termination error, got {error_message:?}"
5077        );
5078    }
5079
5080    #[tokio::test]
5081    async fn stream_prompt_exposes_completion_calls() {
5082        let first_call_usage = usage(10, 2);
5083        let second_call_usage = usage(25, 5);
5084        let model = MockCompletionModel::from_stream_turns([
5085            vec![
5086                MockStreamEvent::tool_call(
5087                    "tool_call_1",
5088                    "add",
5089                    serde_json::json!({"x": 1, "y": 2}),
5090                )
5091                .with_call_id("call_1"),
5092                MockStreamEvent::final_response(first_call_usage),
5093            ],
5094            vec![
5095                MockStreamEvent::text("done"),
5096                MockStreamEvent::final_response(second_call_usage),
5097            ],
5098        ]);
5099        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5100        let empty_history: &[Message] = &[];
5101
5102        let mut stream = agent
5103            .stream_prompt("do tool work")
5104            .with_history(empty_history)
5105            .multi_turn(3)
5106            .await;
5107        let mut completion_calls_events = Vec::new();
5108        let mut final_response = None;
5109
5110        while let Some(item) = stream.next().await {
5111            match item {
5112                Ok(MultiTurnStreamItem::CompletionCall(call_usage)) => {
5113                    completion_calls_events.push(call_usage);
5114                }
5115                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
5116                    final_response = Some(response);
5117                    break;
5118                }
5119                Ok(_) => {}
5120                Err(err) => panic!("unexpected streaming error: {err:?}"),
5121            }
5122        }
5123
5124        assert_eq!(
5125            completion_calls_events,
5126            vec![
5127                CompletionCall::new(0, Some(first_call_usage)),
5128                CompletionCall::new(1, Some(second_call_usage))
5129            ]
5130        );
5131
5132        let final_response = final_response.expect("expected final response");
5133        assert_eq!(
5134            final_response.usage(),
5135            Usage {
5136                input_tokens: 35,
5137                output_tokens: 7,
5138                total_tokens: 42,
5139                cached_input_tokens: 0,
5140                cache_creation_input_tokens: 0,
5141                tool_use_prompt_tokens: 0,
5142                reasoning_tokens: 0,
5143            }
5144        );
5145        assert_eq!(
5146            final_response.completion_calls(),
5147            &[
5148                CompletionCall::new(0, Some(first_call_usage)),
5149                CompletionCall::new(1, Some(second_call_usage))
5150            ]
5151        );
5152    }
5153
5154    #[tokio::test(flavor = "current_thread")]
5155    async fn stream_prompt_records_single_call_usage_on_chat_span_under_outer_span() {
5156        let call_usage = usage(10, 2);
5157        let model = MockCompletionModel::from_stream_turns([[
5158            MockStreamEvent::text("done"),
5159            MockStreamEvent::final_response(call_usage),
5160        ]]);
5161        let agent = AgentBuilder::new(model).build();
5162
5163        assert_stream_usage_recorded_on_chat_spans(agent, "say done", 1, &[call_usage]).await;
5164    }
5165
5166    #[tokio::test(flavor = "current_thread")]
5167    async fn stream_prompt_records_multi_turn_usage_on_chat_spans_under_outer_span() {
5168        let first_call_usage = usage(10, 2);
5169        let second_call_usage = usage(25, 5);
5170        let model = MockCompletionModel::from_stream_turns([
5171            vec![
5172                MockStreamEvent::tool_call(
5173                    "tool_call_1",
5174                    "add",
5175                    serde_json::json!({"x": 1, "y": 2}),
5176                )
5177                .with_call_id("call_1"),
5178                MockStreamEvent::final_response(first_call_usage),
5179            ],
5180            vec![
5181                MockStreamEvent::text("done"),
5182                MockStreamEvent::final_response(second_call_usage),
5183            ],
5184        ]);
5185        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5186
5187        assert_stream_usage_recorded_on_chat_spans(
5188            agent,
5189            "do tool work",
5190            3,
5191            &[first_call_usage, second_call_usage],
5192        )
5193        .await;
5194    }
5195
5196    #[tokio::test]
5197    async fn stream_prompt_emits_completion_call_before_finish_hook_termination() {
5198        let call_usage = usage(10, 2);
5199        let model = MockCompletionModel::from_stream_turns([[
5200            MockStreamEvent::text("done"),
5201            MockStreamEvent::final_response(call_usage),
5202        ]]);
5203        let agent = AgentBuilder::new(model).build();
5204
5205        let mut stream = agent
5206            .stream_prompt("say done")
5207            .with_hook(TerminateOnStreamFinish)
5208            .await;
5209        let mut completion_calls = Vec::new();
5210        let mut saw_error = false;
5211
5212        while let Some(item) = stream.next().await {
5213            match item {
5214                Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
5215                    completion_calls.push(completion_call);
5216                }
5217                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
5218                    panic!("unexpected final response after hook termination: {response:?}");
5219                }
5220                Ok(_) => {}
5221                Err(_) => {
5222                    saw_error = true;
5223                    break;
5224                }
5225            }
5226        }
5227
5228        assert_eq!(
5229            completion_calls,
5230            vec![CompletionCall::new(0, Some(call_usage))]
5231        );
5232        assert!(saw_error);
5233    }
5234
5235    #[tokio::test]
5236    async fn stream_prompt_completion_calls_records_unreported_usage() {
5237        let second_call_usage = usage(25, 5);
5238        let model = MockCompletionModel::from_stream_turns([
5239            vec![
5240                MockStreamEvent::tool_call(
5241                    "tool_call_1",
5242                    "add",
5243                    serde_json::json!({"x": 1, "y": 2}),
5244                )
5245                .with_call_id("call_1"),
5246            ],
5247            vec![
5248                MockStreamEvent::text("done"),
5249                MockStreamEvent::final_response(second_call_usage),
5250            ],
5251        ]);
5252        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5253        let empty_history: &[Message] = &[];
5254
5255        let mut stream = agent
5256            .stream_prompt("do tool work")
5257            .with_history(empty_history)
5258            .multi_turn(3)
5259            .await;
5260        let mut completion_calls_events = Vec::new();
5261        let mut final_response = None;
5262
5263        while let Some(item) = stream.next().await {
5264            match item {
5265                Ok(MultiTurnStreamItem::CompletionCall(call_usage)) => {
5266                    completion_calls_events.push(call_usage);
5267                }
5268                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
5269                    final_response = Some(response);
5270                    break;
5271                }
5272                Ok(_) => {}
5273                Err(err) => panic!("unexpected streaming error: {err:?}"),
5274            }
5275        }
5276
5277        let expected_usage = vec![
5278            CompletionCall::new(0, None),
5279            CompletionCall::new(1, Some(second_call_usage)),
5280        ];
5281        assert_eq!(completion_calls_events, expected_usage);
5282
5283        let final_response = final_response.expect("expected final response");
5284        assert_eq!(final_response.completion_calls(), expected_usage.as_slice());
5285    }
5286
5287    #[tokio::test]
5288    async fn final_response_matches_streamed_text_when_provider_final_is_textless() {
5289        let agent = AgentBuilder::new(streaming_text_then_final_model()).build();
5290
5291        let mut stream = agent.stream_prompt("say hello").await;
5292        let mut streamed_text = String::new();
5293        let mut final_response_text = None;
5294
5295        while let Some(item) = stream.next().await {
5296            match item {
5297                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5298                    text,
5299                ))) => streamed_text.push_str(&text.text),
5300                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5301                    final_response_text = Some(res.response().to_owned());
5302                    break;
5303                }
5304                Ok(_) => {}
5305                Err(err) => panic!("unexpected streaming error: {err:?}"),
5306            }
5307        }
5308
5309        assert_eq!(streamed_text, "hello world");
5310        assert_eq!(final_response_text.as_deref(), Some("hello world"));
5311    }
5312
5313    #[tokio::test]
5314    async fn final_response_preserves_structured_text_metadata() {
5315        let agent = AgentBuilder::new(streaming_cited_text_then_final_model()).build();
5316
5317        let mut stream = agent.stream_prompt("answer with citations").await;
5318        let mut final_response = None;
5319
5320        while let Some(item) = stream.next().await {
5321            match item {
5322                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5323                    final_response = Some(res);
5324                    break;
5325                }
5326                Ok(_) => {}
5327                Err(err) => panic!("unexpected streaming error: {err:?}"),
5328            }
5329        }
5330
5331        let final_response = final_response.expect("expected final response");
5332        assert_eq!(final_response.response(), "cited answer");
5333        let metadata = text_metadata(final_response.content())
5334            .expect("expected text metadata in final content");
5335        assert_eq!(
5336            metadata["citations"][0]["encrypted_index"],
5337            "encrypted-reference"
5338        );
5339    }
5340
5341    #[tokio::test]
5342    async fn final_response_history_preserves_structured_text_metadata() {
5343        let agent = AgentBuilder::new(streaming_cited_text_then_final_model()).build();
5344
5345        let empty_history: &[Message] = &[];
5346        let mut stream = agent
5347            .stream_prompt("answer with citations")
5348            .with_history(empty_history)
5349            .await;
5350        let mut final_response = None;
5351
5352        while let Some(item) = stream.next().await {
5353            match item {
5354                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5355                    final_response = Some(res);
5356                    break;
5357                }
5358                Ok(_) => {}
5359                Err(err) => panic!("unexpected streaming error: {err:?}"),
5360            }
5361        }
5362
5363        let final_response = final_response.expect("expected final response");
5364        let history = final_response
5365            .history()
5366            .expect("with_history should include final history");
5367        let assistant_content = history
5368            .iter()
5369            .find_map(|message| match message {
5370                Message::Assistant { content, .. } => Some(content),
5371                _ => None,
5372            })
5373            .expect("expected assistant message in history");
5374        let metadata =
5375            text_metadata(assistant_content).expect("expected text metadata in assistant history");
5376        assert_eq!(
5377            metadata["citations"][0]["encrypted_index"],
5378            "encrypted-reference"
5379        );
5380    }
5381
5382    #[tokio::test]
5383    async fn tool_follow_up_history_preserves_structured_text_metadata() {
5384        let model = streaming_cited_text_then_tool_model();
5385        let recorded = model.clone();
5386        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5387        let empty_history: &[Message] = &[];
5388
5389        let mut stream = agent
5390            .stream_prompt("use a tool with citations")
5391            .with_history(empty_history)
5392            .multi_turn(3)
5393            .await;
5394
5395        while let Some(item) = stream.next().await {
5396            match item {
5397                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
5398                Ok(_) => {}
5399                Err(err) => panic!("unexpected streaming error: {err:?}"),
5400            }
5401        }
5402
5403        let requests = recorded.requests();
5404        assert_eq!(requests.len(), 2);
5405        let follow_up_history = requests[1].chat_history.iter().collect::<Vec<_>>();
5406        let assistant_content = follow_up_history
5407            .iter()
5408            .find_map(|message| match message {
5409                Message::Assistant { content, .. } => Some(content),
5410                _ => None,
5411            })
5412            .expect("expected assistant message in follow-up history");
5413        let metadata = text_metadata(assistant_content)
5414            .expect("expected citation metadata in follow-up assistant history");
5415        assert_eq!(
5416            metadata["citations"][0]["encrypted_index"],
5417            "encrypted-reference"
5418        );
5419    }
5420
5421    #[tokio::test]
5422    async fn final_response_can_remain_empty_for_truly_textless_turns() {
5423        let agent = AgentBuilder::new(streaming_final_only_model()).build();
5424
5425        let mut stream = agent.stream_prompt("say nothing").await;
5426        let mut streamed_text = String::new();
5427        let mut final_response_text = None;
5428
5429        while let Some(item) = stream.next().await {
5430            match item {
5431                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5432                    text,
5433                ))) => streamed_text.push_str(&text.text),
5434                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5435                    final_response_text = Some(res.response().to_owned());
5436                    break;
5437                }
5438                Ok(_) => {}
5439                Err(err) => panic!("unexpected streaming error: {err:?}"),
5440            }
5441        }
5442
5443        assert!(streamed_text.is_empty());
5444        assert_eq!(final_response_text.as_deref(), Some(""));
5445    }
5446
5447    /// Background task that logs periodically to detect span leakage.
5448    /// If span leakage occurs, these logs will be prefixed with `invoke_agent{...}`.
5449    async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
5450        let mut interval = tokio::time::interval(Duration::from_millis(50));
5451        let mut count = 0u32;
5452
5453        while !stop.load(Ordering::Relaxed) {
5454            interval.tick().await;
5455            count += 1;
5456
5457            tracing::event!(
5458                target: "background_logger",
5459                tracing::Level::INFO,
5460                count = count,
5461                "Background tick"
5462            );
5463
5464            // Check if we're inside an unexpected span
5465            let current = tracing::Span::current();
5466            if !current.is_disabled() && !current.is_none() {
5467                leak_count.fetch_add(1, Ordering::Relaxed);
5468            }
5469        }
5470
5471        tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
5472    }
5473
5474    /// Test that span context doesn't leak to concurrent tasks during streaming.
5475    ///
5476    /// This test verifies that using `.instrument()` instead of `span.enter()` in
5477    /// async_stream prevents thread-local span context from leaking to other tasks.
5478    ///
5479    /// Uses single-threaded runtime to force all tasks onto the same thread,
5480    /// making the span leak deterministic (it only occurs when tasks share a thread).
5481    #[tokio::test(flavor = "current_thread")]
5482    #[ignore = "This requires an API key"]
5483    async fn test_span_context_isolation() -> anyhow::Result<()> {
5484        let stop = Arc::new(AtomicBool::new(false));
5485        let leak_count = Arc::new(AtomicU32::new(0));
5486
5487        // Start background logger
5488        let bg_stop = stop.clone();
5489        let bg_leak = leak_count.clone();
5490        let bg_handle = tokio::spawn(async move {
5491            background_logger(bg_stop, bg_leak).await;
5492        });
5493
5494        // Small delay to let background logger start
5495        tokio::time::sleep(Duration::from_millis(100)).await;
5496
5497        // Make streaming request WITHOUT an outer span so rig creates its own invoke_agent span
5498        // (rig reuses current span if one exists, so we need to ensure there's no current span)
5499        let client = anthropic::Client::from_env()?;
5500        let agent = client
5501            .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
5502            .preamble("You are a helpful assistant.")
5503            .temperature(0.1)
5504            .max_tokens(100)
5505            .build();
5506
5507        let mut stream = agent
5508            .stream_prompt("Say 'hello world' and nothing else.")
5509            .await;
5510
5511        let mut full_content = String::new();
5512        while let Some(item) = stream.next().await {
5513            match item {
5514                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5515                    text,
5516                ))) => {
5517                    full_content.push_str(&text.text);
5518                }
5519                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
5520                    break;
5521                }
5522                Err(e) => {
5523                    tracing::warn!("Error: {:?}", e);
5524                    break;
5525                }
5526                _ => {}
5527            }
5528        }
5529
5530        tracing::info!("Got response: {:?}", full_content);
5531
5532        // Stop background logger
5533        stop.store(true, Ordering::Relaxed);
5534        bg_handle.await?;
5535
5536        let leaks = leak_count.load(Ordering::Relaxed);
5537        anyhow::ensure!(
5538            leaks == 0,
5539            "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
5540             This indicates that span.enter() is being used inside async_stream instead of .instrument()"
5541        );
5542
5543        Ok(())
5544    }
5545
5546    /// Test that FinalResponse contains the updated chat history when with_history is used.
5547    ///
5548    /// This verifies that:
5549    /// 1. FinalResponse.history() returns Some when with_history was called
5550    /// 2. The history contains both the user prompt and assistant response
5551    #[tokio::test]
5552    #[ignore = "This requires an API key"]
5553    async fn test_chat_history_in_final_response() -> anyhow::Result<()> {
5554        use crate::message::Message;
5555
5556        let client = anthropic::Client::from_env()?;
5557        let agent = client
5558            .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
5559            .preamble("You are a helpful assistant. Keep responses brief.")
5560            .temperature(0.1)
5561            .max_tokens(50)
5562            .build();
5563
5564        // Send streaming request with history
5565        let empty_history: &[Message] = &[];
5566        let mut stream = agent
5567            .stream_prompt("Say 'hello' and nothing else.")
5568            .with_history(empty_history)
5569            .await;
5570
5571        // Consume the stream and collect FinalResponse
5572        let mut response_text = String::new();
5573        let mut final_history = None;
5574        while let Some(item) = stream.next().await {
5575            match item {
5576                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5577                    text,
5578                ))) => {
5579                    response_text.push_str(&text.text);
5580                }
5581                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5582                    final_history = res.history().map(|h| h.to_vec());
5583                    break;
5584                }
5585                Err(e) => {
5586                    return Err(e.into());
5587                }
5588                _ => {}
5589            }
5590        }
5591
5592        let history = final_history
5593            .ok_or_else(|| anyhow::anyhow!("final response should include history"))?;
5594
5595        // Should contain at least the user message
5596        anyhow::ensure!(
5597            history.iter().any(|m| matches!(m, Message::User { .. })),
5598            "History should contain the user message"
5599        );
5600
5601        // Should contain the assistant response
5602        anyhow::ensure!(
5603            history
5604                .iter()
5605                .any(|m| matches!(m, Message::Assistant { .. })),
5606            "History should contain the assistant response"
5607        );
5608
5609        tracing::info!(
5610            "History after streaming: {} messages, response: {:?}",
5611            history.len(),
5612            response_text
5613        );
5614
5615        Ok(())
5616    }
5617
5618    #[tokio::test]
5619    async fn streaming_appends_to_memory_after_final_response() {
5620        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5621
5622        let memory = InMemoryConversationMemory::new();
5623        let agent = AgentBuilder::new(streaming_text_then_final_model())
5624            .memory(memory.clone())
5625            .build();
5626
5627        let mut stream = agent
5628            .stream_prompt("hi there")
5629            .conversation("stream-thread")
5630            .await;
5631
5632        let mut history_in_final = None;
5633        while let Some(item) = stream.next().await {
5634            match item {
5635                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5636                    history_in_final = res.history().map(|h| h.to_vec());
5637                    break;
5638                }
5639                Ok(_) => {}
5640                Err(err) => panic!("unexpected streaming error: {err:?}"),
5641            }
5642        }
5643
5644        let final_history = history_in_final
5645            .expect("FinalResponse.history should be populated when memory is configured");
5646        assert_eq!(
5647            final_history.len(),
5648            2,
5649            "user prompt + assistant response in final history: {final_history:?}"
5650        );
5651
5652        let stored = memory.load("stream-thread").await.unwrap();
5653        assert_eq!(stored.len(), 2, "memory should contain user + assistant");
5654    }
5655
5656    #[tokio::test]
5657    async fn streaming_reasoning_without_tools_does_not_duplicate_final_history() {
5658        let agent = AgentBuilder::new(MockCompletionModel::from_stream_turns([[
5659            MockStreamEvent::text("final answer"),
5660            MockStreamEvent::reasoning("reasoned step").with_reasoning_id("rs_1"),
5661            MockStreamEvent::final_response_with_total_tokens(3),
5662        ]]))
5663        .build();
5664
5665        let mut stream = agent
5666            .stream_prompt("think before answering")
5667            .with_history(Vec::<Message>::new())
5668            .await;
5669
5670        let mut history_in_final = None;
5671        while let Some(item) = stream.next().await {
5672            match item {
5673                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5674                    history_in_final = res.history().map(|h| h.to_vec());
5675                    break;
5676                }
5677                Ok(_) => {}
5678                Err(err) => panic!("unexpected streaming error: {err:?}"),
5679            }
5680        }
5681
5682        let final_history = history_in_final
5683            .expect("FinalResponse.history should be populated when with_history is used");
5684        assert_eq!(
5685            final_history.len(),
5686            2,
5687            "user prompt + one assistant response in final history: {final_history:?}"
5688        );
5689
5690        assert!(matches!(
5691            final_history.first(),
5692            Some(Message::User { content })
5693                if matches!(
5694                    content.first(),
5695                    UserContent::Text(text) if text.text == "think before answering"
5696                )
5697        ));
5698
5699        let assistant_messages = final_history
5700            .iter()
5701            .filter_map(|message| match message {
5702                Message::Assistant { content, .. } => Some(content),
5703                _ => None,
5704            })
5705            .collect::<Vec<_>>();
5706        assert_eq!(
5707            assistant_messages.len(),
5708            1,
5709            "reasoning turn should produce exactly one assistant history message: {final_history:?}"
5710        );
5711        let assistant_content = assistant_messages
5712            .first()
5713            .expect("expected assistant history message");
5714        assert!(assistant_content.iter().any(|item| matches!(
5715            item,
5716            AssistantContent::Text(text) if text.text == "final answer"
5717        )));
5718        assert!(assistant_content.iter().any(|item| matches!(
5719            item,
5720            AssistantContent::Reasoning(reasoning)
5721                if reasoning.id.as_deref() == Some("rs_1")
5722                    && reasoning.content.iter().any(|content| matches!(
5723                        content,
5724                        ReasoningContent::Text { text, .. } if text == "reasoned step"
5725                    ))
5726        )));
5727    }
5728
5729    #[tokio::test]
5730    async fn streaming_with_history_overrides_memory() {
5731        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5732
5733        let memory = InMemoryConversationMemory::new();
5734        memory
5735            .append("t1", vec![Message::user("from-memory")])
5736            .await
5737            .unwrap();
5738
5739        let agent = AgentBuilder::new(streaming_text_then_final_model())
5740            .memory(memory.clone())
5741            .build();
5742
5743        let mut stream = agent
5744            .stream_prompt("hi")
5745            .conversation("t1")
5746            .with_history(vec![Message::user("from-caller")])
5747            .await;
5748
5749        while let Some(item) = stream.next().await {
5750            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5751                break;
5752            }
5753        }
5754
5755        let stored = memory.load("t1").await.unwrap();
5756        assert_eq!(
5757            stored.len(),
5758            1,
5759            "with_history bypasses memory; only the pre-seeded entry remains: {stored:?}"
5760        );
5761    }
5762
5763    #[tokio::test]
5764    async fn streaming_without_memory_disables_for_request() {
5765        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5766
5767        let memory = InMemoryConversationMemory::new();
5768        let agent = AgentBuilder::new(streaming_text_then_final_model())
5769            .memory(memory.clone())
5770            .conversation_id("default")
5771            .build();
5772
5773        let mut stream = agent.stream_prompt("hi").without_memory().await;
5774
5775        while let Some(item) = stream.next().await {
5776            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5777                break;
5778            }
5779        }
5780
5781        let stored = memory.load("default").await.unwrap();
5782        assert!(stored.is_empty(), "without_memory disables save");
5783    }
5784
5785    #[tokio::test]
5786    async fn streaming_load_error_yields_memory_error() {
5787        let agent = AgentBuilder::new(streaming_text_then_final_model())
5788            .memory(FailingMemory::default())
5789            .build();
5790
5791        let mut stream = agent.stream_prompt("hi").conversation("t1").await;
5792
5793        let first = stream.next().await.expect("at least one item");
5794        match first {
5795            Err(err) => {
5796                let msg = format!("{err:?}");
5797                assert!(
5798                    msg.contains("Memory") || msg.contains("memory") || msg.contains("load boom"),
5799                    "expected memory error, got: {msg}"
5800                );
5801            }
5802            Ok(other) => panic!("expected memory error, got {other:?}"),
5803        }
5804    }
5805
5806    #[tokio::test]
5807    async fn streaming_with_filter_shapes_loaded_history() {
5808        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5809
5810        let memory = InMemoryConversationMemory::new()
5811            .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
5812        memory
5813            .append(
5814                "t1",
5815                vec![
5816                    Message::user("1"),
5817                    Message::assistant("2"),
5818                    Message::user("3"),
5819                    Message::assistant("4"),
5820                ],
5821            )
5822            .await
5823            .unwrap();
5824
5825        let model = MockCompletionModel::from_stream_turns([[
5826            MockStreamEvent::text("ok"),
5827            MockStreamEvent::final_response_with_total_tokens(1),
5828        ]]);
5829        let recorded = model.clone();
5830        let agent = AgentBuilder::new(model).memory(memory).build();
5831
5832        let mut stream = agent.stream_prompt("ping").conversation("t1").await;
5833        while let Some(item) = stream.next().await {
5834            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5835                break;
5836            }
5837        }
5838
5839        let received = recorded.requests()[0]
5840            .chat_history
5841            .iter()
5842            .cloned()
5843            .collect::<Vec<_>>();
5844        assert_eq!(
5845            received.len(),
5846            3,
5847            "window-truncated history (2) + current prompt: {received:?}"
5848        );
5849    }
5850
5851    #[tokio::test]
5852    async fn streaming_append_error_does_not_suppress_final_response() {
5853        let agent = AgentBuilder::new(streaming_text_then_final_model())
5854            .memory(AppendFailingMemory::default())
5855            .build();
5856
5857        let mut stream = agent.stream_prompt("hi").conversation("t1").await;
5858
5859        let mut saw_final = false;
5860        while let Some(item) = stream.next().await {
5861            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5862                saw_final = true;
5863                break;
5864            }
5865        }
5866        assert!(
5867            saw_final,
5868            "FinalResponse must be yielded even when memory.append fails"
5869        );
5870    }
5871}