Skip to main content

rig/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::completion::{DynamicContextStore, build_completion_request},
4    agent::prompt_request::{HookAction, hooks::PromptHook},
5    completion::{Document, GetTokenUsage},
6    json_utils,
7    message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent},
8    streaming::{StreamedAssistantContent, StreamedUserContent},
9    tool::server::ToolServerHandle,
10    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
11};
12use futures::{Stream, StreamExt};
13use serde::{Deserialize, Serialize};
14use std::{pin::Pin, sync::Arc};
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::ToolCallHookAction;
19use crate::{
20    agent::Agent,
21    completion::{CompletionError, CompletionModel, PromptError},
22    message::{Message, Text},
23    tool::ToolSetError,
24};
25
26#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
27pub type StreamingResult<R> =
28    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
29
30#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
31pub type StreamingResult<R> =
32    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(tag = "type", rename_all = "camelCase")]
36#[non_exhaustive]
37pub enum MultiTurnStreamItem<R> {
38    /// A streamed assistant content item.
39    StreamAssistantItem(StreamedAssistantContent<R>),
40    /// A streamed user content item (mostly for tool results).
41    StreamUserItem(StreamedUserContent),
42    /// The final result from the stream.
43    FinalResponse(FinalResponse),
44}
45
46#[derive(Deserialize, Serialize, Debug, Clone)]
47#[serde(rename_all = "camelCase")]
48pub struct FinalResponse {
49    response: String,
50    aggregated_usage: crate::completion::Usage,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    history: Option<Vec<Message>>,
53}
54
55impl FinalResponse {
56    pub fn empty() -> Self {
57        Self {
58            response: String::new(),
59            aggregated_usage: crate::completion::Usage::new(),
60            history: None,
61        }
62    }
63
64    pub fn response(&self) -> &str {
65        &self.response
66    }
67
68    pub fn usage(&self) -> crate::completion::Usage {
69        self.aggregated_usage
70    }
71
72    pub fn history(&self) -> Option<&[Message]> {
73        self.history.as_deref()
74    }
75}
76
77impl<R> MultiTurnStreamItem<R> {
78    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
79        Self::StreamAssistantItem(item)
80    }
81
82    pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
83        Self::FinalResponse(FinalResponse {
84            response: response.to_string(),
85            aggregated_usage,
86            history: None,
87        })
88    }
89
90    pub fn final_response_with_history(
91        response: &str,
92        aggregated_usage: crate::completion::Usage,
93        history: Option<Vec<Message>>,
94    ) -> Self {
95        Self::FinalResponse(FinalResponse {
96            response: response.to_string(),
97            aggregated_usage,
98            history,
99        })
100    }
101}
102
103fn merge_reasoning_blocks(
104    accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
105    incoming: &crate::message::Reasoning,
106) {
107    let ids_match = |existing: &crate::message::Reasoning| {
108        matches!(
109            (&existing.id, &incoming.id),
110            (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
111        )
112    };
113
114    if let Some(existing) = accumulated_reasoning
115        .iter_mut()
116        .rev()
117        .find(|existing| ids_match(existing))
118    {
119        existing.content.extend(incoming.content.clone());
120    } else {
121        accumulated_reasoning.push(incoming.clone());
122    }
123}
124
125/// Build full history for error reporting (input + new messages).
126fn build_full_history(
127    chat_history: Option<&[Message]>,
128    new_messages: Vec<Message>,
129) -> Vec<Message> {
130    let input = chat_history.unwrap_or(&[]);
131    input.iter().cloned().chain(new_messages).collect()
132}
133
134/// Combine input history with new messages for building completion requests.
135fn build_history_for_request(
136    chat_history: Option<&[Message]>,
137    new_messages: &[Message],
138) -> Vec<Message> {
139    let input = chat_history.unwrap_or(&[]);
140    input.iter().chain(new_messages.iter()).cloned().collect()
141}
142
143async fn cancelled_prompt_error(
144    chat_history: Option<&[Message]>,
145    new_messages: Vec<Message>,
146    reason: String,
147) -> StreamingError {
148    StreamingError::Prompt(
149        PromptError::prompt_cancelled(build_full_history(chat_history, new_messages), reason)
150            .into(),
151    )
152}
153
154fn tool_result_to_user_message(
155    id: String,
156    call_id: Option<String>,
157    tool_result: String,
158) -> Message {
159    let content = OneOrMany::one(ToolResultContent::text(tool_result));
160    let user_content = match call_id {
161        Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
162        None => UserContent::tool_result(id, content),
163    };
164
165    Message::User {
166        content: OneOrMany::one(user_content),
167    }
168}
169
170#[derive(Debug, thiserror::Error)]
171pub enum StreamingError {
172    #[error("CompletionError: {0}")]
173    Completion(#[from] CompletionError),
174    #[error("PromptError: {0}")]
175    Prompt(#[from] Box<PromptError>),
176    #[error("ToolSetError: {0}")]
177    Tool(#[from] ToolSetError),
178}
179
180const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
181
182/// A builder for creating prompt requests with customizable options.
183/// Uses generics to track which options have been set during the build process.
184///
185/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
186/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
187/// attempting to await (which will send the prompt request) can potentially return
188/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
189/// back to back.
190pub struct StreamingPromptRequest<M, P>
191where
192    M: CompletionModel,
193    P: PromptHook<M> + 'static,
194{
195    /// The prompt message to send to the model
196    prompt: Message,
197    /// Optional chat history provided by the caller.
198    chat_history: Option<Vec<Message>>,
199    /// Maximum Turns for multi-turn conversations (0 means no multi-turn)
200    max_turns: usize,
201
202    // Agent data (cloned from agent to allow hook type transitions):
203    /// The completion model
204    model: Arc<M>,
205    /// Agent name for logging
206    agent_name: Option<String>,
207    /// System prompt
208    preamble: Option<String>,
209    /// Static context documents
210    static_context: Vec<Document>,
211    /// Temperature setting
212    temperature: Option<f64>,
213    /// Max tokens setting
214    max_tokens: Option<u64>,
215    /// Additional model parameters
216    additional_params: Option<serde_json::Value>,
217    /// Tool server handle for tool execution
218    tool_server_handle: ToolServerHandle,
219    /// Dynamic context store
220    dynamic_context: DynamicContextStore,
221    /// Tool choice setting
222    tool_choice: Option<ToolChoice>,
223    /// Optional JSON Schema for structured output
224    output_schema: Option<schemars::Schema>,
225    /// Optional per-request hook for events
226    hook: Option<P>,
227}
228
229impl<M, P> StreamingPromptRequest<M, P>
230where
231    M: CompletionModel + 'static,
232    <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
233    P: PromptHook<M>,
234{
235    /// Create a new StreamingPromptRequest with the given prompt and model.
236    /// Note: This creates a request without an agent hook. Use `from_agent` to include the agent's hook.
237    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
238        StreamingPromptRequest {
239            prompt: prompt.into(),
240            chat_history: None,
241            max_turns: agent.default_max_turns.unwrap_or_default(),
242            model: agent.model.clone(),
243            agent_name: agent.name.clone(),
244            preamble: agent.preamble.clone(),
245            static_context: agent.static_context.clone(),
246            temperature: agent.temperature,
247            max_tokens: agent.max_tokens,
248            additional_params: agent.additional_params.clone(),
249            tool_server_handle: agent.tool_server_handle.clone(),
250            dynamic_context: agent.dynamic_context.clone(),
251            tool_choice: agent.tool_choice.clone(),
252            output_schema: agent.output_schema.clone(),
253            hook: None,
254        }
255    }
256
257    /// Create a new StreamingPromptRequest from an agent, cloning the agent's data and default hook.
258    pub fn from_agent<P2>(
259        agent: &Agent<M, P2>,
260        prompt: impl Into<Message>,
261    ) -> StreamingPromptRequest<M, P2>
262    where
263        P2: PromptHook<M>,
264    {
265        StreamingPromptRequest {
266            prompt: prompt.into(),
267            chat_history: None,
268            max_turns: agent.default_max_turns.unwrap_or_default(),
269            model: agent.model.clone(),
270            agent_name: agent.name.clone(),
271            preamble: agent.preamble.clone(),
272            static_context: agent.static_context.clone(),
273            temperature: agent.temperature,
274            max_tokens: agent.max_tokens,
275            additional_params: agent.additional_params.clone(),
276            tool_server_handle: agent.tool_server_handle.clone(),
277            dynamic_context: agent.dynamic_context.clone(),
278            tool_choice: agent.tool_choice.clone(),
279            output_schema: agent.output_schema.clone(),
280            hook: agent.hook.clone(),
281        }
282    }
283
284    fn agent_name(&self) -> &str {
285        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
286    }
287
288    /// Set the maximum Turns for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
289    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
290    pub fn multi_turn(mut self, turns: usize) -> Self {
291        self.max_turns = turns;
292        self
293    }
294
295    /// Add chat history to the prompt request.
296    ///
297    /// When history is provided, the final [`FinalResponse`] will include the
298    /// updated chat history (original messages + new user prompt + assistant response).
299    /// ```ignore
300    /// let mut stream = agent
301    ///     .stream_prompt("Hello")
302    ///     .with_history(vec![])
303    ///     .await;
304    /// // ... consume stream ...
305    /// // Access updated history from FinalResponse::history()
306    /// ```
307    pub fn with_history<I, T>(mut self, history: I) -> Self
308    where
309        I: IntoIterator<Item = T>,
310        T: Into<Message>,
311    {
312        self.chat_history = Some(history.into_iter().map(Into::into).collect());
313        self
314    }
315
316    /// Attach a per-request hook for tool call events.
317    /// This overrides any default hook set on the agent.
318    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
319    where
320        P2: PromptHook<M>,
321    {
322        StreamingPromptRequest {
323            prompt: self.prompt,
324            chat_history: self.chat_history,
325            max_turns: self.max_turns,
326            model: self.model,
327            agent_name: self.agent_name,
328            preamble: self.preamble,
329            static_context: self.static_context,
330            temperature: self.temperature,
331            max_tokens: self.max_tokens,
332            additional_params: self.additional_params,
333            tool_server_handle: self.tool_server_handle,
334            dynamic_context: self.dynamic_context,
335            tool_choice: self.tool_choice,
336            output_schema: self.output_schema,
337            hook: Some(hook),
338        }
339    }
340
341    async fn send(self) -> StreamingResult<M::StreamingResponse> {
342        let agent_span = if tracing::Span::current().is_disabled() {
343            info_span!(
344                "invoke_agent",
345                gen_ai.operation.name = "invoke_agent",
346                gen_ai.agent.name = self.agent_name(),
347                gen_ai.system_instructions = self.preamble,
348                gen_ai.prompt = tracing::field::Empty,
349                gen_ai.completion = tracing::field::Empty,
350                gen_ai.usage.input_tokens = tracing::field::Empty,
351                gen_ai.usage.output_tokens = tracing::field::Empty,
352                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
353                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
354            )
355        } else {
356            tracing::Span::current()
357        };
358
359        let prompt = self.prompt;
360        if let Some(text) = prompt.rag_text() {
361            agent_span.record("gen_ai.prompt", text);
362        }
363
364        // Clone fields needed inside the stream
365        let model = self.model.clone();
366        let preamble = self.preamble.clone();
367        let static_context = self.static_context.clone();
368        let temperature = self.temperature;
369        let max_tokens = self.max_tokens;
370        let additional_params = self.additional_params.clone();
371        let tool_server_handle = self.tool_server_handle.clone();
372        let dynamic_context = self.dynamic_context.clone();
373        let tool_choice = self.tool_choice.clone();
374        let agent_name = self.agent_name.clone();
375        let has_history = self.chat_history.is_some();
376        let chat_history = self.chat_history;
377        let mut new_messages: Vec<Message> = vec![prompt.clone()];
378
379        let mut current_max_turns = 0;
380        let mut last_prompt_error = String::new();
381
382        let mut last_text_response = String::new();
383        let mut is_text_response = false;
384        let mut max_turns_reached = false;
385        let output_schema = self.output_schema;
386
387        let mut aggregated_usage = crate::completion::Usage::new();
388
389        // NOTE: We use .instrument(agent_span) instead of span.enter() to avoid
390        // span context leaking to other concurrent tasks. Using span.enter() inside
391        // async_stream::stream! holds the guard across yield points, which causes
392        // thread-local span context to leak when other tasks run on the same thread.
393        // See: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#in-asynchronous-code
394        // See also: https://github.com/rust-lang/rust-clippy/issues/8722
395        let stream = async_stream::stream! {
396            let mut current_prompt = prompt.clone();
397
398            'outer: loop {
399                if current_max_turns > self.max_turns + 1 {
400                    last_prompt_error = current_prompt.rag_text().unwrap_or_default();
401                    max_turns_reached = true;
402                    break;
403                }
404
405                current_max_turns += 1;
406
407                if self.max_turns > 1 {
408                    tracing::info!(
409                        "Current conversation Turns: {}/{}",
410                        current_max_turns,
411                        self.max_turns
412                    );
413                }
414
415                if let Some(ref hook) = self.hook {
416                    let history_snapshot: Vec<Message> = build_history_for_request(chat_history.as_deref(), &new_messages[..new_messages.len().saturating_sub(1)]);
417                    if let HookAction::Terminate { reason } = hook.on_completion_call(&current_prompt, &history_snapshot)
418                        .await {
419                        yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
420                        break 'outer;
421                    }
422                }
423
424                let chat_stream_span = info_span!(
425                    target: "rig::agent_chat",
426                    parent: tracing::Span::current(),
427                    "chat_streaming",
428                    gen_ai.operation.name = "chat",
429                    gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
430                    gen_ai.system_instructions = preamble,
431                    gen_ai.provider.name = tracing::field::Empty,
432                    gen_ai.request.model = tracing::field::Empty,
433                    gen_ai.response.id = tracing::field::Empty,
434                    gen_ai.response.model = tracing::field::Empty,
435                    gen_ai.usage.output_tokens = tracing::field::Empty,
436                    gen_ai.usage.input_tokens = tracing::field::Empty,
437                    gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
438                    gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
439                    gen_ai.input.messages = tracing::field::Empty,
440                    gen_ai.output.messages = tracing::field::Empty,
441                );
442
443                let history_snapshot: Vec<Message> = build_history_for_request(chat_history.as_deref(), &new_messages[..new_messages.len().saturating_sub(1)]);
444                let mut stream = tracing::Instrument::instrument(
445                    build_completion_request(
446                        &model,
447                        current_prompt.clone(),
448                        &history_snapshot,
449                        preamble.as_deref(),
450                        &static_context,
451                        temperature,
452                        max_tokens,
453                        additional_params.as_ref(),
454                        tool_choice.as_ref(),
455                        &tool_server_handle,
456                        &dynamic_context,
457                        output_schema.as_ref(),
458                    )
459                    .await?
460                    .stream(), chat_stream_span
461                )
462
463                .await?;
464
465                new_messages.push(current_prompt.clone());
466
467                let mut tool_calls = vec![];
468                let mut tool_results = vec![];
469                let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
470                // Kept separate from accumulated_reasoning so providers requiring
471                // signatures (e.g. Anthropic) never see unsigned blocks.
472                let mut pending_reasoning_delta_text = String::new();
473                let mut pending_reasoning_delta_id: Option<String> = None;
474                let mut saw_tool_call_this_turn = false;
475
476                while let Some(content) = stream.next().await {
477                    match content {
478                        Ok(StreamedAssistantContent::Text(text)) => {
479                            if !is_text_response {
480                                last_text_response = String::new();
481                                is_text_response = true;
482                            }
483                            last_text_response.push_str(&text.text);
484                            if let Some(ref hook) = self.hook &&
485                                let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &last_text_response).await {
486                                    yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
487                                    break 'outer;
488                            }
489
490                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
491                        },
492                        Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id }) => {
493                            let tool_span = info_span!(
494                                parent: tracing::Span::current(),
495                                "execute_tool",
496                                gen_ai.operation.name = "execute_tool",
497                                gen_ai.tool.type = "function",
498                                gen_ai.tool.name = tracing::field::Empty,
499                                gen_ai.tool.call.id = tracing::field::Empty,
500                                gen_ai.tool.call.arguments = tracing::field::Empty,
501                                gen_ai.tool.call.result = tracing::field::Empty
502                            );
503
504                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
505
506                            let tc_result = async {
507                                let tool_span = tracing::Span::current();
508                                let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
509                                if let Some(ref hook) = self.hook {
510                                    let action = hook
511                                        .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
512                                        .await;
513
514                                    if let ToolCallHookAction::Terminate { reason } = action {
515                                        return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
516                                    }
517
518                                    if let ToolCallHookAction::Skip { reason } = action {
519                                        // Tool execution rejected, return rejection message as tool result
520                                        tracing::info!(
521                                            tool_name = tool_call.function.name.as_str(),
522                                            reason = reason,
523                                            "Tool call rejected"
524                                        );
525                                        let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
526                                        tool_calls.push(tool_call_msg);
527                                        tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
528                                        saw_tool_call_this_turn = true;
529                                        return Ok(reason);
530                                    }
531                                }
532
533                                tool_span.record("gen_ai.tool.name", &tool_call.function.name);
534                                tool_span.record("gen_ai.tool.call.arguments", &tool_args);
535
536                                let tool_result = match
537                                tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
538                                    Ok(thing) => thing,
539                                    Err(e) => {
540                                        tracing::warn!("Error while calling tool: {e}");
541                                        e.to_string()
542                                    }
543                                };
544
545                                tool_span.record("gen_ai.tool.call.result", &tool_result);
546
547                                if let Some(ref hook) = self.hook &&
548                                    let HookAction::Terminate { reason } =
549                                    hook.on_tool_result(
550                                        &tool_call.function.name,
551                                        tool_call.call_id.clone(),
552                                        &internal_call_id,
553                                        &tool_args,
554                                        &tool_result.to_string()
555                                    )
556                                    .await {
557                                        return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
558                                    }
559
560                                let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
561
562                                tool_calls.push(tool_call_msg);
563                                tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
564
565                                saw_tool_call_this_turn = true;
566                                Ok(tool_result)
567                            }.instrument(tool_span).await;
568
569                            match tc_result {
570                                Ok(text) => {
571                                    let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
572                                    yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
573                                }
574                                Err(e) => {
575                                    yield Err(e);
576                                    break 'outer;
577                                }
578                            }
579                        },
580                        Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content }) => {
581                            if let Some(ref hook) = self.hook {
582                                let (name, delta) = match &content {
583                                    rig::streaming::ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
584                                    rig::streaming::ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
585                                };
586
587                                if let HookAction::Terminate { reason } = hook.on_tool_call_delta(&id, &internal_call_id, name, delta)
588                                .await {
589                                    yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
590                                    break 'outer;
591                                }
592                            }
593                        }
594                        Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
595                            // Accumulate reasoning for inclusion in chat history with tool calls.
596                            // OpenAI Responses API requires reasoning items to be sent back
597                            // alongside function_call items in multi-turn conversations.
598                            merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
599                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
600                        },
601                        Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
602                            // Deltas lack signatures/encrypted content that full
603                            // blocks carry; mixing them into accumulated_reasoning
604                            // causes Anthropic to reject with "signature required".
605                            pending_reasoning_delta_text.push_str(&reasoning);
606                            if pending_reasoning_delta_id.is_none() {
607                                pending_reasoning_delta_id = id.clone();
608                            }
609                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
610                        },
611                        Ok(StreamedAssistantContent::Final(final_resp)) => {
612                            if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
613                            if is_text_response {
614                                if let Some(ref hook) = self.hook &&
615                                     let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(&prompt, &final_resp).await {
616                                        yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
617                                        break 'outer;
618                                    }
619
620                                tracing::Span::current().record("gen_ai.completion", &last_text_response);
621                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
622                                is_text_response = false;
623                            }
624                        }
625                        Err(e) => {
626                            yield Err(e.into());
627                            break 'outer;
628                        }
629                    }
630                }
631
632                // Providers like Gemini emit thinking as incremental deltas
633                // without signatures; assemble into a single block so
634                // reasoning survives into the next turn's chat history.
635                if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
636                    let mut assembled = crate::message::Reasoning::new(&pending_reasoning_delta_text);
637                    if let Some(id) = pending_reasoning_delta_id.take() {
638                        assembled = assembled.with_id(id);
639                    }
640                    accumulated_reasoning.push(assembled);
641                }
642
643                // Add text, reasoning, and tool calls to chat history.
644                // OpenAI Responses API requires reasoning items to precede function_call items.
645                if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
646                    let mut content_items: Vec<rig::message::AssistantContent> = vec![];
647
648                    // Text before tool calls so the model sees its own prior output
649                    if !last_text_response.is_empty() {
650                        content_items.push(rig::message::AssistantContent::text(&last_text_response));
651                        last_text_response.clear();
652                    }
653
654                    // Reasoning must come before tool calls (OpenAI requirement)
655                    for reasoning in accumulated_reasoning.drain(..) {
656                        content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
657                    }
658
659                    content_items.extend(tool_calls.clone());
660
661                    if !content_items.is_empty() {
662                        new_messages.push(Message::Assistant {
663                            id: stream.message_id.clone(),
664                            content: OneOrMany::many(content_items).expect("Should have at least one item"),
665                        });
666                    }
667                }
668
669                for (id, call_id, tool_result) in tool_results {
670                    new_messages.push(tool_result_to_user_message(id, call_id, tool_result));
671                }
672
673                // Set the current prompt to the last message in new_messages
674                current_prompt = match new_messages.pop() {
675                    Some(prompt) => prompt,
676                    None => unreachable!("New messages should never be empty at this point"),
677                };
678
679                if !saw_tool_call_this_turn {
680                    // Add user message and assistant response to history before finishing
681                    new_messages.push(current_prompt.clone());
682                    if !last_text_response.is_empty() {
683                        new_messages.push(Message::assistant(&last_text_response));
684                    }
685
686                    let current_span = tracing::Span::current();
687                    current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
688                    current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
689                    current_span.record("gen_ai.usage.cache_read.input_tokens", aggregated_usage.cached_input_tokens);
690                    current_span.record("gen_ai.usage.cache_creation.input_tokens", aggregated_usage.cache_creation_input_tokens);
691                    tracing::info!("Agent multi-turn stream finished");
692                    let final_messages: Option<Vec<Message>> = if has_history {
693                        Some(new_messages.clone())
694                    } else {
695                        None
696                    };
697                    yield Ok(MultiTurnStreamItem::final_response_with_history(
698                        &last_text_response,
699                        aggregated_usage,
700                        final_messages,
701                    ));
702                    break;
703                }
704            }
705
706            if max_turns_reached {
707                yield Err(Box::new(PromptError::MaxTurnsError {
708                    max_turns: self.max_turns,
709                    chat_history: build_full_history(chat_history.as_deref(), new_messages.clone()).into(),
710                    prompt: Box::new(last_prompt_error.clone().into()),
711                }).into());
712            }
713        };
714
715        Box::pin(stream.instrument(agent_span))
716    }
717}
718
719impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
720where
721    M: CompletionModel + 'static,
722    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
723    P: PromptHook<M> + 'static,
724{
725    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
726    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
727
728    fn into_future(self) -> Self::IntoFuture {
729        // Wrap send() in a future, because send() returns a stream immediately
730        Box::pin(async move { self.send().await })
731    }
732}
733
734/// Helper function to stream a completion request to stdout.
735pub async fn stream_to_stdout<R>(
736    stream: &mut StreamingResult<R>,
737) -> Result<FinalResponse, std::io::Error> {
738    let mut final_res = FinalResponse::empty();
739    print!("Response: ");
740    while let Some(content) = stream.next().await {
741        match content {
742            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
743                Text { text },
744            ))) => {
745                print!("{text}");
746                std::io::Write::flush(&mut std::io::stdout()).unwrap();
747            }
748            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
749                reasoning,
750            ))) => {
751                let reasoning = reasoning.display_text();
752                print!("{reasoning}");
753                std::io::Write::flush(&mut std::io::stdout()).unwrap();
754            }
755            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
756                final_res = res;
757            }
758            Err(err) => {
759                eprintln!("Error: {err}");
760            }
761            _ => {}
762        }
763    }
764
765    Ok(final_res)
766}
767
768#[cfg(test)]
769mod tests {
770    use super::*;
771    use crate::agent::AgentBuilder;
772    use crate::client::ProviderClient;
773    use crate::client::completion::CompletionClient;
774    use crate::completion::{
775        CompletionError, CompletionModel, CompletionRequest, CompletionResponse,
776    };
777    use crate::message::ReasoningContent;
778    use crate::providers::anthropic;
779    use crate::streaming::StreamingPrompt;
780    use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
781    use futures::StreamExt;
782    use serde::{Deserialize, Serialize};
783    use std::sync::Arc;
784    use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
785    use std::time::Duration;
786
787    #[test]
788    fn merge_reasoning_blocks_preserves_order_and_signatures() {
789        let mut accumulated = Vec::new();
790        let first = crate::message::Reasoning {
791            id: Some("rs_1".to_string()),
792            content: vec![ReasoningContent::Text {
793                text: "step-1".to_string(),
794                signature: Some("sig-1".to_string()),
795            }],
796        };
797        let second = crate::message::Reasoning {
798            id: Some("rs_1".to_string()),
799            content: vec![
800                ReasoningContent::Text {
801                    text: "step-2".to_string(),
802                    signature: Some("sig-2".to_string()),
803                },
804                ReasoningContent::Summary("summary".to_string()),
805            ],
806        };
807
808        merge_reasoning_blocks(&mut accumulated, &first);
809        merge_reasoning_blocks(&mut accumulated, &second);
810
811        assert_eq!(accumulated.len(), 1);
812        let merged = accumulated.first().expect("expected accumulated reasoning");
813        assert_eq!(merged.id.as_deref(), Some("rs_1"));
814        assert_eq!(merged.content.len(), 3);
815        assert!(matches!(
816            merged.content.first(),
817            Some(ReasoningContent::Text { text, signature: Some(sig) })
818                if text == "step-1" && sig == "sig-1"
819        ));
820        assert!(matches!(
821            merged.content.get(1),
822            Some(ReasoningContent::Text { text, signature: Some(sig) })
823                if text == "step-2" && sig == "sig-2"
824        ));
825    }
826
827    #[test]
828    fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
829        let mut accumulated = vec![crate::message::Reasoning {
830            id: Some("rs_a".to_string()),
831            content: vec![ReasoningContent::Text {
832                text: "step-1".to_string(),
833                signature: None,
834            }],
835        }];
836        let incoming = crate::message::Reasoning {
837            id: Some("rs_b".to_string()),
838            content: vec![ReasoningContent::Text {
839                text: "step-2".to_string(),
840                signature: None,
841            }],
842        };
843
844        merge_reasoning_blocks(&mut accumulated, &incoming);
845        assert_eq!(accumulated.len(), 2);
846        assert_eq!(
847            accumulated.first().and_then(|r| r.id.as_deref()),
848            Some("rs_a")
849        );
850        assert_eq!(
851            accumulated.get(1).and_then(|r| r.id.as_deref()),
852            Some("rs_b")
853        );
854    }
855
856    #[test]
857    fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
858        let mut accumulated = vec![crate::message::Reasoning {
859            id: None,
860            content: vec![ReasoningContent::Text {
861                text: "first".to_string(),
862                signature: None,
863            }],
864        }];
865        let incoming = crate::message::Reasoning {
866            id: None,
867            content: vec![ReasoningContent::Text {
868                text: "second".to_string(),
869                signature: None,
870            }],
871        };
872
873        merge_reasoning_blocks(&mut accumulated, &incoming);
874        assert_eq!(accumulated.len(), 2);
875        assert!(matches!(
876            accumulated.first(),
877            Some(crate::message::Reasoning {
878                id: None,
879                content
880            }) if matches!(
881                content.first(),
882                Some(ReasoningContent::Text { text, .. }) if text == "first"
883            )
884        ));
885        assert!(matches!(
886            accumulated.get(1),
887            Some(crate::message::Reasoning {
888                id: None,
889                content
890            }) if matches!(
891                content.first(),
892                Some(ReasoningContent::Text { text, .. }) if text == "second"
893            )
894        ));
895    }
896
897    #[derive(Clone, Debug, Deserialize, Serialize)]
898    struct MockStreamingResponse {
899        usage: crate::completion::Usage,
900    }
901
902    impl MockStreamingResponse {
903        fn new(total_tokens: u64) -> Self {
904            let mut usage = crate::completion::Usage::new();
905            usage.total_tokens = total_tokens;
906            Self { usage }
907        }
908    }
909
910    impl crate::completion::GetTokenUsage for MockStreamingResponse {
911        fn token_usage(&self) -> Option<crate::completion::Usage> {
912            Some(self.usage)
913        }
914    }
915
916    #[derive(Clone, Default)]
917    struct MultiTurnMockModel {
918        turn_counter: Arc<AtomicUsize>,
919    }
920
921    #[allow(refining_impl_trait)]
922    impl CompletionModel for MultiTurnMockModel {
923        type Response = ();
924        type StreamingResponse = MockStreamingResponse;
925        type Client = ();
926
927        fn make(_: &Self::Client, _: impl Into<String>) -> Self {
928            Self::default()
929        }
930
931        async fn completion(
932            &self,
933            _request: CompletionRequest,
934        ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
935            Err(CompletionError::ProviderError(
936                "completion is unused in this streaming test".to_string(),
937            ))
938        }
939
940        async fn stream(
941            &self,
942            _request: CompletionRequest,
943        ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
944            let turn = self.turn_counter.fetch_add(1, Ordering::SeqCst);
945            let stream = async_stream::stream! {
946                if turn == 0 {
947                    yield Ok(RawStreamingChoice::ToolCall(
948                        RawStreamingToolCall::new(
949                            "tool_call_1".to_string(),
950                            "missing_tool".to_string(),
951                            serde_json::json!({"input": "value"}),
952                        )
953                        .with_call_id("call_1".to_string()),
954                    ));
955                    yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(4)));
956                } else {
957                    yield Ok(RawStreamingChoice::Message("done".to_string()));
958                    yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(6)));
959                }
960            };
961
962            let pinned_stream: crate::streaming::StreamingResult<Self::StreamingResponse> =
963                Box::pin(stream);
964            Ok(StreamingCompletionResponse::stream(pinned_stream))
965        }
966    }
967
968    #[tokio::test]
969    async fn stream_prompt_continues_after_tool_call_turn() {
970        let model = MultiTurnMockModel::default();
971        let turn_counter = model.turn_counter.clone();
972        let agent = AgentBuilder::new(model).build();
973
974        let mut stream = agent.stream_prompt("do tool work").multi_turn(3).await;
975        let mut saw_tool_call = false;
976        let mut saw_tool_result = false;
977        let mut saw_final_response = false;
978        let mut final_text = String::new();
979
980        while let Some(item) = stream.next().await {
981            match item {
982                Ok(MultiTurnStreamItem::StreamAssistantItem(
983                    StreamedAssistantContent::ToolCall { .. },
984                )) => {
985                    saw_tool_call = true;
986                }
987                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
988                    ..
989                })) => {
990                    saw_tool_result = true;
991                }
992                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
993                    text,
994                ))) => {
995                    final_text.push_str(&text.text);
996                }
997                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
998                    saw_final_response = true;
999                    break;
1000                }
1001                Ok(_) => {}
1002                Err(err) => panic!("unexpected streaming error: {err:?}"),
1003            }
1004        }
1005
1006        assert!(saw_tool_call);
1007        assert!(saw_tool_result);
1008        assert!(saw_final_response);
1009        assert_eq!(final_text, "done");
1010        assert_eq!(turn_counter.load(Ordering::SeqCst), 2);
1011    }
1012
1013    /// Background task that logs periodically to detect span leakage.
1014    /// If span leakage occurs, these logs will be prefixed with `invoke_agent{...}`.
1015    async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
1016        let mut interval = tokio::time::interval(Duration::from_millis(50));
1017        let mut count = 0u32;
1018
1019        while !stop.load(Ordering::Relaxed) {
1020            interval.tick().await;
1021            count += 1;
1022
1023            tracing::event!(
1024                target: "background_logger",
1025                tracing::Level::INFO,
1026                count = count,
1027                "Background tick"
1028            );
1029
1030            // Check if we're inside an unexpected span
1031            let current = tracing::Span::current();
1032            if !current.is_disabled() && !current.is_none() {
1033                leak_count.fetch_add(1, Ordering::Relaxed);
1034            }
1035        }
1036
1037        tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
1038    }
1039
1040    /// Test that span context doesn't leak to concurrent tasks during streaming.
1041    ///
1042    /// This test verifies that using `.instrument()` instead of `span.enter()` in
1043    /// async_stream prevents thread-local span context from leaking to other tasks.
1044    ///
1045    /// Uses single-threaded runtime to force all tasks onto the same thread,
1046    /// making the span leak deterministic (it only occurs when tasks share a thread).
1047    #[tokio::test(flavor = "current_thread")]
1048    #[ignore = "This requires an API key"]
1049    async fn test_span_context_isolation() {
1050        let stop = Arc::new(AtomicBool::new(false));
1051        let leak_count = Arc::new(AtomicU32::new(0));
1052
1053        // Start background logger
1054        let bg_stop = stop.clone();
1055        let bg_leak = leak_count.clone();
1056        let bg_handle = tokio::spawn(async move {
1057            background_logger(bg_stop, bg_leak).await;
1058        });
1059
1060        // Small delay to let background logger start
1061        tokio::time::sleep(Duration::from_millis(100)).await;
1062
1063        // Make streaming request WITHOUT an outer span so rig creates its own invoke_agent span
1064        // (rig reuses current span if one exists, so we need to ensure there's no current span)
1065        let client = anthropic::Client::from_env();
1066        let agent = client
1067            .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
1068            .preamble("You are a helpful assistant.")
1069            .temperature(0.1)
1070            .max_tokens(100)
1071            .build();
1072
1073        let mut stream = agent
1074            .stream_prompt("Say 'hello world' and nothing else.")
1075            .await;
1076
1077        let mut full_content = String::new();
1078        while let Some(item) = stream.next().await {
1079            match item {
1080                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1081                    text,
1082                ))) => {
1083                    full_content.push_str(&text.text);
1084                }
1085                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
1086                    break;
1087                }
1088                Err(e) => {
1089                    tracing::warn!("Error: {:?}", e);
1090                    break;
1091                }
1092                _ => {}
1093            }
1094        }
1095
1096        tracing::info!("Got response: {:?}", full_content);
1097
1098        // Stop background logger
1099        stop.store(true, Ordering::Relaxed);
1100        bg_handle.await.unwrap();
1101
1102        let leaks = leak_count.load(Ordering::Relaxed);
1103        assert_eq!(
1104            leaks, 0,
1105            "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
1106             This indicates that span.enter() is being used inside async_stream instead of .instrument()"
1107        );
1108    }
1109
1110    /// Test that FinalResponse contains the updated chat history when with_history is used.
1111    ///
1112    /// This verifies that:
1113    /// 1. FinalResponse.history() returns Some when with_history was called
1114    /// 2. The history contains both the user prompt and assistant response
1115    #[tokio::test]
1116    #[ignore = "This requires an API key"]
1117    async fn test_chat_history_in_final_response() {
1118        use crate::message::Message;
1119
1120        let client = anthropic::Client::from_env();
1121        let agent = client
1122            .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
1123            .preamble("You are a helpful assistant. Keep responses brief.")
1124            .temperature(0.1)
1125            .max_tokens(50)
1126            .build();
1127
1128        // Send streaming request with history
1129        let empty_history: &[Message] = &[];
1130        let mut stream = agent
1131            .stream_prompt("Say 'hello' and nothing else.")
1132            .with_history(empty_history)
1133            .await;
1134
1135        // Consume the stream and collect FinalResponse
1136        let mut response_text = String::new();
1137        let mut final_history = None;
1138        while let Some(item) = stream.next().await {
1139            match item {
1140                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1141                    text,
1142                ))) => {
1143                    response_text.push_str(&text.text);
1144                }
1145                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1146                    final_history = res.history().map(|h| h.to_vec());
1147                    break;
1148                }
1149                Err(e) => {
1150                    panic!("Streaming error: {:?}", e);
1151                }
1152                _ => {}
1153            }
1154        }
1155
1156        let history =
1157            final_history.expect("FinalResponse should contain history when with_history is used");
1158
1159        // Should contain at least the user message
1160        assert!(
1161            history.iter().any(|m| matches!(m, Message::User { .. })),
1162            "History should contain the user message"
1163        );
1164
1165        // Should contain the assistant response
1166        assert!(
1167            history
1168                .iter()
1169                .any(|m| matches!(m, Message::Assistant { .. })),
1170            "History should contain the assistant response"
1171        );
1172
1173        tracing::info!(
1174            "History after streaming: {} messages, response: {:?}",
1175            history.len(),
1176            response_text
1177        );
1178    }
1179}