1use crate::{
2 OneOrMany,
3 agent::CancelSignal,
4 completion::GetTokenUsage,
5 json_utils,
6 message::{AssistantContent, Reasoning, ToolResult, ToolResultContent, UserContent},
7 streaming::{StreamedAssistantContent, StreamedUserContent, StreamingCompletion},
8 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
9};
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::{pin::Pin, sync::Arc};
13use tokio::sync::RwLock;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use crate::{
18 agent::Agent,
19 completion::{CompletionError, CompletionModel, PromptError},
20 message::{Message, Text},
21 tool::ToolSetError,
22};
23
24#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
25pub type StreamingResult<R> =
26 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
27
28#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
29pub type StreamingResult<R> =
30 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
31
32#[derive(Deserialize, Serialize, Debug, Clone)]
33#[serde(tag = "type", rename_all = "camelCase")]
34#[non_exhaustive]
35pub enum MultiTurnStreamItem<R> {
36 StreamAssistantItem(StreamedAssistantContent<R>),
38 StreamUserItem(StreamedUserContent),
40 FinalResponse(FinalResponse),
42}
43
44#[derive(Deserialize, Serialize, Debug, Clone)]
45#[serde(rename_all = "camelCase")]
46pub struct FinalResponse {
47 response: String,
48 aggregated_usage: crate::completion::Usage,
49}
50
51impl FinalResponse {
52 pub fn empty() -> Self {
53 Self {
54 response: String::new(),
55 aggregated_usage: crate::completion::Usage::new(),
56 }
57 }
58
59 pub fn response(&self) -> &str {
60 &self.response
61 }
62
63 pub fn usage(&self) -> crate::completion::Usage {
64 self.aggregated_usage
65 }
66}
67
68impl<R> MultiTurnStreamItem<R> {
69 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
70 Self::StreamAssistantItem(item)
71 }
72
73 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
74 Self::FinalResponse(FinalResponse {
75 response: response.to_string(),
76 aggregated_usage,
77 })
78 }
79}
80
81#[derive(Debug, thiserror::Error)]
82pub enum StreamingError {
83 #[error("CompletionError: {0}")]
84 Completion(#[from] CompletionError),
85 #[error("PromptError: {0}")]
86 Prompt(#[from] Box<PromptError>),
87 #[error("ToolSetError: {0}")]
88 Tool(#[from] ToolSetError),
89}
90
91pub struct StreamingPromptRequest<M, P>
100where
101 M: CompletionModel,
102 P: StreamingPromptHook<M> + 'static,
103{
104 prompt: Message,
106 chat_history: Option<Vec<Message>>,
109 max_depth: usize,
111 agent: Arc<Agent<M>>,
113 hook: Option<P>,
115}
116
117impl<M, P> StreamingPromptRequest<M, P>
118where
119 M: CompletionModel + 'static,
120 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
121 P: StreamingPromptHook<M>,
122{
123 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
125 Self {
126 prompt: prompt.into(),
127 chat_history: None,
128 max_depth: 0,
129 agent,
130 hook: None,
131 }
132 }
133
134 pub fn multi_turn(mut self, depth: usize) -> Self {
137 self.max_depth = depth;
138 self
139 }
140
141 pub fn with_history(mut self, history: Vec<Message>) -> Self {
143 self.chat_history = Some(history);
144 self
145 }
146
147 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
149 where
150 P2: StreamingPromptHook<M>,
151 {
152 StreamingPromptRequest {
153 prompt: self.prompt,
154 chat_history: self.chat_history,
155 max_depth: self.max_depth,
156 agent: self.agent,
157 hook: Some(hook),
158 }
159 }
160
161 async fn send(self) -> StreamingResult<M::StreamingResponse> {
162 let agent_span = if tracing::Span::current().is_disabled() {
163 info_span!(
164 "invoke_agent",
165 gen_ai.operation.name = "invoke_agent",
166 gen_ai.agent.name = self.agent.name(),
167 gen_ai.system_instructions = self.agent.preamble,
168 gen_ai.prompt = tracing::field::Empty,
169 gen_ai.completion = tracing::field::Empty,
170 gen_ai.usage.input_tokens = tracing::field::Empty,
171 gen_ai.usage.output_tokens = tracing::field::Empty,
172 )
173 } else {
174 tracing::Span::current()
175 };
176
177 let prompt = self.prompt;
178 if let Some(text) = prompt.rag_text() {
179 agent_span.record("gen_ai.prompt", text);
180 }
181
182 let agent = self.agent;
183
184 let chat_history = if let Some(history) = self.chat_history {
185 Arc::new(RwLock::new(history))
186 } else {
187 Arc::new(RwLock::new(vec![]))
188 };
189
190 let mut current_max_depth = 0;
191 let mut last_prompt_error = String::new();
192
193 let mut last_text_response = String::new();
194 let mut is_text_response = false;
195 let mut max_depth_reached = false;
196
197 let mut aggregated_usage = crate::completion::Usage::new();
198
199 let cancel_signal = CancelSignal::new();
200
201 let stream = async_stream::stream! {
208 let mut current_prompt = prompt.clone();
209 let mut did_call_tool = false;
210
211 'outer: loop {
212 if current_max_depth > self.max_depth + 1 {
213 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
214 max_depth_reached = true;
215 break;
216 }
217
218 current_max_depth += 1;
219
220 if self.max_depth > 1 {
221 tracing::info!(
222 "Current conversation depth: {}/{}",
223 current_max_depth,
224 self.max_depth
225 );
226 }
227
228 if let Some(ref hook) = self.hook {
229 let reader = chat_history.read().await;
230 hook.on_completion_call(¤t_prompt, &reader.to_vec(), cancel_signal.clone())
231 .await;
232
233 if cancel_signal.is_cancelled() {
234 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
235 }
236 }
237
238 let chat_stream_span = info_span!(
239 target: "rig::agent_chat",
240 parent: tracing::Span::current(),
241 "chat_streaming",
242 gen_ai.operation.name = "chat",
243 gen_ai.system_instructions = &agent.preamble,
244 gen_ai.provider.name = tracing::field::Empty,
245 gen_ai.request.model = tracing::field::Empty,
246 gen_ai.response.id = tracing::field::Empty,
247 gen_ai.response.model = tracing::field::Empty,
248 gen_ai.usage.output_tokens = tracing::field::Empty,
249 gen_ai.usage.input_tokens = tracing::field::Empty,
250 gen_ai.input.messages = tracing::field::Empty,
251 gen_ai.output.messages = tracing::field::Empty,
252 );
253
254 let mut stream = tracing::Instrument::instrument(
255 agent
256 .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
257 .await?
258 .stream(), chat_stream_span
259 )
260
261 .await?;
262
263 chat_history.write().await.push(current_prompt.clone());
264
265 let mut tool_calls = vec![];
266 let mut tool_results = vec![];
267
268 while let Some(content) = stream.next().await {
269 match content {
270 Ok(StreamedAssistantContent::Text(text)) => {
271 if !is_text_response {
272 last_text_response = String::new();
273 is_text_response = true;
274 }
275 last_text_response.push_str(&text.text);
276 if let Some(ref hook) = self.hook {
277 hook.on_text_delta(&text.text, &last_text_response, cancel_signal.clone()).await;
278 if cancel_signal.is_cancelled() {
279 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
280 }
281 }
282 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
283 did_call_tool = false;
284 },
285 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
286 let tool_span = info_span!(
287 parent: tracing::Span::current(),
288 "execute_tool",
289 gen_ai.operation.name = "execute_tool",
290 gen_ai.tool.type = "function",
291 gen_ai.tool.name = tracing::field::Empty,
292 gen_ai.tool.call.id = tracing::field::Empty,
293 gen_ai.tool.call.arguments = tracing::field::Empty,
294 gen_ai.tool.call.result = tracing::field::Empty
295 );
296
297 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall(tool_call.clone())));
298
299 let tc_result = async {
300 let tool_span = tracing::Span::current();
301 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
302 if let Some(ref hook) = self.hook {
303 hook.on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &tool_args, cancel_signal.clone()).await;
304 if cancel_signal.is_cancelled() {
305 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
306 }
307 }
308
309 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
310 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
311
312 let tool_result = match
313 agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
314 Ok(thing) => thing,
315 Err(e) => {
316 tracing::warn!("Error while calling tool: {e}");
317 e.to_string()
318 }
319 };
320
321 tool_span.record("gen_ai.tool.call.result", &tool_result);
322
323 if let Some(ref hook) = self.hook {
324 hook.on_tool_result(&tool_call.function.name, tool_call.call_id.clone(), &tool_args, &tool_result.to_string(), cancel_signal.clone())
325 .await;
326
327 if cancel_signal.is_cancelled() {
328 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
329 }
330 }
331
332 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
333
334 tool_calls.push(tool_call_msg);
335 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
336
337 did_call_tool = true;
338 Ok(tool_result)
339 }.instrument(tool_span).await;
340
341 match tc_result {
342 Ok(text) => {
343 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: OneOrMany::one(ToolResultContent::Text(Text { text })) };
344 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult(tr)));
345 }
346 Err(e) => {
347 yield Err(e);
348 }
349 }
350 },
351 Ok(StreamedAssistantContent::ToolCallDelta { id, delta }) => {
352 if let Some(ref hook) = self.hook {
353 hook.on_tool_call_delta(&id, &delta, cancel_signal.clone())
354 .await;
355
356 if cancel_signal.is_cancelled() {
357 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
358 }
359 }
360 }
361 Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })) => {
362 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
363 did_call_tool = false;
364 },
365 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
366 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
367 did_call_tool = false;
368 },
369 Ok(StreamedAssistantContent::Final(final_resp)) => {
370 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
371 if is_text_response {
372 if let Some(ref hook) = self.hook {
373 hook.on_stream_completion_response_finish(&prompt, &final_resp, cancel_signal.clone()).await;
374
375 if cancel_signal.is_cancelled() {
376 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
377 }
378 }
379
380 tracing::Span::current().record("gen_ai.completion", &last_text_response);
381 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
382 is_text_response = false;
383 }
384 }
385 Err(e) => {
386 yield Err(e.into());
387 break 'outer;
388 }
389 }
390 }
391
392 if !tool_calls.is_empty() {
394 chat_history.write().await.push(Message::Assistant {
395 id: None,
396 content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
397 });
398 }
399
400 for (id, call_id, tool_result) in tool_results {
402 if let Some(call_id) = call_id {
403 chat_history.write().await.push(Message::User {
404 content: OneOrMany::one(UserContent::tool_result_with_call_id(
405 &id,
406 call_id.clone(),
407 OneOrMany::one(ToolResultContent::text(&tool_result)),
408 )),
409 });
410 } else {
411 chat_history.write().await.push(Message::User {
412 content: OneOrMany::one(UserContent::tool_result(
413 &id,
414 OneOrMany::one(ToolResultContent::text(&tool_result)),
415 )),
416 });
417 }
418 }
419
420 current_prompt = match chat_history.write().await.pop() {
422 Some(prompt) => prompt,
423 None => unreachable!("Chat history should never be empty at this point"),
424 };
425
426 if !did_call_tool {
427 let current_span = tracing::Span::current();
428 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
429 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
430 tracing::info!("Agent multi-turn stream finished");
431 yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
432 break;
433 }
434 }
435
436 if max_depth_reached {
437 yield Err(Box::new(PromptError::MaxDepthError {
438 max_depth: self.max_depth,
439 chat_history: Box::new((*chat_history.read().await).clone()),
440 prompt: Box::new(last_prompt_error.clone().into()),
441 }).into());
442 }
443 };
444
445 Box::pin(stream.instrument(agent_span))
446 }
447}
448
449impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
450where
451 M: CompletionModel + 'static,
452 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
453 P: StreamingPromptHook<M> + 'static,
454{
455 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
457
458 fn into_future(self) -> Self::IntoFuture {
459 Box::pin(async move { self.send().await })
461 }
462}
463
464pub async fn stream_to_stdout<R>(
466 stream: &mut StreamingResult<R>,
467) -> Result<FinalResponse, std::io::Error> {
468 let mut final_res = FinalResponse::empty();
469 print!("Response: ");
470 while let Some(content) = stream.next().await {
471 match content {
472 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
473 Text { text },
474 ))) => {
475 print!("{text}");
476 std::io::Write::flush(&mut std::io::stdout()).unwrap();
477 }
478 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
479 Reasoning { reasoning, .. },
480 ))) => {
481 let reasoning = reasoning.join("\n");
482 print!("{reasoning}");
483 std::io::Write::flush(&mut std::io::stdout()).unwrap();
484 }
485 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
486 final_res = res;
487 }
488 Err(err) => {
489 eprintln!("Error: {err}");
490 }
491 _ => {}
492 }
493 }
494
495 Ok(final_res)
496}
497
498pub trait StreamingPromptHook<M>: Clone + Send + Sync
501where
502 M: CompletionModel,
503{
504 #[allow(unused_variables)]
505 fn on_completion_call(
507 &self,
508 prompt: &Message,
509 history: &[Message],
510 cancel_sig: CancelSignal,
511 ) -> impl Future<Output = ()> + Send {
512 async {}
513 }
514
515 #[allow(unused_variables)]
516 fn on_text_delta(
518 &self,
519 text_delta: &str,
520 aggregated_text: &str,
521 cancel_sig: CancelSignal,
522 ) -> impl Future<Output = ()> + Send {
523 async {}
524 }
525
526 #[allow(unused_variables)]
527 fn on_tool_call_delta(
529 &self,
530 tool_call_id: &str,
531 tool_call_delta: &str,
532 cancel_sig: CancelSignal,
533 ) -> impl Future<Output = ()> + Send {
534 async {}
535 }
536
537 #[allow(unused_variables)]
538 fn on_stream_completion_response_finish(
540 &self,
541 prompt: &Message,
542 response: &<M as CompletionModel>::StreamingResponse,
543 cancel_sig: CancelSignal,
544 ) -> impl Future<Output = ()> + Send {
545 async {}
546 }
547
548 #[allow(unused_variables)]
549 fn on_tool_call(
551 &self,
552 tool_name: &str,
553 tool_call_id: Option<String>,
554 args: &str,
555 cancel_sig: CancelSignal,
556 ) -> impl Future<Output = ()> + Send {
557 async {}
558 }
559
560 #[allow(unused_variables)]
561 fn on_tool_result(
563 &self,
564 tool_name: &str,
565 tool_call_id: Option<String>,
566 args: &str,
567 result: &str,
568 cancel_sig: CancelSignal,
569 ) -> impl Future<Output = ()> + Send {
570 async {}
571 }
572}
573
574impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use crate::client::ProviderClient;
580 use crate::client::completion::CompletionClient;
581 use crate::providers::anthropic;
582 use crate::streaming::StreamingPrompt;
583 use futures::StreamExt;
584 use std::sync::Arc;
585 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
586 use std::time::Duration;
587
588 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
591 let mut interval = tokio::time::interval(Duration::from_millis(50));
592 let mut count = 0u32;
593
594 while !stop.load(Ordering::Relaxed) {
595 interval.tick().await;
596 count += 1;
597
598 tracing::event!(
599 target: "background_logger",
600 tracing::Level::INFO,
601 count = count,
602 "Background tick"
603 );
604
605 let current = tracing::Span::current();
607 if !current.is_disabled() && !current.is_none() {
608 leak_count.fetch_add(1, Ordering::Relaxed);
609 }
610 }
611
612 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
613 }
614
615 #[tokio::test(flavor = "current_thread")]
623 #[ignore = "This requires an API key"]
624 async fn test_span_context_isolation() {
625 let stop = Arc::new(AtomicBool::new(false));
626 let leak_count = Arc::new(AtomicU32::new(0));
627
628 let bg_stop = stop.clone();
630 let bg_leak = leak_count.clone();
631 let bg_handle = tokio::spawn(async move {
632 background_logger(bg_stop, bg_leak).await;
633 });
634
635 tokio::time::sleep(Duration::from_millis(100)).await;
637
638 let client = anthropic::Client::from_env();
641 let agent = client
642 .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
643 .preamble("You are a helpful assistant.")
644 .temperature(0.1)
645 .max_tokens(100)
646 .build();
647
648 let mut stream = agent
649 .stream_prompt("Say 'hello world' and nothing else.")
650 .await;
651
652 let mut full_content = String::new();
653 while let Some(item) = stream.next().await {
654 match item {
655 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
656 text,
657 ))) => {
658 full_content.push_str(&text.text);
659 }
660 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
661 break;
662 }
663 Err(e) => {
664 tracing::warn!("Error: {:?}", e);
665 break;
666 }
667 _ => {}
668 }
669 }
670
671 tracing::info!("Got response: {:?}", full_content);
672
673 stop.store(true, Ordering::Relaxed);
675 bg_handle.await.unwrap();
676
677 let leaks = leak_count.load(Ordering::Relaxed);
678 assert_eq!(
679 leaks, 0,
680 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
681 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
682 );
683 }
684}