1use crate::{
2 OneOrMany,
3 agent::completion::{DynamicContextStore, build_completion_request},
4 agent::prompt_request::{HookAction, hooks::PromptHook},
5 completion::{Document, GetTokenUsage},
6 json_utils,
7 message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent},
8 streaming::{StreamedAssistantContent, StreamedUserContent},
9 tool::server::ToolServerHandle,
10 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
11};
12use futures::{Stream, StreamExt};
13use serde::{Deserialize, Serialize};
14use std::{pin::Pin, sync::Arc};
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::ToolCallHookAction;
19use crate::{
20 agent::Agent,
21 completion::{CompletionError, CompletionModel, PromptError},
22 message::{Message, Text},
23 tool::ToolSetError,
24};
25
26#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
27pub type StreamingResult<R> =
28 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
29
30#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
31pub type StreamingResult<R> =
32 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(tag = "type", rename_all = "camelCase")]
36#[non_exhaustive]
37pub enum MultiTurnStreamItem<R> {
38 StreamAssistantItem(StreamedAssistantContent<R>),
40 StreamUserItem(StreamedUserContent),
42 FinalResponse(FinalResponse),
44}
45
46#[derive(Deserialize, Serialize, Debug, Clone)]
47#[serde(rename_all = "camelCase")]
48pub struct FinalResponse {
49 response: String,
52 aggregated_usage: crate::completion::Usage,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 history: Option<Vec<Message>>,
55}
56
57impl FinalResponse {
58 pub fn empty() -> Self {
59 Self {
60 response: String::new(),
61 aggregated_usage: crate::completion::Usage::new(),
62 history: None,
63 }
64 }
65
66 pub fn response(&self) -> &str {
68 &self.response
69 }
70
71 pub fn usage(&self) -> crate::completion::Usage {
72 self.aggregated_usage
73 }
74
75 pub fn history(&self) -> Option<&[Message]> {
76 self.history.as_deref()
77 }
78}
79
80impl<R> MultiTurnStreamItem<R> {
81 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
82 Self::StreamAssistantItem(item)
83 }
84
85 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
86 Self::FinalResponse(FinalResponse {
87 response: response.to_string(),
88 aggregated_usage,
89 history: None,
90 })
91 }
92
93 pub fn final_response_with_history(
94 response: &str,
95 aggregated_usage: crate::completion::Usage,
96 history: Option<Vec<Message>>,
97 ) -> Self {
98 Self::FinalResponse(FinalResponse {
99 response: response.to_string(),
100 aggregated_usage,
101 history,
102 })
103 }
104}
105
106fn merge_reasoning_blocks(
107 accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
108 incoming: &crate::message::Reasoning,
109) {
110 let ids_match = |existing: &crate::message::Reasoning| {
111 matches!(
112 (&existing.id, &incoming.id),
113 (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
114 )
115 };
116
117 if let Some(existing) = accumulated_reasoning
118 .iter_mut()
119 .rev()
120 .find(|existing| ids_match(existing))
121 {
122 existing.content.extend(incoming.content.clone());
123 } else {
124 accumulated_reasoning.push(incoming.clone());
125 }
126}
127
128fn build_full_history(
130 chat_history: Option<&[Message]>,
131 new_messages: Vec<Message>,
132) -> Vec<Message> {
133 let input = chat_history.unwrap_or(&[]);
134 input.iter().cloned().chain(new_messages).collect()
135}
136
137fn build_history_for_request(
139 chat_history: Option<&[Message]>,
140 new_messages: &[Message],
141) -> Vec<Message> {
142 let input = chat_history.unwrap_or(&[]);
143 input.iter().chain(new_messages.iter()).cloned().collect()
144}
145
146async fn cancelled_prompt_error(
147 chat_history: Option<&[Message]>,
148 new_messages: Vec<Message>,
149 reason: String,
150) -> StreamingError {
151 StreamingError::Prompt(
152 PromptError::prompt_cancelled(build_full_history(chat_history, new_messages), reason)
153 .into(),
154 )
155}
156
157fn tool_result_to_user_message(
158 id: String,
159 call_id: Option<String>,
160 tool_result: String,
161) -> Message {
162 let content = ToolResultContent::from_tool_output(tool_result);
163 let user_content = match call_id {
164 Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
165 None => UserContent::tool_result(id, content),
166 };
167
168 Message::User {
169 content: OneOrMany::one(user_content),
170 }
171}
172
173fn assistant_text_from_choice(choice: &OneOrMany<AssistantContent>) -> String {
174 choice
175 .iter()
176 .filter_map(|content| match content {
177 AssistantContent::Text(text) => Some(text.text.as_str()),
178 _ => None,
179 })
180 .collect()
181}
182
183#[derive(Debug, thiserror::Error)]
184pub enum StreamingError {
185 #[error("CompletionError: {0}")]
186 Completion(#[from] CompletionError),
187 #[error("PromptError: {0}")]
188 Prompt(#[from] Box<PromptError>),
189 #[error("ToolSetError: {0}")]
190 Tool(#[from] ToolSetError),
191}
192
193const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
194
195pub struct StreamingPromptRequest<M, P>
204where
205 M: CompletionModel,
206 P: PromptHook<M> + 'static,
207{
208 prompt: Message,
210 chat_history: Option<Vec<Message>>,
212 max_turns: usize,
214
215 model: Arc<M>,
218 agent_name: Option<String>,
220 preamble: Option<String>,
222 static_context: Vec<Document>,
224 temperature: Option<f64>,
226 max_tokens: Option<u64>,
228 additional_params: Option<serde_json::Value>,
230 tool_server_handle: ToolServerHandle,
232 dynamic_context: DynamicContextStore,
234 tool_choice: Option<ToolChoice>,
236 output_schema: Option<schemars::Schema>,
238 hook: Option<P>,
240}
241
242impl<M, P> StreamingPromptRequest<M, P>
243where
244 M: CompletionModel + 'static,
245 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
246 P: PromptHook<M>,
247{
248 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
251 StreamingPromptRequest {
252 prompt: prompt.into(),
253 chat_history: None,
254 max_turns: agent.default_max_turns.unwrap_or_default(),
255 model: agent.model.clone(),
256 agent_name: agent.name.clone(),
257 preamble: agent.preamble.clone(),
258 static_context: agent.static_context.clone(),
259 temperature: agent.temperature,
260 max_tokens: agent.max_tokens,
261 additional_params: agent.additional_params.clone(),
262 tool_server_handle: agent.tool_server_handle.clone(),
263 dynamic_context: agent.dynamic_context.clone(),
264 tool_choice: agent.tool_choice.clone(),
265 output_schema: agent.output_schema.clone(),
266 hook: None,
267 }
268 }
269
270 pub fn from_agent<P2>(
272 agent: &Agent<M, P2>,
273 prompt: impl Into<Message>,
274 ) -> StreamingPromptRequest<M, P2>
275 where
276 P2: PromptHook<M>,
277 {
278 StreamingPromptRequest {
279 prompt: prompt.into(),
280 chat_history: None,
281 max_turns: agent.default_max_turns.unwrap_or_default(),
282 model: agent.model.clone(),
283 agent_name: agent.name.clone(),
284 preamble: agent.preamble.clone(),
285 static_context: agent.static_context.clone(),
286 temperature: agent.temperature,
287 max_tokens: agent.max_tokens,
288 additional_params: agent.additional_params.clone(),
289 tool_server_handle: agent.tool_server_handle.clone(),
290 dynamic_context: agent.dynamic_context.clone(),
291 tool_choice: agent.tool_choice.clone(),
292 output_schema: agent.output_schema.clone(),
293 hook: agent.hook.clone(),
294 }
295 }
296
297 fn agent_name(&self) -> &str {
298 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
299 }
300
301 pub fn multi_turn(mut self, turns: usize) -> Self {
304 self.max_turns = turns;
305 self
306 }
307
308 pub fn with_history<I, T>(mut self, history: I) -> Self
321 where
322 I: IntoIterator<Item = T>,
323 T: Into<Message>,
324 {
325 self.chat_history = Some(history.into_iter().map(Into::into).collect());
326 self
327 }
328
329 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
332 where
333 P2: PromptHook<M>,
334 {
335 StreamingPromptRequest {
336 prompt: self.prompt,
337 chat_history: self.chat_history,
338 max_turns: self.max_turns,
339 model: self.model,
340 agent_name: self.agent_name,
341 preamble: self.preamble,
342 static_context: self.static_context,
343 temperature: self.temperature,
344 max_tokens: self.max_tokens,
345 additional_params: self.additional_params,
346 tool_server_handle: self.tool_server_handle,
347 dynamic_context: self.dynamic_context,
348 tool_choice: self.tool_choice,
349 output_schema: self.output_schema,
350 hook: Some(hook),
351 }
352 }
353
354 async fn send(self) -> StreamingResult<M::StreamingResponse> {
355 let agent_span = if tracing::Span::current().is_disabled() {
356 info_span!(
357 "invoke_agent",
358 gen_ai.operation.name = "invoke_agent",
359 gen_ai.agent.name = self.agent_name(),
360 gen_ai.system_instructions = self.preamble,
361 gen_ai.prompt = tracing::field::Empty,
362 gen_ai.completion = tracing::field::Empty,
363 gen_ai.usage.input_tokens = tracing::field::Empty,
364 gen_ai.usage.output_tokens = tracing::field::Empty,
365 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
366 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
367 )
368 } else {
369 tracing::Span::current()
370 };
371
372 let prompt = self.prompt;
373 if let Some(text) = prompt.rag_text() {
374 agent_span.record("gen_ai.prompt", text);
375 }
376
377 let model = self.model.clone();
379 let preamble = self.preamble.clone();
380 let static_context = self.static_context.clone();
381 let temperature = self.temperature;
382 let max_tokens = self.max_tokens;
383 let additional_params = self.additional_params.clone();
384 let tool_server_handle = self.tool_server_handle.clone();
385 let dynamic_context = self.dynamic_context.clone();
386 let tool_choice = self.tool_choice.clone();
387 let agent_name = self.agent_name.clone();
388 let has_history = self.chat_history.is_some();
389 let chat_history = self.chat_history;
390 let mut new_messages: Vec<Message> = vec![prompt.clone()];
391
392 let mut current_max_turns = 0;
393 let mut last_prompt_error = String::new();
394
395 let mut text_delta_response = String::new();
396 let mut saw_text_this_turn = false;
397 let mut max_turns_reached = false;
398 let output_schema = self.output_schema;
399
400 let mut aggregated_usage = crate::completion::Usage::new();
401
402 let stream = async_stream::stream! {
409 'outer: loop {
410 let Some((current_prompt_ref, previous_messages)) = new_messages.split_last() else {
411 yield Err(cancelled_prompt_error(
412 chat_history.as_deref(),
413 new_messages.clone(),
414 "streaming loop lost its pending prompt".to_string(),
415 ).await);
416 break 'outer;
417 };
418 let current_prompt = current_prompt_ref.clone();
419
420 if current_max_turns > self.max_turns + 1 {
421 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
422 max_turns_reached = true;
423 break;
424 }
425
426 current_max_turns += 1;
427
428 if self.max_turns > 1 {
429 tracing::info!(
430 "Current conversation Turns: {}/{}",
431 current_max_turns,
432 self.max_turns
433 );
434 }
435
436 let history_snapshot: Vec<Message> = build_history_for_request(
437 chat_history.as_deref(),
438 previous_messages,
439 );
440
441 if let Some(ref hook) = self.hook
442 && let HookAction::Terminate { reason } =
443 hook.on_completion_call(¤t_prompt, &history_snapshot).await
444 {
445 yield Err(
446 cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason)
447 .await,
448 );
449 break 'outer;
450 }
451
452 let chat_stream_span = info_span!(
453 target: "rig::agent_chat",
454 parent: tracing::Span::current(),
455 "chat_streaming",
456 gen_ai.operation.name = "chat",
457 gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
458 gen_ai.system_instructions = preamble,
459 gen_ai.provider.name = tracing::field::Empty,
460 gen_ai.request.model = tracing::field::Empty,
461 gen_ai.response.id = tracing::field::Empty,
462 gen_ai.response.model = tracing::field::Empty,
463 gen_ai.usage.output_tokens = tracing::field::Empty,
464 gen_ai.usage.input_tokens = tracing::field::Empty,
465 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
466 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
467 gen_ai.input.messages = tracing::field::Empty,
468 gen_ai.output.messages = tracing::field::Empty,
469 );
470
471 let mut stream = tracing::Instrument::instrument(
472 build_completion_request(
473 &model,
474 current_prompt.clone(),
475 &history_snapshot,
476 preamble.as_deref(),
477 &static_context,
478 temperature,
479 max_tokens,
480 additional_params.as_ref(),
481 tool_choice.as_ref(),
482 &tool_server_handle,
483 &dynamic_context,
484 output_schema.as_ref(),
485 )
486 .await?
487 .stream(), chat_stream_span
488 )
489
490 .await?;
491
492 let mut tool_calls = vec![];
493 let mut tool_results = vec![];
494 let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
495 let mut pending_reasoning_delta_text = String::new();
498 let mut pending_reasoning_delta_id: Option<String> = None;
499 let mut saw_tool_call_this_turn = false;
500
501 while let Some(content) = stream.next().await {
502 match content {
503 Ok(StreamedAssistantContent::Text(text)) => {
504 if !saw_text_this_turn {
505 text_delta_response.clear();
506 saw_text_this_turn = true;
507 }
508 text_delta_response.push_str(&text.text);
509 if let Some(ref hook) = self.hook &&
510 let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &text_delta_response).await {
511 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
512 break 'outer;
513 }
514
515 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
516 },
517 Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id }) => {
518 let tool_span = info_span!(
519 parent: tracing::Span::current(),
520 "execute_tool",
521 gen_ai.operation.name = "execute_tool",
522 gen_ai.tool.type = "function",
523 gen_ai.tool.name = tracing::field::Empty,
524 gen_ai.tool.call.id = tracing::field::Empty,
525 gen_ai.tool.call.arguments = tracing::field::Empty,
526 gen_ai.tool.call.result = tracing::field::Empty
527 );
528
529 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
530
531 let tc_result = async {
532 let tool_span = tracing::Span::current();
533 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
534 if let Some(ref hook) = self.hook {
535 let action = hook
536 .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
537 .await;
538
539 if let ToolCallHookAction::Terminate { reason } = action {
540 return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
541 }
542
543 if let ToolCallHookAction::Skip { reason } = action {
544 tracing::info!(
546 tool_name = tool_call.function.name.as_str(),
547 reason = reason,
548 "Tool call rejected"
549 );
550 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
551 tool_calls.push(tool_call_msg);
552 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
553 saw_tool_call_this_turn = true;
554 return Ok(reason);
555 }
556 }
557
558 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
559 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
560
561 let tool_result = match
562 tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
563 Ok(thing) => thing,
564 Err(e) => {
565 tracing::warn!("Error while calling tool: {e}");
566 e.to_string()
567 }
568 };
569
570 tool_span.record("gen_ai.tool.call.result", &tool_result);
571
572 if let Some(ref hook) = self.hook &&
573 let HookAction::Terminate { reason } =
574 hook.on_tool_result(
575 &tool_call.function.name,
576 tool_call.call_id.clone(),
577 &internal_call_id,
578 &tool_args,
579 &tool_result.to_string()
580 )
581 .await {
582 return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
583 }
584
585 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
586
587 tool_calls.push(tool_call_msg);
588 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
589
590 saw_tool_call_this_turn = true;
591 Ok(tool_result)
592 }.instrument(tool_span).await;
593
594 match tc_result {
595 Ok(text) => {
596 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
597 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
598 }
599 Err(e) => {
600 yield Err(e);
601 break 'outer;
602 }
603 }
604 },
605 Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content }) => {
606 if let Some(ref hook) = self.hook {
607 let (name, delta) = match &content {
608 rig::streaming::ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
609 rig::streaming::ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
610 };
611
612 if let HookAction::Terminate { reason } = hook.on_tool_call_delta(&id, &internal_call_id, name, delta)
613 .await {
614 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
615 break 'outer;
616 }
617 }
618 }
619 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
620 merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
624 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
625 },
626 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
627 pending_reasoning_delta_text.push_str(&reasoning);
631 if pending_reasoning_delta_id.is_none() {
632 pending_reasoning_delta_id = id.clone();
633 }
634 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
635 },
636 Ok(StreamedAssistantContent::Final(final_resp)) => {
637 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
638 if saw_text_this_turn {
639 if let Some(ref hook) = self.hook &&
640 let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(¤t_prompt, &final_resp).await {
641 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
642 break 'outer;
643 }
644
645 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
646 saw_text_this_turn = false;
647 }
648 }
649 Err(e) => {
650 yield Err(e.into());
651 break 'outer;
652 }
653 }
654 }
655
656 if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
660 let mut assembled = crate::message::Reasoning::new(&pending_reasoning_delta_text);
661 if let Some(id) = pending_reasoning_delta_id.take() {
662 assembled = assembled.with_id(id);
663 }
664 accumulated_reasoning.push(assembled);
665 }
666
667 let turn_text_response = assistant_text_from_choice(&stream.choice);
668 tracing::Span::current().record("gen_ai.completion", &turn_text_response);
669
670 if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
673 let mut content_items: Vec<rig::message::AssistantContent> = vec![];
674
675 if !turn_text_response.is_empty() {
677 content_items.push(rig::message::AssistantContent::text(&turn_text_response));
678 }
679
680 for reasoning in accumulated_reasoning.drain(..) {
682 content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
683 }
684
685 content_items.extend(tool_calls.clone());
686
687 if let Some(content) = OneOrMany::from_iter_optional(content_items) {
688 new_messages.push(Message::Assistant {
689 id: stream.message_id.clone(),
690 content,
691 });
692 }
693 }
694
695 for (id, call_id, tool_result) in tool_results {
696 new_messages.push(tool_result_to_user_message(id, call_id, tool_result));
697 }
698
699 if !saw_tool_call_this_turn {
700 if !turn_text_response.is_empty() {
702 new_messages.push(Message::assistant(&turn_text_response));
703 } else {
704 tracing::warn!(
705 agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
706 message_id = ?stream.message_id,
707 "Streaming turn completed without assistant text; final response will be empty"
708 );
709 }
710
711 let current_span = tracing::Span::current();
712 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
713 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
714 current_span.record("gen_ai.usage.cache_read.input_tokens", aggregated_usage.cached_input_tokens);
715 current_span.record("gen_ai.usage.cache_creation.input_tokens", aggregated_usage.cache_creation_input_tokens);
716 tracing::info!("Agent multi-turn stream finished");
717 let final_messages: Option<Vec<Message>> = if has_history {
718 Some(new_messages.clone())
719 } else {
720 None
721 };
722 yield Ok(MultiTurnStreamItem::final_response_with_history(
723 &turn_text_response,
724 aggregated_usage,
725 final_messages,
726 ));
727 break;
728 }
729 }
730
731 if max_turns_reached {
732 yield Err(Box::new(PromptError::MaxTurnsError {
733 max_turns: self.max_turns,
734 chat_history: build_full_history(chat_history.as_deref(), new_messages.clone()).into(),
735 prompt: Box::new(last_prompt_error.clone().into()),
736 }).into());
737 }
738 };
739
740 Box::pin(stream.instrument(agent_span))
741 }
742}
743
744impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
745where
746 M: CompletionModel + 'static,
747 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
748 P: PromptHook<M> + 'static,
749{
750 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
752
753 fn into_future(self) -> Self::IntoFuture {
754 Box::pin(async move { self.send().await })
756 }
757}
758
759pub async fn stream_to_stdout<R>(
761 stream: &mut StreamingResult<R>,
762) -> Result<FinalResponse, std::io::Error> {
763 let mut final_res = FinalResponse::empty();
764 print!("Response: ");
765 while let Some(content) = stream.next().await {
766 match content {
767 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
768 Text { text },
769 ))) => {
770 print!("{text}");
771 std::io::Write::flush(&mut std::io::stdout())?;
772 }
773 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
774 reasoning,
775 ))) => {
776 let reasoning = reasoning.display_text();
777 print!("{reasoning}");
778 std::io::Write::flush(&mut std::io::stdout())?;
779 }
780 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
781 final_res = res;
782 }
783 Err(err) => {
784 eprintln!("Error: {err}");
785 }
786 _ => {}
787 }
788 }
789
790 Ok(final_res)
791}
792
793#[cfg(test)]
794mod tests {
795 use super::*;
796 use crate::agent::AgentBuilder;
797 use crate::client::ProviderClient;
798 use crate::client::completion::CompletionClient;
799 use crate::completion::{
800 CompletionError, CompletionModel, CompletionRequest, CompletionResponse,
801 };
802 use crate::message::{
803 AssistantContent, DocumentSourceKind, ImageMediaType, Message, ReasoningContent,
804 ToolResultContent, UserContent,
805 };
806 use crate::providers::anthropic;
807 use crate::streaming::StreamingPrompt;
808 use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
809 use futures::StreamExt;
810 use serde::{Deserialize, Serialize};
811 use std::sync::Arc;
812 use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
813 use std::time::Duration;
814
815 #[test]
816 fn merge_reasoning_blocks_preserves_order_and_signatures() {
817 let mut accumulated = Vec::new();
818 let first = crate::message::Reasoning {
819 id: Some("rs_1".to_string()),
820 content: vec![ReasoningContent::Text {
821 text: "step-1".to_string(),
822 signature: Some("sig-1".to_string()),
823 }],
824 };
825 let second = crate::message::Reasoning {
826 id: Some("rs_1".to_string()),
827 content: vec![
828 ReasoningContent::Text {
829 text: "step-2".to_string(),
830 signature: Some("sig-2".to_string()),
831 },
832 ReasoningContent::Summary("summary".to_string()),
833 ],
834 };
835
836 merge_reasoning_blocks(&mut accumulated, &first);
837 merge_reasoning_blocks(&mut accumulated, &second);
838
839 assert_eq!(accumulated.len(), 1);
840 let merged = accumulated.first().expect("expected accumulated reasoning");
841 assert_eq!(merged.id.as_deref(), Some("rs_1"));
842 assert_eq!(merged.content.len(), 3);
843 assert!(matches!(
844 merged.content.first(),
845 Some(ReasoningContent::Text { text, signature: Some(sig) })
846 if text == "step-1" && sig == "sig-1"
847 ));
848 assert!(matches!(
849 merged.content.get(1),
850 Some(ReasoningContent::Text { text, signature: Some(sig) })
851 if text == "step-2" && sig == "sig-2"
852 ));
853 }
854
855 #[test]
856 fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
857 let mut accumulated = vec![crate::message::Reasoning {
858 id: Some("rs_a".to_string()),
859 content: vec![ReasoningContent::Text {
860 text: "step-1".to_string(),
861 signature: None,
862 }],
863 }];
864 let incoming = crate::message::Reasoning {
865 id: Some("rs_b".to_string()),
866 content: vec![ReasoningContent::Text {
867 text: "step-2".to_string(),
868 signature: None,
869 }],
870 };
871
872 merge_reasoning_blocks(&mut accumulated, &incoming);
873 assert_eq!(accumulated.len(), 2);
874 assert_eq!(
875 accumulated.first().and_then(|r| r.id.as_deref()),
876 Some("rs_a")
877 );
878 assert_eq!(
879 accumulated.get(1).and_then(|r| r.id.as_deref()),
880 Some("rs_b")
881 );
882 }
883
884 #[test]
885 fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
886 let mut accumulated = vec![crate::message::Reasoning {
887 id: None,
888 content: vec![ReasoningContent::Text {
889 text: "first".to_string(),
890 signature: None,
891 }],
892 }];
893 let incoming = crate::message::Reasoning {
894 id: None,
895 content: vec![ReasoningContent::Text {
896 text: "second".to_string(),
897 signature: None,
898 }],
899 };
900
901 merge_reasoning_blocks(&mut accumulated, &incoming);
902 assert_eq!(accumulated.len(), 2);
903 assert!(matches!(
904 accumulated.first(),
905 Some(crate::message::Reasoning {
906 id: None,
907 content
908 }) if matches!(
909 content.first(),
910 Some(ReasoningContent::Text { text, .. }) if text == "first"
911 )
912 ));
913 assert!(matches!(
914 accumulated.get(1),
915 Some(crate::message::Reasoning {
916 id: None,
917 content
918 }) if matches!(
919 content.first(),
920 Some(ReasoningContent::Text { text, .. }) if text == "second"
921 )
922 ));
923 }
924
925 #[test]
926 fn tool_result_to_user_message_preserves_multimodal_tool_output() {
927 let message = tool_result_to_user_message(
928 "tool_call_1".to_string(),
929 Some("call_1".to_string()),
930 serde_json::json!({
931 "response": {
932 "instruction": "Use the image part to answer."
933 },
934 "parts": [
935 {
936 "type": "image",
937 "data": "base64data==",
938 "mimeType": "image/png"
939 }
940 ]
941 })
942 .to_string(),
943 );
944
945 let tool_result = match message {
946 Message::User { content } => match content.first() {
947 UserContent::ToolResult(tool_result) => tool_result,
948 other => panic!("expected tool result content, got {other:?}"),
949 },
950 other => panic!("expected user message, got {other:?}"),
951 };
952
953 assert_eq!(tool_result.id, "tool_call_1");
954 assert_eq!(tool_result.call_id.as_deref(), Some("call_1"));
955 assert_eq!(tool_result.content.len(), 2);
956
957 let mut items = tool_result.content.iter();
958 match items.next() {
959 Some(ToolResultContent::Text(text)) => {
960 assert!(text.text.contains("Use the image part to answer."));
961 }
962 other => panic!("expected structured text payload first, got {other:?}"),
963 }
964
965 match items.next() {
966 Some(ToolResultContent::Image(image)) => {
967 assert_eq!(image.media_type, Some(ImageMediaType::PNG));
968 assert!(matches!(
969 image.data,
970 DocumentSourceKind::Base64(ref data) if data == "base64data=="
971 ));
972 }
973 other => panic!("expected image payload second, got {other:?}"),
974 }
975 }
976
977 #[derive(Clone, Debug, Deserialize, Serialize)]
978 struct MockStreamingResponse {
979 usage: crate::completion::Usage,
980 }
981
982 impl MockStreamingResponse {
983 fn new(total_tokens: u64) -> Self {
984 let mut usage = crate::completion::Usage::new();
985 usage.total_tokens = total_tokens;
986 Self { usage }
987 }
988 }
989
990 impl crate::completion::GetTokenUsage for MockStreamingResponse {
991 fn token_usage(&self) -> Option<crate::completion::Usage> {
992 Some(self.usage)
993 }
994 }
995
996 fn validate_follow_up_tool_history(request: &CompletionRequest) -> Result<(), String> {
997 let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
998 if history.len() != 3 {
999 return Err(format!(
1000 "follow-up request should contain [original user prompt, assistant tool call, user tool result]: {history:?}"
1001 ));
1002 }
1003
1004 if !matches!(
1005 history.first(),
1006 Some(Message::User { content })
1007 if matches!(
1008 content.first(),
1009 UserContent::Text(text) if text.text == "do tool work"
1010 )
1011 ) {
1012 return Err(format!(
1013 "follow-up request should begin with the original user prompt: {history:?}"
1014 ));
1015 }
1016
1017 if !matches!(
1018 history.get(1),
1019 Some(Message::Assistant { content, .. })
1020 if matches!(
1021 content.first(),
1022 AssistantContent::ToolCall(tool_call)
1023 if tool_call.id == "tool_call_1"
1024 && tool_call.call_id.as_deref() == Some("call_1")
1025 )
1026 ) {
1027 return Err(format!(
1028 "follow-up request is missing the assistant tool call in position 2: {history:?}"
1029 ));
1030 }
1031
1032 if !matches!(
1033 history.get(2),
1034 Some(Message::User { content })
1035 if matches!(
1036 content.first(),
1037 UserContent::ToolResult(tool_result)
1038 if tool_result.id == "tool_call_1"
1039 && tool_result.call_id.as_deref() == Some("call_1")
1040 )
1041 ) {
1042 return Err(format!(
1043 "follow-up request should end with the user tool result: {history:?}"
1044 ));
1045 }
1046
1047 Ok(())
1048 }
1049
1050 #[derive(Clone, Default)]
1051 struct MultiTurnMockModel {
1052 turn_counter: Arc<AtomicUsize>,
1053 }
1054
1055 #[allow(refining_impl_trait)]
1056 impl CompletionModel for MultiTurnMockModel {
1057 type Response = ();
1058 type StreamingResponse = MockStreamingResponse;
1059 type Client = ();
1060
1061 fn make(_: &Self::Client, _: impl Into<String>) -> Self {
1062 Self::default()
1063 }
1064
1065 async fn completion(
1066 &self,
1067 _request: CompletionRequest,
1068 ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
1069 Err(CompletionError::ProviderError(
1070 "completion is unused in this streaming test".to_string(),
1071 ))
1072 }
1073
1074 async fn stream(
1075 &self,
1076 request: CompletionRequest,
1077 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
1078 let turn = self.turn_counter.fetch_add(1, Ordering::SeqCst);
1079 let validation_error = if turn == 0 {
1080 None
1081 } else {
1082 validate_follow_up_tool_history(&request).err()
1083 };
1084 let stream = async_stream::stream! {
1085 if turn == 0 {
1086 yield Ok(RawStreamingChoice::ToolCall(
1087 RawStreamingToolCall::new(
1088 "tool_call_1".to_string(),
1089 "missing_tool".to_string(),
1090 serde_json::json!({"input": "value"}),
1091 )
1092 .with_call_id("call_1".to_string()),
1093 ));
1094 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(4)));
1095 } else if let Some(error) = validation_error {
1096 yield Err(CompletionError::ProviderError(error));
1097 } else {
1098 yield Ok(RawStreamingChoice::Message("done".to_string()));
1099 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(6)));
1100 }
1101 };
1102
1103 let pinned_stream: crate::streaming::StreamingResult<Self::StreamingResponse> =
1104 Box::pin(stream);
1105 Ok(StreamingCompletionResponse::stream(pinned_stream))
1106 }
1107 }
1108
1109 #[tokio::test]
1110 async fn stream_prompt_continues_after_tool_call_turn() {
1111 let model = MultiTurnMockModel::default();
1112 let turn_counter = model.turn_counter.clone();
1113 let agent = AgentBuilder::new(model).build();
1114 let empty_history: &[Message] = &[];
1115
1116 let mut stream = agent
1117 .stream_prompt("do tool work")
1118 .with_history(empty_history)
1119 .multi_turn(3)
1120 .await;
1121 let mut saw_tool_call = false;
1122 let mut saw_tool_result = false;
1123 let mut saw_final_response = false;
1124 let mut final_text = String::new();
1125 let mut final_response_text = None;
1126 let mut final_history = None;
1127
1128 while let Some(item) = stream.next().await {
1129 match item {
1130 Ok(MultiTurnStreamItem::StreamAssistantItem(
1131 StreamedAssistantContent::ToolCall { .. },
1132 )) => {
1133 saw_tool_call = true;
1134 }
1135 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
1136 ..
1137 })) => {
1138 saw_tool_result = true;
1139 }
1140 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1141 text,
1142 ))) => {
1143 final_text.push_str(&text.text);
1144 }
1145 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1146 saw_final_response = true;
1147 final_response_text = Some(res.response().to_owned());
1148 final_history = res.history().map(|history| history.to_vec());
1149 break;
1150 }
1151 Ok(_) => {}
1152 Err(err) => panic!("unexpected streaming error: {err:?}"),
1153 }
1154 }
1155
1156 assert!(saw_tool_call);
1157 assert!(saw_tool_result);
1158 assert!(saw_final_response);
1159 assert_eq!(final_text, "done");
1160 assert_eq!(final_response_text.as_deref(), Some("done"));
1161 let history = final_history.expect("expected final response history");
1162 assert!(history.iter().any(|message| matches!(
1163 message,
1164 Message::Assistant { content, .. }
1165 if content.iter().any(|item| matches!(
1166 item,
1167 AssistantContent::Text(text) if text.text == "done"
1168 ))
1169 )));
1170 assert_eq!(turn_counter.load(Ordering::SeqCst), 2);
1171 }
1172
1173 #[derive(Clone, Copy)]
1174 enum FinalResponseScenario {
1175 TextThenFinal,
1176 FinalOnly,
1177 }
1178
1179 #[derive(Clone)]
1180 struct FinalResponseMockModel {
1181 scenario: FinalResponseScenario,
1182 }
1183
1184 #[allow(refining_impl_trait)]
1185 impl CompletionModel for FinalResponseMockModel {
1186 type Response = ();
1187 type StreamingResponse = MockStreamingResponse;
1188 type Client = ();
1189
1190 fn make(_: &Self::Client, _: impl Into<String>) -> Self {
1191 Self {
1192 scenario: FinalResponseScenario::TextThenFinal,
1193 }
1194 }
1195
1196 async fn completion(
1197 &self,
1198 _request: CompletionRequest,
1199 ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
1200 Err(CompletionError::ProviderError(
1201 "completion is unused in this streaming test".to_string(),
1202 ))
1203 }
1204
1205 async fn stream(
1206 &self,
1207 _request: CompletionRequest,
1208 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
1209 let scenario = self.scenario;
1210 let stream = async_stream::stream! {
1211 match scenario {
1212 FinalResponseScenario::TextThenFinal => {
1213 yield Ok(RawStreamingChoice::Message("hello".to_string()));
1214 yield Ok(RawStreamingChoice::Message(" world".to_string()));
1215 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(3)));
1216 }
1217 FinalResponseScenario::FinalOnly => {
1218 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(1)));
1219 }
1220 }
1221 };
1222
1223 let pinned_stream: crate::streaming::StreamingResult<Self::StreamingResponse> =
1224 Box::pin(stream);
1225 Ok(StreamingCompletionResponse::stream(pinned_stream))
1226 }
1227 }
1228
1229 #[tokio::test]
1230 async fn final_response_matches_streamed_text_when_provider_final_is_textless() {
1231 let agent = AgentBuilder::new(FinalResponseMockModel {
1232 scenario: FinalResponseScenario::TextThenFinal,
1233 })
1234 .build();
1235
1236 let mut stream = agent.stream_prompt("say hello").await;
1237 let mut streamed_text = String::new();
1238 let mut final_response_text = None;
1239
1240 while let Some(item) = stream.next().await {
1241 match item {
1242 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1243 text,
1244 ))) => streamed_text.push_str(&text.text),
1245 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1246 final_response_text = Some(res.response().to_owned());
1247 break;
1248 }
1249 Ok(_) => {}
1250 Err(err) => panic!("unexpected streaming error: {err:?}"),
1251 }
1252 }
1253
1254 assert_eq!(streamed_text, "hello world");
1255 assert_eq!(final_response_text.as_deref(), Some("hello world"));
1256 }
1257
1258 #[tokio::test]
1259 async fn final_response_can_remain_empty_for_truly_textless_turns() {
1260 let agent = AgentBuilder::new(FinalResponseMockModel {
1261 scenario: FinalResponseScenario::FinalOnly,
1262 })
1263 .build();
1264
1265 let mut stream = agent.stream_prompt("say nothing").await;
1266 let mut streamed_text = String::new();
1267 let mut final_response_text = None;
1268
1269 while let Some(item) = stream.next().await {
1270 match item {
1271 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1272 text,
1273 ))) => streamed_text.push_str(&text.text),
1274 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1275 final_response_text = Some(res.response().to_owned());
1276 break;
1277 }
1278 Ok(_) => {}
1279 Err(err) => panic!("unexpected streaming error: {err:?}"),
1280 }
1281 }
1282
1283 assert!(streamed_text.is_empty());
1284 assert_eq!(final_response_text.as_deref(), Some(""));
1285 }
1286
1287 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
1290 let mut interval = tokio::time::interval(Duration::from_millis(50));
1291 let mut count = 0u32;
1292
1293 while !stop.load(Ordering::Relaxed) {
1294 interval.tick().await;
1295 count += 1;
1296
1297 tracing::event!(
1298 target: "background_logger",
1299 tracing::Level::INFO,
1300 count = count,
1301 "Background tick"
1302 );
1303
1304 let current = tracing::Span::current();
1306 if !current.is_disabled() && !current.is_none() {
1307 leak_count.fetch_add(1, Ordering::Relaxed);
1308 }
1309 }
1310
1311 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
1312 }
1313
1314 #[tokio::test(flavor = "current_thread")]
1322 #[ignore = "This requires an API key"]
1323 async fn test_span_context_isolation() -> anyhow::Result<()> {
1324 let stop = Arc::new(AtomicBool::new(false));
1325 let leak_count = Arc::new(AtomicU32::new(0));
1326
1327 let bg_stop = stop.clone();
1329 let bg_leak = leak_count.clone();
1330 let bg_handle = tokio::spawn(async move {
1331 background_logger(bg_stop, bg_leak).await;
1332 });
1333
1334 tokio::time::sleep(Duration::from_millis(100)).await;
1336
1337 let client = anthropic::Client::from_env()?;
1340 let agent = client
1341 .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
1342 .preamble("You are a helpful assistant.")
1343 .temperature(0.1)
1344 .max_tokens(100)
1345 .build();
1346
1347 let mut stream = agent
1348 .stream_prompt("Say 'hello world' and nothing else.")
1349 .await;
1350
1351 let mut full_content = String::new();
1352 while let Some(item) = stream.next().await {
1353 match item {
1354 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1355 text,
1356 ))) => {
1357 full_content.push_str(&text.text);
1358 }
1359 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
1360 break;
1361 }
1362 Err(e) => {
1363 tracing::warn!("Error: {:?}", e);
1364 break;
1365 }
1366 _ => {}
1367 }
1368 }
1369
1370 tracing::info!("Got response: {:?}", full_content);
1371
1372 stop.store(true, Ordering::Relaxed);
1374 bg_handle.await?;
1375
1376 let leaks = leak_count.load(Ordering::Relaxed);
1377 anyhow::ensure!(
1378 leaks == 0,
1379 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
1380 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
1381 );
1382
1383 Ok(())
1384 }
1385
1386 #[tokio::test]
1392 #[ignore = "This requires an API key"]
1393 async fn test_chat_history_in_final_response() -> anyhow::Result<()> {
1394 use crate::message::Message;
1395
1396 let client = anthropic::Client::from_env()?;
1397 let agent = client
1398 .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
1399 .preamble("You are a helpful assistant. Keep responses brief.")
1400 .temperature(0.1)
1401 .max_tokens(50)
1402 .build();
1403
1404 let empty_history: &[Message] = &[];
1406 let mut stream = agent
1407 .stream_prompt("Say 'hello' and nothing else.")
1408 .with_history(empty_history)
1409 .await;
1410
1411 let mut response_text = String::new();
1413 let mut final_history = None;
1414 while let Some(item) = stream.next().await {
1415 match item {
1416 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1417 text,
1418 ))) => {
1419 response_text.push_str(&text.text);
1420 }
1421 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1422 final_history = res.history().map(|h| h.to_vec());
1423 break;
1424 }
1425 Err(e) => {
1426 return Err(e.into());
1427 }
1428 _ => {}
1429 }
1430 }
1431
1432 let history = final_history
1433 .ok_or_else(|| anyhow::anyhow!("final response should include history"))?;
1434
1435 anyhow::ensure!(
1437 history.iter().any(|m| matches!(m, Message::User { .. })),
1438 "History should contain the user message"
1439 );
1440
1441 anyhow::ensure!(
1443 history
1444 .iter()
1445 .any(|m| matches!(m, Message::Assistant { .. })),
1446 "History should contain the assistant response"
1447 );
1448
1449 tracing::info!(
1450 "History after streaming: {} messages, response: {:?}",
1451 history.len(),
1452 response_text
1453 );
1454
1455 Ok(())
1456 }
1457}