1use crate::{
2 OneOrMany,
3 agent::completion::{DynamicContextStore, build_completion_request},
4 agent::prompt_request::{HookAction, hooks::PromptHook},
5 completion::{Document, GetTokenUsage},
6 json_utils,
7 message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent},
8 streaming::{StreamedAssistantContent, StreamedUserContent},
9 tool::server::ToolServerHandle,
10 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
11};
12use futures::{Stream, StreamExt};
13use serde::{Deserialize, Serialize};
14use std::{pin::Pin, sync::Arc};
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::ToolCallHookAction;
19use crate::{
20 agent::Agent,
21 completion::{CompletionError, CompletionModel, PromptError},
22 message::{Message, Text},
23 tool::ToolSetError,
24};
25
26#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
27pub type StreamingResult<R> =
28 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
29
30#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
31pub type StreamingResult<R> =
32 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(tag = "type", rename_all = "camelCase")]
36#[non_exhaustive]
37pub enum MultiTurnStreamItem<R> {
38 StreamAssistantItem(StreamedAssistantContent<R>),
40 StreamUserItem(StreamedUserContent),
42 FinalResponse(FinalResponse),
44}
45
46#[derive(Deserialize, Serialize, Debug, Clone)]
47#[serde(rename_all = "camelCase")]
48pub struct FinalResponse {
49 response: String,
50 aggregated_usage: crate::completion::Usage,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 history: Option<Vec<Message>>,
53}
54
55impl FinalResponse {
56 pub fn empty() -> Self {
57 Self {
58 response: String::new(),
59 aggregated_usage: crate::completion::Usage::new(),
60 history: None,
61 }
62 }
63
64 pub fn response(&self) -> &str {
65 &self.response
66 }
67
68 pub fn usage(&self) -> crate::completion::Usage {
69 self.aggregated_usage
70 }
71
72 pub fn history(&self) -> Option<&[Message]> {
73 self.history.as_deref()
74 }
75}
76
77impl<R> MultiTurnStreamItem<R> {
78 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
79 Self::StreamAssistantItem(item)
80 }
81
82 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
83 Self::FinalResponse(FinalResponse {
84 response: response.to_string(),
85 aggregated_usage,
86 history: None,
87 })
88 }
89
90 pub fn final_response_with_history(
91 response: &str,
92 aggregated_usage: crate::completion::Usage,
93 history: Option<Vec<Message>>,
94 ) -> Self {
95 Self::FinalResponse(FinalResponse {
96 response: response.to_string(),
97 aggregated_usage,
98 history,
99 })
100 }
101}
102
103fn merge_reasoning_blocks(
104 accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
105 incoming: &crate::message::Reasoning,
106) {
107 let ids_match = |existing: &crate::message::Reasoning| {
108 matches!(
109 (&existing.id, &incoming.id),
110 (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
111 )
112 };
113
114 if let Some(existing) = accumulated_reasoning
115 .iter_mut()
116 .rev()
117 .find(|existing| ids_match(existing))
118 {
119 existing.content.extend(incoming.content.clone());
120 } else {
121 accumulated_reasoning.push(incoming.clone());
122 }
123}
124
125async fn cancelled_prompt_error(chat_history: &Vec<Message>, reason: String) -> StreamingError {
126 StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.to_owned(), reason).into())
127}
128
129fn tool_result_to_user_message(
130 id: String,
131 call_id: Option<String>,
132 tool_result: String,
133) -> Message {
134 let content = OneOrMany::one(ToolResultContent::text(tool_result));
135 let user_content = match call_id {
136 Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
137 None => UserContent::tool_result(id, content),
138 };
139
140 Message::User {
141 content: OneOrMany::one(user_content),
142 }
143}
144
145#[derive(Debug, thiserror::Error)]
146pub enum StreamingError {
147 #[error("CompletionError: {0}")]
148 Completion(#[from] CompletionError),
149 #[error("PromptError: {0}")]
150 Prompt(#[from] Box<PromptError>),
151 #[error("ToolSetError: {0}")]
152 Tool(#[from] ToolSetError),
153}
154
155const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
156
157pub struct StreamingPromptRequest<M, P>
166where
167 M: CompletionModel,
168 P: PromptHook<M> + 'static,
169{
170 prompt: Message,
172 chat_history: Option<Vec<Message>>,
174 max_turns: usize,
176
177 model: Arc<M>,
180 agent_name: Option<String>,
182 preamble: Option<String>,
184 static_context: Vec<Document>,
186 temperature: Option<f64>,
188 max_tokens: Option<u64>,
190 additional_params: Option<serde_json::Value>,
192 tool_server_handle: ToolServerHandle,
194 dynamic_context: DynamicContextStore,
196 tool_choice: Option<ToolChoice>,
198 output_schema: Option<schemars::Schema>,
200 hook: Option<P>,
202}
203
204impl<M, P> StreamingPromptRequest<M, P>
205where
206 M: CompletionModel + 'static,
207 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
208 P: PromptHook<M>,
209{
210 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
213 StreamingPromptRequest {
214 prompt: prompt.into(),
215 chat_history: None,
216 max_turns: agent.default_max_turns.unwrap_or_default(),
217 model: agent.model.clone(),
218 agent_name: agent.name.clone(),
219 preamble: agent.preamble.clone(),
220 static_context: agent.static_context.clone(),
221 temperature: agent.temperature,
222 max_tokens: agent.max_tokens,
223 additional_params: agent.additional_params.clone(),
224 tool_server_handle: agent.tool_server_handle.clone(),
225 dynamic_context: agent.dynamic_context.clone(),
226 tool_choice: agent.tool_choice.clone(),
227 output_schema: agent.output_schema.clone(),
228 hook: None,
229 }
230 }
231
232 pub fn from_agent<P2>(
234 agent: &Agent<M, P2>,
235 prompt: impl Into<Message>,
236 ) -> StreamingPromptRequest<M, P2>
237 where
238 P2: PromptHook<M>,
239 {
240 StreamingPromptRequest {
241 prompt: prompt.into(),
242 chat_history: None,
243 max_turns: agent.default_max_turns.unwrap_or_default(),
244 model: agent.model.clone(),
245 agent_name: agent.name.clone(),
246 preamble: agent.preamble.clone(),
247 static_context: agent.static_context.clone(),
248 temperature: agent.temperature,
249 max_tokens: agent.max_tokens,
250 additional_params: agent.additional_params.clone(),
251 tool_server_handle: agent.tool_server_handle.clone(),
252 dynamic_context: agent.dynamic_context.clone(),
253 tool_choice: agent.tool_choice.clone(),
254 output_schema: agent.output_schema.clone(),
255 hook: agent.hook.clone(),
256 }
257 }
258
259 fn agent_name(&self) -> &str {
260 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
261 }
262
263 pub fn multi_turn(mut self, turns: usize) -> Self {
266 self.max_turns = turns;
267 self
268 }
269
270 pub fn with_history(mut self, history: Vec<Message>) -> Self {
283 self.chat_history = Some(history);
284 self
285 }
286
287 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
290 where
291 P2: PromptHook<M>,
292 {
293 StreamingPromptRequest {
294 prompt: self.prompt,
295 chat_history: self.chat_history,
296 max_turns: self.max_turns,
297 model: self.model,
298 agent_name: self.agent_name,
299 preamble: self.preamble,
300 static_context: self.static_context,
301 temperature: self.temperature,
302 max_tokens: self.max_tokens,
303 additional_params: self.additional_params,
304 tool_server_handle: self.tool_server_handle,
305 dynamic_context: self.dynamic_context,
306 tool_choice: self.tool_choice,
307 output_schema: self.output_schema,
308 hook: Some(hook),
309 }
310 }
311
312 async fn send(self) -> StreamingResult<M::StreamingResponse> {
313 let agent_span = if tracing::Span::current().is_disabled() {
314 info_span!(
315 "invoke_agent",
316 gen_ai.operation.name = "invoke_agent",
317 gen_ai.agent.name = self.agent_name(),
318 gen_ai.system_instructions = self.preamble,
319 gen_ai.prompt = tracing::field::Empty,
320 gen_ai.completion = tracing::field::Empty,
321 gen_ai.usage.input_tokens = tracing::field::Empty,
322 gen_ai.usage.output_tokens = tracing::field::Empty,
323 )
324 } else {
325 tracing::Span::current()
326 };
327
328 let prompt = self.prompt;
329 if let Some(text) = prompt.rag_text() {
330 agent_span.record("gen_ai.prompt", text);
331 }
332
333 let model = self.model.clone();
335 let preamble = self.preamble.clone();
336 let static_context = self.static_context.clone();
337 let temperature = self.temperature;
338 let max_tokens = self.max_tokens;
339 let additional_params = self.additional_params.clone();
340 let tool_server_handle = self.tool_server_handle.clone();
341 let dynamic_context = self.dynamic_context.clone();
342 let tool_choice = self.tool_choice.clone();
343 let agent_name = self.agent_name.clone();
344 let has_history = self.chat_history.is_some();
345 let mut chat_history = self.chat_history.unwrap_or_default();
346
347 let mut current_max_turns = 0;
348 let mut last_prompt_error = String::new();
349
350 let mut last_text_response = String::new();
351 let mut is_text_response = false;
352 let mut max_turns_reached = false;
353 let output_schema = self.output_schema;
354
355 let mut aggregated_usage = crate::completion::Usage::new();
356
357 let stream = async_stream::stream! {
364 let mut current_prompt = prompt.clone();
365
366 'outer: loop {
367 if current_max_turns > self.max_turns + 1 {
368 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
369 max_turns_reached = true;
370 break;
371 }
372
373 current_max_turns += 1;
374
375 if self.max_turns > 1 {
376 tracing::info!(
377 "Current conversation Turns: {}/{}",
378 current_max_turns,
379 self.max_turns
380 );
381 }
382
383 if let Some(ref hook) = self.hook {
384 let history_snapshot = chat_history.clone();
385 if let HookAction::Terminate { reason } = hook.on_completion_call(¤t_prompt, &history_snapshot)
386 .await {
387 yield Err(cancelled_prompt_error(&chat_history, reason).await);
388 break 'outer;
389 }
390 }
391
392 let chat_stream_span = info_span!(
393 target: "rig::agent_chat",
394 parent: tracing::Span::current(),
395 "chat_streaming",
396 gen_ai.operation.name = "chat",
397 gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
398 gen_ai.system_instructions = preamble,
399 gen_ai.provider.name = tracing::field::Empty,
400 gen_ai.request.model = tracing::field::Empty,
401 gen_ai.response.id = tracing::field::Empty,
402 gen_ai.response.model = tracing::field::Empty,
403 gen_ai.usage.output_tokens = tracing::field::Empty,
404 gen_ai.usage.input_tokens = tracing::field::Empty,
405 gen_ai.input.messages = tracing::field::Empty,
406 gen_ai.output.messages = tracing::field::Empty,
407 );
408
409 let history_snapshot = chat_history.clone();
410 let mut stream = tracing::Instrument::instrument(
411 build_completion_request(
412 &model,
413 current_prompt.clone(),
414 history_snapshot,
415 preamble.as_deref(),
416 &static_context,
417 temperature,
418 max_tokens,
419 additional_params.as_ref(),
420 tool_choice.as_ref(),
421 &tool_server_handle,
422 &dynamic_context,
423 output_schema.as_ref(),
424 )
425 .await?
426 .stream(), chat_stream_span
427 )
428
429 .await?;
430
431 chat_history.push(current_prompt.clone());
432
433 let mut tool_calls = vec![];
434 let mut tool_results = vec![];
435 let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
436 let mut pending_reasoning_delta_text = String::new();
439 let mut pending_reasoning_delta_id: Option<String> = None;
440 let mut saw_tool_call_this_turn = false;
441
442 while let Some(content) = stream.next().await {
443 match content {
444 Ok(StreamedAssistantContent::Text(text)) => {
445 if !is_text_response {
446 last_text_response = String::new();
447 is_text_response = true;
448 }
449 last_text_response.push_str(&text.text);
450 if let Some(ref hook) = self.hook &&
451 let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &last_text_response).await {
452 yield Err(cancelled_prompt_error(&chat_history, reason).await);
453 break 'outer;
454 }
455
456 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
457 },
458 Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id }) => {
459 let tool_span = info_span!(
460 parent: tracing::Span::current(),
461 "execute_tool",
462 gen_ai.operation.name = "execute_tool",
463 gen_ai.tool.type = "function",
464 gen_ai.tool.name = tracing::field::Empty,
465 gen_ai.tool.call.id = tracing::field::Empty,
466 gen_ai.tool.call.arguments = tracing::field::Empty,
467 gen_ai.tool.call.result = tracing::field::Empty
468 );
469
470 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
471
472 let tc_result = async {
473 let tool_span = tracing::Span::current();
474 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
475 if let Some(ref hook) = self.hook {
476 let action = hook
477 .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
478 .await;
479
480 if let ToolCallHookAction::Terminate { reason } = action {
481 return Err(cancelled_prompt_error(&chat_history, reason).await);
482 }
483
484 if let ToolCallHookAction::Skip { reason } = action {
485 tracing::info!(
487 tool_name = tool_call.function.name.as_str(),
488 reason = reason,
489 "Tool call rejected"
490 );
491 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
492 tool_calls.push(tool_call_msg);
493 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
494 saw_tool_call_this_turn = true;
495 return Ok(reason);
496 }
497 }
498
499 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
500 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
501
502 let tool_result = match
503 tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
504 Ok(thing) => thing,
505 Err(e) => {
506 tracing::warn!("Error while calling tool: {e}");
507 e.to_string()
508 }
509 };
510
511 tool_span.record("gen_ai.tool.call.result", &tool_result);
512
513 if let Some(ref hook) = self.hook &&
514 let HookAction::Terminate { reason } =
515 hook.on_tool_result(
516 &tool_call.function.name,
517 tool_call.call_id.clone(),
518 &internal_call_id,
519 &tool_args,
520 &tool_result.to_string()
521 )
522 .await {
523 return Err(cancelled_prompt_error(&chat_history, reason).await);
524 }
525
526 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
527
528 tool_calls.push(tool_call_msg);
529 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
530
531 saw_tool_call_this_turn = true;
532 Ok(tool_result)
533 }.instrument(tool_span).await;
534
535 match tc_result {
536 Ok(text) => {
537 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
538 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
539 }
540 Err(e) => {
541 yield Err(e);
542 break 'outer;
543 }
544 }
545 },
546 Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content }) => {
547 if let Some(ref hook) = self.hook {
548 let (name, delta) = match &content {
549 rig::streaming::ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
550 rig::streaming::ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
551 };
552
553 if let HookAction::Terminate { reason } = hook.on_tool_call_delta(&id, &internal_call_id, name, delta)
554 .await {
555 yield Err(cancelled_prompt_error(&chat_history, reason).await);
556 break 'outer;
557 }
558 }
559 }
560 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
561 merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
565 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
566 },
567 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
568 pending_reasoning_delta_text.push_str(&reasoning);
572 if pending_reasoning_delta_id.is_none() {
573 pending_reasoning_delta_id = id.clone();
574 }
575 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
576 },
577 Ok(StreamedAssistantContent::Final(final_resp)) => {
578 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
579 if is_text_response {
580 if let Some(ref hook) = self.hook &&
581 let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(&prompt, &final_resp).await {
582 yield Err(cancelled_prompt_error(&chat_history, reason).await);
583 break 'outer;
584 }
585
586 tracing::Span::current().record("gen_ai.completion", &last_text_response);
587 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
588 is_text_response = false;
589 }
590 }
591 Err(e) => {
592 yield Err(e.into());
593 break 'outer;
594 }
595 }
596 }
597
598 if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
602 let mut assembled = crate::message::Reasoning::new(&pending_reasoning_delta_text);
603 if let Some(id) = pending_reasoning_delta_id.take() {
604 assembled = assembled.with_id(id);
605 }
606 accumulated_reasoning.push(assembled);
607 }
608
609 if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
612 let mut content_items: Vec<rig::message::AssistantContent> = vec![];
613
614 for reasoning in accumulated_reasoning.drain(..) {
616 content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
617 }
618
619 content_items.extend(tool_calls.clone());
620
621 if !content_items.is_empty() {
622 chat_history.push(Message::Assistant {
623 id: stream.message_id.clone(),
624 content: OneOrMany::many(content_items).expect("Should have at least one item"),
625 });
626 }
627 }
628
629 for (id, call_id, tool_result) in tool_results {
630 chat_history.push(tool_result_to_user_message(id, call_id, tool_result));
631 }
632
633 current_prompt = match chat_history.pop() {
635 Some(prompt) => prompt,
636 None => unreachable!("Chat history should never be empty at this point"),
637 };
638
639 if !saw_tool_call_this_turn {
640 chat_history.push(current_prompt.clone());
642 if !last_text_response.is_empty() {
643 chat_history.push(Message::assistant(&last_text_response));
644 }
645
646 let current_span = tracing::Span::current();
647 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
648 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
649 tracing::info!("Agent multi-turn stream finished");
650 let history_snapshot = if has_history {
651 Some(chat_history.clone())
652 } else {
653 None
654 };
655 yield Ok(MultiTurnStreamItem::final_response_with_history(
656 &last_text_response,
657 aggregated_usage,
658 history_snapshot,
659 ));
660 break;
661 }
662 }
663
664 if max_turns_reached {
665 yield Err(Box::new(PromptError::MaxTurnsError {
666 max_turns: self.max_turns,
667 chat_history: Box::new(chat_history.clone()),
668 prompt: Box::new(last_prompt_error.clone().into()),
669 }).into());
670 }
671 };
672
673 Box::pin(stream.instrument(agent_span))
674 }
675}
676
677impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
678where
679 M: CompletionModel + 'static,
680 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
681 P: PromptHook<M> + 'static,
682{
683 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
685
686 fn into_future(self) -> Self::IntoFuture {
687 Box::pin(async move { self.send().await })
689 }
690}
691
692pub async fn stream_to_stdout<R>(
694 stream: &mut StreamingResult<R>,
695) -> Result<FinalResponse, std::io::Error> {
696 let mut final_res = FinalResponse::empty();
697 print!("Response: ");
698 while let Some(content) = stream.next().await {
699 match content {
700 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
701 Text { text },
702 ))) => {
703 print!("{text}");
704 std::io::Write::flush(&mut std::io::stdout()).unwrap();
705 }
706 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
707 reasoning,
708 ))) => {
709 let reasoning = reasoning.display_text();
710 print!("{reasoning}");
711 std::io::Write::flush(&mut std::io::stdout()).unwrap();
712 }
713 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
714 final_res = res;
715 }
716 Err(err) => {
717 eprintln!("Error: {err}");
718 }
719 _ => {}
720 }
721 }
722
723 Ok(final_res)
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729 use crate::agent::AgentBuilder;
730 use crate::client::ProviderClient;
731 use crate::client::completion::CompletionClient;
732 use crate::completion::{
733 CompletionError, CompletionModel, CompletionRequest, CompletionResponse,
734 };
735 use crate::message::ReasoningContent;
736 use crate::providers::anthropic;
737 use crate::streaming::StreamingPrompt;
738 use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
739 use futures::StreamExt;
740 use serde::{Deserialize, Serialize};
741 use std::sync::Arc;
742 use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
743 use std::time::Duration;
744
745 #[test]
746 fn merge_reasoning_blocks_preserves_order_and_signatures() {
747 let mut accumulated = Vec::new();
748 let first = crate::message::Reasoning {
749 id: Some("rs_1".to_string()),
750 content: vec![ReasoningContent::Text {
751 text: "step-1".to_string(),
752 signature: Some("sig-1".to_string()),
753 }],
754 };
755 let second = crate::message::Reasoning {
756 id: Some("rs_1".to_string()),
757 content: vec![
758 ReasoningContent::Text {
759 text: "step-2".to_string(),
760 signature: Some("sig-2".to_string()),
761 },
762 ReasoningContent::Summary("summary".to_string()),
763 ],
764 };
765
766 merge_reasoning_blocks(&mut accumulated, &first);
767 merge_reasoning_blocks(&mut accumulated, &second);
768
769 assert_eq!(accumulated.len(), 1);
770 let merged = accumulated.first().expect("expected accumulated reasoning");
771 assert_eq!(merged.id.as_deref(), Some("rs_1"));
772 assert_eq!(merged.content.len(), 3);
773 assert!(matches!(
774 merged.content.first(),
775 Some(ReasoningContent::Text { text, signature: Some(sig) })
776 if text == "step-1" && sig == "sig-1"
777 ));
778 assert!(matches!(
779 merged.content.get(1),
780 Some(ReasoningContent::Text { text, signature: Some(sig) })
781 if text == "step-2" && sig == "sig-2"
782 ));
783 }
784
785 #[test]
786 fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
787 let mut accumulated = vec![crate::message::Reasoning {
788 id: Some("rs_a".to_string()),
789 content: vec![ReasoningContent::Text {
790 text: "step-1".to_string(),
791 signature: None,
792 }],
793 }];
794 let incoming = crate::message::Reasoning {
795 id: Some("rs_b".to_string()),
796 content: vec![ReasoningContent::Text {
797 text: "step-2".to_string(),
798 signature: None,
799 }],
800 };
801
802 merge_reasoning_blocks(&mut accumulated, &incoming);
803 assert_eq!(accumulated.len(), 2);
804 assert_eq!(
805 accumulated.first().and_then(|r| r.id.as_deref()),
806 Some("rs_a")
807 );
808 assert_eq!(
809 accumulated.get(1).and_then(|r| r.id.as_deref()),
810 Some("rs_b")
811 );
812 }
813
814 #[test]
815 fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
816 let mut accumulated = vec![crate::message::Reasoning {
817 id: None,
818 content: vec![ReasoningContent::Text {
819 text: "first".to_string(),
820 signature: None,
821 }],
822 }];
823 let incoming = crate::message::Reasoning {
824 id: None,
825 content: vec![ReasoningContent::Text {
826 text: "second".to_string(),
827 signature: None,
828 }],
829 };
830
831 merge_reasoning_blocks(&mut accumulated, &incoming);
832 assert_eq!(accumulated.len(), 2);
833 assert!(matches!(
834 accumulated.first(),
835 Some(crate::message::Reasoning {
836 id: None,
837 content
838 }) if matches!(
839 content.first(),
840 Some(ReasoningContent::Text { text, .. }) if text == "first"
841 )
842 ));
843 assert!(matches!(
844 accumulated.get(1),
845 Some(crate::message::Reasoning {
846 id: None,
847 content
848 }) if matches!(
849 content.first(),
850 Some(ReasoningContent::Text { text, .. }) if text == "second"
851 )
852 ));
853 }
854
855 #[derive(Clone, Debug, Deserialize, Serialize)]
856 struct MockStreamingResponse {
857 usage: crate::completion::Usage,
858 }
859
860 impl MockStreamingResponse {
861 fn new(total_tokens: u64) -> Self {
862 let mut usage = crate::completion::Usage::new();
863 usage.total_tokens = total_tokens;
864 Self { usage }
865 }
866 }
867
868 impl crate::completion::GetTokenUsage for MockStreamingResponse {
869 fn token_usage(&self) -> Option<crate::completion::Usage> {
870 Some(self.usage)
871 }
872 }
873
874 #[derive(Clone, Default)]
875 struct MultiTurnMockModel {
876 turn_counter: Arc<AtomicUsize>,
877 }
878
879 #[allow(refining_impl_trait)]
880 impl CompletionModel for MultiTurnMockModel {
881 type Response = ();
882 type StreamingResponse = MockStreamingResponse;
883 type Client = ();
884
885 fn make(_: &Self::Client, _: impl Into<String>) -> Self {
886 Self::default()
887 }
888
889 async fn completion(
890 &self,
891 _request: CompletionRequest,
892 ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
893 Err(CompletionError::ProviderError(
894 "completion is unused in this streaming test".to_string(),
895 ))
896 }
897
898 async fn stream(
899 &self,
900 _request: CompletionRequest,
901 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
902 let turn = self.turn_counter.fetch_add(1, Ordering::SeqCst);
903 let stream = async_stream::stream! {
904 if turn == 0 {
905 yield Ok(RawStreamingChoice::ToolCall(
906 RawStreamingToolCall::new(
907 "tool_call_1".to_string(),
908 "missing_tool".to_string(),
909 serde_json::json!({"input": "value"}),
910 )
911 .with_call_id("call_1".to_string()),
912 ));
913 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(4)));
914 } else {
915 yield Ok(RawStreamingChoice::Message("done".to_string()));
916 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(6)));
917 }
918 };
919
920 let pinned_stream: crate::streaming::StreamingResult<Self::StreamingResponse> =
921 Box::pin(stream);
922 Ok(StreamingCompletionResponse::stream(pinned_stream))
923 }
924 }
925
926 #[tokio::test]
927 async fn stream_prompt_continues_after_tool_call_turn() {
928 let model = MultiTurnMockModel::default();
929 let turn_counter = model.turn_counter.clone();
930 let agent = AgentBuilder::new(model).build();
931
932 let mut stream = agent.stream_prompt("do tool work").multi_turn(3).await;
933 let mut saw_tool_call = false;
934 let mut saw_tool_result = false;
935 let mut saw_final_response = false;
936 let mut final_text = String::new();
937
938 while let Some(item) = stream.next().await {
939 match item {
940 Ok(MultiTurnStreamItem::StreamAssistantItem(
941 StreamedAssistantContent::ToolCall { .. },
942 )) => {
943 saw_tool_call = true;
944 }
945 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
946 ..
947 })) => {
948 saw_tool_result = true;
949 }
950 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
951 text,
952 ))) => {
953 final_text.push_str(&text.text);
954 }
955 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
956 saw_final_response = true;
957 break;
958 }
959 Ok(_) => {}
960 Err(err) => panic!("unexpected streaming error: {err:?}"),
961 }
962 }
963
964 assert!(saw_tool_call);
965 assert!(saw_tool_result);
966 assert!(saw_final_response);
967 assert_eq!(final_text, "done");
968 assert_eq!(turn_counter.load(Ordering::SeqCst), 2);
969 }
970
971 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
974 let mut interval = tokio::time::interval(Duration::from_millis(50));
975 let mut count = 0u32;
976
977 while !stop.load(Ordering::Relaxed) {
978 interval.tick().await;
979 count += 1;
980
981 tracing::event!(
982 target: "background_logger",
983 tracing::Level::INFO,
984 count = count,
985 "Background tick"
986 );
987
988 let current = tracing::Span::current();
990 if !current.is_disabled() && !current.is_none() {
991 leak_count.fetch_add(1, Ordering::Relaxed);
992 }
993 }
994
995 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
996 }
997
998 #[tokio::test(flavor = "current_thread")]
1006 #[ignore = "This requires an API key"]
1007 async fn test_span_context_isolation() {
1008 let stop = Arc::new(AtomicBool::new(false));
1009 let leak_count = Arc::new(AtomicU32::new(0));
1010
1011 let bg_stop = stop.clone();
1013 let bg_leak = leak_count.clone();
1014 let bg_handle = tokio::spawn(async move {
1015 background_logger(bg_stop, bg_leak).await;
1016 });
1017
1018 tokio::time::sleep(Duration::from_millis(100)).await;
1020
1021 let client = anthropic::Client::from_env();
1024 let agent = client
1025 .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
1026 .preamble("You are a helpful assistant.")
1027 .temperature(0.1)
1028 .max_tokens(100)
1029 .build();
1030
1031 let mut stream = agent
1032 .stream_prompt("Say 'hello world' and nothing else.")
1033 .await;
1034
1035 let mut full_content = String::new();
1036 while let Some(item) = stream.next().await {
1037 match item {
1038 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1039 text,
1040 ))) => {
1041 full_content.push_str(&text.text);
1042 }
1043 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
1044 break;
1045 }
1046 Err(e) => {
1047 tracing::warn!("Error: {:?}", e);
1048 break;
1049 }
1050 _ => {}
1051 }
1052 }
1053
1054 tracing::info!("Got response: {:?}", full_content);
1055
1056 stop.store(true, Ordering::Relaxed);
1058 bg_handle.await.unwrap();
1059
1060 let leaks = leak_count.load(Ordering::Relaxed);
1061 assert_eq!(
1062 leaks, 0,
1063 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
1064 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
1065 );
1066 }
1067
1068 #[tokio::test]
1074 #[ignore = "This requires an API key"]
1075 async fn test_chat_history_in_final_response() {
1076 use crate::message::Message;
1077
1078 let client = anthropic::Client::from_env();
1079 let agent = client
1080 .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
1081 .preamble("You are a helpful assistant. Keep responses brief.")
1082 .temperature(0.1)
1083 .max_tokens(50)
1084 .build();
1085
1086 let mut stream = agent
1088 .stream_prompt("Say 'hello' and nothing else.")
1089 .with_history(vec![])
1090 .await;
1091
1092 let mut response_text = String::new();
1094 let mut final_history = None;
1095 while let Some(item) = stream.next().await {
1096 match item {
1097 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1098 text,
1099 ))) => {
1100 response_text.push_str(&text.text);
1101 }
1102 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1103 final_history = res.history().map(|h| h.to_vec());
1104 break;
1105 }
1106 Err(e) => {
1107 panic!("Streaming error: {:?}", e);
1108 }
1109 _ => {}
1110 }
1111 }
1112
1113 let history =
1114 final_history.expect("FinalResponse should contain history when with_history is used");
1115
1116 assert!(
1118 history.iter().any(|m| matches!(m, Message::User { .. })),
1119 "History should contain the user message"
1120 );
1121
1122 assert!(
1124 history
1125 .iter()
1126 .any(|m| matches!(m, Message::Assistant { .. })),
1127 "History should contain the assistant response"
1128 );
1129
1130 tracing::info!(
1131 "History after streaming: {} messages, response: {:?}",
1132 history.len(),
1133 response_text
1134 );
1135 }
1136}