1use crate::{
2 OneOrMany,
3 agent::CancelSignal,
4 completion::GetTokenUsage,
5 json_utils,
6 message::{AssistantContent, Reasoning, ToolResult, ToolResultContent, UserContent},
7 streaming::{StreamedAssistantContent, StreamedUserContent, StreamingCompletion},
8 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
9};
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::{pin::Pin, sync::Arc};
13use tokio::sync::RwLock;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use crate::{
18 agent::Agent,
19 completion::{CompletionError, CompletionModel, PromptError},
20 message::{Message, Text},
21 tool::ToolSetError,
22};
23
24#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
25pub type StreamingResult<R> =
26 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
27
28#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
29pub type StreamingResult<R> =
30 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
31
32#[derive(Deserialize, Serialize, Debug, Clone)]
33#[serde(tag = "type", rename_all = "camelCase")]
34#[non_exhaustive]
35pub enum MultiTurnStreamItem<R> {
36 StreamAssistantItem(StreamedAssistantContent<R>),
38 StreamUserItem(StreamedUserContent),
40 FinalResponse(FinalResponse),
42}
43
44#[derive(Deserialize, Serialize, Debug, Clone)]
45#[serde(rename_all = "camelCase")]
46pub struct FinalResponse {
47 response: String,
48 aggregated_usage: crate::completion::Usage,
49}
50
51impl FinalResponse {
52 pub fn empty() -> Self {
53 Self {
54 response: String::new(),
55 aggregated_usage: crate::completion::Usage::new(),
56 }
57 }
58
59 pub fn response(&self) -> &str {
60 &self.response
61 }
62
63 pub fn usage(&self) -> crate::completion::Usage {
64 self.aggregated_usage
65 }
66}
67
68impl<R> MultiTurnStreamItem<R> {
69 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
70 Self::StreamAssistantItem(item)
71 }
72
73 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
74 Self::FinalResponse(FinalResponse {
75 response: response.to_string(),
76 aggregated_usage,
77 })
78 }
79}
80
81#[derive(Debug, thiserror::Error)]
82pub enum StreamingError {
83 #[error("CompletionError: {0}")]
84 Completion(#[from] CompletionError),
85 #[error("PromptError: {0}")]
86 Prompt(#[from] Box<PromptError>),
87 #[error("ToolSetError: {0}")]
88 Tool(#[from] ToolSetError),
89}
90
91pub struct StreamingPromptRequest<M, P>
100where
101 M: CompletionModel,
102 P: StreamingPromptHook<M> + 'static,
103{
104 prompt: Message,
106 chat_history: Option<Vec<Message>>,
109 max_depth: usize,
111 agent: Arc<Agent<M>>,
113 hook: Option<P>,
115}
116
117impl<M, P> StreamingPromptRequest<M, P>
118where
119 M: CompletionModel + 'static,
120 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
121 P: StreamingPromptHook<M>,
122{
123 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
125 Self {
126 prompt: prompt.into(),
127 chat_history: None,
128 max_depth: agent.default_max_depth.unwrap_or_default(),
129 agent,
130 hook: None,
131 }
132 }
133
134 pub fn multi_turn(mut self, depth: usize) -> Self {
137 self.max_depth = depth;
138 self
139 }
140
141 pub fn with_history(mut self, history: Vec<Message>) -> Self {
143 self.chat_history = Some(history);
144 self
145 }
146
147 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
149 where
150 P2: StreamingPromptHook<M>,
151 {
152 StreamingPromptRequest {
153 prompt: self.prompt,
154 chat_history: self.chat_history,
155 max_depth: self.max_depth,
156 agent: self.agent,
157 hook: Some(hook),
158 }
159 }
160
161 async fn send(self) -> StreamingResult<M::StreamingResponse> {
162 let agent_span = if tracing::Span::current().is_disabled() {
163 info_span!(
164 "invoke_agent",
165 gen_ai.operation.name = "invoke_agent",
166 gen_ai.agent.name = self.agent.name(),
167 gen_ai.system_instructions = self.agent.preamble,
168 gen_ai.prompt = tracing::field::Empty,
169 gen_ai.completion = tracing::field::Empty,
170 gen_ai.usage.input_tokens = tracing::field::Empty,
171 gen_ai.usage.output_tokens = tracing::field::Empty,
172 )
173 } else {
174 tracing::Span::current()
175 };
176
177 let prompt = self.prompt;
178 if let Some(text) = prompt.rag_text() {
179 agent_span.record("gen_ai.prompt", text);
180 }
181
182 let agent = self.agent;
183
184 let chat_history = if let Some(history) = self.chat_history {
185 Arc::new(RwLock::new(history))
186 } else {
187 Arc::new(RwLock::new(vec![]))
188 };
189
190 let mut current_max_depth = 0;
191 let mut last_prompt_error = String::new();
192
193 let mut last_text_response = String::new();
194 let mut is_text_response = false;
195 let mut max_depth_reached = false;
196
197 let mut aggregated_usage = crate::completion::Usage::new();
198
199 let cancel_sig = CancelSignal::new();
200
201 let stream = async_stream::stream! {
208 let mut current_prompt = prompt.clone();
209 let mut did_call_tool = false;
210
211 'outer: loop {
212 if current_max_depth > self.max_depth + 1 {
213 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
214 max_depth_reached = true;
215 break;
216 }
217
218 current_max_depth += 1;
219
220 if self.max_depth > 1 {
221 tracing::info!(
222 "Current conversation depth: {}/{}",
223 current_max_depth,
224 self.max_depth
225 );
226 }
227
228 if let Some(ref hook) = self.hook {
229 let reader = chat_history.read().await;
230 hook.on_completion_call(¤t_prompt, &reader.to_vec(), cancel_sig.clone())
231 .await;
232
233 if cancel_sig.is_cancelled() {
234 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
235 cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
236 ).into()));
237 }
238 }
239
240 let chat_stream_span = info_span!(
241 target: "rig::agent_chat",
242 parent: tracing::Span::current(),
243 "chat_streaming",
244 gen_ai.operation.name = "chat",
245 gen_ai.agent.name = &agent.name(),
246 gen_ai.system_instructions = &agent.preamble,
247 gen_ai.provider.name = tracing::field::Empty,
248 gen_ai.request.model = tracing::field::Empty,
249 gen_ai.response.id = tracing::field::Empty,
250 gen_ai.response.model = tracing::field::Empty,
251 gen_ai.usage.output_tokens = tracing::field::Empty,
252 gen_ai.usage.input_tokens = tracing::field::Empty,
253 gen_ai.input.messages = tracing::field::Empty,
254 gen_ai.output.messages = tracing::field::Empty,
255 );
256
257 let mut stream = tracing::Instrument::instrument(
258 agent
259 .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
260 .await?
261 .stream(), chat_stream_span
262 )
263
264 .await?;
265
266 chat_history.write().await.push(current_prompt.clone());
267
268 let mut tool_calls = vec![];
269 let mut tool_results = vec![];
270
271 while let Some(content) = stream.next().await {
272 match content {
273 Ok(StreamedAssistantContent::Text(text)) => {
274 if !is_text_response {
275 last_text_response = String::new();
276 is_text_response = true;
277 }
278 last_text_response.push_str(&text.text);
279 if let Some(ref hook) = self.hook {
280 hook.on_text_delta(&text.text, &last_text_response, cancel_sig.clone()).await;
281 if cancel_sig.is_cancelled() {
282 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
283 cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
284 ).into()));
285 }
286 }
287 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
288 did_call_tool = false;
289 },
290 Ok(StreamedAssistantContent::ToolCall(tool_call)) => {
291 let tool_span = info_span!(
292 parent: tracing::Span::current(),
293 "execute_tool",
294 gen_ai.operation.name = "execute_tool",
295 gen_ai.tool.type = "function",
296 gen_ai.tool.name = tracing::field::Empty,
297 gen_ai.tool.call.id = tracing::field::Empty,
298 gen_ai.tool.call.arguments = tracing::field::Empty,
299 gen_ai.tool.call.result = tracing::field::Empty
300 );
301
302 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall(tool_call.clone())));
303
304 let tc_result = async {
305 let tool_span = tracing::Span::current();
306 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
307 if let Some(ref hook) = self.hook {
308 hook.on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &tool_args, cancel_sig.clone()).await;
309 if cancel_sig.is_cancelled() {
310 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
311 cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
312 ).into()));
313 }
314 }
315
316 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
317 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
318
319 let tool_result = match
320 agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
321 Ok(thing) => thing,
322 Err(e) => {
323 tracing::warn!("Error while calling tool: {e}");
324 e.to_string()
325 }
326 };
327
328 tool_span.record("gen_ai.tool.call.result", &tool_result);
329
330 if let Some(ref hook) = self.hook {
331 hook.on_tool_result(&tool_call.function.name, tool_call.call_id.clone(), &tool_args, &tool_result.to_string(), cancel_sig.clone())
332 .await;
333
334 if cancel_sig.is_cancelled() {
335 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
336 cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
337 ).into()));
338 }
339 }
340
341 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
342
343 tool_calls.push(tool_call_msg);
344 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
345
346 did_call_tool = true;
347 Ok(tool_result)
348 }.instrument(tool_span).await;
349
350 match tc_result {
351 Ok(text) => {
352 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: OneOrMany::one(ToolResultContent::Text(Text { text })) };
353 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult(tr)));
354 }
355 Err(e) => {
356 yield Err(e);
357 }
358 }
359 },
360 Ok(StreamedAssistantContent::ToolCallDelta { id, content }) => {
361 if let Some(ref hook) = self.hook {
362 let (name, delta) = match &content {
363 rig::streaming::ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
364 rig::streaming::ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
365 };
366 hook.on_tool_call_delta(&id, name, delta, cancel_sig.clone())
367 .await;
368
369 if cancel_sig.is_cancelled() {
370 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
371 cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
372 ).into()));
373 }
374 }
375 }
376 Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })) => {
377 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
378 did_call_tool = false;
379 },
380 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
381 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
382 did_call_tool = false;
383 },
384 Ok(StreamedAssistantContent::Final(final_resp)) => {
385 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
386 if is_text_response {
387 if let Some(ref hook) = self.hook {
388 hook.on_stream_completion_response_finish(&prompt, &final_resp, cancel_sig.clone()).await;
389
390 if cancel_sig.is_cancelled() {
391 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
392 cancel_sig.cancel_reason().unwrap_or("<no reason given>"),
393 ).into()));
394 }
395 }
396
397 tracing::Span::current().record("gen_ai.completion", &last_text_response);
398 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
399 is_text_response = false;
400 }
401 }
402 Err(e) => {
403 yield Err(e.into());
404 break 'outer;
405 }
406 }
407 }
408
409 if !tool_calls.is_empty() {
411 chat_history.write().await.push(Message::Assistant {
412 id: None,
413 content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"),
414 });
415 }
416
417 for (id, call_id, tool_result) in tool_results {
419 if let Some(call_id) = call_id {
420 chat_history.write().await.push(Message::User {
421 content: OneOrMany::one(UserContent::tool_result_with_call_id(
422 &id,
423 call_id.clone(),
424 OneOrMany::one(ToolResultContent::text(&tool_result)),
425 )),
426 });
427 } else {
428 chat_history.write().await.push(Message::User {
429 content: OneOrMany::one(UserContent::tool_result(
430 &id,
431 OneOrMany::one(ToolResultContent::text(&tool_result)),
432 )),
433 });
434 }
435 }
436
437 current_prompt = match chat_history.write().await.pop() {
439 Some(prompt) => prompt,
440 None => unreachable!("Chat history should never be empty at this point"),
441 };
442
443 if !did_call_tool {
444 let current_span = tracing::Span::current();
445 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
446 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
447 tracing::info!("Agent multi-turn stream finished");
448 yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
449 break;
450 }
451 }
452
453 if max_depth_reached {
454 yield Err(Box::new(PromptError::MaxDepthError {
455 max_depth: self.max_depth,
456 chat_history: Box::new((*chat_history.read().await).clone()),
457 prompt: Box::new(last_prompt_error.clone().into()),
458 }).into());
459 }
460 };
461
462 Box::pin(stream.instrument(agent_span))
463 }
464}
465
466impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
467where
468 M: CompletionModel + 'static,
469 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
470 P: StreamingPromptHook<M> + 'static,
471{
472 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
474
475 fn into_future(self) -> Self::IntoFuture {
476 Box::pin(async move { self.send().await })
478 }
479}
480
481pub async fn stream_to_stdout<R>(
483 stream: &mut StreamingResult<R>,
484) -> Result<FinalResponse, std::io::Error> {
485 let mut final_res = FinalResponse::empty();
486 print!("Response: ");
487 while let Some(content) = stream.next().await {
488 match content {
489 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
490 Text { text },
491 ))) => {
492 print!("{text}");
493 std::io::Write::flush(&mut std::io::stdout()).unwrap();
494 }
495 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
496 Reasoning { reasoning, .. },
497 ))) => {
498 let reasoning = reasoning.join("\n");
499 print!("{reasoning}");
500 std::io::Write::flush(&mut std::io::stdout()).unwrap();
501 }
502 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
503 final_res = res;
504 }
505 Err(err) => {
506 eprintln!("Error: {err}");
507 }
508 _ => {}
509 }
510 }
511
512 Ok(final_res)
513}
514
515pub trait StreamingPromptHook<M>: Clone + Send + Sync
518where
519 M: CompletionModel,
520{
521 #[allow(unused_variables)]
522 fn on_completion_call(
524 &self,
525 prompt: &Message,
526 history: &[Message],
527 cancel_sig: CancelSignal,
528 ) -> impl Future<Output = ()> + Send {
529 async {}
530 }
531
532 #[allow(unused_variables)]
533 fn on_text_delta(
535 &self,
536 text_delta: &str,
537 aggregated_text: &str,
538 cancel_sig: CancelSignal,
539 ) -> impl Future<Output = ()> + Send {
540 async {}
541 }
542
543 #[allow(unused_variables)]
544 fn on_tool_call_delta(
547 &self,
548 tool_call_id: &str,
549 tool_name: Option<&str>,
550 tool_call_delta: &str,
551 cancel_sig: CancelSignal,
552 ) -> impl Future<Output = ()> + Send {
553 async {}
554 }
555
556 #[allow(unused_variables)]
557 fn on_stream_completion_response_finish(
559 &self,
560 prompt: &Message,
561 response: &<M as CompletionModel>::StreamingResponse,
562 cancel_sig: CancelSignal,
563 ) -> impl Future<Output = ()> + Send {
564 async {}
565 }
566
567 #[allow(unused_variables)]
568 fn on_tool_call(
570 &self,
571 tool_name: &str,
572 tool_call_id: Option<String>,
573 args: &str,
574 cancel_sig: CancelSignal,
575 ) -> impl Future<Output = ()> + Send {
576 async {}
577 }
578
579 #[allow(unused_variables)]
580 fn on_tool_result(
582 &self,
583 tool_name: &str,
584 tool_call_id: Option<String>,
585 args: &str,
586 result: &str,
587 cancel_sig: CancelSignal,
588 ) -> impl Future<Output = ()> + Send {
589 async {}
590 }
591}
592
593impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598 use crate::client::ProviderClient;
599 use crate::client::completion::CompletionClient;
600 use crate::providers::anthropic;
601 use crate::streaming::StreamingPrompt;
602 use futures::StreamExt;
603 use std::sync::Arc;
604 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
605 use std::time::Duration;
606
607 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
610 let mut interval = tokio::time::interval(Duration::from_millis(50));
611 let mut count = 0u32;
612
613 while !stop.load(Ordering::Relaxed) {
614 interval.tick().await;
615 count += 1;
616
617 tracing::event!(
618 target: "background_logger",
619 tracing::Level::INFO,
620 count = count,
621 "Background tick"
622 );
623
624 let current = tracing::Span::current();
626 if !current.is_disabled() && !current.is_none() {
627 leak_count.fetch_add(1, Ordering::Relaxed);
628 }
629 }
630
631 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
632 }
633
634 #[tokio::test(flavor = "current_thread")]
642 #[ignore = "This requires an API key"]
643 async fn test_span_context_isolation() {
644 let stop = Arc::new(AtomicBool::new(false));
645 let leak_count = Arc::new(AtomicU32::new(0));
646
647 let bg_stop = stop.clone();
649 let bg_leak = leak_count.clone();
650 let bg_handle = tokio::spawn(async move {
651 background_logger(bg_stop, bg_leak).await;
652 });
653
654 tokio::time::sleep(Duration::from_millis(100)).await;
656
657 let client = anthropic::Client::from_env();
660 let agent = client
661 .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
662 .preamble("You are a helpful assistant.")
663 .temperature(0.1)
664 .max_tokens(100)
665 .build();
666
667 let mut stream = agent
668 .stream_prompt("Say 'hello world' and nothing else.")
669 .await;
670
671 let mut full_content = String::new();
672 while let Some(item) = stream.next().await {
673 match item {
674 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
675 text,
676 ))) => {
677 full_content.push_str(&text.text);
678 }
679 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
680 break;
681 }
682 Err(e) => {
683 tracing::warn!("Error: {:?}", e);
684 break;
685 }
686 _ => {}
687 }
688 }
689
690 tracing::info!("Got response: {:?}", full_content);
691
692 stop.store(true, Ordering::Relaxed);
694 bg_handle.await.unwrap();
695
696 let leaks = leak_count.load(Ordering::Relaxed);
697 assert_eq!(
698 leaks, 0,
699 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
700 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
701 );
702 }
703}