Skip to main content

rig/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::completion::{DynamicContextStore, build_completion_request},
4    agent::prompt_request::{HookAction, hooks::PromptHook},
5    completion::{Document, GetTokenUsage},
6    json_utils,
7    message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent},
8    streaming::{StreamedAssistantContent, StreamedUserContent},
9    tool::server::ToolServerHandle,
10    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
11};
12use futures::{Stream, StreamExt};
13use serde::{Deserialize, Serialize};
14use std::{pin::Pin, sync::Arc};
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::ToolCallHookAction;
19use crate::{
20    agent::Agent,
21    completion::{CompletionError, CompletionModel, PromptError},
22    message::{Message, Text},
23    tool::ToolSetError,
24};
25
26#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
27pub type StreamingResult<R> =
28    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
29
30#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
31pub type StreamingResult<R> =
32    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(tag = "type", rename_all = "camelCase")]
36#[non_exhaustive]
37pub enum MultiTurnStreamItem<R> {
38    /// A streamed assistant content item.
39    StreamAssistantItem(StreamedAssistantContent<R>),
40    /// A streamed user content item (mostly for tool results).
41    StreamUserItem(StreamedUserContent),
42    /// The final result from the stream.
43    FinalResponse(FinalResponse),
44}
45
46#[derive(Deserialize, Serialize, Debug, Clone)]
47#[serde(rename_all = "camelCase")]
48pub struct FinalResponse {
49    response: String,
50    aggregated_usage: crate::completion::Usage,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    history: Option<Vec<Message>>,
53}
54
55impl FinalResponse {
56    pub fn empty() -> Self {
57        Self {
58            response: String::new(),
59            aggregated_usage: crate::completion::Usage::new(),
60            history: None,
61        }
62    }
63
64    pub fn response(&self) -> &str {
65        &self.response
66    }
67
68    pub fn usage(&self) -> crate::completion::Usage {
69        self.aggregated_usage
70    }
71
72    pub fn history(&self) -> Option<&[Message]> {
73        self.history.as_deref()
74    }
75}
76
77impl<R> MultiTurnStreamItem<R> {
78    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
79        Self::StreamAssistantItem(item)
80    }
81
82    pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
83        Self::FinalResponse(FinalResponse {
84            response: response.to_string(),
85            aggregated_usage,
86            history: None,
87        })
88    }
89
90    pub fn final_response_with_history(
91        response: &str,
92        aggregated_usage: crate::completion::Usage,
93        history: Option<Vec<Message>>,
94    ) -> Self {
95        Self::FinalResponse(FinalResponse {
96            response: response.to_string(),
97            aggregated_usage,
98            history,
99        })
100    }
101}
102
103fn merge_reasoning_blocks(
104    accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
105    incoming: &crate::message::Reasoning,
106) {
107    let ids_match = |existing: &crate::message::Reasoning| {
108        matches!(
109            (&existing.id, &incoming.id),
110            (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
111        )
112    };
113
114    if let Some(existing) = accumulated_reasoning
115        .iter_mut()
116        .rev()
117        .find(|existing| ids_match(existing))
118    {
119        existing.content.extend(incoming.content.clone());
120    } else {
121        accumulated_reasoning.push(incoming.clone());
122    }
123}
124
125async fn cancelled_prompt_error(chat_history: &Vec<Message>, reason: String) -> StreamingError {
126    StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.to_owned(), reason).into())
127}
128
129fn tool_result_to_user_message(
130    id: String,
131    call_id: Option<String>,
132    tool_result: String,
133) -> Message {
134    let content = OneOrMany::one(ToolResultContent::text(tool_result));
135    let user_content = match call_id {
136        Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
137        None => UserContent::tool_result(id, content),
138    };
139
140    Message::User {
141        content: OneOrMany::one(user_content),
142    }
143}
144
145#[derive(Debug, thiserror::Error)]
146pub enum StreamingError {
147    #[error("CompletionError: {0}")]
148    Completion(#[from] CompletionError),
149    #[error("PromptError: {0}")]
150    Prompt(#[from] Box<PromptError>),
151    #[error("ToolSetError: {0}")]
152    Tool(#[from] ToolSetError),
153}
154
155const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
156
157/// A builder for creating prompt requests with customizable options.
158/// Uses generics to track which options have been set during the build process.
159///
160/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
161/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
162/// attempting to await (which will send the prompt request) can potentially return
163/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
164/// back to back.
165pub struct StreamingPromptRequest<M, P>
166where
167    M: CompletionModel,
168    P: PromptHook<M> + 'static,
169{
170    /// The prompt message to send to the model
171    prompt: Message,
172    /// Optional chat history to include with the prompt.
173    chat_history: Option<Vec<Message>>,
174    /// Maximum Turns for multi-turn conversations (0 means no multi-turn)
175    max_turns: usize,
176
177    // Agent data (cloned from agent to allow hook type transitions):
178    /// The completion model
179    model: Arc<M>,
180    /// Agent name for logging
181    agent_name: Option<String>,
182    /// System prompt
183    preamble: Option<String>,
184    /// Static context documents
185    static_context: Vec<Document>,
186    /// Temperature setting
187    temperature: Option<f64>,
188    /// Max tokens setting
189    max_tokens: Option<u64>,
190    /// Additional model parameters
191    additional_params: Option<serde_json::Value>,
192    /// Tool server handle for tool execution
193    tool_server_handle: ToolServerHandle,
194    /// Dynamic context store
195    dynamic_context: DynamicContextStore,
196    /// Tool choice setting
197    tool_choice: Option<ToolChoice>,
198    /// Optional JSON Schema for structured output
199    output_schema: Option<schemars::Schema>,
200    /// Optional per-request hook for events
201    hook: Option<P>,
202}
203
204impl<M, P> StreamingPromptRequest<M, P>
205where
206    M: CompletionModel + 'static,
207    <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
208    P: PromptHook<M>,
209{
210    /// Create a new StreamingPromptRequest with the given prompt and model.
211    /// Note: This creates a request without an agent hook. Use `from_agent` to include the agent's hook.
212    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
213        StreamingPromptRequest {
214            prompt: prompt.into(),
215            chat_history: None,
216            max_turns: agent.default_max_turns.unwrap_or_default(),
217            model: agent.model.clone(),
218            agent_name: agent.name.clone(),
219            preamble: agent.preamble.clone(),
220            static_context: agent.static_context.clone(),
221            temperature: agent.temperature,
222            max_tokens: agent.max_tokens,
223            additional_params: agent.additional_params.clone(),
224            tool_server_handle: agent.tool_server_handle.clone(),
225            dynamic_context: agent.dynamic_context.clone(),
226            tool_choice: agent.tool_choice.clone(),
227            output_schema: agent.output_schema.clone(),
228            hook: None,
229        }
230    }
231
232    /// Create a new StreamingPromptRequest from an agent, cloning the agent's data and default hook.
233    pub fn from_agent<P2>(
234        agent: &Agent<M, P2>,
235        prompt: impl Into<Message>,
236    ) -> StreamingPromptRequest<M, P2>
237    where
238        P2: PromptHook<M>,
239    {
240        StreamingPromptRequest {
241            prompt: prompt.into(),
242            chat_history: None,
243            max_turns: agent.default_max_turns.unwrap_or_default(),
244            model: agent.model.clone(),
245            agent_name: agent.name.clone(),
246            preamble: agent.preamble.clone(),
247            static_context: agent.static_context.clone(),
248            temperature: agent.temperature,
249            max_tokens: agent.max_tokens,
250            additional_params: agent.additional_params.clone(),
251            tool_server_handle: agent.tool_server_handle.clone(),
252            dynamic_context: agent.dynamic_context.clone(),
253            tool_choice: agent.tool_choice.clone(),
254            output_schema: agent.output_schema.clone(),
255            hook: agent.hook.clone(),
256        }
257    }
258
259    fn agent_name(&self) -> &str {
260        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
261    }
262
263    /// Set the maximum Turns for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
264    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
265    pub fn multi_turn(mut self, turns: usize) -> Self {
266        self.max_turns = turns;
267        self
268    }
269
270    /// Add chat history to the prompt request.
271    ///
272    /// When history is provided, the final [`FinalResponse`] will include the
273    /// updated chat history (original messages + new user prompt + assistant response).
274    /// ```ignore
275    /// let mut stream = agent
276    ///     .stream_prompt("Hello")
277    ///     .with_history(vec![])
278    ///     .await;
279    /// // ... consume stream ...
280    /// // Access updated history from FinalResponse::history()
281    /// ```
282    pub fn with_history(mut self, history: Vec<Message>) -> Self {
283        self.chat_history = Some(history);
284        self
285    }
286
287    /// Attach a per-request hook for tool call events.
288    /// This overrides any default hook set on the agent.
289    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
290    where
291        P2: PromptHook<M>,
292    {
293        StreamingPromptRequest {
294            prompt: self.prompt,
295            chat_history: self.chat_history,
296            max_turns: self.max_turns,
297            model: self.model,
298            agent_name: self.agent_name,
299            preamble: self.preamble,
300            static_context: self.static_context,
301            temperature: self.temperature,
302            max_tokens: self.max_tokens,
303            additional_params: self.additional_params,
304            tool_server_handle: self.tool_server_handle,
305            dynamic_context: self.dynamic_context,
306            tool_choice: self.tool_choice,
307            output_schema: self.output_schema,
308            hook: Some(hook),
309        }
310    }
311
312    async fn send(self) -> StreamingResult<M::StreamingResponse> {
313        let agent_span = if tracing::Span::current().is_disabled() {
314            info_span!(
315                "invoke_agent",
316                gen_ai.operation.name = "invoke_agent",
317                gen_ai.agent.name = self.agent_name(),
318                gen_ai.system_instructions = self.preamble,
319                gen_ai.prompt = tracing::field::Empty,
320                gen_ai.completion = tracing::field::Empty,
321                gen_ai.usage.input_tokens = tracing::field::Empty,
322                gen_ai.usage.output_tokens = tracing::field::Empty,
323            )
324        } else {
325            tracing::Span::current()
326        };
327
328        let prompt = self.prompt;
329        if let Some(text) = prompt.rag_text() {
330            agent_span.record("gen_ai.prompt", text);
331        }
332
333        // Clone fields needed inside the stream
334        let model = self.model.clone();
335        let preamble = self.preamble.clone();
336        let static_context = self.static_context.clone();
337        let temperature = self.temperature;
338        let max_tokens = self.max_tokens;
339        let additional_params = self.additional_params.clone();
340        let tool_server_handle = self.tool_server_handle.clone();
341        let dynamic_context = self.dynamic_context.clone();
342        let tool_choice = self.tool_choice.clone();
343        let agent_name = self.agent_name.clone();
344        let has_history = self.chat_history.is_some();
345        let mut chat_history = self.chat_history.unwrap_or_default();
346
347        let mut current_max_turns = 0;
348        let mut last_prompt_error = String::new();
349
350        let mut last_text_response = String::new();
351        let mut is_text_response = false;
352        let mut max_turns_reached = false;
353        let output_schema = self.output_schema;
354
355        let mut aggregated_usage = crate::completion::Usage::new();
356
357        // NOTE: We use .instrument(agent_span) instead of span.enter() to avoid
358        // span context leaking to other concurrent tasks. Using span.enter() inside
359        // async_stream::stream! holds the guard across yield points, which causes
360        // thread-local span context to leak when other tasks run on the same thread.
361        // See: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#in-asynchronous-code
362        // See also: https://github.com/rust-lang/rust-clippy/issues/8722
363        let stream = async_stream::stream! {
364            let mut current_prompt = prompt.clone();
365
366            'outer: loop {
367                if current_max_turns > self.max_turns + 1 {
368                    last_prompt_error = current_prompt.rag_text().unwrap_or_default();
369                    max_turns_reached = true;
370                    break;
371                }
372
373                current_max_turns += 1;
374
375                if self.max_turns > 1 {
376                    tracing::info!(
377                        "Current conversation Turns: {}/{}",
378                        current_max_turns,
379                        self.max_turns
380                    );
381                }
382
383                if let Some(ref hook) = self.hook {
384                    let history_snapshot = chat_history.clone();
385                    if let HookAction::Terminate { reason } = hook.on_completion_call(&current_prompt, &history_snapshot)
386                        .await {
387                        yield Err(cancelled_prompt_error(&chat_history, reason).await);
388                        break 'outer;
389                    }
390                }
391
392                let chat_stream_span = info_span!(
393                    target: "rig::agent_chat",
394                    parent: tracing::Span::current(),
395                    "chat_streaming",
396                    gen_ai.operation.name = "chat",
397                    gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
398                    gen_ai.system_instructions = preamble,
399                    gen_ai.provider.name = tracing::field::Empty,
400                    gen_ai.request.model = tracing::field::Empty,
401                    gen_ai.response.id = tracing::field::Empty,
402                    gen_ai.response.model = tracing::field::Empty,
403                    gen_ai.usage.output_tokens = tracing::field::Empty,
404                    gen_ai.usage.input_tokens = tracing::field::Empty,
405                    gen_ai.input.messages = tracing::field::Empty,
406                    gen_ai.output.messages = tracing::field::Empty,
407                );
408
409                let history_snapshot = chat_history.clone();
410                let mut stream = tracing::Instrument::instrument(
411                    build_completion_request(
412                        &model,
413                        current_prompt.clone(),
414                        history_snapshot,
415                        preamble.as_deref(),
416                        &static_context,
417                        temperature,
418                        max_tokens,
419                        additional_params.as_ref(),
420                        tool_choice.as_ref(),
421                        &tool_server_handle,
422                        &dynamic_context,
423                        output_schema.as_ref(),
424                    )
425                    .await?
426                    .stream(), chat_stream_span
427                )
428
429                .await?;
430
431                chat_history.push(current_prompt.clone());
432
433                let mut tool_calls = vec![];
434                let mut tool_results = vec![];
435                let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
436                // Kept separate from accumulated_reasoning so providers requiring
437                // signatures (e.g. Anthropic) never see unsigned blocks.
438                let mut pending_reasoning_delta_text = String::new();
439                let mut pending_reasoning_delta_id: Option<String> = None;
440                let mut saw_tool_call_this_turn = false;
441
442                while let Some(content) = stream.next().await {
443                    match content {
444                        Ok(StreamedAssistantContent::Text(text)) => {
445                            if !is_text_response {
446                                last_text_response = String::new();
447                                is_text_response = true;
448                            }
449                            last_text_response.push_str(&text.text);
450                            if let Some(ref hook) = self.hook &&
451                                let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &last_text_response).await {
452                                    yield Err(cancelled_prompt_error(&chat_history, reason).await);
453                                    break 'outer;
454                            }
455
456                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
457                        },
458                        Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id }) => {
459                            let tool_span = info_span!(
460                                parent: tracing::Span::current(),
461                                "execute_tool",
462                                gen_ai.operation.name = "execute_tool",
463                                gen_ai.tool.type = "function",
464                                gen_ai.tool.name = tracing::field::Empty,
465                                gen_ai.tool.call.id = tracing::field::Empty,
466                                gen_ai.tool.call.arguments = tracing::field::Empty,
467                                gen_ai.tool.call.result = tracing::field::Empty
468                            );
469
470                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
471
472                            let tc_result = async {
473                                let tool_span = tracing::Span::current();
474                                let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
475                                if let Some(ref hook) = self.hook {
476                                    let action = hook
477                                        .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
478                                        .await;
479
480                                    if let ToolCallHookAction::Terminate { reason } = action {
481                                        return Err(cancelled_prompt_error(&chat_history, reason).await);
482                                    }
483
484                                    if let ToolCallHookAction::Skip { reason } = action {
485                                        // Tool execution rejected, return rejection message as tool result
486                                        tracing::info!(
487                                            tool_name = tool_call.function.name.as_str(),
488                                            reason = reason,
489                                            "Tool call rejected"
490                                        );
491                                        let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
492                                        tool_calls.push(tool_call_msg);
493                                        tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
494                                        saw_tool_call_this_turn = true;
495                                        return Ok(reason);
496                                    }
497                                }
498
499                                tool_span.record("gen_ai.tool.name", &tool_call.function.name);
500                                tool_span.record("gen_ai.tool.call.arguments", &tool_args);
501
502                                let tool_result = match
503                                tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
504                                    Ok(thing) => thing,
505                                    Err(e) => {
506                                        tracing::warn!("Error while calling tool: {e}");
507                                        e.to_string()
508                                    }
509                                };
510
511                                tool_span.record("gen_ai.tool.call.result", &tool_result);
512
513                                if let Some(ref hook) = self.hook &&
514                                    let HookAction::Terminate { reason } =
515                                    hook.on_tool_result(
516                                        &tool_call.function.name,
517                                        tool_call.call_id.clone(),
518                                        &internal_call_id,
519                                        &tool_args,
520                                        &tool_result.to_string()
521                                    )
522                                    .await {
523                                        return Err(cancelled_prompt_error(&chat_history, reason).await);
524                                    }
525
526                                let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
527
528                                tool_calls.push(tool_call_msg);
529                                tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
530
531                                saw_tool_call_this_turn = true;
532                                Ok(tool_result)
533                            }.instrument(tool_span).await;
534
535                            match tc_result {
536                                Ok(text) => {
537                                    let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
538                                    yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
539                                }
540                                Err(e) => {
541                                    yield Err(e);
542                                    break 'outer;
543                                }
544                            }
545                        },
546                        Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content }) => {
547                            if let Some(ref hook) = self.hook {
548                                let (name, delta) = match &content {
549                                    rig::streaming::ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
550                                    rig::streaming::ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
551                                };
552
553                                if let HookAction::Terminate { reason } = hook.on_tool_call_delta(&id, &internal_call_id, name, delta)
554                                .await {
555                                    yield Err(cancelled_prompt_error(&chat_history, reason).await);
556                                    break 'outer;
557                                }
558                            }
559                        }
560                        Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
561                            // Accumulate reasoning for inclusion in chat history with tool calls.
562                            // OpenAI Responses API requires reasoning items to be sent back
563                            // alongside function_call items in multi-turn conversations.
564                            merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
565                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
566                        },
567                        Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
568                            // Deltas lack signatures/encrypted content that full
569                            // blocks carry; mixing them into accumulated_reasoning
570                            // causes Anthropic to reject with "signature required".
571                            pending_reasoning_delta_text.push_str(&reasoning);
572                            if pending_reasoning_delta_id.is_none() {
573                                pending_reasoning_delta_id = id.clone();
574                            }
575                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
576                        },
577                        Ok(StreamedAssistantContent::Final(final_resp)) => {
578                            if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
579                            if is_text_response {
580                                if let Some(ref hook) = self.hook &&
581                                     let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(&prompt, &final_resp).await {
582                                        yield Err(cancelled_prompt_error(&chat_history, reason).await);
583                                        break 'outer;
584                                    }
585
586                                tracing::Span::current().record("gen_ai.completion", &last_text_response);
587                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
588                                is_text_response = false;
589                            }
590                        }
591                        Err(e) => {
592                            yield Err(e.into());
593                            break 'outer;
594                        }
595                    }
596                }
597
598                // Providers like Gemini emit thinking as incremental deltas
599                // without signatures; assemble into a single block so
600                // reasoning survives into the next turn's chat history.
601                if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
602                    let mut assembled = crate::message::Reasoning::new(&pending_reasoning_delta_text);
603                    if let Some(id) = pending_reasoning_delta_id.take() {
604                        assembled = assembled.with_id(id);
605                    }
606                    accumulated_reasoning.push(assembled);
607                }
608
609                // Add reasoning and tool calls to chat history.
610                // OpenAI Responses API requires reasoning items to precede function_call items.
611                if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
612                    let mut content_items: Vec<rig::message::AssistantContent> = vec![];
613
614                    // Reasoning must come before tool calls (OpenAI requirement)
615                    for reasoning in accumulated_reasoning.drain(..) {
616                        content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
617                    }
618
619                    content_items.extend(tool_calls.clone());
620
621                    if !content_items.is_empty() {
622                        chat_history.push(Message::Assistant {
623                            id: stream.message_id.clone(),
624                            content: OneOrMany::many(content_items).expect("Should have at least one item"),
625                        });
626                    }
627                }
628
629                for (id, call_id, tool_result) in tool_results {
630                    chat_history.push(tool_result_to_user_message(id, call_id, tool_result));
631                }
632
633                // Set the current prompt to the last message in the chat history
634                current_prompt = match chat_history.pop() {
635                    Some(prompt) => prompt,
636                    None => unreachable!("Chat history should never be empty at this point"),
637                };
638
639                if !saw_tool_call_this_turn {
640                    // Add user message and assistant response to history before finishing
641                    chat_history.push(current_prompt.clone());
642                    if !last_text_response.is_empty() {
643                        chat_history.push(Message::assistant(&last_text_response));
644                    }
645
646                    let current_span = tracing::Span::current();
647                    current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
648                    current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
649                    tracing::info!("Agent multi-turn stream finished");
650                    let history_snapshot = if has_history {
651                        Some(chat_history.clone())
652                    } else {
653                        None
654                    };
655                    yield Ok(MultiTurnStreamItem::final_response_with_history(
656                        &last_text_response,
657                        aggregated_usage,
658                        history_snapshot,
659                    ));
660                    break;
661                }
662            }
663
664            if max_turns_reached {
665                yield Err(Box::new(PromptError::MaxTurnsError {
666                    max_turns: self.max_turns,
667                    chat_history: Box::new(chat_history.clone()),
668                    prompt: Box::new(last_prompt_error.clone().into()),
669                }).into());
670            }
671        };
672
673        Box::pin(stream.instrument(agent_span))
674    }
675}
676
677impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
678where
679    M: CompletionModel + 'static,
680    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
681    P: PromptHook<M> + 'static,
682{
683    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
684    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
685
686    fn into_future(self) -> Self::IntoFuture {
687        // Wrap send() in a future, because send() returns a stream immediately
688        Box::pin(async move { self.send().await })
689    }
690}
691
692/// Helper function to stream a completion request to stdout.
693pub async fn stream_to_stdout<R>(
694    stream: &mut StreamingResult<R>,
695) -> Result<FinalResponse, std::io::Error> {
696    let mut final_res = FinalResponse::empty();
697    print!("Response: ");
698    while let Some(content) = stream.next().await {
699        match content {
700            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
701                Text { text },
702            ))) => {
703                print!("{text}");
704                std::io::Write::flush(&mut std::io::stdout()).unwrap();
705            }
706            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
707                reasoning,
708            ))) => {
709                let reasoning = reasoning.display_text();
710                print!("{reasoning}");
711                std::io::Write::flush(&mut std::io::stdout()).unwrap();
712            }
713            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
714                final_res = res;
715            }
716            Err(err) => {
717                eprintln!("Error: {err}");
718            }
719            _ => {}
720        }
721    }
722
723    Ok(final_res)
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729    use crate::agent::AgentBuilder;
730    use crate::client::ProviderClient;
731    use crate::client::completion::CompletionClient;
732    use crate::completion::{
733        CompletionError, CompletionModel, CompletionRequest, CompletionResponse,
734    };
735    use crate::message::ReasoningContent;
736    use crate::providers::anthropic;
737    use crate::streaming::StreamingPrompt;
738    use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
739    use futures::StreamExt;
740    use serde::{Deserialize, Serialize};
741    use std::sync::Arc;
742    use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
743    use std::time::Duration;
744
745    #[test]
746    fn merge_reasoning_blocks_preserves_order_and_signatures() {
747        let mut accumulated = Vec::new();
748        let first = crate::message::Reasoning {
749            id: Some("rs_1".to_string()),
750            content: vec![ReasoningContent::Text {
751                text: "step-1".to_string(),
752                signature: Some("sig-1".to_string()),
753            }],
754        };
755        let second = crate::message::Reasoning {
756            id: Some("rs_1".to_string()),
757            content: vec![
758                ReasoningContent::Text {
759                    text: "step-2".to_string(),
760                    signature: Some("sig-2".to_string()),
761                },
762                ReasoningContent::Summary("summary".to_string()),
763            ],
764        };
765
766        merge_reasoning_blocks(&mut accumulated, &first);
767        merge_reasoning_blocks(&mut accumulated, &second);
768
769        assert_eq!(accumulated.len(), 1);
770        let merged = accumulated.first().expect("expected accumulated reasoning");
771        assert_eq!(merged.id.as_deref(), Some("rs_1"));
772        assert_eq!(merged.content.len(), 3);
773        assert!(matches!(
774            merged.content.first(),
775            Some(ReasoningContent::Text { text, signature: Some(sig) })
776                if text == "step-1" && sig == "sig-1"
777        ));
778        assert!(matches!(
779            merged.content.get(1),
780            Some(ReasoningContent::Text { text, signature: Some(sig) })
781                if text == "step-2" && sig == "sig-2"
782        ));
783    }
784
785    #[test]
786    fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
787        let mut accumulated = vec![crate::message::Reasoning {
788            id: Some("rs_a".to_string()),
789            content: vec![ReasoningContent::Text {
790                text: "step-1".to_string(),
791                signature: None,
792            }],
793        }];
794        let incoming = crate::message::Reasoning {
795            id: Some("rs_b".to_string()),
796            content: vec![ReasoningContent::Text {
797                text: "step-2".to_string(),
798                signature: None,
799            }],
800        };
801
802        merge_reasoning_blocks(&mut accumulated, &incoming);
803        assert_eq!(accumulated.len(), 2);
804        assert_eq!(
805            accumulated.first().and_then(|r| r.id.as_deref()),
806            Some("rs_a")
807        );
808        assert_eq!(
809            accumulated.get(1).and_then(|r| r.id.as_deref()),
810            Some("rs_b")
811        );
812    }
813
814    #[test]
815    fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
816        let mut accumulated = vec![crate::message::Reasoning {
817            id: None,
818            content: vec![ReasoningContent::Text {
819                text: "first".to_string(),
820                signature: None,
821            }],
822        }];
823        let incoming = crate::message::Reasoning {
824            id: None,
825            content: vec![ReasoningContent::Text {
826                text: "second".to_string(),
827                signature: None,
828            }],
829        };
830
831        merge_reasoning_blocks(&mut accumulated, &incoming);
832        assert_eq!(accumulated.len(), 2);
833        assert!(matches!(
834            accumulated.first(),
835            Some(crate::message::Reasoning {
836                id: None,
837                content
838            }) if matches!(
839                content.first(),
840                Some(ReasoningContent::Text { text, .. }) if text == "first"
841            )
842        ));
843        assert!(matches!(
844            accumulated.get(1),
845            Some(crate::message::Reasoning {
846                id: None,
847                content
848            }) if matches!(
849                content.first(),
850                Some(ReasoningContent::Text { text, .. }) if text == "second"
851            )
852        ));
853    }
854
855    #[derive(Clone, Debug, Deserialize, Serialize)]
856    struct MockStreamingResponse {
857        usage: crate::completion::Usage,
858    }
859
860    impl MockStreamingResponse {
861        fn new(total_tokens: u64) -> Self {
862            let mut usage = crate::completion::Usage::new();
863            usage.total_tokens = total_tokens;
864            Self { usage }
865        }
866    }
867
868    impl crate::completion::GetTokenUsage for MockStreamingResponse {
869        fn token_usage(&self) -> Option<crate::completion::Usage> {
870            Some(self.usage)
871        }
872    }
873
874    #[derive(Clone, Default)]
875    struct MultiTurnMockModel {
876        turn_counter: Arc<AtomicUsize>,
877    }
878
879    #[allow(refining_impl_trait)]
880    impl CompletionModel for MultiTurnMockModel {
881        type Response = ();
882        type StreamingResponse = MockStreamingResponse;
883        type Client = ();
884
885        fn make(_: &Self::Client, _: impl Into<String>) -> Self {
886            Self::default()
887        }
888
889        async fn completion(
890            &self,
891            _request: CompletionRequest,
892        ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
893            Err(CompletionError::ProviderError(
894                "completion is unused in this streaming test".to_string(),
895            ))
896        }
897
898        async fn stream(
899            &self,
900            _request: CompletionRequest,
901        ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
902            let turn = self.turn_counter.fetch_add(1, Ordering::SeqCst);
903            let stream = async_stream::stream! {
904                if turn == 0 {
905                    yield Ok(RawStreamingChoice::ToolCall(
906                        RawStreamingToolCall::new(
907                            "tool_call_1".to_string(),
908                            "missing_tool".to_string(),
909                            serde_json::json!({"input": "value"}),
910                        )
911                        .with_call_id("call_1".to_string()),
912                    ));
913                    yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(4)));
914                } else {
915                    yield Ok(RawStreamingChoice::Message("done".to_string()));
916                    yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(6)));
917                }
918            };
919
920            let pinned_stream: crate::streaming::StreamingResult<Self::StreamingResponse> =
921                Box::pin(stream);
922            Ok(StreamingCompletionResponse::stream(pinned_stream))
923        }
924    }
925
926    #[tokio::test]
927    async fn stream_prompt_continues_after_tool_call_turn() {
928        let model = MultiTurnMockModel::default();
929        let turn_counter = model.turn_counter.clone();
930        let agent = AgentBuilder::new(model).build();
931
932        let mut stream = agent.stream_prompt("do tool work").multi_turn(3).await;
933        let mut saw_tool_call = false;
934        let mut saw_tool_result = false;
935        let mut saw_final_response = false;
936        let mut final_text = String::new();
937
938        while let Some(item) = stream.next().await {
939            match item {
940                Ok(MultiTurnStreamItem::StreamAssistantItem(
941                    StreamedAssistantContent::ToolCall { .. },
942                )) => {
943                    saw_tool_call = true;
944                }
945                Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
946                    ..
947                })) => {
948                    saw_tool_result = true;
949                }
950                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
951                    text,
952                ))) => {
953                    final_text.push_str(&text.text);
954                }
955                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
956                    saw_final_response = true;
957                    break;
958                }
959                Ok(_) => {}
960                Err(err) => panic!("unexpected streaming error: {err:?}"),
961            }
962        }
963
964        assert!(saw_tool_call);
965        assert!(saw_tool_result);
966        assert!(saw_final_response);
967        assert_eq!(final_text, "done");
968        assert_eq!(turn_counter.load(Ordering::SeqCst), 2);
969    }
970
971    /// Background task that logs periodically to detect span leakage.
972    /// If span leakage occurs, these logs will be prefixed with `invoke_agent{...}`.
973    async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
974        let mut interval = tokio::time::interval(Duration::from_millis(50));
975        let mut count = 0u32;
976
977        while !stop.load(Ordering::Relaxed) {
978            interval.tick().await;
979            count += 1;
980
981            tracing::event!(
982                target: "background_logger",
983                tracing::Level::INFO,
984                count = count,
985                "Background tick"
986            );
987
988            // Check if we're inside an unexpected span
989            let current = tracing::Span::current();
990            if !current.is_disabled() && !current.is_none() {
991                leak_count.fetch_add(1, Ordering::Relaxed);
992            }
993        }
994
995        tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
996    }
997
998    /// Test that span context doesn't leak to concurrent tasks during streaming.
999    ///
1000    /// This test verifies that using `.instrument()` instead of `span.enter()` in
1001    /// async_stream prevents thread-local span context from leaking to other tasks.
1002    ///
1003    /// Uses single-threaded runtime to force all tasks onto the same thread,
1004    /// making the span leak deterministic (it only occurs when tasks share a thread).
1005    #[tokio::test(flavor = "current_thread")]
1006    #[ignore = "This requires an API key"]
1007    async fn test_span_context_isolation() {
1008        let stop = Arc::new(AtomicBool::new(false));
1009        let leak_count = Arc::new(AtomicU32::new(0));
1010
1011        // Start background logger
1012        let bg_stop = stop.clone();
1013        let bg_leak = leak_count.clone();
1014        let bg_handle = tokio::spawn(async move {
1015            background_logger(bg_stop, bg_leak).await;
1016        });
1017
1018        // Small delay to let background logger start
1019        tokio::time::sleep(Duration::from_millis(100)).await;
1020
1021        // Make streaming request WITHOUT an outer span so rig creates its own invoke_agent span
1022        // (rig reuses current span if one exists, so we need to ensure there's no current span)
1023        let client = anthropic::Client::from_env();
1024        let agent = client
1025            .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
1026            .preamble("You are a helpful assistant.")
1027            .temperature(0.1)
1028            .max_tokens(100)
1029            .build();
1030
1031        let mut stream = agent
1032            .stream_prompt("Say 'hello world' and nothing else.")
1033            .await;
1034
1035        let mut full_content = String::new();
1036        while let Some(item) = stream.next().await {
1037            match item {
1038                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1039                    text,
1040                ))) => {
1041                    full_content.push_str(&text.text);
1042                }
1043                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
1044                    break;
1045                }
1046                Err(e) => {
1047                    tracing::warn!("Error: {:?}", e);
1048                    break;
1049                }
1050                _ => {}
1051            }
1052        }
1053
1054        tracing::info!("Got response: {:?}", full_content);
1055
1056        // Stop background logger
1057        stop.store(true, Ordering::Relaxed);
1058        bg_handle.await.unwrap();
1059
1060        let leaks = leak_count.load(Ordering::Relaxed);
1061        assert_eq!(
1062            leaks, 0,
1063            "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
1064             This indicates that span.enter() is being used inside async_stream instead of .instrument()"
1065        );
1066    }
1067
1068    /// Test that FinalResponse contains the updated chat history when with_history is used.
1069    ///
1070    /// This verifies that:
1071    /// 1. FinalResponse.history() returns Some when with_history was called
1072    /// 2. The history contains both the user prompt and assistant response
1073    #[tokio::test]
1074    #[ignore = "This requires an API key"]
1075    async fn test_chat_history_in_final_response() {
1076        use crate::message::Message;
1077
1078        let client = anthropic::Client::from_env();
1079        let agent = client
1080            .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
1081            .preamble("You are a helpful assistant. Keep responses brief.")
1082            .temperature(0.1)
1083            .max_tokens(50)
1084            .build();
1085
1086        // Send streaming request with history
1087        let mut stream = agent
1088            .stream_prompt("Say 'hello' and nothing else.")
1089            .with_history(vec![])
1090            .await;
1091
1092        // Consume the stream and collect FinalResponse
1093        let mut response_text = String::new();
1094        let mut final_history = None;
1095        while let Some(item) = stream.next().await {
1096            match item {
1097                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1098                    text,
1099                ))) => {
1100                    response_text.push_str(&text.text);
1101                }
1102                Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1103                    final_history = res.history().map(|h| h.to_vec());
1104                    break;
1105                }
1106                Err(e) => {
1107                    panic!("Streaming error: {:?}", e);
1108                }
1109                _ => {}
1110            }
1111        }
1112
1113        let history =
1114            final_history.expect("FinalResponse should contain history when with_history is used");
1115
1116        // Should contain at least the user message
1117        assert!(
1118            history.iter().any(|m| matches!(m, Message::User { .. })),
1119            "History should contain the user message"
1120        );
1121
1122        // Should contain the assistant response
1123        assert!(
1124            history
1125                .iter()
1126                .any(|m| matches!(m, Message::Assistant { .. })),
1127            "History should contain the assistant response"
1128        );
1129
1130        tracing::info!(
1131            "History after streaming: {} messages, response: {:?}",
1132            history.len(),
1133            response_text
1134        );
1135    }
1136}