rig/agent/prompt_request/
streaming.rs

1use crate::{
2    OneOrMany,
3    agent::CancelSignal,
4    completion::GetTokenUsage,
5    json_utils,
6    message::{AssistantContent, Reasoning, ToolResult, ToolResultContent, UserContent},
7    streaming::{StreamedAssistantContent, StreamedUserContent, StreamingCompletion},
8    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
9};
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::{pin::Pin, sync::Arc};
13use tokio::sync::RwLock;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use crate::{
18    agent::Agent,
19    completion::{CompletionError, CompletionModel, PromptError},
20    message::{Message, Text},
21    tool::ToolSetError,
22};
23
24#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
25pub type StreamingResult<R> =
26    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
27
28#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
29pub type StreamingResult<R> =
30    Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
31
32#[derive(Deserialize, Serialize, Debug, Clone)]
33#[serde(tag = "type", rename_all = "camelCase")]
34#[non_exhaustive]
35pub enum MultiTurnStreamItem<R> {
36    /// A streamed assistant content item.
37    StreamAssistantItem(StreamedAssistantContent<R>),
38    /// A streamed user content item (mostly for tool results).
39    StreamUserItem(StreamedUserContent),
40    /// The final result from the stream.
41    FinalResponse(FinalResponse),
42}
43
44#[derive(Deserialize, Serialize, Debug, Clone)]
45#[serde(rename_all = "camelCase")]
46pub struct FinalResponse {
47    response: String,
48    aggregated_usage: crate::completion::Usage,
49}
50
51impl FinalResponse {
52    pub fn empty() -> Self {
53        Self {
54            response: String::new(),
55            aggregated_usage: crate::completion::Usage::new(),
56        }
57    }
58
59    pub fn response(&self) -> &str {
60        &self.response
61    }
62
63    pub fn usage(&self) -> crate::completion::Usage {
64        self.aggregated_usage
65    }
66}
67
68impl<R> MultiTurnStreamItem<R> {
69    pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
70        Self::StreamAssistantItem(item)
71    }
72
73    pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
74        Self::FinalResponse(FinalResponse {
75            response: response.to_string(),
76            aggregated_usage,
77        })
78    }
79}
80
81#[derive(Debug, thiserror::Error)]
82pub enum StreamingError {
83    #[error("CompletionError: {0}")]
84    Completion(#[from] CompletionError),
85    #[error("PromptError: {0}")]
86    Prompt(#[from] Box<PromptError>),
87    #[error("ToolSetError: {0}")]
88    Tool(#[from] ToolSetError),
89}
90
91/// A builder for creating prompt requests with customizable options.
92/// Uses generics to track which options have been set during the build process.
93///
94/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
95/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
96/// attempting to await (which will send the prompt request) can potentially return
97/// [`crate::completion::request::PromptError::MaxDepthError`] if the agent decides to call tools
98/// back to back.
99pub struct StreamingPromptRequest<M, P>
100where
101    M: CompletionModel,
102    P: StreamingPromptHook<M> + 'static,
103{
104    /// The prompt message to send to the model
105    prompt: Message,
106    /// Optional chat history to include with the prompt
107    /// Note: chat history needs to outlive the agent as it might be used with other agents
108    chat_history: Option<Vec<Message>>,
109    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
110    max_depth: usize,
111    /// The agent to use for execution
112    agent: Arc<Agent<M>>,
113    /// Optional per-request hook for events
114    hook: Option<P>,
115}
116
117impl<M, P> StreamingPromptRequest<M, P>
118where
119    M: CompletionModel + 'static,
120    <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
121    P: StreamingPromptHook<M>,
122{
123    /// Create a new PromptRequest with the given prompt and model
124    pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
125        Self {
126            prompt: prompt.into(),
127            chat_history: None,
128            max_depth: agent.default_max_depth.unwrap_or_default(),
129            agent,
130            hook: None,
131        }
132    }
133
134    /// Set the maximum depth for multi-turn conversations (ie, the maximum number of turns an LLM can have calling tools before writing a text response).
135    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxDepthError`].
136    pub fn multi_turn(mut self, depth: usize) -> Self {
137        self.max_depth = depth;
138        self
139    }
140
141    /// Add chat history to the prompt request
142    pub fn with_history(mut self, history: Vec<Message>) -> Self {
143        self.chat_history = Some(history);
144        self
145    }
146
147    /// Attach a per-request hook for tool call events
148    pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
149    where
150        P2: StreamingPromptHook<M>,
151    {
152        StreamingPromptRequest {
153            prompt: self.prompt,
154            chat_history: self.chat_history,
155            max_depth: self.max_depth,
156            agent: self.agent,
157            hook: Some(hook),
158        }
159    }
160
161    async fn send(self) -> StreamingResult<M::StreamingResponse> {
162        let agent_span = if tracing::Span::current().is_disabled() {
163            info_span!(
164                "invoke_agent",
165                gen_ai.operation.name = "invoke_agent",
166                gen_ai.agent.name = self.agent.name(),
167                gen_ai.system_instructions = self.agent.preamble,
168                gen_ai.prompt = tracing::field::Empty,
169                gen_ai.completion = tracing::field::Empty,
170                gen_ai.usage.input_tokens = tracing::field::Empty,
171                gen_ai.usage.output_tokens = tracing::field::Empty,
172            )
173        } else {
174            tracing::Span::current()
175        };
176
177        let prompt = self.prompt;
178        if let Some(text) = prompt.rag_text() {
179            agent_span.record("gen_ai.prompt", text);
180        }
181
182        let agent = self.agent;
183
184        let chat_history = if let Some(history) = self.chat_history {
185            Arc::new(RwLock::new(history))
186        } else {
187            Arc::new(RwLock::new(vec![]))
188        };
189
190        let mut current_max_depth = 0;
191        let mut last_prompt_error = String::new();
192
193        let mut last_text_response = String::new();
194        let mut is_text_response = false;
195        let mut max_depth_reached = false;
196
197        let mut aggregated_usage = crate::completion::Usage::new();
198
199        let cancel_sig = CancelSignal::new();
200
201        // NOTE: We use .instrument(agent_span) instead of span.enter() to avoid
202        // span context leaking to other concurrent tasks. Using span.enter() inside
203        // async_stream::stream! holds the guard across yield points, which causes
204        // thread-local span context to leak when other tasks run on the same thread.
205        // See: https://docs.rs/tracing/latest/tracing/span/struct.Span.html#in-asynchronous-code
206        // See also: https://github.com/rust-lang/rust-clippy/issues/8722
207        let stream = async_stream::stream! {
208            let mut current_prompt = prompt.clone();
209            let mut did_call_tool = false;
210
211            'outer: loop {
212                if current_max_depth > self.max_depth + 1 {
213                    last_prompt_error = current_prompt.rag_text().unwrap_or_default();
214                    max_depth_reached = true;
215                    break;
216                }
217
218                current_max_depth += 1;
219
220                if self.max_depth > 1 {
221                    tracing::info!(
222                        "Current conversation depth: {}/{}",
223                        current_max_depth,
224                        self.max_depth
225                    );
226                }
227
228                if let Some(ref hook) = self.hook {
229                    let reader = chat_history.read().await;
230                    hook.on_completion_call(&current_prompt, &reader.to_vec(), cancel_sig.clone())
231                        .await;
232
233                    if cancel_sig.is_cancelled() {
234                        yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
235                            cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
236                        ).into()));
237                    }
238                }
239
240                let chat_stream_span = info_span!(
241                    target: "rig::agent_chat",
242                    parent: tracing::Span::current(),
243                    "chat_streaming",
244                    gen_ai.operation.name = "chat",
245                    gen_ai.agent.name = &agent.name(),
246                    gen_ai.system_instructions = &agent.preamble,
247                    gen_ai.provider.name = tracing::field::Empty,
248                    gen_ai.request.model = tracing::field::Empty,
249                    gen_ai.response.id = tracing::field::Empty,
250                    gen_ai.response.model = tracing::field::Empty,
251                    gen_ai.usage.output_tokens = tracing::field::Empty,
252                    gen_ai.usage.input_tokens = tracing::field::Empty,
253                    gen_ai.input.messages = tracing::field::Empty,
254                    gen_ai.output.messages = tracing::field::Empty,
255                );
256
257                let mut stream = tracing::Instrument::instrument(
258                    agent
259                    .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
260                    .await?
261                    .stream(), chat_stream_span
262                )
263
264                .await?;
265
266                chat_history.write().await.push(current_prompt.clone());
267
268                let mut tool_calls = vec![];
269                let mut tool_results = vec![];
270
271                while let Some(content) = stream.next().await {
272                    match content {
273                        Ok(StreamedAssistantContent::Text(text)) => {
274                            if !is_text_response {
275                                last_text_response = String::new();
276                                is_text_response = true;
277                            }
278                            last_text_response.push_str(&text.text);
279                            if let Some(ref hook) = self.hook {
280                                hook.on_text_delta(&text.text, &last_text_response, cancel_sig.clone()).await;
281                                if cancel_sig.is_cancelled() {
282                                    yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
283                                        cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
284                                    ).into()));
285                                }
286                            }
287                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
288                            did_call_tool = false;
289                        },
290                        Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
291                            let tool_span = info_span!(
292                                parent: tracing::Span::current(),
293                                "execute_tool",
294                                gen_ai.operation.name = "execute_tool",
295                                gen_ai.tool.type = "function",
296                                gen_ai.tool.name = tracing::field::Empty,
297                                gen_ai.tool.call.id = tracing::field::Empty,
298                                gen_ai.tool.call.arguments = tracing::field::Empty,
299                                gen_ai.tool.call.result = tracing::field::Empty
300                            );
301
302                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall(tool_call.clone())));
303
304                            let tc_result = async {
305                                let tool_span = tracing::Span::current();
306                                let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
307                                if let Some(ref hook) = self.hook {
308                                    hook.on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &tool_args, cancel_sig.clone()).await;
309                                    if cancel_sig.is_cancelled() {
310                                        return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
311                                            cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
312                                        ).into()));
313                                    }
314                                }
315
316                                tool_span.record("gen_ai.tool.name", &tool_call.function.name);
317                                tool_span.record("gen_ai.tool.call.arguments", &tool_args);
318
319                                let tool_result = match
320                                agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
321                                    Ok(thing) => thing,
322                                    Err(e) => {
323                                        tracing::warn!("Error while calling tool: {e}");
324                                        e.to_string()
325                                    }
326                                };
327
328                                tool_span.record("gen_ai.tool.call.result", &tool_result);
329
330                                if let Some(ref hook) = self.hook {
331                                    hook.on_tool_result(&tool_call.function.name, tool_call.call_id.clone(), &tool_args, &tool_result.to_string(), cancel_sig.clone())
332                                    .await;
333
334                                    if cancel_sig.is_cancelled() {
335                                        return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
336                                            cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
337                                        ).into()));
338                                    }
339                                }
340
341                                let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
342
343                                tool_calls.push(tool_call_msg);
344                                tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
345
346                                did_call_tool = true;
347                                Ok(tool_result)
348                            }.instrument(tool_span).await;
349
350                            match tc_result {
351                                Ok(text) => {
352                                    let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: OneOrMany::one(ToolResultContent::Text(Text { text })) };
353                                    yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult(tr)));
354                                }
355                                Err(e) => {
356                                    yield Err(e);
357                                }
358                            }
359                        },
360                        Ok(StreamedAssistantContent::ToolCallDelta { id, content }) => {
361                            if let Some(ref hook) = self.hook {
362                                let (name, delta) = match &content {
363                                    rig::streaming::ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
364                                    rig::streaming::ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
365                                };
366                                hook.on_tool_call_delta(&id, name, delta, cancel_sig.clone())
367                                .await;
368
369                                if cancel_sig.is_cancelled() {
370                                    yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
371                                        cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
372                                    ).into()));
373                                }
374                            }
375                        }
376                        Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })) => {
377                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
378                            did_call_tool = false;
379                        },
380                        Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
381                            yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
382                            did_call_tool = false;
383                        },
384                        Ok(StreamedAssistantContent::Final(final_resp)) => {
385                            if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
386                            if is_text_response {
387                                if let Some(ref hook) = self.hook {
388                                    hook.on_stream_completion_response_finish(&prompt, &final_resp, cancel_sig.clone()).await;
389
390                                    if cancel_sig.is_cancelled() {
391                                        yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
392                                            cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
393                                        ).into()));
394                                    }
395                                }
396
397                                tracing::Span::current().record("gen_ai.completion", &last_text_response);
398                                yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
399                                is_text_response = false;
400                            }
401                        }
402                        Err(e) => {
403                            yield Err(e.into());
404                            break 'outer;
405                        }
406                    }
407                }
408
409                // Add (parallel) tool calls to chat history
410                if !tool_calls.is_empty() {
411                    chat_history.write().await.push(Message::Assistant {
412                        id: None,
413                        content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
414                    });
415                }
416
417                // Add tool results to chat history
418                for (id, call_id, tool_result) in tool_results {
419                    if let Some(call_id) = call_id {
420                        chat_history.write().await.push(Message::User {
421                            content: OneOrMany::one(UserContent::tool_result_with_call_id(
422                                &id,
423                                call_id.clone(),
424                                OneOrMany::one(ToolResultContent::text(&tool_result)),
425                            )),
426                        });
427                    } else {
428                        chat_history.write().await.push(Message::User {
429                            content: OneOrMany::one(UserContent::tool_result(
430                                &id,
431                                OneOrMany::one(ToolResultContent::text(&tool_result)),
432                            )),
433                        });
434                    }
435                }
436
437                // Set the current prompt to the last message in the chat history
438                current_prompt = match chat_history.write().await.pop() {
439                    Some(prompt) => prompt,
440                    None => unreachable!("Chat history should never be empty at this point"),
441                };
442
443                if !did_call_tool {
444                    let current_span = tracing::Span::current();
445                    current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
446                    current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
447                    tracing::info!("Agent multi-turn stream finished");
448                    yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
449                    break;
450                }
451            }
452
453            if max_depth_reached {
454                yield Err(Box::new(PromptError::MaxDepthError {
455                    max_depth: self.max_depth,
456                    chat_history: Box::new((*chat_history.read().await).clone()),
457                    prompt: Box::new(last_prompt_error.clone().into()),
458                }).into());
459            }
460        };
461
462        Box::pin(stream.instrument(agent_span))
463    }
464}
465
466impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
467where
468    M: CompletionModel + 'static,
469    <M as CompletionModel>::StreamingResponse: WasmCompatSend,
470    P: StreamingPromptHook<M> + 'static,
471{
472    type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
473    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
474
475    fn into_future(self) -> Self::IntoFuture {
476        // Wrap send() in a future, because send() returns a stream immediately
477        Box::pin(async move { self.send().await })
478    }
479}
480
481/// Helper function to stream a completion request to stdout.
482pub async fn stream_to_stdout<R>(
483    stream: &mut StreamingResult<R>,
484) -> Result<FinalResponse, std::io::Error> {
485    let mut final_res = FinalResponse::empty();
486    print!("Response: ");
487    while let Some(content) = stream.next().await {
488        match content {
489            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
490                Text { text },
491            ))) => {
492                print!("{text}");
493                std::io::Write::flush(&mut std::io::stdout()).unwrap();
494            }
495            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
496                Reasoning { reasoning, .. },
497            ))) => {
498                let reasoning = reasoning.join("\n");
499                print!("{reasoning}");
500                std::io::Write::flush(&mut std::io::stdout()).unwrap();
501            }
502            Ok(MultiTurnStreamItem::FinalResponse(res)) => {
503                final_res = res;
504            }
505            Err(err) => {
506                eprintln!("Error: {err}");
507            }
508            _ => {}
509        }
510    }
511
512    Ok(final_res)
513}
514
515// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
516/// Trait for per-request hooks to observe tool call events.
517pub trait StreamingPromptHook<M>: Clone + Send + Sync
518where
519    M: CompletionModel,
520{
521    #[allow(unused_variables)]
522    /// Called before the prompt is sent to the model
523    fn on_completion_call(
524        &self,
525        prompt: &Message,
526        history: &[Message],
527        cancel_sig: CancelSignal,
528    ) -> impl Future<Output = ()> + Send {
529        async {}
530    }
531
532    #[allow(unused_variables)]
533    /// Called when receiving a text delta
534    fn on_text_delta(
535        &self,
536        text_delta: &str,
537        aggregated_text: &str,
538        cancel_sig: CancelSignal,
539    ) -> impl Future<Output = ()> + Send {
540        async {}
541    }
542
543    #[allow(unused_variables)]
544    /// Called when receiving a tool call delta.
545    /// `tool_name` is Some on the first delta for a tool call, None on subsequent deltas.
546    fn on_tool_call_delta(
547        &self,
548        tool_call_id: &str,
549        tool_name: Option<&str>,
550        tool_call_delta: &str,
551        cancel_sig: CancelSignal,
552    ) -> impl Future<Output = ()> + Send {
553        async {}
554    }
555
556    #[allow(unused_variables)]
557    /// Called after the model provider has finished streaming a text response from their completion API to the client.
558    fn on_stream_completion_response_finish(
559        &self,
560        prompt: &Message,
561        response: &<M as CompletionModel>::StreamingResponse,
562        cancel_sig: CancelSignal,
563    ) -> impl Future<Output = ()> + Send {
564        async {}
565    }
566
567    #[allow(unused_variables)]
568    /// Called before a tool is invoked.
569    fn on_tool_call(
570        &self,
571        tool_name: &str,
572        tool_call_id: Option<String>,
573        args: &str,
574        cancel_sig: CancelSignal,
575    ) -> impl Future<Output = ()> + Send {
576        async {}
577    }
578
579    #[allow(unused_variables)]
580    /// Called after a tool is invoked (and a result has been returned).
581    fn on_tool_result(
582        &self,
583        tool_name: &str,
584        tool_call_id: Option<String>,
585        args: &str,
586        result: &str,
587        cancel_sig: CancelSignal,
588    ) -> impl Future<Output = ()> + Send {
589        async {}
590    }
591}
592
593impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}
594
595#[cfg(test)]
596mod tests {
597    use super::*;
598    use crate::client::ProviderClient;
599    use crate::client::completion::CompletionClient;
600    use crate::providers::anthropic;
601    use crate::streaming::StreamingPrompt;
602    use futures::StreamExt;
603    use std::sync::Arc;
604    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
605    use std::time::Duration;
606
607    /// Background task that logs periodically to detect span leakage.
608    /// If span leakage occurs, these logs will be prefixed with `invoke_agent{...}`.
609    async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
610        let mut interval = tokio::time::interval(Duration::from_millis(50));
611        let mut count = 0u32;
612
613        while !stop.load(Ordering::Relaxed) {
614            interval.tick().await;
615            count += 1;
616
617            tracing::event!(
618                target: "background_logger",
619                tracing::Level::INFO,
620                count = count,
621                "Background tick"
622            );
623
624            // Check if we're inside an unexpected span
625            let current = tracing::Span::current();
626            if !current.is_disabled() && !current.is_none() {
627                leak_count.fetch_add(1, Ordering::Relaxed);
628            }
629        }
630
631        tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
632    }
633
634    /// Test that span context doesn't leak to concurrent tasks during streaming.
635    ///
636    /// This test verifies that using `.instrument()` instead of `span.enter()` in
637    /// async_stream prevents thread-local span context from leaking to other tasks.
638    ///
639    /// Uses single-threaded runtime to force all tasks onto the same thread,
640    /// making the span leak deterministic (it only occurs when tasks share a thread).
641    #[tokio::test(flavor = "current_thread")]
642    #[ignore = "This requires an API key"]
643    async fn test_span_context_isolation() {
644        let stop = Arc::new(AtomicBool::new(false));
645        let leak_count = Arc::new(AtomicU32::new(0));
646
647        // Start background logger
648        let bg_stop = stop.clone();
649        let bg_leak = leak_count.clone();
650        let bg_handle = tokio::spawn(async move {
651            background_logger(bg_stop, bg_leak).await;
652        });
653
654        // Small delay to let background logger start
655        tokio::time::sleep(Duration::from_millis(100)).await;
656
657        // Make streaming request WITHOUT an outer span so rig creates its own invoke_agent span
658        // (rig reuses current span if one exists, so we need to ensure there's no current span)
659        let client = anthropic::Client::from_env();
660        let agent = client
661            .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
662            .preamble("You are a helpful assistant.")
663            .temperature(0.1)
664            .max_tokens(100)
665            .build();
666
667        let mut stream = agent
668            .stream_prompt("Say 'hello world' and nothing else.")
669            .await;
670
671        let mut full_content = String::new();
672        while let Some(item) = stream.next().await {
673            match item {
674                Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
675                    text,
676                ))) => {
677                    full_content.push_str(&text.text);
678                }
679                Ok(MultiTurnStreamItem::FinalResponse(_)) => {
680                    break;
681                }
682                Err(e) => {
683                    tracing::warn!("Error: {:?}", e);
684                    break;
685                }
686                _ => {}
687            }
688        }
689
690        tracing::info!("Got response: {:?}", full_content);
691
692        // Stop background logger
693        stop.store(true, Ordering::Relaxed);
694        bg_handle.await.unwrap();
695
696        let leaks = leak_count.load(Ordering::Relaxed);
697        assert_eq!(
698            leaks, 0,
699            "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
700             This indicates that span.enter() is being used inside async_stream instead of .instrument()"
701        );
702    }
703}