Skip to main content

rig_core/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::completion::{DynamicContextStore, build_prepared_completion_request},
4    agent::prompt_request::{
5        HookAction, InvalidToolCallResolution, TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER,
6        hooks::PromptHook, resolve_invalid_tool_call, validate_tool_call_name,
7    },
8    completion::{Document, GetTokenUsage},
9    json_utils,
10    memory::ConversationMemory,
11    message::{
12        AssistantContent, ToolCall, ToolChoice, ToolFunction, ToolResult, ToolResultContent,
13        UserContent,
14    },
15    streaming::{StreamedAssistantContent, StreamedUserContent, ToolCallDeltaContent},
16    tool::server::ToolServerHandle,
17    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
18};
19use futures::{Stream, StreamExt};
20use serde::{Deserialize, Serialize};
21use std::{collections::HashMap, pin::Pin, sync::Arc};
22use tracing::info_span;
23use tracing_futures::Instrument;
24
25use super::{CompletionCall, ToolCallHookAction, reported_usage};
26use crate::{
27    agent::Agent,
28    completion::{CompletionError, CompletionModel, PromptError},
29    message::{Message, Text},
30    tool::ToolSetError,
31};
32
33#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
34pub type StreamingResult<R> =
35    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
36
37#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
38pub type StreamingResult<R> =
39    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
40
41#[derive(Deserialize, Serialize, Debug, Clone)]
42#[serde(tag = "type", rename_all = "camelCase")]
43#[non_exhaustive]
44pub enum MultiTurnStreamItem<R> {
45    /// A streamed assistant content item.
46    StreamAssistantItem(StreamedAssistantContent<R>),
47    /// A streamed user content item (mostly for tool results).
48    StreamUserItem(StreamedUserContent),
49    /// Details for one successfully completed completion request made by this agent stream.
50    ///
51    /// This is emitted when a provider call finishes. Usage is the provider's
52    /// final usage for that completion request when available; it is not
53    /// incremental per streamed token.
54    ///
55    /// ```rust,ignore
56    /// match item {
57    ///     MultiTurnStreamItem::CompletionCall(completion_call) => {
58    ///         let context_tokens = completion_call.usage.map(|usage| usage.input_tokens);
59    ///     }
60    ///     _ => {}
61    /// }
62    /// ```
63    CompletionCall(CompletionCall),
64    /// The final result from the stream.
65    FinalResponse(FinalResponse),
66}
67
68#[derive(Deserialize, Serialize, Debug, Clone)]
69#[serde(rename_all = "camelCase")]
70pub struct FinalResponse {
71    /// Structured assistant content for the final turn.
72    content: OneOrMany<AssistantContent>,
73    /// Concatenated assistant text for the final turn.
74    /// This is empty only when the turn completed without emitting any text.
75    response: String,
76    aggregated_usage: crate::completion::Usage,
77    /// Successfully completed completion requests made by this agent stream.
78    #[serde(default, skip_serializing_if = "Vec::is_empty")]
79    completion_calls: Vec<CompletionCall>,
80    #[serde(skip_serializing_if = "Option::is_none")]
81    history: Option<Vec<Message>>,
82}
83
84impl FinalResponse {
85    pub fn empty() -> Self {
86        Self::new(
87            OneOrMany::one(AssistantContent::text("")),
88            crate::completion::Usage::new(),
89            None,
90        )
91    }
92
93    pub fn new(
94        content: OneOrMany<AssistantContent>,
95        aggregated_usage: crate::completion::Usage,
96        history: Option<Vec<Message>>,
97    ) -> Self {
98        let response = assistant_text_from_choice(&content);
99        Self {
100            content,
101            response,
102            aggregated_usage,
103            completion_calls: Vec::new(),
104            history,
105        }
106    }
107
108    /// Returns the concatenated assistant text for the final turn.
109    pub fn response(&self) -> &str {
110        &self.response
111    }
112
113    /// Returns the structured assistant content for the final turn.
114    pub fn content(&self) -> &OneOrMany<AssistantContent> {
115        &self.content
116    }
117
118    /// Returns the structured assistant content for the final turn.
119    pub fn assistant_content(&self) -> &OneOrMany<AssistantContent> {
120        &self.content
121    }
122
123    pub fn usage(&self) -> crate::completion::Usage {
124        self.aggregated_usage
125    }
126
127    /// Returns successfully completed completion requests made by this agent stream, with usage when available.
128    ///
129    /// Each entry represents one provider completion request. When present,
130    /// usage is a whole-request provider snapshot, not incremental usage per
131    /// streamed token. Streaming providers may omit usage for some calls; those
132    /// calls have an entry with `None` usage.
133    pub fn completion_calls(&self) -> &[CompletionCall] {
134        &self.completion_calls
135    }
136
137    pub fn history(&self) -> Option<&[Message]> {
138        self.history.as_deref()
139    }
140}
141
142impl<R> MultiTurnStreamItem<R> {
143    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
144        Self::StreamAssistantItem(item)
145    }
146
147    pub fn final_response(
148        content: OneOrMany<AssistantContent>,
149        aggregated_usage: crate::completion::Usage,
150    ) -> Self {
151        Self::FinalResponse(FinalResponse::new(content, aggregated_usage, None))
152    }
153
154    pub fn final_response_with_history(
155        content: OneOrMany<AssistantContent>,
156        aggregated_usage: crate::completion::Usage,
157        history: Option<Vec<Message>>,
158    ) -> Self {
159        Self::FinalResponse(FinalResponse::new(content, aggregated_usage, history))
160    }
161
162    pub(crate) fn final_response_with_completion_calls(
163        content: OneOrMany<AssistantContent>,
164        aggregated_usage: crate::completion::Usage,
165        completion_calls: Vec<CompletionCall>,
166        history: Option<Vec<Message>>,
167    ) -> Self {
168        let mut response = FinalResponse::new(content, aggregated_usage, history);
169        response.completion_calls = completion_calls;
170        Self::FinalResponse(response)
171    }
172}
173
174fn merge_reasoning_blocks(
175    accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
176    incoming: &crate::message::Reasoning,
177) {
178    let ids_match = |existing: &crate::message::Reasoning| {
179        matches!(
180            (&existing.id, &incoming.id),
181            (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
182        )
183    };
184
185    if let Some(existing) = accumulated_reasoning
186        .iter_mut()
187        .rev()
188        .find(|existing| ids_match(existing))
189    {
190        existing.content.extend(incoming.content.clone());
191    } else {
192        accumulated_reasoning.push(incoming.clone());
193    }
194}
195
196fn flush_pending_reasoning_delta(
197    accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
198    pending_reasoning_delta_text: &mut String,
199    pending_reasoning_delta_id: &mut Option<String>,
200) {
201    if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
202        let mut assembled = crate::message::Reasoning::new(&*pending_reasoning_delta_text);
203        if let Some(id) = pending_reasoning_delta_id.take() {
204            assembled = assembled.with_id(id);
205        }
206        accumulated_reasoning.push(assembled);
207        pending_reasoning_delta_text.clear();
208    }
209}
210
211/// Build full history for error reporting (input + new messages).
212fn build_full_history(
213    chat_history: Option<&[Message]>,
214    new_messages: Vec<Message>,
215) -> Vec<Message> {
216    let input = chat_history.unwrap_or(&[]);
217    input.iter().cloned().chain(new_messages).collect()
218}
219
220struct ToolCallValidationHistory<'a> {
221    chat_history: Option<&'a [Message]>,
222    new_messages: &'a [Message],
223    assistant_message_id: &'a Option<String>,
224    final_turn_content: Option<&'a OneOrMany<AssistantContent>>,
225    text_delta_response: Option<&'a str>,
226    accumulated_reasoning: &'a [crate::message::Reasoning],
227    pending_reasoning_delta_text: &'a str,
228    pending_reasoning_delta_id: &'a Option<String>,
229    pending_tool_calls: &'a [(ToolCall, String)],
230    current_tool_call: Option<ToolCall>,
231}
232
233fn build_tool_call_validation_history(input: ToolCallValidationHistory<'_>) -> Vec<Message> {
234    let mut messages = input.new_messages.to_vec();
235
236    if let Some(final_turn_content) = input.final_turn_content
237        && !is_empty_assistant_choice(final_turn_content)
238    {
239        messages.push(Message::Assistant {
240            id: input.assistant_message_id.clone(),
241            content: final_turn_content.clone(),
242        });
243        return build_full_history(input.chat_history, messages);
244    }
245
246    let mut content_items = Vec::new();
247    if let Some(text) = input.text_delta_response
248        && !text.is_empty()
249    {
250        content_items.push(AssistantContent::text(text.to_string()));
251    }
252    content_items.extend(
253        input
254            .accumulated_reasoning
255            .iter()
256            .cloned()
257            .map(AssistantContent::Reasoning),
258    );
259    if input.accumulated_reasoning.is_empty() && !input.pending_reasoning_delta_text.is_empty() {
260        let mut reasoning = crate::message::Reasoning::new(input.pending_reasoning_delta_text);
261        if let Some(id) = input.pending_reasoning_delta_id.clone() {
262            reasoning = reasoning.with_id(id);
263        }
264        content_items.push(AssistantContent::Reasoning(reasoning));
265    }
266    content_items.extend(
267        input
268            .pending_tool_calls
269            .iter()
270            .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone())),
271    );
272    if let Some(tool_call) = input.current_tool_call {
273        content_items.push(AssistantContent::ToolCall(tool_call));
274    }
275
276    if let Some(content) = OneOrMany::from_iter_optional(content_items) {
277        messages.push(Message::Assistant {
278            id: input.assistant_message_id.clone(),
279            content,
280        });
281    }
282
283    build_full_history(input.chat_history, messages)
284}
285
286async fn drain_recovered_stream_usage<R>(
287    stream: &mut crate::streaming::StreamingCompletionResponse<R>,
288    tool_call_delta_states: &HashMap<(String, String), ToolCallDeltaState>,
289    current_call_usage: &mut Option<crate::completion::Usage>,
290    aggregated_usage: &mut crate::completion::Usage,
291) -> Result<(), StreamingError>
292where
293    R: Clone + Unpin + GetTokenUsage,
294{
295    if let Some(err) = pending_tool_call_delta_error(tool_call_delta_states) {
296        return Err(err.into());
297    }
298
299    while let Some(content) = stream.next().await {
300        match content {
301            Ok(StreamedAssistantContent::Final(final_resp)) => {
302                if let Some(usage) = final_resp.token_usage() {
303                    *current_call_usage = reported_usage(usage);
304                }
305                if let Some(usage) = *current_call_usage {
306                    *aggregated_usage += usage;
307                }
308                return Ok(());
309            }
310            Ok(_) => {}
311            Err(err) => return Err(err.into()),
312        }
313    }
314
315    Ok(())
316}
317
318fn record_completion_call_if_needed(
319    completion_calls: &mut Vec<CompletionCall>,
320    completion_call_emitted: &mut bool,
321    call_index: usize,
322    current_call_usage: Option<crate::completion::Usage>,
323) -> Option<CompletionCall> {
324    if *completion_call_emitted {
325        return None;
326    }
327
328    let completion_call = CompletionCall::new(call_index, current_call_usage);
329    completion_calls.push(completion_call);
330    *completion_call_emitted = true;
331    Some(completion_call)
332}
333
334/// Combine input history with new messages for building completion requests.
335fn build_history_for_request(
336    chat_history: Option<&[Message]>,
337    new_messages: &[Message],
338) -> Vec<Message> {
339    let input = chat_history.unwrap_or(&[]);
340    input.iter().chain(new_messages.iter()).cloned().collect()
341}
342
343async fn cancelled_prompt_error(
344    chat_history: Option<&[Message]>,
345    new_messages: Vec<Message>,
346    reason: String,
347) -> StreamingError {
348    StreamingError::Prompt(
349        PromptError::prompt_cancelled(build_full_history(chat_history, new_messages), reason)
350            .into(),
351    )
352}
353
354fn tool_result_to_user_message(
355    id: String,
356    call_id: Option<String>,
357    tool_result: String,
358) -> Message {
359    let content = ToolResultContent::from_tool_output(tool_result);
360    let user_content = match call_id {
361        Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
362        None => UserContent::tool_result(id, content),
363    };
364
365    Message::User {
366        content: OneOrMany::one(user_content),
367    }
368}
369
370fn tool_result_user_content(
371    id: String,
372    call_id: Option<String>,
373    tool_result: String,
374) -> UserContent {
375    let content = ToolResultContent::from_tool_output(tool_result);
376    match call_id {
377        Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
378        None => UserContent::tool_result(id, content),
379    }
380}
381
382fn invalid_streaming_tool_retry_messages(
383    assistant_message_id: &Option<String>,
384    text_delta_response: Option<&str>,
385    accumulated_reasoning: &[crate::message::Reasoning],
386    pending_tool_calls: &[(ToolCall, String)],
387    invalid_tool_call: ToolCall,
388    feedback: String,
389) -> Option<(Message, Message)> {
390    let mut assistant_content = Vec::new();
391    if let Some(text) = text_delta_response
392        && !text.is_empty()
393    {
394        assistant_content.push(AssistantContent::text(text.to_string()));
395    }
396    assistant_content.extend(
397        accumulated_reasoning
398            .iter()
399            .cloned()
400            .map(AssistantContent::Reasoning),
401    );
402    assistant_content.extend(
403        pending_tool_calls
404            .iter()
405            .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone())),
406    );
407    assistant_content.push(AssistantContent::ToolCall(invalid_tool_call.clone()));
408
409    let assistant_content = OneOrMany::from_iter_optional(assistant_content)?;
410    let assistant_message = Message::Assistant {
411        id: assistant_message_id.clone(),
412        content: assistant_content,
413    };
414
415    let mut retry_results = pending_tool_calls
416        .iter()
417        .map(|(tool_call, _)| {
418            tool_result_user_content(
419                tool_call.id.clone(),
420                tool_call.call_id.clone(),
421                TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
422            )
423        })
424        .collect::<Vec<_>>();
425    retry_results.push(tool_result_user_content(
426        invalid_tool_call.id,
427        invalid_tool_call.call_id,
428        feedback,
429    ));
430
431    let user_message = Message::User {
432        content: OneOrMany::from_iter_optional(retry_results)?,
433    };
434
435    Some((assistant_message, user_message))
436}
437
438fn invalid_streaming_name_delta_retry_messages(
439    assistant_message_id: &Option<String>,
440    text_delta_response: Option<&str>,
441    accumulated_reasoning: &[crate::message::Reasoning],
442    pending_tool_calls: &[(ToolCall, String)],
443    invalid_tool_call: ToolCall,
444    feedback: String,
445) -> Option<(Message, Message)> {
446    let mut assistant_content = Vec::new();
447    if let Some(text) = text_delta_response
448        && !text.is_empty()
449    {
450        assistant_content.push(AssistantContent::text(text.to_string()));
451    }
452    assistant_content.extend(
453        accumulated_reasoning
454            .iter()
455            .cloned()
456            .map(AssistantContent::Reasoning),
457    );
458    assistant_content.extend(
459        pending_tool_calls
460            .iter()
461            .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone())),
462    );
463    assistant_content.push(AssistantContent::ToolCall(invalid_tool_call.clone()));
464
465    let assistant_message = Message::Assistant {
466        id: assistant_message_id.clone(),
467        content: OneOrMany::from_iter_optional(assistant_content)?,
468    };
469    let mut retry_results = pending_tool_calls
470        .iter()
471        .map(|(tool_call, _)| {
472            tool_result_user_content(
473                tool_call.id.clone(),
474                tool_call.call_id.clone(),
475                TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
476            )
477        })
478        .collect::<Vec<_>>();
479    retry_results.push(tool_result_user_content(
480        invalid_tool_call.id,
481        invalid_tool_call.call_id,
482        feedback,
483    ));
484    let user_message = Message::User {
485        content: OneOrMany::from_iter_optional(retry_results)?,
486    };
487
488    Some((assistant_message, user_message))
489}
490
491fn assistant_text_from_choice(choice: &OneOrMany<AssistantContent>) -> String {
492    choice
493        .iter()
494        .filter_map(|content| match content {
495            AssistantContent::Text(text) => Some(text.text.as_str()),
496            _ => None,
497        })
498        .collect()
499}
500
501fn assistant_text_items_from_choice(choice: &OneOrMany<AssistantContent>) -> Vec<AssistantContent> {
502    choice
503        .iter()
504        .filter_map(|content| match content {
505            AssistantContent::Text(text) => (!text.text.is_empty()
506                || text.additional_params.is_some())
507            .then(|| AssistantContent::Text(text.clone())),
508            _ => None,
509        })
510        .collect()
511}
512
513fn is_empty_assistant_choice(choice: &OneOrMany<AssistantContent>) -> bool {
514    choice.len() == 1
515        && matches!(
516            choice.first(),
517            AssistantContent::Text(text)
518                if text.text.is_empty() && text.additional_params.is_none()
519        )
520}
521
522#[derive(Default)]
523struct ToolCallDeltaState {
524    name_validated: bool,
525    buffered_arguments: Vec<String>,
526}
527
528fn pending_tool_call_delta_error(
529    states: &HashMap<(String, String), ToolCallDeltaState>,
530) -> Option<CompletionError> {
531    states
532        .iter()
533        .find(|(_, state)| !state.name_validated && !state.buffered_arguments.is_empty())
534        .map(|((id, internal_call_id), state)| {
535            CompletionError::ResponseError(format!(
536                "streamed tool call arguments received before a validated tool name for id `{id}` and internal_call_id `{internal_call_id}` ({} buffered argument delta(s))",
537                state.buffered_arguments.len()
538            ))
539        })
540}
541
542#[derive(Debug, thiserror::Error)]
543pub enum StreamingError {
544    #[error("CompletionError: {0}")]
545    Completion(#[from] CompletionError),
546    #[error("PromptError: {0}")]
547    Prompt(#[from] Box<PromptError>),
548    #[error("ToolSetError: {0}")]
549    Tool(#[from] ToolSetError),
550}
551
552/// Surface [`crate::memory::ConversationMemory`] failures through the existing
553/// [`CompletionError::RequestError`] variant so adding memory support does not
554/// require a new top-level [`StreamingError`] arm.
555impl From<crate::memory::MemoryError> for StreamingError {
556    fn from(err: crate::memory::MemoryError) -> Self {
557        Self::Completion(CompletionError::RequestError(Box::new(err)))
558    }
559}
560
561const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
562
563/// A builder for creating prompt requests with customizable options.
564/// Uses generics to track which options have been set during the build process.
565///
566/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
567/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
568/// attempting to await (which will send the prompt request) can potentially return
569/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
570/// back to back.
571pub struct StreamingPromptRequest<M, P>
572where
573    M: CompletionModel,
574    P: PromptHook<M> + 'static,
575{
576    /// The prompt message to send to the model
577    prompt: Message,
578    /// Optional chat history provided by the caller.
579    chat_history: Option<Vec<Message>>,
580    /// Maximum Turns for multi-turn conversations (0 means no multi-turn)
581    max_turns: usize,
582
583    // Agent data (cloned from agent to allow hook type transitions):
584    /// The completion model
585    model: Arc<M>,
586    /// Agent name for logging
587    agent_name: Option<String>,
588    /// System prompt
589    preamble: Option<String>,
590    /// Static context documents
591    static_context: Vec<Document>,
592    /// Temperature setting
593    temperature: Option<f64>,
594    /// Max tokens setting
595    max_tokens: Option<u64>,
596    /// Additional model parameters
597    additional_params: Option<serde_json::Value>,
598    /// Tool server handle for tool execution
599    tool_server_handle: ToolServerHandle,
600    /// Dynamic context store
601    dynamic_context: DynamicContextStore,
602    /// Tool choice setting
603    tool_choice: Option<ToolChoice>,
604    /// Optional JSON Schema for structured output
605    output_schema: Option<schemars::Schema>,
606    /// Optional per-request hook for events
607    hook: Option<P>,
608    /// Maximum number of invalid tool-call retries for this request.
609    max_invalid_tool_call_retries: usize,
610    /// Optional conversation memory backend cloned from the agent.
611    memory: Option<Arc<dyn ConversationMemory>>,
612    /// Optional conversation id used for loading and saving memory.
613    conversation_id: Option<String>,
614}
615
616impl<M, P> StreamingPromptRequest<M, P>
617where
618    M: CompletionModel + 'static,
619    <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
620    P: PromptHook<M>,
621{
622    /// Create a new StreamingPromptRequest with the given prompt and model.
623    /// Note: This creates a request without an agent hook. Use `from_agent` to include the agent's hook.
624    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
625        StreamingPromptRequest {
626            prompt: prompt.into(),
627            chat_history: None,
628            max_turns: agent.default_max_turns.unwrap_or_default(),
629            model: agent.model.clone(),
630            agent_name: agent.name.clone(),
631            preamble: agent.preamble.clone(),
632            static_context: agent.static_context.clone(),
633            temperature: agent.temperature,
634            max_tokens: agent.max_tokens,
635            additional_params: agent.additional_params.clone(),
636            tool_server_handle: agent.tool_server_handle.clone(),
637            dynamic_context: agent.dynamic_context.clone(),
638            tool_choice: agent.tool_choice.clone(),
639            output_schema: agent.output_schema.clone(),
640            hook: None,
641            max_invalid_tool_call_retries: 0,
642            memory: agent.memory.clone(),
643            conversation_id: agent.default_conversation_id.clone(),
644        }
645    }
646
647    /// Create a new StreamingPromptRequest from an agent, cloning the agent's data and default hook.
648    pub fn from_agent<P2>(
649        agent: &Agent<M, P2>,
650        prompt: impl Into<Message>,
651    ) -> StreamingPromptRequest<M, P2>
652    where
653        P2: PromptHook<M>,
654    {
655        StreamingPromptRequest {
656            prompt: prompt.into(),
657            chat_history: None,
658            max_turns: agent.default_max_turns.unwrap_or_default(),
659            model: agent.model.clone(),
660            agent_name: agent.name.clone(),
661            preamble: agent.preamble.clone(),
662            static_context: agent.static_context.clone(),
663            temperature: agent.temperature,
664            max_tokens: agent.max_tokens,
665            additional_params: agent.additional_params.clone(),
666            tool_server_handle: agent.tool_server_handle.clone(),
667            dynamic_context: agent.dynamic_context.clone(),
668            tool_choice: agent.tool_choice.clone(),
669            output_schema: agent.output_schema.clone(),
670            hook: agent.hook.clone(),
671            max_invalid_tool_call_retries: 0,
672            memory: agent.memory.clone(),
673            conversation_id: agent.default_conversation_id.clone(),
674        }
675    }
676
677    fn agent_name(&self) -> &str {
678        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
679    }
680
681    /// Set the maximum Turns for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
682    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
683    pub fn multi_turn(mut self, turns: usize) -> Self {
684        self.max_turns = turns;
685        self
686    }
687
688    /// Add chat history to the prompt request.
689    ///
690    /// When history is provided, the final [`FinalResponse`] will include the
691    /// updated chat history (original messages + new user prompt + assistant response).
692    /// ```ignore
693    /// let mut stream = agent
694    ///     .stream_prompt("Hello")
695    ///     .with_history(vec![])
696    ///     .await;
697    /// // ... consume stream ...
698    /// // Access updated history from FinalResponse::history()
699    /// ```
700    pub fn with_history<H, T>(mut self, history: H) -> Self
701    where
702        H: IntoIterator<Item = T>,
703        T: Into<Message>,
704    {
705        self.chat_history = Some(history.into_iter().map(Into::into).collect());
706        self
707    }
708
709    /// Attach a per-request hook for tool call events.
710    /// This overrides any default hook set on the agent.
711    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
712    where
713        P2: PromptHook<M>,
714    {
715        StreamingPromptRequest {
716            prompt: self.prompt,
717            chat_history: self.chat_history,
718            max_turns: self.max_turns,
719            model: self.model,
720            agent_name: self.agent_name,
721            preamble: self.preamble,
722            static_context: self.static_context,
723            temperature: self.temperature,
724            max_tokens: self.max_tokens,
725            additional_params: self.additional_params,
726            tool_server_handle: self.tool_server_handle,
727            dynamic_context: self.dynamic_context,
728            tool_choice: self.tool_choice,
729            output_schema: self.output_schema,
730            hook: Some(hook),
731            max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
732            memory: self.memory,
733            conversation_id: self.conversation_id,
734        }
735    }
736
737    /// Set the retry budget for [`crate::agent::prompt_request::hooks::InvalidToolCallHookAction::Retry`].
738    ///
739    /// Invalid tool-call retries also consume normal multi-turn depth.
740    pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
741        self.max_invalid_tool_call_retries = retries;
742        self
743    }
744
745    /// Set the conversation id used to load and persist memory for this request.
746    ///
747    /// Overrides any default conversation id set on the agent. If memory is not
748    /// configured on the agent, this has no effect.
749    pub fn conversation(mut self, id: impl Into<String>) -> Self {
750        self.conversation_id = Some(id.into());
751        self
752    }
753
754    /// Disable conversation memory for this request.
755    ///
756    /// History will neither be loaded from nor saved to the agent's memory backend.
757    pub fn without_memory(mut self) -> Self {
758        self.memory = None;
759        self.conversation_id = None;
760        self
761    }
762
763    async fn send(self) -> StreamingResult<M::StreamingResponse> {
764        let agent_span = if tracing::Span::current().is_disabled() {
765            info_span!(
766                "invoke_agent",
767                gen_ai.operation.name = "invoke_agent",
768                gen_ai.agent.name = self.agent_name(),
769                gen_ai.system_instructions = self.preamble,
770                gen_ai.prompt = tracing::field::Empty,
771                gen_ai.completion = tracing::field::Empty,
772                gen_ai.usage.input_tokens = tracing::field::Empty,
773                gen_ai.usage.output_tokens = tracing::field::Empty,
774                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
775                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
776                gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
777                gen_ai.usage.reasoning_tokens = tracing::field::Empty,
778            )
779        } else {
780            tracing::Span::current()
781        };
782
783        let prompt = self.prompt;
784        if let Some(text) = prompt.rag_text() {
785            agent_span.record("gen_ai.prompt", text);
786        }
787
788        // Clone fields needed inside the stream
789        let model = self.model.clone();
790        let preamble = self.preamble.clone();
791        let static_context = self.static_context.clone();
792        let temperature = self.temperature;
793        let max_tokens = self.max_tokens;
794        let additional_params = self.additional_params.clone();
795        let tool_server_handle = self.tool_server_handle.clone();
796        let dynamic_context = self.dynamic_context.clone();
797        let tool_choice = self.tool_choice.clone();
798        let agent_name = self.agent_name.clone();
799        // When the caller passes explicit history, memory is fully bypassed for
800        // this request (no load AND no save). Otherwise, if a memory backend and
801        // conversation id are both configured, load prior history; if either is
802        // missing, behave as if no memory is configured.
803        let (chat_history, memory_handle) = match self.chat_history {
804            Some(history) => (Some(history), None),
805            None => match (self.memory, self.conversation_id) {
806                (Some(memory), Some(id)) => match memory.load(&id).await {
807                    Ok(loaded) => (Some(loaded), Some((memory, id))),
808                    Err(err) => {
809                        let stream = async_stream::stream! {
810                            yield Err(StreamingError::from(err));
811                        };
812                        return Box::pin(stream);
813                    }
814                },
815                _ => (None, None),
816            },
817        };
818        let has_history = chat_history.is_some();
819        let mut new_messages: Vec<Message> = vec![prompt.clone()];
820
821        let mut current_max_turns = 0;
822        let mut last_prompt_error = String::new();
823
824        let mut text_delta_response = String::new();
825        let mut saw_text_this_turn = false;
826        let mut max_turns_reached = false;
827        let output_schema = self.output_schema;
828
829        let mut aggregated_usage = crate::completion::Usage::new();
830        let mut completion_calls = Vec::new();
831        let mut completion_call_index = 0;
832        let mut invalid_tool_call_retries = 0;
833
834        // NOTE: We use .instrument(agent_span) instead of span.enter() to avoid
835        // span context leaking to other concurrent tasks. Using span.enter() inside
836        // async_stream::stream! holds the guard across yield points, which causes
837        // thread-local span context to leak when other tasks run on the same thread.
838        // See: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#in-asynchronous-code
839        // See also: https://github.com/rust-lang/rust-clippy/issues/8722
840        let stream = async_stream::stream! {
841            'outer: loop {
842                let Some((current_prompt_ref, previous_messages)) = new_messages.split_last() else {
843                    yield Err(cancelled_prompt_error(
844                        chat_history.as_deref(),
845                        new_messages.clone(),
846                        "streaming loop lost its pending prompt".to_string(),
847                    ).await);
848                    break 'outer;
849                };
850                let current_prompt = current_prompt_ref.clone();
851
852                if current_max_turns > self.max_turns + 1 {
853                    last_prompt_error = current_prompt.rag_text().unwrap_or_default();
854                    max_turns_reached = true;
855                    break;
856                }
857
858                current_max_turns += 1;
859
860                if self.max_turns > 1 {
861                    tracing::info!(
862                        "Current conversation Turns: {}/{}",
863                        current_max_turns,
864                        self.max_turns
865                    );
866                }
867
868                let history_snapshot: Vec<Message> = build_history_for_request(
869                    chat_history.as_deref(),
870                    previous_messages,
871                );
872
873                if let Some(ref hook) = self.hook
874                    && let HookAction::Terminate { reason } =
875                        hook.on_completion_call(&current_prompt, &history_snapshot).await
876                {
877                    yield Err(
878                        cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason)
879                            .await,
880                    );
881                    break 'outer;
882                }
883
884                let chat_stream_span = info_span!(
885                    target: "rig::agent_chat",
886                    parent: tracing::Span::current(),
887                    "chat_streaming",
888                    gen_ai.operation.name = "chat",
889                    gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
890                    gen_ai.system_instructions = preamble,
891                    gen_ai.provider.name = tracing::field::Empty,
892                    gen_ai.request.model = tracing::field::Empty,
893                    gen_ai.response.id = tracing::field::Empty,
894                    gen_ai.response.model = tracing::field::Empty,
895                    gen_ai.usage.output_tokens = tracing::field::Empty,
896                    gen_ai.usage.input_tokens = tracing::field::Empty,
897                    gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
898                    gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
899                    gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
900                    gen_ai.usage.reasoning_tokens = tracing::field::Empty,
901                    gen_ai.input.messages = tracing::field::Empty,
902                    gen_ai.output.messages = tracing::field::Empty,
903                );
904
905                let prepared_request = build_prepared_completion_request(
906                    &model,
907                    current_prompt.clone(),
908                    &history_snapshot,
909                    preamble.as_deref(),
910                    &static_context,
911                    temperature,
912                    max_tokens,
913                    additional_params.as_ref(),
914                    tool_choice.as_ref(),
915                    &tool_server_handle,
916                    &dynamic_context,
917                    output_schema.as_ref(),
918                )
919                .await?;
920                let executable_tool_names = prepared_request.executable_tool_names.clone();
921                let allowed_tool_names = prepared_request.allowed_tool_names.clone();
922
923                let mut stream = prepared_request
924                    .builder
925                    .stream()
926                    .instrument(chat_stream_span)
927                    .await?;
928
929                let call_index = completion_call_index;
930                completion_call_index += 1;
931                let mut current_call_usage = None;
932                let mut completion_call_emitted = false;
933                let mut pending_tool_calls: Vec<(ToolCall, String)> = vec![];
934                let mut tool_calls = vec![];
935                let mut tool_results = vec![];
936                let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
937                // Kept separate from accumulated_reasoning so providers requiring
938                // signatures (e.g. Anthropic) never see unsigned blocks.
939                let mut pending_reasoning_delta_text = String::new();
940                let mut pending_reasoning_delta_id: Option<String> = None;
941                let mut tool_call_delta_states: HashMap<(String, String), ToolCallDeltaState> =
942                    HashMap::new();
943                let mut saw_tool_call_this_turn = false;
944
945                while let Some(content) = stream.next().await {
946                    match content {
947                        Ok(StreamedAssistantContent::Text(text)) => {
948                            if !saw_text_this_turn {
949                                text_delta_response.clear();
950                                saw_text_this_turn = true;
951                            }
952                            text_delta_response.push_str(&text.text);
953                            if let Some(ref hook) = self.hook &&
954                                let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &text_delta_response).await {
955                                    yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
956                                    break 'outer;
957                            }
958
959                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
960                        },
961                        Ok(StreamedAssistantContent::ToolCall { mut tool_call, internal_call_id }) => {
962                            let diagnostic_history =
963                                build_tool_call_validation_history(ToolCallValidationHistory {
964                                    chat_history: chat_history.as_deref(),
965                                    new_messages: &new_messages,
966                                    assistant_message_id: &stream.message_id,
967                                    final_turn_content: None,
968                                    text_delta_response: saw_text_this_turn
969                                        .then_some(text_delta_response.as_str()),
970                                    accumulated_reasoning: &accumulated_reasoning,
971                                    pending_reasoning_delta_text: &pending_reasoning_delta_text,
972                                    pending_reasoning_delta_id: &pending_reasoning_delta_id,
973                                    pending_tool_calls: &pending_tool_calls,
974                                    current_tool_call: Some(tool_call.clone()),
975                                });
976
977                            if !allowed_tool_names.contains(&tool_call.function.name) {
978                                let args = json_utils::value_to_json_string(&tool_call.function.arguments);
979                                let emitted_tool_name = tool_call.function.name.clone();
980                                match resolve_invalid_tool_call::<M, P>(
981                                    self.hook.as_ref(),
982                                    &emitted_tool_name,
983                                    Some(tool_call.id.clone()),
984                                    Some(internal_call_id.clone()),
985                                    Some(args),
986                                    &executable_tool_names,
987                                    &allowed_tool_names,
988                                    self.tool_choice.as_ref(),
989                                    diagnostic_history.clone(),
990                                    true,
991                                ).await {
992                                    InvalidToolCallResolution::Fail(err) => {
993                                        yield Err(Box::new(err).into());
994                                        break 'outer;
995                                    }
996                                    InvalidToolCallResolution::Retry(feedback) => {
997                                        if invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
998                                            yield Err(Box::new(PromptError::UnknownToolCall {
999                                                tool_name: emitted_tool_name,
1000                                                available_tools: executable_tool_names.iter().cloned().collect(),
1001                                                allowed_tools: allowed_tool_names.iter().cloned().collect(),
1002                                                chat_history: Box::new(diagnostic_history),
1003                                            }).into());
1004                                            break 'outer;
1005                                        }
1006
1007                                        invalid_tool_call_retries += 1;
1008                                        flush_pending_reasoning_delta(
1009                                            &mut accumulated_reasoning,
1010                                            &mut pending_reasoning_delta_text,
1011                                            &mut pending_reasoning_delta_id,
1012                                        );
1013                                        let Some((assistant_message, user_message)) =
1014                                            invalid_streaming_tool_retry_messages(
1015                                                &stream.message_id,
1016                                                saw_text_this_turn.then_some(text_delta_response.as_str()),
1017                                                &accumulated_reasoning,
1018                                                &pending_tool_calls,
1019                                                tool_call,
1020                                                feedback,
1021                                            )
1022                                        else {
1023                                            yield Err(cancelled_prompt_error(
1024                                                chat_history.as_deref(),
1025                                                new_messages.clone(),
1026                                                "invalid tool call retry produced no retry messages".to_string(),
1027                                            ).await);
1028                                            break 'outer;
1029                                        };
1030                                        new_messages.push(assistant_message);
1031                                        new_messages.push(user_message);
1032                                        if let Err(err) = drain_recovered_stream_usage(
1033                                            &mut stream,
1034                                            &tool_call_delta_states,
1035                                            &mut current_call_usage,
1036                                            &mut aggregated_usage,
1037                                        )
1038                                        .await
1039                                        {
1040                                            yield Err(err);
1041                                            break 'outer;
1042                                        }
1043                                        if let Some(completion_call) = record_completion_call_if_needed(
1044                                            &mut completion_calls,
1045                                            &mut completion_call_emitted,
1046                                            call_index,
1047                                            current_call_usage,
1048                                        ) {
1049                                            yield Ok(MultiTurnStreamItem::CompletionCall(
1050                                                completion_call,
1051                                            ));
1052                                        }
1053                                        text_delta_response.clear();
1054                                        saw_text_this_turn = false;
1055                                        continue 'outer;
1056                                    }
1057                                    InvalidToolCallResolution::Repair(repaired_name) => {
1058                                        tool_call.function.name = repaired_name;
1059                                    }
1060                                    InvalidToolCallResolution::Skip(reason) => {
1061                                        let skipped_tool_result = ToolResult {
1062                                            id: tool_call.id.clone(),
1063                                            call_id: tool_call.call_id.clone(),
1064                                            content: ToolResultContent::from_tool_output(
1065                                                reason.clone(),
1066                                            ),
1067                                        };
1068                                        flush_pending_reasoning_delta(
1069                                            &mut accumulated_reasoning,
1070                                            &mut pending_reasoning_delta_text,
1071                                            &mut pending_reasoning_delta_id,
1072                                        );
1073                                        let Some((assistant_message, user_message)) =
1074                                            invalid_streaming_tool_retry_messages(
1075                                                &stream.message_id,
1076                                                saw_text_this_turn
1077                                                    .then_some(text_delta_response.as_str()),
1078                                                &accumulated_reasoning,
1079                                                &pending_tool_calls,
1080                                                tool_call,
1081                                                reason,
1082                                            )
1083                                        else {
1084                                            yield Err(cancelled_prompt_error(
1085                                                chat_history.as_deref(),
1086                                                new_messages.clone(),
1087                                                "invalid tool call skip produced no recovery messages".to_string(),
1088                                            ).await);
1089                                            break 'outer;
1090                                        };
1091                                        new_messages.push(assistant_message);
1092                                        new_messages.push(user_message);
1093                                        let tool_result = ToolResult {
1094                                            id: skipped_tool_result.id,
1095                                            call_id: skipped_tool_result.call_id,
1096                                            content: skipped_tool_result.content,
1097                                        };
1098                                        if let Err(err) = drain_recovered_stream_usage(
1099                                            &mut stream,
1100                                            &tool_call_delta_states,
1101                                            &mut current_call_usage,
1102                                            &mut aggregated_usage,
1103                                        )
1104                                        .await
1105                                        {
1106                                            yield Err(err);
1107                                            break 'outer;
1108                                        }
1109                                        if let Some(completion_call) = record_completion_call_if_needed(
1110                                            &mut completion_calls,
1111                                            &mut completion_call_emitted,
1112                                            call_index,
1113                                            current_call_usage,
1114                                        ) {
1115                                            yield Ok(MultiTurnStreamItem::CompletionCall(
1116                                                completion_call,
1117                                            ));
1118                                        }
1119                                        yield Ok(MultiTurnStreamItem::StreamUserItem(
1120                                            StreamedUserContent::ToolResult {
1121                                                tool_result,
1122                                                internal_call_id,
1123                                            },
1124                                        ));
1125                                        text_delta_response.clear();
1126                                        saw_text_this_turn = false;
1127                                        continue 'outer;
1128                                    }
1129                                }
1130                            }
1131
1132                            pending_tool_calls.push((tool_call, internal_call_id));
1133                        },
1134                        Ok(StreamedAssistantContent::ToolCallDelta {
1135                            id,
1136                            internal_call_id,
1137                            content,
1138                        }) => {
1139                            let key = (id.clone(), internal_call_id.clone());
1140                            let mut deltas_to_emit = Vec::new();
1141
1142                            match content {
1143                                ToolCallDeltaContent::Name(mut name) => {
1144                                    let buffered_args = tool_call_delta_states
1145                                        .get(&key)
1146                                        .map(|state| state.buffered_arguments.join(""))
1147                                        .unwrap_or_default();
1148                                    let diagnostic_args = if buffered_args.trim().is_empty() {
1149                                        serde_json::Value::Null
1150                                    } else {
1151                                        serde_json::from_str(&buffered_args)
1152                                            .unwrap_or(serde_json::Value::Null)
1153                                    };
1154                                    let diagnostic_tool_call = ToolCall::new(
1155                                        id.clone(),
1156                                        ToolFunction::new(name.clone(), diagnostic_args),
1157                                    );
1158                                    let diagnostic_history =
1159                                        build_tool_call_validation_history(ToolCallValidationHistory {
1160                                            chat_history: chat_history.as_deref(),
1161                                            new_messages: &new_messages,
1162                                            assistant_message_id: &stream.message_id,
1163                                            final_turn_content: None,
1164                                            text_delta_response: saw_text_this_turn
1165                                                .then_some(text_delta_response.as_str()),
1166                                            accumulated_reasoning: &accumulated_reasoning,
1167                                            pending_reasoning_delta_text: &pending_reasoning_delta_text,
1168                                            pending_reasoning_delta_id: &pending_reasoning_delta_id,
1169                                            pending_tool_calls: &pending_tool_calls,
1170                                            current_tool_call: Some(diagnostic_tool_call.clone()),
1171                                        });
1172
1173                                    if !allowed_tool_names.contains(&name) {
1174                                        let emitted_tool_name = name.clone();
1175                                        match resolve_invalid_tool_call::<M, P>(
1176                                            self.hook.as_ref(),
1177                                            &emitted_tool_name,
1178                                            Some(id.clone()),
1179                                            Some(internal_call_id.clone()),
1180                                            Some(buffered_args.clone()),
1181                                            &executable_tool_names,
1182                                            &allowed_tool_names,
1183                                            self.tool_choice.as_ref(),
1184                                            diagnostic_history.clone(),
1185                                            true,
1186                                        ).await {
1187                                            InvalidToolCallResolution::Fail(err) => {
1188                                                yield Err(Box::new(err).into());
1189                                                break 'outer;
1190                                            }
1191                                            InvalidToolCallResolution::Skip(reason) => {
1192                                                tool_call_delta_states.remove(&key);
1193                                                flush_pending_reasoning_delta(
1194                                                    &mut accumulated_reasoning,
1195                                                    &mut pending_reasoning_delta_text,
1196                                                    &mut pending_reasoning_delta_id,
1197                                                );
1198                                                let Some((assistant_message, user_message)) =
1199                                                    invalid_streaming_name_delta_retry_messages(
1200                                                        &stream.message_id,
1201                                                        saw_text_this_turn
1202                                                            .then_some(text_delta_response.as_str()),
1203                                                        &accumulated_reasoning,
1204                                                        &pending_tool_calls,
1205                                                        diagnostic_tool_call.clone(),
1206                                                        reason.clone(),
1207                                                    )
1208                                                else {
1209                                                    yield Err(cancelled_prompt_error(
1210                                                        chat_history.as_deref(),
1211                                                        new_messages.clone(),
1212                                                        "invalid tool call skip produced no recovery messages".to_string(),
1213                                                    ).await);
1214                                                    break 'outer;
1215                                                };
1216                                                new_messages.push(assistant_message);
1217                                                new_messages.push(user_message);
1218                                                let tool_result = ToolResult {
1219                                                    id,
1220                                                    call_id: None,
1221                                                    content: ToolResultContent::from_tool_output(
1222                                                        reason,
1223                                                    ),
1224                                                };
1225                                                if let Err(err) = drain_recovered_stream_usage(
1226                                                    &mut stream,
1227                                                    &tool_call_delta_states,
1228                                                    &mut current_call_usage,
1229                                                    &mut aggregated_usage,
1230                                                )
1231                                                .await
1232                                                {
1233                                                    yield Err(err);
1234                                                    break 'outer;
1235                                                }
1236                                                if let Some(completion_call) = record_completion_call_if_needed(
1237                                                    &mut completion_calls,
1238                                                    &mut completion_call_emitted,
1239                                                    call_index,
1240                                                    current_call_usage,
1241                                                ) {
1242                                                    yield Ok(MultiTurnStreamItem::CompletionCall(
1243                                                        completion_call,
1244                                                    ));
1245                                                }
1246                                                yield Ok(MultiTurnStreamItem::StreamUserItem(
1247                                                    StreamedUserContent::ToolResult {
1248                                                        tool_result,
1249                                                        internal_call_id,
1250                                                    },
1251                                                ));
1252                                                text_delta_response.clear();
1253                                                saw_text_this_turn = false;
1254                                                continue 'outer;
1255                                            }
1256                                            InvalidToolCallResolution::Retry(feedback) => {
1257                                                tool_call_delta_states.remove(&key);
1258                                                if invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
1259                                                    yield Err(Box::new(PromptError::UnknownToolCall {
1260                                                        tool_name: emitted_tool_name,
1261                                                        available_tools: executable_tool_names.iter().cloned().collect(),
1262                                                        allowed_tools: allowed_tool_names.iter().cloned().collect(),
1263                                                        chat_history: Box::new(diagnostic_history),
1264                                                    }).into());
1265                                                    break 'outer;
1266                                                }
1267
1268                                                invalid_tool_call_retries += 1;
1269                                                flush_pending_reasoning_delta(
1270                                                    &mut accumulated_reasoning,
1271                                                    &mut pending_reasoning_delta_text,
1272                                                    &mut pending_reasoning_delta_id,
1273                                                );
1274                                                let Some((assistant_message, user_message)) =
1275                                                    invalid_streaming_name_delta_retry_messages(
1276                                                        &stream.message_id,
1277                                                        saw_text_this_turn
1278                                                            .then_some(text_delta_response.as_str()),
1279                                                        &accumulated_reasoning,
1280                                                        &pending_tool_calls,
1281                                                        diagnostic_tool_call.clone(),
1282                                                        feedback,
1283                                                    )
1284                                                else {
1285                                                    yield Err(cancelled_prompt_error(
1286                                                        chat_history.as_deref(),
1287                                                        new_messages.clone(),
1288                                                        "invalid tool call retry produced no retry messages".to_string(),
1289                                                    ).await);
1290                                                    break 'outer;
1291                                                };
1292                                                new_messages.push(assistant_message);
1293                                                new_messages.push(user_message);
1294                                                if let Err(err) = drain_recovered_stream_usage(
1295                                                    &mut stream,
1296                                                    &tool_call_delta_states,
1297                                                    &mut current_call_usage,
1298                                                    &mut aggregated_usage,
1299                                                )
1300                                                .await
1301                                                {
1302                                                    yield Err(err);
1303                                                    break 'outer;
1304                                                }
1305                                                if let Some(completion_call) = record_completion_call_if_needed(
1306                                                    &mut completion_calls,
1307                                                    &mut completion_call_emitted,
1308                                                    call_index,
1309                                                    current_call_usage,
1310                                                ) {
1311                                                    yield Ok(MultiTurnStreamItem::CompletionCall(
1312                                                        completion_call,
1313                                                    ));
1314                                                }
1315                                                text_delta_response.clear();
1316                                                saw_text_this_turn = false;
1317                                                continue 'outer;
1318                                            }
1319                                            InvalidToolCallResolution::Repair(repaired_name) => {
1320                                                name = repaired_name;
1321                                            }
1322                                        }
1323                                    }
1324
1325                                    let state =
1326                                        tool_call_delta_states.entry(key.clone()).or_default();
1327                                    state.name_validated = true;
1328                                    let buffered_arguments =
1329                                        std::mem::take(&mut state.buffered_arguments);
1330
1331                                    deltas_to_emit.push(ToolCallDeltaContent::Name(name));
1332                                    deltas_to_emit.extend(
1333                                        buffered_arguments
1334                                            .into_iter()
1335                                            .map(ToolCallDeltaContent::Delta),
1336                                    );
1337                                }
1338                                ToolCallDeltaContent::Delta(arguments) => {
1339                                    let state =
1340                                        tool_call_delta_states.entry(key.clone()).or_default();
1341                                    if state.name_validated {
1342                                        deltas_to_emit.push(ToolCallDeltaContent::Delta(arguments));
1343                                    } else {
1344                                        state.buffered_arguments.push(arguments);
1345                                    }
1346                                }
1347                            }
1348
1349                            for content in deltas_to_emit {
1350                                if let Some(ref hook) = self.hook {
1351                                    let (name, delta) = match &content {
1352                                        ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
1353                                        ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
1354                                    };
1355
1356                                    if let HookAction::Terminate { reason } = hook
1357                                        .on_tool_call_delta(
1358                                            &id,
1359                                            &internal_call_id,
1360                                            name,
1361                                            delta,
1362                                        )
1363                                        .await
1364                                    {
1365                                        yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1366                                        break 'outer;
1367                                    }
1368                                }
1369
1370                                yield Ok(MultiTurnStreamItem::StreamAssistantItem(
1371                                    StreamedAssistantContent::ToolCallDelta {
1372                                        id: id.clone(),
1373                                        internal_call_id: internal_call_id.clone(),
1374                                        content,
1375                                    },
1376                                ));
1377                            }
1378                        }
1379                        Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
1380                            // Accumulate reasoning for inclusion in chat history with tool calls.
1381                            // OpenAI Responses API requires reasoning items to be sent back
1382                            // alongside function_call items in multi-turn conversations.
1383                            merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
1384                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
1385                        },
1386                        Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
1387                            // Deltas lack signatures/encrypted content that full
1388                            // blocks carry; mixing them into accumulated_reasoning
1389                            // causes Anthropic to reject with "signature required".
1390                            pending_reasoning_delta_text.push_str(&reasoning);
1391                            if pending_reasoning_delta_id.is_none() {
1392                                pending_reasoning_delta_id = id.clone();
1393                            }
1394                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
1395                        },
1396                        Ok(StreamedAssistantContent::Final(final_resp)) => {
1397                            if let Some(err) =
1398                                pending_tool_call_delta_error(&tool_call_delta_states)
1399                            {
1400                                yield Err(err.into());
1401                                break 'outer;
1402                            }
1403
1404                            if let Some(usage) = final_resp.token_usage() {
1405                                current_call_usage = reported_usage(usage);
1406                            }
1407                            if let Some(usage) = current_call_usage {
1408                                aggregated_usage += usage;
1409                            }
1410                            let completion_call = CompletionCall::new(call_index, current_call_usage);
1411                            completion_calls.push(completion_call);
1412                            completion_call_emitted = true;
1413                            yield Ok(MultiTurnStreamItem::CompletionCall(completion_call));
1414
1415                            if saw_text_this_turn {
1416                                if let Some(ref hook) = self.hook &&
1417                                     let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(&current_prompt, &final_resp).await {
1418                                        yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1419                                        break 'outer;
1420                                    }
1421
1422                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
1423                                saw_text_this_turn = false;
1424                            }
1425                        }
1426                        Err(e) => {
1427                            yield Err(e.into());
1428                            break 'outer;
1429                        }
1430                    }
1431                }
1432
1433                if let Some(err) = pending_tool_call_delta_error(&tool_call_delta_states) {
1434                    yield Err(err.into());
1435                    break 'outer;
1436                }
1437
1438                if !completion_call_emitted {
1439                    let completion_call = CompletionCall::new(call_index, current_call_usage);
1440                    completion_calls.push(completion_call);
1441                    yield Ok(MultiTurnStreamItem::CompletionCall(completion_call));
1442                }
1443
1444                // Providers like Gemini emit thinking as incremental deltas
1445                // without signatures; assemble into a single block so
1446                // reasoning survives into the next turn's chat history.
1447                flush_pending_reasoning_delta(
1448                    &mut accumulated_reasoning,
1449                    &mut pending_reasoning_delta_text,
1450                    &mut pending_reasoning_delta_id,
1451                );
1452
1453                let final_turn_content = stream.choice.clone();
1454                let turn_text_response = assistant_text_from_choice(&final_turn_content);
1455                tracing::Span::current().record("gen_ai.completion", &turn_text_response);
1456
1457                if !pending_tool_calls.is_empty() {
1458                    let diagnostic_history =
1459                        build_tool_call_validation_history(ToolCallValidationHistory {
1460                            chat_history: chat_history.as_deref(),
1461                            new_messages: &new_messages,
1462                            assistant_message_id: &stream.message_id,
1463                            final_turn_content: Some(&final_turn_content),
1464                            text_delta_response: None,
1465                            accumulated_reasoning: &accumulated_reasoning,
1466                            pending_reasoning_delta_text: "",
1467                            pending_reasoning_delta_id: &None,
1468                            pending_tool_calls: &pending_tool_calls,
1469                            current_tool_call: None,
1470                        });
1471
1472                    for (tool_call, _) in &pending_tool_calls {
1473                        if let Err(err) = validate_tool_call_name(
1474                            &tool_call.function.name,
1475                            &executable_tool_names,
1476                            &allowed_tool_names,
1477                            diagnostic_history.clone(),
1478                        ) {
1479                            yield Err(Box::new(err).into());
1480                            break 'outer;
1481                        }
1482                    }
1483
1484                    for (tool_call, internal_call_id) in pending_tool_calls {
1485                        let tool_span = info_span!(
1486                            parent: tracing::Span::current(),
1487                            "execute_tool",
1488                            gen_ai.operation.name = "execute_tool",
1489                            gen_ai.tool.type = "function",
1490                            gen_ai.tool.name = tracing::field::Empty,
1491                            gen_ai.tool.call.id = tracing::field::Empty,
1492                            gen_ai.tool.call.arguments = tracing::field::Empty,
1493                            gen_ai.tool.call.result = tracing::field::Empty
1494                        );
1495
1496                        yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
1497
1498                        let tc_result = async {
1499                            let tool_span = tracing::Span::current();
1500                            let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
1501                            if let Some(ref hook) = self.hook {
1502                                let action = hook
1503                                    .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
1504                                    .await;
1505
1506                                if let ToolCallHookAction::Terminate { reason } = action {
1507                                    return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1508                                }
1509
1510                                if let ToolCallHookAction::Skip { reason } = action {
1511                                    // Tool execution rejected, return rejection message as tool result
1512                                    tracing::info!(
1513                                        tool_name = tool_call.function.name.as_str(),
1514                                        reason = reason,
1515                                        "Tool call rejected"
1516                                    );
1517                                    let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
1518                                    tool_calls.push(tool_call_msg);
1519                                    tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
1520                                    saw_tool_call_this_turn = true;
1521                                    return Ok(reason);
1522                                }
1523                            }
1524
1525                            tool_span.record("gen_ai.tool.name", &tool_call.function.name);
1526                            tool_span.record("gen_ai.tool.call.arguments", &tool_args);
1527
1528                            let tool_result = match
1529                            tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
1530                                Ok(thing) => thing,
1531                                Err(e) => {
1532                                    tracing::warn!("Error while calling tool: {e}");
1533                                    e.to_string()
1534                                }
1535                            };
1536
1537                            tool_span.record("gen_ai.tool.call.result", &tool_result);
1538
1539                            if let Some(ref hook) = self.hook &&
1540                                let HookAction::Terminate { reason } =
1541                                hook.on_tool_result(
1542                                    &tool_call.function.name,
1543                                    tool_call.call_id.clone(),
1544                                    &internal_call_id,
1545                                    &tool_args,
1546                                    &tool_result.to_string()
1547                                )
1548                                .await {
1549                                    return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
1550                                }
1551
1552                            let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
1553
1554                            tool_calls.push(tool_call_msg);
1555                            tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
1556
1557                            saw_tool_call_this_turn = true;
1558                            Ok(tool_result)
1559                        }.instrument(tool_span).await;
1560
1561                        match tc_result {
1562                            Ok(text) => {
1563                                let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
1564                                yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
1565                            }
1566                            Err(e) => {
1567                                yield Err(e);
1568                                break 'outer;
1569                            }
1570                        }
1571                    }
1572                }
1573
1574                // Add text, reasoning, and tool calls to chat history.
1575                // OpenAI Responses API requires reasoning items to precede function_call items.
1576                let mut assistant_turn_added_to_history = false;
1577                if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
1578                    // Text before tool calls so the model sees its own prior output.
1579                    let mut content_items = assistant_text_items_from_choice(&final_turn_content);
1580
1581                    // Reasoning must come before tool calls (OpenAI requirement)
1582                    for reasoning in accumulated_reasoning.drain(..) {
1583                        content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
1584                    }
1585
1586                    content_items.extend(tool_calls.clone());
1587
1588                    if let Some(content) = OneOrMany::from_iter_optional(content_items) {
1589                        new_messages.push(Message::Assistant {
1590                            id: stream.message_id.clone(),
1591                            content,
1592                        });
1593                        assistant_turn_added_to_history = true;
1594                    }
1595                }
1596
1597                for (id, call_id, tool_result) in tool_results {
1598                    new_messages.push(tool_result_to_user_message(id, call_id, tool_result));
1599                }
1600
1601                if !saw_tool_call_this_turn {
1602                    // Add user message and assistant response to history before finishing
1603                    let should_add_final_assistant = !assistant_turn_added_to_history
1604                        && !is_empty_assistant_choice(&final_turn_content);
1605                    if should_add_final_assistant {
1606                        new_messages.push(Message::Assistant {
1607                            id: stream.message_id.clone(),
1608                            content: final_turn_content.clone(),
1609                        });
1610                    } else if !assistant_turn_added_to_history {
1611                        tracing::warn!(
1612                            agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
1613                            message_id = ?stream.message_id,
1614                            "Streaming turn completed without assistant text; final response will be empty"
1615                        );
1616                    }
1617
1618                    let current_span = tracing::Span::current();
1619                    current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
1620                    current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
1621                    current_span.record("gen_ai.usage.cache_read.input_tokens", aggregated_usage.cached_input_tokens);
1622                    current_span.record("gen_ai.usage.cache_creation.input_tokens", aggregated_usage.cache_creation_input_tokens);
1623                    current_span.record("gen_ai.usage.tool_use_prompt_tokens", aggregated_usage.tool_use_prompt_tokens);
1624                    current_span.record("gen_ai.usage.reasoning_tokens", aggregated_usage.reasoning_tokens);
1625                    tracing::info!("Agent multi-turn stream finished");
1626                    if let Some((memory, id)) = memory_handle.as_ref()
1627                        && let Err(err) = memory.append(id, new_messages.clone()).await
1628                    {
1629                        tracing::warn!(
1630                            error = %err,
1631                            conversation_id = %id,
1632                            "conversation memory append failed; yielding final response anyway"
1633                        );
1634                    }
1635                    let final_messages: Option<Vec<Message>> = if has_history {
1636                        Some(new_messages.clone())
1637                    } else {
1638                        None
1639                    };
1640                    yield Ok(MultiTurnStreamItem::final_response_with_completion_calls(
1641                        final_turn_content,
1642                        aggregated_usage,
1643                        completion_calls,
1644                        final_messages,
1645                    ));
1646                    break;
1647                }
1648            }
1649
1650            if max_turns_reached {
1651                yield Err(Box::new(PromptError::MaxTurnsError {
1652                    max_turns: self.max_turns,
1653                    chat_history: build_full_history(chat_history.as_deref(), new_messages.clone()).into(),
1654                    prompt: Box::new(last_prompt_error.clone().into()),
1655                }).into());
1656            }
1657        };
1658
1659        Box::pin(stream.instrument(agent_span))
1660    }
1661}
1662
1663impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
1664where
1665    M: CompletionModel + 'static,
1666    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
1667    P: PromptHook<M> + 'static,
1668{
1669    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
1670    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1671
1672    fn into_future(self) -> Self::IntoFuture {
1673        // Wrap send() in a future, because send() returns a stream immediately
1674        Box::pin(async move { self.send().await })
1675    }
1676}
1677
1678/// Helper function to stream assistant-visible completion output to stdout.
1679///
1680/// This helper prints streamed assistant text and reasoning only. Streaming
1681/// metadata events, such as `MultiTurnStreamItem::CompletionCall`, are not
1682/// printed; metadata is returned on the `FinalResponse` via accessors such as
1683/// `FinalResponse::completion_calls`.
1684pub async fn stream_to_stdout<R>(
1685    stream: &mut StreamingResult<R>,
1686) -> Result<FinalResponse, std::io::Error> {
1687    let mut final_res = FinalResponse::empty();
1688    print!("Response: ");
1689    while let Some(content) = stream.next().await {
1690        match content {
1691            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1692                Text { text, .. },
1693            ))) => {
1694                print!("{text}");
1695                std::io::Write::flush(&mut std::io::stdout())?;
1696            }
1697            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
1698                reasoning,
1699            ))) => {
1700                let reasoning = reasoning.display_text();
1701                print!("{reasoning}");
1702                std::io::Write::flush(&mut std::io::stdout())?;
1703            }
1704            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1705                final_res = res;
1706            }
1707            Err(err) => {
1708                eprintln!("Error: {err}");
1709            }
1710            _ => {}
1711        }
1712    }
1713
1714    Ok(final_res)
1715}
1716
1717#[cfg(test)]
1718mod tests {
1719    use super::*;
1720    use crate::agent::AgentBuilder;
1721    use crate::agent::prompt_request::hooks::{
1722        InvalidToolCallContext, InvalidToolCallHookAction, PromptHook, ToolCallHookAction,
1723    };
1724    use crate::client::ProviderClient;
1725    use crate::client::completion::CompletionClient;
1726    use crate::completion::{CompletionRequest, PromptError, ToolDefinition, Usage};
1727    use crate::message::{
1728        AssistantContent, DocumentSourceKind, ImageMediaType, Message, ReasoningContent,
1729        ToolChoice, ToolResultContent, UserContent,
1730    };
1731    use crate::providers::anthropic;
1732    use crate::streaming::{StreamingPrompt, ToolCallDeltaContent};
1733    use crate::test_utils::{
1734        AppendFailingMemory, FailingMemory, MockAddTool, MockCompletionModel, MockResponse,
1735        MockStreamEvent, MockSubtractTool, MockToolError,
1736    };
1737    use crate::tool::Tool;
1738    use futures::StreamExt;
1739    use serde::Deserialize;
1740    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
1741    use std::sync::{Arc, Mutex};
1742    use std::time::Duration;
1743
1744    #[test]
1745    fn merge_reasoning_blocks_preserves_order_and_signatures() {
1746        let mut accumulated = Vec::new();
1747        let first = crate::message::Reasoning {
1748            id: Some("rs_1".to_string()),
1749            content: vec![ReasoningContent::Text {
1750                text: "step-1".to_string(),
1751                signature: Some("sig-1".to_string()),
1752            }],
1753        };
1754        let second = crate::message::Reasoning {
1755            id: Some("rs_1".to_string()),
1756            content: vec![
1757                ReasoningContent::Text {
1758                    text: "step-2".to_string(),
1759                    signature: Some("sig-2".to_string()),
1760                },
1761                ReasoningContent::Summary("summary".to_string()),
1762            ],
1763        };
1764
1765        merge_reasoning_blocks(&mut accumulated, &first);
1766        merge_reasoning_blocks(&mut accumulated, &second);
1767
1768        assert_eq!(accumulated.len(), 1);
1769        let merged = accumulated.first().expect("expected accumulated reasoning");
1770        assert_eq!(merged.id.as_deref(), Some("rs_1"));
1771        assert_eq!(merged.content.len(), 3);
1772        assert!(matches!(
1773            merged.content.first(),
1774            Some(ReasoningContent::Text { text, signature: Some(sig) })
1775                if text == "step-1" && sig == "sig-1"
1776        ));
1777        assert!(matches!(
1778            merged.content.get(1),
1779            Some(ReasoningContent::Text { text, signature: Some(sig) })
1780                if text == "step-2" && sig == "sig-2"
1781        ));
1782    }
1783
1784    #[test]
1785    fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
1786        let mut accumulated = vec![crate::message::Reasoning {
1787            id: Some("rs_a".to_string()),
1788            content: vec![ReasoningContent::Text {
1789                text: "step-1".to_string(),
1790                signature: None,
1791            }],
1792        }];
1793        let incoming = crate::message::Reasoning {
1794            id: Some("rs_b".to_string()),
1795            content: vec![ReasoningContent::Text {
1796                text: "step-2".to_string(),
1797                signature: None,
1798            }],
1799        };
1800
1801        merge_reasoning_blocks(&mut accumulated, &incoming);
1802        assert_eq!(accumulated.len(), 2);
1803        assert_eq!(
1804            accumulated.first().and_then(|r| r.id.as_deref()),
1805            Some("rs_a")
1806        );
1807        assert_eq!(
1808            accumulated.get(1).and_then(|r| r.id.as_deref()),
1809            Some("rs_b")
1810        );
1811    }
1812
1813    #[test]
1814    fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
1815        let mut accumulated = vec![crate::message::Reasoning {
1816            id: None,
1817            content: vec![ReasoningContent::Text {
1818                text: "first".to_string(),
1819                signature: None,
1820            }],
1821        }];
1822        let incoming = crate::message::Reasoning {
1823            id: None,
1824            content: vec![ReasoningContent::Text {
1825                text: "second".to_string(),
1826                signature: None,
1827            }],
1828        };
1829
1830        merge_reasoning_blocks(&mut accumulated, &incoming);
1831        assert_eq!(accumulated.len(), 2);
1832        assert!(matches!(
1833            accumulated.first(),
1834            Some(crate::message::Reasoning {
1835                id: None,
1836                content
1837            }) if matches!(
1838                content.first(),
1839                Some(ReasoningContent::Text { text, .. }) if text == "first"
1840            )
1841        ));
1842        assert!(matches!(
1843            accumulated.get(1),
1844            Some(crate::message::Reasoning {
1845                id: None,
1846                content
1847            }) if matches!(
1848                content.first(),
1849                Some(ReasoningContent::Text { text, .. }) if text == "second"
1850            )
1851        ));
1852    }
1853
1854    #[test]
1855    fn tool_result_to_user_message_preserves_multimodal_tool_output() {
1856        let message = tool_result_to_user_message(
1857            "tool_call_1".to_string(),
1858            Some("call_1".to_string()),
1859            serde_json::json!({
1860                "response": {
1861                    "instruction": "Use the image part to answer."
1862                },
1863                "parts": [
1864                    {
1865                        "type": "image",
1866                        "data": "base64data==",
1867                        "mimeType": "image/png"
1868                    }
1869                ]
1870            })
1871            .to_string(),
1872        );
1873
1874        let tool_result = match message {
1875            Message::User { content } => match content.first() {
1876                UserContent::ToolResult(tool_result) => tool_result,
1877                other => panic!("expected tool result content, got {other:?}"),
1878            },
1879            other => panic!("expected user message, got {other:?}"),
1880        };
1881
1882        assert_eq!(tool_result.id, "tool_call_1");
1883        assert_eq!(tool_result.call_id.as_deref(), Some("call_1"));
1884        assert_eq!(tool_result.content.len(), 2);
1885
1886        let mut items = tool_result.content.iter();
1887        match items.next() {
1888            Some(ToolResultContent::Text(text)) => {
1889                assert!(text.text.contains("Use the image part to answer."));
1890            }
1891            other => panic!("expected structured text payload first, got {other:?}"),
1892        }
1893
1894        match items.next() {
1895            Some(ToolResultContent::Image(image)) => {
1896                assert_eq!(image.media_type, Some(ImageMediaType::PNG));
1897                assert!(matches!(
1898                    image.data,
1899                    DocumentSourceKind::Base64(ref data) if data == "base64data=="
1900                ));
1901            }
1902            other => panic!("expected image payload second, got {other:?}"),
1903        }
1904    }
1905
1906    fn validate_follow_up_tool_history(request: &CompletionRequest) -> Result<(), String> {
1907        let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
1908        if history.len() != 3 {
1909            return Err(format!(
1910                "follow-up request should contain [original user prompt, assistant tool call, user tool result]: {history:?}"
1911            ));
1912        }
1913
1914        if !matches!(
1915            history.first(),
1916            Some(Message::User { content })
1917                if matches!(
1918                    content.first(),
1919                    UserContent::Text(text) if text.text == "do tool work"
1920                )
1921        ) {
1922            return Err(format!(
1923                "follow-up request should begin with the original user prompt: {history:?}"
1924            ));
1925        }
1926
1927        if !matches!(
1928            history.get(1),
1929            Some(Message::Assistant { content, .. })
1930                if matches!(
1931                    content.first(),
1932                    AssistantContent::ToolCall(tool_call)
1933                        if tool_call.id == "tool_call_1"
1934                            && tool_call.call_id.as_deref() == Some("call_1")
1935                )
1936        ) {
1937            return Err(format!(
1938                "follow-up request is missing the assistant tool call in position 2: {history:?}"
1939            ));
1940        }
1941
1942        if !matches!(
1943            history.get(2),
1944            Some(Message::User { content })
1945                if matches!(
1946                    content.first(),
1947                    UserContent::ToolResult(tool_result)
1948                        if tool_result.id == "tool_call_1"
1949                            && tool_result.call_id.as_deref() == Some("call_1")
1950                )
1951        ) {
1952            return Err(format!(
1953                "follow-up request should end with the user tool result: {history:?}"
1954            ));
1955        }
1956
1957        Ok(())
1958    }
1959
1960    fn history_contains_tool_call(history: &[Message], tool_name: &str) -> bool {
1961        history.iter().any(|message| {
1962            matches!(
1963                message,
1964                Message::Assistant { content, .. }
1965                    if content.iter().any(|item| matches!(
1966                        item,
1967                        AssistantContent::ToolCall(tool_call)
1968                            if tool_call.function.name == tool_name
1969                    ))
1970            )
1971        })
1972    }
1973
1974    fn history_contains_text(history: &[Message], expected: &str) -> bool {
1975        history.iter().any(|message| {
1976            matches!(
1977                message,
1978                Message::Assistant { content, .. }
1979                    if content.iter().any(|item| matches!(
1980                        item,
1981                        AssistantContent::Text(text) if text.text == expected
1982                    ))
1983            )
1984        })
1985    }
1986
1987    fn assistant_reasoning_precedes_tool_call(
1988        history: &[Message],
1989        expected_reasoning: &str,
1990        tool_name: &str,
1991    ) -> bool {
1992        history.iter().any(|message| {
1993            let Message::Assistant { content, .. } = message else {
1994                return false;
1995            };
1996
1997            let reasoning_index = content.iter().position(|item| {
1998                matches!(
1999                    item,
2000                    AssistantContent::Reasoning(reasoning)
2001                        if reasoning.content.iter().any(|content| matches!(
2002                            content,
2003                            ReasoningContent::Text { text, .. }
2004                                if text == expected_reasoning
2005                        ))
2006                )
2007            });
2008            let tool_index = content.iter().position(|item| {
2009                matches!(
2010                    item,
2011                    AssistantContent::ToolCall(tool_call)
2012                        if tool_call.function.name == tool_name
2013                )
2014            });
2015
2016            matches!((reasoning_index, tool_index), (Some(reasoning), Some(tool)) if reasoning < tool)
2017        })
2018    }
2019
2020    #[derive(Clone)]
2021    struct PanicOnUnknownToolHook;
2022
2023    impl PromptHook<MockCompletionModel> for PanicOnUnknownToolHook {
2024        async fn on_tool_call_delta(
2025            &self,
2026            _tool_call_id: &str,
2027            _internal_call_id: &str,
2028            _tool_name: Option<&str>,
2029            _tool_call_delta: &str,
2030        ) -> HookAction {
2031            panic!("unknown tool call delta should fail before delta hooks run")
2032        }
2033
2034        async fn on_tool_call(
2035            &self,
2036            _tool_name: &str,
2037            _tool_call_id: Option<String>,
2038            _internal_call_id: &str,
2039            _args: &str,
2040        ) -> ToolCallHookAction {
2041            panic!("unknown tool call should fail before tool hooks run")
2042        }
2043
2044        async fn on_stream_completion_response_finish(
2045            &self,
2046            _prompt: &Message,
2047            _response: &MockResponse,
2048        ) -> HookAction {
2049            panic!("unknown tool call should fail before stream finish hooks run")
2050        }
2051    }
2052
2053    #[derive(Clone)]
2054    struct CountingAddTool {
2055        calls: Arc<AtomicU32>,
2056    }
2057
2058    #[derive(Clone)]
2059    struct CountingSubtractTool {
2060        calls: Arc<AtomicU32>,
2061    }
2062
2063    #[derive(Deserialize)]
2064    struct CountingOperationArgs {
2065        x: i32,
2066        y: i32,
2067    }
2068
2069    fn arithmetic_tool_definition(name: &str, description: &str) -> ToolDefinition {
2070        ToolDefinition {
2071            name: name.to_string(),
2072            description: description.to_string(),
2073            parameters: serde_json::json!({
2074                "type": "object",
2075                "properties": {
2076                    "x": {
2077                        "type": "number",
2078                        "description": "The first operand"
2079                    },
2080                    "y": {
2081                        "type": "number",
2082                        "description": "The second operand"
2083                    }
2084                },
2085                "required": ["x", "y"],
2086            }),
2087        }
2088    }
2089
2090    impl Tool for CountingAddTool {
2091        const NAME: &'static str = "add";
2092        type Error = MockToolError;
2093        type Args = CountingOperationArgs;
2094        type Output = i32;
2095
2096        async fn definition(&self, _prompt: String) -> ToolDefinition {
2097            arithmetic_tool_definition(Self::NAME, "Add x and y together")
2098        }
2099
2100        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
2101            self.calls.fetch_add(1, Ordering::SeqCst);
2102            Ok(args.x + args.y)
2103        }
2104    }
2105
2106    impl Tool for CountingSubtractTool {
2107        const NAME: &'static str = "subtract";
2108        type Error = MockToolError;
2109        type Args = CountingOperationArgs;
2110        type Output = i32;
2111
2112        async fn definition(&self, _prompt: String) -> ToolDefinition {
2113            arithmetic_tool_definition(Self::NAME, "Subtract y from x")
2114        }
2115
2116        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
2117            self.calls.fetch_add(1, Ordering::SeqCst);
2118            Ok(args.x - args.y)
2119        }
2120    }
2121
2122    fn streaming_tool_then_text_model() -> MockCompletionModel {
2123        MockCompletionModel::from_stream_turns([
2124            vec![
2125                MockStreamEvent::tool_call(
2126                    "tool_call_1",
2127                    "add",
2128                    serde_json::json!({"x": 1, "y": 2}),
2129                )
2130                .with_call_id("call_1"),
2131                MockStreamEvent::final_response_with_total_tokens(4),
2132            ],
2133            vec![
2134                MockStreamEvent::text("done"),
2135                MockStreamEvent::final_response_with_total_tokens(6),
2136            ],
2137        ])
2138    }
2139
2140    fn usage(input_tokens: u64, output_tokens: u64) -> Usage {
2141        Usage {
2142            input_tokens,
2143            output_tokens,
2144            total_tokens: input_tokens + output_tokens,
2145            cached_input_tokens: 0,
2146            cache_creation_input_tokens: 0,
2147            tool_use_prompt_tokens: 0,
2148            reasoning_tokens: 0,
2149        }
2150    }
2151
2152    #[test]
2153    fn completion_calls_stream_item_serializes_and_deserializes_expected_shape() {
2154        let item: MultiTurnStreamItem<MockResponse> =
2155            MultiTurnStreamItem::CompletionCall(CompletionCall::new(2, Some(usage(3, 4))));
2156
2157        let value = serde_json::to_value(&item).expect("serialize completion call event");
2158
2159        assert_eq!(
2160            value,
2161            serde_json::json!({
2162                "type": "completionCall",
2163                "call_index": 2,
2164                "usage": {
2165                    "input_tokens": 3,
2166                    "output_tokens": 4,
2167                    "total_tokens": 7,
2168                    "cached_input_tokens": 0,
2169                    "cache_creation_input_tokens": 0,
2170                    "tool_use_prompt_tokens": 0,
2171                    "reasoning_tokens": 0,
2172                }
2173            })
2174        );
2175
2176        let item: MultiTurnStreamItem<MockResponse> =
2177            serde_json::from_value(value).expect("deserialize completion call event");
2178        match item {
2179            MultiTurnStreamItem::CompletionCall(call_usage) => {
2180                assert_eq!(call_usage, CompletionCall::new(2, Some(usage(3, 4))));
2181            }
2182            other => panic!("expected completion call event, got {other:?}"),
2183        }
2184
2185        let item: MultiTurnStreamItem<MockResponse> =
2186            MultiTurnStreamItem::CompletionCall(CompletionCall::new(3, None));
2187        let value = serde_json::to_value(&item).expect("serialize missing usage event");
2188
2189        assert_eq!(
2190            value,
2191            serde_json::json!({
2192                "type": "completionCall",
2193                "call_index": 3,
2194                "usage": null
2195            })
2196        );
2197    }
2198
2199    #[test]
2200    fn final_response_serializes_completion_calls_with_missing_usage() {
2201        let item: MultiTurnStreamItem<MockResponse> =
2202            MultiTurnStreamItem::final_response_with_completion_calls(
2203                OneOrMany::one(AssistantContent::text("done")),
2204                usage(3, 4),
2205                vec![
2206                    CompletionCall::new(0, None),
2207                    CompletionCall::new(1, Some(usage(3, 4))),
2208                ],
2209                None,
2210            );
2211
2212        let value = serde_json::to_value(&item).expect("serialize final response");
2213
2214        assert_eq!(
2215            value.get("completionCalls"),
2216            Some(&serde_json::json!([
2217                {
2218                    "call_index": 0,
2219                    "usage": null,
2220                },
2221                {
2222                    "call_index": 1,
2223                    "usage": {
2224                        "input_tokens": 3,
2225                        "output_tokens": 4,
2226                        "total_tokens": 7,
2227                        "cached_input_tokens": 0,
2228                        "cache_creation_input_tokens": 0,
2229                        "tool_use_prompt_tokens": 0,
2230                        "reasoning_tokens": 0,
2231                    }
2232                }
2233            ]))
2234        );
2235    }
2236
2237    fn streaming_text_then_final_model() -> MockCompletionModel {
2238        MockCompletionModel::from_stream_turns([[
2239            MockStreamEvent::text("hello"),
2240            MockStreamEvent::text(" world"),
2241            MockStreamEvent::final_response_with_total_tokens(3),
2242        ]])
2243    }
2244
2245    fn citation_metadata() -> serde_json::Value {
2246        serde_json::json!({
2247            "citations": [{
2248                "type": "web_search_result_location",
2249                "cited_text": "Claude Shannon was born in 1916.",
2250                "url": "https://example.com/shannon",
2251                "title": "Claude Shannon",
2252                "encrypted_index": "encrypted-reference"
2253            }]
2254        })
2255    }
2256
2257    fn streaming_cited_text_then_final_model() -> MockCompletionModel {
2258        MockCompletionModel::from_stream_turns([[
2259            MockStreamEvent::text_start(Some(citation_metadata())),
2260            MockStreamEvent::text("cited "),
2261            MockStreamEvent::text_start(None),
2262            MockStreamEvent::text("answer"),
2263            MockStreamEvent::final_response_with_total_tokens(3),
2264        ]])
2265    }
2266
2267    fn streaming_cited_text_then_tool_model() -> MockCompletionModel {
2268        MockCompletionModel::from_stream_turns([
2269            vec![
2270                MockStreamEvent::text_start(Some(citation_metadata())),
2271                MockStreamEvent::text("I need a tool. "),
2272                MockStreamEvent::tool_call(
2273                    "tool_call_1",
2274                    "add",
2275                    serde_json::json!({"x": 1, "y": 2}),
2276                )
2277                .with_call_id("call_1"),
2278                MockStreamEvent::final_response_with_total_tokens(4),
2279            ],
2280            vec![
2281                MockStreamEvent::text("done"),
2282                MockStreamEvent::final_response_with_total_tokens(6),
2283            ],
2284        ])
2285    }
2286
2287    fn streaming_final_only_model() -> MockCompletionModel {
2288        MockCompletionModel::from_stream_turns([[
2289            MockStreamEvent::final_response_with_total_tokens(1),
2290        ]])
2291    }
2292
2293    #[derive(Clone)]
2294    struct TerminateOnStreamFinish;
2295
2296    impl PromptHook<MockCompletionModel> for TerminateOnStreamFinish {
2297        async fn on_stream_completion_response_finish(
2298            &self,
2299            _prompt: &Message,
2300            _response: &<MockCompletionModel as CompletionModel>::StreamingResponse,
2301        ) -> HookAction {
2302            HookAction::terminate("stop after completion call")
2303        }
2304    }
2305
2306    type RecordedToolCallDelta = (String, String, Option<String>, String);
2307
2308    #[derive(Clone)]
2309    struct RepairDefaultApiHook;
2310
2311    impl PromptHook<MockCompletionModel> for RepairDefaultApiHook {
2312        fn on_invalid_tool_call(
2313            &self,
2314            context: &InvalidToolCallContext,
2315        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2316            let tool_name = context.tool_name.clone();
2317            async move {
2318                assert_eq!(tool_name, "default_api");
2319                InvalidToolCallHookAction::repair("add")
2320            }
2321        }
2322    }
2323
2324    #[derive(Clone)]
2325    struct RetryDefaultApiHook;
2326
2327    impl PromptHook<MockCompletionModel> for RetryDefaultApiHook {
2328        fn on_invalid_tool_call(
2329            &self,
2330            context: &InvalidToolCallContext,
2331        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2332            let tool_name = context.tool_name.clone();
2333            let args = context.args.clone();
2334            async move {
2335                assert_eq!(tool_name, "default_api");
2336                if let Some(args) = args {
2337                    assert!(!args.is_empty());
2338                }
2339                InvalidToolCallHookAction::retry("Use the add tool instead")
2340            }
2341        }
2342    }
2343
2344    #[derive(Clone)]
2345    struct SkipDefaultApiHook;
2346
2347    impl PromptHook<MockCompletionModel> for SkipDefaultApiHook {
2348        fn on_invalid_tool_call(
2349            &self,
2350            context: &InvalidToolCallContext,
2351        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2352            let tool_name = context.tool_name.clone();
2353            async move {
2354                assert_eq!(tool_name, "default_api");
2355                InvalidToolCallHookAction::skip("default_api was skipped")
2356            }
2357        }
2358    }
2359
2360    #[derive(Clone, Default)]
2361    struct RecordingInvalidToolCallHook {
2362        contexts: Arc<Mutex<Vec<InvalidToolCallContext>>>,
2363    }
2364
2365    impl RecordingInvalidToolCallHook {
2366        fn observed(&self) -> Vec<InvalidToolCallContext> {
2367            self.contexts
2368                .lock()
2369                .expect("invalid tool context records mutex was poisoned")
2370                .clone()
2371        }
2372    }
2373
2374    impl PromptHook<MockCompletionModel> for RecordingInvalidToolCallHook {
2375        fn on_invalid_tool_call(
2376            &self,
2377            context: &InvalidToolCallContext,
2378        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2379            let contexts = self.contexts.clone();
2380            let context = context.clone();
2381
2382            async move {
2383                contexts
2384                    .lock()
2385                    .expect("invalid tool context records mutex was poisoned")
2386                    .push(context);
2387                InvalidToolCallHookAction::fail()
2388            }
2389        }
2390    }
2391
2392    #[derive(Clone, Default)]
2393    struct RecordingToolCallDeltaHook {
2394        deltas: Arc<Mutex<Vec<RecordedToolCallDelta>>>,
2395    }
2396
2397    impl RecordingToolCallDeltaHook {
2398        fn observed(&self) -> Vec<RecordedToolCallDelta> {
2399            self.deltas
2400                .lock()
2401                .expect("tool call delta hook records mutex was poisoned")
2402                .clone()
2403        }
2404    }
2405
2406    impl PromptHook<MockCompletionModel> for RecordingToolCallDeltaHook {
2407        fn on_tool_call_delta(
2408            &self,
2409            tool_call_id: &str,
2410            internal_call_id: &str,
2411            tool_name: Option<&str>,
2412            tool_call_delta: &str,
2413        ) -> impl Future<Output = HookAction> + Send {
2414            let deltas = self.deltas.clone();
2415            let event = (
2416                tool_call_id.to_string(),
2417                internal_call_id.to_string(),
2418                tool_name.map(str::to_string),
2419                tool_call_delta.to_string(),
2420            );
2421
2422            async move {
2423                deltas
2424                    .lock()
2425                    .expect("tool call delta hook records mutex was poisoned")
2426                    .push(event);
2427                HookAction::cont()
2428            }
2429        }
2430    }
2431
2432    #[derive(Clone, Default)]
2433    struct RecordingTextDeltaHook {
2434        deltas: Arc<Mutex<Vec<(String, String)>>>,
2435    }
2436
2437    impl RecordingTextDeltaHook {
2438        fn observed(&self) -> Vec<(String, String)> {
2439            self.deltas
2440                .lock()
2441                .expect("text delta hook records mutex was poisoned")
2442                .clone()
2443        }
2444    }
2445
2446    impl PromptHook<MockCompletionModel> for RecordingTextDeltaHook {
2447        fn on_text_delta(
2448            &self,
2449            text_delta: &str,
2450            full_text: &str,
2451        ) -> impl Future<Output = HookAction> + Send {
2452            let deltas = self.deltas.clone();
2453            let event = (text_delta.to_string(), full_text.to_string());
2454
2455            async move {
2456                deltas
2457                    .lock()
2458                    .expect("text delta hook records mutex was poisoned")
2459                    .push(event);
2460                HookAction::cont()
2461            }
2462        }
2463    }
2464
2465    #[derive(Clone)]
2466    struct RecordingTextAndSkipInvalidToolHook {
2467        text: RecordingTextDeltaHook,
2468    }
2469
2470    impl PromptHook<MockCompletionModel> for RecordingTextAndSkipInvalidToolHook {
2471        fn on_text_delta(
2472            &self,
2473            text_delta: &str,
2474            full_text: &str,
2475        ) -> impl Future<Output = HookAction> + Send {
2476            self.text.on_text_delta(text_delta, full_text)
2477        }
2478
2479        fn on_invalid_tool_call(
2480            &self,
2481            context: &InvalidToolCallContext,
2482        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2483            SkipDefaultApiHook.on_invalid_tool_call(context)
2484        }
2485    }
2486
2487    #[derive(Clone)]
2488    struct RecordingTextAndRetryInvalidToolHook {
2489        text: RecordingTextDeltaHook,
2490    }
2491
2492    impl PromptHook<MockCompletionModel> for RecordingTextAndRetryInvalidToolHook {
2493        fn on_text_delta(
2494            &self,
2495            text_delta: &str,
2496            full_text: &str,
2497        ) -> impl Future<Output = HookAction> + Send {
2498            self.text.on_text_delta(text_delta, full_text)
2499        }
2500
2501        fn on_invalid_tool_call(
2502            &self,
2503            context: &InvalidToolCallContext,
2504        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2505            RetryDefaultApiHook.on_invalid_tool_call(context)
2506        }
2507    }
2508
2509    #[derive(Clone)]
2510    struct RecordingDeltaAndRetryInvalidToolHook {
2511        delta: RecordingToolCallDeltaHook,
2512    }
2513
2514    impl PromptHook<MockCompletionModel> for RecordingDeltaAndRetryInvalidToolHook {
2515        fn on_tool_call_delta(
2516            &self,
2517            tool_call_id: &str,
2518            internal_call_id: &str,
2519            tool_name: Option<&str>,
2520            tool_call_delta: &str,
2521        ) -> impl Future<Output = HookAction> + Send {
2522            self.delta.on_tool_call_delta(
2523                tool_call_id,
2524                internal_call_id,
2525                tool_name,
2526                tool_call_delta,
2527            )
2528        }
2529
2530        fn on_invalid_tool_call(
2531            &self,
2532            context: &InvalidToolCallContext,
2533        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2534            RetryDefaultApiHook.on_invalid_tool_call(context)
2535        }
2536    }
2537
2538    #[derive(Clone)]
2539    struct RecordingDeltaAndSkipInvalidToolHook {
2540        delta: RecordingToolCallDeltaHook,
2541    }
2542
2543    impl PromptHook<MockCompletionModel> for RecordingDeltaAndSkipInvalidToolHook {
2544        fn on_tool_call_delta(
2545            &self,
2546            tool_call_id: &str,
2547            internal_call_id: &str,
2548            tool_name: Option<&str>,
2549            tool_call_delta: &str,
2550        ) -> impl Future<Output = HookAction> + Send {
2551            self.delta.on_tool_call_delta(
2552                tool_call_id,
2553                internal_call_id,
2554                tool_name,
2555                tool_call_delta,
2556            )
2557        }
2558
2559        fn on_invalid_tool_call(
2560            &self,
2561            context: &InvalidToolCallContext,
2562        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2563            SkipDefaultApiHook.on_invalid_tool_call(context)
2564        }
2565    }
2566
2567    #[derive(Clone, Default)]
2568    struct TerminatingToolCallDeltaHook {
2569        deltas: Arc<Mutex<Vec<RecordedToolCallDelta>>>,
2570    }
2571
2572    impl TerminatingToolCallDeltaHook {
2573        fn observed(&self) -> Vec<RecordedToolCallDelta> {
2574            self.deltas
2575                .lock()
2576                .expect("tool call delta hook records mutex was poisoned")
2577                .clone()
2578        }
2579    }
2580
2581    impl PromptHook<MockCompletionModel> for TerminatingToolCallDeltaHook {
2582        fn on_tool_call_delta(
2583            &self,
2584            tool_call_id: &str,
2585            internal_call_id: &str,
2586            tool_name: Option<&str>,
2587            tool_call_delta: &str,
2588        ) -> impl Future<Output = HookAction> + Send {
2589            let deltas = self.deltas.clone();
2590            let event = (
2591                tool_call_id.to_string(),
2592                internal_call_id.to_string(),
2593                tool_name.map(str::to_string),
2594                tool_call_delta.to_string(),
2595            );
2596
2597            async move {
2598                deltas
2599                    .lock()
2600                    .expect("tool call delta hook records mutex was poisoned")
2601                    .push(event);
2602                HookAction::terminate("stop on tool call delta")
2603            }
2604        }
2605    }
2606
2607    fn text_metadata(content: &OneOrMany<AssistantContent>) -> Option<&serde_json::Value> {
2608        content.iter().find_map(|item| match item {
2609            AssistantContent::Text(text) => text.additional_params.as_ref(),
2610            _ => None,
2611        })
2612    }
2613
2614    #[tokio::test]
2615    async fn stream_prompt_continues_after_tool_call_turn() {
2616        let model = streaming_tool_then_text_model();
2617        let recorded = model.clone();
2618        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2619        let empty_history: &[Message] = &[];
2620
2621        let mut stream = agent
2622            .stream_prompt("do tool work")
2623            .with_history(empty_history)
2624            .multi_turn(3)
2625            .await;
2626        let mut saw_tool_call = false;
2627        let mut saw_tool_result = false;
2628        let mut saw_final_response = false;
2629        let mut final_text = String::new();
2630        let mut final_response_text = None;
2631        let mut final_history = None;
2632
2633        while let Some(item) = stream.next().await {
2634            match item {
2635                Ok(MultiTurnStreamItem::StreamAssistantItem(
2636                    StreamedAssistantContent::ToolCall { .. },
2637                )) => {
2638                    saw_tool_call = true;
2639                }
2640                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2641                    ..
2642                })) => {
2643                    saw_tool_result = true;
2644                }
2645                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
2646                    text,
2647                ))) => {
2648                    final_text.push_str(&text.text);
2649                }
2650                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
2651                    saw_final_response = true;
2652                    final_response_text = Some(res.response().to_owned());
2653                    final_history = res.history().map(|history| history.to_vec());
2654                    break;
2655                }
2656                Ok(_) => {}
2657                Err(err) => panic!("unexpected streaming error: {err:?}"),
2658            }
2659        }
2660
2661        assert!(saw_tool_call);
2662        assert!(saw_tool_result);
2663        assert!(saw_final_response);
2664        assert_eq!(final_text, "done");
2665        assert_eq!(final_response_text.as_deref(), Some("done"));
2666        let history = final_history.expect("expected final response history");
2667        assert!(history.iter().any(|message| matches!(
2668            message,
2669            Message::Assistant { content, .. }
2670                if content.iter().any(|item| matches!(
2671                    item,
2672                    AssistantContent::Text(text) if text.text == "done"
2673                ))
2674        )));
2675        let requests = recorded.requests();
2676        assert_eq!(requests.len(), 2);
2677        assert!(validate_follow_up_tool_history(&requests[1]).is_ok());
2678    }
2679
2680    #[tokio::test]
2681    async fn unknown_tool_call_fails_before_streaming_second_request() {
2682        let model = MockCompletionModel::from_stream_turns([
2683            vec![
2684                MockStreamEvent::tool_call(
2685                    "tool_call_1",
2686                    "default_api",
2687                    serde_json::json!({"x": 1, "y": 2}),
2688                ),
2689                MockStreamEvent::final_response_with_total_tokens(4),
2690            ],
2691            vec![
2692                MockStreamEvent::text("should not be requested"),
2693                MockStreamEvent::final_response_with_total_tokens(6),
2694            ],
2695        ]);
2696        let recorded = model.clone();
2697        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2698
2699        let mut stream = agent
2700            .stream_prompt("use the tool")
2701            .with_hook(PanicOnUnknownToolHook)
2702            .multi_turn(3)
2703            .await;
2704        let mut saw_tool_call = false;
2705        let mut error = None;
2706
2707        while let Some(item) = stream.next().await {
2708            match item {
2709                Ok(MultiTurnStreamItem::StreamAssistantItem(
2710                    StreamedAssistantContent::ToolCall { .. },
2711                )) => {
2712                    saw_tool_call = true;
2713                }
2714                Ok(_) => {}
2715                Err(err) => {
2716                    error = Some(err);
2717                    break;
2718                }
2719            }
2720        }
2721
2722        assert!(!saw_tool_call);
2723        let error = error.expect("unknown model-emitted tool should fail");
2724        match error {
2725            StreamingError::Prompt(err) => match *err {
2726                PromptError::UnknownToolCall {
2727                    tool_name,
2728                    available_tools,
2729                    allowed_tools,
2730                    chat_history,
2731                } => {
2732                    assert_eq!(tool_name, "default_api");
2733                    assert_eq!(available_tools, vec!["add".to_string()]);
2734                    assert_eq!(allowed_tools, vec!["add".to_string()]);
2735                    assert!(history_contains_tool_call(&chat_history, "default_api"));
2736                }
2737                other => panic!("expected UnknownToolCall, got {other:?}"),
2738            },
2739            other => panic!("expected prompt streaming error, got {other:?}"),
2740        }
2741        assert_eq!(recorded.request_count(), 1);
2742    }
2743
2744    #[tokio::test]
2745    async fn invalid_tool_call_hook_can_repair_streaming_tool_name() {
2746        let model = MockCompletionModel::from_stream_turns([
2747            vec![
2748                MockStreamEvent::tool_call(
2749                    "tool_call_1",
2750                    "default_api",
2751                    serde_json::json!({"x": 2, "y": 3}),
2752                ),
2753                MockStreamEvent::final_response_with_total_tokens(4),
2754            ],
2755            vec![
2756                MockStreamEvent::text("done"),
2757                MockStreamEvent::final_response_with_total_tokens(6),
2758            ],
2759        ]);
2760        let recorded = model.clone();
2761        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2762
2763        let mut stream = agent
2764            .stream_prompt("use the tool")
2765            .with_hook(RepairDefaultApiHook)
2766            .multi_turn(3)
2767            .with_history(Vec::<Message>::new())
2768            .await;
2769        let mut saw_repaired_tool_call = false;
2770        let mut saw_tool_result = false;
2771        let mut final_response_text = None;
2772
2773        while let Some(item) = stream.next().await {
2774            match item {
2775                Ok(MultiTurnStreamItem::StreamAssistantItem(
2776                    StreamedAssistantContent::ToolCall { tool_call, .. },
2777                )) => {
2778                    assert_eq!(tool_call.function.name, "add");
2779                    saw_repaired_tool_call = true;
2780                }
2781                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2782                    tool_result,
2783                    ..
2784                })) => {
2785                    assert!(tool_result.content.iter().any(|content| {
2786                        matches!(
2787                            content,
2788                            ToolResultContent::Text(text) if text.text == "5"
2789                        )
2790                    }));
2791                    saw_tool_result = true;
2792                }
2793                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2794                    final_response_text = Some(response.response().to_string());
2795                    break;
2796                }
2797                Ok(_) => {}
2798                Err(err) => panic!("unexpected streaming error: {err:?}"),
2799            }
2800        }
2801
2802        assert!(saw_repaired_tool_call);
2803        assert!(saw_tool_result);
2804        assert_eq!(final_response_text.as_deref(), Some("done"));
2805        assert_eq!(recorded.request_count(), 2);
2806    }
2807
2808    #[tokio::test]
2809    async fn invalid_tool_call_context_uses_completed_streaming_tool_call_provider_id() {
2810        let invalid_hook = RecordingInvalidToolCallHook::default();
2811        let model = MockCompletionModel::from_stream_turns([
2812            vec![
2813                MockStreamEvent::tool_call(
2814                    "tool_call_1",
2815                    "default_api",
2816                    serde_json::json!({"x": 2, "y": 3}),
2817                )
2818                .with_call_id("provider_call_1"),
2819                MockStreamEvent::final_response_with_total_tokens(4),
2820            ],
2821            vec![
2822                MockStreamEvent::text("should not be requested"),
2823                MockStreamEvent::final_response_with_total_tokens(6),
2824            ],
2825        ]);
2826        let recorded = model.clone();
2827        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2828
2829        let mut stream = agent
2830            .stream_prompt("use the tool")
2831            .with_hook(invalid_hook.clone())
2832            .multi_turn(3)
2833            .await;
2834        let mut error = None;
2835
2836        while let Some(item) = stream.next().await {
2837            if let Err(err) = item {
2838                error = Some(err);
2839                break;
2840            }
2841        }
2842
2843        assert!(error.is_some(), "invalid tool should fail");
2844        assert_eq!(recorded.request_count(), 1);
2845        let contexts = invalid_hook.observed();
2846        assert_eq!(contexts.len(), 1);
2847        let context = &contexts[0];
2848        assert_eq!(context.tool_name, "default_api");
2849        assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
2850        assert!(context.internal_call_id.is_some());
2851        assert!(context.is_streaming);
2852    }
2853
2854    #[tokio::test]
2855    async fn invalid_tool_call_hook_skip_emits_streaming_tool_result() {
2856        let add_calls = Arc::new(AtomicU32::new(0));
2857        let model = MockCompletionModel::from_stream_turns([
2858            vec![
2859                MockStreamEvent::tool_call(
2860                    "tool_call_1",
2861                    "default_api",
2862                    serde_json::json!({"x": 2, "y": 3}),
2863                )
2864                .with_call_id("call_1"),
2865                MockStreamEvent::final_response_with_total_tokens(4),
2866            ],
2867            vec![
2868                MockStreamEvent::text("continued"),
2869                MockStreamEvent::final_response_with_total_tokens(6),
2870            ],
2871        ]);
2872        let recorded = model.clone();
2873        let agent = AgentBuilder::new(model)
2874            .tool(CountingAddTool {
2875                calls: add_calls.clone(),
2876            })
2877            .build();
2878
2879        let mut stream = agent
2880            .stream_prompt("use the tool")
2881            .with_hook(SkipDefaultApiHook)
2882            .multi_turn(3)
2883            .with_history(Vec::<Message>::new())
2884            .await;
2885        let mut skipped_tool_result = None;
2886        let mut final_response_text = None;
2887
2888        while let Some(item) = stream.next().await {
2889            match item {
2890                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2891                    tool_result,
2892                    internal_call_id,
2893                })) => {
2894                    assert!(!internal_call_id.is_empty());
2895                    skipped_tool_result = Some(tool_result);
2896                }
2897                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2898                    final_response_text = Some(response.response().to_string());
2899                    break;
2900                }
2901                Ok(_) => {}
2902                Err(err) => panic!("unexpected streaming error: {err:?}"),
2903            }
2904        }
2905
2906        let skipped_tool_result =
2907            skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
2908        assert_eq!(skipped_tool_result.id, "tool_call_1");
2909        assert_eq!(skipped_tool_result.call_id.as_deref(), Some("call_1"));
2910        assert!(skipped_tool_result.content.iter().any(|content| matches!(
2911            content,
2912            ToolResultContent::Text(text) if text.text == "default_api was skipped"
2913        )));
2914        assert_eq!(final_response_text.as_deref(), Some("continued"));
2915        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
2916
2917        let requests = recorded.requests();
2918        assert_eq!(requests.len(), 2);
2919        let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
2920        assert!(matches!(
2921            follow_up_history.get(2),
2922            Some(Message::User { content })
2923                if content.iter().any(|item| matches!(
2924                    item,
2925                    UserContent::ToolResult(result)
2926                        if result.id == "tool_call_1"
2927                            && result.content.iter().any(|content| matches!(
2928                                content,
2929                                ToolResultContent::Text(text)
2930                                    if text.text == "default_api was skipped"
2931                            ))
2932                ))
2933        ));
2934    }
2935
2936    #[tokio::test]
2937    async fn invalid_tool_call_hook_retries_mixed_streaming_turn_without_executing_valid_call() {
2938        let add_calls = Arc::new(AtomicU32::new(0));
2939        let model = MockCompletionModel::from_stream_turns([
2940            vec![
2941                MockStreamEvent::text("checking "),
2942                MockStreamEvent::tool_call(
2943                    "tool_call_1",
2944                    "add",
2945                    serde_json::json!({"x": 2, "y": 3}),
2946                )
2947                .with_call_id("call_1"),
2948                MockStreamEvent::tool_call(
2949                    "tool_call_2",
2950                    "default_api",
2951                    serde_json::json!({"x": 4, "y": 5}),
2952                )
2953                .with_call_id("call_2"),
2954                MockStreamEvent::final_response_with_total_tokens(4),
2955            ],
2956            vec![
2957                MockStreamEvent::text("retried"),
2958                MockStreamEvent::final_response_with_total_tokens(6),
2959            ],
2960        ]);
2961        let recorded = model.clone();
2962        let agent = AgentBuilder::new(model)
2963            .tool(CountingAddTool {
2964                calls: add_calls.clone(),
2965            })
2966            .build();
2967
2968        let mut stream = agent
2969            .stream_prompt("use the tool")
2970            .with_hook(RetryDefaultApiHook)
2971            .multi_turn(3)
2972            .with_history(Vec::<Message>::new())
2973            .max_invalid_tool_call_retries(1)
2974            .await;
2975        let mut completion_call_events = Vec::new();
2976        let mut final_response_text = None;
2977        let mut final_response_usage = Usage::new();
2978        let mut final_completion_calls = Vec::new();
2979
2980        while let Some(item) = stream.next().await {
2981            match item {
2982                Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
2983                    completion_call_events.push(completion_call);
2984                }
2985                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2986                    final_response_text = Some(response.response().to_string());
2987                    final_response_usage = response.usage();
2988                    final_completion_calls = response.completion_calls().to_vec();
2989                    break;
2990                }
2991                Ok(_) => {}
2992                Err(err) => panic!("unexpected streaming error: {err:?}"),
2993            }
2994        }
2995
2996        assert_eq!(final_response_text.as_deref(), Some("retried"));
2997        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
2998        let mut first_usage = Usage::new();
2999        first_usage.total_tokens = 4;
3000        let mut second_usage = Usage::new();
3001        second_usage.total_tokens = 6;
3002        let expected_completion_calls = vec![
3003            CompletionCall::new(0, Some(first_usage)),
3004            CompletionCall::new(1, Some(second_usage)),
3005        ];
3006        assert_eq!(completion_call_events, expected_completion_calls);
3007        assert_eq!(final_completion_calls, expected_completion_calls);
3008        assert_eq!(final_response_usage.total_tokens, 10);
3009
3010        let requests = recorded.requests();
3011        assert_eq!(requests.len(), 2);
3012        let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3013        assert_eq!(retry_history.len(), 3);
3014        assert!(matches!(
3015            retry_history.get(1),
3016            Some(Message::Assistant { content, .. })
3017                if content.iter().any(|item| matches!(
3018                    item,
3019                    AssistantContent::Text(text) if text.text == "checking "
3020                ))
3021                    && content.iter().any(|item| matches!(
3022                        item,
3023                        AssistantContent::ToolCall(tool_call)
3024                            if tool_call.id == "tool_call_1"
3025                                && tool_call.function.name == "add"
3026                    ))
3027                    && content.iter().any(|item| matches!(
3028                        item,
3029                        AssistantContent::ToolCall(tool_call)
3030                            if tool_call.id == "tool_call_2"
3031                                && tool_call.function.name == "default_api"
3032                    ))
3033        ));
3034        assert!(matches!(
3035            retry_history.get(2),
3036            Some(Message::User { content })
3037                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3038                    && content.iter().any(|item| matches!(
3039                        item,
3040                        UserContent::ToolResult(result)
3041                            if result.id == "tool_call_1"
3042                                && result.content.iter().any(|content| matches!(
3043                                    content,
3044                                    ToolResultContent::Text(text)
3045                                        if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3046                                ))
3047                    ))
3048                    && content.iter().any(|item| matches!(
3049                        item,
3050                        UserContent::ToolResult(result)
3051                            if result.id == "tool_call_2"
3052                                && result.content.iter().any(|content| matches!(
3053                                    content,
3054                                    ToolResultContent::Text(text)
3055                                        if text.text == "Use the add tool instead"
3056                                ))
3057                    ))
3058        ));
3059    }
3060
3061    #[tokio::test]
3062    async fn invalid_tool_call_hook_skips_mixed_streaming_turn_without_executing_valid_call() {
3063        let add_calls = Arc::new(AtomicU32::new(0));
3064        let model = MockCompletionModel::from_stream_turns([
3065            vec![
3066                MockStreamEvent::text("checking "),
3067                MockStreamEvent::tool_call(
3068                    "tool_call_1",
3069                    "add",
3070                    serde_json::json!({"x": 2, "y": 3}),
3071                )
3072                .with_call_id("call_1"),
3073                MockStreamEvent::tool_call(
3074                    "tool_call_2",
3075                    "default_api",
3076                    serde_json::json!({"x": 4, "y": 5}),
3077                )
3078                .with_call_id("call_2"),
3079                MockStreamEvent::final_response_with_total_tokens(4),
3080            ],
3081            vec![
3082                MockStreamEvent::text("continued"),
3083                MockStreamEvent::final_response_with_total_tokens(6),
3084            ],
3085        ]);
3086        let recorded = model.clone();
3087        let agent = AgentBuilder::new(model)
3088            .tool(CountingAddTool {
3089                calls: add_calls.clone(),
3090            })
3091            .build();
3092
3093        let mut stream = agent
3094            .stream_prompt("use the tool")
3095            .with_hook(SkipDefaultApiHook)
3096            .multi_turn(3)
3097            .with_history(Vec::<Message>::new())
3098            .await;
3099        let mut skipped_tool_result = None;
3100        let mut final_response_text = None;
3101
3102        while let Some(item) = stream.next().await {
3103            match item {
3104                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3105                    tool_result,
3106                    ..
3107                })) => {
3108                    skipped_tool_result = Some(tool_result);
3109                }
3110                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3111                    final_response_text = Some(response.response().to_string());
3112                    break;
3113                }
3114                Ok(_) => {}
3115                Err(err) => panic!("unexpected streaming error: {err:?}"),
3116            }
3117        }
3118
3119        let skipped_tool_result =
3120            skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
3121        assert_eq!(skipped_tool_result.id, "tool_call_2");
3122        assert_eq!(skipped_tool_result.call_id.as_deref(), Some("call_2"));
3123        assert_eq!(final_response_text.as_deref(), Some("continued"));
3124        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3125
3126        let requests = recorded.requests();
3127        assert_eq!(requests.len(), 2);
3128        let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3129        assert_eq!(follow_up_history.len(), 3);
3130        assert!(matches!(
3131            follow_up_history.get(1),
3132            Some(Message::Assistant { content, .. })
3133                if content.iter().any(|item| matches!(
3134                    item,
3135                    AssistantContent::Text(text) if text.text == "checking "
3136                ))
3137                    && content.iter().any(|item| matches!(
3138                        item,
3139                        AssistantContent::ToolCall(tool_call)
3140                            if tool_call.id == "tool_call_1"
3141                                && tool_call.function.name == "add"
3142                    ))
3143                    && content.iter().any(|item| matches!(
3144                        item,
3145                        AssistantContent::ToolCall(tool_call)
3146                            if tool_call.id == "tool_call_2"
3147                                && tool_call.function.name == "default_api"
3148                    ))
3149        ));
3150        assert!(matches!(
3151            follow_up_history.get(2),
3152            Some(Message::User { content })
3153                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3154                    && content.iter().any(|item| matches!(
3155                        item,
3156                        UserContent::ToolResult(result)
3157                            if result.id == "tool_call_1"
3158                                && result.call_id.as_deref() == Some("call_1")
3159                                && result.content.iter().any(|content| matches!(
3160                                    content,
3161                                    ToolResultContent::Text(text)
3162                                        if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3163                                ))
3164                    ))
3165                    && content.iter().any(|item| matches!(
3166                        item,
3167                        UserContent::ToolResult(result)
3168                            if result.id == "tool_call_2"
3169                                && result.call_id.as_deref() == Some("call_2")
3170                                && result.content.iter().any(|content| matches!(
3171                                    content,
3172                                    ToolResultContent::Text(text)
3173                                        if text.text == "default_api was skipped"
3174                                ))
3175            ))
3176        ));
3177    }
3178
3179    #[tokio::test]
3180    async fn invalid_completed_tool_call_skip_preserves_streaming_reasoning_history() {
3181        let model = MockCompletionModel::from_stream_turns([
3182            vec![
3183                MockStreamEvent::text("checking "),
3184                MockStreamEvent::reasoning("reasoned step").with_reasoning_id("rs_1"),
3185                MockStreamEvent::tool_call(
3186                    "tool_call_1",
3187                    "default_api",
3188                    serde_json::json!({"x": 2, "y": 3}),
3189                ),
3190                MockStreamEvent::final_response_with_total_tokens(4),
3191            ],
3192            vec![
3193                MockStreamEvent::text("continued"),
3194                MockStreamEvent::final_response_with_total_tokens(6),
3195            ],
3196        ]);
3197        let recorded = model.clone();
3198        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3199
3200        let mut stream = agent
3201            .stream_prompt("use the tool")
3202            .with_hook(SkipDefaultApiHook)
3203            .multi_turn(3)
3204            .with_history(Vec::<Message>::new())
3205            .await;
3206
3207        while let Some(item) = stream.next().await {
3208            match item {
3209                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3210                Ok(_) => {}
3211                Err(err) => panic!("unexpected streaming error: {err:?}"),
3212            }
3213        }
3214
3215        let requests = recorded.requests();
3216        assert_eq!(requests.len(), 2);
3217        let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3218        assert!(history_contains_text(&follow_up_history, "checking "));
3219        assert!(assistant_reasoning_precedes_tool_call(
3220            &follow_up_history,
3221            "reasoned step",
3222            "default_api"
3223        ));
3224    }
3225
3226    #[tokio::test]
3227    async fn invalid_name_delta_retry_preserves_streaming_reasoning_history() {
3228        let model = MockCompletionModel::from_stream_turns([
3229            vec![
3230                MockStreamEvent::reasoning_delta(Some("rs_1"), "delta reason"),
3231                MockStreamEvent::tool_call_arguments_delta(
3232                    "tool_call_1",
3233                    "internal_1",
3234                    r#"{"x":2,"y":3}"#,
3235                ),
3236                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3237                MockStreamEvent::final_response_with_total_tokens(4),
3238            ],
3239            vec![
3240                MockStreamEvent::text("retried"),
3241                MockStreamEvent::final_response_with_total_tokens(6),
3242            ],
3243        ]);
3244        let recorded = model.clone();
3245        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3246
3247        let mut stream = agent
3248            .stream_prompt("use the tool")
3249            .with_hook(RetryDefaultApiHook)
3250            .multi_turn(3)
3251            .with_history(Vec::<Message>::new())
3252            .max_invalid_tool_call_retries(1)
3253            .await;
3254
3255        while let Some(item) = stream.next().await {
3256            match item {
3257                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3258                Ok(_) => {}
3259                Err(err) => panic!("unexpected streaming error: {err:?}"),
3260            }
3261        }
3262
3263        let requests = recorded.requests();
3264        assert_eq!(requests.len(), 2);
3265        let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3266        assert!(assistant_reasoning_precedes_tool_call(
3267            &retry_history,
3268            "delta reason",
3269            "default_api"
3270        ));
3271    }
3272
3273    #[tokio::test]
3274    async fn invalid_tool_call_hook_skip_resets_streaming_text_delta_state() {
3275        let text_hook = RecordingTextDeltaHook::default();
3276        let model = MockCompletionModel::from_stream_turns([
3277            vec![
3278                MockStreamEvent::text("stale "),
3279                MockStreamEvent::tool_call(
3280                    "tool_call_1",
3281                    "default_api",
3282                    serde_json::json!({"x": 2, "y": 3}),
3283                ),
3284                MockStreamEvent::final_response_with_total_tokens(4),
3285            ],
3286            vec![
3287                MockStreamEvent::text("fresh"),
3288                MockStreamEvent::final_response_with_total_tokens(6),
3289            ],
3290        ]);
3291        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3292
3293        let mut stream = agent
3294            .stream_prompt("use the tool")
3295            .with_hook(RecordingTextAndSkipInvalidToolHook {
3296                text: text_hook.clone(),
3297            })
3298            .multi_turn(3)
3299            .with_history(Vec::<Message>::new())
3300            .await;
3301
3302        while let Some(item) = stream.next().await {
3303            match item {
3304                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3305                Ok(_) => {}
3306                Err(err) => panic!("unexpected streaming error: {err:?}"),
3307            }
3308        }
3309
3310        assert_eq!(
3311            text_hook.observed(),
3312            vec![
3313                ("stale ".to_string(), "stale ".to_string()),
3314                ("fresh".to_string(), "fresh".to_string()),
3315            ]
3316        );
3317    }
3318
3319    #[tokio::test]
3320    async fn invalid_tool_call_delta_retry_uses_structured_tool_feedback() {
3321        let delta_hook = RecordingToolCallDeltaHook::default();
3322        let add_calls = Arc::new(AtomicU32::new(0));
3323        let model = MockCompletionModel::from_stream_turns([
3324            vec![
3325                MockStreamEvent::text("checking "),
3326                MockStreamEvent::reasoning_delta(Some("rs_1"), "diagnostic reason"),
3327                MockStreamEvent::tool_call(
3328                    "tool_call_0",
3329                    "add",
3330                    serde_json::json!({"x": 1, "y": 2}),
3331                )
3332                .with_call_id("call_0"),
3333                MockStreamEvent::tool_call_arguments_delta(
3334                    "tool_call_1",
3335                    "internal_1",
3336                    r#"{"x":2,"y":3}"#,
3337                ),
3338                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3339                MockStreamEvent::final_response_with_total_tokens(4),
3340            ],
3341            vec![
3342                MockStreamEvent::text("retried"),
3343                MockStreamEvent::final_response_with_total_tokens(6),
3344            ],
3345        ]);
3346        let recorded = model.clone();
3347        let agent = AgentBuilder::new(model)
3348            .tool(CountingAddTool {
3349                calls: add_calls.clone(),
3350            })
3351            .build();
3352
3353        let mut stream = agent
3354            .stream_prompt("use the tool")
3355            .with_hook(RecordingDeltaAndRetryInvalidToolHook {
3356                delta: delta_hook.clone(),
3357            })
3358            .multi_turn(3)
3359            .with_history(Vec::<Message>::new())
3360            .max_invalid_tool_call_retries(1)
3361            .await;
3362        let mut completion_call_events = Vec::new();
3363        let mut final_response_text = None;
3364        let mut final_response_usage = Usage::new();
3365        let mut final_completion_calls = Vec::new();
3366
3367        while let Some(item) = stream.next().await {
3368            match item {
3369                Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
3370                    completion_call_events.push(completion_call);
3371                }
3372                Ok(MultiTurnStreamItem::StreamAssistantItem(
3373                    StreamedAssistantContent::ToolCallDelta { .. },
3374                )) => panic!("invalid tool-call delta should not be emitted"),
3375                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3376                    final_response_text = Some(response.response().to_string());
3377                    final_response_usage = response.usage();
3378                    final_completion_calls = response.completion_calls().to_vec();
3379                    break;
3380                }
3381                Ok(_) => {}
3382                Err(err) => panic!("unexpected streaming error: {err:?}"),
3383            }
3384        }
3385
3386        assert_eq!(final_response_text.as_deref(), Some("retried"));
3387        assert!(delta_hook.observed().is_empty());
3388        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3389        let mut first_usage = Usage::new();
3390        first_usage.total_tokens = 4;
3391        let mut second_usage = Usage::new();
3392        second_usage.total_tokens = 6;
3393        let expected_completion_calls = vec![
3394            CompletionCall::new(0, Some(first_usage)),
3395            CompletionCall::new(1, Some(second_usage)),
3396        ];
3397        assert_eq!(completion_call_events, expected_completion_calls);
3398        assert_eq!(final_completion_calls, expected_completion_calls);
3399        assert_eq!(final_response_usage.total_tokens, 10);
3400
3401        let requests = recorded.requests();
3402        assert_eq!(requests.len(), 2);
3403        let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3404        assert!(matches!(
3405            retry_history.get(1),
3406            Some(Message::Assistant { content, .. })
3407                if content.iter().any(|item| matches!(
3408                    item,
3409                    AssistantContent::Text(text) if text.text == "checking "
3410                ))
3411                    && content.iter().any(|item| matches!(
3412                        item,
3413                        AssistantContent::ToolCall(tool_call)
3414                            if tool_call.id == "tool_call_0"
3415                                && tool_call.function.name == "add"
3416                    ))
3417                    && content.iter().any(|item| matches!(
3418                    item,
3419                    AssistantContent::ToolCall(tool_call)
3420                        if tool_call.id == "tool_call_1"
3421                            && tool_call.function.name == "default_api"
3422                            && tool_call.function.arguments == serde_json::json!({"x": 2, "y": 3})
3423                ))
3424        ));
3425        assert!(matches!(
3426            retry_history.get(2),
3427            Some(Message::User { content })
3428                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3429                    && content.iter().any(|item| matches!(
3430                        item,
3431                        UserContent::ToolResult(result)
3432                            if result.id == "tool_call_0"
3433                                && result.call_id.as_deref() == Some("call_0")
3434                                && result.content.iter().any(|content| matches!(
3435                                    content,
3436                                    ToolResultContent::Text(text)
3437                                        if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3438                                ))
3439                    ))
3440                    && content.iter().any(|item| matches!(
3441                    item,
3442                    UserContent::ToolResult(result)
3443                        if result.id == "tool_call_1"
3444                            && result.content.iter().any(|content| matches!(
3445                                content,
3446                                ToolResultContent::Text(text)
3447                                    if text.text == "Use the add tool instead"
3448                            ))
3449                ))
3450        ));
3451    }
3452
3453    #[tokio::test]
3454    async fn invalid_tool_call_delta_context_includes_same_turn_history_and_tool_call_id() {
3455        let invalid_hook = RecordingInvalidToolCallHook::default();
3456        let model = MockCompletionModel::from_stream_turns([
3457            vec![
3458                MockStreamEvent::text("checking "),
3459                MockStreamEvent::reasoning_delta(Some("rs_1"), "diagnostic reason"),
3460                MockStreamEvent::tool_call(
3461                    "tool_call_0",
3462                    "add",
3463                    serde_json::json!({"x": 1, "y": 2}),
3464                )
3465                .with_call_id("call_0"),
3466                MockStreamEvent::tool_call_arguments_delta(
3467                    "tool_call_1",
3468                    "internal_1",
3469                    r#"{"x":2,"y":3}"#,
3470                ),
3471                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3472                MockStreamEvent::final_response_with_total_tokens(4),
3473            ],
3474            vec![
3475                MockStreamEvent::text("should not be requested"),
3476                MockStreamEvent::final_response_with_total_tokens(6),
3477            ],
3478        ]);
3479        let recorded = model.clone();
3480        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3481
3482        let mut stream = agent
3483            .stream_prompt("use the tool")
3484            .with_hook(invalid_hook.clone())
3485            .multi_turn(3)
3486            .await;
3487        let mut error = None;
3488
3489        while let Some(item) = stream.next().await {
3490            if let Err(err) = item {
3491                error = Some(err);
3492                break;
3493            }
3494        }
3495
3496        assert!(error.is_some(), "invalid name delta should fail");
3497        assert_eq!(recorded.request_count(), 1);
3498        let contexts = invalid_hook.observed();
3499        assert_eq!(contexts.len(), 1);
3500        let context = &contexts[0];
3501        assert_eq!(context.tool_name, "default_api");
3502        assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
3503        assert_eq!(context.internal_call_id.as_deref(), Some("internal_1"));
3504        assert!(context.is_streaming);
3505        assert!(history_contains_text(&context.chat_history, "checking "));
3506        assert!(
3507            assistant_reasoning_precedes_tool_call(
3508                &context.chat_history,
3509                "diagnostic reason",
3510                "add"
3511            ),
3512            "{:?}",
3513            context.chat_history
3514        );
3515        assert!(history_contains_tool_call(&context.chat_history, "add"));
3516        assert!(history_contains_tool_call(
3517            &context.chat_history,
3518            "default_api"
3519        ));
3520    }
3521
3522    #[tokio::test]
3523    async fn invalid_tool_call_delta_retry_resets_streaming_text_delta_state() {
3524        let text_hook = RecordingTextDeltaHook::default();
3525        let model = MockCompletionModel::from_stream_turns([
3526            vec![
3527                MockStreamEvent::text("stale "),
3528                MockStreamEvent::tool_call_arguments_delta(
3529                    "tool_call_1",
3530                    "internal_1",
3531                    r#"{"x":2,"y":3}"#,
3532                ),
3533                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3534                MockStreamEvent::final_response_with_total_tokens(4),
3535            ],
3536            vec![
3537                MockStreamEvent::text("fresh"),
3538                MockStreamEvent::final_response_with_total_tokens(6),
3539            ],
3540        ]);
3541        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3542
3543        let mut stream = agent
3544            .stream_prompt("use the tool")
3545            .with_hook(RecordingTextAndRetryInvalidToolHook {
3546                text: text_hook.clone(),
3547            })
3548            .multi_turn(3)
3549            .with_history(Vec::<Message>::new())
3550            .max_invalid_tool_call_retries(1)
3551            .await;
3552
3553        while let Some(item) = stream.next().await {
3554            match item {
3555                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3556                Ok(_) => {}
3557                Err(err) => panic!("unexpected streaming error: {err:?}"),
3558            }
3559        }
3560
3561        assert_eq!(
3562            text_hook.observed(),
3563            vec![
3564                ("stale ".to_string(), "stale ".to_string()),
3565                ("fresh".to_string(), "fresh".to_string()),
3566            ]
3567        );
3568    }
3569
3570    #[tokio::test]
3571    async fn invalid_tool_call_delta_skip_uses_structured_tool_feedback() {
3572        let delta_hook = RecordingToolCallDeltaHook::default();
3573        let add_calls = Arc::new(AtomicU32::new(0));
3574        let model = MockCompletionModel::from_stream_turns([
3575            vec![
3576                MockStreamEvent::text("checking "),
3577                MockStreamEvent::tool_call(
3578                    "tool_call_0",
3579                    "add",
3580                    serde_json::json!({"x": 1, "y": 2}),
3581                )
3582                .with_call_id("call_0"),
3583                MockStreamEvent::tool_call_arguments_delta(
3584                    "tool_call_1",
3585                    "internal_1",
3586                    r#"{"x":2,"y":3}"#,
3587                ),
3588                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3589                MockStreamEvent::final_response_with_total_tokens(4),
3590            ],
3591            vec![
3592                MockStreamEvent::text("continued"),
3593                MockStreamEvent::final_response_with_total_tokens(6),
3594            ],
3595        ]);
3596        let recorded = model.clone();
3597        let agent = AgentBuilder::new(model)
3598            .tool(CountingAddTool {
3599                calls: add_calls.clone(),
3600            })
3601            .build();
3602
3603        let mut stream = agent
3604            .stream_prompt("use the tool")
3605            .with_hook(RecordingDeltaAndSkipInvalidToolHook {
3606                delta: delta_hook.clone(),
3607            })
3608            .multi_turn(3)
3609            .with_history(Vec::<Message>::new())
3610            .await;
3611        let mut skipped_tool_result = None;
3612        let mut final_response_text = None;
3613
3614        while let Some(item) = stream.next().await {
3615            match item {
3616                Ok(MultiTurnStreamItem::StreamAssistantItem(
3617                    StreamedAssistantContent::ToolCallDelta { .. },
3618                )) => panic!("invalid tool-call delta should not be emitted"),
3619                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3620                    tool_result,
3621                    internal_call_id,
3622                })) => {
3623                    assert_eq!(internal_call_id, "internal_1");
3624                    skipped_tool_result = Some(tool_result);
3625                }
3626                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3627                    final_response_text = Some(response.response().to_string());
3628                    break;
3629                }
3630                Ok(_) => {}
3631                Err(err) => panic!("unexpected streaming error: {err:?}"),
3632            }
3633        }
3634
3635        let skipped_tool_result =
3636            skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
3637        assert_eq!(skipped_tool_result.id, "tool_call_1");
3638        assert!(skipped_tool_result.call_id.is_none());
3639        assert!(skipped_tool_result.content.iter().any(|content| matches!(
3640            content,
3641            ToolResultContent::Text(text) if text.text == "default_api was skipped"
3642        )));
3643        assert_eq!(final_response_text.as_deref(), Some("continued"));
3644        assert!(delta_hook.observed().is_empty());
3645        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3646
3647        let requests = recorded.requests();
3648        assert_eq!(requests.len(), 2);
3649        let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3650        assert!(matches!(
3651            follow_up_history.get(1),
3652            Some(Message::Assistant { content, .. })
3653                if content.iter().any(|item| matches!(
3654                    item,
3655                    AssistantContent::Text(text) if text.text == "checking "
3656                ))
3657                    && content.iter().any(|item| matches!(
3658                        item,
3659                        AssistantContent::ToolCall(tool_call)
3660                            if tool_call.id == "tool_call_0"
3661                                && tool_call.function.name == "add"
3662                    ))
3663                    && content.iter().any(|item| matches!(
3664                    item,
3665                    AssistantContent::ToolCall(tool_call)
3666                        if tool_call.id == "tool_call_1"
3667                            && tool_call.function.name == "default_api"
3668                            && tool_call.function.arguments == serde_json::json!({"x": 2, "y": 3})
3669                ))
3670        ));
3671        assert!(matches!(
3672            follow_up_history.get(2),
3673            Some(Message::User { content })
3674                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3675                    && content.iter().any(|item| matches!(
3676                        item,
3677                        UserContent::ToolResult(result)
3678                            if result.id == "tool_call_0"
3679                                && result.call_id.as_deref() == Some("call_0")
3680                                && result.content.iter().any(|content| matches!(
3681                                    content,
3682                                    ToolResultContent::Text(text)
3683                                        if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3684                                ))
3685                    ))
3686                    && content.iter().any(|item| matches!(
3687                    item,
3688                    UserContent::ToolResult(result)
3689                        if result.id == "tool_call_1"
3690                            && result.content.iter().any(|content| matches!(
3691                                content,
3692                                ToolResultContent::Text(text)
3693                                    if text.text == "default_api was skipped"
3694                            ))
3695                ))
3696        ));
3697    }
3698
3699    #[tokio::test]
3700    async fn streaming_retry_budget_exhaustion_history_contains_invalid_tool_call() {
3701        let model = MockCompletionModel::from_stream_turns([
3702            vec![
3703                MockStreamEvent::tool_call(
3704                    "tool_call_1",
3705                    "default_api",
3706                    serde_json::json!({"x": 1, "y": 2}),
3707                ),
3708                MockStreamEvent::final_response_with_total_tokens(4),
3709            ],
3710            vec![
3711                MockStreamEvent::text("should not be requested"),
3712                MockStreamEvent::final_response_with_total_tokens(6),
3713            ],
3714        ]);
3715        let recorded = model.clone();
3716        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3717
3718        let mut stream = agent
3719            .stream_prompt("use the tool")
3720            .with_hook(RetryDefaultApiHook)
3721            .multi_turn(3)
3722            .max_invalid_tool_call_retries(0)
3723            .await;
3724        let mut error = None;
3725
3726        while let Some(item) = stream.next().await {
3727            if let Err(err) = item {
3728                error = Some(err);
3729                break;
3730            }
3731        }
3732
3733        let error = error.expect("retry budget exhaustion should fail");
3734        match error {
3735            StreamingError::Prompt(err) => match *err {
3736                PromptError::UnknownToolCall {
3737                    tool_name,
3738                    chat_history,
3739                    ..
3740                } => {
3741                    assert_eq!(tool_name, "default_api");
3742                    assert!(history_contains_tool_call(&chat_history, "default_api"));
3743                }
3744                other => panic!("expected UnknownToolCall, got {other:?}"),
3745            },
3746            other => panic!("expected prompt streaming error, got {other:?}"),
3747        }
3748        assert_eq!(recorded.request_count(), 1);
3749    }
3750
3751    #[tokio::test]
3752    async fn streaming_name_delta_retry_budget_exhaustion_history_includes_same_turn_context() {
3753        let model = MockCompletionModel::from_stream_turns([
3754            vec![
3755                MockStreamEvent::text("checking "),
3756                MockStreamEvent::tool_call(
3757                    "tool_call_0",
3758                    "add",
3759                    serde_json::json!({"x": 1, "y": 2}),
3760                )
3761                .with_call_id("call_0"),
3762                MockStreamEvent::tool_call_arguments_delta(
3763                    "tool_call_1",
3764                    "internal_1",
3765                    r#"{"x":2,"y":3}"#,
3766                ),
3767                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3768                MockStreamEvent::final_response_with_total_tokens(4),
3769            ],
3770            vec![
3771                MockStreamEvent::text("should not be requested"),
3772                MockStreamEvent::final_response_with_total_tokens(6),
3773            ],
3774        ]);
3775        let recorded = model.clone();
3776        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3777
3778        let mut stream = agent
3779            .stream_prompt("use the tool")
3780            .with_hook(RetryDefaultApiHook)
3781            .multi_turn(3)
3782            .max_invalid_tool_call_retries(0)
3783            .await;
3784        let mut error = None;
3785
3786        while let Some(item) = stream.next().await {
3787            if let Err(err) = item {
3788                error = Some(err);
3789                break;
3790            }
3791        }
3792
3793        let error = error.expect("retry budget exhaustion should fail");
3794        match error {
3795            StreamingError::Prompt(err) => match *err {
3796                PromptError::UnknownToolCall {
3797                    tool_name,
3798                    chat_history,
3799                    ..
3800                } => {
3801                    assert_eq!(tool_name, "default_api");
3802                    assert!(history_contains_text(&chat_history, "checking "));
3803                    assert!(history_contains_tool_call(&chat_history, "add"));
3804                    assert!(history_contains_tool_call(&chat_history, "default_api"));
3805                }
3806                other => panic!("expected UnknownToolCall, got {other:?}"),
3807            },
3808            other => panic!("expected prompt streaming error, got {other:?}"),
3809        }
3810        assert_eq!(recorded.request_count(), 1);
3811    }
3812
3813    #[tokio::test]
3814    async fn completed_unknown_tool_call_after_text_fails_before_finish_hook_or_later_emit() {
3815        let add_calls = Arc::new(AtomicU32::new(0));
3816        let model = MockCompletionModel::from_stream_turns([
3817            vec![
3818                MockStreamEvent::text("thinking "),
3819                MockStreamEvent::tool_call(
3820                    "tool_call_1",
3821                    "default_api",
3822                    serde_json::json!({"x": 1, "y": 2}),
3823                ),
3824                MockStreamEvent::final_response_with_total_tokens(4),
3825            ],
3826            vec![
3827                MockStreamEvent::text("should not be requested"),
3828                MockStreamEvent::final_response_with_total_tokens(6),
3829            ],
3830        ]);
3831        let recorded = model.clone();
3832        let agent = AgentBuilder::new(model)
3833            .tool(CountingAddTool {
3834                calls: add_calls.clone(),
3835            })
3836            .build();
3837
3838        let mut stream = agent
3839            .stream_prompt("use the tool")
3840            .with_hook(PanicOnUnknownToolHook)
3841            .multi_turn(3)
3842            .await;
3843        let mut saw_text = false;
3844        let mut saw_completion_call = false;
3845        let mut saw_final_response = false;
3846        let mut saw_tool_call = false;
3847        let mut saw_tool_result = false;
3848        let mut error = None;
3849
3850        while let Some(item) = stream.next().await {
3851            match item {
3852                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(_))) => {
3853                    saw_text = true;
3854                }
3855                Ok(MultiTurnStreamItem::CompletionCall(_)) => {
3856                    saw_completion_call = true;
3857                }
3858                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Final(
3859                    _,
3860                )))
3861                | Ok(MultiTurnStreamItem::FinalResponse(_)) => {
3862                    saw_final_response = true;
3863                }
3864                Ok(MultiTurnStreamItem::StreamAssistantItem(
3865                    StreamedAssistantContent::ToolCall { .. },
3866                )) => {
3867                    saw_tool_call = true;
3868                }
3869                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3870                    ..
3871                })) => {
3872                    saw_tool_result = true;
3873                }
3874                Ok(_) => {}
3875                Err(err) => {
3876                    error = Some(err);
3877                    break;
3878                }
3879            }
3880        }
3881
3882        assert!(saw_text);
3883        assert!(!saw_completion_call);
3884        assert!(!saw_final_response);
3885        assert!(!saw_tool_call);
3886        assert!(!saw_tool_result);
3887        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3888        let error = error.expect("completed unknown tool call should fail immediately");
3889        match error {
3890            StreamingError::Prompt(err) => match *err {
3891                PromptError::UnknownToolCall {
3892                    tool_name,
3893                    available_tools,
3894                    allowed_tools,
3895                    chat_history,
3896                } => {
3897                    assert_eq!(tool_name, "default_api");
3898                    assert_eq!(available_tools, vec!["add".to_string()]);
3899                    assert_eq!(allowed_tools, vec!["add".to_string()]);
3900                    assert!(history_contains_tool_call(&chat_history, "default_api"));
3901                }
3902                other => panic!("expected UnknownToolCall, got {other:?}"),
3903            },
3904            other => panic!("expected prompt streaming error, got {other:?}"),
3905        }
3906        assert_eq!(recorded.request_count(), 1);
3907    }
3908
3909    #[tokio::test]
3910    async fn mixed_streaming_tool_calls_fail_before_any_tool_execution() {
3911        let add_calls = Arc::new(AtomicU32::new(0));
3912        let model = MockCompletionModel::from_stream_turns([
3913            vec![
3914                MockStreamEvent::tool_call(
3915                    "tool_call_1",
3916                    "add",
3917                    serde_json::json!({"x": 1, "y": 2}),
3918                )
3919                .with_call_id("call_1"),
3920                MockStreamEvent::tool_call(
3921                    "tool_call_2",
3922                    "default_api",
3923                    serde_json::json!({"x": 3, "y": 4}),
3924                ),
3925                MockStreamEvent::final_response_with_total_tokens(4),
3926            ],
3927            vec![
3928                MockStreamEvent::text("should not be requested"),
3929                MockStreamEvent::final_response_with_total_tokens(6),
3930            ],
3931        ]);
3932        let recorded = model.clone();
3933        let agent = AgentBuilder::new(model)
3934            .tool(CountingAddTool {
3935                calls: add_calls.clone(),
3936            })
3937            .build();
3938
3939        let mut stream = agent
3940            .stream_prompt("use tools")
3941            .with_hook(PanicOnUnknownToolHook)
3942            .multi_turn(3)
3943            .await;
3944        let mut saw_completion_call = false;
3945        let mut saw_tool_call = false;
3946        let mut saw_tool_result = false;
3947        let mut error = None;
3948
3949        while let Some(item) = stream.next().await {
3950            match item {
3951                Ok(MultiTurnStreamItem::CompletionCall(_)) => {
3952                    saw_completion_call = true;
3953                }
3954                Ok(MultiTurnStreamItem::StreamAssistantItem(
3955                    StreamedAssistantContent::ToolCall { .. },
3956                )) => {
3957                    saw_tool_call = true;
3958                }
3959                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3960                    ..
3961                })) => {
3962                    saw_tool_result = true;
3963                }
3964                Ok(_) => {}
3965                Err(err) => {
3966                    error = Some(err);
3967                    break;
3968                }
3969            }
3970        }
3971
3972        assert!(!saw_completion_call);
3973        assert!(!saw_tool_call);
3974        assert!(!saw_tool_result);
3975        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3976        let error = error.expect("mixed unknown streamed tool call should fail");
3977        match error {
3978            StreamingError::Prompt(err) => match *err {
3979                PromptError::UnknownToolCall {
3980                    tool_name,
3981                    available_tools,
3982                    allowed_tools,
3983                    chat_history,
3984                } => {
3985                    assert_eq!(tool_name, "default_api");
3986                    assert_eq!(available_tools, vec!["add".to_string()]);
3987                    assert_eq!(allowed_tools, vec!["add".to_string()]);
3988                    assert!(history_contains_tool_call(&chat_history, "default_api"));
3989                }
3990                other => panic!("expected UnknownToolCall, got {other:?}"),
3991            },
3992            other => panic!("expected prompt streaming error, got {other:?}"),
3993        }
3994        assert_eq!(recorded.request_count(), 1);
3995    }
3996
3997    #[tokio::test]
3998    async fn multiple_valid_streaming_tool_calls_execute_after_batch_validation() {
3999        let add_calls = Arc::new(AtomicU32::new(0));
4000        let subtract_calls = Arc::new(AtomicU32::new(0));
4001        let model = MockCompletionModel::from_stream_turns([
4002            vec![
4003                MockStreamEvent::tool_call(
4004                    "tool_call_1",
4005                    "add",
4006                    serde_json::json!({"x": 1, "y": 2}),
4007                )
4008                .with_call_id("call_1"),
4009                MockStreamEvent::tool_call(
4010                    "tool_call_2",
4011                    "subtract",
4012                    serde_json::json!({"x": 8, "y": 3}),
4013                )
4014                .with_call_id("call_2"),
4015                MockStreamEvent::final_response_with_total_tokens(4),
4016            ],
4017            vec![
4018                MockStreamEvent::text("done"),
4019                MockStreamEvent::final_response_with_total_tokens(6),
4020            ],
4021        ]);
4022        let recorded = model.clone();
4023        let agent = AgentBuilder::new(model)
4024            .tool(CountingAddTool {
4025                calls: add_calls.clone(),
4026            })
4027            .tool(CountingSubtractTool {
4028                calls: subtract_calls.clone(),
4029            })
4030            .build();
4031
4032        let mut stream = agent.stream_prompt("use tools").multi_turn(3).await;
4033        let mut tool_call_names = Vec::new();
4034        let mut tool_result_ids = Vec::new();
4035        let mut final_response_text = None;
4036
4037        while let Some(item) = stream.next().await {
4038            match item {
4039                Ok(MultiTurnStreamItem::StreamAssistantItem(
4040                    StreamedAssistantContent::ToolCall { tool_call, .. },
4041                )) => {
4042                    tool_call_names.push(tool_call.function.name);
4043                }
4044                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
4045                    tool_result,
4046                    ..
4047                })) => {
4048                    tool_result_ids.push(tool_result.id);
4049                }
4050                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4051                    final_response_text = Some(response.response().to_owned());
4052                    break;
4053                }
4054                Ok(_) => {}
4055                Err(err) => panic!("unexpected streaming error: {err:?}"),
4056            }
4057        }
4058
4059        assert_eq!(
4060            tool_call_names,
4061            vec!["add".to_string(), "subtract".to_string()]
4062        );
4063        assert_eq!(
4064            tool_result_ids,
4065            vec!["tool_call_1".to_string(), "tool_call_2".to_string()]
4066        );
4067        assert_eq!(add_calls.load(Ordering::SeqCst), 1);
4068        assert_eq!(subtract_calls.load(Ordering::SeqCst), 1);
4069        assert_eq!(final_response_text.as_deref(), Some("done"));
4070        assert_eq!(recorded.request_count(), 2);
4071    }
4072
4073    #[tokio::test]
4074    async fn disallowed_specific_tool_call_fails_before_streaming_second_request() {
4075        let model = MockCompletionModel::from_stream_turns([
4076            vec![
4077                MockStreamEvent::tool_call(
4078                    "tool_call_1",
4079                    "subtract",
4080                    serde_json::json!({"x": 3, "y": 1}),
4081                ),
4082                MockStreamEvent::final_response_with_total_tokens(4),
4083            ],
4084            vec![
4085                MockStreamEvent::text("should not be requested"),
4086                MockStreamEvent::final_response_with_total_tokens(6),
4087            ],
4088        ]);
4089        let recorded = model.clone();
4090        let agent = AgentBuilder::new(model)
4091            .tool(MockAddTool)
4092            .tool(MockSubtractTool)
4093            .tool_choice(ToolChoice::Specific {
4094                function_names: vec!["add".to_string()],
4095            })
4096            .build();
4097
4098        let mut stream = agent
4099            .stream_prompt("use the allowed tool")
4100            .with_hook(PanicOnUnknownToolHook)
4101            .multi_turn(3)
4102            .await;
4103        let mut saw_tool_call = false;
4104        let mut error = None;
4105
4106        while let Some(item) = stream.next().await {
4107            match item {
4108                Ok(MultiTurnStreamItem::StreamAssistantItem(
4109                    StreamedAssistantContent::ToolCall { .. },
4110                )) => {
4111                    saw_tool_call = true;
4112                }
4113                Ok(_) => {}
4114                Err(err) => {
4115                    error = Some(err);
4116                    break;
4117                }
4118            }
4119        }
4120
4121        assert!(!saw_tool_call);
4122        let error = error.expect("disallowed model-emitted tool should fail");
4123        match error {
4124            StreamingError::Prompt(err) => match *err {
4125                PromptError::UnknownToolCall {
4126                    tool_name,
4127                    available_tools,
4128                    allowed_tools,
4129                    chat_history,
4130                } => {
4131                    assert_eq!(tool_name, "subtract");
4132                    assert_eq!(
4133                        available_tools,
4134                        vec!["add".to_string(), "subtract".to_string()]
4135                    );
4136                    assert_eq!(allowed_tools, vec!["add".to_string()]);
4137                    assert!(history_contains_tool_call(&chat_history, "subtract"));
4138                }
4139                other => panic!("expected UnknownToolCall, got {other:?}"),
4140            },
4141            other => panic!("expected prompt streaming error, got {other:?}"),
4142        }
4143        assert_eq!(recorded.request_count(), 1);
4144    }
4145
4146    #[tokio::test]
4147    async fn mixed_specific_tool_calls_fail_before_any_tool_execution() {
4148        let add_calls = Arc::new(AtomicU32::new(0));
4149        let model = MockCompletionModel::from_stream_turns([
4150            vec![
4151                MockStreamEvent::tool_call(
4152                    "tool_call_1",
4153                    "add",
4154                    serde_json::json!({"x": 1, "y": 2}),
4155                ),
4156                MockStreamEvent::tool_call(
4157                    "tool_call_2",
4158                    "subtract",
4159                    serde_json::json!({"x": 3, "y": 1}),
4160                ),
4161                MockStreamEvent::final_response_with_total_tokens(4),
4162            ],
4163            vec![
4164                MockStreamEvent::text("should not be requested"),
4165                MockStreamEvent::final_response_with_total_tokens(6),
4166            ],
4167        ]);
4168        let recorded = model.clone();
4169        let agent = AgentBuilder::new(model)
4170            .tool(CountingAddTool {
4171                calls: add_calls.clone(),
4172            })
4173            .tool(MockSubtractTool)
4174            .tool_choice(ToolChoice::Specific {
4175                function_names: vec!["add".to_string()],
4176            })
4177            .build();
4178
4179        let mut stream = agent
4180            .stream_prompt("use the allowed tool")
4181            .with_hook(PanicOnUnknownToolHook)
4182            .multi_turn(3)
4183            .await;
4184        let mut saw_tool_call = false;
4185        let mut saw_tool_result = false;
4186        let mut error = None;
4187
4188        while let Some(item) = stream.next().await {
4189            match item {
4190                Ok(MultiTurnStreamItem::StreamAssistantItem(
4191                    StreamedAssistantContent::ToolCall { .. },
4192                )) => {
4193                    saw_tool_call = true;
4194                }
4195                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
4196                    ..
4197                })) => {
4198                    saw_tool_result = true;
4199                }
4200                Ok(_) => {}
4201                Err(err) => {
4202                    error = Some(err);
4203                    break;
4204                }
4205            }
4206        }
4207
4208        assert!(!saw_tool_call);
4209        assert!(!saw_tool_result);
4210        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
4211        let error = error.expect("mixed disallowed streamed tool call should fail");
4212        match error {
4213            StreamingError::Prompt(err) => match *err {
4214                PromptError::UnknownToolCall {
4215                    tool_name,
4216                    available_tools,
4217                    allowed_tools,
4218                    chat_history,
4219                } => {
4220                    assert_eq!(tool_name, "subtract");
4221                    assert_eq!(
4222                        available_tools,
4223                        vec!["add".to_string(), "subtract".to_string()]
4224                    );
4225                    assert_eq!(allowed_tools, vec!["add".to_string()]);
4226                    assert!(history_contains_tool_call(&chat_history, "subtract"));
4227                }
4228                other => panic!("expected UnknownToolCall, got {other:?}"),
4229            },
4230            other => panic!("expected prompt streaming error, got {other:?}"),
4231        }
4232        assert_eq!(recorded.request_count(), 1);
4233    }
4234
4235    #[tokio::test]
4236    async fn tool_choice_none_rejects_streaming_tool_call() {
4237        let model = MockCompletionModel::from_stream_turns([
4238            vec![
4239                MockStreamEvent::tool_call(
4240                    "tool_call_1",
4241                    "add",
4242                    serde_json::json!({"x": 1, "y": 2}),
4243                ),
4244                MockStreamEvent::final_response_with_total_tokens(4),
4245            ],
4246            vec![
4247                MockStreamEvent::text("should not be requested"),
4248                MockStreamEvent::final_response_with_total_tokens(6),
4249            ],
4250        ]);
4251        let recorded = model.clone();
4252        let agent = AgentBuilder::new(model)
4253            .tool(MockAddTool)
4254            .tool_choice(ToolChoice::None)
4255            .build();
4256
4257        let mut stream = agent
4258            .stream_prompt("do not use tools")
4259            .with_hook(PanicOnUnknownToolHook)
4260            .multi_turn(3)
4261            .await;
4262        let mut saw_tool_call = false;
4263        let mut error = None;
4264
4265        while let Some(item) = stream.next().await {
4266            match item {
4267                Ok(MultiTurnStreamItem::StreamAssistantItem(
4268                    StreamedAssistantContent::ToolCall { .. },
4269                )) => {
4270                    saw_tool_call = true;
4271                }
4272                Ok(_) => {}
4273                Err(err) => {
4274                    error = Some(err);
4275                    break;
4276                }
4277            }
4278        }
4279
4280        assert!(!saw_tool_call);
4281        let error = error.expect("ToolChoice::None should reject returned tool calls");
4282        match error {
4283            StreamingError::Prompt(err) => match *err {
4284                PromptError::UnknownToolCall {
4285                    tool_name,
4286                    available_tools,
4287                    allowed_tools,
4288                    chat_history,
4289                } => {
4290                    assert_eq!(tool_name, "add");
4291                    assert_eq!(available_tools, vec!["add".to_string()]);
4292                    assert!(allowed_tools.is_empty());
4293                    assert!(history_contains_tool_call(&chat_history, "add"));
4294                }
4295                other => panic!("expected UnknownToolCall, got {other:?}"),
4296            },
4297            other => panic!("expected prompt streaming error, got {other:?}"),
4298        }
4299        assert_eq!(recorded.request_count(), 1);
4300    }
4301
4302    #[tokio::test]
4303    async fn tool_choice_none_rejects_streaming_tool_call_name_delta_before_hook_or_emit() {
4304        let model = MockCompletionModel::from_stream_turns([
4305            vec![
4306                MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4307                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4308                MockStreamEvent::final_response_with_total_tokens(4),
4309            ],
4310            vec![
4311                MockStreamEvent::text("should not be requested"),
4312                MockStreamEvent::final_response_with_total_tokens(6),
4313            ],
4314        ]);
4315        let recorded = model.clone();
4316        let agent = AgentBuilder::new(model)
4317            .tool(MockAddTool)
4318            .tool_choice(ToolChoice::None)
4319            .build();
4320
4321        let mut stream = agent
4322            .stream_prompt("do not use tools")
4323            .with_hook(PanicOnUnknownToolHook)
4324            .multi_turn(3)
4325            .await;
4326        let mut saw_delta = false;
4327        let mut error = None;
4328
4329        while let Some(item) = stream.next().await {
4330            match item {
4331                Ok(MultiTurnStreamItem::StreamAssistantItem(
4332                    StreamedAssistantContent::ToolCallDelta { .. },
4333                )) => {
4334                    saw_delta = true;
4335                }
4336                Ok(_) => {}
4337                Err(err) => {
4338                    error = Some(err);
4339                    break;
4340                }
4341            }
4342        }
4343
4344        assert!(!saw_delta);
4345        let error = error.expect("ToolChoice::None should reject returned tool-call deltas");
4346        match error {
4347            StreamingError::Prompt(err) => match *err {
4348                PromptError::UnknownToolCall {
4349                    tool_name,
4350                    available_tools,
4351                    allowed_tools,
4352                    chat_history,
4353                } => {
4354                    assert_eq!(tool_name, "add");
4355                    assert_eq!(available_tools, vec!["add".to_string()]);
4356                    assert!(allowed_tools.is_empty());
4357                    assert!(history_contains_tool_call(&chat_history, "add"));
4358                }
4359                other => panic!("expected UnknownToolCall, got {other:?}"),
4360            },
4361            other => panic!("expected prompt streaming error, got {other:?}"),
4362        }
4363        assert_eq!(recorded.request_count(), 1);
4364    }
4365
4366    #[tokio::test]
4367    async fn unknown_tool_call_name_delta_fails_before_streaming_delta_hook_or_emit() {
4368        let model = MockCompletionModel::from_stream_turns([
4369            vec![
4370                MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"),
4371                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4372                MockStreamEvent::final_response_with_total_tokens(4),
4373            ],
4374            vec![
4375                MockStreamEvent::text("should not be requested"),
4376                MockStreamEvent::final_response_with_total_tokens(6),
4377            ],
4378        ]);
4379        let recorded = model.clone();
4380        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4381
4382        let mut stream = agent
4383            .stream_prompt("stream a bad tool call")
4384            .with_hook(PanicOnUnknownToolHook)
4385            .multi_turn(3)
4386            .await;
4387        let mut saw_delta = false;
4388        let mut error = None;
4389
4390        while let Some(item) = stream.next().await {
4391            match item {
4392                Ok(MultiTurnStreamItem::StreamAssistantItem(
4393                    StreamedAssistantContent::ToolCallDelta { .. },
4394                )) => {
4395                    saw_delta = true;
4396                }
4397                Ok(_) => {}
4398                Err(err) => {
4399                    error = Some(err);
4400                    break;
4401                }
4402            }
4403        }
4404
4405        assert!(!saw_delta);
4406        let error = error.expect("unknown tool-call name delta should fail");
4407        match error {
4408            StreamingError::Prompt(err) => match *err {
4409                PromptError::UnknownToolCall {
4410                    tool_name,
4411                    available_tools,
4412                    allowed_tools,
4413                    chat_history,
4414                } => {
4415                    assert_eq!(tool_name, "default_api");
4416                    assert_eq!(available_tools, vec!["add".to_string()]);
4417                    assert_eq!(allowed_tools, vec!["add".to_string()]);
4418                    assert!(history_contains_tool_call(&chat_history, "default_api"));
4419                }
4420                other => panic!("expected UnknownToolCall, got {other:?}"),
4421            },
4422            other => panic!("expected prompt streaming error, got {other:?}"),
4423        }
4424        assert_eq!(recorded.request_count(), 1);
4425    }
4426
4427    #[tokio::test]
4428    async fn tool_call_args_delta_before_unknown_name_fails_before_hook_or_emit() {
4429        let model = MockCompletionModel::from_stream_turns([
4430            vec![
4431                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4432                MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"),
4433                MockStreamEvent::final_response_with_total_tokens(4),
4434            ],
4435            vec![
4436                MockStreamEvent::text("should not be requested"),
4437                MockStreamEvent::final_response_with_total_tokens(6),
4438            ],
4439        ]);
4440        let recorded = model.clone();
4441        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4442
4443        let mut stream = agent
4444            .stream_prompt("stream a bad tool call")
4445            .with_hook(PanicOnUnknownToolHook)
4446            .multi_turn(3)
4447            .await;
4448        let mut saw_delta = false;
4449        let mut error = None;
4450
4451        while let Some(item) = stream.next().await {
4452            match item {
4453                Ok(MultiTurnStreamItem::StreamAssistantItem(
4454                    StreamedAssistantContent::ToolCallDelta { .. },
4455                )) => {
4456                    saw_delta = true;
4457                }
4458                Ok(_) => {}
4459                Err(err) => {
4460                    error = Some(err);
4461                    break;
4462                }
4463            }
4464        }
4465
4466        assert!(!saw_delta);
4467        let error = error.expect("unknown tool-call name should reject buffered args");
4468        match error {
4469            StreamingError::Prompt(err) => match *err {
4470                PromptError::UnknownToolCall {
4471                    tool_name,
4472                    available_tools,
4473                    allowed_tools,
4474                    chat_history,
4475                } => {
4476                    assert_eq!(tool_name, "default_api");
4477                    assert_eq!(available_tools, vec!["add".to_string()]);
4478                    assert_eq!(allowed_tools, vec!["add".to_string()]);
4479                    assert!(history_contains_tool_call(&chat_history, "default_api"));
4480                }
4481                other => panic!("expected UnknownToolCall, got {other:?}"),
4482            },
4483            other => panic!("expected prompt streaming error, got {other:?}"),
4484        }
4485        assert_eq!(recorded.request_count(), 1);
4486    }
4487
4488    #[tokio::test]
4489    async fn tool_call_args_delta_before_valid_name_buffers_then_emits_in_safe_order() {
4490        let model = MockCompletionModel::from_stream_turns([[
4491            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4492            MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4493            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4494            MockStreamEvent::final_response_with_total_tokens(3),
4495        ]]);
4496        let hook = RecordingToolCallDeltaHook::default();
4497        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4498
4499        let mut stream = agent
4500            .stream_prompt("stream a tool call")
4501            .with_hook(hook.clone())
4502            .await;
4503        let mut stream_deltas = Vec::new();
4504
4505        while let Some(item) = stream.next().await {
4506            match item {
4507                Ok(MultiTurnStreamItem::StreamAssistantItem(
4508                    StreamedAssistantContent::ToolCallDelta {
4509                        id,
4510                        internal_call_id,
4511                        content,
4512                    },
4513                )) => {
4514                    stream_deltas.push((id, internal_call_id, content));
4515                }
4516                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4517                Ok(_) => {}
4518                Err(err) => panic!("unexpected streaming error: {err:?}"),
4519            }
4520        }
4521
4522        assert_eq!(
4523            hook.observed(),
4524            vec![
4525                (
4526                    "tool_1".to_string(),
4527                    "internal_1".to_string(),
4528                    Some("add".to_string()),
4529                    String::new()
4530                ),
4531                (
4532                    "tool_1".to_string(),
4533                    "internal_1".to_string(),
4534                    None,
4535                    "{\"x\":".to_string()
4536                ),
4537                (
4538                    "tool_1".to_string(),
4539                    "internal_1".to_string(),
4540                    None,
4541                    "1}".to_string()
4542                ),
4543            ]
4544        );
4545        assert_eq!(
4546            stream_deltas,
4547            vec![
4548                (
4549                    "tool_1".to_string(),
4550                    "internal_1".to_string(),
4551                    ToolCallDeltaContent::Name("add".to_string())
4552                ),
4553                (
4554                    "tool_1".to_string(),
4555                    "internal_1".to_string(),
4556                    ToolCallDeltaContent::Delta("{\"x\":".to_string())
4557                ),
4558                (
4559                    "tool_1".to_string(),
4560                    "internal_1".to_string(),
4561                    ToolCallDeltaContent::Delta("1}".to_string())
4562                ),
4563            ]
4564        );
4565    }
4566
4567    #[tokio::test]
4568    async fn tool_call_args_delta_without_name_errors_at_stream_end() {
4569        let model = MockCompletionModel::from_stream_turns([
4570            vec![
4571                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4572                MockStreamEvent::final_response_with_total_tokens(4),
4573            ],
4574            vec![
4575                MockStreamEvent::text("should not be requested"),
4576                MockStreamEvent::final_response_with_total_tokens(6),
4577            ],
4578        ]);
4579        let recorded = model.clone();
4580        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4581
4582        let mut stream = agent
4583            .stream_prompt("stream an incomplete tool call")
4584            .with_hook(PanicOnUnknownToolHook)
4585            .multi_turn(3)
4586            .await;
4587        let mut saw_delta = false;
4588        let mut saw_completion_call = false;
4589        let mut saw_final_response = false;
4590        let mut error = None;
4591
4592        while let Some(item) = stream.next().await {
4593            match item {
4594                Ok(MultiTurnStreamItem::StreamAssistantItem(
4595                    StreamedAssistantContent::ToolCallDelta { .. },
4596                )) => {
4597                    saw_delta = true;
4598                }
4599                Ok(MultiTurnStreamItem::CompletionCall(_)) => {
4600                    saw_completion_call = true;
4601                }
4602                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
4603                    saw_final_response = true;
4604                }
4605                Ok(_) => {}
4606                Err(err) => {
4607                    error = Some(err);
4608                    break;
4609                }
4610            }
4611        }
4612
4613        assert!(!saw_delta);
4614        assert!(!saw_completion_call);
4615        assert!(!saw_final_response);
4616        let error = error.expect("unterminated tool-call args delta should fail");
4617        match error {
4618            StreamingError::Completion(CompletionError::ResponseError(message)) => {
4619                assert!(
4620                    message.contains("streamed tool call arguments"),
4621                    "{message}"
4622                );
4623                assert!(message.contains("tool_1"), "{message}");
4624                assert!(message.contains("internal_1"), "{message}");
4625            }
4626            other => panic!("expected completion response error, got {other:?}"),
4627        }
4628        assert_eq!(recorded.request_count(), 1);
4629    }
4630
4631    #[tokio::test]
4632    async fn tool_choice_none_buffers_args_then_rejects_name_without_emit() {
4633        let model = MockCompletionModel::from_stream_turns([
4634            vec![
4635                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4636                MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4637                MockStreamEvent::final_response_with_total_tokens(4),
4638            ],
4639            vec![
4640                MockStreamEvent::text("should not be requested"),
4641                MockStreamEvent::final_response_with_total_tokens(6),
4642            ],
4643        ]);
4644        let recorded = model.clone();
4645        let agent = AgentBuilder::new(model)
4646            .tool(MockAddTool)
4647            .tool_choice(ToolChoice::None)
4648            .build();
4649
4650        let mut stream = agent
4651            .stream_prompt("do not use tools")
4652            .with_hook(PanicOnUnknownToolHook)
4653            .multi_turn(3)
4654            .await;
4655        let mut saw_delta = false;
4656        let mut error = None;
4657
4658        while let Some(item) = stream.next().await {
4659            match item {
4660                Ok(MultiTurnStreamItem::StreamAssistantItem(
4661                    StreamedAssistantContent::ToolCallDelta { .. },
4662                )) => {
4663                    saw_delta = true;
4664                }
4665                Ok(_) => {}
4666                Err(err) => {
4667                    error = Some(err);
4668                    break;
4669                }
4670            }
4671        }
4672
4673        assert!(!saw_delta);
4674        let error = error.expect("ToolChoice::None should reject buffered tool-call deltas");
4675        match error {
4676            StreamingError::Prompt(err) => match *err {
4677                PromptError::UnknownToolCall {
4678                    tool_name,
4679                    available_tools,
4680                    allowed_tools,
4681                    chat_history,
4682                } => {
4683                    assert_eq!(tool_name, "add");
4684                    assert_eq!(available_tools, vec!["add".to_string()]);
4685                    assert!(allowed_tools.is_empty());
4686                    assert!(history_contains_tool_call(&chat_history, "add"));
4687                }
4688                other => panic!("expected UnknownToolCall, got {other:?}"),
4689            },
4690            other => panic!("expected prompt streaming error, got {other:?}"),
4691        }
4692        assert_eq!(recorded.request_count(), 1);
4693    }
4694
4695    #[tokio::test]
4696    async fn stream_prompt_emits_tool_call_deltas_without_hook() {
4697        let model = MockCompletionModel::from_stream_turns([[
4698            MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4699            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4700            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4701            MockStreamEvent::final_response_with_total_tokens(3),
4702        ]]);
4703        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4704
4705        let mut stream = agent.stream_prompt("stream a tool call").await;
4706        let mut deltas = Vec::new();
4707
4708        while let Some(item) = stream.next().await {
4709            match item {
4710                Ok(MultiTurnStreamItem::StreamAssistantItem(
4711                    StreamedAssistantContent::ToolCallDelta {
4712                        id,
4713                        internal_call_id,
4714                        content,
4715                    },
4716                )) => {
4717                    deltas.push((id, internal_call_id, content));
4718                }
4719                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4720                Ok(_) => {}
4721                Err(err) => panic!("unexpected streaming error: {err:?}"),
4722            }
4723        }
4724
4725        assert_eq!(
4726            deltas,
4727            vec![
4728                (
4729                    "tool_1".to_string(),
4730                    "internal_1".to_string(),
4731                    ToolCallDeltaContent::Name("add".to_string())
4732                ),
4733                (
4734                    "tool_1".to_string(),
4735                    "internal_1".to_string(),
4736                    ToolCallDeltaContent::Delta("{\"x\":".to_string())
4737                ),
4738                (
4739                    "tool_1".to_string(),
4740                    "internal_1".to_string(),
4741                    ToolCallDeltaContent::Delta("1}".to_string())
4742                ),
4743            ]
4744        );
4745    }
4746
4747    #[tokio::test]
4748    async fn stream_prompt_emits_tool_call_deltas_after_hook_continue() {
4749        let model = MockCompletionModel::from_stream_turns([[
4750            MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4751            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4752            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4753            MockStreamEvent::final_response_with_total_tokens(3),
4754        ]]);
4755        let hook = RecordingToolCallDeltaHook::default();
4756        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4757
4758        let mut stream = agent
4759            .stream_prompt("stream a tool call")
4760            .with_hook(hook.clone())
4761            .await;
4762        let mut stream_deltas = Vec::new();
4763
4764        while let Some(item) = stream.next().await {
4765            match item {
4766                Ok(MultiTurnStreamItem::StreamAssistantItem(
4767                    StreamedAssistantContent::ToolCallDelta {
4768                        id,
4769                        internal_call_id,
4770                        content,
4771                    },
4772                )) => {
4773                    stream_deltas.push((id, internal_call_id, content));
4774                }
4775                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4776                Ok(_) => {}
4777                Err(err) => panic!("unexpected streaming error: {err:?}"),
4778            }
4779        }
4780
4781        assert_eq!(
4782            hook.observed(),
4783            vec![
4784                (
4785                    "tool_1".to_string(),
4786                    "internal_1".to_string(),
4787                    Some("add".to_string()),
4788                    String::new()
4789                ),
4790                (
4791                    "tool_1".to_string(),
4792                    "internal_1".to_string(),
4793                    None,
4794                    "{\"x\":".to_string()
4795                ),
4796                (
4797                    "tool_1".to_string(),
4798                    "internal_1".to_string(),
4799                    None,
4800                    "1}".to_string()
4801                ),
4802            ]
4803        );
4804        assert_eq!(
4805            stream_deltas,
4806            vec![
4807                (
4808                    "tool_1".to_string(),
4809                    "internal_1".to_string(),
4810                    ToolCallDeltaContent::Name("add".to_string())
4811                ),
4812                (
4813                    "tool_1".to_string(),
4814                    "internal_1".to_string(),
4815                    ToolCallDeltaContent::Delta("{\"x\":".to_string())
4816                ),
4817                (
4818                    "tool_1".to_string(),
4819                    "internal_1".to_string(),
4820                    ToolCallDeltaContent::Delta("1}".to_string())
4821                ),
4822            ]
4823        );
4824    }
4825
4826    #[tokio::test]
4827    async fn stream_prompt_tool_call_deltas_hook_termination_prevents_delta_emit() {
4828        let model = MockCompletionModel::from_stream_turns([[
4829            MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4830            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4831            MockStreamEvent::final_response_with_total_tokens(3),
4832        ]]);
4833        let hook = TerminatingToolCallDeltaHook::default();
4834        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4835
4836        let mut stream = agent
4837            .stream_prompt("stream a tool call")
4838            .with_hook(hook.clone())
4839            .await;
4840        let mut saw_delta = false;
4841        let mut saw_final_response = false;
4842        let mut error_message = None;
4843
4844        while let Some(item) = stream.next().await {
4845            match item {
4846                Ok(MultiTurnStreamItem::StreamAssistantItem(
4847                    StreamedAssistantContent::ToolCallDelta { .. },
4848                )) => {
4849                    saw_delta = true;
4850                }
4851                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
4852                    saw_final_response = true;
4853                }
4854                Ok(_) => {}
4855                Err(err) => {
4856                    error_message = Some(err.to_string());
4857                    break;
4858                }
4859            }
4860        }
4861
4862        assert_eq!(
4863            hook.observed(),
4864            vec![(
4865                "tool_1".to_string(),
4866                "internal_1".to_string(),
4867                Some("add".to_string()),
4868                String::new()
4869            )]
4870        );
4871        assert!(!saw_delta);
4872        assert!(!saw_final_response);
4873        assert!(
4874            error_message
4875                .as_deref()
4876                .is_some_and(|message| message.contains("PromptCancelled: stop on tool call delta")),
4877            "expected hook termination error, got {error_message:?}"
4878        );
4879    }
4880
4881    #[tokio::test]
4882    async fn stream_prompt_exposes_completion_calls() {
4883        let first_call_usage = usage(10, 2);
4884        let second_call_usage = usage(25, 5);
4885        let model = MockCompletionModel::from_stream_turns([
4886            vec![
4887                MockStreamEvent::tool_call(
4888                    "tool_call_1",
4889                    "add",
4890                    serde_json::json!({"x": 1, "y": 2}),
4891                )
4892                .with_call_id("call_1"),
4893                MockStreamEvent::final_response(first_call_usage),
4894            ],
4895            vec![
4896                MockStreamEvent::text("done"),
4897                MockStreamEvent::final_response(second_call_usage),
4898            ],
4899        ]);
4900        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4901        let empty_history: &[Message] = &[];
4902
4903        let mut stream = agent
4904            .stream_prompt("do tool work")
4905            .with_history(empty_history)
4906            .multi_turn(3)
4907            .await;
4908        let mut completion_calls_events = Vec::new();
4909        let mut final_response = None;
4910
4911        while let Some(item) = stream.next().await {
4912            match item {
4913                Ok(MultiTurnStreamItem::CompletionCall(call_usage)) => {
4914                    completion_calls_events.push(call_usage);
4915                }
4916                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4917                    final_response = Some(response);
4918                    break;
4919                }
4920                Ok(_) => {}
4921                Err(err) => panic!("unexpected streaming error: {err:?}"),
4922            }
4923        }
4924
4925        assert_eq!(
4926            completion_calls_events,
4927            vec![
4928                CompletionCall::new(0, Some(first_call_usage)),
4929                CompletionCall::new(1, Some(second_call_usage))
4930            ]
4931        );
4932
4933        let final_response = final_response.expect("expected final response");
4934        assert_eq!(
4935            final_response.usage(),
4936            Usage {
4937                input_tokens: 35,
4938                output_tokens: 7,
4939                total_tokens: 42,
4940                cached_input_tokens: 0,
4941                cache_creation_input_tokens: 0,
4942                tool_use_prompt_tokens: 0,
4943                reasoning_tokens: 0,
4944            }
4945        );
4946        assert_eq!(
4947            final_response.completion_calls(),
4948            &[
4949                CompletionCall::new(0, Some(first_call_usage)),
4950                CompletionCall::new(1, Some(second_call_usage))
4951            ]
4952        );
4953    }
4954
4955    #[tokio::test]
4956    async fn stream_prompt_emits_completion_call_before_finish_hook_termination() {
4957        let call_usage = usage(10, 2);
4958        let model = MockCompletionModel::from_stream_turns([[
4959            MockStreamEvent::text("done"),
4960            MockStreamEvent::final_response(call_usage),
4961        ]]);
4962        let agent = AgentBuilder::new(model).build();
4963
4964        let mut stream = agent
4965            .stream_prompt("say done")
4966            .with_hook(TerminateOnStreamFinish)
4967            .await;
4968        let mut completion_calls = Vec::new();
4969        let mut saw_error = false;
4970
4971        while let Some(item) = stream.next().await {
4972            match item {
4973                Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
4974                    completion_calls.push(completion_call);
4975                }
4976                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4977                    panic!("unexpected final response after hook termination: {response:?}");
4978                }
4979                Ok(_) => {}
4980                Err(_) => {
4981                    saw_error = true;
4982                    break;
4983                }
4984            }
4985        }
4986
4987        assert_eq!(
4988            completion_calls,
4989            vec![CompletionCall::new(0, Some(call_usage))]
4990        );
4991        assert!(saw_error);
4992    }
4993
4994    #[tokio::test]
4995    async fn stream_prompt_completion_calls_records_unreported_usage() {
4996        let second_call_usage = usage(25, 5);
4997        let model = MockCompletionModel::from_stream_turns([
4998            vec![
4999                MockStreamEvent::tool_call(
5000                    "tool_call_1",
5001                    "add",
5002                    serde_json::json!({"x": 1, "y": 2}),
5003                )
5004                .with_call_id("call_1"),
5005            ],
5006            vec![
5007                MockStreamEvent::text("done"),
5008                MockStreamEvent::final_response(second_call_usage),
5009            ],
5010        ]);
5011        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5012        let empty_history: &[Message] = &[];
5013
5014        let mut stream = agent
5015            .stream_prompt("do tool work")
5016            .with_history(empty_history)
5017            .multi_turn(3)
5018            .await;
5019        let mut completion_calls_events = Vec::new();
5020        let mut final_response = None;
5021
5022        while let Some(item) = stream.next().await {
5023            match item {
5024                Ok(MultiTurnStreamItem::CompletionCall(call_usage)) => {
5025                    completion_calls_events.push(call_usage);
5026                }
5027                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
5028                    final_response = Some(response);
5029                    break;
5030                }
5031                Ok(_) => {}
5032                Err(err) => panic!("unexpected streaming error: {err:?}"),
5033            }
5034        }
5035
5036        let expected_usage = vec![
5037            CompletionCall::new(0, None),
5038            CompletionCall::new(1, Some(second_call_usage)),
5039        ];
5040        assert_eq!(completion_calls_events, expected_usage);
5041
5042        let final_response = final_response.expect("expected final response");
5043        assert_eq!(final_response.completion_calls(), expected_usage.as_slice());
5044    }
5045
5046    #[tokio::test]
5047    async fn final_response_matches_streamed_text_when_provider_final_is_textless() {
5048        let agent = AgentBuilder::new(streaming_text_then_final_model()).build();
5049
5050        let mut stream = agent.stream_prompt("say hello").await;
5051        let mut streamed_text = String::new();
5052        let mut final_response_text = None;
5053
5054        while let Some(item) = stream.next().await {
5055            match item {
5056                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5057                    text,
5058                ))) => streamed_text.push_str(&text.text),
5059                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5060                    final_response_text = Some(res.response().to_owned());
5061                    break;
5062                }
5063                Ok(_) => {}
5064                Err(err) => panic!("unexpected streaming error: {err:?}"),
5065            }
5066        }
5067
5068        assert_eq!(streamed_text, "hello world");
5069        assert_eq!(final_response_text.as_deref(), Some("hello world"));
5070    }
5071
5072    #[tokio::test]
5073    async fn final_response_preserves_structured_text_metadata() {
5074        let agent = AgentBuilder::new(streaming_cited_text_then_final_model()).build();
5075
5076        let mut stream = agent.stream_prompt("answer with citations").await;
5077        let mut final_response = None;
5078
5079        while let Some(item) = stream.next().await {
5080            match item {
5081                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5082                    final_response = Some(res);
5083                    break;
5084                }
5085                Ok(_) => {}
5086                Err(err) => panic!("unexpected streaming error: {err:?}"),
5087            }
5088        }
5089
5090        let final_response = final_response.expect("expected final response");
5091        assert_eq!(final_response.response(), "cited answer");
5092        let metadata = text_metadata(final_response.content())
5093            .expect("expected text metadata in final content");
5094        assert_eq!(
5095            metadata["citations"][0]["encrypted_index"],
5096            "encrypted-reference"
5097        );
5098    }
5099
5100    #[tokio::test]
5101    async fn final_response_history_preserves_structured_text_metadata() {
5102        let agent = AgentBuilder::new(streaming_cited_text_then_final_model()).build();
5103
5104        let empty_history: &[Message] = &[];
5105        let mut stream = agent
5106            .stream_prompt("answer with citations")
5107            .with_history(empty_history)
5108            .await;
5109        let mut final_response = None;
5110
5111        while let Some(item) = stream.next().await {
5112            match item {
5113                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5114                    final_response = Some(res);
5115                    break;
5116                }
5117                Ok(_) => {}
5118                Err(err) => panic!("unexpected streaming error: {err:?}"),
5119            }
5120        }
5121
5122        let final_response = final_response.expect("expected final response");
5123        let history = final_response
5124            .history()
5125            .expect("with_history should include final history");
5126        let assistant_content = history
5127            .iter()
5128            .find_map(|message| match message {
5129                Message::Assistant { content, .. } => Some(content),
5130                _ => None,
5131            })
5132            .expect("expected assistant message in history");
5133        let metadata =
5134            text_metadata(assistant_content).expect("expected text metadata in assistant history");
5135        assert_eq!(
5136            metadata["citations"][0]["encrypted_index"],
5137            "encrypted-reference"
5138        );
5139    }
5140
5141    #[tokio::test]
5142    async fn tool_follow_up_history_preserves_structured_text_metadata() {
5143        let model = streaming_cited_text_then_tool_model();
5144        let recorded = model.clone();
5145        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
5146        let empty_history: &[Message] = &[];
5147
5148        let mut stream = agent
5149            .stream_prompt("use a tool with citations")
5150            .with_history(empty_history)
5151            .multi_turn(3)
5152            .await;
5153
5154        while let Some(item) = stream.next().await {
5155            match item {
5156                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
5157                Ok(_) => {}
5158                Err(err) => panic!("unexpected streaming error: {err:?}"),
5159            }
5160        }
5161
5162        let requests = recorded.requests();
5163        assert_eq!(requests.len(), 2);
5164        let follow_up_history = requests[1].chat_history.iter().collect::<Vec<_>>();
5165        let assistant_content = follow_up_history
5166            .iter()
5167            .find_map(|message| match message {
5168                Message::Assistant { content, .. } => Some(content),
5169                _ => None,
5170            })
5171            .expect("expected assistant message in follow-up history");
5172        let metadata = text_metadata(assistant_content)
5173            .expect("expected citation metadata in follow-up assistant history");
5174        assert_eq!(
5175            metadata["citations"][0]["encrypted_index"],
5176            "encrypted-reference"
5177        );
5178    }
5179
5180    #[tokio::test]
5181    async fn final_response_can_remain_empty_for_truly_textless_turns() {
5182        let agent = AgentBuilder::new(streaming_final_only_model()).build();
5183
5184        let mut stream = agent.stream_prompt("say nothing").await;
5185        let mut streamed_text = String::new();
5186        let mut final_response_text = None;
5187
5188        while let Some(item) = stream.next().await {
5189            match item {
5190                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5191                    text,
5192                ))) => streamed_text.push_str(&text.text),
5193                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5194                    final_response_text = Some(res.response().to_owned());
5195                    break;
5196                }
5197                Ok(_) => {}
5198                Err(err) => panic!("unexpected streaming error: {err:?}"),
5199            }
5200        }
5201
5202        assert!(streamed_text.is_empty());
5203        assert_eq!(final_response_text.as_deref(), Some(""));
5204    }
5205
5206    /// Background task that logs periodically to detect span leakage.
5207    /// If span leakage occurs, these logs will be prefixed with `invoke_agent{...}`.
5208    async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
5209        let mut interval = tokio::time::interval(Duration::from_millis(50));
5210        let mut count = 0u32;
5211
5212        while !stop.load(Ordering::Relaxed) {
5213            interval.tick().await;
5214            count += 1;
5215
5216            tracing::event!(
5217                target: "background_logger",
5218                tracing::Level::INFO,
5219                count = count,
5220                "Background tick"
5221            );
5222
5223            // Check if we're inside an unexpected span
5224            let current = tracing::Span::current();
5225            if !current.is_disabled() && !current.is_none() {
5226                leak_count.fetch_add(1, Ordering::Relaxed);
5227            }
5228        }
5229
5230        tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
5231    }
5232
5233    /// Test that span context doesn't leak to concurrent tasks during streaming.
5234    ///
5235    /// This test verifies that using `.instrument()` instead of `span.enter()` in
5236    /// async_stream prevents thread-local span context from leaking to other tasks.
5237    ///
5238    /// Uses single-threaded runtime to force all tasks onto the same thread,
5239    /// making the span leak deterministic (it only occurs when tasks share a thread).
5240    #[tokio::test(flavor = "current_thread")]
5241    #[ignore = "This requires an API key"]
5242    async fn test_span_context_isolation() -> anyhow::Result<()> {
5243        let stop = Arc::new(AtomicBool::new(false));
5244        let leak_count = Arc::new(AtomicU32::new(0));
5245
5246        // Start background logger
5247        let bg_stop = stop.clone();
5248        let bg_leak = leak_count.clone();
5249        let bg_handle = tokio::spawn(async move {
5250            background_logger(bg_stop, bg_leak).await;
5251        });
5252
5253        // Small delay to let background logger start
5254        tokio::time::sleep(Duration::from_millis(100)).await;
5255
5256        // Make streaming request WITHOUT an outer span so rig creates its own invoke_agent span
5257        // (rig reuses current span if one exists, so we need to ensure there's no current span)
5258        let client = anthropic::Client::from_env()?;
5259        let agent = client
5260            .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
5261            .preamble("You are a helpful assistant.")
5262            .temperature(0.1)
5263            .max_tokens(100)
5264            .build();
5265
5266        let mut stream = agent
5267            .stream_prompt("Say 'hello world' and nothing else.")
5268            .await;
5269
5270        let mut full_content = String::new();
5271        while let Some(item) = stream.next().await {
5272            match item {
5273                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5274                    text,
5275                ))) => {
5276                    full_content.push_str(&text.text);
5277                }
5278                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
5279                    break;
5280                }
5281                Err(e) => {
5282                    tracing::warn!("Error: {:?}", e);
5283                    break;
5284                }
5285                _ => {}
5286            }
5287        }
5288
5289        tracing::info!("Got response: {:?}", full_content);
5290
5291        // Stop background logger
5292        stop.store(true, Ordering::Relaxed);
5293        bg_handle.await?;
5294
5295        let leaks = leak_count.load(Ordering::Relaxed);
5296        anyhow::ensure!(
5297            leaks == 0,
5298            "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
5299             This indicates that span.enter() is being used inside async_stream instead of .instrument()"
5300        );
5301
5302        Ok(())
5303    }
5304
5305    /// Test that FinalResponse contains the updated chat history when with_history is used.
5306    ///
5307    /// This verifies that:
5308    /// 1. FinalResponse.history() returns Some when with_history was called
5309    /// 2. The history contains both the user prompt and assistant response
5310    #[tokio::test]
5311    #[ignore = "This requires an API key"]
5312    async fn test_chat_history_in_final_response() -> anyhow::Result<()> {
5313        use crate::message::Message;
5314
5315        let client = anthropic::Client::from_env()?;
5316        let agent = client
5317            .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
5318            .preamble("You are a helpful assistant. Keep responses brief.")
5319            .temperature(0.1)
5320            .max_tokens(50)
5321            .build();
5322
5323        // Send streaming request with history
5324        let empty_history: &[Message] = &[];
5325        let mut stream = agent
5326            .stream_prompt("Say 'hello' and nothing else.")
5327            .with_history(empty_history)
5328            .await;
5329
5330        // Consume the stream and collect FinalResponse
5331        let mut response_text = String::new();
5332        let mut final_history = None;
5333        while let Some(item) = stream.next().await {
5334            match item {
5335                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5336                    text,
5337                ))) => {
5338                    response_text.push_str(&text.text);
5339                }
5340                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5341                    final_history = res.history().map(|h| h.to_vec());
5342                    break;
5343                }
5344                Err(e) => {
5345                    return Err(e.into());
5346                }
5347                _ => {}
5348            }
5349        }
5350
5351        let history = final_history
5352            .ok_or_else(|| anyhow::anyhow!("final response should include history"))?;
5353
5354        // Should contain at least the user message
5355        anyhow::ensure!(
5356            history.iter().any(|m| matches!(m, Message::User { .. })),
5357            "History should contain the user message"
5358        );
5359
5360        // Should contain the assistant response
5361        anyhow::ensure!(
5362            history
5363                .iter()
5364                .any(|m| matches!(m, Message::Assistant { .. })),
5365            "History should contain the assistant response"
5366        );
5367
5368        tracing::info!(
5369            "History after streaming: {} messages, response: {:?}",
5370            history.len(),
5371            response_text
5372        );
5373
5374        Ok(())
5375    }
5376
5377    #[tokio::test]
5378    async fn streaming_appends_to_memory_after_final_response() {
5379        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5380
5381        let memory = InMemoryConversationMemory::new();
5382        let agent = AgentBuilder::new(streaming_text_then_final_model())
5383            .memory(memory.clone())
5384            .build();
5385
5386        let mut stream = agent
5387            .stream_prompt("hi there")
5388            .conversation("stream-thread")
5389            .await;
5390
5391        let mut history_in_final = None;
5392        while let Some(item) = stream.next().await {
5393            match item {
5394                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5395                    history_in_final = res.history().map(|h| h.to_vec());
5396                    break;
5397                }
5398                Ok(_) => {}
5399                Err(err) => panic!("unexpected streaming error: {err:?}"),
5400            }
5401        }
5402
5403        let final_history = history_in_final
5404            .expect("FinalResponse.history should be populated when memory is configured");
5405        assert_eq!(
5406            final_history.len(),
5407            2,
5408            "user prompt + assistant response in final history: {final_history:?}"
5409        );
5410
5411        let stored = memory.load("stream-thread").await.unwrap();
5412        assert_eq!(stored.len(), 2, "memory should contain user + assistant");
5413    }
5414
5415    #[tokio::test]
5416    async fn streaming_reasoning_without_tools_does_not_duplicate_final_history() {
5417        let agent = AgentBuilder::new(MockCompletionModel::from_stream_turns([[
5418            MockStreamEvent::text("final answer"),
5419            MockStreamEvent::reasoning("reasoned step").with_reasoning_id("rs_1"),
5420            MockStreamEvent::final_response_with_total_tokens(3),
5421        ]]))
5422        .build();
5423
5424        let mut stream = agent
5425            .stream_prompt("think before answering")
5426            .with_history(Vec::<Message>::new())
5427            .await;
5428
5429        let mut history_in_final = None;
5430        while let Some(item) = stream.next().await {
5431            match item {
5432                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5433                    history_in_final = res.history().map(|h| h.to_vec());
5434                    break;
5435                }
5436                Ok(_) => {}
5437                Err(err) => panic!("unexpected streaming error: {err:?}"),
5438            }
5439        }
5440
5441        let final_history = history_in_final
5442            .expect("FinalResponse.history should be populated when with_history is used");
5443        assert_eq!(
5444            final_history.len(),
5445            2,
5446            "user prompt + one assistant response in final history: {final_history:?}"
5447        );
5448
5449        assert!(matches!(
5450            final_history.first(),
5451            Some(Message::User { content })
5452                if matches!(
5453                    content.first(),
5454                    UserContent::Text(text) if text.text == "think before answering"
5455                )
5456        ));
5457
5458        let assistant_messages = final_history
5459            .iter()
5460            .filter_map(|message| match message {
5461                Message::Assistant { content, .. } => Some(content),
5462                _ => None,
5463            })
5464            .collect::<Vec<_>>();
5465        assert_eq!(
5466            assistant_messages.len(),
5467            1,
5468            "reasoning turn should produce exactly one assistant history message: {final_history:?}"
5469        );
5470        let assistant_content = assistant_messages
5471            .first()
5472            .expect("expected assistant history message");
5473        assert!(assistant_content.iter().any(|item| matches!(
5474            item,
5475            AssistantContent::Text(text) if text.text == "final answer"
5476        )));
5477        assert!(assistant_content.iter().any(|item| matches!(
5478            item,
5479            AssistantContent::Reasoning(reasoning)
5480                if reasoning.id.as_deref() == Some("rs_1")
5481                    && reasoning.content.iter().any(|content| matches!(
5482                        content,
5483                        ReasoningContent::Text { text, .. } if text == "reasoned step"
5484                    ))
5485        )));
5486    }
5487
5488    #[tokio::test]
5489    async fn streaming_with_history_overrides_memory() {
5490        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5491
5492        let memory = InMemoryConversationMemory::new();
5493        memory
5494            .append("t1", vec![Message::user("from-memory")])
5495            .await
5496            .unwrap();
5497
5498        let agent = AgentBuilder::new(streaming_text_then_final_model())
5499            .memory(memory.clone())
5500            .build();
5501
5502        let mut stream = agent
5503            .stream_prompt("hi")
5504            .conversation("t1")
5505            .with_history(vec![Message::user("from-caller")])
5506            .await;
5507
5508        while let Some(item) = stream.next().await {
5509            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5510                break;
5511            }
5512        }
5513
5514        let stored = memory.load("t1").await.unwrap();
5515        assert_eq!(
5516            stored.len(),
5517            1,
5518            "with_history bypasses memory; only the pre-seeded entry remains: {stored:?}"
5519        );
5520    }
5521
5522    #[tokio::test]
5523    async fn streaming_without_memory_disables_for_request() {
5524        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5525
5526        let memory = InMemoryConversationMemory::new();
5527        let agent = AgentBuilder::new(streaming_text_then_final_model())
5528            .memory(memory.clone())
5529            .conversation_id("default")
5530            .build();
5531
5532        let mut stream = agent.stream_prompt("hi").without_memory().await;
5533
5534        while let Some(item) = stream.next().await {
5535            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5536                break;
5537            }
5538        }
5539
5540        let stored = memory.load("default").await.unwrap();
5541        assert!(stored.is_empty(), "without_memory disables save");
5542    }
5543
5544    #[tokio::test]
5545    async fn streaming_load_error_yields_memory_error() {
5546        let agent = AgentBuilder::new(streaming_text_then_final_model())
5547            .memory(FailingMemory::default())
5548            .build();
5549
5550        let mut stream = agent.stream_prompt("hi").conversation("t1").await;
5551
5552        let first = stream.next().await.expect("at least one item");
5553        match first {
5554            Err(err) => {
5555                let msg = format!("{err:?}");
5556                assert!(
5557                    msg.contains("Memory") || msg.contains("memory") || msg.contains("load boom"),
5558                    "expected memory error, got: {msg}"
5559                );
5560            }
5561            Ok(other) => panic!("expected memory error, got {other:?}"),
5562        }
5563    }
5564
5565    #[tokio::test]
5566    async fn streaming_with_filter_shapes_loaded_history() {
5567        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5568
5569        let memory = InMemoryConversationMemory::new()
5570            .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
5571        memory
5572            .append(
5573                "t1",
5574                vec![
5575                    Message::user("1"),
5576                    Message::assistant("2"),
5577                    Message::user("3"),
5578                    Message::assistant("4"),
5579                ],
5580            )
5581            .await
5582            .unwrap();
5583
5584        let model = MockCompletionModel::from_stream_turns([[
5585            MockStreamEvent::text("ok"),
5586            MockStreamEvent::final_response_with_total_tokens(1),
5587        ]]);
5588        let recorded = model.clone();
5589        let agent = AgentBuilder::new(model).memory(memory).build();
5590
5591        let mut stream = agent.stream_prompt("ping").conversation("t1").await;
5592        while let Some(item) = stream.next().await {
5593            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5594                break;
5595            }
5596        }
5597
5598        let received = recorded.requests()[0]
5599            .chat_history
5600            .iter()
5601            .cloned()
5602            .collect::<Vec<_>>();
5603        assert_eq!(
5604            received.len(),
5605            3,
5606            "window-truncated history (2) + current prompt: {received:?}"
5607        );
5608    }
5609
5610    #[tokio::test]
5611    async fn streaming_append_error_does_not_suppress_final_response() {
5612        let agent = AgentBuilder::new(streaming_text_then_final_model())
5613            .memory(AppendFailingMemory::default())
5614            .build();
5615
5616        let mut stream = agent.stream_prompt("hi").conversation("t1").await;
5617
5618        let mut saw_final = false;
5619        while let Some(item) = stream.next().await {
5620            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5621                saw_final = true;
5622                break;
5623            }
5624        }
5625        assert!(
5626            saw_final,
5627            "FinalResponse must be yielded even when memory.append fails"
5628        );
5629    }
5630}