1use crate::{
2 OneOrMany,
3 agent::prompt_request::HookAction,
4 completion::GetTokenUsage,
5 json_utils,
6 message::{AssistantContent, Reasoning, ToolResult, ToolResultContent, UserContent},
7 streaming::{StreamedAssistantContent, StreamedUserContent, StreamingCompletion},
8 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
9};
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::{pin::Pin, sync::Arc};
13use tokio::sync::RwLock;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use super::ToolCallHookAction;
18use crate::{
19 agent::Agent,
20 completion::{CompletionError, CompletionModel, PromptError},
21 message::{Message, Text},
22 tool::ToolSetError,
23};
24
25#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
26pub type StreamingResult<R> =
27 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
28
29#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
30pub type StreamingResult<R> =
31 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
32
33#[derive(Deserialize, Serialize, Debug, Clone)]
34#[serde(tag = "type", rename_all = "camelCase")]
35#[non_exhaustive]
36pub enum MultiTurnStreamItem<R> {
37 StreamAssistantItem(StreamedAssistantContent<R>),
39 StreamUserItem(StreamedUserContent),
41 FinalResponse(FinalResponse),
43}
44
45#[derive(Deserialize, Serialize, Debug, Clone)]
46#[serde(rename_all = "camelCase")]
47pub struct FinalResponse {
48 response: String,
49 aggregated_usage: crate::completion::Usage,
50}
51
52impl FinalResponse {
53 pub fn empty() -> Self {
54 Self {
55 response: String::new(),
56 aggregated_usage: crate::completion::Usage::new(),
57 }
58 }
59
60 pub fn response(&self) -> &str {
61 &self.response
62 }
63
64 pub fn usage(&self) -> crate::completion::Usage {
65 self.aggregated_usage
66 }
67}
68
69impl<R> MultiTurnStreamItem<R> {
70 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
71 Self::StreamAssistantItem(item)
72 }
73
74 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
75 Self::FinalResponse(FinalResponse {
76 response: response.to_string(),
77 aggregated_usage,
78 })
79 }
80}
81
82#[derive(Debug, thiserror::Error)]
83pub enum StreamingError {
84 #[error("CompletionError: {0}")]
85 Completion(#[from] CompletionError),
86 #[error("PromptError: {0}")]
87 Prompt(#[from] Box<PromptError>),
88 #[error("ToolSetError: {0}")]
89 Tool(#[from] ToolSetError),
90}
91
92pub struct StreamingPromptRequest<M, P>
101where
102 M: CompletionModel,
103 P: StreamingPromptHook<M> + 'static,
104{
105 prompt: Message,
107 chat_history: Option<Vec<Message>>,
110 max_turns: usize,
112 agent: Arc<Agent<M>>,
114 hook: Option<P>,
116}
117
118impl<M, P> StreamingPromptRequest<M, P>
119where
120 M: CompletionModel + 'static,
121 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
122 P: StreamingPromptHook<M>,
123{
124 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> Self {
126 Self {
127 prompt: prompt.into(),
128 chat_history: None,
129 max_turns: agent.default_max_turns.unwrap_or_default(),
130 agent,
131 hook: None,
132 }
133 }
134
135 pub fn multi_turn(mut self, turns: usize) -> Self {
138 self.max_turns = turns;
139 self
140 }
141
142 pub fn with_history(mut self, history: Vec<Message>) -> Self {
144 self.chat_history = Some(history);
145 self
146 }
147
148 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
150 where
151 P2: StreamingPromptHook<M>,
152 {
153 StreamingPromptRequest {
154 prompt: self.prompt,
155 chat_history: self.chat_history,
156 max_turns: self.max_turns,
157 agent: self.agent,
158 hook: Some(hook),
159 }
160 }
161
162 async fn send(self) -> StreamingResult<M::StreamingResponse> {
163 let agent_span = if tracing::Span::current().is_disabled() {
164 info_span!(
165 "invoke_agent",
166 gen_ai.operation.name = "invoke_agent",
167 gen_ai.agent.name = self.agent.name(),
168 gen_ai.system_instructions = self.agent.preamble,
169 gen_ai.prompt = tracing::field::Empty,
170 gen_ai.completion = tracing::field::Empty,
171 gen_ai.usage.input_tokens = tracing::field::Empty,
172 gen_ai.usage.output_tokens = tracing::field::Empty,
173 )
174 } else {
175 tracing::Span::current()
176 };
177
178 let prompt = self.prompt;
179 if let Some(text) = prompt.rag_text() {
180 agent_span.record("gen_ai.prompt", text);
181 }
182
183 let agent = self.agent;
184
185 let chat_history = if let Some(history) = self.chat_history {
186 Arc::new(RwLock::new(history))
187 } else {
188 Arc::new(RwLock::new(vec![]))
189 };
190
191 let mut current_max_turns = 0;
192 let mut last_prompt_error = String::new();
193
194 let mut last_text_response = String::new();
195 let mut is_text_response = false;
196 let mut max_turns_reached = false;
197
198 let mut aggregated_usage = crate::completion::Usage::new();
199
200 let stream = async_stream::stream! {
207 let mut current_prompt = prompt.clone();
208 let mut did_call_tool = false;
209
210 'outer: loop {
211 if current_max_turns > self.max_turns + 1 {
212 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
213 max_turns_reached = true;
214 break;
215 }
216
217 current_max_turns += 1;
218
219 if self.max_turns > 1 {
220 tracing::info!(
221 "Current conversation Turns: {}/{}",
222 current_max_turns,
223 self.max_turns
224 );
225 }
226
227 if let Some(ref hook) = self.hook {
228 let reader = chat_history.read().await;
229 if let HookAction::Terminate { reason } = hook.on_completion_call(¤t_prompt, &reader.to_vec())
230 .await {
231
232 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
233 reason
234 ).into()));
235 }
236 }
237
238 let chat_stream_span = info_span!(
239 target: "rig::agent_chat",
240 parent: tracing::Span::current(),
241 "chat_streaming",
242 gen_ai.operation.name = "chat",
243 gen_ai.agent.name = &agent.name(),
244 gen_ai.system_instructions = &agent.preamble,
245 gen_ai.provider.name = tracing::field::Empty,
246 gen_ai.request.model = tracing::field::Empty,
247 gen_ai.response.id = tracing::field::Empty,
248 gen_ai.response.model = tracing::field::Empty,
249 gen_ai.usage.output_tokens = tracing::field::Empty,
250 gen_ai.usage.input_tokens = tracing::field::Empty,
251 gen_ai.input.messages = tracing::field::Empty,
252 gen_ai.output.messages = tracing::field::Empty,
253 );
254
255 let mut stream = tracing::Instrument::instrument(
256 agent
257 .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone())
258 .await?
259 .stream(), chat_stream_span
260 )
261
262 .await?;
263
264 chat_history.write().await.push(current_prompt.clone());
265
266 let mut tool_calls = vec![];
267 let mut tool_results = vec![];
268 let mut accumulated_reasoning: Option<rig::message::Reasoning> = None;
269
270 while let Some(content) = stream.next().await {
271 match content {
272 Ok(StreamedAssistantContent::Text(text)) => {
273 if !is_text_response {
274 last_text_response = String::new();
275 is_text_response = true;
276 }
277 last_text_response.push_str(&text.text);
278 if let Some(ref hook) = self.hook &&
279 let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &last_text_response).await {
280 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
281 reason
282 ).into()));
283 }
284
285 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
286 did_call_tool = false;
287 },
288 Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id }) => {
289 let tool_span = info_span!(
290 parent: tracing::Span::current(),
291 "execute_tool",
292 gen_ai.operation.name = "execute_tool",
293 gen_ai.tool.type = "function",
294 gen_ai.tool.name = tracing::field::Empty,
295 gen_ai.tool.call.id = tracing::field::Empty,
296 gen_ai.tool.call.arguments = tracing::field::Empty,
297 gen_ai.tool.call.result = tracing::field::Empty
298 );
299
300 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
301
302 let tc_result = async {
303 let tool_span = tracing::Span::current();
304 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
305 if let Some(ref hook) = self.hook {
306 let action = hook
307 .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
308 .await;
309
310 if let ToolCallHookAction::Terminate { reason } = action {
311 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
312 reason
313 ).into()));
314 }
315
316 if let ToolCallHookAction::Skip { reason } = action {
317 tracing::info!(
319 tool_name = tool_call.function.name.as_str(),
320 reason = reason,
321 "Tool call rejected"
322 );
323 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
324 tool_calls.push(tool_call_msg);
325 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
326 did_call_tool = true;
327 return Ok(reason);
328 }
329 }
330
331 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
332 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
333
334 let tool_result = match
335 agent.tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
336 Ok(thing) => thing,
337 Err(e) => {
338 tracing::warn!("Error while calling tool: {e}");
339 e.to_string()
340 }
341 };
342
343 tool_span.record("gen_ai.tool.call.result", &tool_result);
344
345 if let Some(ref hook) = self.hook &&
346 let HookAction::Terminate { reason } =
347 hook.on_tool_result(
348 &tool_call.function.name,
349 tool_call.call_id.clone(),
350 &internal_call_id,
351 &tool_args,
352 &tool_result.to_string()
353 )
354 .await {
355 return Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
356 reason
357 ).into()));
358 }
359
360 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
361
362 tool_calls.push(tool_call_msg);
363 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
364
365 did_call_tool = true;
366 Ok(tool_result)
367 }.instrument(tool_span).await;
368
369 match tc_result {
370 Ok(text) => {
371 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
372 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
373 }
374 Err(e) => {
375 yield Err(e);
376 }
377 }
378 },
379 Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content }) => {
380 if let Some(ref hook) = self.hook {
381 let (name, delta) = match &content {
382 rig::streaming::ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
383 rig::streaming::ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
384 };
385
386 if let HookAction::Terminate { reason } = hook.on_tool_call_delta(&id, &internal_call_id, name, delta)
387 .await {
388 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
389 reason
390 ).into()));
391 }
392 }
393 }
394 Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })) => {
395 if let Some(ref mut existing) = accumulated_reasoning {
399 existing.reasoning.extend(reasoning.clone());
400 } else {
401 accumulated_reasoning = Some(rig::message::Reasoning {
402 reasoning: reasoning.clone(),
403 id: id.clone(),
404 signature: signature.clone(),
405 });
406 }
407 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
408 did_call_tool = false;
409 },
410 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
411 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
412 did_call_tool = false;
413 },
414 Ok(StreamedAssistantContent::Final(final_resp)) => {
415 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
416 if is_text_response {
417 if let Some(ref hook) = self.hook &&
418 let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(&prompt, &final_resp).await {
419 yield Err(StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.read().await.to_vec(),
420 reason
421 ).into()));
422 }
423
424 tracing::Span::current().record("gen_ai.completion", &last_text_response);
425 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
426 is_text_response = false;
427 }
428 }
429 Err(e) => {
430 yield Err(e.into());
431 break 'outer;
432 }
433 }
434 }
435
436 if !tool_calls.is_empty() || accumulated_reasoning.is_some() {
439 let mut content_items: Vec<rig::message::AssistantContent> = vec![];
440
441 if let Some(reasoning) = accumulated_reasoning.take() {
443 content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
444 }
445
446 content_items.extend(tool_calls.clone());
447
448 if !content_items.is_empty() {
449 chat_history.write().await.push(Message::Assistant {
450 id: None,
451 content: OneOrMany::many(content_items).expect("Should have at least one item"),
452 });
453 }
454 }
455
456 for (id, call_id, tool_result) in tool_results {
458 if let Some(call_id) = call_id {
459 chat_history.write().await.push(Message::User {
460 content: OneOrMany::one(UserContent::tool_result_with_call_id(
461 &id,
462 call_id.clone(),
463 OneOrMany::one(ToolResultContent::text(&tool_result)),
464 )),
465 });
466 } else {
467 chat_history.write().await.push(Message::User {
468 content: OneOrMany::one(UserContent::tool_result(
469 &id,
470 OneOrMany::one(ToolResultContent::text(&tool_result)),
471 )),
472 });
473 }
474 }
475
476 current_prompt = match chat_history.write().await.pop() {
478 Some(prompt) => prompt,
479 None => unreachable!("Chat history should never be empty at this point"),
480 };
481
482 if !did_call_tool {
483 let current_span = tracing::Span::current();
484 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
485 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
486 tracing::info!("Agent multi-turn stream finished");
487 yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage));
488 break;
489 }
490 }
491
492 if max_turns_reached {
493 yield Err(Box::new(PromptError::MaxTurnsError {
494 max_turns: self.max_turns,
495 chat_history: Box::new((*chat_history.read().await).clone()),
496 prompt: Box::new(last_prompt_error.clone().into()),
497 }).into());
498 }
499 };
500
501 Box::pin(stream.instrument(agent_span))
502 }
503}
504
505impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
506where
507 M: CompletionModel + 'static,
508 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
509 P: StreamingPromptHook<M> + 'static,
510{
511 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
513
514 fn into_future(self) -> Self::IntoFuture {
515 Box::pin(async move { self.send().await })
517 }
518}
519
520pub async fn stream_to_stdout<R>(
522 stream: &mut StreamingResult<R>,
523) -> Result<FinalResponse, std::io::Error> {
524 let mut final_res = FinalResponse::empty();
525 print!("Response: ");
526 while let Some(content) = stream.next().await {
527 match content {
528 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
529 Text { text },
530 ))) => {
531 print!("{text}");
532 std::io::Write::flush(&mut std::io::stdout()).unwrap();
533 }
534 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
535 Reasoning { reasoning, .. },
536 ))) => {
537 let reasoning = reasoning.join("\n");
538 print!("{reasoning}");
539 std::io::Write::flush(&mut std::io::stdout()).unwrap();
540 }
541 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
542 final_res = res;
543 }
544 Err(err) => {
545 eprintln!("Error: {err}");
546 }
547 _ => {}
548 }
549 }
550
551 Ok(final_res)
552}
553
554pub trait StreamingPromptHook<M>: Clone + Send + Sync
557where
558 M: CompletionModel,
559{
560 fn on_completion_call(
562 &self,
563 _prompt: &Message,
564 _history: &[Message],
565 ) -> impl Future<Output = HookAction> + Send {
566 async { HookAction::cont() }
567 }
568
569 fn on_text_delta(
571 &self,
572 _text_delta: &str,
573 _aggregated_text: &str,
574 ) -> impl Future<Output = HookAction> + Send {
575 async { HookAction::cont() }
576 }
577
578 fn on_tool_call_delta(
581 &self,
582 _tool_call_id: &str,
583 _internal_call_id: &str,
584 _tool_name: Option<&str>,
585 _tool_call_delta: &str,
586 ) -> impl Future<Output = HookAction> + Send {
587 async { HookAction::cont() }
588 }
589
590 fn on_stream_completion_response_finish(
592 &self,
593 _prompt: &Message,
594 _response: &<M as CompletionModel>::StreamingResponse,
595 ) -> impl Future<Output = HookAction> + Send {
596 async { HookAction::cont() }
597 }
598
599 fn on_tool_call(
605 &self,
606 _tool_name: &str,
607 _tool_call_id: Option<String>,
608 _internal_call_id: &str,
609 _args: &str,
610 ) -> impl Future<Output = ToolCallHookAction> + Send {
611 async { ToolCallHookAction::cont() }
612 }
613
614 fn on_tool_result(
616 &self,
617 _tool_name: &str,
618 _tool_call_id: Option<String>,
619 _internal_call_id: &str,
620 _args: &str,
621 _result: &str,
622 ) -> impl Future<Output = HookAction> + Send {
623 async { HookAction::cont() }
624 }
625}
626
627impl<M> StreamingPromptHook<M> for () where M: CompletionModel {}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632 use crate::client::ProviderClient;
633 use crate::client::completion::CompletionClient;
634 use crate::providers::anthropic;
635 use crate::streaming::StreamingPrompt;
636 use futures::StreamExt;
637 use std::sync::Arc;
638 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
639 use std::time::Duration;
640
641 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
644 let mut interval = tokio::time::interval(Duration::from_millis(50));
645 let mut count = 0u32;
646
647 while !stop.load(Ordering::Relaxed) {
648 interval.tick().await;
649 count += 1;
650
651 tracing::event!(
652 target: "background_logger",
653 tracing::Level::INFO,
654 count = count,
655 "Background tick"
656 );
657
658 let current = tracing::Span::current();
660 if !current.is_disabled() && !current.is_none() {
661 leak_count.fetch_add(1, Ordering::Relaxed);
662 }
663 }
664
665 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
666 }
667
668 #[tokio::test(flavor = "current_thread")]
676 #[ignore = "This requires an API key"]
677 async fn test_span_context_isolation() {
678 let stop = Arc::new(AtomicBool::new(false));
679 let leak_count = Arc::new(AtomicU32::new(0));
680
681 let bg_stop = stop.clone();
683 let bg_leak = leak_count.clone();
684 let bg_handle = tokio::spawn(async move {
685 background_logger(bg_stop, bg_leak).await;
686 });
687
688 tokio::time::sleep(Duration::from_millis(100)).await;
690
691 let client = anthropic::Client::from_env();
694 let agent = client
695 .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
696 .preamble("You are a helpful assistant.")
697 .temperature(0.1)
698 .max_tokens(100)
699 .build();
700
701 let mut stream = agent
702 .stream_prompt("Say 'hello world' and nothing else.")
703 .await;
704
705 let mut full_content = String::new();
706 while let Some(item) = stream.next().await {
707 match item {
708 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
709 text,
710 ))) => {
711 full_content.push_str(&text.text);
712 }
713 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
714 break;
715 }
716 Err(e) => {
717 tracing::warn!("Error: {:?}", e);
718 break;
719 }
720 _ => {}
721 }
722 }
723
724 tracing::info!("Got response: {:?}", full_content);
725
726 stop.store(true, Ordering::Relaxed);
728 bg_handle.await.unwrap();
729
730 let leaks = leak_count.load(Ordering::Relaxed);
731 assert_eq!(
732 leaks, 0,
733 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
734 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
735 );
736 }
737}