1use crate::{
2 OneOrMany,
3 agent::completion::{DynamicContextStore, build_completion_request},
4 agent::prompt_request::{HookAction, hooks::PromptHook},
5 completion::{Document, GetTokenUsage},
6 json_utils,
7 message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent},
8 streaming::{StreamedAssistantContent, StreamedUserContent},
9 tool::server::ToolServerHandle,
10 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
11};
12use futures::{Stream, StreamExt};
13use serde::{Deserialize, Serialize};
14use std::{pin::Pin, sync::Arc};
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::ToolCallHookAction;
19use crate::{
20 agent::Agent,
21 completion::{CompletionError, CompletionModel, PromptError},
22 message::{Message, Text},
23 tool::ToolSetError,
24};
25
26#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
27pub type StreamingResult<R> =
28 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
29
30#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
31pub type StreamingResult<R> =
32 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(tag = "type", rename_all = "camelCase")]
36#[non_exhaustive]
37pub enum MultiTurnStreamItem<R> {
38 StreamAssistantItem(StreamedAssistantContent<R>),
40 StreamUserItem(StreamedUserContent),
42 FinalResponse(FinalResponse),
44}
45
46#[derive(Deserialize, Serialize, Debug, Clone)]
47#[serde(rename_all = "camelCase")]
48pub struct FinalResponse {
49 response: String,
52 aggregated_usage: crate::completion::Usage,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 history: Option<Vec<Message>>,
55}
56
57impl FinalResponse {
58 pub fn empty() -> Self {
59 Self {
60 response: String::new(),
61 aggregated_usage: crate::completion::Usage::new(),
62 history: None,
63 }
64 }
65
66 pub fn response(&self) -> &str {
68 &self.response
69 }
70
71 pub fn usage(&self) -> crate::completion::Usage {
72 self.aggregated_usage
73 }
74
75 pub fn history(&self) -> Option<&[Message]> {
76 self.history.as_deref()
77 }
78}
79
80impl<R> MultiTurnStreamItem<R> {
81 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
82 Self::StreamAssistantItem(item)
83 }
84
85 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
86 Self::FinalResponse(FinalResponse {
87 response: response.to_string(),
88 aggregated_usage,
89 history: None,
90 })
91 }
92
93 pub fn final_response_with_history(
94 response: &str,
95 aggregated_usage: crate::completion::Usage,
96 history: Option<Vec<Message>>,
97 ) -> Self {
98 Self::FinalResponse(FinalResponse {
99 response: response.to_string(),
100 aggregated_usage,
101 history,
102 })
103 }
104}
105
106fn merge_reasoning_blocks(
107 accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
108 incoming: &crate::message::Reasoning,
109) {
110 let ids_match = |existing: &crate::message::Reasoning| {
111 matches!(
112 (&existing.id, &incoming.id),
113 (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
114 )
115 };
116
117 if let Some(existing) = accumulated_reasoning
118 .iter_mut()
119 .rev()
120 .find(|existing| ids_match(existing))
121 {
122 existing.content.extend(incoming.content.clone());
123 } else {
124 accumulated_reasoning.push(incoming.clone());
125 }
126}
127
128fn build_full_history(
130 chat_history: Option<&[Message]>,
131 new_messages: Vec<Message>,
132) -> Vec<Message> {
133 let input = chat_history.unwrap_or(&[]);
134 input.iter().cloned().chain(new_messages).collect()
135}
136
137fn build_history_for_request(
139 chat_history: Option<&[Message]>,
140 new_messages: &[Message],
141) -> Vec<Message> {
142 let input = chat_history.unwrap_or(&[]);
143 input.iter().chain(new_messages.iter()).cloned().collect()
144}
145
146async fn cancelled_prompt_error(
147 chat_history: Option<&[Message]>,
148 new_messages: Vec<Message>,
149 reason: String,
150) -> StreamingError {
151 StreamingError::Prompt(
152 PromptError::prompt_cancelled(build_full_history(chat_history, new_messages), reason)
153 .into(),
154 )
155}
156
157fn tool_result_to_user_message(
158 id: String,
159 call_id: Option<String>,
160 tool_result: String,
161) -> Message {
162 let content = OneOrMany::one(ToolResultContent::text(tool_result));
163 let user_content = match call_id {
164 Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
165 None => UserContent::tool_result(id, content),
166 };
167
168 Message::User {
169 content: OneOrMany::one(user_content),
170 }
171}
172
173fn assistant_text_from_choice(choice: &OneOrMany<AssistantContent>) -> String {
174 choice
175 .iter()
176 .filter_map(|content| match content {
177 AssistantContent::Text(text) => Some(text.text.as_str()),
178 _ => None,
179 })
180 .collect()
181}
182
183#[derive(Debug, thiserror::Error)]
184pub enum StreamingError {
185 #[error("CompletionError: {0}")]
186 Completion(#[from] CompletionError),
187 #[error("PromptError: {0}")]
188 Prompt(#[from] Box<PromptError>),
189 #[error("ToolSetError: {0}")]
190 Tool(#[from] ToolSetError),
191}
192
193const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
194
195pub struct StreamingPromptRequest<M, P>
204where
205 M: CompletionModel,
206 P: PromptHook<M> + 'static,
207{
208 prompt: Message,
210 chat_history: Option<Vec<Message>>,
212 max_turns: usize,
214
215 model: Arc<M>,
218 agent_name: Option<String>,
220 preamble: Option<String>,
222 static_context: Vec<Document>,
224 temperature: Option<f64>,
226 max_tokens: Option<u64>,
228 additional_params: Option<serde_json::Value>,
230 tool_server_handle: ToolServerHandle,
232 dynamic_context: DynamicContextStore,
234 tool_choice: Option<ToolChoice>,
236 output_schema: Option<schemars::Schema>,
238 hook: Option<P>,
240}
241
242impl<M, P> StreamingPromptRequest<M, P>
243where
244 M: CompletionModel + 'static,
245 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
246 P: PromptHook<M>,
247{
248 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
251 StreamingPromptRequest {
252 prompt: prompt.into(),
253 chat_history: None,
254 max_turns: agent.default_max_turns.unwrap_or_default(),
255 model: agent.model.clone(),
256 agent_name: agent.name.clone(),
257 preamble: agent.preamble.clone(),
258 static_context: agent.static_context.clone(),
259 temperature: agent.temperature,
260 max_tokens: agent.max_tokens,
261 additional_params: agent.additional_params.clone(),
262 tool_server_handle: agent.tool_server_handle.clone(),
263 dynamic_context: agent.dynamic_context.clone(),
264 tool_choice: agent.tool_choice.clone(),
265 output_schema: agent.output_schema.clone(),
266 hook: None,
267 }
268 }
269
270 pub fn from_agent<P2>(
272 agent: &Agent<M, P2>,
273 prompt: impl Into<Message>,
274 ) -> StreamingPromptRequest<M, P2>
275 where
276 P2: PromptHook<M>,
277 {
278 StreamingPromptRequest {
279 prompt: prompt.into(),
280 chat_history: None,
281 max_turns: agent.default_max_turns.unwrap_or_default(),
282 model: agent.model.clone(),
283 agent_name: agent.name.clone(),
284 preamble: agent.preamble.clone(),
285 static_context: agent.static_context.clone(),
286 temperature: agent.temperature,
287 max_tokens: agent.max_tokens,
288 additional_params: agent.additional_params.clone(),
289 tool_server_handle: agent.tool_server_handle.clone(),
290 dynamic_context: agent.dynamic_context.clone(),
291 tool_choice: agent.tool_choice.clone(),
292 output_schema: agent.output_schema.clone(),
293 hook: agent.hook.clone(),
294 }
295 }
296
297 fn agent_name(&self) -> &str {
298 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
299 }
300
301 pub fn multi_turn(mut self, turns: usize) -> Self {
304 self.max_turns = turns;
305 self
306 }
307
308 pub fn with_history<I, T>(mut self, history: I) -> Self
321 where
322 I: IntoIterator<Item = T>,
323 T: Into<Message>,
324 {
325 self.chat_history = Some(history.into_iter().map(Into::into).collect());
326 self
327 }
328
329 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
332 where
333 P2: PromptHook<M>,
334 {
335 StreamingPromptRequest {
336 prompt: self.prompt,
337 chat_history: self.chat_history,
338 max_turns: self.max_turns,
339 model: self.model,
340 agent_name: self.agent_name,
341 preamble: self.preamble,
342 static_context: self.static_context,
343 temperature: self.temperature,
344 max_tokens: self.max_tokens,
345 additional_params: self.additional_params,
346 tool_server_handle: self.tool_server_handle,
347 dynamic_context: self.dynamic_context,
348 tool_choice: self.tool_choice,
349 output_schema: self.output_schema,
350 hook: Some(hook),
351 }
352 }
353
354 async fn send(self) -> StreamingResult<M::StreamingResponse> {
355 let agent_span = if tracing::Span::current().is_disabled() {
356 info_span!(
357 "invoke_agent",
358 gen_ai.operation.name = "invoke_agent",
359 gen_ai.agent.name = self.agent_name(),
360 gen_ai.system_instructions = self.preamble,
361 gen_ai.prompt = tracing::field::Empty,
362 gen_ai.completion = tracing::field::Empty,
363 gen_ai.usage.input_tokens = tracing::field::Empty,
364 gen_ai.usage.output_tokens = tracing::field::Empty,
365 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
366 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
367 )
368 } else {
369 tracing::Span::current()
370 };
371
372 let prompt = self.prompt;
373 if let Some(text) = prompt.rag_text() {
374 agent_span.record("gen_ai.prompt", text);
375 }
376
377 let model = self.model.clone();
379 let preamble = self.preamble.clone();
380 let static_context = self.static_context.clone();
381 let temperature = self.temperature;
382 let max_tokens = self.max_tokens;
383 let additional_params = self.additional_params.clone();
384 let tool_server_handle = self.tool_server_handle.clone();
385 let dynamic_context = self.dynamic_context.clone();
386 let tool_choice = self.tool_choice.clone();
387 let agent_name = self.agent_name.clone();
388 let has_history = self.chat_history.is_some();
389 let chat_history = self.chat_history;
390 let mut new_messages: Vec<Message> = vec![prompt.clone()];
391
392 let mut current_max_turns = 0;
393 let mut last_prompt_error = String::new();
394
395 let mut text_delta_response = String::new();
396 let mut saw_text_this_turn = false;
397 let mut max_turns_reached = false;
398 let output_schema = self.output_schema;
399
400 let mut aggregated_usage = crate::completion::Usage::new();
401
402 let stream = async_stream::stream! {
409 'outer: loop {
410 let current_prompt = new_messages
411 .last()
412 .cloned()
413 .expect("streaming loop should always have a pending prompt");
414
415 if current_max_turns > self.max_turns + 1 {
416 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
417 max_turns_reached = true;
418 break;
419 }
420
421 current_max_turns += 1;
422
423 if self.max_turns > 1 {
424 tracing::info!(
425 "Current conversation Turns: {}/{}",
426 current_max_turns,
427 self.max_turns
428 );
429 }
430
431 let history_snapshot: Vec<Message> = build_history_for_request(
432 chat_history.as_deref(),
433 &new_messages[..new_messages.len().saturating_sub(1)],
434 );
435
436 if let Some(ref hook) = self.hook
437 && let HookAction::Terminate { reason } =
438 hook.on_completion_call(¤t_prompt, &history_snapshot).await
439 {
440 yield Err(
441 cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason)
442 .await,
443 );
444 break 'outer;
445 }
446
447 let chat_stream_span = info_span!(
448 target: "rig::agent_chat",
449 parent: tracing::Span::current(),
450 "chat_streaming",
451 gen_ai.operation.name = "chat",
452 gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
453 gen_ai.system_instructions = preamble,
454 gen_ai.provider.name = tracing::field::Empty,
455 gen_ai.request.model = tracing::field::Empty,
456 gen_ai.response.id = tracing::field::Empty,
457 gen_ai.response.model = tracing::field::Empty,
458 gen_ai.usage.output_tokens = tracing::field::Empty,
459 gen_ai.usage.input_tokens = tracing::field::Empty,
460 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
461 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
462 gen_ai.input.messages = tracing::field::Empty,
463 gen_ai.output.messages = tracing::field::Empty,
464 );
465
466 let mut stream = tracing::Instrument::instrument(
467 build_completion_request(
468 &model,
469 current_prompt.clone(),
470 &history_snapshot,
471 preamble.as_deref(),
472 &static_context,
473 temperature,
474 max_tokens,
475 additional_params.as_ref(),
476 tool_choice.as_ref(),
477 &tool_server_handle,
478 &dynamic_context,
479 output_schema.as_ref(),
480 )
481 .await?
482 .stream(), chat_stream_span
483 )
484
485 .await?;
486
487 let mut tool_calls = vec![];
488 let mut tool_results = vec![];
489 let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
490 let mut pending_reasoning_delta_text = String::new();
493 let mut pending_reasoning_delta_id: Option<String> = None;
494 let mut saw_tool_call_this_turn = false;
495
496 while let Some(content) = stream.next().await {
497 match content {
498 Ok(StreamedAssistantContent::Text(text)) => {
499 if !saw_text_this_turn {
500 text_delta_response.clear();
501 saw_text_this_turn = true;
502 }
503 text_delta_response.push_str(&text.text);
504 if let Some(ref hook) = self.hook &&
505 let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &text_delta_response).await {
506 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
507 break 'outer;
508 }
509
510 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
511 },
512 Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id }) => {
513 let tool_span = info_span!(
514 parent: tracing::Span::current(),
515 "execute_tool",
516 gen_ai.operation.name = "execute_tool",
517 gen_ai.tool.type = "function",
518 gen_ai.tool.name = tracing::field::Empty,
519 gen_ai.tool.call.id = tracing::field::Empty,
520 gen_ai.tool.call.arguments = tracing::field::Empty,
521 gen_ai.tool.call.result = tracing::field::Empty
522 );
523
524 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
525
526 let tc_result = async {
527 let tool_span = tracing::Span::current();
528 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
529 if let Some(ref hook) = self.hook {
530 let action = hook
531 .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
532 .await;
533
534 if let ToolCallHookAction::Terminate { reason } = action {
535 return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
536 }
537
538 if let ToolCallHookAction::Skip { reason } = action {
539 tracing::info!(
541 tool_name = tool_call.function.name.as_str(),
542 reason = reason,
543 "Tool call rejected"
544 );
545 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
546 tool_calls.push(tool_call_msg);
547 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
548 saw_tool_call_this_turn = true;
549 return Ok(reason);
550 }
551 }
552
553 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
554 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
555
556 let tool_result = match
557 tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
558 Ok(thing) => thing,
559 Err(e) => {
560 tracing::warn!("Error while calling tool: {e}");
561 e.to_string()
562 }
563 };
564
565 tool_span.record("gen_ai.tool.call.result", &tool_result);
566
567 if let Some(ref hook) = self.hook &&
568 let HookAction::Terminate { reason } =
569 hook.on_tool_result(
570 &tool_call.function.name,
571 tool_call.call_id.clone(),
572 &internal_call_id,
573 &tool_args,
574 &tool_result.to_string()
575 )
576 .await {
577 return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
578 }
579
580 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
581
582 tool_calls.push(tool_call_msg);
583 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
584
585 saw_tool_call_this_turn = true;
586 Ok(tool_result)
587 }.instrument(tool_span).await;
588
589 match tc_result {
590 Ok(text) => {
591 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
592 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
593 }
594 Err(e) => {
595 yield Err(e);
596 break 'outer;
597 }
598 }
599 },
600 Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content }) => {
601 if let Some(ref hook) = self.hook {
602 let (name, delta) = match &content {
603 rig::streaming::ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
604 rig::streaming::ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
605 };
606
607 if let HookAction::Terminate { reason } = hook.on_tool_call_delta(&id, &internal_call_id, name, delta)
608 .await {
609 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
610 break 'outer;
611 }
612 }
613 }
614 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
615 merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
619 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
620 },
621 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
622 pending_reasoning_delta_text.push_str(&reasoning);
626 if pending_reasoning_delta_id.is_none() {
627 pending_reasoning_delta_id = id.clone();
628 }
629 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
630 },
631 Ok(StreamedAssistantContent::Final(final_resp)) => {
632 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
633 if saw_text_this_turn {
634 if let Some(ref hook) = self.hook &&
635 let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(¤t_prompt, &final_resp).await {
636 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
637 break 'outer;
638 }
639
640 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
641 saw_text_this_turn = false;
642 }
643 }
644 Err(e) => {
645 yield Err(e.into());
646 break 'outer;
647 }
648 }
649 }
650
651 if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
655 let mut assembled = crate::message::Reasoning::new(&pending_reasoning_delta_text);
656 if let Some(id) = pending_reasoning_delta_id.take() {
657 assembled = assembled.with_id(id);
658 }
659 accumulated_reasoning.push(assembled);
660 }
661
662 let turn_text_response = assistant_text_from_choice(&stream.choice);
663 tracing::Span::current().record("gen_ai.completion", &turn_text_response);
664
665 if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
668 let mut content_items: Vec<rig::message::AssistantContent> = vec![];
669
670 if !turn_text_response.is_empty() {
672 content_items.push(rig::message::AssistantContent::text(&turn_text_response));
673 }
674
675 for reasoning in accumulated_reasoning.drain(..) {
677 content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
678 }
679
680 content_items.extend(tool_calls.clone());
681
682 if !content_items.is_empty() {
683 new_messages.push(Message::Assistant {
684 id: stream.message_id.clone(),
685 content: OneOrMany::many(content_items).expect("Should have at least one item"),
686 });
687 }
688 }
689
690 for (id, call_id, tool_result) in tool_results {
691 new_messages.push(tool_result_to_user_message(id, call_id, tool_result));
692 }
693
694 if !saw_tool_call_this_turn {
695 if !turn_text_response.is_empty() {
697 new_messages.push(Message::assistant(&turn_text_response));
698 } else {
699 tracing::warn!(
700 agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
701 message_id = ?stream.message_id,
702 "Streaming turn completed without assistant text; final response will be empty"
703 );
704 }
705
706 let current_span = tracing::Span::current();
707 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
708 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
709 current_span.record("gen_ai.usage.cache_read.input_tokens", aggregated_usage.cached_input_tokens);
710 current_span.record("gen_ai.usage.cache_creation.input_tokens", aggregated_usage.cache_creation_input_tokens);
711 tracing::info!("Agent multi-turn stream finished");
712 let final_messages: Option<Vec<Message>> = if has_history {
713 Some(new_messages.clone())
714 } else {
715 None
716 };
717 yield Ok(MultiTurnStreamItem::final_response_with_history(
718 &turn_text_response,
719 aggregated_usage,
720 final_messages,
721 ));
722 break;
723 }
724 }
725
726 if max_turns_reached {
727 yield Err(Box::new(PromptError::MaxTurnsError {
728 max_turns: self.max_turns,
729 chat_history: build_full_history(chat_history.as_deref(), new_messages.clone()).into(),
730 prompt: Box::new(last_prompt_error.clone().into()),
731 }).into());
732 }
733 };
734
735 Box::pin(stream.instrument(agent_span))
736 }
737}
738
739impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
740where
741 M: CompletionModel + 'static,
742 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
743 P: PromptHook<M> + 'static,
744{
745 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
747
748 fn into_future(self) -> Self::IntoFuture {
749 Box::pin(async move { self.send().await })
751 }
752}
753
754pub async fn stream_to_stdout<R>(
756 stream: &mut StreamingResult<R>,
757) -> Result<FinalResponse, std::io::Error> {
758 let mut final_res = FinalResponse::empty();
759 print!("Response: ");
760 while let Some(content) = stream.next().await {
761 match content {
762 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
763 Text { text },
764 ))) => {
765 print!("{text}");
766 std::io::Write::flush(&mut std::io::stdout()).unwrap();
767 }
768 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
769 reasoning,
770 ))) => {
771 let reasoning = reasoning.display_text();
772 print!("{reasoning}");
773 std::io::Write::flush(&mut std::io::stdout()).unwrap();
774 }
775 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
776 final_res = res;
777 }
778 Err(err) => {
779 eprintln!("Error: {err}");
780 }
781 _ => {}
782 }
783 }
784
785 Ok(final_res)
786}
787
788#[cfg(test)]
789mod tests {
790 use super::*;
791 use crate::agent::AgentBuilder;
792 use crate::client::ProviderClient;
793 use crate::client::completion::CompletionClient;
794 use crate::completion::{
795 CompletionError, CompletionModel, CompletionRequest, CompletionResponse,
796 };
797 use crate::message::{AssistantContent, Message, ReasoningContent, UserContent};
798 use crate::providers::anthropic;
799 use crate::streaming::StreamingPrompt;
800 use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
801 use futures::StreamExt;
802 use serde::{Deserialize, Serialize};
803 use std::sync::Arc;
804 use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
805 use std::time::Duration;
806
807 #[test]
808 fn merge_reasoning_blocks_preserves_order_and_signatures() {
809 let mut accumulated = Vec::new();
810 let first = crate::message::Reasoning {
811 id: Some("rs_1".to_string()),
812 content: vec![ReasoningContent::Text {
813 text: "step-1".to_string(),
814 signature: Some("sig-1".to_string()),
815 }],
816 };
817 let second = crate::message::Reasoning {
818 id: Some("rs_1".to_string()),
819 content: vec![
820 ReasoningContent::Text {
821 text: "step-2".to_string(),
822 signature: Some("sig-2".to_string()),
823 },
824 ReasoningContent::Summary("summary".to_string()),
825 ],
826 };
827
828 merge_reasoning_blocks(&mut accumulated, &first);
829 merge_reasoning_blocks(&mut accumulated, &second);
830
831 assert_eq!(accumulated.len(), 1);
832 let merged = accumulated.first().expect("expected accumulated reasoning");
833 assert_eq!(merged.id.as_deref(), Some("rs_1"));
834 assert_eq!(merged.content.len(), 3);
835 assert!(matches!(
836 merged.content.first(),
837 Some(ReasoningContent::Text { text, signature: Some(sig) })
838 if text == "step-1" && sig == "sig-1"
839 ));
840 assert!(matches!(
841 merged.content.get(1),
842 Some(ReasoningContent::Text { text, signature: Some(sig) })
843 if text == "step-2" && sig == "sig-2"
844 ));
845 }
846
847 #[test]
848 fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
849 let mut accumulated = vec![crate::message::Reasoning {
850 id: Some("rs_a".to_string()),
851 content: vec![ReasoningContent::Text {
852 text: "step-1".to_string(),
853 signature: None,
854 }],
855 }];
856 let incoming = crate::message::Reasoning {
857 id: Some("rs_b".to_string()),
858 content: vec![ReasoningContent::Text {
859 text: "step-2".to_string(),
860 signature: None,
861 }],
862 };
863
864 merge_reasoning_blocks(&mut accumulated, &incoming);
865 assert_eq!(accumulated.len(), 2);
866 assert_eq!(
867 accumulated.first().and_then(|r| r.id.as_deref()),
868 Some("rs_a")
869 );
870 assert_eq!(
871 accumulated.get(1).and_then(|r| r.id.as_deref()),
872 Some("rs_b")
873 );
874 }
875
876 #[test]
877 fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
878 let mut accumulated = vec![crate::message::Reasoning {
879 id: None,
880 content: vec![ReasoningContent::Text {
881 text: "first".to_string(),
882 signature: None,
883 }],
884 }];
885 let incoming = crate::message::Reasoning {
886 id: None,
887 content: vec![ReasoningContent::Text {
888 text: "second".to_string(),
889 signature: None,
890 }],
891 };
892
893 merge_reasoning_blocks(&mut accumulated, &incoming);
894 assert_eq!(accumulated.len(), 2);
895 assert!(matches!(
896 accumulated.first(),
897 Some(crate::message::Reasoning {
898 id: None,
899 content
900 }) if matches!(
901 content.first(),
902 Some(ReasoningContent::Text { text, .. }) if text == "first"
903 )
904 ));
905 assert!(matches!(
906 accumulated.get(1),
907 Some(crate::message::Reasoning {
908 id: None,
909 content
910 }) if matches!(
911 content.first(),
912 Some(ReasoningContent::Text { text, .. }) if text == "second"
913 )
914 ));
915 }
916
917 #[derive(Clone, Debug, Deserialize, Serialize)]
918 struct MockStreamingResponse {
919 usage: crate::completion::Usage,
920 }
921
922 impl MockStreamingResponse {
923 fn new(total_tokens: u64) -> Self {
924 let mut usage = crate::completion::Usage::new();
925 usage.total_tokens = total_tokens;
926 Self { usage }
927 }
928 }
929
930 impl crate::completion::GetTokenUsage for MockStreamingResponse {
931 fn token_usage(&self) -> Option<crate::completion::Usage> {
932 Some(self.usage)
933 }
934 }
935
936 fn validate_follow_up_tool_history(request: &CompletionRequest) -> Result<(), String> {
937 let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
938 if history.len() != 3 {
939 return Err(format!(
940 "follow-up request should contain [original user prompt, assistant tool call, user tool result]: {history:?}"
941 ));
942 }
943
944 if !matches!(
945 history.first(),
946 Some(Message::User { content })
947 if matches!(
948 content.first(),
949 UserContent::Text(text) if text.text == "do tool work"
950 )
951 ) {
952 return Err(format!(
953 "follow-up request should begin with the original user prompt: {history:?}"
954 ));
955 }
956
957 if !matches!(
958 history.get(1),
959 Some(Message::Assistant { content, .. })
960 if matches!(
961 content.first(),
962 AssistantContent::ToolCall(tool_call)
963 if tool_call.id == "tool_call_1"
964 && tool_call.call_id.as_deref() == Some("call_1")
965 )
966 ) {
967 return Err(format!(
968 "follow-up request is missing the assistant tool call in position 2: {history:?}"
969 ));
970 }
971
972 if !matches!(
973 history.get(2),
974 Some(Message::User { content })
975 if matches!(
976 content.first(),
977 UserContent::ToolResult(tool_result)
978 if tool_result.id == "tool_call_1"
979 && tool_result.call_id.as_deref() == Some("call_1")
980 )
981 ) {
982 return Err(format!(
983 "follow-up request should end with the user tool result: {history:?}"
984 ));
985 }
986
987 Ok(())
988 }
989
990 #[derive(Clone, Default)]
991 struct MultiTurnMockModel {
992 turn_counter: Arc<AtomicUsize>,
993 }
994
995 #[allow(refining_impl_trait)]
996 impl CompletionModel for MultiTurnMockModel {
997 type Response = ();
998 type StreamingResponse = MockStreamingResponse;
999 type Client = ();
1000
1001 fn make(_: &Self::Client, _: impl Into<String>) -> Self {
1002 Self::default()
1003 }
1004
1005 async fn completion(
1006 &self,
1007 _request: CompletionRequest,
1008 ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
1009 Err(CompletionError::ProviderError(
1010 "completion is unused in this streaming test".to_string(),
1011 ))
1012 }
1013
1014 async fn stream(
1015 &self,
1016 request: CompletionRequest,
1017 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
1018 let turn = self.turn_counter.fetch_add(1, Ordering::SeqCst);
1019 let validation_error = if turn == 0 {
1020 None
1021 } else {
1022 validate_follow_up_tool_history(&request).err()
1023 };
1024 let stream = async_stream::stream! {
1025 if turn == 0 {
1026 yield Ok(RawStreamingChoice::ToolCall(
1027 RawStreamingToolCall::new(
1028 "tool_call_1".to_string(),
1029 "missing_tool".to_string(),
1030 serde_json::json!({"input": "value"}),
1031 )
1032 .with_call_id("call_1".to_string()),
1033 ));
1034 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(4)));
1035 } else if let Some(error) = validation_error {
1036 yield Err(CompletionError::ProviderError(error));
1037 } else {
1038 yield Ok(RawStreamingChoice::Message("done".to_string()));
1039 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(6)));
1040 }
1041 };
1042
1043 let pinned_stream: crate::streaming::StreamingResult<Self::StreamingResponse> =
1044 Box::pin(stream);
1045 Ok(StreamingCompletionResponse::stream(pinned_stream))
1046 }
1047 }
1048
1049 #[tokio::test]
1050 async fn stream_prompt_continues_after_tool_call_turn() {
1051 let model = MultiTurnMockModel::default();
1052 let turn_counter = model.turn_counter.clone();
1053 let agent = AgentBuilder::new(model).build();
1054 let empty_history: &[Message] = &[];
1055
1056 let mut stream = agent
1057 .stream_prompt("do tool work")
1058 .with_history(empty_history)
1059 .multi_turn(3)
1060 .await;
1061 let mut saw_tool_call = false;
1062 let mut saw_tool_result = false;
1063 let mut saw_final_response = false;
1064 let mut final_text = String::new();
1065 let mut final_response_text = None;
1066 let mut final_history = None;
1067
1068 while let Some(item) = stream.next().await {
1069 match item {
1070 Ok(MultiTurnStreamItem::StreamAssistantItem(
1071 StreamedAssistantContent::ToolCall { .. },
1072 )) => {
1073 saw_tool_call = true;
1074 }
1075 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
1076 ..
1077 })) => {
1078 saw_tool_result = true;
1079 }
1080 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1081 text,
1082 ))) => {
1083 final_text.push_str(&text.text);
1084 }
1085 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1086 saw_final_response = true;
1087 final_response_text = Some(res.response().to_owned());
1088 final_history = res.history().map(|history| history.to_vec());
1089 break;
1090 }
1091 Ok(_) => {}
1092 Err(err) => panic!("unexpected streaming error: {err:?}"),
1093 }
1094 }
1095
1096 assert!(saw_tool_call);
1097 assert!(saw_tool_result);
1098 assert!(saw_final_response);
1099 assert_eq!(final_text, "done");
1100 assert_eq!(final_response_text.as_deref(), Some("done"));
1101 let history = final_history.expect("expected final response history");
1102 assert!(history.iter().any(|message| matches!(
1103 message,
1104 Message::Assistant { content, .. }
1105 if content.iter().any(|item| matches!(
1106 item,
1107 AssistantContent::Text(text) if text.text == "done"
1108 ))
1109 )));
1110 assert_eq!(turn_counter.load(Ordering::SeqCst), 2);
1111 }
1112
1113 #[derive(Clone, Copy)]
1114 enum FinalResponseScenario {
1115 TextThenFinal,
1116 FinalOnly,
1117 }
1118
1119 #[derive(Clone)]
1120 struct FinalResponseMockModel {
1121 scenario: FinalResponseScenario,
1122 }
1123
1124 #[allow(refining_impl_trait)]
1125 impl CompletionModel for FinalResponseMockModel {
1126 type Response = ();
1127 type StreamingResponse = MockStreamingResponse;
1128 type Client = ();
1129
1130 fn make(_: &Self::Client, _: impl Into<String>) -> Self {
1131 Self {
1132 scenario: FinalResponseScenario::TextThenFinal,
1133 }
1134 }
1135
1136 async fn completion(
1137 &self,
1138 _request: CompletionRequest,
1139 ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
1140 Err(CompletionError::ProviderError(
1141 "completion is unused in this streaming test".to_string(),
1142 ))
1143 }
1144
1145 async fn stream(
1146 &self,
1147 _request: CompletionRequest,
1148 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
1149 let scenario = self.scenario;
1150 let stream = async_stream::stream! {
1151 match scenario {
1152 FinalResponseScenario::TextThenFinal => {
1153 yield Ok(RawStreamingChoice::Message("hello".to_string()));
1154 yield Ok(RawStreamingChoice::Message(" world".to_string()));
1155 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(3)));
1156 }
1157 FinalResponseScenario::FinalOnly => {
1158 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(1)));
1159 }
1160 }
1161 };
1162
1163 let pinned_stream: crate::streaming::StreamingResult<Self::StreamingResponse> =
1164 Box::pin(stream);
1165 Ok(StreamingCompletionResponse::stream(pinned_stream))
1166 }
1167 }
1168
1169 #[tokio::test]
1170 async fn final_response_matches_streamed_text_when_provider_final_is_textless() {
1171 let agent = AgentBuilder::new(FinalResponseMockModel {
1172 scenario: FinalResponseScenario::TextThenFinal,
1173 })
1174 .build();
1175
1176 let mut stream = agent.stream_prompt("say hello").await;
1177 let mut streamed_text = String::new();
1178 let mut final_response_text = None;
1179
1180 while let Some(item) = stream.next().await {
1181 match item {
1182 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1183 text,
1184 ))) => streamed_text.push_str(&text.text),
1185 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1186 final_response_text = Some(res.response().to_owned());
1187 break;
1188 }
1189 Ok(_) => {}
1190 Err(err) => panic!("unexpected streaming error: {err:?}"),
1191 }
1192 }
1193
1194 assert_eq!(streamed_text, "hello world");
1195 assert_eq!(final_response_text.as_deref(), Some("hello world"));
1196 }
1197
1198 #[tokio::test]
1199 async fn final_response_can_remain_empty_for_truly_textless_turns() {
1200 let agent = AgentBuilder::new(FinalResponseMockModel {
1201 scenario: FinalResponseScenario::FinalOnly,
1202 })
1203 .build();
1204
1205 let mut stream = agent.stream_prompt("say nothing").await;
1206 let mut streamed_text = String::new();
1207 let mut final_response_text = None;
1208
1209 while let Some(item) = stream.next().await {
1210 match item {
1211 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1212 text,
1213 ))) => streamed_text.push_str(&text.text),
1214 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1215 final_response_text = Some(res.response().to_owned());
1216 break;
1217 }
1218 Ok(_) => {}
1219 Err(err) => panic!("unexpected streaming error: {err:?}"),
1220 }
1221 }
1222
1223 assert!(streamed_text.is_empty());
1224 assert_eq!(final_response_text.as_deref(), Some(""));
1225 }
1226
1227 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
1230 let mut interval = tokio::time::interval(Duration::from_millis(50));
1231 let mut count = 0u32;
1232
1233 while !stop.load(Ordering::Relaxed) {
1234 interval.tick().await;
1235 count += 1;
1236
1237 tracing::event!(
1238 target: "background_logger",
1239 tracing::Level::INFO,
1240 count = count,
1241 "Background tick"
1242 );
1243
1244 let current = tracing::Span::current();
1246 if !current.is_disabled() && !current.is_none() {
1247 leak_count.fetch_add(1, Ordering::Relaxed);
1248 }
1249 }
1250
1251 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
1252 }
1253
1254 #[tokio::test(flavor = "current_thread")]
1262 #[ignore = "This requires an API key"]
1263 async fn test_span_context_isolation() {
1264 let stop = Arc::new(AtomicBool::new(false));
1265 let leak_count = Arc::new(AtomicU32::new(0));
1266
1267 let bg_stop = stop.clone();
1269 let bg_leak = leak_count.clone();
1270 let bg_handle = tokio::spawn(async move {
1271 background_logger(bg_stop, bg_leak).await;
1272 });
1273
1274 tokio::time::sleep(Duration::from_millis(100)).await;
1276
1277 let client = anthropic::Client::from_env();
1280 let agent = client
1281 .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
1282 .preamble("You are a helpful assistant.")
1283 .temperature(0.1)
1284 .max_tokens(100)
1285 .build();
1286
1287 let mut stream = agent
1288 .stream_prompt("Say 'hello world' and nothing else.")
1289 .await;
1290
1291 let mut full_content = String::new();
1292 while let Some(item) = stream.next().await {
1293 match item {
1294 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1295 text,
1296 ))) => {
1297 full_content.push_str(&text.text);
1298 }
1299 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
1300 break;
1301 }
1302 Err(e) => {
1303 tracing::warn!("Error: {:?}", e);
1304 break;
1305 }
1306 _ => {}
1307 }
1308 }
1309
1310 tracing::info!("Got response: {:?}", full_content);
1311
1312 stop.store(true, Ordering::Relaxed);
1314 bg_handle.await.unwrap();
1315
1316 let leaks = leak_count.load(Ordering::Relaxed);
1317 assert_eq!(
1318 leaks, 0,
1319 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
1320 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
1321 );
1322 }
1323
1324 #[tokio::test]
1330 #[ignore = "This requires an API key"]
1331 async fn test_chat_history_in_final_response() {
1332 use crate::message::Message;
1333
1334 let client = anthropic::Client::from_env();
1335 let agent = client
1336 .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
1337 .preamble("You are a helpful assistant. Keep responses brief.")
1338 .temperature(0.1)
1339 .max_tokens(50)
1340 .build();
1341
1342 let empty_history: &[Message] = &[];
1344 let mut stream = agent
1345 .stream_prompt("Say 'hello' and nothing else.")
1346 .with_history(empty_history)
1347 .await;
1348
1349 let mut response_text = String::new();
1351 let mut final_history = None;
1352 while let Some(item) = stream.next().await {
1353 match item {
1354 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1355 text,
1356 ))) => {
1357 response_text.push_str(&text.text);
1358 }
1359 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1360 final_history = res.history().map(|h| h.to_vec());
1361 break;
1362 }
1363 Err(e) => {
1364 panic!("Streaming error: {:?}", e);
1365 }
1366 _ => {}
1367 }
1368 }
1369
1370 let history =
1371 final_history.expect("FinalResponse should contain history when with_history is used");
1372
1373 assert!(
1375 history.iter().any(|m| matches!(m, Message::User { .. })),
1376 "History should contain the user message"
1377 );
1378
1379 assert!(
1381 history
1382 .iter()
1383 .any(|m| matches!(m, Message::Assistant { .. })),
1384 "History should contain the assistant response"
1385 );
1386
1387 tracing::info!(
1388 "History after streaming: {} messages, response: {:?}",
1389 history.len(),
1390 response_text
1391 );
1392 }
1393}