1use crate::{
2 OneOrMany,
3 agent::CancelSignal,
4 completion::GetTokenUsage,
5 json_utils,
6 message::{AssistantContent, Reasoning, ToolResult, ToolResultContent, UserContent},
7 streaming::{StreamedAssistantContent, StreamedUserContent, StreamingCompletion},
8 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
9};
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::{pin::Pin, sync::Arc};
13use tokio::sync::RwLock;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use crate::{
18 agent::Agent,
19 completion::{CompletionError, CompletionModel, PromptError},
20 message::{Message, Text},
21 tool::ToolSetError,
22};
23
24#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
25pub type StreamingResult<R> =
26 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
27
28#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
29pub type StreamingResult<R> =
30 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
31
32#[derive(Deserialize, Serialize, Debug, Clone)]
33#[serde(tag = "type", rename_all = "camelCase")]
34#[non_exhaustive]
35pub enum MultiTurnStreamItem<R> {
36 StreamAssistantItem(StreamedAssistantContent<R>),
38 StreamUserItem(StreamedUserContent),
40 FinalResponse(FinalResponse),
42}
43
44#[derive(Deserialize, Serialize, Debug, Clone)]
45#[serde(rename_all = "camelCase")]
46pub struct FinalResponse {
47 response: String,
48 aggregated_usage: crate::completion::Usage,
49}
50
51impl FinalResponse {
52 pub fn empty() -> Self {
53 Self {
54 response: String::new(),
55 aggregated_usage: crate::completion::Usage::new(),
56 }
57 }
58
59 pub fn response(&self) -> &str {
60 &self.response
61 }
62
63 pub fn usage(&self) -> crate::completion::Usage {
64 self.aggregated_usage
65 }
66}
67
68impl<R> MultiTurnStreamItem<R> {
69 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
70 Self::StreamAssistantItem(item)
71 }
72
73 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
74 Self::FinalResponse(FinalResponse {
75 response: response.to_string(),
76 aggregated_usage,
77 })
78 }
79}
80
81#[derive(Debug, thiserror::Error)]
82pub enum StreamingError {
83 #[error("CompletionError: {0}")]
84 Completion(#[from] CompletionError),
85 #[error("PromptError: {0}")]
86 Prompt(#[from] Box<PromptError>),
87 #[error("ToolSetError: {0}")]
88 Tool(#[from] ToolSetError),
89}
90
91pub struct StreamingPromptRequest<M, P>
100where
101 M: CompletionModel,
102 P: StreamingPromptHook<M> + 'static,
103{
104 prompt: Message,
106 chat_history: Option<Vec<Message>>,
109 max_depth: usize,
111 agent: Arc<Agent<M>>,
113 hook: Option<P>,
115}
116
117impl<M, P> StreamingPromptRequest<M, P>
118where
119 M: CompletionModel + 'static,
120 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
121 P: StreamingPromptHook<M>,
122{
123 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
125 Self {
126 prompt: prompt.into(),
127 chat_history: None,
128 max_depth: 0,
129 agent,
130 hook: None,
131 }
132 }
133
134 pub fn multi_turn(mut self, depth: usize) -> Self {
137 self.max_depth = depth;
138 self
139 }
140
141 pub fn with_history(mut self, history: Vec<Message>) -> Self {
143 self.chat_history = Some(history);
144 self
145 }
146
147 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
149 where
150 P2: StreamingPromptHook<M>,
151 {
152 StreamingPromptRequest {
153 prompt: self.prompt,
154 chat_history: self.chat_history,
155 max_depth: self.max_depth,
156 agent: self.agent,
157 hook: Some(hook),
158 }
159 }
160
161 async fn send(self) -> StreamingResult<M::StreamingResponse> {
162 let agent_span = if tracing::Span::current().is_disabled() {
163 info_span!(
164 "invoke_agent",
165 gen_ai.operation.name = "invoke_agent",
166 gen_ai.agent.name = self.agent.name(),
167 gen_ai.system_instructions = self.agent.preamble,
168 gen_ai.prompt = tracing::field::Empty,
169 gen_ai.completion = tracing::field::Empty,
170 gen_ai.usage.input_tokens = tracing::field::Empty,
171 gen_ai.usage.output_tokens = tracing::field::Empty,
172 )
173 } else {
174 tracing::Span::current()
175 };
176
177 let prompt = self.prompt;
178 if let Some(text) = prompt.rag_text() {
179 agent_span.record("gen_ai.prompt", text);
180 }
181
182 let agent = self.agent;
183
184 let chat_history = if let Some(history) = self.chat_history {
185 Arc::new(RwLock::new(history))
186 } else {
187 Arc::new(RwLock::new(vec![]))
188 };
189
190 let mut current_max_depth = 0;
191 let mut last_prompt_error = String::new();
192
193 let mut last_text_response = String::new();
194 let mut is_text_response = false;
195 let mut max_depth_reached = false;
196
197 let mut aggregated_usage = crate::completion::Usage::new();
198
199 let cancel_signal = CancelSignal::new();
200
201 let stream = async_stream::stream! {
208 let mut current_prompt = prompt.clone();
209 let mut did_call_tool = false;
210
211 'outer: loop {
212 if current_max_depth > self.max_depth + 1 {
213 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
214 max_depth_reached = true;
215 break;
216 }
217
218 current_max_depth += 1;
219
220 if self.max_depth > 1 {
221 tracing::info!(
222 "Current conversation depth: {}/{}",
223 current_max_depth,
224 self.max_depth
225 );
226 }
227
228 if let Some(ref hook) = self.hook {
229 let reader = chat_history.read().await;
230 hook.on_completion_call(¤t_prompt, &reader.to_vec(), cancel_signal.clone())
231 .await;
232
233 if cancel_signal.is_cancelled() {
234 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
235 }
236 }
237
238 let chat_stream_span = info_span!(
239 target: "rig::agent_chat",
240 parent: tracing::Span::current(),
241 "chat_streaming",
242 gen_ai.operation.name = "chat",
243 gen_ai.system_instructions = &agent.preamble,
244 gen_ai.provider.name = tracing::field::Empty,
245 gen_ai.request.model = tracing::field::Empty,
246 gen_ai.response.id = tracing::field::Empty,
247 gen_ai.response.model = tracing::field::Empty,
248 gen_ai.usage.output_tokens = tracing::field::Empty,
249 gen_ai.usage.input_tokens = tracing::field::Empty,
250 gen_ai.input.messages = tracing::field::Empty,
251 gen_ai.output.messages = tracing::field::Empty,
252 );
253
254 let mut stream = tracing::Instrument::instrument(
255 agent
256 .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
257 .await?
258 .stream(), chat_stream_span
259 )
260
261 .await?;
262
263 chat_history.write().await.push(current_prompt.clone());
264
265 let mut tool_calls = vec![];
266 let mut tool_results = vec![];
267
268 while let Some(content) = stream.next().await {
269 match content {
270 Ok(StreamedAssistantContent::Text(text)) => {
271 if !is_text_response {
272 last_text_response = String::new();
273 is_text_response = true;
274 }
275 last_text_response.push_str(&text.text);
276 if let Some(ref hook) = self.hook {
277 hook.on_text_delta(&text.text, &last_text_response, cancel_signal.clone()).await;
278 if cancel_signal.is_cancelled() {
279 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
280 }
281 }
282 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
283 did_call_tool = false;
284 },
285 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
286 let tool_span = info_span!(
287 parent: tracing::Span::current(),
288 "execute_tool",
289 gen_ai.operation.name = "execute_tool",
290 gen_ai.tool.type = "function",
291 gen_ai.tool.name = tracing::field::Empty,
292 gen_ai.tool.call.id = tracing::field::Empty,
293 gen_ai.tool.call.arguments = tracing::field::Empty,
294 gen_ai.tool.call.result = tracing::field::Empty
295 );
296
297 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall(tool_call.clone())));
298
299 let tc_result = async {
300 let tool_span = tracing::Span::current();
301 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
302 if let Some(ref hook) = self.hook {
303 hook.on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &tool_args, cancel_signal.clone()).await;
304 if cancel_signal.is_cancelled() {
305 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
306 }
307 }
308
309 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
310 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
311
312 let tool_result = match
313 agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
314 Ok(thing) => thing,
315 Err(e) => {
316 tracing::warn!("Error while calling tool: {e}");
317 e.to_string()
318 }
319 };
320
321 tool_span.record("gen_ai.tool.call.result", &tool_result);
322
323 if let Some(ref hook) = self.hook {
324 hook.on_tool_result(&tool_call.function.name, tool_call.call_id.clone(), &tool_args, &tool_result.to_string(), cancel_signal.clone())
325 .await;
326
327 if cancel_signal.is_cancelled() {
328 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
329 }
330 }
331
332 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
333
334 tool_calls.push(tool_call_msg);
335 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
336
337 did_call_tool = true;
338 Ok(tool_result)
339 }.instrument(tool_span).await;
340
341 match tc_result {
342 Ok(text) => {
343 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: OneOrMany::one(ToolResultContent::Text(Text { text })) };
344 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult(tr)));
345 }
346 Err(e) => {
347 yield Err(e);
348 }
349 }
350 },
351 Ok(StreamedAssistantContent::ToolCallDelta { id, content }) => {
352 if let Some(ref hook) = self.hook {
353 let (name, delta) = match &content {
354 rig::streaming::ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
355 rig::streaming::ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
356 };
357 hook.on_tool_call_delta(&id, name, delta, cancel_signal.clone())
358 .await;
359
360 if cancel_signal.is_cancelled() {
361 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
362 }
363 }
364 }
365 Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })) => {
366 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
367 did_call_tool = false;
368 },
369 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
370 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
371 did_call_tool = false;
372 },
373 Ok(StreamedAssistantContent::Final(final_resp)) => {
374 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
375 if is_text_response {
376 if let Some(ref hook) = self.hook {
377 hook.on_stream_completion_response_finish(&prompt, &final_resp, cancel_signal.clone()).await;
378
379 if cancel_signal.is_cancelled() {
380 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec()).into()));
381 }
382 }
383
384 tracing::Span::current().record("gen_ai.completion", &last_text_response);
385 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
386 is_text_response = false;
387 }
388 }
389 Err(e) => {
390 yield Err(e.into());
391 break 'outer;
392 }
393 }
394 }
395
396 if !tool_calls.is_empty() {
398 chat_history.write().await.push(Message::Assistant {
399 id: None,
400 content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
401 });
402 }
403
404 for (id, call_id, tool_result) in tool_results {
406 if let Some(call_id) = call_id {
407 chat_history.write().await.push(Message::User {
408 content: OneOrMany::one(UserContent::tool_result_with_call_id(
409 &id,
410 call_id.clone(),
411 OneOrMany::one(ToolResultContent::text(&tool_result)),
412 )),
413 });
414 } else {
415 chat_history.write().await.push(Message::User {
416 content: OneOrMany::one(UserContent::tool_result(
417 &id,
418 OneOrMany::one(ToolResultContent::text(&tool_result)),
419 )),
420 });
421 }
422 }
423
424 current_prompt = match chat_history.write().await.pop() {
426 Some(prompt) => prompt,
427 None => unreachable!("Chat history should never be empty at this point"),
428 };
429
430 if !did_call_tool {
431 let current_span = tracing::Span::current();
432 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
433 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
434 tracing::info!("Agent multi-turn stream finished");
435 yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
436 break;
437 }
438 }
439
440 if max_depth_reached {
441 yield Err(Box::new(PromptError::MaxDepthError {
442 max_depth: self.max_depth,
443 chat_history: Box::new((*chat_history.read().await).clone()),
444 prompt: Box::new(last_prompt_error.clone().into()),
445 }).into());
446 }
447 };
448
449 Box::pin(stream.instrument(agent_span))
450 }
451}
452
453impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
454where
455 M: CompletionModel + 'static,
456 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
457 P: StreamingPromptHook<M> + 'static,
458{
459 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
461
462 fn into_future(self) -> Self::IntoFuture {
463 Box::pin(async move { self.send().await })
465 }
466}
467
468pub async fn stream_to_stdout<R>(
470 stream: &mut StreamingResult<R>,
471) -> Result<FinalResponse, std::io::Error> {
472 let mut final_res = FinalResponse::empty();
473 print!("Response: ");
474 while let Some(content) = stream.next().await {
475 match content {
476 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
477 Text { text },
478 ))) => {
479 print!("{text}");
480 std::io::Write::flush(&mut std::io::stdout()).unwrap();
481 }
482 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
483 Reasoning { reasoning, .. },
484 ))) => {
485 let reasoning = reasoning.join("\n");
486 print!("{reasoning}");
487 std::io::Write::flush(&mut std::io::stdout()).unwrap();
488 }
489 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
490 final_res = res;
491 }
492 Err(err) => {
493 eprintln!("Error: {err}");
494 }
495 _ => {}
496 }
497 }
498
499 Ok(final_res)
500}
501
502pub trait StreamingPromptHook<M>: Clone + Send + Sync
505where
506 M: CompletionModel,
507{
508 #[allow(unused_variables)]
509 fn on_completion_call(
511 &self,
512 prompt: &Message,
513 history: &[Message],
514 cancel_sig: CancelSignal,
515 ) -> impl Future<Output = ()> + Send {
516 async {}
517 }
518
519 #[allow(unused_variables)]
520 fn on_text_delta(
522 &self,
523 text_delta: &str,
524 aggregated_text: &str,
525 cancel_sig: CancelSignal,
526 ) -> impl Future<Output = ()> + Send {
527 async {}
528 }
529
530 #[allow(unused_variables)]
531 fn on_tool_call_delta(
534 &self,
535 tool_call_id: &str,
536 tool_name: Option<&str>,
537 tool_call_delta: &str,
538 cancel_sig: CancelSignal,
539 ) -> impl Future<Output = ()> + Send {
540 async {}
541 }
542
543 #[allow(unused_variables)]
544 fn on_stream_completion_response_finish(
546 &self,
547 prompt: &Message,
548 response: &<M as CompletionModel>::StreamingResponse,
549 cancel_sig: CancelSignal,
550 ) -> impl Future<Output = ()> + Send {
551 async {}
552 }
553
554 #[allow(unused_variables)]
555 fn on_tool_call(
557 &self,
558 tool_name: &str,
559 tool_call_id: Option<String>,
560 args: &str,
561 cancel_sig: CancelSignal,
562 ) -> impl Future<Output = ()> + Send {
563 async {}
564 }
565
566 #[allow(unused_variables)]
567 fn on_tool_result(
569 &self,
570 tool_name: &str,
571 tool_call_id: Option<String>,
572 args: &str,
573 result: &str,
574 cancel_sig: CancelSignal,
575 ) -> impl Future<Output = ()> + Send {
576 async {}
577 }
578}
579
580impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585 use crate::client::ProviderClient;
586 use crate::client::completion::CompletionClient;
587 use crate::providers::anthropic;
588 use crate::streaming::StreamingPrompt;
589 use futures::StreamExt;
590 use std::sync::Arc;
591 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
592 use std::time::Duration;
593
594 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
597 let mut interval = tokio::time::interval(Duration::from_millis(50));
598 let mut count = 0u32;
599
600 while !stop.load(Ordering::Relaxed) {
601 interval.tick().await;
602 count += 1;
603
604 tracing::event!(
605 target: "background_logger",
606 tracing::Level::INFO,
607 count = count,
608 "Background tick"
609 );
610
611 let current = tracing::Span::current();
613 if !current.is_disabled() && !current.is_none() {
614 leak_count.fetch_add(1, Ordering::Relaxed);
615 }
616 }
617
618 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
619 }
620
621 #[tokio::test(flavor = "current_thread")]
629 #[ignore = "This requires an API key"]
630 async fn test_span_context_isolation() {
631 let stop = Arc::new(AtomicBool::new(false));
632 let leak_count = Arc::new(AtomicU32::new(0));
633
634 let bg_stop = stop.clone();
636 let bg_leak = leak_count.clone();
637 let bg_handle = tokio::spawn(async move {
638 background_logger(bg_stop, bg_leak).await;
639 });
640
641 tokio::time::sleep(Duration::from_millis(100)).await;
643
644 let client = anthropic::Client::from_env();
647 let agent = client
648 .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
649 .preamble("You are a helpful assistant.")
650 .temperature(0.1)
651 .max_tokens(100)
652 .build();
653
654 let mut stream = agent
655 .stream_prompt("Say 'hello world' and nothing else.")
656 .await;
657
658 let mut full_content = String::new();
659 while let Some(item) = stream.next().await {
660 match item {
661 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
662 text,
663 ))) => {
664 full_content.push_str(&text.text);
665 }
666 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
667 break;
668 }
669 Err(e) => {
670 tracing::warn!("Error: {:?}", e);
671 break;
672 }
673 _ => {}
674 }
675 }
676
677 tracing::info!("Got response: {:?}", full_content);
678
679 stop.store(true, Ordering::Relaxed);
681 bg_handle.await.unwrap();
682
683 let leaks = leak_count.load(Ordering::Relaxed);
684 assert_eq!(
685 leaks, 0,
686 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
687 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
688 );
689 }
690}