Skip to main content

rig_core/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::completion::{DynamicContextStore, build_prepared_completion_request},
4    agent::prompt_request::{
5        HookAction, assistant_text_from_choice,
6        hooks::{InvalidToolCallHookAction, PromptHook},
7        is_empty_assistant_turn, tool_result_user_content,
8    },
9    agent::run::{
10        AgentRun, AgentRunStep,
11        streamed::{StreamedResolution, StreamedTurnAssembler, StreamedTurnEvent},
12    },
13    completion::{Document, GetTokenUsage},
14    json_utils,
15    memory::ConversationMemory,
16    message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent},
17    streaming::{StreamedAssistantContent, StreamedUserContent, ToolCallDeltaContent},
18    tool::server::ToolServerHandle,
19    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
20};
21use futures::{Stream, StreamExt};
22use serde::{Deserialize, Serialize};
23use std::{collections::VecDeque, pin::Pin, sync::Arc};
24use tracing::info_span;
25use tracing_futures::Instrument;
26
27use super::{CompletionCall, ToolCallHookAction};
28use crate::{
29    agent::Agent,
30    completion::{CompletionError, CompletionModel, PromptError},
31    message::{Message, Text},
32    tool::ToolSetError,
33};
34
35#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
36pub type StreamingResult<R> =
37    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
38
39#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
40pub type StreamingResult<R> =
41    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
42
43#[derive(Deserialize, Serialize, Debug, Clone)]
44#[serde(tag = "type", rename_all = "camelCase")]
45#[non_exhaustive]
46pub enum MultiTurnStreamItem<R> {
47    /// A streamed assistant content item.
48    StreamAssistantItem(StreamedAssistantContent<R>),
49    /// A streamed user content item (mostly for tool results).
50    StreamUserItem(StreamedUserContent),
51    /// Details for one successfully completed completion request made by this agent stream.
52    ///
53    /// This is emitted when a provider call finishes. Usage is the provider's
54    /// final usage for that completion request when available; it is not
55    /// incremental per streamed token.
56    ///
57    /// ```rust,ignore
58    /// match item {
59    ///     MultiTurnStreamItem::CompletionCall(completion_call) => {
60    ///         // Zero-valued usage means the provider reported no metrics.
61    ///         if completion_call.usage.has_values() {
62    ///             let context_tokens = completion_call.usage.input_tokens;
63    ///         }
64    ///     }
65    ///     _ => {}
66    /// }
67    /// ```
68    CompletionCall(CompletionCall),
69    /// The final result from the stream.
70    FinalResponse(FinalResponse),
71}
72
73#[derive(Deserialize, Serialize, Debug, Clone)]
74#[serde(rename_all = "camelCase")]
75pub struct FinalResponse {
76    /// Structured assistant content for the final turn.
77    content: OneOrMany<AssistantContent>,
78    /// Concatenated assistant text for the final turn.
79    /// This is empty only when the turn completed without emitting any text.
80    response: String,
81    aggregated_usage: crate::completion::Usage,
82    /// Successfully completed completion requests made by this agent stream.
83    #[serde(default, skip_serializing_if = "Vec::is_empty")]
84    completion_calls: Vec<CompletionCall>,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    history: Option<Vec<Message>>,
87}
88
89impl FinalResponse {
90    pub fn empty() -> Self {
91        Self::new(
92            OneOrMany::one(AssistantContent::text("")),
93            crate::completion::Usage::new(),
94            None,
95        )
96    }
97
98    pub fn new(
99        content: OneOrMany<AssistantContent>,
100        aggregated_usage: crate::completion::Usage,
101        history: Option<Vec<Message>>,
102    ) -> Self {
103        let response = assistant_text_from_choice(&content);
104        Self {
105            content,
106            response,
107            aggregated_usage,
108            completion_calls: Vec::new(),
109            history,
110        }
111    }
112
113    /// Returns the concatenated assistant text for the final turn.
114    pub fn response(&self) -> &str {
115        &self.response
116    }
117
118    /// Returns the structured assistant content for the final turn.
119    pub fn content(&self) -> &OneOrMany<AssistantContent> {
120        &self.content
121    }
122
123    /// Returns the structured assistant content for the final turn.
124    pub fn assistant_content(&self) -> &OneOrMany<AssistantContent> {
125        &self.content
126    }
127
128    pub fn usage(&self) -> crate::completion::Usage {
129        self.aggregated_usage
130    }
131
132    /// Returns successfully completed completion requests made by this agent stream, with usage when available.
133    ///
134    /// Each entry represents one provider completion request. Usage is a
135    /// whole-request provider snapshot, not incremental usage per streamed
136    /// token. Streaming providers may omit usage for some calls; those calls
137    /// have an entry with zero-valued usage.
138    pub fn completion_calls(&self) -> &[CompletionCall] {
139        &self.completion_calls
140    }
141
142    /// Number of completion requests this agent run made.
143    pub fn requests(&self) -> usize {
144        self.completion_calls.len()
145    }
146
147    pub fn history(&self) -> Option<&[Message]> {
148        self.history.as_deref()
149    }
150}
151
152impl<R> MultiTurnStreamItem<R> {
153    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
154        Self::StreamAssistantItem(item)
155    }
156
157    pub fn final_response(
158        content: OneOrMany<AssistantContent>,
159        aggregated_usage: crate::completion::Usage,
160    ) -> Self {
161        Self::FinalResponse(FinalResponse::new(content, aggregated_usage, None))
162    }
163
164    pub fn final_response_with_history(
165        content: OneOrMany<AssistantContent>,
166        aggregated_usage: crate::completion::Usage,
167        history: Option<Vec<Message>>,
168    ) -> Self {
169        Self::FinalResponse(FinalResponse::new(content, aggregated_usage, history))
170    }
171
172    pub(crate) fn final_response_with_completion_calls(
173        content: OneOrMany<AssistantContent>,
174        aggregated_usage: crate::completion::Usage,
175        completion_calls: Vec<CompletionCall>,
176        history: Option<Vec<Message>>,
177    ) -> Self {
178        let mut response = FinalResponse::new(content, aggregated_usage, history);
179        response.completion_calls = completion_calls;
180        Self::FinalResponse(response)
181    }
182}
183
184/// Drain a provider stream abandoned by invalid tool-call recovery so the
185/// reported usage for the recovered completion call is not lost.
186async fn drain_stream_usage<R>(
187    stream: &mut crate::streaming::StreamingCompletionResponse<R>,
188) -> Result<crate::completion::Usage, StreamingError>
189where
190    R: Clone + Unpin + GetTokenUsage,
191{
192    while let Some(content) = stream.next().await {
193        match content {
194            Ok(StreamedAssistantContent::Final(final_resp)) => {
195                return Ok(final_resp.token_usage());
196            }
197            Ok(_) => {}
198            Err(err) => return Err(err.into()),
199        }
200    }
201
202    Ok(crate::completion::Usage::new())
203}
204
205fn record_usage_on_span(span: &tracing::Span, usage: crate::completion::Usage) {
206    span.record("gen_ai.usage.input_tokens", usage.input_tokens);
207    span.record("gen_ai.usage.output_tokens", usage.output_tokens);
208    span.record(
209        "gen_ai.usage.cache_read.input_tokens",
210        usage.cached_input_tokens,
211    );
212    span.record(
213        "gen_ai.usage.cache_creation.input_tokens",
214        usage.cache_creation_input_tokens,
215    );
216    span.record(
217        "gen_ai.usage.tool_use_prompt_tokens",
218        usage.tool_use_prompt_tokens,
219    );
220    span.record("gen_ai.usage.reasoning_tokens", usage.reasoning_tokens);
221}
222
223#[derive(Debug, thiserror::Error)]
224pub enum StreamingError {
225    #[error("CompletionError: {0}")]
226    Completion(#[from] CompletionError),
227    #[error("PromptError: {0}")]
228    Prompt(#[from] Box<PromptError>),
229    #[error("ToolSetError: {0}")]
230    Tool(#[from] ToolSetError),
231}
232
233/// Surface [`crate::memory::ConversationMemory`] failures through the existing
234/// [`CompletionError::RequestError`] variant so adding memory support does not
235/// require a new top-level [`StreamingError`] arm.
236impl From<crate::memory::MemoryError> for StreamingError {
237    fn from(err: crate::memory::MemoryError) -> Self {
238        Self::Completion(CompletionError::RequestError(Box::new(err)))
239    }
240}
241
242const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
243
244/// A builder for creating prompt requests with customizable options.
245/// Uses generics to track which options have been set during the build process.
246///
247/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
248/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
249/// attempting to await (which will send the prompt request) can potentially return
250/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
251/// back to back.
252pub struct StreamingPromptRequest<M, P>
253where
254    M: CompletionModel,
255    P: PromptHook<M> + 'static,
256{
257    /// The prompt message to send to the model
258    prompt: Message,
259    /// Optional chat history provided by the caller.
260    chat_history: Option<Vec<Message>>,
261    /// Maximum Turns for multi-turn conversations (0 means no multi-turn)
262    max_turns: usize,
263
264    // Agent data (cloned from agent to allow hook type transitions):
265    /// The completion model
266    model: Arc<M>,
267    /// Agent name for logging
268    agent_name: Option<String>,
269    /// System prompt
270    preamble: Option<String>,
271    /// Static context documents
272    static_context: Vec<Document>,
273    /// Temperature setting
274    temperature: Option<f64>,
275    /// Max tokens setting
276    max_tokens: Option<u64>,
277    /// Additional model parameters
278    additional_params: Option<serde_json::Value>,
279    /// Tool server handle for tool execution
280    tool_server_handle: ToolServerHandle,
281    /// Dynamic context store
282    dynamic_context: DynamicContextStore,
283    /// Tool choice setting
284    tool_choice: Option<ToolChoice>,
285    /// Optional JSON Schema for structured output
286    output_schema: Option<schemars::Schema>,
287    /// Optional per-request hook for events
288    hook: Option<P>,
289    /// Maximum number of invalid tool-call retries for this request.
290    max_invalid_tool_call_retries: usize,
291    /// Optional conversation memory backend cloned from the agent.
292    memory: Option<Arc<dyn ConversationMemory>>,
293    /// Optional conversation id used for loading and saving memory.
294    conversation_id: Option<String>,
295}
296
297impl<M, P> StreamingPromptRequest<M, P>
298where
299    M: CompletionModel + 'static,
300    <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
301    P: PromptHook<M>,
302{
303    /// Create a new StreamingPromptRequest with the given prompt and model.
304    /// Note: This creates a request without an agent hook. Use `from_agent` to include the agent's hook.
305    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
306        StreamingPromptRequest {
307            prompt: prompt.into(),
308            chat_history: None,
309            max_turns: agent.default_max_turns.unwrap_or_default(),
310            model: agent.model.clone(),
311            agent_name: agent.name.clone(),
312            preamble: agent.preamble.clone(),
313            static_context: agent.static_context.clone(),
314            temperature: agent.temperature,
315            max_tokens: agent.max_tokens,
316            additional_params: agent.additional_params.clone(),
317            tool_server_handle: agent.tool_server_handle.clone(),
318            dynamic_context: agent.dynamic_context.clone(),
319            tool_choice: agent.tool_choice.clone(),
320            output_schema: agent.output_schema.clone(),
321            hook: None,
322            max_invalid_tool_call_retries: 0,
323            memory: agent.memory.clone(),
324            conversation_id: agent.default_conversation_id.clone(),
325        }
326    }
327
328    /// Create a new StreamingPromptRequest from an agent, cloning the agent's data and default hook.
329    pub fn from_agent<P2>(
330        agent: &Agent<M, P2>,
331        prompt: impl Into<Message>,
332    ) -> StreamingPromptRequest<M, P2>
333    where
334        P2: PromptHook<M>,
335    {
336        StreamingPromptRequest {
337            prompt: prompt.into(),
338            chat_history: None,
339            max_turns: agent.default_max_turns.unwrap_or_default(),
340            model: agent.model.clone(),
341            agent_name: agent.name.clone(),
342            preamble: agent.preamble.clone(),
343            static_context: agent.static_context.clone(),
344            temperature: agent.temperature,
345            max_tokens: agent.max_tokens,
346            additional_params: agent.additional_params.clone(),
347            tool_server_handle: agent.tool_server_handle.clone(),
348            dynamic_context: agent.dynamic_context.clone(),
349            tool_choice: agent.tool_choice.clone(),
350            output_schema: agent.output_schema.clone(),
351            hook: agent.hook.clone(),
352            max_invalid_tool_call_retries: 0,
353            memory: agent.memory.clone(),
354            conversation_id: agent.default_conversation_id.clone(),
355        }
356    }
357
358    fn agent_name(&self) -> &str {
359        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
360    }
361
362    /// Set the maximum Turns for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
363    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
364    pub fn multi_turn(mut self, turns: usize) -> Self {
365        self.max_turns = turns;
366        self
367    }
368
369    /// Add chat history to the prompt request.
370    ///
371    /// When history is provided, the final [`FinalResponse`] will include the
372    /// updated chat history (original messages + new user prompt + assistant response).
373    /// ```ignore
374    /// let mut stream = agent
375    ///     .stream_prompt("Hello")
376    ///     .with_history(vec![])
377    ///     .await;
378    /// // ... consume stream ...
379    /// // Access updated history from FinalResponse::history()
380    /// ```
381    pub fn with_history<H, T>(mut self, history: H) -> Self
382    where
383        H: IntoIterator<Item = T>,
384        T: Into<Message>,
385    {
386        self.chat_history = Some(history.into_iter().map(Into::into).collect());
387        self
388    }
389
390    /// Attach a per-request hook for tool call events.
391    /// This overrides any default hook set on the agent.
392    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
393    where
394        P2: PromptHook<M>,
395    {
396        StreamingPromptRequest {
397            prompt: self.prompt,
398            chat_history: self.chat_history,
399            max_turns: self.max_turns,
400            model: self.model,
401            agent_name: self.agent_name,
402            preamble: self.preamble,
403            static_context: self.static_context,
404            temperature: self.temperature,
405            max_tokens: self.max_tokens,
406            additional_params: self.additional_params,
407            tool_server_handle: self.tool_server_handle,
408            dynamic_context: self.dynamic_context,
409            tool_choice: self.tool_choice,
410            output_schema: self.output_schema,
411            hook: Some(hook),
412            max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
413            memory: self.memory,
414            conversation_id: self.conversation_id,
415        }
416    }
417
418    /// Set the retry budget for [`crate::agent::prompt_request::hooks::InvalidToolCallHookAction::Retry`].
419    ///
420    /// Invalid tool-call retries also consume normal multi-turn depth.
421    pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
422        self.max_invalid_tool_call_retries = retries;
423        self
424    }
425
426    /// Set the conversation id used to load and persist memory for this request.
427    ///
428    /// Overrides any default conversation id set on the agent. If memory is not
429    /// configured on the agent, this has no effect.
430    pub fn conversation(mut self, id: impl Into<String>) -> Self {
431        self.conversation_id = Some(id.into());
432        self
433    }
434
435    /// Disable conversation memory for this request.
436    ///
437    /// History will neither be loaded from nor saved to the agent's memory backend.
438    pub fn without_memory(mut self) -> Self {
439        self.memory = None;
440        self.conversation_id = None;
441        self
442    }
443
444    async fn send(self) -> StreamingResult<M::StreamingResponse> {
445        let (agent_span, created_agent_span) = if tracing::Span::current().is_disabled() {
446            (
447                info_span!(
448                    "invoke_agent",
449                    gen_ai.operation.name = "invoke_agent",
450                    gen_ai.agent.name = self.agent_name(),
451                    gen_ai.system_instructions = self.preamble,
452                    gen_ai.prompt = tracing::field::Empty,
453                    gen_ai.completion = tracing::field::Empty,
454                    gen_ai.usage.input_tokens = tracing::field::Empty,
455                    gen_ai.usage.output_tokens = tracing::field::Empty,
456                    gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
457                    gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
458                    gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
459                    gen_ai.usage.reasoning_tokens = tracing::field::Empty,
460                ),
461                true,
462            )
463        } else {
464            (tracing::Span::current(), false)
465        };
466
467        let prompt = self.prompt;
468        if let Some(text) = prompt.rag_text() {
469            agent_span.record("gen_ai.prompt", text);
470        }
471
472        // Clone fields needed inside the stream
473        let model = self.model.clone();
474        let preamble = self.preamble.clone();
475        let static_context = self.static_context.clone();
476        let temperature = self.temperature;
477        let max_tokens = self.max_tokens;
478        let additional_params = self.additional_params.clone();
479        let tool_server_handle = self.tool_server_handle.clone();
480        let dynamic_context = self.dynamic_context.clone();
481        let tool_choice = self.tool_choice.clone();
482        let agent_name = self.agent_name.clone();
483        let output_schema = self.output_schema;
484        // When the caller passes explicit history, memory is fully bypassed for
485        // this request (no load AND no save). Otherwise, if a memory backend and
486        // conversation id are both configured, load prior history; if either is
487        // missing, behave as if no memory is configured.
488        let (chat_history, memory_handle) = match self.chat_history {
489            Some(history) => (Some(history), None),
490            None => match (self.memory, self.conversation_id) {
491                (Some(memory), Some(id)) => match memory.load(&id).await {
492                    Ok(loaded) => (Some(loaded), Some((memory, id))),
493                    Err(err) => {
494                        let stream = async_stream::stream! {
495                            yield Err(StreamingError::from(err));
496                        };
497                        return Box::pin(stream);
498                    }
499                },
500                _ => (None, None),
501            },
502        };
503        let has_history = chat_history.is_some();
504
505        let mut run = AgentRun::new(prompt.clone())
506            .max_turns(self.max_turns)
507            .max_invalid_tool_call_retries(self.max_invalid_tool_call_retries);
508        if let Some(history) = chat_history {
509            run = run.with_history(history);
510        }
511        if let Some(tool_choice) = tool_choice.clone() {
512            run = run.with_tool_choice(tool_choice);
513        }
514
515        // NOTE: We use .instrument(agent_span) instead of span.enter() to avoid
516        // span context leaking to other concurrent tasks. Using span.enter() inside
517        // async_stream::stream! holds the guard across yield points, which causes
518        // thread-local span context to leak when other tasks run on the same thread.
519        // See: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#in-asynchronous-code
520        // See also: https://github.com/rust-lang/rust-clippy/issues/8722
521        let stream = async_stream::stream! {
522            // The raw provider choice of the most recent turn; the final
523            // response surfaces it as-is, even when canonical reordering was
524            // recorded in history.
525            let mut last_final_choice: OneOrMany<AssistantContent> =
526                OneOrMany::one(AssistantContent::text(""));
527            let mut last_message_id: Option<String> = None;
528
529            'outer: loop {
530                let step = match run.next_step() {
531                    Ok(step) => step,
532                    Err(err) => {
533                        yield Err(Box::new(err).into());
534                        break 'outer;
535                    }
536                };
537
538                match step {
539                    AgentRunStep::CallModel { prompt: current_prompt, history, turn } => {
540                        if self.max_turns > 1 {
541                            tracing::info!(
542                                "Current conversation Turns: {}/{}",
543                                turn,
544                                self.max_turns
545                            );
546                        }
547
548                        if let Some(ref hook) = self.hook
549                            && let HookAction::Terminate { reason } =
550                                hook.on_completion_call(&current_prompt, &history).await
551                        {
552                            yield Err(StreamingError::Prompt(Box::new(run.cancel_error(reason))));
553                            break 'outer;
554                        }
555
556                        let chat_stream_span = info_span!(
557                            target: "rig::agent_chat",
558                            parent: tracing::Span::current(),
559                            "chat_streaming",
560                            gen_ai.operation.name = "chat",
561                            gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
562                            gen_ai.system_instructions = preamble,
563                            gen_ai.provider.name = tracing::field::Empty,
564                            gen_ai.request.model = tracing::field::Empty,
565                            gen_ai.response.id = tracing::field::Empty,
566                            gen_ai.response.model = tracing::field::Empty,
567                            gen_ai.usage.output_tokens = tracing::field::Empty,
568                            gen_ai.usage.input_tokens = tracing::field::Empty,
569                            gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
570                            gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
571                            gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
572                            gen_ai.usage.reasoning_tokens = tracing::field::Empty,
573                            gen_ai.input.messages = tracing::field::Empty,
574                            gen_ai.output.messages = tracing::field::Empty,
575                        );
576
577                        let prepared_request = build_prepared_completion_request(
578                            &model,
579                            current_prompt.clone(),
580                            &history,
581                            preamble.as_deref(),
582                            &static_context,
583                            temperature,
584                            max_tokens,
585                            additional_params.as_ref(),
586                            tool_choice.as_ref(),
587                            &tool_server_handle,
588                            &dynamic_context,
589                            output_schema.as_ref(),
590                        )
591                        .await?;
592
593                        let mut stream = prepared_request
594                            .builder
595                            .stream()
596                            .instrument(chat_stream_span.clone())
597                            .await?;
598
599                        let mut assembler = StreamedTurnAssembler::new(
600                            prepared_request.executable_tool_names.clone(),
601                            prepared_request.allowed_tool_names.clone(),
602                        );
603                        let mut completion_call_emitted = false;
604                        let mut turn_abandoned = false;
605
606                        'turn: while let Some(item) = stream.next().await {
607                            let item = match item {
608                                Ok(item) => item,
609                                Err(err) => {
610                                    yield Err(err.into());
611                                    break 'outer;
612                                }
613                            };
614                            let mut events: VecDeque<StreamedTurnEvent> =
615                                match assembler.ingest(&item) {
616                                    Ok(events) => events.into(),
617                                    Err(err) => {
618                                        yield Err(err.into());
619                                        break 'outer;
620                                    }
621                                };
622                            // At most one event per ingested item forwards the
623                            // item itself; moving it out of the slot avoids a
624                            // clone per streamed delta.
625                            let mut item_slot = Some(item);
626                            while let Some(event) = events.pop_front() {
627                                match event {
628                                    StreamedTurnEvent::EmitIngested => {
629                                        if let Some(StreamedAssistantContent::Text(text)) =
630                                            item_slot.as_ref()
631                                            && let Some(ref hook) = self.hook
632                                            && let HookAction::Terminate { reason } = hook
633                                                .on_text_delta(
634                                                    &text.text,
635                                                    assembler.aggregated_text(),
636                                                )
637                                                .await
638                                        {
639                                            yield Err(StreamingError::Prompt(Box::new(
640                                                run.cancel_error(reason),
641                                            )));
642                                            break 'outer;
643                                        }
644                                        if let Some(item) = item_slot.take() {
645                                            yield Ok(MultiTurnStreamItem::stream_item(item));
646                                        }
647                                    }
648                                    StreamedTurnEvent::EmitToolCallDelta {
649                                        id,
650                                        internal_call_id,
651                                        content,
652                                    } => {
653                                        if let Some(ref hook) = self.hook {
654                                            let (name, delta) = match &content {
655                                                ToolCallDeltaContent::Name(name) => {
656                                                    (Some(name.as_str()), "")
657                                                }
658                                                ToolCallDeltaContent::Delta(delta) => {
659                                                    (None, delta.as_str())
660                                                }
661                                            };
662
663                                            if let HookAction::Terminate { reason } = hook
664                                                .on_tool_call_delta(
665                                                    &id,
666                                                    &internal_call_id,
667                                                    name,
668                                                    delta,
669                                                )
670                                                .await
671                                            {
672                                                yield Err(StreamingError::Prompt(Box::new(
673                                                    run.cancel_error(reason),
674                                                )));
675                                                break 'outer;
676                                            }
677                                        }
678
679                                        yield Ok(MultiTurnStreamItem::StreamAssistantItem(
680                                            StreamedAssistantContent::ToolCallDelta {
681                                                id,
682                                                internal_call_id,
683                                                content,
684                                            },
685                                        ));
686                                    }
687                                    StreamedTurnEvent::Completed { usage, emit_final } => {
688                                        if !completion_call_emitted {
689                                            if usage.has_values() {
690                                                record_usage_on_span(&chat_stream_span, usage);
691                                            }
692                                            let completion_call =
693                                                match run.record_streamed_completion_call(usage) {
694                                                    Ok(call) => call,
695                                                    Err(err) => {
696                                                        yield Err(Box::new(err).into());
697                                                        break 'outer;
698                                                    }
699                                                };
700                                            completion_call_emitted = true;
701                                            yield Ok(MultiTurnStreamItem::CompletionCall(
702                                                completion_call,
703                                            ));
704                                        }
705
706                                        if emit_final
707                                            && let Some(StreamedAssistantContent::Final(
708                                                final_resp,
709                                            )) = item_slot.as_ref()
710                                        {
711                                            if let Some(ref hook) = self.hook
712                                                && let HookAction::Terminate { reason } = hook
713                                                    .on_stream_completion_response_finish(
714                                                        &current_prompt,
715                                                        final_resp,
716                                                    )
717                                                    .await
718                                            {
719                                                yield Err(StreamingError::Prompt(Box::new(
720                                                    run.cancel_error(reason),
721                                                )));
722                                                break 'outer;
723                                            }
724                                            if let Some(item) = item_slot.take() {
725                                                yield Ok(MultiTurnStreamItem::stream_item(item));
726                                            }
727                                        }
728                                    }
729                                    StreamedTurnEvent::InvalidToolCall(invalid) => {
730                                        let partial =
731                                            assembler.partial_turn(stream.message_id.clone());
732                                        let action = match self.hook.as_ref() {
733                                            Some(hook) => {
734                                                let context = run
735                                                    .streamed_invalid_tool_call_context(
736                                                        &partial, &invalid,
737                                                    );
738                                                hook.on_invalid_tool_call(&context).await
739                                            }
740                                            None => InvalidToolCallHookAction::fail(),
741                                        };
742
743                                        let resolution = match run
744                                            .resolve_streamed_invalid_tool_call(
745                                                &partial, &invalid, action,
746                                            ) {
747                                            Ok(resolution) => resolution,
748                                            Err(err) => {
749                                                yield Err(Box::new(err).into());
750                                                break 'outer;
751                                            }
752                                        };
753
754                                        match resolution {
755                                            StreamedResolution::Repaired { .. } => {
756                                                // Replayed name/argument deltas flow through
757                                                // the same event handling above.
758                                                events.extend(
759                                                    assembler.resolve_pending_invalid(&resolution),
760                                                );
761                                            }
762                                            StreamedResolution::TurnAbandoned {
763                                                ref skipped_tool_result,
764                                            } => {
765                                                let skipped_tool_result =
766                                                    skipped_tool_result.clone();
767                                                assembler.resolve_pending_invalid(&resolution);
768
769                                                if let Some(err) = assembler.pending_delta_error() {
770                                                    yield Err(err.into());
771                                                    break 'outer;
772                                                }
773                                                let drained_usage =
774                                                    match drain_stream_usage(&mut stream).await {
775                                                        Ok(usage) => usage,
776                                                        Err(err) => {
777                                                            yield Err(err);
778                                                            break 'outer;
779                                                        }
780                                                    };
781                                                if !completion_call_emitted {
782                                                    if drained_usage.has_values() {
783                                                        record_usage_on_span(
784                                                            &chat_stream_span,
785                                                            drained_usage,
786                                                        );
787                                                    }
788                                                    let completion_call = match run
789                                                        .record_streamed_completion_call(
790                                                            drained_usage,
791                                                        ) {
792                                                        Ok(call) => call,
793                                                        Err(err) => {
794                                                            yield Err(Box::new(err).into());
795                                                            break 'outer;
796                                                        }
797                                                    };
798                                                    completion_call_emitted = true;
799                                                    yield Ok(MultiTurnStreamItem::CompletionCall(
800                                                        completion_call,
801                                                    ));
802                                                }
803                                                if let Some(tool_result) = skipped_tool_result {
804                                                    yield Ok(MultiTurnStreamItem::StreamUserItem(
805                                                        StreamedUserContent::ToolResult {
806                                                            tool_result,
807                                                            internal_call_id: invalid
808                                                                .internal_call_id
809                                                                .clone(),
810                                                        },
811                                                    ));
812                                                }
813                                                turn_abandoned = true;
814                                                break 'turn;
815                                            }
816                                        }
817                                    }
818                                }
819                            }
820                        }
821
822                        if turn_abandoned {
823                            continue 'outer;
824                        }
825
826                        if let Some(err) = assembler.pending_delta_error() {
827                            yield Err(err.into());
828                            break 'outer;
829                        }
830
831                        if !completion_call_emitted {
832                            let completion_call =
833                                match run
834                                    .record_streamed_completion_call(crate::completion::Usage::new())
835                                {
836                                    Ok(call) => call,
837                                    Err(err) => {
838                                        yield Err(Box::new(err).into());
839                                        break 'outer;
840                                    }
841                                };
842                            yield Ok(MultiTurnStreamItem::CompletionCall(completion_call));
843                        }
844
845                        let final_turn_content = stream.choice.clone();
846                        tracing::Span::current().record(
847                            "gen_ai.completion",
848                            assistant_text_from_choice(&final_turn_content),
849                        );
850
851                        last_message_id = stream.message_id.clone();
852                        let streamed_turn =
853                            assembler.finish(stream.message_id.clone(), &final_turn_content);
854                        if let Err(err) = run.streamed_turn(streamed_turn) {
855                            yield Err(Box::new(err).into());
856                            break 'outer;
857                        }
858                        last_final_choice = final_turn_content;
859                    }
860                    AgentRunStep::CallTools { calls } => {
861                        let full_history_for_errors = run.full_history();
862                        let mut results: Vec<UserContent> = Vec::with_capacity(calls.len());
863
864                        for pending in calls {
865                            let tool_call = pending.tool_call;
866                            if let Some(result) = pending.preresolved_result {
867                                // Pre-resolved results only occur when invalid
868                                // tool-call recovery suppressed execution; the
869                                // streamed path abandons such turns instead, so
870                                // this arm only serves machine-level drivers
871                                // mixing streamed and non-streamed turns.
872                                results.push(result);
873                                continue;
874                            }
875                            let internal_call_id = pending
876                                .internal_call_id
877                                .unwrap_or_else(|| nanoid::nanoid!());
878
879                            let tool_span = info_span!(
880                                parent: tracing::Span::current(),
881                                "execute_tool",
882                                gen_ai.operation.name = "execute_tool",
883                                gen_ai.tool.type = "function",
884                                gen_ai.tool.name = tracing::field::Empty,
885                                gen_ai.tool.call.id = tracing::field::Empty,
886                                gen_ai.tool.call.arguments = tracing::field::Empty,
887                                gen_ai.tool.call.result = tracing::field::Empty
888                            );
889
890                            yield Ok(MultiTurnStreamItem::stream_item(
891                                StreamedAssistantContent::ToolCall {
892                                    tool_call: tool_call.clone(),
893                                    internal_call_id: internal_call_id.clone(),
894                                },
895                            ));
896
897                            let tc_result = async {
898                                let tool_span = tracing::Span::current();
899                                let tool_args =
900                                    json_utils::value_to_json_string(&tool_call.function.arguments);
901                                if let Some(ref hook) = self.hook {
902                                    let action = hook
903                                        .on_tool_call(
904                                            &tool_call.function.name,
905                                            tool_call.call_id.clone(),
906                                            &internal_call_id,
907                                            &tool_args,
908                                        )
909                                        .await;
910
911                                    if let ToolCallHookAction::Terminate { reason } = action {
912                                        return Err(StreamingError::Prompt(Box::new(
913                                            PromptError::prompt_cancelled(
914                                                full_history_for_errors.clone(),
915                                                reason,
916                                            ),
917                                        )));
918                                    }
919
920                                    if let ToolCallHookAction::Skip { reason } = action {
921                                        // Tool execution rejected, return rejection message as tool result
922                                        tracing::info!(
923                                            tool_name = tool_call.function.name.as_str(),
924                                            reason = reason,
925                                            "Tool call rejected"
926                                        );
927                                        return Ok(reason);
928                                    }
929                                }
930
931                                tool_span.record("gen_ai.tool.name", &tool_call.function.name);
932                                tool_span.record("gen_ai.tool.call.arguments", &tool_args);
933
934                                let tool_result = match tool_server_handle
935                                    .call_tool(&tool_call.function.name, &tool_args)
936                                    .await
937                                {
938                                    Ok(result) => result,
939                                    Err(err) => {
940                                        tracing::warn!("Error while calling tool: {err}");
941                                        err.to_string()
942                                    }
943                                };
944
945                                tool_span.record("gen_ai.tool.call.result", &tool_result);
946
947                                if let Some(ref hook) = self.hook
948                                    && let HookAction::Terminate { reason } = hook
949                                        .on_tool_result(
950                                            &tool_call.function.name,
951                                            tool_call.call_id.clone(),
952                                            &internal_call_id,
953                                            &tool_args,
954                                            &tool_result.to_string(),
955                                        )
956                                        .await
957                                {
958                                    return Err(StreamingError::Prompt(Box::new(
959                                        PromptError::prompt_cancelled(
960                                            full_history_for_errors.clone(),
961                                            reason,
962                                        ),
963                                    )));
964                                }
965
966                                Ok(tool_result)
967                            }
968                            .instrument(tool_span)
969                            .await;
970
971                            match tc_result {
972                                Ok(text) => {
973                                    results.push(tool_result_user_content(
974                                        tool_call.id.clone(),
975                                        tool_call.call_id.clone(),
976                                        text.clone(),
977                                    ));
978                                    let tool_result = ToolResult {
979                                        id: tool_call.id,
980                                        call_id: tool_call.call_id,
981                                        content: ToolResultContent::from_tool_output(text),
982                                    };
983                                    yield Ok(MultiTurnStreamItem::StreamUserItem(
984                                        StreamedUserContent::ToolResult {
985                                            tool_result,
986                                            internal_call_id,
987                                        },
988                                    ));
989                                }
990                                Err(err) => {
991                                    yield Err(err);
992                                    break 'outer;
993                                }
994                            }
995                        }
996
997                        if let Err(err) = run.tool_results(results) {
998                            yield Err(Box::new(err).into());
999                            break 'outer;
1000                        }
1001                    }
1002                    AgentRunStep::Done(response) => {
1003                        if is_empty_assistant_turn(&last_final_choice) {
1004                            tracing::warn!(
1005                                agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
1006                                message_id = ?last_message_id,
1007                                "Streaming turn completed without assistant text; final response will be empty"
1008                            );
1009                        }
1010
1011                        if created_agent_span {
1012                            let current_span = tracing::Span::current();
1013                            record_usage_on_span(&current_span, response.usage);
1014                        }
1015                        tracing::info!("Agent multi-turn stream finished");
1016                        if let Some((memory, id)) = memory_handle.as_ref()
1017                            && let Err(err) = memory
1018                                .append(id, response.messages.clone().unwrap_or_default())
1019                                .await
1020                        {
1021                            tracing::warn!(
1022                                error = %err,
1023                                conversation_id = %id,
1024                                "conversation memory append failed; yielding final response anyway"
1025                            );
1026                        }
1027                        let final_messages: Option<Vec<Message>> = if has_history {
1028                            Some(response.messages.clone().unwrap_or_default())
1029                        } else {
1030                            None
1031                        };
1032                        yield Ok(MultiTurnStreamItem::final_response_with_completion_calls(
1033                            last_final_choice.clone(),
1034                            response.usage,
1035                            response.completion_calls.clone(),
1036                            final_messages,
1037                        ));
1038                        break 'outer;
1039                    }
1040                }
1041            }
1042        };
1043
1044        Box::pin(stream.instrument(agent_span))
1045    }
1046}
1047
1048impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
1049where
1050    M: CompletionModel + 'static,
1051    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
1052    P: PromptHook<M> + 'static,
1053{
1054    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
1055    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1056
1057    fn into_future(self) -> Self::IntoFuture {
1058        // Wrap send() in a future, because send() returns a stream immediately
1059        Box::pin(async move { self.send().await })
1060    }
1061}
1062
1063/// Helper function to stream assistant-visible completion output to stdout.
1064///
1065/// This helper prints streamed assistant text and reasoning only. Streaming
1066/// metadata events, such as `MultiTurnStreamItem::CompletionCall`, are not
1067/// printed; metadata is returned on the `FinalResponse` via accessors such as
1068/// `FinalResponse::completion_calls`.
1069pub async fn stream_to_stdout<R>(
1070    stream: &mut StreamingResult<R>,
1071) -> Result<FinalResponse, std::io::Error> {
1072    let mut final_res = FinalResponse::empty();
1073    print!("Response: ");
1074    while let Some(content) = stream.next().await {
1075        match content {
1076            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1077                Text { text, .. },
1078            ))) => {
1079                print!("{text}");
1080                std::io::Write::flush(&mut std::io::stdout())?;
1081            }
1082            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
1083                reasoning,
1084            ))) => {
1085                let reasoning = reasoning.display_text();
1086                print!("{reasoning}");
1087                std::io::Write::flush(&mut std::io::stdout())?;
1088            }
1089            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1090                final_res = res;
1091            }
1092            Err(err) => {
1093                eprintln!("Error: {err}");
1094            }
1095            _ => {}
1096        }
1097    }
1098
1099    Ok(final_res)
1100}
1101
1102#[cfg(test)]
1103mod tests {
1104    use super::*;
1105    use crate::agent::AgentBuilder;
1106    use crate::agent::prompt_request::TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER;
1107    use crate::agent::prompt_request::hooks::{
1108        InvalidToolCallContext, InvalidToolCallHookAction, PromptHook, ToolCallHookAction,
1109    };
1110    use crate::agent::run::streamed::merge_reasoning_blocks;
1111    use crate::client::ProviderClient;
1112    use crate::client::completion::CompletionClient;
1113    use crate::completion::{CompletionRequest, PromptError, ToolDefinition, Usage};
1114    use crate::message::{
1115        AssistantContent, DocumentSourceKind, ImageMediaType, Message, ReasoningContent,
1116        ToolChoice, ToolResultContent, UserContent,
1117    };
1118    use crate::providers::anthropic;
1119    use crate::streaming::{StreamingPrompt, ToolCallDeltaContent};
1120    use crate::test_utils::{
1121        AppendFailingMemory, FailingMemory, MockAddTool, MockCompletionModel, MockResponse,
1122        MockStreamEvent, MockSubtractTool, MockToolError,
1123    };
1124    use crate::tool::Tool;
1125    use futures::{StreamExt, TryStreamExt};
1126    use serde::Deserialize;
1127    use std::collections::HashMap;
1128    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
1129    use std::sync::{Arc, Mutex};
1130    use std::time::Duration;
1131    use tracing::field::{Field, Visit};
1132    use tracing::{Id, Subscriber};
1133    use tracing_subscriber::layer::{Context, SubscriberExt};
1134    use tracing_subscriber::{Layer, Registry, registry::LookupSpan};
1135
1136    #[test]
1137    fn merge_reasoning_blocks_preserves_order_and_signatures() {
1138        let mut accumulated = Vec::new();
1139        let first = crate::message::Reasoning {
1140            id: Some("rs_1".to_string()),
1141            content: vec![ReasoningContent::Text {
1142                text: "step-1".to_string(),
1143                signature: Some("sig-1".to_string()),
1144            }],
1145        };
1146        let second = crate::message::Reasoning {
1147            id: Some("rs_1".to_string()),
1148            content: vec![
1149                ReasoningContent::Text {
1150                    text: "step-2".to_string(),
1151                    signature: Some("sig-2".to_string()),
1152                },
1153                ReasoningContent::Summary("summary".to_string()),
1154            ],
1155        };
1156
1157        merge_reasoning_blocks(&mut accumulated, &first);
1158        merge_reasoning_blocks(&mut accumulated, &second);
1159
1160        assert_eq!(accumulated.len(), 1);
1161        let merged = accumulated.first().expect("expected accumulated reasoning");
1162        assert_eq!(merged.id.as_deref(), Some("rs_1"));
1163        assert_eq!(merged.content.len(), 3);
1164        assert!(matches!(
1165            merged.content.first(),
1166            Some(ReasoningContent::Text { text, signature: Some(sig) })
1167                if text == "step-1" && sig == "sig-1"
1168        ));
1169        assert!(matches!(
1170            merged.content.get(1),
1171            Some(ReasoningContent::Text { text, signature: Some(sig) })
1172                if text == "step-2" && sig == "sig-2"
1173        ));
1174    }
1175
1176    #[test]
1177    fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
1178        let mut accumulated = vec![crate::message::Reasoning {
1179            id: Some("rs_a".to_string()),
1180            content: vec![ReasoningContent::Text {
1181                text: "step-1".to_string(),
1182                signature: None,
1183            }],
1184        }];
1185        let incoming = crate::message::Reasoning {
1186            id: Some("rs_b".to_string()),
1187            content: vec![ReasoningContent::Text {
1188                text: "step-2".to_string(),
1189                signature: None,
1190            }],
1191        };
1192
1193        merge_reasoning_blocks(&mut accumulated, &incoming);
1194        assert_eq!(accumulated.len(), 2);
1195        assert_eq!(
1196            accumulated.first().and_then(|r| r.id.as_deref()),
1197            Some("rs_a")
1198        );
1199        assert_eq!(
1200            accumulated.get(1).and_then(|r| r.id.as_deref()),
1201            Some("rs_b")
1202        );
1203    }
1204
1205    #[test]
1206    fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
1207        let mut accumulated = vec![crate::message::Reasoning {
1208            id: None,
1209            content: vec![ReasoningContent::Text {
1210                text: "first".to_string(),
1211                signature: None,
1212            }],
1213        }];
1214        let incoming = crate::message::Reasoning {
1215            id: None,
1216            content: vec![ReasoningContent::Text {
1217                text: "second".to_string(),
1218                signature: None,
1219            }],
1220        };
1221
1222        merge_reasoning_blocks(&mut accumulated, &incoming);
1223        assert_eq!(accumulated.len(), 2);
1224        assert!(matches!(
1225            accumulated.first(),
1226            Some(crate::message::Reasoning {
1227                id: None,
1228                content
1229            }) if matches!(
1230                content.first(),
1231                Some(ReasoningContent::Text { text, .. }) if text == "first"
1232            )
1233        ));
1234        assert!(matches!(
1235            accumulated.get(1),
1236            Some(crate::message::Reasoning {
1237                id: None,
1238                content
1239            }) if matches!(
1240                content.first(),
1241                Some(ReasoningContent::Text { text, .. }) if text == "second"
1242            )
1243        ));
1244    }
1245
1246    #[test]
1247    fn tool_result_user_content_preserves_multimodal_tool_output() {
1248        let user_content = tool_result_user_content(
1249            "tool_call_1".to_string(),
1250            Some("call_1".to_string()),
1251            serde_json::json!({
1252                "response": {
1253                    "instruction": "Use the image part to answer."
1254                },
1255                "parts": [
1256                    {
1257                        "type": "image",
1258                        "data": "base64data==",
1259                        "mimeType": "image/png"
1260                    }
1261                ]
1262            })
1263            .to_string(),
1264        );
1265
1266        let tool_result = match user_content {
1267            UserContent::ToolResult(tool_result) => tool_result,
1268            other => panic!("expected tool result content, got {other:?}"),
1269        };
1270
1271        assert_eq!(tool_result.id, "tool_call_1");
1272        assert_eq!(tool_result.call_id.as_deref(), Some("call_1"));
1273        assert_eq!(tool_result.content.len(), 2);
1274
1275        let mut items = tool_result.content.iter();
1276        match items.next() {
1277            Some(ToolResultContent::Text(text)) => {
1278                assert!(text.text.contains("Use the image part to answer."));
1279            }
1280            other => panic!("expected structured text payload first, got {other:?}"),
1281        }
1282
1283        match items.next() {
1284            Some(ToolResultContent::Image(image)) => {
1285                assert_eq!(image.media_type, Some(ImageMediaType::PNG));
1286                assert!(matches!(
1287                    image.data,
1288                    DocumentSourceKind::Base64(ref data) if data == "base64data=="
1289                ));
1290            }
1291            other => panic!("expected image payload second, got {other:?}"),
1292        }
1293    }
1294
1295    fn validate_follow_up_tool_history(request: &CompletionRequest) -> Result<(), String> {
1296        let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
1297        if history.len() != 3 {
1298            return Err(format!(
1299                "follow-up request should contain [original user prompt, assistant tool call, user tool result]: {history:?}"
1300            ));
1301        }
1302
1303        if !matches!(
1304            history.first(),
1305            Some(Message::User { content })
1306                if matches!(
1307                    content.first(),
1308                    UserContent::Text(text) if text.text == "do tool work"
1309                )
1310        ) {
1311            return Err(format!(
1312                "follow-up request should begin with the original user prompt: {history:?}"
1313            ));
1314        }
1315
1316        if !matches!(
1317            history.get(1),
1318            Some(Message::Assistant { content, .. })
1319                if matches!(
1320                    content.first(),
1321                    AssistantContent::ToolCall(tool_call)
1322                        if tool_call.id == "tool_call_1"
1323                            && tool_call.call_id.as_deref() == Some("call_1")
1324                )
1325        ) {
1326            return Err(format!(
1327                "follow-up request is missing the assistant tool call in position 2: {history:?}"
1328            ));
1329        }
1330
1331        if !matches!(
1332            history.get(2),
1333            Some(Message::User { content })
1334                if matches!(
1335                    content.first(),
1336                    UserContent::ToolResult(tool_result)
1337                        if tool_result.id == "tool_call_1"
1338                            && tool_result.call_id.as_deref() == Some("call_1")
1339                )
1340        ) {
1341            return Err(format!(
1342                "follow-up request should end with the user tool result: {history:?}"
1343            ));
1344        }
1345
1346        Ok(())
1347    }
1348
1349    fn history_contains_tool_call(history: &[Message], tool_name: &str) -> bool {
1350        history.iter().any(|message| {
1351            matches!(
1352                message,
1353                Message::Assistant { content, .. }
1354                    if content.iter().any(|item| matches!(
1355                        item,
1356                        AssistantContent::ToolCall(tool_call)
1357                            if tool_call.function.name == tool_name
1358                    ))
1359            )
1360        })
1361    }
1362
1363    fn history_contains_text(history: &[Message], expected: &str) -> bool {
1364        history.iter().any(|message| {
1365            matches!(
1366                message,
1367                Message::Assistant { content, .. }
1368                    if content.iter().any(|item| matches!(
1369                        item,
1370                        AssistantContent::Text(text) if text.text == expected
1371                    ))
1372            )
1373        })
1374    }
1375
1376    fn assistant_reasoning_precedes_tool_call(
1377        history: &[Message],
1378        expected_reasoning: &str,
1379        tool_name: &str,
1380    ) -> bool {
1381        history.iter().any(|message| {
1382            let Message::Assistant { content, .. } = message else {
1383                return false;
1384            };
1385
1386            let reasoning_index = content.iter().position(|item| {
1387                matches!(
1388                    item,
1389                    AssistantContent::Reasoning(reasoning)
1390                        if reasoning.content.iter().any(|content| matches!(
1391                            content,
1392                            ReasoningContent::Text { text, .. }
1393                                if text == expected_reasoning
1394                        ))
1395                )
1396            });
1397            let tool_index = content.iter().position(|item| {
1398                matches!(
1399                    item,
1400                    AssistantContent::ToolCall(tool_call)
1401                        if tool_call.function.name == tool_name
1402                )
1403            });
1404
1405            matches!((reasoning_index, tool_index), (Some(reasoning), Some(tool)) if reasoning < tool)
1406        })
1407    }
1408
1409    fn assistant_reasoning_precedes_text_and_tool_call(
1410        history: &[Message],
1411        expected_reasoning: &str,
1412        expected_text: &str,
1413        tool_name: &str,
1414    ) -> bool {
1415        history.iter().any(|message| {
1416            let Message::Assistant { content, .. } = message else {
1417                return false;
1418            };
1419
1420            let reasoning_index = content.iter().position(|item| {
1421                matches!(
1422                    item,
1423                    AssistantContent::Reasoning(reasoning)
1424                        if reasoning.content.iter().any(|content| matches!(
1425                            content,
1426                            ReasoningContent::Text { text, .. }
1427                                if text == expected_reasoning
1428                        ))
1429                )
1430            });
1431            let text_index = content.iter().position(|item| {
1432                matches!(
1433                    item,
1434                    AssistantContent::Text(text) if text.text == expected_text
1435                )
1436            });
1437            let tool_index = content.iter().position(|item| {
1438                matches!(
1439                    item,
1440                    AssistantContent::ToolCall(tool_call)
1441                        if tool_call.function.name == tool_name
1442                )
1443            });
1444
1445            matches!(
1446                (reasoning_index, text_index, tool_index),
1447                (Some(reasoning), Some(text), Some(tool))
1448                    if reasoning < text && text < tool
1449            )
1450        })
1451    }
1452
1453    #[derive(Clone)]
1454    struct PanicOnUnknownToolHook;
1455
1456    impl PromptHook<MockCompletionModel> for PanicOnUnknownToolHook {
1457        async fn on_tool_call_delta(
1458            &self,
1459            _tool_call_id: &str,
1460            _internal_call_id: &str,
1461            _tool_name: Option<&str>,
1462            _tool_call_delta: &str,
1463        ) -> HookAction {
1464            panic!("unknown tool call delta should fail before delta hooks run")
1465        }
1466
1467        async fn on_tool_call(
1468            &self,
1469            _tool_name: &str,
1470            _tool_call_id: Option<String>,
1471            _internal_call_id: &str,
1472            _args: &str,
1473        ) -> ToolCallHookAction {
1474            panic!("unknown tool call should fail before tool hooks run")
1475        }
1476
1477        async fn on_stream_completion_response_finish(
1478            &self,
1479            _prompt: &Message,
1480            _response: &MockResponse,
1481        ) -> HookAction {
1482            panic!("unknown tool call should fail before stream finish hooks run")
1483        }
1484    }
1485
1486    #[derive(Clone)]
1487    struct CountingAddTool {
1488        calls: Arc<AtomicU32>,
1489    }
1490
1491    #[derive(Clone)]
1492    struct CountingSubtractTool {
1493        calls: Arc<AtomicU32>,
1494    }
1495
1496    #[derive(Deserialize)]
1497    struct CountingOperationArgs {
1498        x: i32,
1499        y: i32,
1500    }
1501
1502    fn arithmetic_tool_definition(name: &str, description: &str) -> ToolDefinition {
1503        ToolDefinition {
1504            name: name.to_string(),
1505            description: description.to_string(),
1506            parameters: serde_json::json!({
1507                "type": "object",
1508                "properties": {
1509                    "x": {
1510                        "type": "number",
1511                        "description": "The first operand"
1512                    },
1513                    "y": {
1514                        "type": "number",
1515                        "description": "The second operand"
1516                    }
1517                },
1518                "required": ["x", "y"],
1519            }),
1520        }
1521    }
1522
1523    impl Tool for CountingAddTool {
1524        const NAME: &'static str = "add";
1525        type Error = MockToolError;
1526        type Args = CountingOperationArgs;
1527        type Output = i32;
1528
1529        async fn definition(&self, _prompt: String) -> ToolDefinition {
1530            arithmetic_tool_definition(Self::NAME, "Add x and y together")
1531        }
1532
1533        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
1534            self.calls.fetch_add(1, Ordering::SeqCst);
1535            Ok(args.x + args.y)
1536        }
1537    }
1538
1539    impl Tool for CountingSubtractTool {
1540        const NAME: &'static str = "subtract";
1541        type Error = MockToolError;
1542        type Args = CountingOperationArgs;
1543        type Output = i32;
1544
1545        async fn definition(&self, _prompt: String) -> ToolDefinition {
1546            arithmetic_tool_definition(Self::NAME, "Subtract y from x")
1547        }
1548
1549        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
1550            self.calls.fetch_add(1, Ordering::SeqCst);
1551            Ok(args.x - args.y)
1552        }
1553    }
1554
1555    fn streaming_tool_then_text_model() -> MockCompletionModel {
1556        MockCompletionModel::from_stream_turns([
1557            vec![
1558                MockStreamEvent::tool_call(
1559                    "tool_call_1",
1560                    "add",
1561                    serde_json::json!({"x": 1, "y": 2}),
1562                )
1563                .with_call_id("call_1"),
1564                MockStreamEvent::final_response_with_total_tokens(4),
1565            ],
1566            vec![
1567                MockStreamEvent::text("done"),
1568                MockStreamEvent::final_response_with_total_tokens(6),
1569            ],
1570        ])
1571    }
1572
1573    fn usage(input_tokens: u64, output_tokens: u64) -> Usage {
1574        Usage {
1575            input_tokens,
1576            output_tokens,
1577            total_tokens: input_tokens + output_tokens,
1578            cached_input_tokens: 0,
1579            cache_creation_input_tokens: 0,
1580            tool_use_prompt_tokens: 0,
1581            reasoning_tokens: 0,
1582        }
1583    }
1584
1585    #[derive(Clone, Debug, Default)]
1586    struct CapturedSpan {
1587        id: u64,
1588        name: String,
1589        parent_id: Option<u64>,
1590        fields: HashMap<String, u64>,
1591    }
1592
1593    #[derive(Clone, Default)]
1594    struct CapturedSpans(Arc<Mutex<Vec<CapturedSpan>>>);
1595
1596    impl CapturedSpans {
1597        fn clear(&self) {
1598            if let Ok(mut spans) = self.0.lock() {
1599                spans.clear();
1600            }
1601        }
1602
1603        fn insert(&self, id: &Id, name: &str, parent_id: Option<u64>) {
1604            let id = id.into_u64();
1605            if let Ok(mut spans) = self.0.lock() {
1606                spans.push(CapturedSpan {
1607                    id,
1608                    name: name.to_string(),
1609                    parent_id,
1610                    fields: HashMap::new(),
1611                });
1612            }
1613        }
1614
1615        fn record(&self, id: &Id, fields: Vec<(String, u64)>) {
1616            if let Ok(mut spans) = self.0.lock()
1617                && let Some(span) = spans.iter_mut().rev().find(|span| span.id == id.into_u64())
1618            {
1619                span.fields.extend(fields);
1620            }
1621        }
1622
1623        fn snapshot(&self) -> Vec<CapturedSpan> {
1624            self.0.lock().map(|spans| spans.clone()).unwrap_or_default()
1625        }
1626    }
1627
1628    struct SpanCaptureLayer {
1629        spans: CapturedSpans,
1630    }
1631
1632    impl<S> Layer<S> for SpanCaptureLayer
1633    where
1634        S: Subscriber,
1635        S: for<'lookup> LookupSpan<'lookup>,
1636    {
1637        fn on_new_span(&self, attrs: &tracing::span::Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
1638            let parent_id = attrs
1639                .parent()
1640                .map(Id::into_u64)
1641                .or_else(|| ctx.current_span().id().map(Id::into_u64));
1642            self.spans.insert(id, attrs.metadata().name(), parent_id);
1643        }
1644
1645        fn on_record(&self, span: &Id, values: &tracing::span::Record<'_>, _ctx: Context<'_, S>) {
1646            let mut fields = Vec::new();
1647            values.record(&mut SpanFieldCaptureVisitor {
1648                fields: &mut fields,
1649            });
1650            self.spans.record(span, fields);
1651        }
1652    }
1653
1654    struct SpanFieldCaptureVisitor<'a> {
1655        fields: &'a mut Vec<(String, u64)>,
1656    }
1657
1658    impl Visit for SpanFieldCaptureVisitor<'_> {
1659        fn record_u64(&mut self, field: &Field, value: u64) {
1660            self.fields.push((field.name().to_string(), value));
1661        }
1662
1663        fn record_debug(&mut self, _field: &Field, _value: &dyn std::fmt::Debug) {}
1664    }
1665
1666    async fn assert_stream_usage_recorded_on_chat_spans(
1667        agent: crate::agent::Agent<MockCompletionModel>,
1668        prompt: &str,
1669        max_turns: usize,
1670        expected_usages: &[Usage],
1671    ) {
1672        // Scoped-subscriber tests must not run concurrently; the warm-up
1673        // below explains the callsite-interest hazard this guards against.
1674        let _isolation = crate::test_utils::scoped_tracing_subscriber_guard().await;
1675        let spans = CapturedSpans::default();
1676        let subscriber = Registry::default().with(SpanCaptureLayer {
1677            spans: spans.clone(),
1678        });
1679        let _default = tracing::subscriber::set_default(subscriber);
1680
1681        // Span callsites in the driver are shared with every other test in
1682        // this binary. The FIRST thread to hit a callsite caches its interest
1683        // from that thread's dispatcher (`Dispatchers::Rebuilder::JustOne`
1684        // consults `dispatcher::get_default`), so a parallel test without a
1685        // subscriber can permanently cache `Interest::never` for the very
1686        // spans this harness asserts on. Defend in two steps, both under the
1687        // isolation guard: (1) warm the whole driver path from THIS thread so
1688        // unregistered callsites first-register against this subscriber, then
1689        // (2) rebuild the interest cache to heal callsites a foreign thread
1690        // already poisoned.
1691        let warmup_model = MockCompletionModel::from_stream_turns([[
1692            MockStreamEvent::text("warmup"),
1693            MockStreamEvent::final_response(Usage::default()),
1694        ]]);
1695        let warmup_agent = crate::agent::AgentBuilder::new(warmup_model).build();
1696        let mut warmup_stream = warmup_agent.stream_prompt("warmup").multi_turn(1).await;
1697        while let Some(item) = warmup_stream
1698            .try_next()
1699            .await
1700            .expect("warmup stream should not error")
1701        {
1702            if matches!(item, MultiTurnStreamItem::FinalResponse(_)) {
1703                break;
1704            }
1705        }
1706        tracing::callsite::rebuild_interest_cache();
1707        spans.clear();
1708
1709        let empty_history: &[Message] = &[];
1710        let outer_span = tracing::info_span!("outer");
1711
1712        async {
1713            let mut stream = agent
1714                .stream_prompt(prompt)
1715                .with_history(empty_history)
1716                .multi_turn(max_turns)
1717                .await;
1718
1719            while let Some(item) = stream.try_next().await.expect("stream should not error") {
1720                if matches!(item, MultiTurnStreamItem::FinalResponse(_)) {
1721                    break;
1722                }
1723            }
1724        }
1725        .instrument(outer_span)
1726        .await;
1727
1728        let span_snapshot = spans.snapshot();
1729        let outer_span_id = span_snapshot
1730            .iter()
1731            .find(|span| span.name == "outer")
1732            .map(|span| span.id)
1733            .expect("outer span should be captured");
1734        let chat_spans = span_snapshot
1735            .iter()
1736            .filter(|span| span.name == "chat_streaming")
1737            .collect::<Vec<_>>();
1738
1739        assert_eq!(chat_spans.len(), expected_usages.len());
1740        assert!(
1741            span_snapshot.iter().all(|span| span.name != "invoke_agent"),
1742            "outer span path should not create invoke_agent"
1743        );
1744
1745        for (chat_span, expected_usage) in chat_spans.into_iter().zip(expected_usages) {
1746            assert_eq!(chat_span.parent_id, Some(outer_span_id));
1747            assert_eq!(
1748                chat_span.fields.get("gen_ai.usage.input_tokens"),
1749                Some(&expected_usage.input_tokens)
1750            );
1751            assert_eq!(
1752                chat_span.fields.get("gen_ai.usage.output_tokens"),
1753                Some(&expected_usage.output_tokens)
1754            );
1755            assert_eq!(
1756                chat_span.fields.get("gen_ai.usage.cache_read.input_tokens"),
1757                Some(&expected_usage.cached_input_tokens)
1758            );
1759            assert_eq!(
1760                chat_span
1761                    .fields
1762                    .get("gen_ai.usage.cache_creation.input_tokens"),
1763                Some(&expected_usage.cache_creation_input_tokens)
1764            );
1765            assert_eq!(
1766                chat_span.fields.get("gen_ai.usage.tool_use_prompt_tokens"),
1767                Some(&expected_usage.tool_use_prompt_tokens)
1768            );
1769            assert_eq!(
1770                chat_span.fields.get("gen_ai.usage.reasoning_tokens"),
1771                Some(&expected_usage.reasoning_tokens)
1772            );
1773        }
1774
1775        let outer_span = span_snapshot
1776            .iter()
1777            .find(|span| span.id == outer_span_id)
1778            .expect("outer span should be present");
1779        assert!(
1780            outer_span
1781                .fields
1782                .keys()
1783                .all(|field| !field.starts_with("gen_ai.usage.")),
1784            "usage should not be recorded onto the caller's outer span"
1785        );
1786    }
1787
1788    #[test]
1789    fn completion_calls_stream_item_serializes_and_deserializes_expected_shape() {
1790        let item: MultiTurnStreamItem<MockResponse> =
1791            MultiTurnStreamItem::CompletionCall(CompletionCall::new(2, usage(3, 4)));
1792
1793        let value = serde_json::to_value(&item).expect("serialize completion call event");
1794
1795        assert_eq!(
1796            value,
1797            serde_json::json!({
1798                "type": "completionCall",
1799                "call_index": 2,
1800                "usage": {
1801                    "input_tokens": 3,
1802                    "output_tokens": 4,
1803                    "total_tokens": 7,
1804                    "cached_input_tokens": 0,
1805                    "cache_creation_input_tokens": 0,
1806                    "tool_use_prompt_tokens": 0,
1807                    "reasoning_tokens": 0,
1808                }
1809            })
1810        );
1811
1812        let item: MultiTurnStreamItem<MockResponse> =
1813            serde_json::from_value(value).expect("deserialize completion call event");
1814        match item {
1815            MultiTurnStreamItem::CompletionCall(call_usage) => {
1816                assert_eq!(call_usage, CompletionCall::new(2, usage(3, 4)));
1817            }
1818            other => panic!("expected completion call event, got {other:?}"),
1819        }
1820
1821        let item: MultiTurnStreamItem<MockResponse> =
1822            MultiTurnStreamItem::CompletionCall(CompletionCall::new(3, Usage::new()));
1823        let value = serde_json::to_value(&item).expect("serialize missing usage event");
1824
1825        // Unreported usage serializes as a plain zero-valued object (Usage's
1826        // documented sentinel for missing provider metrics).
1827        assert_eq!(
1828            value,
1829            serde_json::json!({
1830                "type": "completionCall",
1831                "call_index": 3,
1832                "usage": {
1833                    "input_tokens": 0,
1834                    "output_tokens": 0,
1835                    "total_tokens": 0,
1836                    "cached_input_tokens": 0,
1837                    "cache_creation_input_tokens": 0,
1838                    "tool_use_prompt_tokens": 0,
1839                    "reasoning_tokens": 0,
1840                }
1841            })
1842        );
1843
1844        // Stream items serialized before the Option encoding was dropped used
1845        // `"usage": null`; they must still deserialize.
1846        let legacy: MultiTurnStreamItem<MockResponse> = serde_json::from_value(serde_json::json!({
1847            "type": "completionCall",
1848            "call_index": 3,
1849            "usage": null
1850        }))
1851        .expect("legacy null-usage event should deserialize");
1852        match legacy {
1853            MultiTurnStreamItem::CompletionCall(call) => {
1854                assert_eq!(call, CompletionCall::new(3, Usage::new()));
1855            }
1856            other => panic!("expected completion call event, got {other:?}"),
1857        }
1858    }
1859
1860    #[test]
1861    fn final_response_serializes_completion_calls_with_missing_usage() {
1862        let item: MultiTurnStreamItem<MockResponse> =
1863            MultiTurnStreamItem::final_response_with_completion_calls(
1864                OneOrMany::one(AssistantContent::text("done")),
1865                usage(3, 4),
1866                vec![
1867                    CompletionCall::new(0, Usage::new()),
1868                    CompletionCall::new(1, usage(3, 4)),
1869                ],
1870                None,
1871            );
1872
1873        if let MultiTurnStreamItem::FinalResponse(response) = &item {
1874            assert_eq!(response.requests(), 2);
1875        }
1876
1877        let value = serde_json::to_value(&item).expect("serialize final response");
1878
1879        assert_eq!(
1880            value.get("completionCalls"),
1881            Some(&serde_json::json!([
1882                {
1883                    "call_index": 0,
1884                    "usage": {
1885                        "input_tokens": 0,
1886                        "output_tokens": 0,
1887                        "total_tokens": 0,
1888                        "cached_input_tokens": 0,
1889                        "cache_creation_input_tokens": 0,
1890                        "tool_use_prompt_tokens": 0,
1891                        "reasoning_tokens": 0,
1892                    }
1893                },
1894                {
1895                    "call_index": 1,
1896                    "usage": {
1897                        "input_tokens": 3,
1898                        "output_tokens": 4,
1899                        "total_tokens": 7,
1900                        "cached_input_tokens": 0,
1901                        "cache_creation_input_tokens": 0,
1902                        "tool_use_prompt_tokens": 0,
1903                        "reasoning_tokens": 0,
1904                    }
1905                }
1906            ]))
1907        );
1908    }
1909
1910    fn streaming_text_then_final_model() -> MockCompletionModel {
1911        MockCompletionModel::from_stream_turns([[
1912            MockStreamEvent::text("hello"),
1913            MockStreamEvent::text(" world"),
1914            MockStreamEvent::final_response_with_total_tokens(3),
1915        ]])
1916    }
1917
1918    fn citation_metadata() -> serde_json::Value {
1919        serde_json::json!({
1920            "citations": [{
1921                "type": "web_search_result_location",
1922                "cited_text": "Claude Shannon was born in 1916.",
1923                "url": "https://example.com/shannon",
1924                "title": "Claude Shannon",
1925                "encrypted_index": "encrypted-reference"
1926            }]
1927        })
1928    }
1929
1930    fn streaming_cited_text_then_final_model() -> MockCompletionModel {
1931        MockCompletionModel::from_stream_turns([[
1932            MockStreamEvent::text_start(Some(citation_metadata())),
1933            MockStreamEvent::text("cited "),
1934            MockStreamEvent::text_start(None),
1935            MockStreamEvent::text("answer"),
1936            MockStreamEvent::final_response_with_total_tokens(3),
1937        ]])
1938    }
1939
1940    fn streaming_cited_text_then_tool_model() -> MockCompletionModel {
1941        MockCompletionModel::from_stream_turns([
1942            vec![
1943                MockStreamEvent::text_start(Some(citation_metadata())),
1944                MockStreamEvent::text("I need a tool. "),
1945                MockStreamEvent::tool_call(
1946                    "tool_call_1",
1947                    "add",
1948                    serde_json::json!({"x": 1, "y": 2}),
1949                )
1950                .with_call_id("call_1"),
1951                MockStreamEvent::final_response_with_total_tokens(4),
1952            ],
1953            vec![
1954                MockStreamEvent::text("done"),
1955                MockStreamEvent::final_response_with_total_tokens(6),
1956            ],
1957        ])
1958    }
1959
1960    fn streaming_final_only_model() -> MockCompletionModel {
1961        MockCompletionModel::from_stream_turns([[
1962            MockStreamEvent::final_response_with_total_tokens(1),
1963        ]])
1964    }
1965
1966    #[derive(Clone)]
1967    struct TerminateOnStreamFinish;
1968
1969    impl PromptHook<MockCompletionModel> for TerminateOnStreamFinish {
1970        async fn on_stream_completion_response_finish(
1971            &self,
1972            _prompt: &Message,
1973            _response: &<MockCompletionModel as CompletionModel>::StreamingResponse,
1974        ) -> HookAction {
1975            HookAction::terminate("stop after completion call")
1976        }
1977    }
1978
1979    type RecordedToolCallDelta = (String, String, Option<String>, String);
1980
1981    #[derive(Clone)]
1982    struct RepairDefaultApiHook;
1983
1984    impl PromptHook<MockCompletionModel> for RepairDefaultApiHook {
1985        fn on_invalid_tool_call(
1986            &self,
1987            context: &InvalidToolCallContext,
1988        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
1989            let tool_name = context.tool_name.clone();
1990            async move {
1991                assert_eq!(tool_name, "default_api");
1992                InvalidToolCallHookAction::repair("add")
1993            }
1994        }
1995    }
1996
1997    #[derive(Clone)]
1998    struct RetryDefaultApiHook;
1999
2000    impl PromptHook<MockCompletionModel> for RetryDefaultApiHook {
2001        fn on_invalid_tool_call(
2002            &self,
2003            context: &InvalidToolCallContext,
2004        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2005            let tool_name = context.tool_name.clone();
2006            let args = context.args.clone();
2007            async move {
2008                assert_eq!(tool_name, "default_api");
2009                if let Some(args) = args {
2010                    assert!(!args.is_empty());
2011                }
2012                InvalidToolCallHookAction::retry("Use the add tool instead")
2013            }
2014        }
2015    }
2016
2017    #[derive(Clone)]
2018    struct SkipDefaultApiHook;
2019
2020    impl PromptHook<MockCompletionModel> for SkipDefaultApiHook {
2021        fn on_invalid_tool_call(
2022            &self,
2023            context: &InvalidToolCallContext,
2024        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2025            let tool_name = context.tool_name.clone();
2026            async move {
2027                assert_eq!(tool_name, "default_api");
2028                InvalidToolCallHookAction::skip("default_api was skipped")
2029            }
2030        }
2031    }
2032
2033    #[derive(Clone, Default)]
2034    struct RecordingInvalidToolCallHook {
2035        contexts: Arc<Mutex<Vec<InvalidToolCallContext>>>,
2036    }
2037
2038    impl RecordingInvalidToolCallHook {
2039        fn observed(&self) -> Vec<InvalidToolCallContext> {
2040            self.contexts
2041                .lock()
2042                .expect("invalid tool context records mutex was poisoned")
2043                .clone()
2044        }
2045    }
2046
2047    impl PromptHook<MockCompletionModel> for RecordingInvalidToolCallHook {
2048        fn on_invalid_tool_call(
2049            &self,
2050            context: &InvalidToolCallContext,
2051        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2052            let contexts = self.contexts.clone();
2053            let context = context.clone();
2054
2055            async move {
2056                contexts
2057                    .lock()
2058                    .expect("invalid tool context records mutex was poisoned")
2059                    .push(context);
2060                InvalidToolCallHookAction::fail()
2061            }
2062        }
2063    }
2064
2065    #[derive(Clone, Default)]
2066    struct RecordingToolCallDeltaHook {
2067        deltas: Arc<Mutex<Vec<RecordedToolCallDelta>>>,
2068    }
2069
2070    impl RecordingToolCallDeltaHook {
2071        fn observed(&self) -> Vec<RecordedToolCallDelta> {
2072            self.deltas
2073                .lock()
2074                .expect("tool call delta hook records mutex was poisoned")
2075                .clone()
2076        }
2077    }
2078
2079    impl PromptHook<MockCompletionModel> for RecordingToolCallDeltaHook {
2080        fn on_tool_call_delta(
2081            &self,
2082            tool_call_id: &str,
2083            internal_call_id: &str,
2084            tool_name: Option<&str>,
2085            tool_call_delta: &str,
2086        ) -> impl Future<Output = HookAction> + Send {
2087            let deltas = self.deltas.clone();
2088            let event = (
2089                tool_call_id.to_string(),
2090                internal_call_id.to_string(),
2091                tool_name.map(str::to_string),
2092                tool_call_delta.to_string(),
2093            );
2094
2095            async move {
2096                deltas
2097                    .lock()
2098                    .expect("tool call delta hook records mutex was poisoned")
2099                    .push(event);
2100                HookAction::cont()
2101            }
2102        }
2103    }
2104
2105    #[derive(Clone, Default)]
2106    struct RecordingTextDeltaHook {
2107        deltas: Arc<Mutex<Vec<(String, String)>>>,
2108    }
2109
2110    impl RecordingTextDeltaHook {
2111        fn observed(&self) -> Vec<(String, String)> {
2112            self.deltas
2113                .lock()
2114                .expect("text delta hook records mutex was poisoned")
2115                .clone()
2116        }
2117    }
2118
2119    impl PromptHook<MockCompletionModel> for RecordingTextDeltaHook {
2120        fn on_text_delta(
2121            &self,
2122            text_delta: &str,
2123            full_text: &str,
2124        ) -> impl Future<Output = HookAction> + Send {
2125            let deltas = self.deltas.clone();
2126            let event = (text_delta.to_string(), full_text.to_string());
2127
2128            async move {
2129                deltas
2130                    .lock()
2131                    .expect("text delta hook records mutex was poisoned")
2132                    .push(event);
2133                HookAction::cont()
2134            }
2135        }
2136    }
2137
2138    #[derive(Clone)]
2139    struct RecordingTextAndSkipInvalidToolHook {
2140        text: RecordingTextDeltaHook,
2141    }
2142
2143    impl PromptHook<MockCompletionModel> for RecordingTextAndSkipInvalidToolHook {
2144        fn on_text_delta(
2145            &self,
2146            text_delta: &str,
2147            full_text: &str,
2148        ) -> impl Future<Output = HookAction> + Send {
2149            self.text.on_text_delta(text_delta, full_text)
2150        }
2151
2152        fn on_invalid_tool_call(
2153            &self,
2154            context: &InvalidToolCallContext,
2155        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2156            SkipDefaultApiHook.on_invalid_tool_call(context)
2157        }
2158    }
2159
2160    #[derive(Clone)]
2161    struct RecordingTextAndRetryInvalidToolHook {
2162        text: RecordingTextDeltaHook,
2163    }
2164
2165    impl PromptHook<MockCompletionModel> for RecordingTextAndRetryInvalidToolHook {
2166        fn on_text_delta(
2167            &self,
2168            text_delta: &str,
2169            full_text: &str,
2170        ) -> impl Future<Output = HookAction> + Send {
2171            self.text.on_text_delta(text_delta, full_text)
2172        }
2173
2174        fn on_invalid_tool_call(
2175            &self,
2176            context: &InvalidToolCallContext,
2177        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2178            RetryDefaultApiHook.on_invalid_tool_call(context)
2179        }
2180    }
2181
2182    #[derive(Clone)]
2183    struct RecordingDeltaAndRetryInvalidToolHook {
2184        delta: RecordingToolCallDeltaHook,
2185    }
2186
2187    impl PromptHook<MockCompletionModel> for RecordingDeltaAndRetryInvalidToolHook {
2188        fn on_tool_call_delta(
2189            &self,
2190            tool_call_id: &str,
2191            internal_call_id: &str,
2192            tool_name: Option<&str>,
2193            tool_call_delta: &str,
2194        ) -> impl Future<Output = HookAction> + Send {
2195            self.delta.on_tool_call_delta(
2196                tool_call_id,
2197                internal_call_id,
2198                tool_name,
2199                tool_call_delta,
2200            )
2201        }
2202
2203        fn on_invalid_tool_call(
2204            &self,
2205            context: &InvalidToolCallContext,
2206        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2207            RetryDefaultApiHook.on_invalid_tool_call(context)
2208        }
2209    }
2210
2211    #[derive(Clone)]
2212    struct RecordingDeltaAndSkipInvalidToolHook {
2213        delta: RecordingToolCallDeltaHook,
2214    }
2215
2216    impl PromptHook<MockCompletionModel> for RecordingDeltaAndSkipInvalidToolHook {
2217        fn on_tool_call_delta(
2218            &self,
2219            tool_call_id: &str,
2220            internal_call_id: &str,
2221            tool_name: Option<&str>,
2222            tool_call_delta: &str,
2223        ) -> impl Future<Output = HookAction> + Send {
2224            self.delta.on_tool_call_delta(
2225                tool_call_id,
2226                internal_call_id,
2227                tool_name,
2228                tool_call_delta,
2229            )
2230        }
2231
2232        fn on_invalid_tool_call(
2233            &self,
2234            context: &InvalidToolCallContext,
2235        ) -> impl Future<Output = InvalidToolCallHookAction> + Send {
2236            SkipDefaultApiHook.on_invalid_tool_call(context)
2237        }
2238    }
2239
2240    #[derive(Clone, Default)]
2241    struct TerminatingToolCallDeltaHook {
2242        deltas: Arc<Mutex<Vec<RecordedToolCallDelta>>>,
2243    }
2244
2245    impl TerminatingToolCallDeltaHook {
2246        fn observed(&self) -> Vec<RecordedToolCallDelta> {
2247            self.deltas
2248                .lock()
2249                .expect("tool call delta hook records mutex was poisoned")
2250                .clone()
2251        }
2252    }
2253
2254    impl PromptHook<MockCompletionModel> for TerminatingToolCallDeltaHook {
2255        fn on_tool_call_delta(
2256            &self,
2257            tool_call_id: &str,
2258            internal_call_id: &str,
2259            tool_name: Option<&str>,
2260            tool_call_delta: &str,
2261        ) -> impl Future<Output = HookAction> + Send {
2262            let deltas = self.deltas.clone();
2263            let event = (
2264                tool_call_id.to_string(),
2265                internal_call_id.to_string(),
2266                tool_name.map(str::to_string),
2267                tool_call_delta.to_string(),
2268            );
2269
2270            async move {
2271                deltas
2272                    .lock()
2273                    .expect("tool call delta hook records mutex was poisoned")
2274                    .push(event);
2275                HookAction::terminate("stop on tool call delta")
2276            }
2277        }
2278    }
2279
2280    fn text_metadata(content: &OneOrMany<AssistantContent>) -> Option<&serde_json::Value> {
2281        content.iter().find_map(|item| match item {
2282            AssistantContent::Text(text) => text.additional_params.as_ref(),
2283            _ => None,
2284        })
2285    }
2286
2287    #[tokio::test]
2288    async fn stream_prompt_continues_after_tool_call_turn() {
2289        let model = streaming_tool_then_text_model();
2290        let recorded = model.clone();
2291        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2292        let empty_history: &[Message] = &[];
2293
2294        let mut stream = agent
2295            .stream_prompt("do tool work")
2296            .with_history(empty_history)
2297            .multi_turn(3)
2298            .await;
2299        let mut saw_tool_call = false;
2300        let mut saw_tool_result = false;
2301        let mut saw_final_response = false;
2302        let mut final_text = String::new();
2303        let mut final_response_text = None;
2304        let mut final_history = None;
2305
2306        while let Some(item) = stream.next().await {
2307            match item {
2308                Ok(MultiTurnStreamItem::StreamAssistantItem(
2309                    StreamedAssistantContent::ToolCall { .. },
2310                )) => {
2311                    saw_tool_call = true;
2312                }
2313                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2314                    ..
2315                })) => {
2316                    saw_tool_result = true;
2317                }
2318                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
2319                    text,
2320                ))) => {
2321                    final_text.push_str(&text.text);
2322                }
2323                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
2324                    saw_final_response = true;
2325                    final_response_text = Some(res.response().to_owned());
2326                    final_history = res.history().map(|history| history.to_vec());
2327                    break;
2328                }
2329                Ok(_) => {}
2330                Err(err) => panic!("unexpected streaming error: {err:?}"),
2331            }
2332        }
2333
2334        assert!(saw_tool_call);
2335        assert!(saw_tool_result);
2336        assert!(saw_final_response);
2337        assert_eq!(final_text, "done");
2338        assert_eq!(final_response_text.as_deref(), Some("done"));
2339        let history = final_history.expect("expected final response history");
2340        assert!(history.iter().any(|message| matches!(
2341            message,
2342            Message::Assistant { content, .. }
2343                if content.iter().any(|item| matches!(
2344                    item,
2345                    AssistantContent::Text(text) if text.text == "done"
2346                ))
2347        )));
2348        let requests = recorded.requests();
2349        assert_eq!(requests.len(), 2);
2350        assert!(validate_follow_up_tool_history(&requests[1]).is_ok());
2351    }
2352
2353    #[tokio::test]
2354    async fn unknown_tool_call_fails_before_streaming_second_request() {
2355        let model = MockCompletionModel::from_stream_turns([
2356            vec![
2357                MockStreamEvent::tool_call(
2358                    "tool_call_1",
2359                    "default_api",
2360                    serde_json::json!({"x": 1, "y": 2}),
2361                ),
2362                MockStreamEvent::final_response_with_total_tokens(4),
2363            ],
2364            vec![
2365                MockStreamEvent::text("should not be requested"),
2366                MockStreamEvent::final_response_with_total_tokens(6),
2367            ],
2368        ]);
2369        let recorded = model.clone();
2370        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2371
2372        let mut stream = agent
2373            .stream_prompt("use the tool")
2374            .with_hook(PanicOnUnknownToolHook)
2375            .multi_turn(3)
2376            .await;
2377        let mut saw_tool_call = false;
2378        let mut error = None;
2379
2380        while let Some(item) = stream.next().await {
2381            match item {
2382                Ok(MultiTurnStreamItem::StreamAssistantItem(
2383                    StreamedAssistantContent::ToolCall { .. },
2384                )) => {
2385                    saw_tool_call = true;
2386                }
2387                Ok(_) => {}
2388                Err(err) => {
2389                    error = Some(err);
2390                    break;
2391                }
2392            }
2393        }
2394
2395        assert!(!saw_tool_call);
2396        let error = error.expect("unknown model-emitted tool should fail");
2397        match error {
2398            StreamingError::Prompt(err) => match *err {
2399                PromptError::UnknownToolCall {
2400                    tool_name,
2401                    available_tools,
2402                    allowed_tools,
2403                    chat_history,
2404                } => {
2405                    assert_eq!(tool_name, "default_api");
2406                    assert_eq!(available_tools, vec!["add".to_string()]);
2407                    assert_eq!(allowed_tools, vec!["add".to_string()]);
2408                    assert!(history_contains_tool_call(&chat_history, "default_api"));
2409                }
2410                other => panic!("expected UnknownToolCall, got {other:?}"),
2411            },
2412            other => panic!("expected prompt streaming error, got {other:?}"),
2413        }
2414        assert_eq!(recorded.request_count(), 1);
2415    }
2416
2417    #[tokio::test]
2418    async fn invalid_tool_call_hook_can_repair_streaming_tool_name() {
2419        let model = MockCompletionModel::from_stream_turns([
2420            vec![
2421                MockStreamEvent::tool_call(
2422                    "tool_call_1",
2423                    "default_api",
2424                    serde_json::json!({"x": 2, "y": 3}),
2425                ),
2426                MockStreamEvent::final_response_with_total_tokens(4),
2427            ],
2428            vec![
2429                MockStreamEvent::text("done"),
2430                MockStreamEvent::final_response_with_total_tokens(6),
2431            ],
2432        ]);
2433        let recorded = model.clone();
2434        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2435
2436        let mut stream = agent
2437            .stream_prompt("use the tool")
2438            .with_hook(RepairDefaultApiHook)
2439            .multi_turn(3)
2440            .with_history(Vec::<Message>::new())
2441            .await;
2442        let mut saw_repaired_tool_call = false;
2443        let mut saw_tool_result = false;
2444        let mut final_response_text = None;
2445
2446        while let Some(item) = stream.next().await {
2447            match item {
2448                Ok(MultiTurnStreamItem::StreamAssistantItem(
2449                    StreamedAssistantContent::ToolCall { tool_call, .. },
2450                )) => {
2451                    assert_eq!(tool_call.function.name, "add");
2452                    saw_repaired_tool_call = true;
2453                }
2454                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2455                    tool_result,
2456                    ..
2457                })) => {
2458                    assert!(tool_result.content.iter().any(|content| {
2459                        matches!(
2460                            content,
2461                            ToolResultContent::Text(text) if text.text == "5"
2462                        )
2463                    }));
2464                    saw_tool_result = true;
2465                }
2466                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2467                    final_response_text = Some(response.response().to_string());
2468                    break;
2469                }
2470                Ok(_) => {}
2471                Err(err) => panic!("unexpected streaming error: {err:?}"),
2472            }
2473        }
2474
2475        assert!(saw_repaired_tool_call);
2476        assert!(saw_tool_result);
2477        assert_eq!(final_response_text.as_deref(), Some("done"));
2478        assert_eq!(recorded.request_count(), 2);
2479    }
2480
2481    #[tokio::test]
2482    async fn invalid_tool_call_context_uses_completed_streaming_tool_call_provider_id() {
2483        let invalid_hook = RecordingInvalidToolCallHook::default();
2484        let model = MockCompletionModel::from_stream_turns([
2485            vec![
2486                MockStreamEvent::tool_call(
2487                    "tool_call_1",
2488                    "default_api",
2489                    serde_json::json!({"x": 2, "y": 3}),
2490                )
2491                .with_call_id("provider_call_1"),
2492                MockStreamEvent::final_response_with_total_tokens(4),
2493            ],
2494            vec![
2495                MockStreamEvent::text("should not be requested"),
2496                MockStreamEvent::final_response_with_total_tokens(6),
2497            ],
2498        ]);
2499        let recorded = model.clone();
2500        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2501
2502        let mut stream = agent
2503            .stream_prompt("use the tool")
2504            .with_hook(invalid_hook.clone())
2505            .multi_turn(3)
2506            .await;
2507        let mut error = None;
2508
2509        while let Some(item) = stream.next().await {
2510            if let Err(err) = item {
2511                error = Some(err);
2512                break;
2513            }
2514        }
2515
2516        assert!(error.is_some(), "invalid tool should fail");
2517        assert_eq!(recorded.request_count(), 1);
2518        let contexts = invalid_hook.observed();
2519        assert_eq!(contexts.len(), 1);
2520        let context = &contexts[0];
2521        assert_eq!(context.tool_name, "default_api");
2522        assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
2523        assert!(context.internal_call_id.is_some());
2524        assert!(context.is_streaming);
2525    }
2526
2527    #[tokio::test]
2528    async fn invalid_tool_call_hook_skip_emits_streaming_tool_result() {
2529        let add_calls = Arc::new(AtomicU32::new(0));
2530        let model = MockCompletionModel::from_stream_turns([
2531            vec![
2532                MockStreamEvent::tool_call(
2533                    "tool_call_1",
2534                    "default_api",
2535                    serde_json::json!({"x": 2, "y": 3}),
2536                )
2537                .with_call_id("call_1"),
2538                MockStreamEvent::final_response_with_total_tokens(4),
2539            ],
2540            vec![
2541                MockStreamEvent::text("continued"),
2542                MockStreamEvent::final_response_with_total_tokens(6),
2543            ],
2544        ]);
2545        let recorded = model.clone();
2546        let agent = AgentBuilder::new(model)
2547            .tool(CountingAddTool {
2548                calls: add_calls.clone(),
2549            })
2550            .build();
2551
2552        let mut stream = agent
2553            .stream_prompt("use the tool")
2554            .with_hook(SkipDefaultApiHook)
2555            .multi_turn(3)
2556            .with_history(Vec::<Message>::new())
2557            .await;
2558        let mut skipped_tool_result = None;
2559        let mut final_response_text = None;
2560
2561        while let Some(item) = stream.next().await {
2562            match item {
2563                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2564                    tool_result,
2565                    internal_call_id,
2566                })) => {
2567                    assert!(!internal_call_id.is_empty());
2568                    skipped_tool_result = Some(tool_result);
2569                }
2570                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2571                    final_response_text = Some(response.response().to_string());
2572                    break;
2573                }
2574                Ok(_) => {}
2575                Err(err) => panic!("unexpected streaming error: {err:?}"),
2576            }
2577        }
2578
2579        let skipped_tool_result =
2580            skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
2581        assert_eq!(skipped_tool_result.id, "tool_call_1");
2582        assert_eq!(skipped_tool_result.call_id.as_deref(), Some("call_1"));
2583        assert!(skipped_tool_result.content.iter().any(|content| matches!(
2584            content,
2585            ToolResultContent::Text(text) if text.text == "default_api was skipped"
2586        )));
2587        assert_eq!(final_response_text.as_deref(), Some("continued"));
2588        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
2589
2590        let requests = recorded.requests();
2591        assert_eq!(requests.len(), 2);
2592        let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
2593        assert!(matches!(
2594            follow_up_history.get(2),
2595            Some(Message::User { content })
2596                if content.iter().any(|item| matches!(
2597                    item,
2598                    UserContent::ToolResult(result)
2599                        if result.id == "tool_call_1"
2600                            && result.content.iter().any(|content| matches!(
2601                                content,
2602                                ToolResultContent::Text(text)
2603                                    if text.text == "default_api was skipped"
2604                            ))
2605                ))
2606        ));
2607    }
2608
2609    #[tokio::test]
2610    async fn invalid_tool_call_hook_retries_mixed_streaming_turn_without_executing_valid_call() {
2611        let add_calls = Arc::new(AtomicU32::new(0));
2612        let model = MockCompletionModel::from_stream_turns([
2613            vec![
2614                MockStreamEvent::text("checking "),
2615                MockStreamEvent::tool_call(
2616                    "tool_call_1",
2617                    "add",
2618                    serde_json::json!({"x": 2, "y": 3}),
2619                )
2620                .with_call_id("call_1"),
2621                MockStreamEvent::tool_call(
2622                    "tool_call_2",
2623                    "default_api",
2624                    serde_json::json!({"x": 4, "y": 5}),
2625                )
2626                .with_call_id("call_2"),
2627                MockStreamEvent::final_response_with_total_tokens(4),
2628            ],
2629            vec![
2630                MockStreamEvent::text("retried"),
2631                MockStreamEvent::final_response_with_total_tokens(6),
2632            ],
2633        ]);
2634        let recorded = model.clone();
2635        let agent = AgentBuilder::new(model)
2636            .tool(CountingAddTool {
2637                calls: add_calls.clone(),
2638            })
2639            .build();
2640
2641        let mut stream = agent
2642            .stream_prompt("use the tool")
2643            .with_hook(RetryDefaultApiHook)
2644            .multi_turn(3)
2645            .with_history(Vec::<Message>::new())
2646            .max_invalid_tool_call_retries(1)
2647            .await;
2648        let mut completion_call_events = Vec::new();
2649        let mut final_response_text = None;
2650        let mut final_response_usage = Usage::new();
2651        let mut final_completion_calls = Vec::new();
2652
2653        while let Some(item) = stream.next().await {
2654            match item {
2655                Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
2656                    completion_call_events.push(completion_call);
2657                }
2658                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2659                    final_response_text = Some(response.response().to_string());
2660                    final_response_usage = response.usage();
2661                    final_completion_calls = response.completion_calls().to_vec();
2662                    break;
2663                }
2664                Ok(_) => {}
2665                Err(err) => panic!("unexpected streaming error: {err:?}"),
2666            }
2667        }
2668
2669        assert_eq!(final_response_text.as_deref(), Some("retried"));
2670        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
2671        let mut first_usage = Usage::new();
2672        first_usage.total_tokens = 4;
2673        let mut second_usage = Usage::new();
2674        second_usage.total_tokens = 6;
2675        let expected_completion_calls = vec![
2676            CompletionCall::new(0, first_usage),
2677            CompletionCall::new(1, second_usage),
2678        ];
2679        assert_eq!(completion_call_events, expected_completion_calls);
2680        assert_eq!(final_completion_calls, expected_completion_calls);
2681        assert_eq!(final_response_usage.total_tokens, 10);
2682
2683        let requests = recorded.requests();
2684        assert_eq!(requests.len(), 2);
2685        let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
2686        assert_eq!(retry_history.len(), 3);
2687        assert!(matches!(
2688            retry_history.get(1),
2689            Some(Message::Assistant { content, .. })
2690                if content.iter().any(|item| matches!(
2691                    item,
2692                    AssistantContent::Text(text) if text.text == "checking "
2693                ))
2694                    && content.iter().any(|item| matches!(
2695                        item,
2696                        AssistantContent::ToolCall(tool_call)
2697                            if tool_call.id == "tool_call_1"
2698                                && tool_call.function.name == "add"
2699                    ))
2700                    && content.iter().any(|item| matches!(
2701                        item,
2702                        AssistantContent::ToolCall(tool_call)
2703                            if tool_call.id == "tool_call_2"
2704                                && tool_call.function.name == "default_api"
2705                    ))
2706        ));
2707        assert!(matches!(
2708            retry_history.get(2),
2709            Some(Message::User { content })
2710                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
2711                    && content.iter().any(|item| matches!(
2712                        item,
2713                        UserContent::ToolResult(result)
2714                            if result.id == "tool_call_1"
2715                                && result.content.iter().any(|content| matches!(
2716                                    content,
2717                                    ToolResultContent::Text(text)
2718                                        if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
2719                                ))
2720                    ))
2721                    && content.iter().any(|item| matches!(
2722                        item,
2723                        UserContent::ToolResult(result)
2724                            if result.id == "tool_call_2"
2725                                && result.content.iter().any(|content| matches!(
2726                                    content,
2727                                    ToolResultContent::Text(text)
2728                                        if text.text == "Use the add tool instead"
2729                                ))
2730                    ))
2731        ));
2732    }
2733
2734    #[tokio::test]
2735    async fn invalid_tool_call_hook_skips_mixed_streaming_turn_without_executing_valid_call() {
2736        let add_calls = Arc::new(AtomicU32::new(0));
2737        let model = MockCompletionModel::from_stream_turns([
2738            vec![
2739                MockStreamEvent::text("checking "),
2740                MockStreamEvent::tool_call(
2741                    "tool_call_1",
2742                    "add",
2743                    serde_json::json!({"x": 2, "y": 3}),
2744                )
2745                .with_call_id("call_1"),
2746                MockStreamEvent::tool_call(
2747                    "tool_call_2",
2748                    "default_api",
2749                    serde_json::json!({"x": 4, "y": 5}),
2750                )
2751                .with_call_id("call_2"),
2752                MockStreamEvent::final_response_with_total_tokens(4),
2753            ],
2754            vec![
2755                MockStreamEvent::text("continued"),
2756                MockStreamEvent::final_response_with_total_tokens(6),
2757            ],
2758        ]);
2759        let recorded = model.clone();
2760        let agent = AgentBuilder::new(model)
2761            .tool(CountingAddTool {
2762                calls: add_calls.clone(),
2763            })
2764            .build();
2765
2766        let mut stream = agent
2767            .stream_prompt("use the tool")
2768            .with_hook(SkipDefaultApiHook)
2769            .multi_turn(3)
2770            .with_history(Vec::<Message>::new())
2771            .await;
2772        let mut skipped_tool_result = None;
2773        let mut final_response_text = None;
2774
2775        while let Some(item) = stream.next().await {
2776            match item {
2777                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
2778                    tool_result,
2779                    ..
2780                })) => {
2781                    skipped_tool_result = Some(tool_result);
2782                }
2783                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
2784                    final_response_text = Some(response.response().to_string());
2785                    break;
2786                }
2787                Ok(_) => {}
2788                Err(err) => panic!("unexpected streaming error: {err:?}"),
2789            }
2790        }
2791
2792        let skipped_tool_result =
2793            skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
2794        assert_eq!(skipped_tool_result.id, "tool_call_2");
2795        assert_eq!(skipped_tool_result.call_id.as_deref(), Some("call_2"));
2796        assert_eq!(final_response_text.as_deref(), Some("continued"));
2797        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
2798
2799        let requests = recorded.requests();
2800        assert_eq!(requests.len(), 2);
2801        let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
2802        assert_eq!(follow_up_history.len(), 3);
2803        assert!(matches!(
2804            follow_up_history.get(1),
2805            Some(Message::Assistant { content, .. })
2806                if content.iter().any(|item| matches!(
2807                    item,
2808                    AssistantContent::Text(text) if text.text == "checking "
2809                ))
2810                    && content.iter().any(|item| matches!(
2811                        item,
2812                        AssistantContent::ToolCall(tool_call)
2813                            if tool_call.id == "tool_call_1"
2814                                && tool_call.function.name == "add"
2815                    ))
2816                    && content.iter().any(|item| matches!(
2817                        item,
2818                        AssistantContent::ToolCall(tool_call)
2819                            if tool_call.id == "tool_call_2"
2820                                && tool_call.function.name == "default_api"
2821                    ))
2822        ));
2823        assert!(matches!(
2824            follow_up_history.get(2),
2825            Some(Message::User { content })
2826                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
2827                    && content.iter().any(|item| matches!(
2828                        item,
2829                        UserContent::ToolResult(result)
2830                            if result.id == "tool_call_1"
2831                                && result.call_id.as_deref() == Some("call_1")
2832                                && result.content.iter().any(|content| matches!(
2833                                    content,
2834                                    ToolResultContent::Text(text)
2835                                        if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
2836                                ))
2837                    ))
2838                    && content.iter().any(|item| matches!(
2839                        item,
2840                        UserContent::ToolResult(result)
2841                            if result.id == "tool_call_2"
2842                                && result.call_id.as_deref() == Some("call_2")
2843                                && result.content.iter().any(|content| matches!(
2844                                    content,
2845                                    ToolResultContent::Text(text)
2846                                        if text.text == "default_api was skipped"
2847                                ))
2848            ))
2849        ));
2850    }
2851
2852    #[tokio::test]
2853    async fn invalid_completed_tool_call_skip_preserves_streaming_reasoning_history() {
2854        let model = MockCompletionModel::from_stream_turns([
2855            vec![
2856                MockStreamEvent::text("checking "),
2857                MockStreamEvent::reasoning("reasoned step").with_reasoning_id("rs_1"),
2858                MockStreamEvent::tool_call(
2859                    "tool_call_1",
2860                    "default_api",
2861                    serde_json::json!({"x": 2, "y": 3}),
2862                ),
2863                MockStreamEvent::final_response_with_total_tokens(4),
2864            ],
2865            vec![
2866                MockStreamEvent::text("continued"),
2867                MockStreamEvent::final_response_with_total_tokens(6),
2868            ],
2869        ]);
2870        let recorded = model.clone();
2871        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2872
2873        let mut stream = agent
2874            .stream_prompt("use the tool")
2875            .with_hook(SkipDefaultApiHook)
2876            .multi_turn(3)
2877            .with_history(Vec::<Message>::new())
2878            .await;
2879
2880        while let Some(item) = stream.next().await {
2881            match item {
2882                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
2883                Ok(_) => {}
2884                Err(err) => panic!("unexpected streaming error: {err:?}"),
2885            }
2886        }
2887
2888        let requests = recorded.requests();
2889        assert_eq!(requests.len(), 2);
2890        let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
2891        assert!(history_contains_text(&follow_up_history, "checking "));
2892        assert!(assistant_reasoning_precedes_tool_call(
2893            &follow_up_history,
2894            "reasoned step",
2895            "default_api"
2896        ));
2897        assert!(
2898            assistant_reasoning_precedes_text_and_tool_call(
2899                &follow_up_history,
2900                "reasoned step",
2901                "checking ",
2902                "default_api"
2903            ),
2904            "{follow_up_history:?}"
2905        );
2906    }
2907
2908    #[tokio::test]
2909    async fn invalid_name_delta_retry_preserves_streaming_reasoning_history() {
2910        let model = MockCompletionModel::from_stream_turns([
2911            vec![
2912                MockStreamEvent::reasoning_delta(Some("rs_1"), "delta reason"),
2913                MockStreamEvent::tool_call_arguments_delta(
2914                    "tool_call_1",
2915                    "internal_1",
2916                    r#"{"x":2,"y":3}"#,
2917                ),
2918                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
2919                MockStreamEvent::final_response_with_total_tokens(4),
2920            ],
2921            vec![
2922                MockStreamEvent::text("retried"),
2923                MockStreamEvent::final_response_with_total_tokens(6),
2924            ],
2925        ]);
2926        let recorded = model.clone();
2927        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2928
2929        let mut stream = agent
2930            .stream_prompt("use the tool")
2931            .with_hook(RetryDefaultApiHook)
2932            .multi_turn(3)
2933            .with_history(Vec::<Message>::new())
2934            .max_invalid_tool_call_retries(1)
2935            .await;
2936
2937        while let Some(item) = stream.next().await {
2938            match item {
2939                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
2940                Ok(_) => {}
2941                Err(err) => panic!("unexpected streaming error: {err:?}"),
2942            }
2943        }
2944
2945        let requests = recorded.requests();
2946        assert_eq!(requests.len(), 2);
2947        let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
2948        assert!(assistant_reasoning_precedes_tool_call(
2949            &retry_history,
2950            "delta reason",
2951            "default_api"
2952        ));
2953    }
2954
2955    #[tokio::test]
2956    async fn invalid_tool_call_hook_skip_resets_streaming_text_delta_state() {
2957        let text_hook = RecordingTextDeltaHook::default();
2958        let model = MockCompletionModel::from_stream_turns([
2959            vec![
2960                MockStreamEvent::text("stale "),
2961                MockStreamEvent::tool_call(
2962                    "tool_call_1",
2963                    "default_api",
2964                    serde_json::json!({"x": 2, "y": 3}),
2965                ),
2966                MockStreamEvent::final_response_with_total_tokens(4),
2967            ],
2968            vec![
2969                MockStreamEvent::text("fresh"),
2970                MockStreamEvent::final_response_with_total_tokens(6),
2971            ],
2972        ]);
2973        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2974
2975        let mut stream = agent
2976            .stream_prompt("use the tool")
2977            .with_hook(RecordingTextAndSkipInvalidToolHook {
2978                text: text_hook.clone(),
2979            })
2980            .multi_turn(3)
2981            .with_history(Vec::<Message>::new())
2982            .await;
2983
2984        while let Some(item) = stream.next().await {
2985            match item {
2986                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
2987                Ok(_) => {}
2988                Err(err) => panic!("unexpected streaming error: {err:?}"),
2989            }
2990        }
2991
2992        assert_eq!(
2993            text_hook.observed(),
2994            vec![
2995                ("stale ".to_string(), "stale ".to_string()),
2996                ("fresh".to_string(), "fresh".to_string()),
2997            ]
2998        );
2999    }
3000
3001    #[tokio::test]
3002    async fn invalid_tool_call_delta_retry_uses_structured_tool_feedback() {
3003        let delta_hook = RecordingToolCallDeltaHook::default();
3004        let add_calls = Arc::new(AtomicU32::new(0));
3005        let model = MockCompletionModel::from_stream_turns([
3006            vec![
3007                MockStreamEvent::text("checking "),
3008                MockStreamEvent::reasoning_delta(Some("rs_1"), "diagnostic reason"),
3009                MockStreamEvent::tool_call(
3010                    "tool_call_0",
3011                    "add",
3012                    serde_json::json!({"x": 1, "y": 2}),
3013                )
3014                .with_call_id("call_0"),
3015                MockStreamEvent::tool_call_arguments_delta(
3016                    "tool_call_1",
3017                    "internal_1",
3018                    r#"{"x":2,"y":3}"#,
3019                ),
3020                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3021                MockStreamEvent::final_response_with_total_tokens(4),
3022            ],
3023            vec![
3024                MockStreamEvent::text("retried"),
3025                MockStreamEvent::final_response_with_total_tokens(6),
3026            ],
3027        ]);
3028        let recorded = model.clone();
3029        let agent = AgentBuilder::new(model)
3030            .tool(CountingAddTool {
3031                calls: add_calls.clone(),
3032            })
3033            .build();
3034
3035        let mut stream = agent
3036            .stream_prompt("use the tool")
3037            .with_hook(RecordingDeltaAndRetryInvalidToolHook {
3038                delta: delta_hook.clone(),
3039            })
3040            .multi_turn(3)
3041            .with_history(Vec::<Message>::new())
3042            .max_invalid_tool_call_retries(1)
3043            .await;
3044        let mut completion_call_events = Vec::new();
3045        let mut final_response_text = None;
3046        let mut final_response_usage = Usage::new();
3047        let mut final_completion_calls = Vec::new();
3048
3049        while let Some(item) = stream.next().await {
3050            match item {
3051                Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
3052                    completion_call_events.push(completion_call);
3053                }
3054                Ok(MultiTurnStreamItem::StreamAssistantItem(
3055                    StreamedAssistantContent::ToolCallDelta { .. },
3056                )) => panic!("invalid tool-call delta should not be emitted"),
3057                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3058                    final_response_text = Some(response.response().to_string());
3059                    final_response_usage = response.usage();
3060                    final_completion_calls = response.completion_calls().to_vec();
3061                    break;
3062                }
3063                Ok(_) => {}
3064                Err(err) => panic!("unexpected streaming error: {err:?}"),
3065            }
3066        }
3067
3068        assert_eq!(final_response_text.as_deref(), Some("retried"));
3069        assert!(delta_hook.observed().is_empty());
3070        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3071        let mut first_usage = Usage::new();
3072        first_usage.total_tokens = 4;
3073        let mut second_usage = Usage::new();
3074        second_usage.total_tokens = 6;
3075        let expected_completion_calls = vec![
3076            CompletionCall::new(0, first_usage),
3077            CompletionCall::new(1, second_usage),
3078        ];
3079        assert_eq!(completion_call_events, expected_completion_calls);
3080        assert_eq!(final_completion_calls, expected_completion_calls);
3081        assert_eq!(final_response_usage.total_tokens, 10);
3082
3083        let requests = recorded.requests();
3084        assert_eq!(requests.len(), 2);
3085        let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3086        assert!(matches!(
3087            retry_history.get(1),
3088            Some(Message::Assistant { content, .. })
3089                if content.iter().any(|item| matches!(
3090                    item,
3091                    AssistantContent::Text(text) if text.text == "checking "
3092                ))
3093                    && content.iter().any(|item| matches!(
3094                        item,
3095                        AssistantContent::ToolCall(tool_call)
3096                            if tool_call.id == "tool_call_0"
3097                                && tool_call.function.name == "add"
3098                    ))
3099                    && content.iter().any(|item| matches!(
3100                    item,
3101                    AssistantContent::ToolCall(tool_call)
3102                        if tool_call.id == "tool_call_1"
3103                            && tool_call.function.name == "default_api"
3104                            && tool_call.function.arguments == serde_json::json!({"x": 2, "y": 3})
3105                ))
3106        ));
3107        assert!(matches!(
3108            retry_history.get(2),
3109            Some(Message::User { content })
3110                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3111                    && content.iter().any(|item| matches!(
3112                        item,
3113                        UserContent::ToolResult(result)
3114                            if result.id == "tool_call_0"
3115                                && result.call_id.as_deref() == Some("call_0")
3116                                && result.content.iter().any(|content| matches!(
3117                                    content,
3118                                    ToolResultContent::Text(text)
3119                                        if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3120                                ))
3121                    ))
3122                    && content.iter().any(|item| matches!(
3123                    item,
3124                    UserContent::ToolResult(result)
3125                        if result.id == "tool_call_1"
3126                            && result.content.iter().any(|content| matches!(
3127                                content,
3128                                ToolResultContent::Text(text)
3129                                    if text.text == "Use the add tool instead"
3130                            ))
3131                ))
3132        ));
3133    }
3134
3135    #[tokio::test]
3136    async fn invalid_tool_call_delta_context_includes_same_turn_history_and_tool_call_id() {
3137        let invalid_hook = RecordingInvalidToolCallHook::default();
3138        let model = MockCompletionModel::from_stream_turns([
3139            vec![
3140                MockStreamEvent::text("checking "),
3141                MockStreamEvent::reasoning_delta(Some("rs_1"), "diagnostic reason"),
3142                MockStreamEvent::tool_call(
3143                    "tool_call_0",
3144                    "add",
3145                    serde_json::json!({"x": 1, "y": 2}),
3146                )
3147                .with_call_id("call_0"),
3148                MockStreamEvent::tool_call_arguments_delta(
3149                    "tool_call_1",
3150                    "internal_1",
3151                    r#"{"x":2,"y":3}"#,
3152                ),
3153                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3154                MockStreamEvent::final_response_with_total_tokens(4),
3155            ],
3156            vec![
3157                MockStreamEvent::text("should not be requested"),
3158                MockStreamEvent::final_response_with_total_tokens(6),
3159            ],
3160        ]);
3161        let recorded = model.clone();
3162        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3163
3164        let mut stream = agent
3165            .stream_prompt("use the tool")
3166            .with_hook(invalid_hook.clone())
3167            .multi_turn(3)
3168            .await;
3169        let mut error = None;
3170
3171        while let Some(item) = stream.next().await {
3172            if let Err(err) = item {
3173                error = Some(err);
3174                break;
3175            }
3176        }
3177
3178        assert!(error.is_some(), "invalid name delta should fail");
3179        assert_eq!(recorded.request_count(), 1);
3180        let contexts = invalid_hook.observed();
3181        assert_eq!(contexts.len(), 1);
3182        let context = &contexts[0];
3183        assert_eq!(context.tool_name, "default_api");
3184        assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
3185        assert_eq!(context.internal_call_id.as_deref(), Some("internal_1"));
3186        assert!(context.is_streaming);
3187        assert!(history_contains_text(&context.chat_history, "checking "));
3188        assert!(
3189            assistant_reasoning_precedes_tool_call(
3190                &context.chat_history,
3191                "diagnostic reason",
3192                "add"
3193            ),
3194            "{:?}",
3195            context.chat_history
3196        );
3197        assert!(history_contains_tool_call(&context.chat_history, "add"));
3198        assert!(history_contains_tool_call(
3199            &context.chat_history,
3200            "default_api"
3201        ));
3202    }
3203
3204    #[tokio::test]
3205    async fn invalid_tool_call_delta_retry_resets_streaming_text_delta_state() {
3206        let text_hook = RecordingTextDeltaHook::default();
3207        let model = MockCompletionModel::from_stream_turns([
3208            vec![
3209                MockStreamEvent::text("stale "),
3210                MockStreamEvent::tool_call_arguments_delta(
3211                    "tool_call_1",
3212                    "internal_1",
3213                    r#"{"x":2,"y":3}"#,
3214                ),
3215                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3216                MockStreamEvent::final_response_with_total_tokens(4),
3217            ],
3218            vec![
3219                MockStreamEvent::text("fresh"),
3220                MockStreamEvent::final_response_with_total_tokens(6),
3221            ],
3222        ]);
3223        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3224
3225        let mut stream = agent
3226            .stream_prompt("use the tool")
3227            .with_hook(RecordingTextAndRetryInvalidToolHook {
3228                text: text_hook.clone(),
3229            })
3230            .multi_turn(3)
3231            .with_history(Vec::<Message>::new())
3232            .max_invalid_tool_call_retries(1)
3233            .await;
3234
3235        while let Some(item) = stream.next().await {
3236            match item {
3237                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
3238                Ok(_) => {}
3239                Err(err) => panic!("unexpected streaming error: {err:?}"),
3240            }
3241        }
3242
3243        assert_eq!(
3244            text_hook.observed(),
3245            vec![
3246                ("stale ".to_string(), "stale ".to_string()),
3247                ("fresh".to_string(), "fresh".to_string()),
3248            ]
3249        );
3250    }
3251
3252    #[tokio::test]
3253    async fn invalid_tool_call_delta_skip_uses_structured_tool_feedback() {
3254        let delta_hook = RecordingToolCallDeltaHook::default();
3255        let add_calls = Arc::new(AtomicU32::new(0));
3256        let model = MockCompletionModel::from_stream_turns([
3257            vec![
3258                MockStreamEvent::text("checking "),
3259                MockStreamEvent::tool_call(
3260                    "tool_call_0",
3261                    "add",
3262                    serde_json::json!({"x": 1, "y": 2}),
3263                )
3264                .with_call_id("call_0"),
3265                MockStreamEvent::tool_call_arguments_delta(
3266                    "tool_call_1",
3267                    "internal_1",
3268                    r#"{"x":2,"y":3}"#,
3269                ),
3270                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3271                MockStreamEvent::final_response_with_total_tokens(4),
3272            ],
3273            vec![
3274                MockStreamEvent::text("continued"),
3275                MockStreamEvent::final_response_with_total_tokens(6),
3276            ],
3277        ]);
3278        let recorded = model.clone();
3279        let agent = AgentBuilder::new(model)
3280            .tool(CountingAddTool {
3281                calls: add_calls.clone(),
3282            })
3283            .build();
3284
3285        let mut stream = agent
3286            .stream_prompt("use the tool")
3287            .with_hook(RecordingDeltaAndSkipInvalidToolHook {
3288                delta: delta_hook.clone(),
3289            })
3290            .multi_turn(3)
3291            .with_history(Vec::<Message>::new())
3292            .await;
3293        let mut skipped_tool_result = None;
3294        let mut final_response_text = None;
3295
3296        while let Some(item) = stream.next().await {
3297            match item {
3298                Ok(MultiTurnStreamItem::StreamAssistantItem(
3299                    StreamedAssistantContent::ToolCallDelta { .. },
3300                )) => panic!("invalid tool-call delta should not be emitted"),
3301                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3302                    tool_result,
3303                    internal_call_id,
3304                })) => {
3305                    assert_eq!(internal_call_id, "internal_1");
3306                    skipped_tool_result = Some(tool_result);
3307                }
3308                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3309                    final_response_text = Some(response.response().to_string());
3310                    break;
3311                }
3312                Ok(_) => {}
3313                Err(err) => panic!("unexpected streaming error: {err:?}"),
3314            }
3315        }
3316
3317        let skipped_tool_result =
3318            skipped_tool_result.expect("skip recovery should emit a synthetic tool result");
3319        assert_eq!(skipped_tool_result.id, "tool_call_1");
3320        assert!(skipped_tool_result.call_id.is_none());
3321        assert!(skipped_tool_result.content.iter().any(|content| matches!(
3322            content,
3323            ToolResultContent::Text(text) if text.text == "default_api was skipped"
3324        )));
3325        assert_eq!(final_response_text.as_deref(), Some("continued"));
3326        assert!(delta_hook.observed().is_empty());
3327        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3328
3329        let requests = recorded.requests();
3330        assert_eq!(requests.len(), 2);
3331        let follow_up_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
3332        assert!(matches!(
3333            follow_up_history.get(1),
3334            Some(Message::Assistant { content, .. })
3335                if content.iter().any(|item| matches!(
3336                    item,
3337                    AssistantContent::Text(text) if text.text == "checking "
3338                ))
3339                    && content.iter().any(|item| matches!(
3340                        item,
3341                        AssistantContent::ToolCall(tool_call)
3342                            if tool_call.id == "tool_call_0"
3343                                && tool_call.function.name == "add"
3344                    ))
3345                    && content.iter().any(|item| matches!(
3346                    item,
3347                    AssistantContent::ToolCall(tool_call)
3348                        if tool_call.id == "tool_call_1"
3349                            && tool_call.function.name == "default_api"
3350                            && tool_call.function.arguments == serde_json::json!({"x": 2, "y": 3})
3351                ))
3352        ));
3353        assert!(matches!(
3354            follow_up_history.get(2),
3355            Some(Message::User { content })
3356                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
3357                    && content.iter().any(|item| matches!(
3358                        item,
3359                        UserContent::ToolResult(result)
3360                            if result.id == "tool_call_0"
3361                                && result.call_id.as_deref() == Some("call_0")
3362                                && result.content.iter().any(|content| matches!(
3363                                    content,
3364                                    ToolResultContent::Text(text)
3365                                        if text.text == TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
3366                                ))
3367                    ))
3368                    && content.iter().any(|item| matches!(
3369                    item,
3370                    UserContent::ToolResult(result)
3371                        if result.id == "tool_call_1"
3372                            && result.content.iter().any(|content| matches!(
3373                                content,
3374                                ToolResultContent::Text(text)
3375                                    if text.text == "default_api was skipped"
3376                            ))
3377                ))
3378        ));
3379    }
3380
3381    #[tokio::test]
3382    async fn streaming_retry_budget_exhaustion_history_contains_invalid_tool_call() {
3383        let model = MockCompletionModel::from_stream_turns([
3384            vec![
3385                MockStreamEvent::tool_call(
3386                    "tool_call_1",
3387                    "default_api",
3388                    serde_json::json!({"x": 1, "y": 2}),
3389                ),
3390                MockStreamEvent::final_response_with_total_tokens(4),
3391            ],
3392            vec![
3393                MockStreamEvent::text("should not be requested"),
3394                MockStreamEvent::final_response_with_total_tokens(6),
3395            ],
3396        ]);
3397        let recorded = model.clone();
3398        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3399
3400        let mut stream = agent
3401            .stream_prompt("use the tool")
3402            .with_hook(RetryDefaultApiHook)
3403            .multi_turn(3)
3404            .max_invalid_tool_call_retries(0)
3405            .await;
3406        let mut error = None;
3407
3408        while let Some(item) = stream.next().await {
3409            if let Err(err) = item {
3410                error = Some(err);
3411                break;
3412            }
3413        }
3414
3415        let error = error.expect("retry budget exhaustion should fail");
3416        match error {
3417            StreamingError::Prompt(err) => match *err {
3418                PromptError::UnknownToolCall {
3419                    tool_name,
3420                    chat_history,
3421                    ..
3422                } => {
3423                    assert_eq!(tool_name, "default_api");
3424                    assert!(history_contains_tool_call(&chat_history, "default_api"));
3425                }
3426                other => panic!("expected UnknownToolCall, got {other:?}"),
3427            },
3428            other => panic!("expected prompt streaming error, got {other:?}"),
3429        }
3430        assert_eq!(recorded.request_count(), 1);
3431    }
3432
3433    #[tokio::test]
3434    async fn streaming_name_delta_retry_budget_exhaustion_history_includes_same_turn_context() {
3435        let model = MockCompletionModel::from_stream_turns([
3436            vec![
3437                MockStreamEvent::text("checking "),
3438                MockStreamEvent::tool_call(
3439                    "tool_call_0",
3440                    "add",
3441                    serde_json::json!({"x": 1, "y": 2}),
3442                )
3443                .with_call_id("call_0"),
3444                MockStreamEvent::tool_call_arguments_delta(
3445                    "tool_call_1",
3446                    "internal_1",
3447                    r#"{"x":2,"y":3}"#,
3448                ),
3449                MockStreamEvent::tool_call_name_delta("tool_call_1", "internal_1", "default_api"),
3450                MockStreamEvent::final_response_with_total_tokens(4),
3451            ],
3452            vec![
3453                MockStreamEvent::text("should not be requested"),
3454                MockStreamEvent::final_response_with_total_tokens(6),
3455            ],
3456        ]);
3457        let recorded = model.clone();
3458        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
3459
3460        let mut stream = agent
3461            .stream_prompt("use the tool")
3462            .with_hook(RetryDefaultApiHook)
3463            .multi_turn(3)
3464            .max_invalid_tool_call_retries(0)
3465            .await;
3466        let mut error = None;
3467
3468        while let Some(item) = stream.next().await {
3469            if let Err(err) = item {
3470                error = Some(err);
3471                break;
3472            }
3473        }
3474
3475        let error = error.expect("retry budget exhaustion should fail");
3476        match error {
3477            StreamingError::Prompt(err) => match *err {
3478                PromptError::UnknownToolCall {
3479                    tool_name,
3480                    chat_history,
3481                    ..
3482                } => {
3483                    assert_eq!(tool_name, "default_api");
3484                    assert!(history_contains_text(&chat_history, "checking "));
3485                    assert!(history_contains_tool_call(&chat_history, "add"));
3486                    assert!(history_contains_tool_call(&chat_history, "default_api"));
3487                }
3488                other => panic!("expected UnknownToolCall, got {other:?}"),
3489            },
3490            other => panic!("expected prompt streaming error, got {other:?}"),
3491        }
3492        assert_eq!(recorded.request_count(), 1);
3493    }
3494
3495    #[tokio::test]
3496    async fn completed_unknown_tool_call_after_text_fails_before_finish_hook_or_later_emit() {
3497        let add_calls = Arc::new(AtomicU32::new(0));
3498        let model = MockCompletionModel::from_stream_turns([
3499            vec![
3500                MockStreamEvent::text("thinking "),
3501                MockStreamEvent::tool_call(
3502                    "tool_call_1",
3503                    "default_api",
3504                    serde_json::json!({"x": 1, "y": 2}),
3505                ),
3506                MockStreamEvent::final_response_with_total_tokens(4),
3507            ],
3508            vec![
3509                MockStreamEvent::text("should not be requested"),
3510                MockStreamEvent::final_response_with_total_tokens(6),
3511            ],
3512        ]);
3513        let recorded = model.clone();
3514        let agent = AgentBuilder::new(model)
3515            .tool(CountingAddTool {
3516                calls: add_calls.clone(),
3517            })
3518            .build();
3519
3520        let mut stream = agent
3521            .stream_prompt("use the tool")
3522            .with_hook(PanicOnUnknownToolHook)
3523            .multi_turn(3)
3524            .await;
3525        let mut saw_text = false;
3526        let mut saw_completion_call = false;
3527        let mut saw_final_response = false;
3528        let mut saw_tool_call = false;
3529        let mut saw_tool_result = false;
3530        let mut error = None;
3531
3532        while let Some(item) = stream.next().await {
3533            match item {
3534                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(_))) => {
3535                    saw_text = true;
3536                }
3537                Ok(MultiTurnStreamItem::CompletionCall(_)) => {
3538                    saw_completion_call = true;
3539                }
3540                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Final(
3541                    _,
3542                )))
3543                | Ok(MultiTurnStreamItem::FinalResponse(_)) => {
3544                    saw_final_response = true;
3545                }
3546                Ok(MultiTurnStreamItem::StreamAssistantItem(
3547                    StreamedAssistantContent::ToolCall { .. },
3548                )) => {
3549                    saw_tool_call = true;
3550                }
3551                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3552                    ..
3553                })) => {
3554                    saw_tool_result = true;
3555                }
3556                Ok(_) => {}
3557                Err(err) => {
3558                    error = Some(err);
3559                    break;
3560                }
3561            }
3562        }
3563
3564        assert!(saw_text);
3565        assert!(!saw_completion_call);
3566        assert!(!saw_final_response);
3567        assert!(!saw_tool_call);
3568        assert!(!saw_tool_result);
3569        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3570        let error = error.expect("completed unknown tool call should fail immediately");
3571        match error {
3572            StreamingError::Prompt(err) => match *err {
3573                PromptError::UnknownToolCall {
3574                    tool_name,
3575                    available_tools,
3576                    allowed_tools,
3577                    chat_history,
3578                } => {
3579                    assert_eq!(tool_name, "default_api");
3580                    assert_eq!(available_tools, vec!["add".to_string()]);
3581                    assert_eq!(allowed_tools, vec!["add".to_string()]);
3582                    assert!(history_contains_tool_call(&chat_history, "default_api"));
3583                }
3584                other => panic!("expected UnknownToolCall, got {other:?}"),
3585            },
3586            other => panic!("expected prompt streaming error, got {other:?}"),
3587        }
3588        assert_eq!(recorded.request_count(), 1);
3589    }
3590
3591    #[tokio::test]
3592    async fn mixed_streaming_tool_calls_fail_before_any_tool_execution() {
3593        let add_calls = Arc::new(AtomicU32::new(0));
3594        let model = MockCompletionModel::from_stream_turns([
3595            vec![
3596                MockStreamEvent::tool_call(
3597                    "tool_call_1",
3598                    "add",
3599                    serde_json::json!({"x": 1, "y": 2}),
3600                )
3601                .with_call_id("call_1"),
3602                MockStreamEvent::tool_call(
3603                    "tool_call_2",
3604                    "default_api",
3605                    serde_json::json!({"x": 3, "y": 4}),
3606                ),
3607                MockStreamEvent::final_response_with_total_tokens(4),
3608            ],
3609            vec![
3610                MockStreamEvent::text("should not be requested"),
3611                MockStreamEvent::final_response_with_total_tokens(6),
3612            ],
3613        ]);
3614        let recorded = model.clone();
3615        let agent = AgentBuilder::new(model)
3616            .tool(CountingAddTool {
3617                calls: add_calls.clone(),
3618            })
3619            .build();
3620
3621        let mut stream = agent
3622            .stream_prompt("use tools")
3623            .with_hook(PanicOnUnknownToolHook)
3624            .multi_turn(3)
3625            .await;
3626        let mut saw_completion_call = false;
3627        let mut saw_tool_call = false;
3628        let mut saw_tool_result = false;
3629        let mut error = None;
3630
3631        while let Some(item) = stream.next().await {
3632            match item {
3633                Ok(MultiTurnStreamItem::CompletionCall(_)) => {
3634                    saw_completion_call = true;
3635                }
3636                Ok(MultiTurnStreamItem::StreamAssistantItem(
3637                    StreamedAssistantContent::ToolCall { .. },
3638                )) => {
3639                    saw_tool_call = true;
3640                }
3641                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3642                    ..
3643                })) => {
3644                    saw_tool_result = true;
3645                }
3646                Ok(_) => {}
3647                Err(err) => {
3648                    error = Some(err);
3649                    break;
3650                }
3651            }
3652        }
3653
3654        assert!(!saw_completion_call);
3655        assert!(!saw_tool_call);
3656        assert!(!saw_tool_result);
3657        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3658        let error = error.expect("mixed unknown streamed tool call should fail");
3659        match error {
3660            StreamingError::Prompt(err) => match *err {
3661                PromptError::UnknownToolCall {
3662                    tool_name,
3663                    available_tools,
3664                    allowed_tools,
3665                    chat_history,
3666                } => {
3667                    assert_eq!(tool_name, "default_api");
3668                    assert_eq!(available_tools, vec!["add".to_string()]);
3669                    assert_eq!(allowed_tools, vec!["add".to_string()]);
3670                    assert!(history_contains_tool_call(&chat_history, "default_api"));
3671                }
3672                other => panic!("expected UnknownToolCall, got {other:?}"),
3673            },
3674            other => panic!("expected prompt streaming error, got {other:?}"),
3675        }
3676        assert_eq!(recorded.request_count(), 1);
3677    }
3678
3679    #[tokio::test]
3680    async fn multiple_valid_streaming_tool_calls_execute_after_batch_validation() {
3681        let add_calls = Arc::new(AtomicU32::new(0));
3682        let subtract_calls = Arc::new(AtomicU32::new(0));
3683        let model = MockCompletionModel::from_stream_turns([
3684            vec![
3685                MockStreamEvent::tool_call(
3686                    "tool_call_1",
3687                    "add",
3688                    serde_json::json!({"x": 1, "y": 2}),
3689                )
3690                .with_call_id("call_1"),
3691                MockStreamEvent::tool_call(
3692                    "tool_call_2",
3693                    "subtract",
3694                    serde_json::json!({"x": 8, "y": 3}),
3695                )
3696                .with_call_id("call_2"),
3697                MockStreamEvent::final_response_with_total_tokens(4),
3698            ],
3699            vec![
3700                MockStreamEvent::text("done"),
3701                MockStreamEvent::final_response_with_total_tokens(6),
3702            ],
3703        ]);
3704        let recorded = model.clone();
3705        let agent = AgentBuilder::new(model)
3706            .tool(CountingAddTool {
3707                calls: add_calls.clone(),
3708            })
3709            .tool(CountingSubtractTool {
3710                calls: subtract_calls.clone(),
3711            })
3712            .build();
3713
3714        let mut stream = agent.stream_prompt("use tools").multi_turn(3).await;
3715        let mut tool_call_names = Vec::new();
3716        let mut tool_result_ids = Vec::new();
3717        let mut final_response_text = None;
3718
3719        while let Some(item) = stream.next().await {
3720            match item {
3721                Ok(MultiTurnStreamItem::StreamAssistantItem(
3722                    StreamedAssistantContent::ToolCall { tool_call, .. },
3723                )) => {
3724                    tool_call_names.push(tool_call.function.name);
3725                }
3726                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3727                    tool_result,
3728                    ..
3729                })) => {
3730                    tool_result_ids.push(tool_result.id);
3731                }
3732                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
3733                    final_response_text = Some(response.response().to_owned());
3734                    break;
3735                }
3736                Ok(_) => {}
3737                Err(err) => panic!("unexpected streaming error: {err:?}"),
3738            }
3739        }
3740
3741        assert_eq!(
3742            tool_call_names,
3743            vec!["add".to_string(), "subtract".to_string()]
3744        );
3745        assert_eq!(
3746            tool_result_ids,
3747            vec!["tool_call_1".to_string(), "tool_call_2".to_string()]
3748        );
3749        assert_eq!(add_calls.load(Ordering::SeqCst), 1);
3750        assert_eq!(subtract_calls.load(Ordering::SeqCst), 1);
3751        assert_eq!(final_response_text.as_deref(), Some("done"));
3752        assert_eq!(recorded.request_count(), 2);
3753    }
3754
3755    #[tokio::test]
3756    async fn disallowed_specific_tool_call_fails_before_streaming_second_request() {
3757        let model = MockCompletionModel::from_stream_turns([
3758            vec![
3759                MockStreamEvent::tool_call(
3760                    "tool_call_1",
3761                    "subtract",
3762                    serde_json::json!({"x": 3, "y": 1}),
3763                ),
3764                MockStreamEvent::final_response_with_total_tokens(4),
3765            ],
3766            vec![
3767                MockStreamEvent::text("should not be requested"),
3768                MockStreamEvent::final_response_with_total_tokens(6),
3769            ],
3770        ]);
3771        let recorded = model.clone();
3772        let agent = AgentBuilder::new(model)
3773            .tool(MockAddTool)
3774            .tool(MockSubtractTool)
3775            .tool_choice(ToolChoice::Specific {
3776                function_names: vec!["add".to_string()],
3777            })
3778            .build();
3779
3780        let mut stream = agent
3781            .stream_prompt("use the allowed tool")
3782            .with_hook(PanicOnUnknownToolHook)
3783            .multi_turn(3)
3784            .await;
3785        let mut saw_tool_call = false;
3786        let mut error = None;
3787
3788        while let Some(item) = stream.next().await {
3789            match item {
3790                Ok(MultiTurnStreamItem::StreamAssistantItem(
3791                    StreamedAssistantContent::ToolCall { .. },
3792                )) => {
3793                    saw_tool_call = true;
3794                }
3795                Ok(_) => {}
3796                Err(err) => {
3797                    error = Some(err);
3798                    break;
3799                }
3800            }
3801        }
3802
3803        assert!(!saw_tool_call);
3804        let error = error.expect("disallowed model-emitted tool should fail");
3805        match error {
3806            StreamingError::Prompt(err) => match *err {
3807                PromptError::UnknownToolCall {
3808                    tool_name,
3809                    available_tools,
3810                    allowed_tools,
3811                    chat_history,
3812                } => {
3813                    assert_eq!(tool_name, "subtract");
3814                    assert_eq!(
3815                        available_tools,
3816                        vec!["add".to_string(), "subtract".to_string()]
3817                    );
3818                    assert_eq!(allowed_tools, vec!["add".to_string()]);
3819                    assert!(history_contains_tool_call(&chat_history, "subtract"));
3820                }
3821                other => panic!("expected UnknownToolCall, got {other:?}"),
3822            },
3823            other => panic!("expected prompt streaming error, got {other:?}"),
3824        }
3825        assert_eq!(recorded.request_count(), 1);
3826    }
3827
3828    #[tokio::test]
3829    async fn mixed_specific_tool_calls_fail_before_any_tool_execution() {
3830        let add_calls = Arc::new(AtomicU32::new(0));
3831        let model = MockCompletionModel::from_stream_turns([
3832            vec![
3833                MockStreamEvent::tool_call(
3834                    "tool_call_1",
3835                    "add",
3836                    serde_json::json!({"x": 1, "y": 2}),
3837                ),
3838                MockStreamEvent::tool_call(
3839                    "tool_call_2",
3840                    "subtract",
3841                    serde_json::json!({"x": 3, "y": 1}),
3842                ),
3843                MockStreamEvent::final_response_with_total_tokens(4),
3844            ],
3845            vec![
3846                MockStreamEvent::text("should not be requested"),
3847                MockStreamEvent::final_response_with_total_tokens(6),
3848            ],
3849        ]);
3850        let recorded = model.clone();
3851        let agent = AgentBuilder::new(model)
3852            .tool(CountingAddTool {
3853                calls: add_calls.clone(),
3854            })
3855            .tool(MockSubtractTool)
3856            .tool_choice(ToolChoice::Specific {
3857                function_names: vec!["add".to_string()],
3858            })
3859            .build();
3860
3861        let mut stream = agent
3862            .stream_prompt("use the allowed tool")
3863            .with_hook(PanicOnUnknownToolHook)
3864            .multi_turn(3)
3865            .await;
3866        let mut saw_tool_call = false;
3867        let mut saw_tool_result = false;
3868        let mut error = None;
3869
3870        while let Some(item) = stream.next().await {
3871            match item {
3872                Ok(MultiTurnStreamItem::StreamAssistantItem(
3873                    StreamedAssistantContent::ToolCall { .. },
3874                )) => {
3875                    saw_tool_call = true;
3876                }
3877                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
3878                    ..
3879                })) => {
3880                    saw_tool_result = true;
3881                }
3882                Ok(_) => {}
3883                Err(err) => {
3884                    error = Some(err);
3885                    break;
3886                }
3887            }
3888        }
3889
3890        assert!(!saw_tool_call);
3891        assert!(!saw_tool_result);
3892        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
3893        let error = error.expect("mixed disallowed streamed tool call should fail");
3894        match error {
3895            StreamingError::Prompt(err) => match *err {
3896                PromptError::UnknownToolCall {
3897                    tool_name,
3898                    available_tools,
3899                    allowed_tools,
3900                    chat_history,
3901                } => {
3902                    assert_eq!(tool_name, "subtract");
3903                    assert_eq!(
3904                        available_tools,
3905                        vec!["add".to_string(), "subtract".to_string()]
3906                    );
3907                    assert_eq!(allowed_tools, vec!["add".to_string()]);
3908                    assert!(history_contains_tool_call(&chat_history, "subtract"));
3909                }
3910                other => panic!("expected UnknownToolCall, got {other:?}"),
3911            },
3912            other => panic!("expected prompt streaming error, got {other:?}"),
3913        }
3914        assert_eq!(recorded.request_count(), 1);
3915    }
3916
3917    #[tokio::test]
3918    async fn tool_choice_none_rejects_streaming_tool_call() {
3919        let model = MockCompletionModel::from_stream_turns([
3920            vec![
3921                MockStreamEvent::tool_call(
3922                    "tool_call_1",
3923                    "add",
3924                    serde_json::json!({"x": 1, "y": 2}),
3925                ),
3926                MockStreamEvent::final_response_with_total_tokens(4),
3927            ],
3928            vec![
3929                MockStreamEvent::text("should not be requested"),
3930                MockStreamEvent::final_response_with_total_tokens(6),
3931            ],
3932        ]);
3933        let recorded = model.clone();
3934        let agent = AgentBuilder::new(model)
3935            .tool(MockAddTool)
3936            .tool_choice(ToolChoice::None)
3937            .build();
3938
3939        let mut stream = agent
3940            .stream_prompt("do not use tools")
3941            .with_hook(PanicOnUnknownToolHook)
3942            .multi_turn(3)
3943            .await;
3944        let mut saw_tool_call = false;
3945        let mut error = None;
3946
3947        while let Some(item) = stream.next().await {
3948            match item {
3949                Ok(MultiTurnStreamItem::StreamAssistantItem(
3950                    StreamedAssistantContent::ToolCall { .. },
3951                )) => {
3952                    saw_tool_call = true;
3953                }
3954                Ok(_) => {}
3955                Err(err) => {
3956                    error = Some(err);
3957                    break;
3958                }
3959            }
3960        }
3961
3962        assert!(!saw_tool_call);
3963        let error = error.expect("ToolChoice::None should reject returned tool calls");
3964        match error {
3965            StreamingError::Prompt(err) => match *err {
3966                PromptError::UnknownToolCall {
3967                    tool_name,
3968                    available_tools,
3969                    allowed_tools,
3970                    chat_history,
3971                } => {
3972                    assert_eq!(tool_name, "add");
3973                    assert_eq!(available_tools, vec!["add".to_string()]);
3974                    assert!(allowed_tools.is_empty());
3975                    assert!(history_contains_tool_call(&chat_history, "add"));
3976                }
3977                other => panic!("expected UnknownToolCall, got {other:?}"),
3978            },
3979            other => panic!("expected prompt streaming error, got {other:?}"),
3980        }
3981        assert_eq!(recorded.request_count(), 1);
3982    }
3983
3984    #[tokio::test]
3985    async fn tool_choice_none_rejects_streaming_tool_call_name_delta_before_hook_or_emit() {
3986        let model = MockCompletionModel::from_stream_turns([
3987            vec![
3988                MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
3989                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
3990                MockStreamEvent::final_response_with_total_tokens(4),
3991            ],
3992            vec![
3993                MockStreamEvent::text("should not be requested"),
3994                MockStreamEvent::final_response_with_total_tokens(6),
3995            ],
3996        ]);
3997        let recorded = model.clone();
3998        let agent = AgentBuilder::new(model)
3999            .tool(MockAddTool)
4000            .tool_choice(ToolChoice::None)
4001            .build();
4002
4003        let mut stream = agent
4004            .stream_prompt("do not use tools")
4005            .with_hook(PanicOnUnknownToolHook)
4006            .multi_turn(3)
4007            .await;
4008        let mut saw_delta = false;
4009        let mut error = None;
4010
4011        while let Some(item) = stream.next().await {
4012            match item {
4013                Ok(MultiTurnStreamItem::StreamAssistantItem(
4014                    StreamedAssistantContent::ToolCallDelta { .. },
4015                )) => {
4016                    saw_delta = true;
4017                }
4018                Ok(_) => {}
4019                Err(err) => {
4020                    error = Some(err);
4021                    break;
4022                }
4023            }
4024        }
4025
4026        assert!(!saw_delta);
4027        let error = error.expect("ToolChoice::None should reject returned tool-call deltas");
4028        match error {
4029            StreamingError::Prompt(err) => match *err {
4030                PromptError::UnknownToolCall {
4031                    tool_name,
4032                    available_tools,
4033                    allowed_tools,
4034                    chat_history,
4035                } => {
4036                    assert_eq!(tool_name, "add");
4037                    assert_eq!(available_tools, vec!["add".to_string()]);
4038                    assert!(allowed_tools.is_empty());
4039                    assert!(history_contains_tool_call(&chat_history, "add"));
4040                }
4041                other => panic!("expected UnknownToolCall, got {other:?}"),
4042            },
4043            other => panic!("expected prompt streaming error, got {other:?}"),
4044        }
4045        assert_eq!(recorded.request_count(), 1);
4046    }
4047
4048    #[tokio::test]
4049    async fn unknown_tool_call_name_delta_fails_before_streaming_delta_hook_or_emit() {
4050        let model = MockCompletionModel::from_stream_turns([
4051            vec![
4052                MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"),
4053                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4054                MockStreamEvent::final_response_with_total_tokens(4),
4055            ],
4056            vec![
4057                MockStreamEvent::text("should not be requested"),
4058                MockStreamEvent::final_response_with_total_tokens(6),
4059            ],
4060        ]);
4061        let recorded = model.clone();
4062        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4063
4064        let mut stream = agent
4065            .stream_prompt("stream a bad tool call")
4066            .with_hook(PanicOnUnknownToolHook)
4067            .multi_turn(3)
4068            .await;
4069        let mut saw_delta = false;
4070        let mut error = None;
4071
4072        while let Some(item) = stream.next().await {
4073            match item {
4074                Ok(MultiTurnStreamItem::StreamAssistantItem(
4075                    StreamedAssistantContent::ToolCallDelta { .. },
4076                )) => {
4077                    saw_delta = true;
4078                }
4079                Ok(_) => {}
4080                Err(err) => {
4081                    error = Some(err);
4082                    break;
4083                }
4084            }
4085        }
4086
4087        assert!(!saw_delta);
4088        let error = error.expect("unknown tool-call name delta should fail");
4089        match error {
4090            StreamingError::Prompt(err) => match *err {
4091                PromptError::UnknownToolCall {
4092                    tool_name,
4093                    available_tools,
4094                    allowed_tools,
4095                    chat_history,
4096                } => {
4097                    assert_eq!(tool_name, "default_api");
4098                    assert_eq!(available_tools, vec!["add".to_string()]);
4099                    assert_eq!(allowed_tools, vec!["add".to_string()]);
4100                    assert!(history_contains_tool_call(&chat_history, "default_api"));
4101                }
4102                other => panic!("expected UnknownToolCall, got {other:?}"),
4103            },
4104            other => panic!("expected prompt streaming error, got {other:?}"),
4105        }
4106        assert_eq!(recorded.request_count(), 1);
4107    }
4108
4109    #[tokio::test]
4110    async fn tool_call_args_delta_before_unknown_name_fails_before_hook_or_emit() {
4111        let model = MockCompletionModel::from_stream_turns([
4112            vec![
4113                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4114                MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"),
4115                MockStreamEvent::final_response_with_total_tokens(4),
4116            ],
4117            vec![
4118                MockStreamEvent::text("should not be requested"),
4119                MockStreamEvent::final_response_with_total_tokens(6),
4120            ],
4121        ]);
4122        let recorded = model.clone();
4123        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4124
4125        let mut stream = agent
4126            .stream_prompt("stream a bad tool call")
4127            .with_hook(PanicOnUnknownToolHook)
4128            .multi_turn(3)
4129            .await;
4130        let mut saw_delta = false;
4131        let mut error = None;
4132
4133        while let Some(item) = stream.next().await {
4134            match item {
4135                Ok(MultiTurnStreamItem::StreamAssistantItem(
4136                    StreamedAssistantContent::ToolCallDelta { .. },
4137                )) => {
4138                    saw_delta = true;
4139                }
4140                Ok(_) => {}
4141                Err(err) => {
4142                    error = Some(err);
4143                    break;
4144                }
4145            }
4146        }
4147
4148        assert!(!saw_delta);
4149        let error = error.expect("unknown tool-call name should reject buffered args");
4150        match error {
4151            StreamingError::Prompt(err) => match *err {
4152                PromptError::UnknownToolCall {
4153                    tool_name,
4154                    available_tools,
4155                    allowed_tools,
4156                    chat_history,
4157                } => {
4158                    assert_eq!(tool_name, "default_api");
4159                    assert_eq!(available_tools, vec!["add".to_string()]);
4160                    assert_eq!(allowed_tools, vec!["add".to_string()]);
4161                    assert!(history_contains_tool_call(&chat_history, "default_api"));
4162                }
4163                other => panic!("expected UnknownToolCall, got {other:?}"),
4164            },
4165            other => panic!("expected prompt streaming error, got {other:?}"),
4166        }
4167        assert_eq!(recorded.request_count(), 1);
4168    }
4169
4170    #[tokio::test]
4171    async fn tool_call_args_delta_before_valid_name_buffers_then_emits_in_safe_order() {
4172        let model = MockCompletionModel::from_stream_turns([[
4173            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4174            MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4175            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4176            MockStreamEvent::final_response_with_total_tokens(3),
4177        ]]);
4178        let hook = RecordingToolCallDeltaHook::default();
4179        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4180
4181        let mut stream = agent
4182            .stream_prompt("stream a tool call")
4183            .with_hook(hook.clone())
4184            .await;
4185        let mut stream_deltas = Vec::new();
4186
4187        while let Some(item) = stream.next().await {
4188            match item {
4189                Ok(MultiTurnStreamItem::StreamAssistantItem(
4190                    StreamedAssistantContent::ToolCallDelta {
4191                        id,
4192                        internal_call_id,
4193                        content,
4194                    },
4195                )) => {
4196                    stream_deltas.push((id, internal_call_id, content));
4197                }
4198                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4199                Ok(_) => {}
4200                Err(err) => panic!("unexpected streaming error: {err:?}"),
4201            }
4202        }
4203
4204        assert_eq!(
4205            hook.observed(),
4206            vec![
4207                (
4208                    "tool_1".to_string(),
4209                    "internal_1".to_string(),
4210                    Some("add".to_string()),
4211                    String::new()
4212                ),
4213                (
4214                    "tool_1".to_string(),
4215                    "internal_1".to_string(),
4216                    None,
4217                    "{\"x\":".to_string()
4218                ),
4219                (
4220                    "tool_1".to_string(),
4221                    "internal_1".to_string(),
4222                    None,
4223                    "1}".to_string()
4224                ),
4225            ]
4226        );
4227        assert_eq!(
4228            stream_deltas,
4229            vec![
4230                (
4231                    "tool_1".to_string(),
4232                    "internal_1".to_string(),
4233                    ToolCallDeltaContent::Name("add".to_string())
4234                ),
4235                (
4236                    "tool_1".to_string(),
4237                    "internal_1".to_string(),
4238                    ToolCallDeltaContent::Delta("{\"x\":".to_string())
4239                ),
4240                (
4241                    "tool_1".to_string(),
4242                    "internal_1".to_string(),
4243                    ToolCallDeltaContent::Delta("1}".to_string())
4244                ),
4245            ]
4246        );
4247    }
4248
4249    #[tokio::test]
4250    async fn tool_call_args_delta_without_name_errors_at_stream_end() {
4251        let model = MockCompletionModel::from_stream_turns([
4252            vec![
4253                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4254                MockStreamEvent::final_response_with_total_tokens(4),
4255            ],
4256            vec![
4257                MockStreamEvent::text("should not be requested"),
4258                MockStreamEvent::final_response_with_total_tokens(6),
4259            ],
4260        ]);
4261        let recorded = model.clone();
4262        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4263
4264        let mut stream = agent
4265            .stream_prompt("stream an incomplete tool call")
4266            .with_hook(PanicOnUnknownToolHook)
4267            .multi_turn(3)
4268            .await;
4269        let mut saw_delta = false;
4270        let mut saw_completion_call = false;
4271        let mut saw_final_response = false;
4272        let mut error = None;
4273
4274        while let Some(item) = stream.next().await {
4275            match item {
4276                Ok(MultiTurnStreamItem::StreamAssistantItem(
4277                    StreamedAssistantContent::ToolCallDelta { .. },
4278                )) => {
4279                    saw_delta = true;
4280                }
4281                Ok(MultiTurnStreamItem::CompletionCall(_)) => {
4282                    saw_completion_call = true;
4283                }
4284                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
4285                    saw_final_response = true;
4286                }
4287                Ok(_) => {}
4288                Err(err) => {
4289                    error = Some(err);
4290                    break;
4291                }
4292            }
4293        }
4294
4295        assert!(!saw_delta);
4296        assert!(!saw_completion_call);
4297        assert!(!saw_final_response);
4298        let error = error.expect("unterminated tool-call args delta should fail");
4299        match error {
4300            StreamingError::Completion(CompletionError::ResponseError(message)) => {
4301                assert!(
4302                    message.contains("streamed tool call arguments"),
4303                    "{message}"
4304                );
4305                assert!(message.contains("tool_1"), "{message}");
4306                assert!(message.contains("internal_1"), "{message}");
4307            }
4308            other => panic!("expected completion response error, got {other:?}"),
4309        }
4310        assert_eq!(recorded.request_count(), 1);
4311    }
4312
4313    #[tokio::test]
4314    async fn tool_choice_none_buffers_args_then_rejects_name_without_emit() {
4315        let model = MockCompletionModel::from_stream_turns([
4316            vec![
4317                MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
4318                MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4319                MockStreamEvent::final_response_with_total_tokens(4),
4320            ],
4321            vec![
4322                MockStreamEvent::text("should not be requested"),
4323                MockStreamEvent::final_response_with_total_tokens(6),
4324            ],
4325        ]);
4326        let recorded = model.clone();
4327        let agent = AgentBuilder::new(model)
4328            .tool(MockAddTool)
4329            .tool_choice(ToolChoice::None)
4330            .build();
4331
4332        let mut stream = agent
4333            .stream_prompt("do not use tools")
4334            .with_hook(PanicOnUnknownToolHook)
4335            .multi_turn(3)
4336            .await;
4337        let mut saw_delta = false;
4338        let mut error = None;
4339
4340        while let Some(item) = stream.next().await {
4341            match item {
4342                Ok(MultiTurnStreamItem::StreamAssistantItem(
4343                    StreamedAssistantContent::ToolCallDelta { .. },
4344                )) => {
4345                    saw_delta = true;
4346                }
4347                Ok(_) => {}
4348                Err(err) => {
4349                    error = Some(err);
4350                    break;
4351                }
4352            }
4353        }
4354
4355        assert!(!saw_delta);
4356        let error = error.expect("ToolChoice::None should reject buffered tool-call deltas");
4357        match error {
4358            StreamingError::Prompt(err) => match *err {
4359                PromptError::UnknownToolCall {
4360                    tool_name,
4361                    available_tools,
4362                    allowed_tools,
4363                    chat_history,
4364                } => {
4365                    assert_eq!(tool_name, "add");
4366                    assert_eq!(available_tools, vec!["add".to_string()]);
4367                    assert!(allowed_tools.is_empty());
4368                    assert!(history_contains_tool_call(&chat_history, "add"));
4369                }
4370                other => panic!("expected UnknownToolCall, got {other:?}"),
4371            },
4372            other => panic!("expected prompt streaming error, got {other:?}"),
4373        }
4374        assert_eq!(recorded.request_count(), 1);
4375    }
4376
4377    #[tokio::test]
4378    async fn stream_prompt_emits_tool_call_deltas_without_hook() {
4379        let model = MockCompletionModel::from_stream_turns([[
4380            MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4381            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4382            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4383            MockStreamEvent::final_response_with_total_tokens(3),
4384        ]]);
4385        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4386
4387        let mut stream = agent.stream_prompt("stream a tool call").await;
4388        let mut deltas = Vec::new();
4389
4390        while let Some(item) = stream.next().await {
4391            match item {
4392                Ok(MultiTurnStreamItem::StreamAssistantItem(
4393                    StreamedAssistantContent::ToolCallDelta {
4394                        id,
4395                        internal_call_id,
4396                        content,
4397                    },
4398                )) => {
4399                    deltas.push((id, internal_call_id, content));
4400                }
4401                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4402                Ok(_) => {}
4403                Err(err) => panic!("unexpected streaming error: {err:?}"),
4404            }
4405        }
4406
4407        assert_eq!(
4408            deltas,
4409            vec![
4410                (
4411                    "tool_1".to_string(),
4412                    "internal_1".to_string(),
4413                    ToolCallDeltaContent::Name("add".to_string())
4414                ),
4415                (
4416                    "tool_1".to_string(),
4417                    "internal_1".to_string(),
4418                    ToolCallDeltaContent::Delta("{\"x\":".to_string())
4419                ),
4420                (
4421                    "tool_1".to_string(),
4422                    "internal_1".to_string(),
4423                    ToolCallDeltaContent::Delta("1}".to_string())
4424                ),
4425            ]
4426        );
4427    }
4428
4429    #[tokio::test]
4430    async fn stream_prompt_emits_tool_call_deltas_after_hook_continue() {
4431        let model = MockCompletionModel::from_stream_turns([[
4432            MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4433            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4434            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"),
4435            MockStreamEvent::final_response_with_total_tokens(3),
4436        ]]);
4437        let hook = RecordingToolCallDeltaHook::default();
4438        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4439
4440        let mut stream = agent
4441            .stream_prompt("stream a tool call")
4442            .with_hook(hook.clone())
4443            .await;
4444        let mut stream_deltas = Vec::new();
4445
4446        while let Some(item) = stream.next().await {
4447            match item {
4448                Ok(MultiTurnStreamItem::StreamAssistantItem(
4449                    StreamedAssistantContent::ToolCallDelta {
4450                        id,
4451                        internal_call_id,
4452                        content,
4453                    },
4454                )) => {
4455                    stream_deltas.push((id, internal_call_id, content));
4456                }
4457                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4458                Ok(_) => {}
4459                Err(err) => panic!("unexpected streaming error: {err:?}"),
4460            }
4461        }
4462
4463        assert_eq!(
4464            hook.observed(),
4465            vec![
4466                (
4467                    "tool_1".to_string(),
4468                    "internal_1".to_string(),
4469                    Some("add".to_string()),
4470                    String::new()
4471                ),
4472                (
4473                    "tool_1".to_string(),
4474                    "internal_1".to_string(),
4475                    None,
4476                    "{\"x\":".to_string()
4477                ),
4478                (
4479                    "tool_1".to_string(),
4480                    "internal_1".to_string(),
4481                    None,
4482                    "1}".to_string()
4483                ),
4484            ]
4485        );
4486        assert_eq!(
4487            stream_deltas,
4488            vec![
4489                (
4490                    "tool_1".to_string(),
4491                    "internal_1".to_string(),
4492                    ToolCallDeltaContent::Name("add".to_string())
4493                ),
4494                (
4495                    "tool_1".to_string(),
4496                    "internal_1".to_string(),
4497                    ToolCallDeltaContent::Delta("{\"x\":".to_string())
4498                ),
4499                (
4500                    "tool_1".to_string(),
4501                    "internal_1".to_string(),
4502                    ToolCallDeltaContent::Delta("1}".to_string())
4503                ),
4504            ]
4505        );
4506    }
4507
4508    #[tokio::test]
4509    async fn stream_prompt_tool_call_deltas_hook_termination_prevents_delta_emit() {
4510        let model = MockCompletionModel::from_stream_turns([[
4511            MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"),
4512            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"),
4513            MockStreamEvent::final_response_with_total_tokens(3),
4514        ]]);
4515        let hook = TerminatingToolCallDeltaHook::default();
4516        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4517
4518        let mut stream = agent
4519            .stream_prompt("stream a tool call")
4520            .with_hook(hook.clone())
4521            .await;
4522        let mut saw_delta = false;
4523        let mut saw_final_response = false;
4524        let mut error_message = None;
4525
4526        while let Some(item) = stream.next().await {
4527            match item {
4528                Ok(MultiTurnStreamItem::StreamAssistantItem(
4529                    StreamedAssistantContent::ToolCallDelta { .. },
4530                )) => {
4531                    saw_delta = true;
4532                }
4533                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
4534                    saw_final_response = true;
4535                }
4536                Ok(_) => {}
4537                Err(err) => {
4538                    error_message = Some(err.to_string());
4539                    break;
4540                }
4541            }
4542        }
4543
4544        assert_eq!(
4545            hook.observed(),
4546            vec![(
4547                "tool_1".to_string(),
4548                "internal_1".to_string(),
4549                Some("add".to_string()),
4550                String::new()
4551            )]
4552        );
4553        assert!(!saw_delta);
4554        assert!(!saw_final_response);
4555        assert!(
4556            error_message
4557                .as_deref()
4558                .is_some_and(|message| message.contains("PromptCancelled: stop on tool call delta")),
4559            "expected hook termination error, got {error_message:?}"
4560        );
4561    }
4562
4563    #[tokio::test]
4564    async fn stream_prompt_exposes_completion_calls() {
4565        let first_call_usage = usage(10, 2);
4566        let second_call_usage = usage(25, 5);
4567        let model = MockCompletionModel::from_stream_turns([
4568            vec![
4569                MockStreamEvent::tool_call(
4570                    "tool_call_1",
4571                    "add",
4572                    serde_json::json!({"x": 1, "y": 2}),
4573                )
4574                .with_call_id("call_1"),
4575                MockStreamEvent::final_response(first_call_usage),
4576            ],
4577            vec![
4578                MockStreamEvent::text("done"),
4579                MockStreamEvent::final_response(second_call_usage),
4580            ],
4581        ]);
4582        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4583        let empty_history: &[Message] = &[];
4584
4585        let mut stream = agent
4586            .stream_prompt("do tool work")
4587            .with_history(empty_history)
4588            .multi_turn(3)
4589            .await;
4590        let mut completion_calls_events = Vec::new();
4591        let mut final_response = None;
4592
4593        while let Some(item) = stream.next().await {
4594            match item {
4595                Ok(MultiTurnStreamItem::CompletionCall(call_usage)) => {
4596                    completion_calls_events.push(call_usage);
4597                }
4598                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4599                    final_response = Some(response);
4600                    break;
4601                }
4602                Ok(_) => {}
4603                Err(err) => panic!("unexpected streaming error: {err:?}"),
4604            }
4605        }
4606
4607        assert_eq!(
4608            completion_calls_events,
4609            vec![
4610                CompletionCall::new(0, first_call_usage),
4611                CompletionCall::new(1, second_call_usage)
4612            ]
4613        );
4614
4615        let final_response = final_response.expect("expected final response");
4616        assert_eq!(
4617            final_response.usage(),
4618            Usage {
4619                input_tokens: 35,
4620                output_tokens: 7,
4621                total_tokens: 42,
4622                cached_input_tokens: 0,
4623                cache_creation_input_tokens: 0,
4624                tool_use_prompt_tokens: 0,
4625                reasoning_tokens: 0,
4626            }
4627        );
4628        assert_eq!(
4629            final_response.completion_calls(),
4630            &[
4631                CompletionCall::new(0, first_call_usage),
4632                CompletionCall::new(1, second_call_usage)
4633            ]
4634        );
4635    }
4636
4637    #[tokio::test(flavor = "current_thread")]
4638    async fn stream_prompt_records_single_call_usage_on_chat_span_under_outer_span() {
4639        let call_usage = usage(10, 2);
4640        let model = MockCompletionModel::from_stream_turns([[
4641            MockStreamEvent::text("done"),
4642            MockStreamEvent::final_response(call_usage),
4643        ]]);
4644        let agent = AgentBuilder::new(model).build();
4645
4646        assert_stream_usage_recorded_on_chat_spans(agent, "say done", 1, &[call_usage]).await;
4647    }
4648
4649    #[tokio::test(flavor = "current_thread")]
4650    async fn stream_prompt_records_multi_turn_usage_on_chat_spans_under_outer_span() {
4651        let first_call_usage = usage(10, 2);
4652        let second_call_usage = usage(25, 5);
4653        let model = MockCompletionModel::from_stream_turns([
4654            vec![
4655                MockStreamEvent::tool_call(
4656                    "tool_call_1",
4657                    "add",
4658                    serde_json::json!({"x": 1, "y": 2}),
4659                )
4660                .with_call_id("call_1"),
4661                MockStreamEvent::final_response(first_call_usage),
4662            ],
4663            vec![
4664                MockStreamEvent::text("done"),
4665                MockStreamEvent::final_response(second_call_usage),
4666            ],
4667        ]);
4668        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4669
4670        assert_stream_usage_recorded_on_chat_spans(
4671            agent,
4672            "do tool work",
4673            3,
4674            &[first_call_usage, second_call_usage],
4675        )
4676        .await;
4677    }
4678
4679    #[tokio::test]
4680    async fn stream_prompt_emits_completion_call_before_finish_hook_termination() {
4681        let call_usage = usage(10, 2);
4682        let model = MockCompletionModel::from_stream_turns([[
4683            MockStreamEvent::text("done"),
4684            MockStreamEvent::final_response(call_usage),
4685        ]]);
4686        let agent = AgentBuilder::new(model).build();
4687
4688        let mut stream = agent
4689            .stream_prompt("say done")
4690            .with_hook(TerminateOnStreamFinish)
4691            .await;
4692        let mut completion_calls = Vec::new();
4693        let mut saw_error = false;
4694
4695        while let Some(item) = stream.next().await {
4696            match item {
4697                Ok(MultiTurnStreamItem::CompletionCall(completion_call)) => {
4698                    completion_calls.push(completion_call);
4699                }
4700                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4701                    panic!("unexpected final response after hook termination: {response:?}");
4702                }
4703                Ok(_) => {}
4704                Err(_) => {
4705                    saw_error = true;
4706                    break;
4707                }
4708            }
4709        }
4710
4711        assert_eq!(completion_calls, vec![CompletionCall::new(0, call_usage)]);
4712        assert!(saw_error);
4713    }
4714
4715    #[tokio::test]
4716    async fn stream_prompt_completion_calls_records_unreported_usage() {
4717        let second_call_usage = usage(25, 5);
4718        let model = MockCompletionModel::from_stream_turns([
4719            vec![
4720                MockStreamEvent::tool_call(
4721                    "tool_call_1",
4722                    "add",
4723                    serde_json::json!({"x": 1, "y": 2}),
4724                )
4725                .with_call_id("call_1"),
4726            ],
4727            vec![
4728                MockStreamEvent::text("done"),
4729                MockStreamEvent::final_response(second_call_usage),
4730            ],
4731        ]);
4732        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4733        let empty_history: &[Message] = &[];
4734
4735        let mut stream = agent
4736            .stream_prompt("do tool work")
4737            .with_history(empty_history)
4738            .multi_turn(3)
4739            .await;
4740        let mut completion_calls_events = Vec::new();
4741        let mut final_response = None;
4742
4743        while let Some(item) = stream.next().await {
4744            match item {
4745                Ok(MultiTurnStreamItem::CompletionCall(call_usage)) => {
4746                    completion_calls_events.push(call_usage);
4747                }
4748                Ok(MultiTurnStreamItem::FinalResponse(response)) => {
4749                    final_response = Some(response);
4750                    break;
4751                }
4752                Ok(_) => {}
4753                Err(err) => panic!("unexpected streaming error: {err:?}"),
4754            }
4755        }
4756
4757        let expected_usage = vec![
4758            CompletionCall::new(0, Usage::new()),
4759            CompletionCall::new(1, second_call_usage),
4760        ];
4761        assert_eq!(completion_calls_events, expected_usage);
4762
4763        let final_response = final_response.expect("expected final response");
4764        assert_eq!(final_response.completion_calls(), expected_usage.as_slice());
4765    }
4766
4767    #[tokio::test]
4768    async fn final_response_matches_streamed_text_when_provider_final_is_textless() {
4769        let agent = AgentBuilder::new(streaming_text_then_final_model()).build();
4770
4771        let mut stream = agent.stream_prompt("say hello").await;
4772        let mut streamed_text = String::new();
4773        let mut final_response_text = None;
4774
4775        while let Some(item) = stream.next().await {
4776            match item {
4777                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
4778                    text,
4779                ))) => streamed_text.push_str(&text.text),
4780                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
4781                    final_response_text = Some(res.response().to_owned());
4782                    break;
4783                }
4784                Ok(_) => {}
4785                Err(err) => panic!("unexpected streaming error: {err:?}"),
4786            }
4787        }
4788
4789        assert_eq!(streamed_text, "hello world");
4790        assert_eq!(final_response_text.as_deref(), Some("hello world"));
4791    }
4792
4793    #[tokio::test]
4794    async fn final_response_preserves_structured_text_metadata() {
4795        let agent = AgentBuilder::new(streaming_cited_text_then_final_model()).build();
4796
4797        let mut stream = agent.stream_prompt("answer with citations").await;
4798        let mut final_response = None;
4799
4800        while let Some(item) = stream.next().await {
4801            match item {
4802                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
4803                    final_response = Some(res);
4804                    break;
4805                }
4806                Ok(_) => {}
4807                Err(err) => panic!("unexpected streaming error: {err:?}"),
4808            }
4809        }
4810
4811        let final_response = final_response.expect("expected final response");
4812        assert_eq!(final_response.response(), "cited answer");
4813        let metadata = text_metadata(final_response.content())
4814            .expect("expected text metadata in final content");
4815        assert_eq!(
4816            metadata["citations"][0]["encrypted_index"],
4817            "encrypted-reference"
4818        );
4819    }
4820
4821    #[tokio::test]
4822    async fn final_response_history_preserves_structured_text_metadata() {
4823        let agent = AgentBuilder::new(streaming_cited_text_then_final_model()).build();
4824
4825        let empty_history: &[Message] = &[];
4826        let mut stream = agent
4827            .stream_prompt("answer with citations")
4828            .with_history(empty_history)
4829            .await;
4830        let mut final_response = None;
4831
4832        while let Some(item) = stream.next().await {
4833            match item {
4834                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
4835                    final_response = Some(res);
4836                    break;
4837                }
4838                Ok(_) => {}
4839                Err(err) => panic!("unexpected streaming error: {err:?}"),
4840            }
4841        }
4842
4843        let final_response = final_response.expect("expected final response");
4844        let history = final_response
4845            .history()
4846            .expect("with_history should include final history");
4847        let assistant_content = history
4848            .iter()
4849            .find_map(|message| match message {
4850                Message::Assistant { content, .. } => Some(content),
4851                _ => None,
4852            })
4853            .expect("expected assistant message in history");
4854        let metadata =
4855            text_metadata(assistant_content).expect("expected text metadata in assistant history");
4856        assert_eq!(
4857            metadata["citations"][0]["encrypted_index"],
4858            "encrypted-reference"
4859        );
4860    }
4861
4862    #[tokio::test]
4863    async fn tool_follow_up_history_preserves_structured_text_metadata() {
4864        let model = streaming_cited_text_then_tool_model();
4865        let recorded = model.clone();
4866        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
4867        let empty_history: &[Message] = &[];
4868
4869        let mut stream = agent
4870            .stream_prompt("use a tool with citations")
4871            .with_history(empty_history)
4872            .multi_turn(3)
4873            .await;
4874
4875        while let Some(item) = stream.next().await {
4876            match item {
4877                Ok(MultiTurnStreamItem::FinalResponse(_)) => break,
4878                Ok(_) => {}
4879                Err(err) => panic!("unexpected streaming error: {err:?}"),
4880            }
4881        }
4882
4883        let requests = recorded.requests();
4884        assert_eq!(requests.len(), 2);
4885        let follow_up_history = requests[1].chat_history.iter().collect::<Vec<_>>();
4886        let assistant_content = follow_up_history
4887            .iter()
4888            .find_map(|message| match message {
4889                Message::Assistant { content, .. } => Some(content),
4890                _ => None,
4891            })
4892            .expect("expected assistant message in follow-up history");
4893        let metadata = text_metadata(assistant_content)
4894            .expect("expected citation metadata in follow-up assistant history");
4895        assert_eq!(
4896            metadata["citations"][0]["encrypted_index"],
4897            "encrypted-reference"
4898        );
4899    }
4900
4901    #[tokio::test]
4902    async fn final_response_can_remain_empty_for_truly_textless_turns() {
4903        let agent = AgentBuilder::new(streaming_final_only_model()).build();
4904
4905        let mut stream = agent.stream_prompt("say nothing").await;
4906        let mut streamed_text = String::new();
4907        let mut final_response_text = None;
4908
4909        while let Some(item) = stream.next().await {
4910            match item {
4911                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
4912                    text,
4913                ))) => streamed_text.push_str(&text.text),
4914                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
4915                    final_response_text = Some(res.response().to_owned());
4916                    break;
4917                }
4918                Ok(_) => {}
4919                Err(err) => panic!("unexpected streaming error: {err:?}"),
4920            }
4921        }
4922
4923        assert!(streamed_text.is_empty());
4924        assert_eq!(final_response_text.as_deref(), Some(""));
4925    }
4926
4927    /// Background task that logs periodically to detect span leakage.
4928    /// If span leakage occurs, these logs will be prefixed with `invoke_agent{...}`.
4929    async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
4930        let mut interval = tokio::time::interval(Duration::from_millis(50));
4931        let mut count = 0u32;
4932
4933        while !stop.load(Ordering::Relaxed) {
4934            interval.tick().await;
4935            count += 1;
4936
4937            tracing::event!(
4938                target: "background_logger",
4939                tracing::Level::INFO,
4940                count = count,
4941                "Background tick"
4942            );
4943
4944            // Check if we're inside an unexpected span
4945            let current = tracing::Span::current();
4946            if !current.is_disabled() && !current.is_none() {
4947                leak_count.fetch_add(1, Ordering::Relaxed);
4948            }
4949        }
4950
4951        tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
4952    }
4953
4954    /// Test that span context doesn't leak to concurrent tasks during streaming.
4955    ///
4956    /// This test verifies that using `.instrument()` instead of `span.enter()` in
4957    /// async_stream prevents thread-local span context from leaking to other tasks.
4958    ///
4959    /// Uses single-threaded runtime to force all tasks onto the same thread,
4960    /// making the span leak deterministic (it only occurs when tasks share a thread).
4961    #[tokio::test(flavor = "current_thread")]
4962    #[ignore = "This requires an API key"]
4963    async fn test_span_context_isolation() -> anyhow::Result<()> {
4964        let stop = Arc::new(AtomicBool::new(false));
4965        let leak_count = Arc::new(AtomicU32::new(0));
4966
4967        // Start background logger
4968        let bg_stop = stop.clone();
4969        let bg_leak = leak_count.clone();
4970        let bg_handle = tokio::spawn(async move {
4971            background_logger(bg_stop, bg_leak).await;
4972        });
4973
4974        // Small delay to let background logger start
4975        tokio::time::sleep(Duration::from_millis(100)).await;
4976
4977        // Make streaming request WITHOUT an outer span so rig creates its own invoke_agent span
4978        // (rig reuses current span if one exists, so we need to ensure there's no current span)
4979        let client = anthropic::Client::from_env()?;
4980        let agent = client
4981            .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
4982            .preamble("You are a helpful assistant.")
4983            .temperature(0.1)
4984            .max_tokens(100)
4985            .build();
4986
4987        let mut stream = agent
4988            .stream_prompt("Say 'hello world' and nothing else.")
4989            .await;
4990
4991        let mut full_content = String::new();
4992        while let Some(item) = stream.next().await {
4993            match item {
4994                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
4995                    text,
4996                ))) => {
4997                    full_content.push_str(&text.text);
4998                }
4999                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
5000                    break;
5001                }
5002                Err(e) => {
5003                    tracing::warn!("Error: {:?}", e);
5004                    break;
5005                }
5006                _ => {}
5007            }
5008        }
5009
5010        tracing::info!("Got response: {:?}", full_content);
5011
5012        // Stop background logger
5013        stop.store(true, Ordering::Relaxed);
5014        bg_handle.await?;
5015
5016        let leaks = leak_count.load(Ordering::Relaxed);
5017        anyhow::ensure!(
5018            leaks == 0,
5019            "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
5020             This indicates that span.enter() is being used inside async_stream instead of .instrument()"
5021        );
5022
5023        Ok(())
5024    }
5025
5026    /// Test that FinalResponse contains the updated chat history when with_history is used.
5027    ///
5028    /// This verifies that:
5029    /// 1. FinalResponse.history() returns Some when with_history was called
5030    /// 2. The history contains both the user prompt and assistant response
5031    #[tokio::test]
5032    #[ignore = "This requires an API key"]
5033    async fn test_chat_history_in_final_response() -> anyhow::Result<()> {
5034        use crate::message::Message;
5035
5036        let client = anthropic::Client::from_env()?;
5037        let agent = client
5038            .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
5039            .preamble("You are a helpful assistant. Keep responses brief.")
5040            .temperature(0.1)
5041            .max_tokens(50)
5042            .build();
5043
5044        // Send streaming request with history
5045        let empty_history: &[Message] = &[];
5046        let mut stream = agent
5047            .stream_prompt("Say 'hello' and nothing else.")
5048            .with_history(empty_history)
5049            .await;
5050
5051        // Consume the stream and collect FinalResponse
5052        let mut response_text = String::new();
5053        let mut final_history = None;
5054        while let Some(item) = stream.next().await {
5055            match item {
5056                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
5057                    text,
5058                ))) => {
5059                    response_text.push_str(&text.text);
5060                }
5061                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5062                    final_history = res.history().map(|h| h.to_vec());
5063                    break;
5064                }
5065                Err(e) => {
5066                    return Err(e.into());
5067                }
5068                _ => {}
5069            }
5070        }
5071
5072        let history = final_history
5073            .ok_or_else(|| anyhow::anyhow!("final response should include history"))?;
5074
5075        // Should contain at least the user message
5076        anyhow::ensure!(
5077            history.iter().any(|m| matches!(m, Message::User { .. })),
5078            "History should contain the user message"
5079        );
5080
5081        // Should contain the assistant response
5082        anyhow::ensure!(
5083            history
5084                .iter()
5085                .any(|m| matches!(m, Message::Assistant { .. })),
5086            "History should contain the assistant response"
5087        );
5088
5089        tracing::info!(
5090            "History after streaming: {} messages, response: {:?}",
5091            history.len(),
5092            response_text
5093        );
5094
5095        Ok(())
5096    }
5097
5098    #[tokio::test]
5099    async fn streaming_appends_to_memory_after_final_response() {
5100        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5101
5102        let memory = InMemoryConversationMemory::new();
5103        let agent = AgentBuilder::new(streaming_text_then_final_model())
5104            .memory(memory.clone())
5105            .build();
5106
5107        let mut stream = agent
5108            .stream_prompt("hi there")
5109            .conversation("stream-thread")
5110            .await;
5111
5112        let mut history_in_final = None;
5113        while let Some(item) = stream.next().await {
5114            match item {
5115                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5116                    history_in_final = res.history().map(|h| h.to_vec());
5117                    break;
5118                }
5119                Ok(_) => {}
5120                Err(err) => panic!("unexpected streaming error: {err:?}"),
5121            }
5122        }
5123
5124        let final_history = history_in_final
5125            .expect("FinalResponse.history should be populated when memory is configured");
5126        assert_eq!(
5127            final_history.len(),
5128            2,
5129            "user prompt + assistant response in final history: {final_history:?}"
5130        );
5131
5132        let stored = memory.load("stream-thread").await.unwrap();
5133        assert_eq!(stored.len(), 2, "memory should contain user + assistant");
5134    }
5135
5136    #[tokio::test]
5137    async fn streaming_reasoning_without_tools_does_not_duplicate_final_history() {
5138        let agent = AgentBuilder::new(MockCompletionModel::from_stream_turns([[
5139            MockStreamEvent::text("final answer"),
5140            MockStreamEvent::reasoning("reasoned step").with_reasoning_id("rs_1"),
5141            MockStreamEvent::final_response_with_total_tokens(3),
5142        ]]))
5143        .build();
5144
5145        let mut stream = agent
5146            .stream_prompt("think before answering")
5147            .with_history(Vec::<Message>::new())
5148            .await;
5149
5150        let mut history_in_final = None;
5151        while let Some(item) = stream.next().await {
5152            match item {
5153                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
5154                    history_in_final = res.history().map(|h| h.to_vec());
5155                    break;
5156                }
5157                Ok(_) => {}
5158                Err(err) => panic!("unexpected streaming error: {err:?}"),
5159            }
5160        }
5161
5162        let final_history = history_in_final
5163            .expect("FinalResponse.history should be populated when with_history is used");
5164        assert_eq!(
5165            final_history.len(),
5166            2,
5167            "user prompt + one assistant response in final history: {final_history:?}"
5168        );
5169
5170        assert!(matches!(
5171            final_history.first(),
5172            Some(Message::User { content })
5173                if matches!(
5174                    content.first(),
5175                    UserContent::Text(text) if text.text == "think before answering"
5176                )
5177        ));
5178
5179        let assistant_messages = final_history
5180            .iter()
5181            .filter_map(|message| match message {
5182                Message::Assistant { content, .. } => Some(content),
5183                _ => None,
5184            })
5185            .collect::<Vec<_>>();
5186        assert_eq!(
5187            assistant_messages.len(),
5188            1,
5189            "reasoning turn should produce exactly one assistant history message: {final_history:?}"
5190        );
5191        let assistant_content = assistant_messages
5192            .first()
5193            .expect("expected assistant history message");
5194        assert!(assistant_content.iter().any(|item| matches!(
5195            item,
5196            AssistantContent::Text(text) if text.text == "final answer"
5197        )));
5198        assert!(assistant_content.iter().any(|item| matches!(
5199            item,
5200            AssistantContent::Reasoning(reasoning)
5201                if reasoning.id.as_deref() == Some("rs_1")
5202                    && reasoning.content.iter().any(|content| matches!(
5203                        content,
5204                        ReasoningContent::Text { text, .. } if text == "reasoned step"
5205                    ))
5206        )));
5207        let reasoning_index = assistant_content
5208            .iter()
5209            .position(|item| matches!(item, AssistantContent::Reasoning(_)))
5210            .expect("assistant history should contain reasoning");
5211        let text_index = assistant_content
5212            .iter()
5213            .position(|item| matches!(item, AssistantContent::Text(_)))
5214            .expect("assistant history should contain text");
5215        assert!(
5216            reasoning_index < text_index,
5217            "assistant reasoning must be stored before assistant text: {assistant_content:?}"
5218        );
5219    }
5220
5221    #[tokio::test]
5222    async fn streaming_with_history_overrides_memory() {
5223        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5224
5225        let memory = InMemoryConversationMemory::new();
5226        memory
5227            .append("t1", vec![Message::user("from-memory")])
5228            .await
5229            .unwrap();
5230
5231        let agent = AgentBuilder::new(streaming_text_then_final_model())
5232            .memory(memory.clone())
5233            .build();
5234
5235        let mut stream = agent
5236            .stream_prompt("hi")
5237            .conversation("t1")
5238            .with_history(vec![Message::user("from-caller")])
5239            .await;
5240
5241        while let Some(item) = stream.next().await {
5242            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5243                break;
5244            }
5245        }
5246
5247        let stored = memory.load("t1").await.unwrap();
5248        assert_eq!(
5249            stored.len(),
5250            1,
5251            "with_history bypasses memory; only the pre-seeded entry remains: {stored:?}"
5252        );
5253    }
5254
5255    #[tokio::test]
5256    async fn streaming_without_memory_disables_for_request() {
5257        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5258
5259        let memory = InMemoryConversationMemory::new();
5260        let agent = AgentBuilder::new(streaming_text_then_final_model())
5261            .memory(memory.clone())
5262            .conversation_id("default")
5263            .build();
5264
5265        let mut stream = agent.stream_prompt("hi").without_memory().await;
5266
5267        while let Some(item) = stream.next().await {
5268            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5269                break;
5270            }
5271        }
5272
5273        let stored = memory.load("default").await.unwrap();
5274        assert!(stored.is_empty(), "without_memory disables save");
5275    }
5276
5277    #[tokio::test]
5278    async fn streaming_load_error_yields_memory_error() {
5279        let agent = AgentBuilder::new(streaming_text_then_final_model())
5280            .memory(FailingMemory::default())
5281            .build();
5282
5283        let mut stream = agent.stream_prompt("hi").conversation("t1").await;
5284
5285        let first = stream.next().await.expect("at least one item");
5286        match first {
5287            Err(err) => {
5288                let msg = format!("{err:?}");
5289                assert!(
5290                    msg.contains("Memory") || msg.contains("memory") || msg.contains("load boom"),
5291                    "expected memory error, got: {msg}"
5292                );
5293            }
5294            Ok(other) => panic!("expected memory error, got {other:?}"),
5295        }
5296    }
5297
5298    #[tokio::test]
5299    async fn streaming_with_filter_shapes_loaded_history() {
5300        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
5301
5302        let memory = InMemoryConversationMemory::new()
5303            .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
5304        memory
5305            .append(
5306                "t1",
5307                vec![
5308                    Message::user("1"),
5309                    Message::assistant("2"),
5310                    Message::user("3"),
5311                    Message::assistant("4"),
5312                ],
5313            )
5314            .await
5315            .unwrap();
5316
5317        let model = MockCompletionModel::from_stream_turns([[
5318            MockStreamEvent::text("ok"),
5319            MockStreamEvent::final_response_with_total_tokens(1),
5320        ]]);
5321        let recorded = model.clone();
5322        let agent = AgentBuilder::new(model).memory(memory).build();
5323
5324        let mut stream = agent.stream_prompt("ping").conversation("t1").await;
5325        while let Some(item) = stream.next().await {
5326            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5327                break;
5328            }
5329        }
5330
5331        let received = recorded.requests()[0]
5332            .chat_history
5333            .iter()
5334            .cloned()
5335            .collect::<Vec<_>>();
5336        assert_eq!(
5337            received.len(),
5338            3,
5339            "window-truncated history (2) + current prompt: {received:?}"
5340        );
5341    }
5342
5343    #[tokio::test]
5344    async fn streaming_append_error_does_not_suppress_final_response() {
5345        let agent = AgentBuilder::new(streaming_text_then_final_model())
5346            .memory(AppendFailingMemory::default())
5347            .build();
5348
5349        let mut stream = agent.stream_prompt("hi").conversation("t1").await;
5350
5351        let mut saw_final = false;
5352        while let Some(item) = stream.next().await {
5353            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
5354                saw_final = true;
5355                break;
5356            }
5357        }
5358        assert!(
5359            saw_final,
5360            "FinalResponse must be yielded even when memory.append fails"
5361        );
5362    }
5363}