Skip to main content

rig/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::completion::{DynamicContextStore, build_completion_request},
4    agent::prompt_request::{HookAction, hooks::PromptHook},
5    completion::{Document, GetTokenUsage},
6    json_utils,
7    message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent},
8    streaming::{StreamedAssistantContent, StreamedUserContent},
9    tool::server::ToolServerHandle,
10    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
11};
12use futures::{Stream, StreamExt};
13use serde::{Deserialize, Serialize};
14use std::{pin::Pin, sync::Arc};
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::ToolCallHookAction;
19use crate::{
20    agent::Agent,
21    completion::{CompletionError, CompletionModel, PromptError},
22    message::{Message, Text},
23    tool::ToolSetError,
24};
25
26#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
27pub type StreamingResult<R> =
28    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
29
30#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
31pub type StreamingResult<R> =
32    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(tag = "type", rename_all = "camelCase")]
36#[non_exhaustive]
37pub enum MultiTurnStreamItem<R> {
38    /// A streamed assistant content item.
39    StreamAssistantItem(StreamedAssistantContent<R>),
40    /// A streamed user content item (mostly for tool results).
41    StreamUserItem(StreamedUserContent),
42    /// The final result from the stream.
43    FinalResponse(FinalResponse),
44}
45
46#[derive(Deserialize, Serialize, Debug, Clone)]
47#[serde(rename_all = "camelCase")]
48pub struct FinalResponse {
49    response: String,
50    aggregated_usage: crate::completion::Usage,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    history: Option<Vec<Message>>,
53}
54
55impl FinalResponse {
56    pub fn empty() -> Self {
57        Self {
58            response: String::new(),
59            aggregated_usage: crate::completion::Usage::new(),
60            history: None,
61        }
62    }
63
64    pub fn response(&self) -> &str {
65        &self.response
66    }
67
68    pub fn usage(&self) -> crate::completion::Usage {
69        self.aggregated_usage
70    }
71
72    pub fn history(&self) -> Option<&[Message]> {
73        self.history.as_deref()
74    }
75}
76
77impl<R> MultiTurnStreamItem<R> {
78    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
79        Self::StreamAssistantItem(item)
80    }
81
82    pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
83        Self::FinalResponse(FinalResponse {
84            response: response.to_string(),
85            aggregated_usage,
86            history: None,
87        })
88    }
89
90    pub fn final_response_with_history(
91        response: &str,
92        aggregated_usage: crate::completion::Usage,
93        history: Option<Vec<Message>>,
94    ) -> Self {
95        Self::FinalResponse(FinalResponse {
96            response: response.to_string(),
97            aggregated_usage,
98            history,
99        })
100    }
101}
102
103fn merge_reasoning_blocks(
104    accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
105    incoming: &crate::message::Reasoning,
106) {
107    let ids_match = |existing: &crate::message::Reasoning| {
108        matches!(
109            (&existing.id, &incoming.id),
110            (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
111        )
112    };
113
114    if let Some(existing) = accumulated_reasoning
115        .iter_mut()
116        .rev()
117        .find(|existing| ids_match(existing))
118    {
119        existing.content.extend(incoming.content.clone());
120    } else {
121        accumulated_reasoning.push(incoming.clone());
122    }
123}
124
125async fn cancelled_prompt_error(chat_history: &Vec<Message>, reason: String) -> StreamingError {
126    StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.to_owned(), reason).into())
127}
128
129fn tool_result_to_user_message(
130    id: String,
131    call_id: Option<String>,
132    tool_result: String,
133) -> Message {
134    let content = OneOrMany::one(ToolResultContent::text(tool_result));
135    let user_content = match call_id {
136        Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
137        None => UserContent::tool_result(id, content),
138    };
139
140    Message::User {
141        content: OneOrMany::one(user_content),
142    }
143}
144
145#[derive(Debug, thiserror::Error)]
146pub enum StreamingError {
147    #[error("CompletionError: {0}")]
148    Completion(#[from] CompletionError),
149    #[error("PromptError: {0}")]
150    Prompt(#[from] Box<PromptError>),
151    #[error("ToolSetError: {0}")]
152    Tool(#[from] ToolSetError),
153}
154
155const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
156
157/// A builder for creating prompt requests with customizable options.
158/// Uses generics to track which options have been set during the build process.
159///
160/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
161/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
162/// attempting to await (which will send the prompt request) can potentially return
163/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
164/// back to back.
165pub struct StreamingPromptRequest<M, P>
166where
167    M: CompletionModel,
168    P: PromptHook<M> + 'static,
169{
170    /// The prompt message to send to the model
171    prompt: Message,
172    /// Optional chat history to include with the prompt.
173    chat_history: Option<Vec<Message>>,
174    /// Maximum Turns for multi-turn conversations (0 means no multi-turn)
175    max_turns: usize,
176
177    // Agent data (cloned from agent to allow hook type transitions):
178    /// The completion model
179    model: Arc<M>,
180    /// Agent name for logging
181    agent_name: Option<String>,
182    /// System prompt
183    preamble: Option<String>,
184    /// Static context documents
185    static_context: Vec<Document>,
186    /// Temperature setting
187    temperature: Option<f64>,
188    /// Max tokens setting
189    max_tokens: Option<u64>,
190    /// Additional model parameters
191    additional_params: Option<serde_json::Value>,
192    /// Tool server handle for tool execution
193    tool_server_handle: ToolServerHandle,
194    /// Dynamic context store
195    dynamic_context: DynamicContextStore,
196    /// Tool choice setting
197    tool_choice: Option<ToolChoice>,
198    /// Optional JSON Schema for structured output
199    output_schema: Option<schemars::Schema>,
200    /// Optional per-request hook for events
201    hook: Option<P>,
202}
203
204impl<M, P> StreamingPromptRequest<M, P>
205where
206    M: CompletionModel + 'static,
207    <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
208    P: PromptHook<M>,
209{
210    /// Create a new StreamingPromptRequest with the given prompt and model.
211    /// Note: This creates a request without an agent hook. Use `from_agent` to include the agent's hook.
212    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
213        StreamingPromptRequest {
214            prompt: prompt.into(),
215            chat_history: None,
216            max_turns: agent.default_max_turns.unwrap_or_default(),
217            model: agent.model.clone(),
218            agent_name: agent.name.clone(),
219            preamble: agent.preamble.clone(),
220            static_context: agent.static_context.clone(),
221            temperature: agent.temperature,
222            max_tokens: agent.max_tokens,
223            additional_params: agent.additional_params.clone(),
224            tool_server_handle: agent.tool_server_handle.clone(),
225            dynamic_context: agent.dynamic_context.clone(),
226            tool_choice: agent.tool_choice.clone(),
227            output_schema: agent.output_schema.clone(),
228            hook: None,
229        }
230    }
231
232    /// Create a new StreamingPromptRequest from an agent, cloning the agent's data and default hook.
233    pub fn from_agent<P2>(
234        agent: &Agent<M, P2>,
235        prompt: impl Into<Message>,
236    ) -> StreamingPromptRequest<M, P2>
237    where
238        P2: PromptHook<M>,
239    {
240        StreamingPromptRequest {
241            prompt: prompt.into(),
242            chat_history: None,
243            max_turns: agent.default_max_turns.unwrap_or_default(),
244            model: agent.model.clone(),
245            agent_name: agent.name.clone(),
246            preamble: agent.preamble.clone(),
247            static_context: agent.static_context.clone(),
248            temperature: agent.temperature,
249            max_tokens: agent.max_tokens,
250            additional_params: agent.additional_params.clone(),
251            tool_server_handle: agent.tool_server_handle.clone(),
252            dynamic_context: agent.dynamic_context.clone(),
253            tool_choice: agent.tool_choice.clone(),
254            output_schema: agent.output_schema.clone(),
255            hook: agent.hook.clone(),
256        }
257    }
258
259    fn agent_name(&self) -> &str {
260        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
261    }
262
263    /// Set the maximum Turns for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
264    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
265    pub fn multi_turn(mut self, turns: usize) -> Self {
266        self.max_turns = turns;
267        self
268    }
269
270    /// Add chat history to the prompt request.
271    ///
272    /// When history is provided, the final [`FinalResponse`] will include the
273    /// updated chat history (original messages + new user prompt + assistant response).
274    /// ```ignore
275    /// let mut stream = agent
276    ///     .stream_prompt("Hello")
277    ///     .with_history(vec![])
278    ///     .await;
279    /// // ... consume stream ...
280    /// // Access updated history from FinalResponse::history()
281    /// ```
282    pub fn with_history(mut self, history: Vec<Message>) -> Self {
283        self.chat_history = Some(history);
284        self
285    }
286
287    /// Attach a per-request hook for tool call events.
288    /// This overrides any default hook set on the agent.
289    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
290    where
291        P2: PromptHook<M>,
292    {
293        StreamingPromptRequest {
294            prompt: self.prompt,
295            chat_history: self.chat_history,
296            max_turns: self.max_turns,
297            model: self.model,
298            agent_name: self.agent_name,
299            preamble: self.preamble,
300            static_context: self.static_context,
301            temperature: self.temperature,
302            max_tokens: self.max_tokens,
303            additional_params: self.additional_params,
304            tool_server_handle: self.tool_server_handle,
305            dynamic_context: self.dynamic_context,
306            tool_choice: self.tool_choice,
307            output_schema: self.output_schema,
308            hook: Some(hook),
309        }
310    }
311
312    async fn send(self) -> StreamingResult<M::StreamingResponse> {
313        let agent_span = if tracing::Span::current().is_disabled() {
314            info_span!(
315                "invoke_agent",
316                gen_ai.operation.name = "invoke_agent",
317                gen_ai.agent.name = self.agent_name(),
318                gen_ai.system_instructions = self.preamble,
319                gen_ai.prompt = tracing::field::Empty,
320                gen_ai.completion = tracing::field::Empty,
321                gen_ai.usage.input_tokens = tracing::field::Empty,
322                gen_ai.usage.output_tokens = tracing::field::Empty,
323                gen_ai.usage.cached_tokens = tracing::field::Empty,
324            )
325        } else {
326            tracing::Span::current()
327        };
328
329        let prompt = self.prompt;
330        if let Some(text) = prompt.rag_text() {
331            agent_span.record("gen_ai.prompt", text);
332        }
333
334        // Clone fields needed inside the stream
335        let model = self.model.clone();
336        let preamble = self.preamble.clone();
337        let static_context = self.static_context.clone();
338        let temperature = self.temperature;
339        let max_tokens = self.max_tokens;
340        let additional_params = self.additional_params.clone();
341        let tool_server_handle = self.tool_server_handle.clone();
342        let dynamic_context = self.dynamic_context.clone();
343        let tool_choice = self.tool_choice.clone();
344        let agent_name = self.agent_name.clone();
345        let has_history = self.chat_history.is_some();
346        let mut chat_history = self.chat_history.unwrap_or_default();
347
348        let mut current_max_turns = 0;
349        let mut last_prompt_error = String::new();
350
351        let mut last_text_response = String::new();
352        let mut is_text_response = false;
353        let mut max_turns_reached = false;
354        let output_schema = self.output_schema;
355
356        let mut aggregated_usage = crate::completion::Usage::new();
357
358        // NOTE: We use .instrument(agent_span) instead of span.enter() to avoid
359        // span context leaking to other concurrent tasks. Using span.enter() inside
360        // async_stream::stream! holds the guard across yield points, which causes
361        // thread-local span context to leak when other tasks run on the same thread.
362        // See: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#in-asynchronous-code
363        // See also: https://github.com/rust-lang/rust-clippy/issues/8722
364        let stream = async_stream::stream! {
365            let mut current_prompt = prompt.clone();
366
367            'outer: loop {
368                if current_max_turns > self.max_turns + 1 {
369                    last_prompt_error = current_prompt.rag_text().unwrap_or_default();
370                    max_turns_reached = true;
371                    break;
372                }
373
374                current_max_turns += 1;
375
376                if self.max_turns > 1 {
377                    tracing::info!(
378                        "Current conversation Turns: {}/{}",
379                        current_max_turns,
380                        self.max_turns
381                    );
382                }
383
384                if let Some(ref hook) = self.hook {
385                    let history_snapshot = chat_history.clone();
386                    if let HookAction::Terminate { reason } = hook.on_completion_call(&current_prompt, &history_snapshot)
387                        .await {
388                        yield Err(cancelled_prompt_error(&chat_history, reason).await);
389                        break 'outer;
390                    }
391                }
392
393                let chat_stream_span = info_span!(
394                    target: "rig::agent_chat",
395                    parent: tracing::Span::current(),
396                    "chat_streaming",
397                    gen_ai.operation.name = "chat",
398                    gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
399                    gen_ai.system_instructions = preamble,
400                    gen_ai.provider.name = tracing::field::Empty,
401                    gen_ai.request.model = tracing::field::Empty,
402                    gen_ai.response.id = tracing::field::Empty,
403                    gen_ai.response.model = tracing::field::Empty,
404                    gen_ai.usage.output_tokens = tracing::field::Empty,
405                    gen_ai.usage.input_tokens = tracing::field::Empty,
406                    gen_ai.usage.cached_tokens = tracing::field::Empty,
407                    gen_ai.input.messages = tracing::field::Empty,
408                    gen_ai.output.messages = tracing::field::Empty,
409                );
410
411                let history_snapshot = chat_history.clone();
412                let mut stream = tracing::Instrument::instrument(
413                    build_completion_request(
414                        &model,
415                        current_prompt.clone(),
416                        history_snapshot,
417                        preamble.as_deref(),
418                        &static_context,
419                        temperature,
420                        max_tokens,
421                        additional_params.as_ref(),
422                        tool_choice.as_ref(),
423                        &tool_server_handle,
424                        &dynamic_context,
425                        output_schema.as_ref(),
426                    )
427                    .await?
428                    .stream(), chat_stream_span
429                )
430
431                .await?;
432
433                chat_history.push(current_prompt.clone());
434
435                let mut tool_calls = vec![];
436                let mut tool_results = vec![];
437                let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
438                // Kept separate from accumulated_reasoning so providers requiring
439                // signatures (e.g. Anthropic) never see unsigned blocks.
440                let mut pending_reasoning_delta_text = String::new();
441                let mut pending_reasoning_delta_id: Option<String> = None;
442                let mut saw_tool_call_this_turn = false;
443
444                while let Some(content) = stream.next().await {
445                    match content {
446                        Ok(StreamedAssistantContent::Text(text)) => {
447                            if !is_text_response {
448                                last_text_response = String::new();
449                                is_text_response = true;
450                            }
451                            last_text_response.push_str(&text.text);
452                            if let Some(ref hook) = self.hook &&
453                                let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &last_text_response).await {
454                                    yield Err(cancelled_prompt_error(&chat_history, reason).await);
455                                    break 'outer;
456                            }
457
458                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
459                        },
460                        Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id }) => {
461                            let tool_span = info_span!(
462                                parent: tracing::Span::current(),
463                                "execute_tool",
464                                gen_ai.operation.name = "execute_tool",
465                                gen_ai.tool.type = "function",
466                                gen_ai.tool.name = tracing::field::Empty,
467                                gen_ai.tool.call.id = tracing::field::Empty,
468                                gen_ai.tool.call.arguments = tracing::field::Empty,
469                                gen_ai.tool.call.result = tracing::field::Empty
470                            );
471
472                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
473
474                            let tc_result = async {
475                                let tool_span = tracing::Span::current();
476                                let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
477                                if let Some(ref hook) = self.hook {
478                                    let action = hook
479                                        .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
480                                        .await;
481
482                                    if let ToolCallHookAction::Terminate { reason } = action {
483                                        return Err(cancelled_prompt_error(&chat_history, reason).await);
484                                    }
485
486                                    if let ToolCallHookAction::Skip { reason } = action {
487                                        // Tool execution rejected, return rejection message as tool result
488                                        tracing::info!(
489                                            tool_name = tool_call.function.name.as_str(),
490                                            reason = reason,
491                                            "Tool call rejected"
492                                        );
493                                        let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
494                                        tool_calls.push(tool_call_msg);
495                                        tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
496                                        saw_tool_call_this_turn = true;
497                                        return Ok(reason);
498                                    }
499                                }
500
501                                tool_span.record("gen_ai.tool.name", &tool_call.function.name);
502                                tool_span.record("gen_ai.tool.call.arguments", &tool_args);
503
504                                let tool_result = match
505                                tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
506                                    Ok(thing) => thing,
507                                    Err(e) => {
508                                        tracing::warn!("Error while calling tool: {e}");
509                                        e.to_string()
510                                    }
511                                };
512
513                                tool_span.record("gen_ai.tool.call.result", &tool_result);
514
515                                if let Some(ref hook) = self.hook &&
516                                    let HookAction::Terminate { reason } =
517                                    hook.on_tool_result(
518                                        &tool_call.function.name,
519                                        tool_call.call_id.clone(),
520                                        &internal_call_id,
521                                        &tool_args,
522                                        &tool_result.to_string()
523                                    )
524                                    .await {
525                                        return Err(cancelled_prompt_error(&chat_history, reason).await);
526                                    }
527
528                                let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
529
530                                tool_calls.push(tool_call_msg);
531                                tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
532
533                                saw_tool_call_this_turn = true;
534                                Ok(tool_result)
535                            }.instrument(tool_span).await;
536
537                            match tc_result {
538                                Ok(text) => {
539                                    let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
540                                    yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
541                                }
542                                Err(e) => {
543                                    yield Err(e);
544                                    break 'outer;
545                                }
546                            }
547                        },
548                        Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content }) => {
549                            if let Some(ref hook) = self.hook {
550                                let (name, delta) = match &content {
551                                    rig::streaming::ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
552                                    rig::streaming::ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
553                                };
554
555                                if let HookAction::Terminate { reason } = hook.on_tool_call_delta(&id, &internal_call_id, name, delta)
556                                .await {
557                                    yield Err(cancelled_prompt_error(&chat_history, reason).await);
558                                    break 'outer;
559                                }
560                            }
561                        }
562                        Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
563                            // Accumulate reasoning for inclusion in chat history with tool calls.
564                            // OpenAI Responses API requires reasoning items to be sent back
565                            // alongside function_call items in multi-turn conversations.
566                            merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
567                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
568                        },
569                        Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
570                            // Deltas lack signatures/encrypted content that full
571                            // blocks carry; mixing them into accumulated_reasoning
572                            // causes Anthropic to reject with "signature required".
573                            pending_reasoning_delta_text.push_str(&reasoning);
574                            if pending_reasoning_delta_id.is_none() {
575                                pending_reasoning_delta_id = id.clone();
576                            }
577                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
578                        },
579                        Ok(StreamedAssistantContent::Final(final_resp)) => {
580                            if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
581                            if is_text_response {
582                                if let Some(ref hook) = self.hook &&
583                                     let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(&prompt, &final_resp).await {
584                                        yield Err(cancelled_prompt_error(&chat_history, reason).await);
585                                        break 'outer;
586                                    }
587
588                                tracing::Span::current().record("gen_ai.completion", &last_text_response);
589                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
590                                is_text_response = false;
591                            }
592                        }
593                        Err(e) => {
594                            yield Err(e.into());
595                            break 'outer;
596                        }
597                    }
598                }
599
600                // Providers like Gemini emit thinking as incremental deltas
601                // without signatures; assemble into a single block so
602                // reasoning survives into the next turn's chat history.
603                if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
604                    let mut assembled = crate::message::Reasoning::new(&pending_reasoning_delta_text);
605                    if let Some(id) = pending_reasoning_delta_id.take() {
606                        assembled = assembled.with_id(id);
607                    }
608                    accumulated_reasoning.push(assembled);
609                }
610
611                // Add reasoning and tool calls to chat history.
612                // OpenAI Responses API requires reasoning items to precede function_call items.
613                if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
614                    let mut content_items: Vec<rig::message::AssistantContent> = vec![];
615
616                    // Reasoning must come before tool calls (OpenAI requirement)
617                    for reasoning in accumulated_reasoning.drain(..) {
618                        content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
619                    }
620
621                    content_items.extend(tool_calls.clone());
622
623                    if !content_items.is_empty() {
624                        chat_history.push(Message::Assistant {
625                            id: stream.message_id.clone(),
626                            content: OneOrMany::many(content_items).expect("Should have at least one item"),
627                        });
628                    }
629                }
630
631                for (id, call_id, tool_result) in tool_results {
632                    chat_history.push(tool_result_to_user_message(id, call_id, tool_result));
633                }
634
635                // Set the current prompt to the last message in the chat history
636                current_prompt = match chat_history.pop() {
637                    Some(prompt) => prompt,
638                    None => unreachable!("Chat history should never be empty at this point"),
639                };
640
641                if !saw_tool_call_this_turn {
642                    // Add user message and assistant response to history before finishing
643                    chat_history.push(current_prompt.clone());
644                    if !last_text_response.is_empty() {
645                        chat_history.push(Message::assistant(&last_text_response));
646                    }
647
648                    let current_span = tracing::Span::current();
649                    current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
650                    current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
651                    current_span.record("gen_ai.usage.cached_tokens", aggregated_usage.cached_input_tokens);
652                    tracing::info!("Agent multi-turn stream finished");
653                    let history_snapshot = if has_history {
654                        Some(chat_history.clone())
655                    } else {
656                        None
657                    };
658                    yield Ok(MultiTurnStreamItem::final_response_with_history(
659                        &last_text_response,
660                        aggregated_usage,
661                        history_snapshot,
662                    ));
663                    break;
664                }
665            }
666
667            if max_turns_reached {
668                yield Err(Box::new(PromptError::MaxTurnsError {
669                    max_turns: self.max_turns,
670                    chat_history: Box::new(chat_history.clone()),
671                    prompt: Box::new(last_prompt_error.clone().into()),
672                }).into());
673            }
674        };
675
676        Box::pin(stream.instrument(agent_span))
677    }
678}
679
680impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
681where
682    M: CompletionModel + 'static,
683    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
684    P: PromptHook<M> + 'static,
685{
686    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
687    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
688
689    fn into_future(self) -> Self::IntoFuture {
690        // Wrap send() in a future, because send() returns a stream immediately
691        Box::pin(async move { self.send().await })
692    }
693}
694
695/// Helper function to stream a completion request to stdout.
696pub async fn stream_to_stdout<R>(
697    stream: &mut StreamingResult<R>,
698) -> Result<FinalResponse, std::io::Error> {
699    let mut final_res = FinalResponse::empty();
700    print!("Response: ");
701    while let Some(content) = stream.next().await {
702        match content {
703            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
704                Text { text },
705            ))) => {
706                print!("{text}");
707                std::io::Write::flush(&mut std::io::stdout()).unwrap();
708            }
709            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
710                reasoning,
711            ))) => {
712                let reasoning = reasoning.display_text();
713                print!("{reasoning}");
714                std::io::Write::flush(&mut std::io::stdout()).unwrap();
715            }
716            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
717                final_res = res;
718            }
719            Err(err) => {
720                eprintln!("Error: {err}");
721            }
722            _ => {}
723        }
724    }
725
726    Ok(final_res)
727}
728
729#[cfg(test)]
730mod tests {
731    use super::*;
732    use crate::agent::AgentBuilder;
733    use crate::client::ProviderClient;
734    use crate::client::completion::CompletionClient;
735    use crate::completion::{
736        CompletionError, CompletionModel, CompletionRequest, CompletionResponse,
737    };
738    use crate::message::ReasoningContent;
739    use crate::providers::anthropic;
740    use crate::streaming::StreamingPrompt;
741    use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
742    use futures::StreamExt;
743    use serde::{Deserialize, Serialize};
744    use std::sync::Arc;
745    use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
746    use std::time::Duration;
747
748    #[test]
749    fn merge_reasoning_blocks_preserves_order_and_signatures() {
750        let mut accumulated = Vec::new();
751        let first = crate::message::Reasoning {
752            id: Some("rs_1".to_string()),
753            content: vec![ReasoningContent::Text {
754                text: "step-1".to_string(),
755                signature: Some("sig-1".to_string()),
756            }],
757        };
758        let second = crate::message::Reasoning {
759            id: Some("rs_1".to_string()),
760            content: vec![
761                ReasoningContent::Text {
762                    text: "step-2".to_string(),
763                    signature: Some("sig-2".to_string()),
764                },
765                ReasoningContent::Summary("summary".to_string()),
766            ],
767        };
768
769        merge_reasoning_blocks(&mut accumulated, &first);
770        merge_reasoning_blocks(&mut accumulated, &second);
771
772        assert_eq!(accumulated.len(), 1);
773        let merged = accumulated.first().expect("expected accumulated reasoning");
774        assert_eq!(merged.id.as_deref(), Some("rs_1"));
775        assert_eq!(merged.content.len(), 3);
776        assert!(matches!(
777            merged.content.first(),
778            Some(ReasoningContent::Text { text, signature: Some(sig) })
779                if text == "step-1" && sig == "sig-1"
780        ));
781        assert!(matches!(
782            merged.content.get(1),
783            Some(ReasoningContent::Text { text, signature: Some(sig) })
784                if text == "step-2" && sig == "sig-2"
785        ));
786    }
787
788    #[test]
789    fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
790        let mut accumulated = vec![crate::message::Reasoning {
791            id: Some("rs_a".to_string()),
792            content: vec![ReasoningContent::Text {
793                text: "step-1".to_string(),
794                signature: None,
795            }],
796        }];
797        let incoming = crate::message::Reasoning {
798            id: Some("rs_b".to_string()),
799            content: vec![ReasoningContent::Text {
800                text: "step-2".to_string(),
801                signature: None,
802            }],
803        };
804
805        merge_reasoning_blocks(&mut accumulated, &incoming);
806        assert_eq!(accumulated.len(), 2);
807        assert_eq!(
808            accumulated.first().and_then(|r| r.id.as_deref()),
809            Some("rs_a")
810        );
811        assert_eq!(
812            accumulated.get(1).and_then(|r| r.id.as_deref()),
813            Some("rs_b")
814        );
815    }
816
817    #[test]
818    fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
819        let mut accumulated = vec![crate::message::Reasoning {
820            id: None,
821            content: vec![ReasoningContent::Text {
822                text: "first".to_string(),
823                signature: None,
824            }],
825        }];
826        let incoming = crate::message::Reasoning {
827            id: None,
828            content: vec![ReasoningContent::Text {
829                text: "second".to_string(),
830                signature: None,
831            }],
832        };
833
834        merge_reasoning_blocks(&mut accumulated, &incoming);
835        assert_eq!(accumulated.len(), 2);
836        assert!(matches!(
837            accumulated.first(),
838            Some(crate::message::Reasoning {
839                id: None,
840                content
841            }) if matches!(
842                content.first(),
843                Some(ReasoningContent::Text { text, .. }) if text == "first"
844            )
845        ));
846        assert!(matches!(
847            accumulated.get(1),
848            Some(crate::message::Reasoning {
849                id: None,
850                content
851            }) if matches!(
852                content.first(),
853                Some(ReasoningContent::Text { text, .. }) if text == "second"
854            )
855        ));
856    }
857
858    #[derive(Clone, Debug, Deserialize, Serialize)]
859    struct MockStreamingResponse {
860        usage: crate::completion::Usage,
861    }
862
863    impl MockStreamingResponse {
864        fn new(total_tokens: u64) -> Self {
865            let mut usage = crate::completion::Usage::new();
866            usage.total_tokens = total_tokens;
867            Self { usage }
868        }
869    }
870
871    impl crate::completion::GetTokenUsage for MockStreamingResponse {
872        fn token_usage(&self) -> Option<crate::completion::Usage> {
873            Some(self.usage)
874        }
875    }
876
877    #[derive(Clone, Default)]
878    struct MultiTurnMockModel {
879        turn_counter: Arc<AtomicUsize>,
880    }
881
882    #[allow(refining_impl_trait)]
883    impl CompletionModel for MultiTurnMockModel {
884        type Response = ();
885        type StreamingResponse = MockStreamingResponse;
886        type Client = ();
887
888        fn make(_: &Self::Client, _: impl Into<String>) -> Self {
889            Self::default()
890        }
891
892        async fn completion(
893            &self,
894            _request: CompletionRequest,
895        ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
896            Err(CompletionError::ProviderError(
897                "completion is unused in this streaming test".to_string(),
898            ))
899        }
900
901        async fn stream(
902            &self,
903            _request: CompletionRequest,
904        ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
905            let turn = self.turn_counter.fetch_add(1, Ordering::SeqCst);
906            let stream = async_stream::stream! {
907                if turn == 0 {
908                    yield Ok(RawStreamingChoice::ToolCall(
909                        RawStreamingToolCall::new(
910                            "tool_call_1".to_string(),
911                            "missing_tool".to_string(),
912                            serde_json::json!({"input": "value"}),
913                        )
914                        .with_call_id("call_1".to_string()),
915                    ));
916                    yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(4)));
917                } else {
918                    yield Ok(RawStreamingChoice::Message("done".to_string()));
919                    yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(6)));
920                }
921            };
922
923            let pinned_stream: crate::streaming::StreamingResult<Self::StreamingResponse> =
924                Box::pin(stream);
925            Ok(StreamingCompletionResponse::stream(pinned_stream))
926        }
927    }
928
929    #[tokio::test]
930    async fn stream_prompt_continues_after_tool_call_turn() {
931        let model = MultiTurnMockModel::default();
932        let turn_counter = model.turn_counter.clone();
933        let agent = AgentBuilder::new(model).build();
934
935        let mut stream = agent.stream_prompt("do tool work").multi_turn(3).await;
936        let mut saw_tool_call = false;
937        let mut saw_tool_result = false;
938        let mut saw_final_response = false;
939        let mut final_text = String::new();
940
941        while let Some(item) = stream.next().await {
942            match item {
943                Ok(MultiTurnStreamItem::StreamAssistantItem(
944                    StreamedAssistantContent::ToolCall { .. },
945                )) => {
946                    saw_tool_call = true;
947                }
948                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
949                    ..
950                })) => {
951                    saw_tool_result = true;
952                }
953                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
954                    text,
955                ))) => {
956                    final_text.push_str(&text.text);
957                }
958                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
959                    saw_final_response = true;
960                    break;
961                }
962                Ok(_) => {}
963                Err(err) => panic!("unexpected streaming error: {err:?}"),
964            }
965        }
966
967        assert!(saw_tool_call);
968        assert!(saw_tool_result);
969        assert!(saw_final_response);
970        assert_eq!(final_text, "done");
971        assert_eq!(turn_counter.load(Ordering::SeqCst), 2);
972    }
973
974    /// Background task that logs periodically to detect span leakage.
975    /// If span leakage occurs, these logs will be prefixed with `invoke_agent{...}`.
976    async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
977        let mut interval = tokio::time::interval(Duration::from_millis(50));
978        let mut count = 0u32;
979
980        while !stop.load(Ordering::Relaxed) {
981            interval.tick().await;
982            count += 1;
983
984            tracing::event!(
985                target: "background_logger",
986                tracing::Level::INFO,
987                count = count,
988                "Background tick"
989            );
990
991            // Check if we're inside an unexpected span
992            let current = tracing::Span::current();
993            if !current.is_disabled() && !current.is_none() {
994                leak_count.fetch_add(1, Ordering::Relaxed);
995            }
996        }
997
998        tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
999    }
1000
1001    /// Test that span context doesn't leak to concurrent tasks during streaming.
1002    ///
1003    /// This test verifies that using `.instrument()` instead of `span.enter()` in
1004    /// async_stream prevents thread-local span context from leaking to other tasks.
1005    ///
1006    /// Uses single-threaded runtime to force all tasks onto the same thread,
1007    /// making the span leak deterministic (it only occurs when tasks share a thread).
1008    #[tokio::test(flavor = "current_thread")]
1009    #[ignore = "This requires an API key"]
1010    async fn test_span_context_isolation() {
1011        let stop = Arc::new(AtomicBool::new(false));
1012        let leak_count = Arc::new(AtomicU32::new(0));
1013
1014        // Start background logger
1015        let bg_stop = stop.clone();
1016        let bg_leak = leak_count.clone();
1017        let bg_handle = tokio::spawn(async move {
1018            background_logger(bg_stop, bg_leak).await;
1019        });
1020
1021        // Small delay to let background logger start
1022        tokio::time::sleep(Duration::from_millis(100)).await;
1023
1024        // Make streaming request WITHOUT an outer span so rig creates its own invoke_agent span
1025        // (rig reuses current span if one exists, so we need to ensure there's no current span)
1026        let client = anthropic::Client::from_env();
1027        let agent = client
1028            .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
1029            .preamble("You are a helpful assistant.")
1030            .temperature(0.1)
1031            .max_tokens(100)
1032            .build();
1033
1034        let mut stream = agent
1035            .stream_prompt("Say 'hello world' and nothing else.")
1036            .await;
1037
1038        let mut full_content = String::new();
1039        while let Some(item) = stream.next().await {
1040            match item {
1041                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1042                    text,
1043                ))) => {
1044                    full_content.push_str(&text.text);
1045                }
1046                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
1047                    break;
1048                }
1049                Err(e) => {
1050                    tracing::warn!("Error: {:?}", e);
1051                    break;
1052                }
1053                _ => {}
1054            }
1055        }
1056
1057        tracing::info!("Got response: {:?}", full_content);
1058
1059        // Stop background logger
1060        stop.store(true, Ordering::Relaxed);
1061        bg_handle.await.unwrap();
1062
1063        let leaks = leak_count.load(Ordering::Relaxed);
1064        assert_eq!(
1065            leaks, 0,
1066            "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
1067             This indicates that span.enter() is being used inside async_stream instead of .instrument()"
1068        );
1069    }
1070
1071    /// Test that FinalResponse contains the updated chat history when with_history is used.
1072    ///
1073    /// This verifies that:
1074    /// 1. FinalResponse.history() returns Some when with_history was called
1075    /// 2. The history contains both the user prompt and assistant response
1076    #[tokio::test]
1077    #[ignore = "This requires an API key"]
1078    async fn test_chat_history_in_final_response() {
1079        use crate::message::Message;
1080
1081        let client = anthropic::Client::from_env();
1082        let agent = client
1083            .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
1084            .preamble("You are a helpful assistant. Keep responses brief.")
1085            .temperature(0.1)
1086            .max_tokens(50)
1087            .build();
1088
1089        // Send streaming request with history
1090        let mut stream = agent
1091            .stream_prompt("Say 'hello' and nothing else.")
1092            .with_history(vec![])
1093            .await;
1094
1095        // Consume the stream and collect FinalResponse
1096        let mut response_text = String::new();
1097        let mut final_history = None;
1098        while let Some(item) = stream.next().await {
1099            match item {
1100                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1101                    text,
1102                ))) => {
1103                    response_text.push_str(&text.text);
1104                }
1105                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1106                    final_history = res.history().map(|h| h.to_vec());
1107                    break;
1108                }
1109                Err(e) => {
1110                    panic!("Streaming error: {:?}", e);
1111                }
1112                _ => {}
1113            }
1114        }
1115
1116        let history =
1117            final_history.expect("FinalResponse should contain history when with_history is used");
1118
1119        // Should contain at least the user message
1120        assert!(
1121            history.iter().any(|m| matches!(m, Message::User { .. })),
1122            "History should contain the user message"
1123        );
1124
1125        // Should contain the assistant response
1126        assert!(
1127            history
1128                .iter()
1129                .any(|m| matches!(m, Message::Assistant { .. })),
1130            "History should contain the assistant response"
1131        );
1132
1133        tracing::info!(
1134            "History after streaming: {} messages, response: {:?}",
1135            history.len(),
1136            response_text
1137        );
1138    }
1139}