1use crate::{
2 OneOrMany,
3 agent::CancelSignal,
4 completion::GetTokenUsage,
5 json_utils,
6 message::{AssistantContent, Reasoning, ToolResult, ToolResultContent, UserContent},
7 streaming::{StreamedAssistantContent, StreamedUserContent, StreamingCompletion},
8 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
9};
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::{pin::Pin, sync::Arc};
13use tokio::sync::RwLock;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use crate::{
18 agent::Agent,
19 completion::{CompletionError, CompletionModel, PromptError},
20 message::{Message, Text},
21 tool::ToolSetError,
22};
23
24#[cfg(not(target_arch = "wasm32"))]
25pub type StreamingResult<R> =
26 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
27
28#[cfg(target_arch = "wasm32")]
29pub type StreamingResult<R> =
30 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
31
32#[derive(Deserialize, Serialize, Debug, Clone)]
33#[serde(tag = "type", rename_all = "camelCase")]
34#[non_exhaustive]
35pub enum MultiTurnStreamItem<R> {
36 StreamAssistantItem(StreamedAssistantContent<R>),
38 StreamUserItem(StreamedUserContent),
40 FinalResponse(FinalResponse),
42}
43
44#[derive(Deserialize, Serialize, Debug, Clone)]
45#[serde(rename_all = "camelCase")]
46pub struct FinalResponse {
47 response: String,
48 aggregated_usage: crate::completion::Usage,
49}
50
51impl FinalResponse {
52 pub fn empty() -> Self {
53 Self {
54 response: String::new(),
55 aggregated_usage: crate::completion::Usage::new(),
56 }
57 }
58
59 pub fn response(&self) -> &str {
60 &self.response
61 }
62
63 pub fn usage(&self) -> crate::completion::Usage {
64 self.aggregated_usage
65 }
66}
67
68impl<R> MultiTurnStreamItem<R> {
69 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
70 Self::StreamAssistantItem(item)
71 }
72
73 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
74 Self::FinalResponse(FinalResponse {
75 response: response.to_string(),
76 aggregated_usage,
77 })
78 }
79}
80
81#[derive(Debug, thiserror::Error)]
82pub enum StreamingError {
83 #[error("CompletionError: {0}")]
84 Completion(#[from] CompletionError),
85 #[error("PromptError: {0}")]
86 Prompt(#[from] Box<PromptError>),
87 #[error("ToolSetError: {0}")]
88 Tool(#[from] ToolSetError),
89}
90
91pub struct StreamingPromptRequest<M, P>
100where
101 M: CompletionModel,
102 P: StreamingPromptHook<M> + 'static,
103{
104 prompt: Message,
106 chat_history: Option<Vec<Message>>,
109 max_depth: usize,
111 agent: Arc<Agent<M>>,
113 hook: Option<P>,
115}
116
117impl<M, P> StreamingPromptRequest<M, P>
118where
119 M: CompletionModel + 'static,
120 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
121 P: StreamingPromptHook<M>,
122{
123 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
125 Self {
126 prompt: prompt.into(),
127 chat_history: None,
128 max_depth: 0,
129 agent,
130 hook: None,
131 }
132 }
133
134 pub fn multi_turn(mut self, depth: usize) -> Self {
137 self.max_depth = depth;
138 self
139 }
140
141 pub fn with_history(mut self, history: Vec<Message>) -> Self {
143 self.chat_history = Some(history);
144 self
145 }
146
147 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
149 where
150 P2: StreamingPromptHook<M>,
151 {
152 StreamingPromptRequest {
153 prompt: self.prompt,
154 chat_history: self.chat_history,
155 max_depth: self.max_depth,
156 agent: self.agent,
157 hook: Some(hook),
158 }
159 }
160
161 #[cfg_attr(feature = "worker", worker::send)]
162 async fn send(self) -> StreamingResult<M::StreamingResponse> {
163 let agent_span = if tracing::Span::current().is_disabled() {
164 info_span!(
165 "invoke_agent",
166 gen_ai.operation.name = "invoke_agent",
167 gen_ai.agent.name = self.agent.name(),
168 gen_ai.system_instructions = self.agent.preamble,
169 gen_ai.prompt = tracing::field::Empty,
170 gen_ai.completion = tracing::field::Empty,
171 gen_ai.usage.input_tokens = tracing::field::Empty,
172 gen_ai.usage.output_tokens = tracing::field::Empty,
173 )
174 } else {
175 tracing::Span::current()
176 };
177
178 let prompt = self.prompt;
179 if let Some(text) = prompt.rag_text() {
180 agent_span.record("gen_ai.prompt", text);
181 }
182
183 let agent = self.agent;
184
185 let chat_history = if let Some(history) = self.chat_history {
186 Arc::new(RwLock::new(history))
187 } else {
188 Arc::new(RwLock::new(vec![]))
189 };
190
191 let mut current_max_depth = 0;
192 let mut last_prompt_error = String::new();
193
194 let mut last_text_response = String::new();
195 let mut is_text_response = false;
196 let mut max_depth_reached = false;
197
198 let mut aggregated_usage = crate::completion::Usage::new();
199
200 let cancel_signal = CancelSignal::new();
201
202 let stream = async_stream::stream! {
209 let mut current_prompt = prompt.clone();
210 let mut did_call_tool = false;
211
212 'outer: loop {
213 if current_max_depth > self.max_depth + 1 {
214 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
215 max_depth_reached = true;
216 break;
217 }
218
219 current_max_depth += 1;
220
221 if self.max_depth > 1 {
222 tracing::info!(
223 "Current conversation depth: {}/{}",
224 current_max_depth,
225 self.max_depth
226 );
227 }
228
229 if let Some(ref hook) = self.hook {
230 let reader = chat_history.read().await;
231 let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history");
232 let chat_history_except_last = reader[..reader.len() - 1].to_vec();
233
234 hook.on_completion_call(&prompt, &chat_history_except_last, cancel_signal.clone())
235 .await;
236
237 if cancel_signal.is_cancelled() {
238 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
239 }
240 }
241
242 let chat_stream_span = info_span!(
243 target: "rig::agent_chat",
244 parent: tracing::Span::current(),
245 "chat_streaming",
246 gen_ai.operation.name = "chat",
247 gen_ai.system_instructions = &agent.preamble,
248 gen_ai.provider.name = tracing::field::Empty,
249 gen_ai.request.model = tracing::field::Empty,
250 gen_ai.response.id = tracing::field::Empty,
251 gen_ai.response.model = tracing::field::Empty,
252 gen_ai.usage.output_tokens = tracing::field::Empty,
253 gen_ai.usage.input_tokens = tracing::field::Empty,
254 gen_ai.input.messages = tracing::field::Empty,
255 gen_ai.output.messages = tracing::field::Empty,
256 );
257
258 let mut stream = tracing::Instrument::instrument(
259 agent
260 .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
261 .await?
262 .stream(), chat_stream_span
263 )
264
265 .await?;
266
267 chat_history.write().await.push(current_prompt.clone());
268
269 let mut tool_calls = vec![];
270 let mut tool_results = vec![];
271
272 while let Some(content) = stream.next().await {
273 match content {
274 Ok(StreamedAssistantContent::Text(text)) => {
275 if !is_text_response {
276 last_text_response = String::new();
277 is_text_response = true;
278 }
279 last_text_response.push_str(&text.text);
280 if let Some(ref hook) = self.hook {
281 hook.on_text_delta(&text.text, &last_text_response, cancel_signal.clone()).await;
282 if cancel_signal.is_cancelled() {
283 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
284 }
285 }
286 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
287 did_call_tool = false;
288 },
289 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
290 let tool_span = info_span!(
291 parent: tracing::Span::current(),
292 "execute_tool",
293 gen_ai.operation.name = "execute_tool",
294 gen_ai.tool.type = "function",
295 gen_ai.tool.name = tracing::field::Empty,
296 gen_ai.tool.call.id = tracing::field::Empty,
297 gen_ai.tool.call.arguments = tracing::field::Empty,
298 gen_ai.tool.call.result = tracing::field::Empty
299 );
300
301 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall(tool_call.clone())));
302
303 let tc_result = async {
304 let tool_span = tracing::Span::current();
305 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
306 if let Some(ref hook) = self.hook {
307 hook.on_tool_call(&tool_call.function.name, &tool_args, cancel_signal.clone()).await;
308 if cancel_signal.is_cancelled() {
309 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
310 }
311 }
312
313 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
314 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
315
316 let tool_result = match
317 agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
318 Ok(thing) => thing,
319 Err(e) => {
320 tracing::warn!("Error while calling tool: {e}");
321 e.to_string()
322 }
323 };
324
325 tool_span.record("gen_ai.tool.call.result", &tool_result);
326
327 if let Some(ref hook) = self.hook {
328 hook.on_tool_result(&tool_call.function.name, &tool_args, &tool_result.to_string(), cancel_signal.clone())
329 .await;
330
331 if cancel_signal.is_cancelled() {
332 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
333 }
334 }
335
336 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
337
338 tool_calls.push(tool_call_msg);
339 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
340
341 did_call_tool = true;
342 Ok(tool_result)
343 }.instrument(tool_span).await;
344
345 match tc_result {
346 Ok(text) => {
347 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: OneOrMany::one(ToolResultContent::Text(Text { text })) };
348 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult(tr)));
349 }
350 Err(e) => {
351 yield Err(e);
352 }
353 }
354 },
355 Ok(StreamedAssistantContent::ToolCallDelta { id, delta }) => {
356 if let Some(ref hook) = self.hook {
357 hook.on_tool_call_delta(&id, &delta, cancel_signal.clone())
358 .await;
359
360 if cancel_signal.is_cancelled() {
361 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
362 }
363 }
364 }
365 Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })) => {
366 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
367 did_call_tool = false;
368 },
369 Ok(StreamedAssistantContent::Final(final_resp)) => {
370 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
371 if is_text_response {
372 if let Some(ref hook) = self.hook {
373 hook.on_stream_completion_response_finish(&prompt, &final_resp, cancel_signal.clone()).await;
374
375 if cancel_signal.is_cancelled() {
376 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
377 }
378 }
379
380 tracing::Span::current().record("gen_ai.completion", &last_text_response);
381 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
382 is_text_response = false;
383 }
384 }
385 Err(e) => {
386 yield Err(e.into());
387 break 'outer;
388 }
389 }
390 }
391
392 if !tool_calls.is_empty() {
394 chat_history.write().await.push(Message::Assistant {
395 id: None,
396 content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
397 });
398 }
399
400 for (id, call_id, tool_result) in tool_results {
402 if let Some(call_id) = call_id {
403 chat_history.write().await.push(Message::User {
404 content: OneOrMany::one(UserContent::tool_result_with_call_id(
405 &id,
406 call_id.clone(),
407 OneOrMany::one(ToolResultContent::text(&tool_result)),
408 )),
409 });
410 } else {
411 chat_history.write().await.push(Message::User {
412 content: OneOrMany::one(UserContent::tool_result(
413 &id,
414 OneOrMany::one(ToolResultContent::text(&tool_result)),
415 )),
416 });
417 }
418 }
419
420 current_prompt = match chat_history.write().await.pop() {
422 Some(prompt) => prompt,
423 None => unreachable!("Chat history should never be empty at this point"),
424 };
425
426 if !did_call_tool {
427 let current_span = tracing::Span::current();
428 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
429 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
430 tracing::info!("Agent multi-turn stream finished");
431 yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
432 break;
433 }
434 }
435
436 if max_depth_reached {
437 yield Err(Box::new(PromptError::MaxDepthError {
438 max_depth: self.max_depth,
439 chat_history: Box::new((*chat_history.read().await).clone()),
440 prompt: last_prompt_error.clone().into(),
441 }).into());
442 }
443 };
444
445 Box::pin(stream.instrument(agent_span))
446 }
447}
448
449impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
450where
451 M: CompletionModel + 'static,
452 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
453 P: StreamingPromptHook<M> + 'static,
454{
455 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
457
458 fn into_future(self) -> Self::IntoFuture {
459 Box::pin(async move { self.send().await })
461 }
462}
463
464pub async fn stream_to_stdout<R>(
466 stream: &mut StreamingResult<R>,
467) -> Result<FinalResponse, std::io::Error> {
468 let mut final_res = FinalResponse::empty();
469 print!("Response: ");
470 while let Some(content) = stream.next().await {
471 match content {
472 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
473 Text { text },
474 ))) => {
475 print!("{text}");
476 std::io::Write::flush(&mut std::io::stdout()).unwrap();
477 }
478 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
479 Reasoning { reasoning, .. },
480 ))) => {
481 let reasoning = reasoning.join("\n");
482 print!("{reasoning}");
483 std::io::Write::flush(&mut std::io::stdout()).unwrap();
484 }
485 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
486 final_res = res;
487 }
488 Err(err) => {
489 eprintln!("Error: {err}");
490 }
491 _ => {}
492 }
493 }
494
495 Ok(final_res)
496}
497
498pub trait StreamingPromptHook<M>: Clone + Send + Sync
501where
502 M: CompletionModel,
503{
504 #[allow(unused_variables)]
505 fn on_completion_call(
507 &self,
508 prompt: &Message,
509 history: &[Message],
510 cancel_sig: CancelSignal,
511 ) -> impl Future<Output = ()> + Send {
512 async {}
513 }
514
515 #[allow(unused_variables)]
516 fn on_text_delta(
518 &self,
519 text_delta: &str,
520 aggregated_text: &str,
521 cancel_sig: CancelSignal,
522 ) -> impl Future<Output = ()> + Send {
523 async {}
524 }
525
526 #[allow(unused_variables)]
527 fn on_tool_call_delta(
529 &self,
530 tool_call_id: &str,
531 tool_call_delta: &str,
532 cancel_sig: CancelSignal,
533 ) -> impl Future<Output = ()> + Send {
534 async {}
535 }
536
537 #[allow(unused_variables)]
538 fn on_stream_completion_response_finish(
540 &self,
541 prompt: &Message,
542 response: &<M as CompletionModel>::StreamingResponse,
543 cancel_sig: CancelSignal,
544 ) -> impl Future<Output = ()> + Send {
545 async {}
546 }
547
548 #[allow(unused_variables)]
549 fn on_tool_call(
551 &self,
552 tool_name: &str,
553 args: &str,
554 cancel_sig: CancelSignal,
555 ) -> impl Future<Output = ()> + Send {
556 async {}
557 }
558
559 #[allow(unused_variables)]
560 fn on_tool_result(
562 &self,
563 tool_name: &str,
564 args: &str,
565 result: &str,
566 cancel_sig: CancelSignal,
567 ) -> impl Future<Output = ()> + Send {
568 async {}
569 }
570}
571
572impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577 use crate::client::ProviderClient;
578 use crate::client::completion::CompletionClient;
579 use crate::providers::anthropic;
580 use crate::streaming::StreamingPrompt;
581 use futures::StreamExt;
582 use std::sync::Arc;
583 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
584 use std::time::Duration;
585
586 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
589 let mut interval = tokio::time::interval(Duration::from_millis(50));
590 let mut count = 0u32;
591
592 while !stop.load(Ordering::Relaxed) {
593 interval.tick().await;
594 count += 1;
595
596 tracing::event!(
597 target: "background_logger",
598 tracing::Level::INFO,
599 count = count,
600 "Background tick"
601 );
602
603 let current = tracing::Span::current();
605 if !current.is_disabled() && !current.is_none() {
606 leak_count.fetch_add(1, Ordering::Relaxed);
607 }
608 }
609
610 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
611 }
612
613 #[tokio::test(flavor = "current_thread")]
621 #[ignore = "This requires an API key"]
622 async fn test_span_context_isolation() {
623 let stop = Arc::new(AtomicBool::new(false));
624 let leak_count = Arc::new(AtomicU32::new(0));
625
626 let bg_stop = stop.clone();
628 let bg_leak = leak_count.clone();
629 let bg_handle = tokio::spawn(async move {
630 background_logger(bg_stop, bg_leak).await;
631 });
632
633 tokio::time::sleep(Duration::from_millis(100)).await;
635
636 let client = anthropic::Client::from_env();
639 let agent = client
640 .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
641 .preamble("You are a helpful assistant.")
642 .temperature(0.1)
643 .max_tokens(100)
644 .build();
645
646 let mut stream = agent
647 .stream_prompt("Say 'hello world' and nothing else.")
648 .await;
649
650 let mut full_content = String::new();
651 while let Some(item) = stream.next().await {
652 match item {
653 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
654 text,
655 ))) => {
656 full_content.push_str(&text.text);
657 }
658 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
659 break;
660 }
661 Err(e) => {
662 tracing::warn!("Error: {:?}", e);
663 break;
664 }
665 _ => {}
666 }
667 }
668
669 tracing::info!("Got response: {:?}", full_content);
670
671 stop.store(true, Ordering::Relaxed);
673 bg_handle.await.unwrap();
674
675 let leaks = leak_count.load(Ordering::Relaxed);
676 assert_eq!(
677 leaks, 0,
678 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
679 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
680 );
681 }
682}