Skip to main content

rig_core/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::completion::{DynamicContextStore, build_completion_request},
4    agent::prompt_request::{HookAction, hooks::PromptHook},
5    completion::{Document, GetTokenUsage},
6    json_utils,
7    memory::ConversationMemory,
8    message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent},
9    streaming::{StreamedAssistantContent, StreamedUserContent},
10    tool::server::ToolServerHandle,
11    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
12};
13use futures::{Stream, StreamExt};
14use serde::{Deserialize, Serialize};
15use std::{pin::Pin, sync::Arc};
16use tracing::info_span;
17use tracing_futures::Instrument;
18
19use super::ToolCallHookAction;
20use crate::{
21    agent::Agent,
22    completion::{CompletionError, CompletionModel, PromptError},
23    message::{Message, Text},
24    tool::ToolSetError,
25};
26
27#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
28pub type StreamingResult<R> =
29    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
30
31#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
32pub type StreamingResult<R> =
33    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
34
35#[derive(Deserialize, Serialize, Debug, Clone)]
36#[serde(tag = "type", rename_all = "camelCase")]
37#[non_exhaustive]
38pub enum MultiTurnStreamItem<R> {
39    /// A streamed assistant content item.
40    StreamAssistantItem(StreamedAssistantContent<R>),
41    /// A streamed user content item (mostly for tool results).
42    StreamUserItem(StreamedUserContent),
43    /// The final result from the stream.
44    FinalResponse(FinalResponse),
45}
46
47#[derive(Deserialize, Serialize, Debug, Clone)]
48#[serde(rename_all = "camelCase")]
49pub struct FinalResponse {
50    /// Concatenated assistant text for the final turn.
51    /// This is empty only when the turn completed without emitting any text.
52    response: String,
53    aggregated_usage: crate::completion::Usage,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    history: Option<Vec<Message>>,
56}
57
58impl FinalResponse {
59    pub fn empty() -> Self {
60        Self {
61            response: String::new(),
62            aggregated_usage: crate::completion::Usage::new(),
63            history: None,
64        }
65    }
66
67    /// Returns the concatenated assistant text for the final turn.
68    pub fn response(&self) -> &str {
69        &self.response
70    }
71
72    pub fn usage(&self) -> crate::completion::Usage {
73        self.aggregated_usage
74    }
75
76    pub fn history(&self) -> Option<&[Message]> {
77        self.history.as_deref()
78    }
79}
80
81impl<R> MultiTurnStreamItem<R> {
82    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
83        Self::StreamAssistantItem(item)
84    }
85
86    pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
87        Self::FinalResponse(FinalResponse {
88            response: response.to_string(),
89            aggregated_usage,
90            history: None,
91        })
92    }
93
94    pub fn final_response_with_history(
95        response: &str,
96        aggregated_usage: crate::completion::Usage,
97        history: Option<Vec<Message>>,
98    ) -> Self {
99        Self::FinalResponse(FinalResponse {
100            response: response.to_string(),
101            aggregated_usage,
102            history,
103        })
104    }
105}
106
107fn merge_reasoning_blocks(
108    accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
109    incoming: &crate::message::Reasoning,
110) {
111    let ids_match = |existing: &crate::message::Reasoning| {
112        matches!(
113            (&existing.id, &incoming.id),
114            (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
115        )
116    };
117
118    if let Some(existing) = accumulated_reasoning
119        .iter_mut()
120        .rev()
121        .find(|existing| ids_match(existing))
122    {
123        existing.content.extend(incoming.content.clone());
124    } else {
125        accumulated_reasoning.push(incoming.clone());
126    }
127}
128
129/// Build full history for error reporting (input + new messages).
130fn build_full_history(
131    chat_history: Option<&[Message]>,
132    new_messages: Vec<Message>,
133) -> Vec<Message> {
134    let input = chat_history.unwrap_or(&[]);
135    input.iter().cloned().chain(new_messages).collect()
136}
137
138/// Combine input history with new messages for building completion requests.
139fn build_history_for_request(
140    chat_history: Option<&[Message]>,
141    new_messages: &[Message],
142) -> Vec<Message> {
143    let input = chat_history.unwrap_or(&[]);
144    input.iter().chain(new_messages.iter()).cloned().collect()
145}
146
147async fn cancelled_prompt_error(
148    chat_history: Option<&[Message]>,
149    new_messages: Vec<Message>,
150    reason: String,
151) -> StreamingError {
152    StreamingError::Prompt(
153        PromptError::prompt_cancelled(build_full_history(chat_history, new_messages), reason)
154            .into(),
155    )
156}
157
158fn tool_result_to_user_message(
159    id: String,
160    call_id: Option<String>,
161    tool_result: String,
162) -> Message {
163    let content = ToolResultContent::from_tool_output(tool_result);
164    let user_content = match call_id {
165        Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
166        None => UserContent::tool_result(id, content),
167    };
168
169    Message::User {
170        content: OneOrMany::one(user_content),
171    }
172}
173
174fn assistant_text_from_choice(choice: &OneOrMany<AssistantContent>) -> String {
175    choice
176        .iter()
177        .filter_map(|content| match content {
178            AssistantContent::Text(text) => Some(text.text.as_str()),
179            _ => None,
180        })
181        .collect()
182}
183
184#[derive(Debug, thiserror::Error)]
185pub enum StreamingError {
186    #[error("CompletionError: {0}")]
187    Completion(#[from] CompletionError),
188    #[error("PromptError: {0}")]
189    Prompt(#[from] Box<PromptError>),
190    #[error("ToolSetError: {0}")]
191    Tool(#[from] ToolSetError),
192}
193
194/// Surface [`crate::memory::ConversationMemory`] failures through the existing
195/// [`CompletionError::RequestError`] variant so adding memory support does not
196/// require a new top-level [`StreamingError`] arm.
197impl From<crate::memory::MemoryError> for StreamingError {
198    fn from(err: crate::memory::MemoryError) -> Self {
199        Self::Completion(CompletionError::RequestError(Box::new(err)))
200    }
201}
202
203const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
204
205/// A builder for creating prompt requests with customizable options.
206/// Uses generics to track which options have been set during the build process.
207///
208/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
209/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
210/// attempting to await (which will send the prompt request) can potentially return
211/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
212/// back to back.
213pub struct StreamingPromptRequest<M, P>
214where
215    M: CompletionModel,
216    P: PromptHook<M> + 'static,
217{
218    /// The prompt message to send to the model
219    prompt: Message,
220    /// Optional chat history provided by the caller.
221    chat_history: Option<Vec<Message>>,
222    /// Maximum Turns for multi-turn conversations (0 means no multi-turn)
223    max_turns: usize,
224
225    // Agent data (cloned from agent to allow hook type transitions):
226    /// The completion model
227    model: Arc<M>,
228    /// Agent name for logging
229    agent_name: Option<String>,
230    /// System prompt
231    preamble: Option<String>,
232    /// Static context documents
233    static_context: Vec<Document>,
234    /// Temperature setting
235    temperature: Option<f64>,
236    /// Max tokens setting
237    max_tokens: Option<u64>,
238    /// Additional model parameters
239    additional_params: Option<serde_json::Value>,
240    /// Tool server handle for tool execution
241    tool_server_handle: ToolServerHandle,
242    /// Dynamic context store
243    dynamic_context: DynamicContextStore,
244    /// Tool choice setting
245    tool_choice: Option<ToolChoice>,
246    /// Optional JSON Schema for structured output
247    output_schema: Option<schemars::Schema>,
248    /// Optional per-request hook for events
249    hook: Option<P>,
250    /// Optional conversation memory backend cloned from the agent.
251    memory: Option<Arc<dyn ConversationMemory>>,
252    /// Optional conversation id used for loading and saving memory.
253    conversation_id: Option<String>,
254}
255
256impl<M, P> StreamingPromptRequest<M, P>
257where
258    M: CompletionModel + 'static,
259    <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
260    P: PromptHook<M>,
261{
262    /// Create a new StreamingPromptRequest with the given prompt and model.
263    /// Note: This creates a request without an agent hook. Use `from_agent` to include the agent's hook.
264    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
265        StreamingPromptRequest {
266            prompt: prompt.into(),
267            chat_history: None,
268            max_turns: agent.default_max_turns.unwrap_or_default(),
269            model: agent.model.clone(),
270            agent_name: agent.name.clone(),
271            preamble: agent.preamble.clone(),
272            static_context: agent.static_context.clone(),
273            temperature: agent.temperature,
274            max_tokens: agent.max_tokens,
275            additional_params: agent.additional_params.clone(),
276            tool_server_handle: agent.tool_server_handle.clone(),
277            dynamic_context: agent.dynamic_context.clone(),
278            tool_choice: agent.tool_choice.clone(),
279            output_schema: agent.output_schema.clone(),
280            hook: None,
281            memory: agent.memory.clone(),
282            conversation_id: agent.default_conversation_id.clone(),
283        }
284    }
285
286    /// Create a new StreamingPromptRequest from an agent, cloning the agent's data and default hook.
287    pub fn from_agent<P2>(
288        agent: &Agent<M, P2>,
289        prompt: impl Into<Message>,
290    ) -> StreamingPromptRequest<M, P2>
291    where
292        P2: PromptHook<M>,
293    {
294        StreamingPromptRequest {
295            prompt: prompt.into(),
296            chat_history: None,
297            max_turns: agent.default_max_turns.unwrap_or_default(),
298            model: agent.model.clone(),
299            agent_name: agent.name.clone(),
300            preamble: agent.preamble.clone(),
301            static_context: agent.static_context.clone(),
302            temperature: agent.temperature,
303            max_tokens: agent.max_tokens,
304            additional_params: agent.additional_params.clone(),
305            tool_server_handle: agent.tool_server_handle.clone(),
306            dynamic_context: agent.dynamic_context.clone(),
307            tool_choice: agent.tool_choice.clone(),
308            output_schema: agent.output_schema.clone(),
309            hook: agent.hook.clone(),
310            memory: agent.memory.clone(),
311            conversation_id: agent.default_conversation_id.clone(),
312        }
313    }
314
315    fn agent_name(&self) -> &str {
316        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
317    }
318
319    /// Set the maximum Turns for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
320    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
321    pub fn multi_turn(mut self, turns: usize) -> Self {
322        self.max_turns = turns;
323        self
324    }
325
326    /// Add chat history to the prompt request.
327    ///
328    /// When history is provided, the final [`FinalResponse`] will include the
329    /// updated chat history (original messages + new user prompt + assistant response).
330    /// ```ignore
331    /// let mut stream = agent
332    ///     .stream_prompt("Hello")
333    ///     .with_history(vec![])
334    ///     .await;
335    /// // ... consume stream ...
336    /// // Access updated history from FinalResponse::history()
337    /// ```
338    pub fn with_history<I, T>(mut self, history: I) -> Self
339    where
340        I: IntoIterator<Item = T>,
341        T: Into<Message>,
342    {
343        self.chat_history = Some(history.into_iter().map(Into::into).collect());
344        self
345    }
346
347    /// Attach a per-request hook for tool call events.
348    /// This overrides any default hook set on the agent.
349    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
350    where
351        P2: PromptHook<M>,
352    {
353        StreamingPromptRequest {
354            prompt: self.prompt,
355            chat_history: self.chat_history,
356            max_turns: self.max_turns,
357            model: self.model,
358            agent_name: self.agent_name,
359            preamble: self.preamble,
360            static_context: self.static_context,
361            temperature: self.temperature,
362            max_tokens: self.max_tokens,
363            additional_params: self.additional_params,
364            tool_server_handle: self.tool_server_handle,
365            dynamic_context: self.dynamic_context,
366            tool_choice: self.tool_choice,
367            output_schema: self.output_schema,
368            hook: Some(hook),
369            memory: self.memory,
370            conversation_id: self.conversation_id,
371        }
372    }
373
374    /// Set the conversation id used to load and persist memory for this request.
375    ///
376    /// Overrides any default conversation id set on the agent. If memory is not
377    /// configured on the agent, this has no effect.
378    pub fn conversation(mut self, id: impl Into<String>) -> Self {
379        self.conversation_id = Some(id.into());
380        self
381    }
382
383    /// Disable conversation memory for this request.
384    ///
385    /// History will neither be loaded from nor saved to the agent's memory backend.
386    pub fn without_memory(mut self) -> Self {
387        self.memory = None;
388        self.conversation_id = None;
389        self
390    }
391
392    async fn send(self) -> StreamingResult<M::StreamingResponse> {
393        let agent_span = if tracing::Span::current().is_disabled() {
394            info_span!(
395                "invoke_agent",
396                gen_ai.operation.name = "invoke_agent",
397                gen_ai.agent.name = self.agent_name(),
398                gen_ai.system_instructions = self.preamble,
399                gen_ai.prompt = tracing::field::Empty,
400                gen_ai.completion = tracing::field::Empty,
401                gen_ai.usage.input_tokens = tracing::field::Empty,
402                gen_ai.usage.output_tokens = tracing::field::Empty,
403                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
404                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
405                gen_ai.usage.reasoning_tokens = tracing::field::Empty,
406            )
407        } else {
408            tracing::Span::current()
409        };
410
411        let prompt = self.prompt;
412        if let Some(text) = prompt.rag_text() {
413            agent_span.record("gen_ai.prompt", text);
414        }
415
416        // Clone fields needed inside the stream
417        let model = self.model.clone();
418        let preamble = self.preamble.clone();
419        let static_context = self.static_context.clone();
420        let temperature = self.temperature;
421        let max_tokens = self.max_tokens;
422        let additional_params = self.additional_params.clone();
423        let tool_server_handle = self.tool_server_handle.clone();
424        let dynamic_context = self.dynamic_context.clone();
425        let tool_choice = self.tool_choice.clone();
426        let agent_name = self.agent_name.clone();
427        // When the caller passes explicit history, memory is fully bypassed for
428        // this request (no load AND no save). Otherwise, if a memory backend and
429        // conversation id are both configured, load prior history; if either is
430        // missing, behave as if no memory is configured.
431        let (chat_history, memory_handle) = match self.chat_history {
432            Some(history) => (Some(history), None),
433            None => match (self.memory, self.conversation_id) {
434                (Some(memory), Some(id)) => match memory.load(&id).await {
435                    Ok(loaded) => (Some(loaded), Some((memory, id))),
436                    Err(err) => {
437                        let stream = async_stream::stream! {
438                            yield Err(StreamingError::from(err));
439                        };
440                        return Box::pin(stream);
441                    }
442                },
443                _ => (None, None),
444            },
445        };
446        let has_history = chat_history.is_some();
447        let mut new_messages: Vec<Message> = vec![prompt.clone()];
448
449        let mut current_max_turns = 0;
450        let mut last_prompt_error = String::new();
451
452        let mut text_delta_response = String::new();
453        let mut saw_text_this_turn = false;
454        let mut max_turns_reached = false;
455        let output_schema = self.output_schema;
456
457        let mut aggregated_usage = crate::completion::Usage::new();
458
459        // NOTE: We use .instrument(agent_span) instead of span.enter() to avoid
460        // span context leaking to other concurrent tasks. Using span.enter() inside
461        // async_stream::stream! holds the guard across yield points, which causes
462        // thread-local span context to leak when other tasks run on the same thread.
463        // See: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#in-asynchronous-code
464        // See also: https://github.com/rust-lang/rust-clippy/issues/8722
465        let stream = async_stream::stream! {
466            'outer: loop {
467                let Some((current_prompt_ref, previous_messages)) = new_messages.split_last() else {
468                    yield Err(cancelled_prompt_error(
469                        chat_history.as_deref(),
470                        new_messages.clone(),
471                        "streaming loop lost its pending prompt".to_string(),
472                    ).await);
473                    break 'outer;
474                };
475                let current_prompt = current_prompt_ref.clone();
476
477                if current_max_turns > self.max_turns + 1 {
478                    last_prompt_error = current_prompt.rag_text().unwrap_or_default();
479                    max_turns_reached = true;
480                    break;
481                }
482
483                current_max_turns += 1;
484
485                if self.max_turns > 1 {
486                    tracing::info!(
487                        "Current conversation Turns: {}/{}",
488                        current_max_turns,
489                        self.max_turns
490                    );
491                }
492
493                let history_snapshot: Vec<Message> = build_history_for_request(
494                    chat_history.as_deref(),
495                    previous_messages,
496                );
497
498                if let Some(ref hook) = self.hook
499                    && let HookAction::Terminate { reason } =
500                        hook.on_completion_call(&current_prompt, &history_snapshot).await
501                {
502                    yield Err(
503                        cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason)
504                            .await,
505                    );
506                    break 'outer;
507                }
508
509                let chat_stream_span = info_span!(
510                    target: "rig::agent_chat",
511                    parent: tracing::Span::current(),
512                    "chat_streaming",
513                    gen_ai.operation.name = "chat",
514                    gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
515                    gen_ai.system_instructions = preamble,
516                    gen_ai.provider.name = tracing::field::Empty,
517                    gen_ai.request.model = tracing::field::Empty,
518                    gen_ai.response.id = tracing::field::Empty,
519                    gen_ai.response.model = tracing::field::Empty,
520                    gen_ai.usage.output_tokens = tracing::field::Empty,
521                    gen_ai.usage.input_tokens = tracing::field::Empty,
522                    gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
523                    gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
524                    gen_ai.usage.reasoning_tokens = tracing::field::Empty,
525                    gen_ai.input.messages = tracing::field::Empty,
526                    gen_ai.output.messages = tracing::field::Empty,
527                );
528
529                let mut stream = tracing::Instrument::instrument(
530                    build_completion_request(
531                        &model,
532                        current_prompt.clone(),
533                        &history_snapshot,
534                        preamble.as_deref(),
535                        &static_context,
536                        temperature,
537                        max_tokens,
538                        additional_params.as_ref(),
539                        tool_choice.as_ref(),
540                        &tool_server_handle,
541                        &dynamic_context,
542                        output_schema.as_ref(),
543                    )
544                    .await?
545                    .stream(), chat_stream_span
546                )
547
548                .await?;
549
550                let mut tool_calls = vec![];
551                let mut tool_results = vec![];
552                let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
553                // Kept separate from accumulated_reasoning so providers requiring
554                // signatures (e.g. Anthropic) never see unsigned blocks.
555                let mut pending_reasoning_delta_text = String::new();
556                let mut pending_reasoning_delta_id: Option<String> = None;
557                let mut saw_tool_call_this_turn = false;
558
559                while let Some(content) = stream.next().await {
560                    match content {
561                        Ok(StreamedAssistantContent::Text(text)) => {
562                            if !saw_text_this_turn {
563                                text_delta_response.clear();
564                                saw_text_this_turn = true;
565                            }
566                            text_delta_response.push_str(&text.text);
567                            if let Some(ref hook) = self.hook &&
568                                let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &text_delta_response).await {
569                                    yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
570                                    break 'outer;
571                            }
572
573                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
574                        },
575                        Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id }) => {
576                            let tool_span = info_span!(
577                                parent: tracing::Span::current(),
578                                "execute_tool",
579                                gen_ai.operation.name = "execute_tool",
580                                gen_ai.tool.type = "function",
581                                gen_ai.tool.name = tracing::field::Empty,
582                                gen_ai.tool.call.id = tracing::field::Empty,
583                                gen_ai.tool.call.arguments = tracing::field::Empty,
584                                gen_ai.tool.call.result = tracing::field::Empty
585                            );
586
587                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
588
589                            let tc_result = async {
590                                let tool_span = tracing::Span::current();
591                                let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
592                                if let Some(ref hook) = self.hook {
593                                    let action = hook
594                                        .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
595                                        .await;
596
597                                    if let ToolCallHookAction::Terminate { reason } = action {
598                                        return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
599                                    }
600
601                                    if let ToolCallHookAction::Skip { reason } = action {
602                                        // Tool execution rejected, return rejection message as tool result
603                                        tracing::info!(
604                                            tool_name = tool_call.function.name.as_str(),
605                                            reason = reason,
606                                            "Tool call rejected"
607                                        );
608                                        let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
609                                        tool_calls.push(tool_call_msg);
610                                        tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
611                                        saw_tool_call_this_turn = true;
612                                        return Ok(reason);
613                                    }
614                                }
615
616                                tool_span.record("gen_ai.tool.name", &tool_call.function.name);
617                                tool_span.record("gen_ai.tool.call.arguments", &tool_args);
618
619                                let tool_result = match
620                                tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
621                                    Ok(thing) => thing,
622                                    Err(e) => {
623                                        tracing::warn!("Error while calling tool: {e}");
624                                        e.to_string()
625                                    }
626                                };
627
628                                tool_span.record("gen_ai.tool.call.result", &tool_result);
629
630                                if let Some(ref hook) = self.hook &&
631                                    let HookAction::Terminate { reason } =
632                                    hook.on_tool_result(
633                                        &tool_call.function.name,
634                                        tool_call.call_id.clone(),
635                                        &internal_call_id,
636                                        &tool_args,
637                                        &tool_result.to_string()
638                                    )
639                                    .await {
640                                        return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
641                                    }
642
643                                let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
644
645                                tool_calls.push(tool_call_msg);
646                                tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
647
648                                saw_tool_call_this_turn = true;
649                                Ok(tool_result)
650                            }.instrument(tool_span).await;
651
652                            match tc_result {
653                                Ok(text) => {
654                                    let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
655                                    yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
656                                }
657                                Err(e) => {
658                                    yield Err(e);
659                                    break 'outer;
660                                }
661                            }
662                        },
663                        Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content }) => {
664                            if let Some(ref hook) = self.hook {
665                                let (name, delta) = match &content {
666                                    rig::streaming::ToolCallDeltaContent::Name(n) => {
667                                        (Some(n.as_str()), "")
668                                    }
669                                    rig::streaming::ToolCallDeltaContent::Delta(d) => {
670                                        (None, d.as_str())
671                                    }
672                                };
673
674                                if let HookAction::Terminate { reason } = hook.on_tool_call_delta(&id, &internal_call_id, name, delta)
675                                .await {
676                                    yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
677                                    break 'outer;
678                                }
679                            }
680                        }
681                        Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
682                            // Accumulate reasoning for inclusion in chat history with tool calls.
683                            // OpenAI Responses API requires reasoning items to be sent back
684                            // alongside function_call items in multi-turn conversations.
685                            merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
686                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
687                        },
688                        Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
689                            // Deltas lack signatures/encrypted content that full
690                            // blocks carry; mixing them into accumulated_reasoning
691                            // causes Anthropic to reject with "signature required".
692                            pending_reasoning_delta_text.push_str(&reasoning);
693                            if pending_reasoning_delta_id.is_none() {
694                                pending_reasoning_delta_id = id.clone();
695                            }
696                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
697                        },
698                        Ok(StreamedAssistantContent::Final(final_resp)) => {
699                            if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
700                            if saw_text_this_turn {
701                                if let Some(ref hook) = self.hook &&
702                                     let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(&current_prompt, &final_resp).await {
703                                        yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
704                                        break 'outer;
705                                    }
706
707                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
708                                saw_text_this_turn = false;
709                            }
710                        }
711                        Err(e) => {
712                            yield Err(e.into());
713                            break 'outer;
714                        }
715                    }
716                }
717
718                // Providers like Gemini emit thinking as incremental deltas
719                // without signatures; assemble into a single block so
720                // reasoning survives into the next turn's chat history.
721                if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
722                    let mut assembled = crate::message::Reasoning::new(&pending_reasoning_delta_text);
723                    if let Some(id) = pending_reasoning_delta_id.take() {
724                        assembled = assembled.with_id(id);
725                    }
726                    accumulated_reasoning.push(assembled);
727                }
728
729                let turn_text_response = assistant_text_from_choice(&stream.choice);
730                tracing::Span::current().record("gen_ai.completion", &turn_text_response);
731
732                // Add text, reasoning, and tool calls to chat history.
733                // OpenAI Responses API requires reasoning items to precede function_call items.
734                if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
735                    let mut content_items: Vec<rig::message::AssistantContent> = vec![];
736
737                    // Text before tool calls so the model sees its own prior output
738                    if !turn_text_response.is_empty() {
739                        content_items.push(rig::message::AssistantContent::text(&turn_text_response));
740                    }
741
742                    // Reasoning must come before tool calls (OpenAI requirement)
743                    for reasoning in accumulated_reasoning.drain(..) {
744                        content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
745                    }
746
747                    content_items.extend(tool_calls.clone());
748
749                    if let Some(content) = OneOrMany::from_iter_optional(content_items) {
750                        new_messages.push(Message::Assistant {
751                            id: stream.message_id.clone(),
752                            content,
753                        });
754                    }
755                }
756
757                for (id, call_id, tool_result) in tool_results {
758                    new_messages.push(tool_result_to_user_message(id, call_id, tool_result));
759                }
760
761                if !saw_tool_call_this_turn {
762                    // Add user message and assistant response to history before finishing
763                    if !turn_text_response.is_empty() {
764                        new_messages.push(Message::assistant(&turn_text_response));
765                    } else {
766                        tracing::warn!(
767                            agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
768                            message_id = ?stream.message_id,
769                            "Streaming turn completed without assistant text; final response will be empty"
770                        );
771                    }
772
773                    let current_span = tracing::Span::current();
774                    current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
775                    current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
776                    current_span.record("gen_ai.usage.cache_read.input_tokens", aggregated_usage.cached_input_tokens);
777                    current_span.record("gen_ai.usage.cache_creation.input_tokens", aggregated_usage.cache_creation_input_tokens);
778                    current_span.record("gen_ai.usage.reasoning_tokens", aggregated_usage.reasoning_tokens);
779                    tracing::info!("Agent multi-turn stream finished");
780                    if let Some((memory, id)) = memory_handle.as_ref()
781                        && let Err(err) = memory.append(id, new_messages.clone()).await
782                    {
783                        tracing::warn!(
784                            error = %err,
785                            conversation_id = %id,
786                            "conversation memory append failed; yielding final response anyway"
787                        );
788                    }
789                    let final_messages: Option<Vec<Message>> = if has_history {
790                        Some(new_messages.clone())
791                    } else {
792                        None
793                    };
794                    yield Ok(MultiTurnStreamItem::final_response_with_history(
795                        &turn_text_response,
796                        aggregated_usage,
797                        final_messages,
798                    ));
799                    break;
800                }
801            }
802
803            if max_turns_reached {
804                yield Err(Box::new(PromptError::MaxTurnsError {
805                    max_turns: self.max_turns,
806                    chat_history: build_full_history(chat_history.as_deref(), new_messages.clone()).into(),
807                    prompt: Box::new(last_prompt_error.clone().into()),
808                }).into());
809            }
810        };
811
812        Box::pin(stream.instrument(agent_span))
813    }
814}
815
816impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
817where
818    M: CompletionModel + 'static,
819    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
820    P: PromptHook<M> + 'static,
821{
822    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
823    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
824
825    fn into_future(self) -> Self::IntoFuture {
826        // Wrap send() in a future, because send() returns a stream immediately
827        Box::pin(async move { self.send().await })
828    }
829}
830
831/// Helper function to stream a completion request to stdout.
832pub async fn stream_to_stdout<R>(
833    stream: &mut StreamingResult<R>,
834) -> Result<FinalResponse, std::io::Error> {
835    let mut final_res = FinalResponse::empty();
836    print!("Response: ");
837    while let Some(content) = stream.next().await {
838        match content {
839            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
840                Text { text },
841            ))) => {
842                print!("{text}");
843                std::io::Write::flush(&mut std::io::stdout())?;
844            }
845            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
846                reasoning,
847            ))) => {
848                let reasoning = reasoning.display_text();
849                print!("{reasoning}");
850                std::io::Write::flush(&mut std::io::stdout())?;
851            }
852            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
853                final_res = res;
854            }
855            Err(err) => {
856                eprintln!("Error: {err}");
857            }
858            _ => {}
859        }
860    }
861
862    Ok(final_res)
863}
864
865#[cfg(test)]
866mod tests {
867    use super::*;
868    use crate::agent::AgentBuilder;
869    use crate::client::ProviderClient;
870    use crate::client::completion::CompletionClient;
871    use crate::completion::CompletionRequest;
872    use crate::message::{
873        AssistantContent, DocumentSourceKind, ImageMediaType, Message, ReasoningContent,
874        ToolResultContent, UserContent,
875    };
876    use crate::providers::anthropic;
877    use crate::streaming::StreamingPrompt;
878    use crate::test_utils::{
879        AppendFailingMemory, FailingMemory, MockCompletionModel, MockStreamEvent,
880    };
881    use futures::StreamExt;
882    use std::sync::Arc;
883    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
884    use std::time::Duration;
885
886    #[test]
887    fn merge_reasoning_blocks_preserves_order_and_signatures() {
888        let mut accumulated = Vec::new();
889        let first = crate::message::Reasoning {
890            id: Some("rs_1".to_string()),
891            content: vec![ReasoningContent::Text {
892                text: "step-1".to_string(),
893                signature: Some("sig-1".to_string()),
894            }],
895        };
896        let second = crate::message::Reasoning {
897            id: Some("rs_1".to_string()),
898            content: vec![
899                ReasoningContent::Text {
900                    text: "step-2".to_string(),
901                    signature: Some("sig-2".to_string()),
902                },
903                ReasoningContent::Summary("summary".to_string()),
904            ],
905        };
906
907        merge_reasoning_blocks(&mut accumulated, &first);
908        merge_reasoning_blocks(&mut accumulated, &second);
909
910        assert_eq!(accumulated.len(), 1);
911        let merged = accumulated.first().expect("expected accumulated reasoning");
912        assert_eq!(merged.id.as_deref(), Some("rs_1"));
913        assert_eq!(merged.content.len(), 3);
914        assert!(matches!(
915            merged.content.first(),
916            Some(ReasoningContent::Text { text, signature: Some(sig) })
917                if text == "step-1" && sig == "sig-1"
918        ));
919        assert!(matches!(
920            merged.content.get(1),
921            Some(ReasoningContent::Text { text, signature: Some(sig) })
922                if text == "step-2" && sig == "sig-2"
923        ));
924    }
925
926    #[test]
927    fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
928        let mut accumulated = vec![crate::message::Reasoning {
929            id: Some("rs_a".to_string()),
930            content: vec![ReasoningContent::Text {
931                text: "step-1".to_string(),
932                signature: None,
933            }],
934        }];
935        let incoming = crate::message::Reasoning {
936            id: Some("rs_b".to_string()),
937            content: vec![ReasoningContent::Text {
938                text: "step-2".to_string(),
939                signature: None,
940            }],
941        };
942
943        merge_reasoning_blocks(&mut accumulated, &incoming);
944        assert_eq!(accumulated.len(), 2);
945        assert_eq!(
946            accumulated.first().and_then(|r| r.id.as_deref()),
947            Some("rs_a")
948        );
949        assert_eq!(
950            accumulated.get(1).and_then(|r| r.id.as_deref()),
951            Some("rs_b")
952        );
953    }
954
955    #[test]
956    fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
957        let mut accumulated = vec![crate::message::Reasoning {
958            id: None,
959            content: vec![ReasoningContent::Text {
960                text: "first".to_string(),
961                signature: None,
962            }],
963        }];
964        let incoming = crate::message::Reasoning {
965            id: None,
966            content: vec![ReasoningContent::Text {
967                text: "second".to_string(),
968                signature: None,
969            }],
970        };
971
972        merge_reasoning_blocks(&mut accumulated, &incoming);
973        assert_eq!(accumulated.len(), 2);
974        assert!(matches!(
975            accumulated.first(),
976            Some(crate::message::Reasoning {
977                id: None,
978                content
979            }) if matches!(
980                content.first(),
981                Some(ReasoningContent::Text { text, .. }) if text == "first"
982            )
983        ));
984        assert!(matches!(
985            accumulated.get(1),
986            Some(crate::message::Reasoning {
987                id: None,
988                content
989            }) if matches!(
990                content.first(),
991                Some(ReasoningContent::Text { text, .. }) if text == "second"
992            )
993        ));
994    }
995
996    #[test]
997    fn tool_result_to_user_message_preserves_multimodal_tool_output() {
998        let message = tool_result_to_user_message(
999            "tool_call_1".to_string(),
1000            Some("call_1".to_string()),
1001            serde_json::json!({
1002                "response": {
1003                    "instruction": "Use the image part to answer."
1004                },
1005                "parts": [
1006                    {
1007                        "type": "image",
1008                        "data": "base64data==",
1009                        "mimeType": "image/png"
1010                    }
1011                ]
1012            })
1013            .to_string(),
1014        );
1015
1016        let tool_result = match message {
1017            Message::User { content } => match content.first() {
1018                UserContent::ToolResult(tool_result) => tool_result,
1019                other => panic!("expected tool result content, got {other:?}"),
1020            },
1021            other => panic!("expected user message, got {other:?}"),
1022        };
1023
1024        assert_eq!(tool_result.id, "tool_call_1");
1025        assert_eq!(tool_result.call_id.as_deref(), Some("call_1"));
1026        assert_eq!(tool_result.content.len(), 2);
1027
1028        let mut items = tool_result.content.iter();
1029        match items.next() {
1030            Some(ToolResultContent::Text(text)) => {
1031                assert!(text.text.contains("Use the image part to answer."));
1032            }
1033            other => panic!("expected structured text payload first, got {other:?}"),
1034        }
1035
1036        match items.next() {
1037            Some(ToolResultContent::Image(image)) => {
1038                assert_eq!(image.media_type, Some(ImageMediaType::PNG));
1039                assert!(matches!(
1040                    image.data,
1041                    DocumentSourceKind::Base64(ref data) if data == "base64data=="
1042                ));
1043            }
1044            other => panic!("expected image payload second, got {other:?}"),
1045        }
1046    }
1047
1048    fn validate_follow_up_tool_history(request: &CompletionRequest) -> Result<(), String> {
1049        let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
1050        if history.len() != 3 {
1051            return Err(format!(
1052                "follow-up request should contain [original user prompt, assistant tool call, user tool result]: {history:?}"
1053            ));
1054        }
1055
1056        if !matches!(
1057            history.first(),
1058            Some(Message::User { content })
1059                if matches!(
1060                    content.first(),
1061                    UserContent::Text(text) if text.text == "do tool work"
1062                )
1063        ) {
1064            return Err(format!(
1065                "follow-up request should begin with the original user prompt: {history:?}"
1066            ));
1067        }
1068
1069        if !matches!(
1070            history.get(1),
1071            Some(Message::Assistant { content, .. })
1072                if matches!(
1073                    content.first(),
1074                    AssistantContent::ToolCall(tool_call)
1075                        if tool_call.id == "tool_call_1"
1076                            && tool_call.call_id.as_deref() == Some("call_1")
1077                )
1078        ) {
1079            return Err(format!(
1080                "follow-up request is missing the assistant tool call in position 2: {history:?}"
1081            ));
1082        }
1083
1084        if !matches!(
1085            history.get(2),
1086            Some(Message::User { content })
1087                if matches!(
1088                    content.first(),
1089                    UserContent::ToolResult(tool_result)
1090                        if tool_result.id == "tool_call_1"
1091                            && tool_result.call_id.as_deref() == Some("call_1")
1092                )
1093        ) {
1094            return Err(format!(
1095                "follow-up request should end with the user tool result: {history:?}"
1096            ));
1097        }
1098
1099        Ok(())
1100    }
1101
1102    fn streaming_tool_then_text_model() -> MockCompletionModel {
1103        MockCompletionModel::from_stream_turns([
1104            vec![
1105                MockStreamEvent::tool_call(
1106                    "tool_call_1",
1107                    "missing_tool",
1108                    serde_json::json!({"input": "value"}),
1109                )
1110                .with_call_id("call_1"),
1111                MockStreamEvent::final_response_with_total_tokens(4),
1112            ],
1113            vec![
1114                MockStreamEvent::text("done"),
1115                MockStreamEvent::final_response_with_total_tokens(6),
1116            ],
1117        ])
1118    }
1119
1120    fn streaming_text_then_final_model() -> MockCompletionModel {
1121        MockCompletionModel::from_stream_turns([[
1122            MockStreamEvent::text("hello"),
1123            MockStreamEvent::text(" world"),
1124            MockStreamEvent::final_response_with_total_tokens(3),
1125        ]])
1126    }
1127
1128    fn streaming_final_only_model() -> MockCompletionModel {
1129        MockCompletionModel::from_stream_turns([[
1130            MockStreamEvent::final_response_with_total_tokens(1),
1131        ]])
1132    }
1133
1134    #[tokio::test]
1135    async fn stream_prompt_continues_after_tool_call_turn() {
1136        let model = streaming_tool_then_text_model();
1137        let recorded = model.clone();
1138        let agent = AgentBuilder::new(model).build();
1139        let empty_history: &[Message] = &[];
1140
1141        let mut stream = agent
1142            .stream_prompt("do tool work")
1143            .with_history(empty_history)
1144            .multi_turn(3)
1145            .await;
1146        let mut saw_tool_call = false;
1147        let mut saw_tool_result = false;
1148        let mut saw_final_response = false;
1149        let mut final_text = String::new();
1150        let mut final_response_text = None;
1151        let mut final_history = None;
1152
1153        while let Some(item) = stream.next().await {
1154            match item {
1155                Ok(MultiTurnStreamItem::StreamAssistantItem(
1156                    StreamedAssistantContent::ToolCall { .. },
1157                )) => {
1158                    saw_tool_call = true;
1159                }
1160                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
1161                    ..
1162                })) => {
1163                    saw_tool_result = true;
1164                }
1165                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1166                    text,
1167                ))) => {
1168                    final_text.push_str(&text.text);
1169                }
1170                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1171                    saw_final_response = true;
1172                    final_response_text = Some(res.response().to_owned());
1173                    final_history = res.history().map(|history| history.to_vec());
1174                    break;
1175                }
1176                Ok(_) => {}
1177                Err(err) => panic!("unexpected streaming error: {err:?}"),
1178            }
1179        }
1180
1181        assert!(saw_tool_call);
1182        assert!(saw_tool_result);
1183        assert!(saw_final_response);
1184        assert_eq!(final_text, "done");
1185        assert_eq!(final_response_text.as_deref(), Some("done"));
1186        let history = final_history.expect("expected final response history");
1187        assert!(history.iter().any(|message| matches!(
1188            message,
1189            Message::Assistant { content, .. }
1190                if content.iter().any(|item| matches!(
1191                    item,
1192                    AssistantContent::Text(text) if text.text == "done"
1193                ))
1194        )));
1195        let requests = recorded.requests();
1196        assert_eq!(requests.len(), 2);
1197        assert!(validate_follow_up_tool_history(&requests[1]).is_ok());
1198    }
1199
1200    #[tokio::test]
1201    async fn final_response_matches_streamed_text_when_provider_final_is_textless() {
1202        let agent = AgentBuilder::new(streaming_text_then_final_model()).build();
1203
1204        let mut stream = agent.stream_prompt("say hello").await;
1205        let mut streamed_text = String::new();
1206        let mut final_response_text = None;
1207
1208        while let Some(item) = stream.next().await {
1209            match item {
1210                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1211                    text,
1212                ))) => streamed_text.push_str(&text.text),
1213                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1214                    final_response_text = Some(res.response().to_owned());
1215                    break;
1216                }
1217                Ok(_) => {}
1218                Err(err) => panic!("unexpected streaming error: {err:?}"),
1219            }
1220        }
1221
1222        assert_eq!(streamed_text, "hello world");
1223        assert_eq!(final_response_text.as_deref(), Some("hello world"));
1224    }
1225
1226    #[tokio::test]
1227    async fn final_response_can_remain_empty_for_truly_textless_turns() {
1228        let agent = AgentBuilder::new(streaming_final_only_model()).build();
1229
1230        let mut stream = agent.stream_prompt("say nothing").await;
1231        let mut streamed_text = String::new();
1232        let mut final_response_text = None;
1233
1234        while let Some(item) = stream.next().await {
1235            match item {
1236                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1237                    text,
1238                ))) => streamed_text.push_str(&text.text),
1239                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1240                    final_response_text = Some(res.response().to_owned());
1241                    break;
1242                }
1243                Ok(_) => {}
1244                Err(err) => panic!("unexpected streaming error: {err:?}"),
1245            }
1246        }
1247
1248        assert!(streamed_text.is_empty());
1249        assert_eq!(final_response_text.as_deref(), Some(""));
1250    }
1251
1252    /// Background task that logs periodically to detect span leakage.
1253    /// If span leakage occurs, these logs will be prefixed with `invoke_agent{...}`.
1254    async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
1255        let mut interval = tokio::time::interval(Duration::from_millis(50));
1256        let mut count = 0u32;
1257
1258        while !stop.load(Ordering::Relaxed) {
1259            interval.tick().await;
1260            count += 1;
1261
1262            tracing::event!(
1263                target: "background_logger",
1264                tracing::Level::INFO,
1265                count = count,
1266                "Background tick"
1267            );
1268
1269            // Check if we're inside an unexpected span
1270            let current = tracing::Span::current();
1271            if !current.is_disabled() && !current.is_none() {
1272                leak_count.fetch_add(1, Ordering::Relaxed);
1273            }
1274        }
1275
1276        tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
1277    }
1278
1279    /// Test that span context doesn't leak to concurrent tasks during streaming.
1280    ///
1281    /// This test verifies that using `.instrument()` instead of `span.enter()` in
1282    /// async_stream prevents thread-local span context from leaking to other tasks.
1283    ///
1284    /// Uses single-threaded runtime to force all tasks onto the same thread,
1285    /// making the span leak deterministic (it only occurs when tasks share a thread).
1286    #[tokio::test(flavor = "current_thread")]
1287    #[ignore = "This requires an API key"]
1288    async fn test_span_context_isolation() -> anyhow::Result<()> {
1289        let stop = Arc::new(AtomicBool::new(false));
1290        let leak_count = Arc::new(AtomicU32::new(0));
1291
1292        // Start background logger
1293        let bg_stop = stop.clone();
1294        let bg_leak = leak_count.clone();
1295        let bg_handle = tokio::spawn(async move {
1296            background_logger(bg_stop, bg_leak).await;
1297        });
1298
1299        // Small delay to let background logger start
1300        tokio::time::sleep(Duration::from_millis(100)).await;
1301
1302        // Make streaming request WITHOUT an outer span so rig creates its own invoke_agent span
1303        // (rig reuses current span if one exists, so we need to ensure there's no current span)
1304        let client = anthropic::Client::from_env()?;
1305        let agent = client
1306            .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
1307            .preamble("You are a helpful assistant.")
1308            .temperature(0.1)
1309            .max_tokens(100)
1310            .build();
1311
1312        let mut stream = agent
1313            .stream_prompt("Say 'hello world' and nothing else.")
1314            .await;
1315
1316        let mut full_content = String::new();
1317        while let Some(item) = stream.next().await {
1318            match item {
1319                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1320                    text,
1321                ))) => {
1322                    full_content.push_str(&text.text);
1323                }
1324                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
1325                    break;
1326                }
1327                Err(e) => {
1328                    tracing::warn!("Error: {:?}", e);
1329                    break;
1330                }
1331                _ => {}
1332            }
1333        }
1334
1335        tracing::info!("Got response: {:?}", full_content);
1336
1337        // Stop background logger
1338        stop.store(true, Ordering::Relaxed);
1339        bg_handle.await?;
1340
1341        let leaks = leak_count.load(Ordering::Relaxed);
1342        anyhow::ensure!(
1343            leaks == 0,
1344            "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
1345             This indicates that span.enter() is being used inside async_stream instead of .instrument()"
1346        );
1347
1348        Ok(())
1349    }
1350
1351    /// Test that FinalResponse contains the updated chat history when with_history is used.
1352    ///
1353    /// This verifies that:
1354    /// 1. FinalResponse.history() returns Some when with_history was called
1355    /// 2. The history contains both the user prompt and assistant response
1356    #[tokio::test]
1357    #[ignore = "This requires an API key"]
1358    async fn test_chat_history_in_final_response() -> anyhow::Result<()> {
1359        use crate::message::Message;
1360
1361        let client = anthropic::Client::from_env()?;
1362        let agent = client
1363            .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
1364            .preamble("You are a helpful assistant. Keep responses brief.")
1365            .temperature(0.1)
1366            .max_tokens(50)
1367            .build();
1368
1369        // Send streaming request with history
1370        let empty_history: &[Message] = &[];
1371        let mut stream = agent
1372            .stream_prompt("Say 'hello' and nothing else.")
1373            .with_history(empty_history)
1374            .await;
1375
1376        // Consume the stream and collect FinalResponse
1377        let mut response_text = String::new();
1378        let mut final_history = None;
1379        while let Some(item) = stream.next().await {
1380            match item {
1381                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1382                    text,
1383                ))) => {
1384                    response_text.push_str(&text.text);
1385                }
1386                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1387                    final_history = res.history().map(|h| h.to_vec());
1388                    break;
1389                }
1390                Err(e) => {
1391                    return Err(e.into());
1392                }
1393                _ => {}
1394            }
1395        }
1396
1397        let history = final_history
1398            .ok_or_else(|| anyhow::anyhow!("final response should include history"))?;
1399
1400        // Should contain at least the user message
1401        anyhow::ensure!(
1402            history.iter().any(|m| matches!(m, Message::User { .. })),
1403            "History should contain the user message"
1404        );
1405
1406        // Should contain the assistant response
1407        anyhow::ensure!(
1408            history
1409                .iter()
1410                .any(|m| matches!(m, Message::Assistant { .. })),
1411            "History should contain the assistant response"
1412        );
1413
1414        tracing::info!(
1415            "History after streaming: {} messages, response: {:?}",
1416            history.len(),
1417            response_text
1418        );
1419
1420        Ok(())
1421    }
1422
1423    #[tokio::test]
1424    async fn streaming_appends_to_memory_after_final_response() {
1425        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
1426
1427        let memory = InMemoryConversationMemory::new();
1428        let agent = AgentBuilder::new(streaming_text_then_final_model())
1429            .memory(memory.clone())
1430            .build();
1431
1432        let mut stream = agent
1433            .stream_prompt("hi there")
1434            .conversation("stream-thread")
1435            .await;
1436
1437        let mut history_in_final = None;
1438        while let Some(item) = stream.next().await {
1439            match item {
1440                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1441                    history_in_final = res.history().map(|h| h.to_vec());
1442                    break;
1443                }
1444                Ok(_) => {}
1445                Err(err) => panic!("unexpected streaming error: {err:?}"),
1446            }
1447        }
1448
1449        let final_history = history_in_final
1450            .expect("FinalResponse.history should be populated when memory is configured");
1451        assert_eq!(
1452            final_history.len(),
1453            2,
1454            "user prompt + assistant response in final history: {final_history:?}"
1455        );
1456
1457        let stored = memory.load("stream-thread").await.unwrap();
1458        assert_eq!(stored.len(), 2, "memory should contain user + assistant");
1459    }
1460
1461    #[tokio::test]
1462    async fn streaming_with_history_overrides_memory() {
1463        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
1464
1465        let memory = InMemoryConversationMemory::new();
1466        memory
1467            .append("t1", vec![Message::user("from-memory")])
1468            .await
1469            .unwrap();
1470
1471        let agent = AgentBuilder::new(streaming_text_then_final_model())
1472            .memory(memory.clone())
1473            .build();
1474
1475        let mut stream = agent
1476            .stream_prompt("hi")
1477            .conversation("t1")
1478            .with_history(vec![Message::user("from-caller")])
1479            .await;
1480
1481        while let Some(item) = stream.next().await {
1482            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
1483                break;
1484            }
1485        }
1486
1487        let stored = memory.load("t1").await.unwrap();
1488        assert_eq!(
1489            stored.len(),
1490            1,
1491            "with_history bypasses memory; only the pre-seeded entry remains: {stored:?}"
1492        );
1493    }
1494
1495    #[tokio::test]
1496    async fn streaming_without_memory_disables_for_request() {
1497        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
1498
1499        let memory = InMemoryConversationMemory::new();
1500        let agent = AgentBuilder::new(streaming_text_then_final_model())
1501            .memory(memory.clone())
1502            .conversation_id("default")
1503            .build();
1504
1505        let mut stream = agent.stream_prompt("hi").without_memory().await;
1506
1507        while let Some(item) = stream.next().await {
1508            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
1509                break;
1510            }
1511        }
1512
1513        let stored = memory.load("default").await.unwrap();
1514        assert!(stored.is_empty(), "without_memory disables save");
1515    }
1516
1517    #[tokio::test]
1518    async fn streaming_load_error_yields_memory_error() {
1519        let agent = AgentBuilder::new(streaming_text_then_final_model())
1520            .memory(FailingMemory::default())
1521            .build();
1522
1523        let mut stream = agent.stream_prompt("hi").conversation("t1").await;
1524
1525        let first = stream.next().await.expect("at least one item");
1526        match first {
1527            Err(err) => {
1528                let msg = format!("{err:?}");
1529                assert!(
1530                    msg.contains("Memory") || msg.contains("memory") || msg.contains("load boom"),
1531                    "expected memory error, got: {msg}"
1532                );
1533            }
1534            Ok(other) => panic!("expected memory error, got {other:?}"),
1535        }
1536    }
1537
1538    #[tokio::test]
1539    async fn streaming_with_filter_shapes_loaded_history() {
1540        use crate::memory::{ConversationMemory, InMemoryConversationMemory};
1541
1542        let memory = InMemoryConversationMemory::new()
1543            .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
1544        memory
1545            .append(
1546                "t1",
1547                vec![
1548                    Message::user("1"),
1549                    Message::assistant("2"),
1550                    Message::user("3"),
1551                    Message::assistant("4"),
1552                ],
1553            )
1554            .await
1555            .unwrap();
1556
1557        let model = MockCompletionModel::from_stream_turns([[
1558            MockStreamEvent::text("ok"),
1559            MockStreamEvent::final_response_with_total_tokens(1),
1560        ]]);
1561        let recorded = model.clone();
1562        let agent = AgentBuilder::new(model).memory(memory).build();
1563
1564        let mut stream = agent.stream_prompt("ping").conversation("t1").await;
1565        while let Some(item) = stream.next().await {
1566            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
1567                break;
1568            }
1569        }
1570
1571        let received = recorded.requests()[0]
1572            .chat_history
1573            .iter()
1574            .cloned()
1575            .collect::<Vec<_>>();
1576        assert_eq!(
1577            received.len(),
1578            3,
1579            "window-truncated history (2) + current prompt: {received:?}"
1580        );
1581    }
1582
1583    #[tokio::test]
1584    async fn streaming_append_error_does_not_suppress_final_response() {
1585        let agent = AgentBuilder::new(streaming_text_then_final_model())
1586            .memory(AppendFailingMemory::default())
1587            .build();
1588
1589        let mut stream = agent.stream_prompt("hi").conversation("t1").await;
1590
1591        let mut saw_final = false;
1592        while let Some(item) = stream.next().await {
1593            if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
1594                saw_final = true;
1595                break;
1596            }
1597        }
1598        assert!(
1599            saw_final,
1600            "FinalResponse must be yielded even when memory.append fails"
1601        );
1602    }
1603}