Skip to main content

rig/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::prompt_request::HookAction,
4    completion::GetTokenUsage,
5    json_utils,
6    message::{AssistantContent, Reasoning, ToolResult, ToolResultContent, UserContent},
7    streaming::{StreamedAssistantContent, StreamedUserContent, StreamingCompletion},
8    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
9};
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::{pin::Pin, sync::Arc};
13use tokio::sync::RwLock;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use super::ToolCallHookAction;
18use crate::{
19    agent::Agent,
20    completion::{CompletionError, CompletionModel, PromptError},
21    message::{Message, Text},
22    tool::ToolSetError,
23};
24
25#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
26pub type StreamingResult<R> =
27    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
28
29#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
30pub type StreamingResult<R> =
31    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
32
33#[derive(Deserialize, Serialize, Debug, Clone)]
34#[serde(tag = "type", rename_all = "camelCase")]
35#[non_exhaustive]
36pub enum MultiTurnStreamItem<R> {
37    /// A streamed assistant content item.
38    StreamAssistantItem(StreamedAssistantContent<R>),
39    /// A streamed user content item (mostly for tool results).
40    StreamUserItem(StreamedUserContent),
41    /// The final result from the stream.
42    FinalResponse(FinalResponse),
43}
44
45#[derive(Deserialize, Serialize, Debug, Clone)]
46#[serde(rename_all = "camelCase")]
47pub struct FinalResponse {
48    response: String,
49    aggregated_usage: crate::completion::Usage,
50}
51
52impl FinalResponse {
53    pub fn empty() -> Self {
54        Self {
55            response: String::new(),
56            aggregated_usage: crate::completion::Usage::new(),
57        }
58    }
59
60    pub fn response(&self) -> &str {
61        &self.response
62    }
63
64    pub fn usage(&self) -> crate::completion::Usage {
65        self.aggregated_usage
66    }
67}
68
69impl<R> MultiTurnStreamItem<R> {
70    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
71        Self::StreamAssistantItem(item)
72    }
73
74    pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
75        Self::FinalResponse(FinalResponse {
76            response: response.to_string(),
77            aggregated_usage,
78        })
79    }
80}
81
82#[derive(Debug, thiserror::Error)]
83pub enum StreamingError {
84    #[error("CompletionError: {0}")]
85    Completion(#[from] CompletionError),
86    #[error("PromptError: {0}")]
87    Prompt(#[from] Box<PromptError>),
88    #[error("ToolSetError: {0}")]
89    Tool(#[from] ToolSetError),
90}
91
92/// A builder for creating prompt requests with customizable options.
93/// Uses generics to track which options have been set during the build process.
94///
95/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
96/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
97/// attempting to await (which will send the prompt request) can potentially return
98/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
99/// back to back.
100pub struct StreamingPromptRequest<M, P>
101where
102    M: CompletionModel,
103    P: StreamingPromptHook<M> + 'static,
104{
105    /// The prompt message to send to the model
106    prompt: Message,
107    /// Optional chat history to include with the prompt
108    /// Note: chat history needs to outlive the agent as it might be used with other agents
109    chat_history: Option<Vec<Message>>,
110    /// Maximum Turns for multi-turn conversations (0 means no multi-turn)
111    max_turns: usize,
112    /// The agent to use for execution
113    agent: Arc<Agent<M>>,
114    /// Optional per-request hook for events
115    hook: Option<P>,
116}
117
118impl<M, P> StreamingPromptRequest<M, P>
119where
120    M: CompletionModel + 'static,
121    <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
122    P: StreamingPromptHook<M>,
123{
124    /// Create a new PromptRequest with the given prompt and model
125    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
126        Self {
127            prompt: prompt.into(),
128            chat_history: None,
129            max_turns: agent.default_max_turns.unwrap_or_default(),
130            agent,
131            hook: None,
132        }
133    }
134
135    /// Set the maximum Turns for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
136    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
137    pub fn multi_turn(mut self, turns: usize) -> Self {
138        self.max_turns = turns;
139        self
140    }
141
142    /// Add chat history to the prompt request
143    pub fn with_history(mut self, history: Vec<Message>) -> Self {
144        self.chat_history = Some(history);
145        self
146    }
147
148    /// Attach a per-request hook for tool call events
149    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
150    where
151        P2: StreamingPromptHook<M>,
152    {
153        StreamingPromptRequest {
154            prompt: self.prompt,
155            chat_history: self.chat_history,
156            max_turns: self.max_turns,
157            agent: self.agent,
158            hook: Some(hook),
159        }
160    }
161
162    async fn send(self) -> StreamingResult<M::StreamingResponse> {
163        let agent_span = if tracing::Span::current().is_disabled() {
164            info_span!(
165                "invoke_agent",
166                gen_ai.operation.name = "invoke_agent",
167                gen_ai.agent.name = self.agent.name(),
168                gen_ai.system_instructions = self.agent.preamble,
169                gen_ai.prompt = tracing::field::Empty,
170                gen_ai.completion = tracing::field::Empty,
171                gen_ai.usage.input_tokens = tracing::field::Empty,
172                gen_ai.usage.output_tokens = tracing::field::Empty,
173            )
174        } else {
175            tracing::Span::current()
176        };
177
178        let prompt = self.prompt;
179        if let Some(text) = prompt.rag_text() {
180            agent_span.record("gen_ai.prompt", text);
181        }
182
183        let agent = self.agent;
184
185        let chat_history = if let Some(history) = self.chat_history {
186            Arc::new(RwLock::new(history))
187        } else {
188            Arc::new(RwLock::new(vec![]))
189        };
190
191        let mut current_max_turns = 0;
192        let mut last_prompt_error = String::new();
193
194        let mut last_text_response = String::new();
195        let mut is_text_response = false;
196        let mut max_turns_reached = false;
197
198        let mut aggregated_usage = crate::completion::Usage::new();
199
200        // NOTE: We use .instrument(agent_span) instead of span.enter() to avoid
201        // span context leaking to other concurrent tasks. Using span.enter() inside
202        // async_stream::stream! holds the guard across yield points, which causes
203        // thread-local span context to leak when other tasks run on the same thread.
204        // See: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#in-asynchronous-code
205        // See also: https://github.com/rust-lang/rust-clippy/issues/8722
206        let stream = async_stream::stream! {
207            let mut current_prompt = prompt.clone();
208            let mut did_call_tool = false;
209
210            'outer: loop {
211                if current_max_turns > self.max_turns + 1 {
212                    last_prompt_error = current_prompt.rag_text().unwrap_or_default();
213                    max_turns_reached = true;
214                    break;
215                }
216
217                current_max_turns += 1;
218
219                if self.max_turns > 1 {
220                    tracing::info!(
221                        "Current conversation Turns: {}/{}",
222                        current_max_turns,
223                        self.max_turns
224                    );
225                }
226
227                if let Some(ref hook) = self.hook {
228                    let reader = chat_history.read().await;
229                    if let HookAction::Terminate { reason } = hook.on_completion_call(&current_prompt, &reader.to_vec())
230                        .await {
231
232                        yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
233                            reason
234                        ).into()));
235                    }
236                }
237
238                let chat_stream_span = info_span!(
239                    target: "rig::agent_chat",
240                    parent: tracing::Span::current(),
241                    "chat_streaming",
242                    gen_ai.operation.name = "chat",
243                    gen_ai.agent.name = &agent.name(),
244                    gen_ai.system_instructions = &agent.preamble,
245                    gen_ai.provider.name = tracing::field::Empty,
246                    gen_ai.request.model = tracing::field::Empty,
247                    gen_ai.response.id = tracing::field::Empty,
248                    gen_ai.response.model = tracing::field::Empty,
249                    gen_ai.usage.output_tokens = tracing::field::Empty,
250                    gen_ai.usage.input_tokens = tracing::field::Empty,
251                    gen_ai.input.messages = tracing::field::Empty,
252                    gen_ai.output.messages = tracing::field::Empty,
253                );
254
255                let mut stream = tracing::Instrument::instrument(
256                    agent
257                    .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
258                    .await?
259                    .stream(), chat_stream_span
260                )
261
262                .await?;
263
264                chat_history.write().await.push(current_prompt.clone());
265
266                let mut tool_calls = vec![];
267                let mut tool_results = vec![];
268                let mut accumulated_reasoning: Option<rig::message::Reasoning> = None;
269
270                while let Some(content) = stream.next().await {
271                    match content {
272                        Ok(StreamedAssistantContent::Text(text)) => {
273                            if !is_text_response {
274                                last_text_response = String::new();
275                                is_text_response = true;
276                            }
277                            last_text_response.push_str(&text.text);
278                            if let Some(ref hook) = self.hook &&
279                                let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &last_text_response).await {
280                                    yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
281                                        reason
282                                    ).into()));
283                                }
284
285                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
286                            did_call_tool = false;
287                        },
288                        Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id }) => {
289                            let tool_span = info_span!(
290                                parent: tracing::Span::current(),
291                                "execute_tool",
292                                gen_ai.operation.name = "execute_tool",
293                                gen_ai.tool.type = "function",
294                                gen_ai.tool.name = tracing::field::Empty,
295                                gen_ai.tool.call.id = tracing::field::Empty,
296                                gen_ai.tool.call.arguments = tracing::field::Empty,
297                                gen_ai.tool.call.result = tracing::field::Empty
298                            );
299
300                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
301
302                            let tc_result = async {
303                                let tool_span = tracing::Span::current();
304                                let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
305                                if let Some(ref hook) = self.hook {
306                                    let action = hook
307                                        .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
308                                        .await;
309
310                                    if let ToolCallHookAction::Terminate { reason } = action {
311                                        return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
312                                            reason
313                                        ).into()));
314                                    }
315
316                                    if let ToolCallHookAction::Skip { reason } = action {
317                                        // Tool execution rejected, return rejection message as tool result
318                                        tracing::info!(
319                                            tool_name = tool_call.function.name.as_str(),
320                                            reason = reason,
321                                            "Tool call rejected"
322                                        );
323                                        let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
324                                        tool_calls.push(tool_call_msg);
325                                        tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
326                                        did_call_tool = true;
327                                        return Ok(reason);
328                                    }
329                                }
330
331                                tool_span.record("gen_ai.tool.name", &tool_call.function.name);
332                                tool_span.record("gen_ai.tool.call.arguments", &tool_args);
333
334                                let tool_result = match
335                                agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
336                                    Ok(thing) => thing,
337                                    Err(e) => {
338                                        tracing::warn!("Error while calling tool: {e}");
339                                        e.to_string()
340                                    }
341                                };
342
343                                tool_span.record("gen_ai.tool.call.result", &tool_result);
344
345                                if let Some(ref hook) = self.hook &&
346                                    let HookAction::Terminate { reason } =
347                                    hook.on_tool_result(
348                                        &tool_call.function.name,
349                                        tool_call.call_id.clone(),
350                                        &internal_call_id,
351                                        &tool_args,
352                                        &tool_result.to_string()
353                                    )
354                                    .await {
355                                        return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
356                                            reason
357                                        ).into()));
358                                    }
359
360                                let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
361
362                                tool_calls.push(tool_call_msg);
363                                tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
364
365                                did_call_tool = true;
366                                Ok(tool_result)
367                            }.instrument(tool_span).await;
368
369                            match tc_result {
370                                Ok(text) => {
371                                    let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
372                                    yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
373                                }
374                                Err(e) => {
375                                    yield Err(e);
376                                }
377                            }
378                        },
379                        Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content }) => {
380                            if let Some(ref hook) = self.hook {
381                                let (name, delta) = match &content {
382                                    rig::streaming::ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
383                                    rig::streaming::ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
384                                };
385
386                                if let HookAction::Terminate { reason } = hook.on_tool_call_delta(&id, &internal_call_id, name, delta)
387                                .await {
388                                    yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
389                                        reason
390                                    ).into()));
391                                }
392                            }
393                        }
394                        Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })) => {
395                            // Accumulate reasoning for inclusion in chat history with tool calls.
396                            // OpenAI Responses API requires reasoning items to be sent back
397                            // alongside function_call items in multi-turn conversations.
398                            if let Some(ref mut existing) = accumulated_reasoning {
399                                existing.reasoning.extend(reasoning.clone());
400                            } else {
401                                accumulated_reasoning = Some(rig::message::Reasoning {
402                                    reasoning: reasoning.clone(),
403                                    id: id.clone(),
404                                    signature: signature.clone(),
405                                });
406                            }
407                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
408                            did_call_tool = false;
409                        },
410                        Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
411                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
412                            did_call_tool = false;
413                        },
414                        Ok(StreamedAssistantContent::Final(final_resp)) => {
415                            if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
416                            if is_text_response {
417                                if let Some(ref hook) = self.hook &&
418                                     let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(&prompt, &final_resp).await {
419                                        yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
420                                            reason
421                                        ).into()));
422                                    }
423
424                                tracing::Span::current().record("gen_ai.completion", &last_text_response);
425                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
426                                is_text_response = false;
427                            }
428                        }
429                        Err(e) => {
430                            yield Err(e.into());
431                            break 'outer;
432                        }
433                    }
434                }
435
436                // Add reasoning and tool calls to chat history.
437                // OpenAI Responses API requires reasoning items to precede function_call items.
438                if !tool_calls.is_empty() || accumulated_reasoning.is_some() {
439                    let mut content_items: Vec<rig::message::AssistantContent> = vec![];
440
441                    // Reasoning must come before tool calls (OpenAI requirement)
442                    if let Some(reasoning) = accumulated_reasoning.take() {
443                        content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
444                    }
445
446                    content_items.extend(tool_calls.clone());
447
448                    if !content_items.is_empty() {
449                        chat_history.write().await.push(Message::Assistant {
450                            id: None,
451                            content: OneOrMany::many(content_items).expect("Should have at least one item"),
452                        });
453                    }
454                }
455
456                // Add tool results to chat history
457                for (id, call_id, tool_result) in tool_results {
458                    if let Some(call_id) = call_id {
459                        chat_history.write().await.push(Message::User {
460                            content: OneOrMany::one(UserContent::tool_result_with_call_id(
461                                &id,
462                                call_id.clone(),
463                                OneOrMany::one(ToolResultContent::text(&tool_result)),
464                            )),
465                        });
466                    } else {
467                        chat_history.write().await.push(Message::User {
468                            content: OneOrMany::one(UserContent::tool_result(
469                                &id,
470                                OneOrMany::one(ToolResultContent::text(&tool_result)),
471                            )),
472                        });
473                    }
474                }
475
476                // Set the current prompt to the last message in the chat history
477                current_prompt = match chat_history.write().await.pop() {
478                    Some(prompt) => prompt,
479                    None => unreachable!("Chat history should never be empty at this point"),
480                };
481
482                if !did_call_tool {
483                    let current_span = tracing::Span::current();
484                    current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
485                    current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
486                    tracing::info!("Agent multi-turn stream finished");
487                    yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
488                    break;
489                }
490            }
491
492            if max_turns_reached {
493                yield Err(Box::new(PromptError::MaxTurnsError {
494                    max_turns: self.max_turns,
495                    chat_history: Box::new((*chat_history.read().await).clone()),
496                    prompt: Box::new(last_prompt_error.clone().into()),
497                }).into());
498            }
499        };
500
501        Box::pin(stream.instrument(agent_span))
502    }
503}
504
505impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
506where
507    M: CompletionModel + 'static,
508    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
509    P: StreamingPromptHook<M> + 'static,
510{
511    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
512    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
513
514    fn into_future(self) -> Self::IntoFuture {
515        // Wrap send() in a future, because send() returns a stream immediately
516        Box::pin(async move { self.send().await })
517    }
518}
519
520/// Helper function to stream a completion request to stdout.
521pub async fn stream_to_stdout<R>(
522    stream: &mut StreamingResult<R>,
523) -> Result<FinalResponse, std::io::Error> {
524    let mut final_res = FinalResponse::empty();
525    print!("Response: ");
526    while let Some(content) = stream.next().await {
527        match content {
528            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
529                Text { text },
530            ))) => {
531                print!("{text}");
532                std::io::Write::flush(&mut std::io::stdout()).unwrap();
533            }
534            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
535                Reasoning { reasoning, .. },
536            ))) => {
537                let reasoning = reasoning.join("\n");
538                print!("{reasoning}");
539                std::io::Write::flush(&mut std::io::stdout()).unwrap();
540            }
541            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
542                final_res = res;
543            }
544            Err(err) => {
545                eprintln!("Error: {err}");
546            }
547            _ => {}
548        }
549    }
550
551    Ok(final_res)
552}
553
554// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
555/// Trait for per-request hooks to observe tool call events.
556pub trait StreamingPromptHook<M>: Clone + Send + Sync
557where
558    M: CompletionModel,
559{
560    /// Called before the prompt is sent to the model
561    fn on_completion_call(
562        &self,
563        _prompt: &Message,
564        _history: &[Message],
565    ) -> impl Future<Output = HookAction> + Send {
566        async { HookAction::cont() }
567    }
568
569    /// Called when receiving a text delta
570    fn on_text_delta(
571        &self,
572        _text_delta: &str,
573        _aggregated_text: &str,
574    ) -> impl Future<Output = HookAction> + Send {
575        async { HookAction::cont() }
576    }
577
578    /// Called when receiving a tool call delta.
579    /// `tool_name` is Some on the first delta for a tool call, None on subsequent deltas.
580    fn on_tool_call_delta(
581        &self,
582        _tool_call_id: &str,
583        _internal_call_id: &str,
584        _tool_name: Option<&str>,
585        _tool_call_delta: &str,
586    ) -> impl Future<Output = HookAction> + Send {
587        async { HookAction::cont() }
588    }
589
590    /// Called after the model provider has finished streaming a text response from their completion API to the client.
591    fn on_stream_completion_response_finish(
592        &self,
593        _prompt: &Message,
594        _response: &<M as CompletionModel>::StreamingResponse,
595    ) -> impl Future<Output = HookAction> + Send {
596        async { HookAction::cont() }
597    }
598
599    /// Called before a tool is invoked.
600    ///
601    /// # Returns
602    /// - `ToolCallHookAction::Continue` - Allow tool execution to proceed
603    /// - `ToolCallHookAction::Skip { reason }` - Reject tool execution; `reason` will be returned to the LLM as the tool result
604    fn on_tool_call(
605        &self,
606        _tool_name: &str,
607        _tool_call_id: Option<String>,
608        _internal_call_id: &str,
609        _args: &str,
610    ) -> impl Future<Output = ToolCallHookAction> + Send {
611        async { ToolCallHookAction::cont() }
612    }
613
614    /// Called after a tool is invoked (and a result has been returned).
615    fn on_tool_result(
616        &self,
617        _tool_name: &str,
618        _tool_call_id: Option<String>,
619        _internal_call_id: &str,
620        _args: &str,
621        _result: &str,
622    ) -> impl Future<Output = HookAction> + Send {
623        async { HookAction::cont() }
624    }
625}
626
627impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632    use crate::client::ProviderClient;
633    use crate::client::completion::CompletionClient;
634    use crate::providers::anthropic;
635    use crate::streaming::StreamingPrompt;
636    use futures::StreamExt;
637    use std::sync::Arc;
638    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
639    use std::time::Duration;
640
641    /// Background task that logs periodically to detect span leakage.
642    /// If span leakage occurs, these logs will be prefixed with `invoke_agent{...}`.
643    async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
644        let mut interval = tokio::time::interval(Duration::from_millis(50));
645        let mut count = 0u32;
646
647        while !stop.load(Ordering::Relaxed) {
648            interval.tick().await;
649            count += 1;
650
651            tracing::event!(
652                target: "background_logger",
653                tracing::Level::INFO,
654                count = count,
655                "Background tick"
656            );
657
658            // Check if we're inside an unexpected span
659            let current = tracing::Span::current();
660            if !current.is_disabled() && !current.is_none() {
661                leak_count.fetch_add(1, Ordering::Relaxed);
662            }
663        }
664
665        tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
666    }
667
668    /// Test that span context doesn't leak to concurrent tasks during streaming.
669    ///
670    /// This test verifies that using `.instrument()` instead of `span.enter()` in
671    /// async_stream prevents thread-local span context from leaking to other tasks.
672    ///
673    /// Uses single-threaded runtime to force all tasks onto the same thread,
674    /// making the span leak deterministic (it only occurs when tasks share a thread).
675    #[tokio::test(flavor = "current_thread")]
676    #[ignore = "This requires an API key"]
677    async fn test_span_context_isolation() {
678        let stop = Arc::new(AtomicBool::new(false));
679        let leak_count = Arc::new(AtomicU32::new(0));
680
681        // Start background logger
682        let bg_stop = stop.clone();
683        let bg_leak = leak_count.clone();
684        let bg_handle = tokio::spawn(async move {
685            background_logger(bg_stop, bg_leak).await;
686        });
687
688        // Small delay to let background logger start
689        tokio::time::sleep(Duration::from_millis(100)).await;
690
691        // Make streaming request WITHOUT an outer span so rig creates its own invoke_agent span
692        // (rig reuses current span if one exists, so we need to ensure there's no current span)
693        let client = anthropic::Client::from_env();
694        let agent = client
695            .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
696            .preamble("You are a helpful assistant.")
697            .temperature(0.1)
698            .max_tokens(100)
699            .build();
700
701        let mut stream = agent
702            .stream_prompt("Say 'hello world' and nothing else.")
703            .await;
704
705        let mut full_content = String::new();
706        while let Some(item) = stream.next().await {
707            match item {
708                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
709                    text,
710                ))) => {
711                    full_content.push_str(&text.text);
712                }
713                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
714                    break;
715                }
716                Err(e) => {
717                    tracing::warn!("Error: {:?}", e);
718                    break;
719                }
720                _ => {}
721            }
722        }
723
724        tracing::info!("Got response: {:?}", full_content);
725
726        // Stop background logger
727        stop.store(true, Ordering::Relaxed);
728        bg_handle.await.unwrap();
729
730        let leaks = leak_count.load(Ordering::Relaxed);
731        assert_eq!(
732            leaks, 0,
733            "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
734             This indicates that span.enter() is being used inside async_stream instead of .instrument()"
735        );
736    }
737}