1use crate::{
2 OneOrMany,
3 agent::completion::{DynamicContextStore, build_completion_request},
4 agent::prompt_request::{HookAction, hooks::PromptHook},
5 completion::{Document, GetTokenUsage},
6 json_utils,
7 message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent},
8 streaming::{StreamedAssistantContent, StreamedUserContent},
9 tool::server::ToolServerHandle,
10 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
11};
12use futures::{Stream, StreamExt};
13use serde::{Deserialize, Serialize};
14use std::{pin::Pin, sync::Arc};
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::ToolCallHookAction;
19use crate::{
20 agent::Agent,
21 completion::{CompletionError, CompletionModel, PromptError},
22 message::{Message, Text},
23 tool::ToolSetError,
24};
25
26#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
27pub type StreamingResult<R> =
28 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
29
30#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
31pub type StreamingResult<R> =
32 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(tag = "type", rename_all = "camelCase")]
36#[non_exhaustive]
37pub enum MultiTurnStreamItem<R> {
38 StreamAssistantItem(StreamedAssistantContent<R>),
40 StreamUserItem(StreamedUserContent),
42 FinalResponse(FinalResponse),
44}
45
46#[derive(Deserialize, Serialize, Debug, Clone)]
47#[serde(rename_all = "camelCase")]
48pub struct FinalResponse {
49 response: String,
50 aggregated_usage: crate::completion::Usage,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 history: Option<Vec<Message>>,
53}
54
55impl FinalResponse {
56 pub fn empty() -> Self {
57 Self {
58 response: String::new(),
59 aggregated_usage: crate::completion::Usage::new(),
60 history: None,
61 }
62 }
63
64 pub fn response(&self) -> &str {
65 &self.response
66 }
67
68 pub fn usage(&self) -> crate::completion::Usage {
69 self.aggregated_usage
70 }
71
72 pub fn history(&self) -> Option<&[Message]> {
73 self.history.as_deref()
74 }
75}
76
77impl<R> MultiTurnStreamItem<R> {
78 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
79 Self::StreamAssistantItem(item)
80 }
81
82 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
83 Self::FinalResponse(FinalResponse {
84 response: response.to_string(),
85 aggregated_usage,
86 history: None,
87 })
88 }
89
90 pub fn final_response_with_history(
91 response: &str,
92 aggregated_usage: crate::completion::Usage,
93 history: Option<Vec<Message>>,
94 ) -> Self {
95 Self::FinalResponse(FinalResponse {
96 response: response.to_string(),
97 aggregated_usage,
98 history,
99 })
100 }
101}
102
103fn merge_reasoning_blocks(
104 accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
105 incoming: &crate::message::Reasoning,
106) {
107 let ids_match = |existing: &crate::message::Reasoning| {
108 matches!(
109 (&existing.id, &incoming.id),
110 (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
111 )
112 };
113
114 if let Some(existing) = accumulated_reasoning
115 .iter_mut()
116 .rev()
117 .find(|existing| ids_match(existing))
118 {
119 existing.content.extend(incoming.content.clone());
120 } else {
121 accumulated_reasoning.push(incoming.clone());
122 }
123}
124
125fn build_full_history(
127 chat_history: Option<&[Message]>,
128 new_messages: Vec<Message>,
129) -> Vec<Message> {
130 let input = chat_history.unwrap_or(&[]);
131 input.iter().cloned().chain(new_messages).collect()
132}
133
134fn build_history_for_request(
136 chat_history: Option<&[Message]>,
137 new_messages: &[Message],
138) -> Vec<Message> {
139 let input = chat_history.unwrap_or(&[]);
140 input.iter().chain(new_messages.iter()).cloned().collect()
141}
142
143async fn cancelled_prompt_error(
144 chat_history: Option<&[Message]>,
145 new_messages: Vec<Message>,
146 reason: String,
147) -> StreamingError {
148 StreamingError::Prompt(
149 PromptError::prompt_cancelled(build_full_history(chat_history, new_messages), reason)
150 .into(),
151 )
152}
153
154fn tool_result_to_user_message(
155 id: String,
156 call_id: Option<String>,
157 tool_result: String,
158) -> Message {
159 let content = OneOrMany::one(ToolResultContent::text(tool_result));
160 let user_content = match call_id {
161 Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
162 None => UserContent::tool_result(id, content),
163 };
164
165 Message::User {
166 content: OneOrMany::one(user_content),
167 }
168}
169
170#[derive(Debug, thiserror::Error)]
171pub enum StreamingError {
172 #[error("CompletionError: {0}")]
173 Completion(#[from] CompletionError),
174 #[error("PromptError: {0}")]
175 Prompt(#[from] Box<PromptError>),
176 #[error("ToolSetError: {0}")]
177 Tool(#[from] ToolSetError),
178}
179
180const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
181
182pub struct StreamingPromptRequest<M, P>
191where
192 M: CompletionModel,
193 P: PromptHook<M> + 'static,
194{
195 prompt: Message,
197 chat_history: Option<Vec<Message>>,
199 max_turns: usize,
201
202 model: Arc<M>,
205 agent_name: Option<String>,
207 preamble: Option<String>,
209 static_context: Vec<Document>,
211 temperature: Option<f64>,
213 max_tokens: Option<u64>,
215 additional_params: Option<serde_json::Value>,
217 tool_server_handle: ToolServerHandle,
219 dynamic_context: DynamicContextStore,
221 tool_choice: Option<ToolChoice>,
223 output_schema: Option<schemars::Schema>,
225 hook: Option<P>,
227}
228
229impl<M, P> StreamingPromptRequest<M, P>
230where
231 M: CompletionModel + 'static,
232 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
233 P: PromptHook<M>,
234{
235 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
238 StreamingPromptRequest {
239 prompt: prompt.into(),
240 chat_history: None,
241 max_turns: agent.default_max_turns.unwrap_or_default(),
242 model: agent.model.clone(),
243 agent_name: agent.name.clone(),
244 preamble: agent.preamble.clone(),
245 static_context: agent.static_context.clone(),
246 temperature: agent.temperature,
247 max_tokens: agent.max_tokens,
248 additional_params: agent.additional_params.clone(),
249 tool_server_handle: agent.tool_server_handle.clone(),
250 dynamic_context: agent.dynamic_context.clone(),
251 tool_choice: agent.tool_choice.clone(),
252 output_schema: agent.output_schema.clone(),
253 hook: None,
254 }
255 }
256
257 pub fn from_agent<P2>(
259 agent: &Agent<M, P2>,
260 prompt: impl Into<Message>,
261 ) -> StreamingPromptRequest<M, P2>
262 where
263 P2: PromptHook<M>,
264 {
265 StreamingPromptRequest {
266 prompt: prompt.into(),
267 chat_history: None,
268 max_turns: agent.default_max_turns.unwrap_or_default(),
269 model: agent.model.clone(),
270 agent_name: agent.name.clone(),
271 preamble: agent.preamble.clone(),
272 static_context: agent.static_context.clone(),
273 temperature: agent.temperature,
274 max_tokens: agent.max_tokens,
275 additional_params: agent.additional_params.clone(),
276 tool_server_handle: agent.tool_server_handle.clone(),
277 dynamic_context: agent.dynamic_context.clone(),
278 tool_choice: agent.tool_choice.clone(),
279 output_schema: agent.output_schema.clone(),
280 hook: agent.hook.clone(),
281 }
282 }
283
284 fn agent_name(&self) -> &str {
285 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
286 }
287
288 pub fn multi_turn(mut self, turns: usize) -> Self {
291 self.max_turns = turns;
292 self
293 }
294
295 pub fn with_history<I, T>(mut self, history: I) -> Self
308 where
309 I: IntoIterator<Item = T>,
310 T: Into<Message>,
311 {
312 self.chat_history = Some(history.into_iter().map(Into::into).collect());
313 self
314 }
315
316 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
319 where
320 P2: PromptHook<M>,
321 {
322 StreamingPromptRequest {
323 prompt: self.prompt,
324 chat_history: self.chat_history,
325 max_turns: self.max_turns,
326 model: self.model,
327 agent_name: self.agent_name,
328 preamble: self.preamble,
329 static_context: self.static_context,
330 temperature: self.temperature,
331 max_tokens: self.max_tokens,
332 additional_params: self.additional_params,
333 tool_server_handle: self.tool_server_handle,
334 dynamic_context: self.dynamic_context,
335 tool_choice: self.tool_choice,
336 output_schema: self.output_schema,
337 hook: Some(hook),
338 }
339 }
340
341 async fn send(self) -> StreamingResult<M::StreamingResponse> {
342 let agent_span = if tracing::Span::current().is_disabled() {
343 info_span!(
344 "invoke_agent",
345 gen_ai.operation.name = "invoke_agent",
346 gen_ai.agent.name = self.agent_name(),
347 gen_ai.system_instructions = self.preamble,
348 gen_ai.prompt = tracing::field::Empty,
349 gen_ai.completion = tracing::field::Empty,
350 gen_ai.usage.input_tokens = tracing::field::Empty,
351 gen_ai.usage.output_tokens = tracing::field::Empty,
352 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
353 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
354 )
355 } else {
356 tracing::Span::current()
357 };
358
359 let prompt = self.prompt;
360 if let Some(text) = prompt.rag_text() {
361 agent_span.record("gen_ai.prompt", text);
362 }
363
364 let model = self.model.clone();
366 let preamble = self.preamble.clone();
367 let static_context = self.static_context.clone();
368 let temperature = self.temperature;
369 let max_tokens = self.max_tokens;
370 let additional_params = self.additional_params.clone();
371 let tool_server_handle = self.tool_server_handle.clone();
372 let dynamic_context = self.dynamic_context.clone();
373 let tool_choice = self.tool_choice.clone();
374 let agent_name = self.agent_name.clone();
375 let has_history = self.chat_history.is_some();
376 let chat_history = self.chat_history;
377 let mut new_messages: Vec<Message> = vec![prompt.clone()];
378
379 let mut current_max_turns = 0;
380 let mut last_prompt_error = String::new();
381
382 let mut last_text_response = String::new();
383 let mut is_text_response = false;
384 let mut max_turns_reached = false;
385 let output_schema = self.output_schema;
386
387 let mut aggregated_usage = crate::completion::Usage::new();
388
389 let stream = async_stream::stream! {
396 let mut current_prompt = prompt.clone();
397
398 'outer: loop {
399 if current_max_turns > self.max_turns + 1 {
400 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
401 max_turns_reached = true;
402 break;
403 }
404
405 current_max_turns += 1;
406
407 if self.max_turns > 1 {
408 tracing::info!(
409 "Current conversation Turns: {}/{}",
410 current_max_turns,
411 self.max_turns
412 );
413 }
414
415 if let Some(ref hook) = self.hook {
416 let history_snapshot: Vec<Message> = build_history_for_request(chat_history.as_deref(), &new_messages[..new_messages.len().saturating_sub(1)]);
417 if let HookAction::Terminate { reason } = hook.on_completion_call(¤t_prompt, &history_snapshot)
418 .await {
419 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
420 break 'outer;
421 }
422 }
423
424 let chat_stream_span = info_span!(
425 target: "rig::agent_chat",
426 parent: tracing::Span::current(),
427 "chat_streaming",
428 gen_ai.operation.name = "chat",
429 gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
430 gen_ai.system_instructions = preamble,
431 gen_ai.provider.name = tracing::field::Empty,
432 gen_ai.request.model = tracing::field::Empty,
433 gen_ai.response.id = tracing::field::Empty,
434 gen_ai.response.model = tracing::field::Empty,
435 gen_ai.usage.output_tokens = tracing::field::Empty,
436 gen_ai.usage.input_tokens = tracing::field::Empty,
437 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
438 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
439 gen_ai.input.messages = tracing::field::Empty,
440 gen_ai.output.messages = tracing::field::Empty,
441 );
442
443 let history_snapshot: Vec<Message> = build_history_for_request(chat_history.as_deref(), &new_messages[..new_messages.len().saturating_sub(1)]);
444 let mut stream = tracing::Instrument::instrument(
445 build_completion_request(
446 &model,
447 current_prompt.clone(),
448 &history_snapshot,
449 preamble.as_deref(),
450 &static_context,
451 temperature,
452 max_tokens,
453 additional_params.as_ref(),
454 tool_choice.as_ref(),
455 &tool_server_handle,
456 &dynamic_context,
457 output_schema.as_ref(),
458 )
459 .await?
460 .stream(), chat_stream_span
461 )
462
463 .await?;
464
465 new_messages.push(current_prompt.clone());
466
467 let mut tool_calls = vec![];
468 let mut tool_results = vec![];
469 let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
470 let mut pending_reasoning_delta_text = String::new();
473 let mut pending_reasoning_delta_id: Option<String> = None;
474 let mut saw_tool_call_this_turn = false;
475
476 while let Some(content) = stream.next().await {
477 match content {
478 Ok(StreamedAssistantContent::Text(text)) => {
479 if !is_text_response {
480 last_text_response = String::new();
481 is_text_response = true;
482 }
483 last_text_response.push_str(&text.text);
484 if let Some(ref hook) = self.hook &&
485 let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &last_text_response).await {
486 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
487 break 'outer;
488 }
489
490 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
491 },
492 Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id }) => {
493 let tool_span = info_span!(
494 parent: tracing::Span::current(),
495 "execute_tool",
496 gen_ai.operation.name = "execute_tool",
497 gen_ai.tool.type = "function",
498 gen_ai.tool.name = tracing::field::Empty,
499 gen_ai.tool.call.id = tracing::field::Empty,
500 gen_ai.tool.call.arguments = tracing::field::Empty,
501 gen_ai.tool.call.result = tracing::field::Empty
502 );
503
504 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
505
506 let tc_result = async {
507 let tool_span = tracing::Span::current();
508 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
509 if let Some(ref hook) = self.hook {
510 let action = hook
511 .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
512 .await;
513
514 if let ToolCallHookAction::Terminate { reason } = action {
515 return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
516 }
517
518 if let ToolCallHookAction::Skip { reason } = action {
519 tracing::info!(
521 tool_name = tool_call.function.name.as_str(),
522 reason = reason,
523 "Tool call rejected"
524 );
525 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
526 tool_calls.push(tool_call_msg);
527 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
528 saw_tool_call_this_turn = true;
529 return Ok(reason);
530 }
531 }
532
533 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
534 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
535
536 let tool_result = match
537 tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
538 Ok(thing) => thing,
539 Err(e) => {
540 tracing::warn!("Error while calling tool: {e}");
541 e.to_string()
542 }
543 };
544
545 tool_span.record("gen_ai.tool.call.result", &tool_result);
546
547 if let Some(ref hook) = self.hook &&
548 let HookAction::Terminate { reason } =
549 hook.on_tool_result(
550 &tool_call.function.name,
551 tool_call.call_id.clone(),
552 &internal_call_id,
553 &tool_args,
554 &tool_result.to_string()
555 )
556 .await {
557 return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
558 }
559
560 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
561
562 tool_calls.push(tool_call_msg);
563 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
564
565 saw_tool_call_this_turn = true;
566 Ok(tool_result)
567 }.instrument(tool_span).await;
568
569 match tc_result {
570 Ok(text) => {
571 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
572 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
573 }
574 Err(e) => {
575 yield Err(e);
576 break 'outer;
577 }
578 }
579 },
580 Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content }) => {
581 if let Some(ref hook) = self.hook {
582 let (name, delta) = match &content {
583 rig::streaming::ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
584 rig::streaming::ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
585 };
586
587 if let HookAction::Terminate { reason } = hook.on_tool_call_delta(&id, &internal_call_id, name, delta)
588 .await {
589 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
590 break 'outer;
591 }
592 }
593 }
594 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
595 merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
599 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
600 },
601 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
602 pending_reasoning_delta_text.push_str(&reasoning);
606 if pending_reasoning_delta_id.is_none() {
607 pending_reasoning_delta_id = id.clone();
608 }
609 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
610 },
611 Ok(StreamedAssistantContent::Final(final_resp)) => {
612 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
613 if is_text_response {
614 if let Some(ref hook) = self.hook &&
615 let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(&prompt, &final_resp).await {
616 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
617 break 'outer;
618 }
619
620 tracing::Span::current().record("gen_ai.completion", &last_text_response);
621 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
622 is_text_response = false;
623 }
624 }
625 Err(e) => {
626 yield Err(e.into());
627 break 'outer;
628 }
629 }
630 }
631
632 if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
636 let mut assembled = crate::message::Reasoning::new(&pending_reasoning_delta_text);
637 if let Some(id) = pending_reasoning_delta_id.take() {
638 assembled = assembled.with_id(id);
639 }
640 accumulated_reasoning.push(assembled);
641 }
642
643 if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
646 let mut content_items: Vec<rig::message::AssistantContent> = vec![];
647
648 if !last_text_response.is_empty() {
650 content_items.push(rig::message::AssistantContent::text(&last_text_response));
651 last_text_response.clear();
652 }
653
654 for reasoning in accumulated_reasoning.drain(..) {
656 content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
657 }
658
659 content_items.extend(tool_calls.clone());
660
661 if !content_items.is_empty() {
662 new_messages.push(Message::Assistant {
663 id: stream.message_id.clone(),
664 content: OneOrMany::many(content_items).expect("Should have at least one item"),
665 });
666 }
667 }
668
669 for (id, call_id, tool_result) in tool_results {
670 new_messages.push(tool_result_to_user_message(id, call_id, tool_result));
671 }
672
673 current_prompt = match new_messages.pop() {
675 Some(prompt) => prompt,
676 None => unreachable!("New messages should never be empty at this point"),
677 };
678
679 if !saw_tool_call_this_turn {
680 new_messages.push(current_prompt.clone());
682 if !last_text_response.is_empty() {
683 new_messages.push(Message::assistant(&last_text_response));
684 }
685
686 let current_span = tracing::Span::current();
687 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
688 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
689 current_span.record("gen_ai.usage.cache_read.input_tokens", aggregated_usage.cached_input_tokens);
690 current_span.record("gen_ai.usage.cache_creation.input_tokens", aggregated_usage.cache_creation_input_tokens);
691 tracing::info!("Agent multi-turn stream finished");
692 let final_messages: Option<Vec<Message>> = if has_history {
693 Some(new_messages.clone())
694 } else {
695 None
696 };
697 yield Ok(MultiTurnStreamItem::final_response_with_history(
698 &last_text_response,
699 aggregated_usage,
700 final_messages,
701 ));
702 break;
703 }
704 }
705
706 if max_turns_reached {
707 yield Err(Box::new(PromptError::MaxTurnsError {
708 max_turns: self.max_turns,
709 chat_history: build_full_history(chat_history.as_deref(), new_messages.clone()).into(),
710 prompt: Box::new(last_prompt_error.clone().into()),
711 }).into());
712 }
713 };
714
715 Box::pin(stream.instrument(agent_span))
716 }
717}
718
719impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
720where
721 M: CompletionModel + 'static,
722 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
723 P: PromptHook<M> + 'static,
724{
725 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
727
728 fn into_future(self) -> Self::IntoFuture {
729 Box::pin(async move { self.send().await })
731 }
732}
733
734pub async fn stream_to_stdout<R>(
736 stream: &mut StreamingResult<R>,
737) -> Result<FinalResponse, std::io::Error> {
738 let mut final_res = FinalResponse::empty();
739 print!("Response: ");
740 while let Some(content) = stream.next().await {
741 match content {
742 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
743 Text { text },
744 ))) => {
745 print!("{text}");
746 std::io::Write::flush(&mut std::io::stdout()).unwrap();
747 }
748 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
749 reasoning,
750 ))) => {
751 let reasoning = reasoning.display_text();
752 print!("{reasoning}");
753 std::io::Write::flush(&mut std::io::stdout()).unwrap();
754 }
755 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
756 final_res = res;
757 }
758 Err(err) => {
759 eprintln!("Error: {err}");
760 }
761 _ => {}
762 }
763 }
764
765 Ok(final_res)
766}
767
768#[cfg(test)]
769mod tests {
770 use super::*;
771 use crate::agent::AgentBuilder;
772 use crate::client::ProviderClient;
773 use crate::client::completion::CompletionClient;
774 use crate::completion::{
775 CompletionError, CompletionModel, CompletionRequest, CompletionResponse,
776 };
777 use crate::message::ReasoningContent;
778 use crate::providers::anthropic;
779 use crate::streaming::StreamingPrompt;
780 use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
781 use futures::StreamExt;
782 use serde::{Deserialize, Serialize};
783 use std::sync::Arc;
784 use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
785 use std::time::Duration;
786
787 #[test]
788 fn merge_reasoning_blocks_preserves_order_and_signatures() {
789 let mut accumulated = Vec::new();
790 let first = crate::message::Reasoning {
791 id: Some("rs_1".to_string()),
792 content: vec![ReasoningContent::Text {
793 text: "step-1".to_string(),
794 signature: Some("sig-1".to_string()),
795 }],
796 };
797 let second = crate::message::Reasoning {
798 id: Some("rs_1".to_string()),
799 content: vec![
800 ReasoningContent::Text {
801 text: "step-2".to_string(),
802 signature: Some("sig-2".to_string()),
803 },
804 ReasoningContent::Summary("summary".to_string()),
805 ],
806 };
807
808 merge_reasoning_blocks(&mut accumulated, &first);
809 merge_reasoning_blocks(&mut accumulated, &second);
810
811 assert_eq!(accumulated.len(), 1);
812 let merged = accumulated.first().expect("expected accumulated reasoning");
813 assert_eq!(merged.id.as_deref(), Some("rs_1"));
814 assert_eq!(merged.content.len(), 3);
815 assert!(matches!(
816 merged.content.first(),
817 Some(ReasoningContent::Text { text, signature: Some(sig) })
818 if text == "step-1" && sig == "sig-1"
819 ));
820 assert!(matches!(
821 merged.content.get(1),
822 Some(ReasoningContent::Text { text, signature: Some(sig) })
823 if text == "step-2" && sig == "sig-2"
824 ));
825 }
826
827 #[test]
828 fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
829 let mut accumulated = vec![crate::message::Reasoning {
830 id: Some("rs_a".to_string()),
831 content: vec![ReasoningContent::Text {
832 text: "step-1".to_string(),
833 signature: None,
834 }],
835 }];
836 let incoming = crate::message::Reasoning {
837 id: Some("rs_b".to_string()),
838 content: vec![ReasoningContent::Text {
839 text: "step-2".to_string(),
840 signature: None,
841 }],
842 };
843
844 merge_reasoning_blocks(&mut accumulated, &incoming);
845 assert_eq!(accumulated.len(), 2);
846 assert_eq!(
847 accumulated.first().and_then(|r| r.id.as_deref()),
848 Some("rs_a")
849 );
850 assert_eq!(
851 accumulated.get(1).and_then(|r| r.id.as_deref()),
852 Some("rs_b")
853 );
854 }
855
856 #[test]
857 fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
858 let mut accumulated = vec![crate::message::Reasoning {
859 id: None,
860 content: vec![ReasoningContent::Text {
861 text: "first".to_string(),
862 signature: None,
863 }],
864 }];
865 let incoming = crate::message::Reasoning {
866 id: None,
867 content: vec![ReasoningContent::Text {
868 text: "second".to_string(),
869 signature: None,
870 }],
871 };
872
873 merge_reasoning_blocks(&mut accumulated, &incoming);
874 assert_eq!(accumulated.len(), 2);
875 assert!(matches!(
876 accumulated.first(),
877 Some(crate::message::Reasoning {
878 id: None,
879 content
880 }) if matches!(
881 content.first(),
882 Some(ReasoningContent::Text { text, .. }) if text == "first"
883 )
884 ));
885 assert!(matches!(
886 accumulated.get(1),
887 Some(crate::message::Reasoning {
888 id: None,
889 content
890 }) if matches!(
891 content.first(),
892 Some(ReasoningContent::Text { text, .. }) if text == "second"
893 )
894 ));
895 }
896
897 #[derive(Clone, Debug, Deserialize, Serialize)]
898 struct MockStreamingResponse {
899 usage: crate::completion::Usage,
900 }
901
902 impl MockStreamingResponse {
903 fn new(total_tokens: u64) -> Self {
904 let mut usage = crate::completion::Usage::new();
905 usage.total_tokens = total_tokens;
906 Self { usage }
907 }
908 }
909
910 impl crate::completion::GetTokenUsage for MockStreamingResponse {
911 fn token_usage(&self) -> Option<crate::completion::Usage> {
912 Some(self.usage)
913 }
914 }
915
916 #[derive(Clone, Default)]
917 struct MultiTurnMockModel {
918 turn_counter: Arc<AtomicUsize>,
919 }
920
921 #[allow(refining_impl_trait)]
922 impl CompletionModel for MultiTurnMockModel {
923 type Response = ();
924 type StreamingResponse = MockStreamingResponse;
925 type Client = ();
926
927 fn make(_: &Self::Client, _: impl Into<String>) -> Self {
928 Self::default()
929 }
930
931 async fn completion(
932 &self,
933 _request: CompletionRequest,
934 ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
935 Err(CompletionError::ProviderError(
936 "completion is unused in this streaming test".to_string(),
937 ))
938 }
939
940 async fn stream(
941 &self,
942 _request: CompletionRequest,
943 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
944 let turn = self.turn_counter.fetch_add(1, Ordering::SeqCst);
945 let stream = async_stream::stream! {
946 if turn == 0 {
947 yield Ok(RawStreamingChoice::ToolCall(
948 RawStreamingToolCall::new(
949 "tool_call_1".to_string(),
950 "missing_tool".to_string(),
951 serde_json::json!({"input": "value"}),
952 )
953 .with_call_id("call_1".to_string()),
954 ));
955 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(4)));
956 } else {
957 yield Ok(RawStreamingChoice::Message("done".to_string()));
958 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(6)));
959 }
960 };
961
962 let pinned_stream: crate::streaming::StreamingResult<Self::StreamingResponse> =
963 Box::pin(stream);
964 Ok(StreamingCompletionResponse::stream(pinned_stream))
965 }
966 }
967
968 #[tokio::test]
969 async fn stream_prompt_continues_after_tool_call_turn() {
970 let model = MultiTurnMockModel::default();
971 let turn_counter = model.turn_counter.clone();
972 let agent = AgentBuilder::new(model).build();
973
974 let mut stream = agent.stream_prompt("do tool work").multi_turn(3).await;
975 let mut saw_tool_call = false;
976 let mut saw_tool_result = false;
977 let mut saw_final_response = false;
978 let mut final_text = String::new();
979
980 while let Some(item) = stream.next().await {
981 match item {
982 Ok(MultiTurnStreamItem::StreamAssistantItem(
983 StreamedAssistantContent::ToolCall { .. },
984 )) => {
985 saw_tool_call = true;
986 }
987 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
988 ..
989 })) => {
990 saw_tool_result = true;
991 }
992 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
993 text,
994 ))) => {
995 final_text.push_str(&text.text);
996 }
997 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
998 saw_final_response = true;
999 break;
1000 }
1001 Ok(_) => {}
1002 Err(err) => panic!("unexpected streaming error: {err:?}"),
1003 }
1004 }
1005
1006 assert!(saw_tool_call);
1007 assert!(saw_tool_result);
1008 assert!(saw_final_response);
1009 assert_eq!(final_text, "done");
1010 assert_eq!(turn_counter.load(Ordering::SeqCst), 2);
1011 }
1012
1013 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
1016 let mut interval = tokio::time::interval(Duration::from_millis(50));
1017 let mut count = 0u32;
1018
1019 while !stop.load(Ordering::Relaxed) {
1020 interval.tick().await;
1021 count += 1;
1022
1023 tracing::event!(
1024 target: "background_logger",
1025 tracing::Level::INFO,
1026 count = count,
1027 "Background tick"
1028 );
1029
1030 let current = tracing::Span::current();
1032 if !current.is_disabled() && !current.is_none() {
1033 leak_count.fetch_add(1, Ordering::Relaxed);
1034 }
1035 }
1036
1037 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
1038 }
1039
1040 #[tokio::test(flavor = "current_thread")]
1048 #[ignore = "This requires an API key"]
1049 async fn test_span_context_isolation() {
1050 let stop = Arc::new(AtomicBool::new(false));
1051 let leak_count = Arc::new(AtomicU32::new(0));
1052
1053 let bg_stop = stop.clone();
1055 let bg_leak = leak_count.clone();
1056 let bg_handle = tokio::spawn(async move {
1057 background_logger(bg_stop, bg_leak).await;
1058 });
1059
1060 tokio::time::sleep(Duration::from_millis(100)).await;
1062
1063 let client = anthropic::Client::from_env();
1066 let agent = client
1067 .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
1068 .preamble("You are a helpful assistant.")
1069 .temperature(0.1)
1070 .max_tokens(100)
1071 .build();
1072
1073 let mut stream = agent
1074 .stream_prompt("Say 'hello world' and nothing else.")
1075 .await;
1076
1077 let mut full_content = String::new();
1078 while let Some(item) = stream.next().await {
1079 match item {
1080 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1081 text,
1082 ))) => {
1083 full_content.push_str(&text.text);
1084 }
1085 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
1086 break;
1087 }
1088 Err(e) => {
1089 tracing::warn!("Error: {:?}", e);
1090 break;
1091 }
1092 _ => {}
1093 }
1094 }
1095
1096 tracing::info!("Got response: {:?}", full_content);
1097
1098 stop.store(true, Ordering::Relaxed);
1100 bg_handle.await.unwrap();
1101
1102 let leaks = leak_count.load(Ordering::Relaxed);
1103 assert_eq!(
1104 leaks, 0,
1105 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
1106 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
1107 );
1108 }
1109
1110 #[tokio::test]
1116 #[ignore = "This requires an API key"]
1117 async fn test_chat_history_in_final_response() {
1118 use crate::message::Message;
1119
1120 let client = anthropic::Client::from_env();
1121 let agent = client
1122 .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
1123 .preamble("You are a helpful assistant. Keep responses brief.")
1124 .temperature(0.1)
1125 .max_tokens(50)
1126 .build();
1127
1128 let empty_history: &[Message] = &[];
1130 let mut stream = agent
1131 .stream_prompt("Say 'hello' and nothing else.")
1132 .with_history(empty_history)
1133 .await;
1134
1135 let mut response_text = String::new();
1137 let mut final_history = None;
1138 while let Some(item) = stream.next().await {
1139 match item {
1140 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1141 text,
1142 ))) => {
1143 response_text.push_str(&text.text);
1144 }
1145 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1146 final_history = res.history().map(|h| h.to_vec());
1147 break;
1148 }
1149 Err(e) => {
1150 panic!("Streaming error: {:?}", e);
1151 }
1152 _ => {}
1153 }
1154 }
1155
1156 let history =
1157 final_history.expect("FinalResponse should contain history when with_history is used");
1158
1159 assert!(
1161 history.iter().any(|m| matches!(m, Message::User { .. })),
1162 "History should contain the user message"
1163 );
1164
1165 assert!(
1167 history
1168 .iter()
1169 .any(|m| matches!(m, Message::Assistant { .. })),
1170 "History should contain the assistant response"
1171 );
1172
1173 tracing::info!(
1174 "History after streaming: {} messages, response: {:?}",
1175 history.len(),
1176 response_text
1177 );
1178 }
1179}