1use crate::{
2 OneOrMany,
3 agent::completion::{DynamicContextStore, build_completion_request},
4 agent::prompt_request::{HookAction, hooks::PromptHook},
5 completion::{Document, GetTokenUsage},
6 json_utils,
7 message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent},
8 streaming::{StreamedAssistantContent, StreamedUserContent},
9 tool::server::ToolServerHandle,
10 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
11};
12use futures::{Stream, StreamExt};
13use serde::{Deserialize, Serialize};
14use std::{pin::Pin, sync::Arc};
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::ToolCallHookAction;
19use crate::{
20 agent::Agent,
21 completion::{CompletionError, CompletionModel, PromptError},
22 message::{Message, Text},
23 tool::ToolSetError,
24};
25
26#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
27pub type StreamingResult<R> =
28 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
29
30#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
31pub type StreamingResult<R> =
32 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(tag = "type", rename_all = "camelCase")]
36#[non_exhaustive]
37pub enum MultiTurnStreamItem<R> {
38 StreamAssistantItem(StreamedAssistantContent<R>),
40 StreamUserItem(StreamedUserContent),
42 FinalResponse(FinalResponse),
44}
45
46#[derive(Deserialize, Serialize, Debug, Clone)]
47#[serde(rename_all = "camelCase")]
48pub struct FinalResponse {
49 response: String,
50 aggregated_usage: crate::completion::Usage,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 history: Option<Vec<Message>>,
53}
54
55impl FinalResponse {
56 pub fn empty() -> Self {
57 Self {
58 response: String::new(),
59 aggregated_usage: crate::completion::Usage::new(),
60 history: None,
61 }
62 }
63
64 pub fn response(&self) -> &str {
65 &self.response
66 }
67
68 pub fn usage(&self) -> crate::completion::Usage {
69 self.aggregated_usage
70 }
71
72 pub fn history(&self) -> Option<&[Message]> {
73 self.history.as_deref()
74 }
75}
76
77impl<R> MultiTurnStreamItem<R> {
78 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
79 Self::StreamAssistantItem(item)
80 }
81
82 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
83 Self::FinalResponse(FinalResponse {
84 response: response.to_string(),
85 aggregated_usage,
86 history: None,
87 })
88 }
89
90 pub fn final_response_with_history(
91 response: &str,
92 aggregated_usage: crate::completion::Usage,
93 history: Option<Vec<Message>>,
94 ) -> Self {
95 Self::FinalResponse(FinalResponse {
96 response: response.to_string(),
97 aggregated_usage,
98 history,
99 })
100 }
101}
102
103fn merge_reasoning_blocks(
104 accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
105 incoming: &crate::message::Reasoning,
106) {
107 let ids_match = |existing: &crate::message::Reasoning| {
108 matches!(
109 (&existing.id, &incoming.id),
110 (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
111 )
112 };
113
114 if let Some(existing) = accumulated_reasoning
115 .iter_mut()
116 .rev()
117 .find(|existing| ids_match(existing))
118 {
119 existing.content.extend(incoming.content.clone());
120 } else {
121 accumulated_reasoning.push(incoming.clone());
122 }
123}
124
125async fn cancelled_prompt_error(chat_history: &Vec<Message>, reason: String) -> StreamingError {
126 StreamingError::Prompt(PromptError::prompt_cancelled(chat_history.to_owned(), reason).into())
127}
128
129fn tool_result_to_user_message(
130 id: String,
131 call_id: Option<String>,
132 tool_result: String,
133) -> Message {
134 let content = OneOrMany::one(ToolResultContent::text(tool_result));
135 let user_content = match call_id {
136 Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
137 None => UserContent::tool_result(id, content),
138 };
139
140 Message::User {
141 content: OneOrMany::one(user_content),
142 }
143}
144
145#[derive(Debug, thiserror::Error)]
146pub enum StreamingError {
147 #[error("CompletionError: {0}")]
148 Completion(#[from] CompletionError),
149 #[error("PromptError: {0}")]
150 Prompt(#[from] Box<PromptError>),
151 #[error("ToolSetError: {0}")]
152 Tool(#[from] ToolSetError),
153}
154
155const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
156
157pub struct StreamingPromptRequest<M, P>
166where
167 M: CompletionModel,
168 P: PromptHook<M> + 'static,
169{
170 prompt: Message,
172 chat_history: Option<Vec<Message>>,
174 max_turns: usize,
176
177 model: Arc<M>,
180 agent_name: Option<String>,
182 preamble: Option<String>,
184 static_context: Vec<Document>,
186 temperature: Option<f64>,
188 max_tokens: Option<u64>,
190 additional_params: Option<serde_json::Value>,
192 tool_server_handle: ToolServerHandle,
194 dynamic_context: DynamicContextStore,
196 tool_choice: Option<ToolChoice>,
198 output_schema: Option<schemars::Schema>,
200 hook: Option<P>,
202}
203
204impl<M, P> StreamingPromptRequest<M, P>
205where
206 M: CompletionModel + 'static,
207 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
208 P: PromptHook<M>,
209{
210 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
213 StreamingPromptRequest {
214 prompt: prompt.into(),
215 chat_history: None,
216 max_turns: agent.default_max_turns.unwrap_or_default(),
217 model: agent.model.clone(),
218 agent_name: agent.name.clone(),
219 preamble: agent.preamble.clone(),
220 static_context: agent.static_context.clone(),
221 temperature: agent.temperature,
222 max_tokens: agent.max_tokens,
223 additional_params: agent.additional_params.clone(),
224 tool_server_handle: agent.tool_server_handle.clone(),
225 dynamic_context: agent.dynamic_context.clone(),
226 tool_choice: agent.tool_choice.clone(),
227 output_schema: agent.output_schema.clone(),
228 hook: None,
229 }
230 }
231
232 pub fn from_agent<P2>(
234 agent: &Agent<M, P2>,
235 prompt: impl Into<Message>,
236 ) -> StreamingPromptRequest<M, P2>
237 where
238 P2: PromptHook<M>,
239 {
240 StreamingPromptRequest {
241 prompt: prompt.into(),
242 chat_history: None,
243 max_turns: agent.default_max_turns.unwrap_or_default(),
244 model: agent.model.clone(),
245 agent_name: agent.name.clone(),
246 preamble: agent.preamble.clone(),
247 static_context: agent.static_context.clone(),
248 temperature: agent.temperature,
249 max_tokens: agent.max_tokens,
250 additional_params: agent.additional_params.clone(),
251 tool_server_handle: agent.tool_server_handle.clone(),
252 dynamic_context: agent.dynamic_context.clone(),
253 tool_choice: agent.tool_choice.clone(),
254 output_schema: agent.output_schema.clone(),
255 hook: agent.hook.clone(),
256 }
257 }
258
259 fn agent_name(&self) -> &str {
260 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
261 }
262
263 pub fn multi_turn(mut self, turns: usize) -> Self {
266 self.max_turns = turns;
267 self
268 }
269
270 pub fn with_history(mut self, history: Vec<Message>) -> Self {
283 self.chat_history = Some(history);
284 self
285 }
286
287 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
290 where
291 P2: PromptHook<M>,
292 {
293 StreamingPromptRequest {
294 prompt: self.prompt,
295 chat_history: self.chat_history,
296 max_turns: self.max_turns,
297 model: self.model,
298 agent_name: self.agent_name,
299 preamble: self.preamble,
300 static_context: self.static_context,
301 temperature: self.temperature,
302 max_tokens: self.max_tokens,
303 additional_params: self.additional_params,
304 tool_server_handle: self.tool_server_handle,
305 dynamic_context: self.dynamic_context,
306 tool_choice: self.tool_choice,
307 output_schema: self.output_schema,
308 hook: Some(hook),
309 }
310 }
311
312 async fn send(self) -> StreamingResult<M::StreamingResponse> {
313 let agent_span = if tracing::Span::current().is_disabled() {
314 info_span!(
315 "invoke_agent",
316 gen_ai.operation.name = "invoke_agent",
317 gen_ai.agent.name = self.agent_name(),
318 gen_ai.system_instructions = self.preamble,
319 gen_ai.prompt = tracing::field::Empty,
320 gen_ai.completion = tracing::field::Empty,
321 gen_ai.usage.input_tokens = tracing::field::Empty,
322 gen_ai.usage.output_tokens = tracing::field::Empty,
323 gen_ai.usage.cached_tokens = tracing::field::Empty,
324 )
325 } else {
326 tracing::Span::current()
327 };
328
329 let prompt = self.prompt;
330 if let Some(text) = prompt.rag_text() {
331 agent_span.record("gen_ai.prompt", text);
332 }
333
334 let model = self.model.clone();
336 let preamble = self.preamble.clone();
337 let static_context = self.static_context.clone();
338 let temperature = self.temperature;
339 let max_tokens = self.max_tokens;
340 let additional_params = self.additional_params.clone();
341 let tool_server_handle = self.tool_server_handle.clone();
342 let dynamic_context = self.dynamic_context.clone();
343 let tool_choice = self.tool_choice.clone();
344 let agent_name = self.agent_name.clone();
345 let has_history = self.chat_history.is_some();
346 let mut chat_history = self.chat_history.unwrap_or_default();
347
348 let mut current_max_turns = 0;
349 let mut last_prompt_error = String::new();
350
351 let mut last_text_response = String::new();
352 let mut is_text_response = false;
353 let mut max_turns_reached = false;
354 let output_schema = self.output_schema;
355
356 let mut aggregated_usage = crate::completion::Usage::new();
357
358 let stream = async_stream::stream! {
365 let mut current_prompt = prompt.clone();
366
367 'outer: loop {
368 if current_max_turns > self.max_turns + 1 {
369 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
370 max_turns_reached = true;
371 break;
372 }
373
374 current_max_turns += 1;
375
376 if self.max_turns > 1 {
377 tracing::info!(
378 "Current conversation Turns: {}/{}",
379 current_max_turns,
380 self.max_turns
381 );
382 }
383
384 if let Some(ref hook) = self.hook {
385 let history_snapshot = chat_history.clone();
386 if let HookAction::Terminate { reason } = hook.on_completion_call(¤t_prompt, &history_snapshot)
387 .await {
388 yield Err(cancelled_prompt_error(&chat_history, reason).await);
389 break 'outer;
390 }
391 }
392
393 let chat_stream_span = info_span!(
394 target: "rig::agent_chat",
395 parent: tracing::Span::current(),
396 "chat_streaming",
397 gen_ai.operation.name = "chat",
398 gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
399 gen_ai.system_instructions = preamble,
400 gen_ai.provider.name = tracing::field::Empty,
401 gen_ai.request.model = tracing::field::Empty,
402 gen_ai.response.id = tracing::field::Empty,
403 gen_ai.response.model = tracing::field::Empty,
404 gen_ai.usage.output_tokens = tracing::field::Empty,
405 gen_ai.usage.input_tokens = tracing::field::Empty,
406 gen_ai.usage.cached_tokens = tracing::field::Empty,
407 gen_ai.input.messages = tracing::field::Empty,
408 gen_ai.output.messages = tracing::field::Empty,
409 );
410
411 let history_snapshot = chat_history.clone();
412 let mut stream = tracing::Instrument::instrument(
413 build_completion_request(
414 &model,
415 current_prompt.clone(),
416 history_snapshot,
417 preamble.as_deref(),
418 &static_context,
419 temperature,
420 max_tokens,
421 additional_params.as_ref(),
422 tool_choice.as_ref(),
423 &tool_server_handle,
424 &dynamic_context,
425 output_schema.as_ref(),
426 )
427 .await?
428 .stream(), chat_stream_span
429 )
430
431 .await?;
432
433 chat_history.push(current_prompt.clone());
434
435 let mut tool_calls = vec![];
436 let mut tool_results = vec![];
437 let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
438 let mut pending_reasoning_delta_text = String::new();
441 let mut pending_reasoning_delta_id: Option<String> = None;
442 let mut saw_tool_call_this_turn = false;
443
444 while let Some(content) = stream.next().await {
445 match content {
446 Ok(StreamedAssistantContent::Text(text)) => {
447 if !is_text_response {
448 last_text_response = String::new();
449 is_text_response = true;
450 }
451 last_text_response.push_str(&text.text);
452 if let Some(ref hook) = self.hook &&
453 let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &last_text_response).await {
454 yield Err(cancelled_prompt_error(&chat_history, reason).await);
455 break 'outer;
456 }
457
458 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
459 },
460 Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id }) => {
461 let tool_span = info_span!(
462 parent: tracing::Span::current(),
463 "execute_tool",
464 gen_ai.operation.name = "execute_tool",
465 gen_ai.tool.type = "function",
466 gen_ai.tool.name = tracing::field::Empty,
467 gen_ai.tool.call.id = tracing::field::Empty,
468 gen_ai.tool.call.arguments = tracing::field::Empty,
469 gen_ai.tool.call.result = tracing::field::Empty
470 );
471
472 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
473
474 let tc_result = async {
475 let tool_span = tracing::Span::current();
476 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
477 if let Some(ref hook) = self.hook {
478 let action = hook
479 .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
480 .await;
481
482 if let ToolCallHookAction::Terminate { reason } = action {
483 return Err(cancelled_prompt_error(&chat_history, reason).await);
484 }
485
486 if let ToolCallHookAction::Skip { reason } = action {
487 tracing::info!(
489 tool_name = tool_call.function.name.as_str(),
490 reason = reason,
491 "Tool call rejected"
492 );
493 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
494 tool_calls.push(tool_call_msg);
495 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
496 saw_tool_call_this_turn = true;
497 return Ok(reason);
498 }
499 }
500
501 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
502 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
503
504 let tool_result = match
505 tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
506 Ok(thing) => thing,
507 Err(e) => {
508 tracing::warn!("Error while calling tool: {e}");
509 e.to_string()
510 }
511 };
512
513 tool_span.record("gen_ai.tool.call.result", &tool_result);
514
515 if let Some(ref hook) = self.hook &&
516 let HookAction::Terminate { reason } =
517 hook.on_tool_result(
518 &tool_call.function.name,
519 tool_call.call_id.clone(),
520 &internal_call_id,
521 &tool_args,
522 &tool_result.to_string()
523 )
524 .await {
525 return Err(cancelled_prompt_error(&chat_history, reason).await);
526 }
527
528 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
529
530 tool_calls.push(tool_call_msg);
531 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
532
533 saw_tool_call_this_turn = true;
534 Ok(tool_result)
535 }.instrument(tool_span).await;
536
537 match tc_result {
538 Ok(text) => {
539 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
540 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
541 }
542 Err(e) => {
543 yield Err(e);
544 break 'outer;
545 }
546 }
547 },
548 Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content }) => {
549 if let Some(ref hook) = self.hook {
550 let (name, delta) = match &content {
551 rig::streaming::ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""),
552 rig::streaming::ToolCallDeltaContent::Delta(d) => (None, d.as_str()),
553 };
554
555 if let HookAction::Terminate { reason } = hook.on_tool_call_delta(&id, &internal_call_id, name, delta)
556 .await {
557 yield Err(cancelled_prompt_error(&chat_history, reason).await);
558 break 'outer;
559 }
560 }
561 }
562 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
563 merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
567 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
568 },
569 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
570 pending_reasoning_delta_text.push_str(&reasoning);
574 if pending_reasoning_delta_id.is_none() {
575 pending_reasoning_delta_id = id.clone();
576 }
577 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
578 },
579 Ok(StreamedAssistantContent::Final(final_resp)) => {
580 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
581 if is_text_response {
582 if let Some(ref hook) = self.hook &&
583 let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(&prompt, &final_resp).await {
584 yield Err(cancelled_prompt_error(&chat_history, reason).await);
585 break 'outer;
586 }
587
588 tracing::Span::current().record("gen_ai.completion", &last_text_response);
589 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
590 is_text_response = false;
591 }
592 }
593 Err(e) => {
594 yield Err(e.into());
595 break 'outer;
596 }
597 }
598 }
599
600 if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
604 let mut assembled = crate::message::Reasoning::new(&pending_reasoning_delta_text);
605 if let Some(id) = pending_reasoning_delta_id.take() {
606 assembled = assembled.with_id(id);
607 }
608 accumulated_reasoning.push(assembled);
609 }
610
611 if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
614 let mut content_items: Vec<rig::message::AssistantContent> = vec![];
615
616 for reasoning in accumulated_reasoning.drain(..) {
618 content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
619 }
620
621 content_items.extend(tool_calls.clone());
622
623 if !content_items.is_empty() {
624 chat_history.push(Message::Assistant {
625 id: stream.message_id.clone(),
626 content: OneOrMany::many(content_items).expect("Should have at least one item"),
627 });
628 }
629 }
630
631 for (id, call_id, tool_result) in tool_results {
632 chat_history.push(tool_result_to_user_message(id, call_id, tool_result));
633 }
634
635 current_prompt = match chat_history.pop() {
637 Some(prompt) => prompt,
638 None => unreachable!("Chat history should never be empty at this point"),
639 };
640
641 if !saw_tool_call_this_turn {
642 chat_history.push(current_prompt.clone());
644 if !last_text_response.is_empty() {
645 chat_history.push(Message::assistant(&last_text_response));
646 }
647
648 let current_span = tracing::Span::current();
649 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
650 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
651 current_span.record("gen_ai.usage.cached_tokens", aggregated_usage.cached_input_tokens);
652 tracing::info!("Agent multi-turn stream finished");
653 let history_snapshot = if has_history {
654 Some(chat_history.clone())
655 } else {
656 None
657 };
658 yield Ok(MultiTurnStreamItem::final_response_with_history(
659 &last_text_response,
660 aggregated_usage,
661 history_snapshot,
662 ));
663 break;
664 }
665 }
666
667 if max_turns_reached {
668 yield Err(Box::new(PromptError::MaxTurnsError {
669 max_turns: self.max_turns,
670 chat_history: Box::new(chat_history.clone()),
671 prompt: Box::new(last_prompt_error.clone().into()),
672 }).into());
673 }
674 };
675
676 Box::pin(stream.instrument(agent_span))
677 }
678}
679
680impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
681where
682 M: CompletionModel + 'static,
683 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
684 P: PromptHook<M> + 'static,
685{
686 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
688
689 fn into_future(self) -> Self::IntoFuture {
690 Box::pin(async move { self.send().await })
692 }
693}
694
695pub async fn stream_to_stdout<R>(
697 stream: &mut StreamingResult<R>,
698) -> Result<FinalResponse, std::io::Error> {
699 let mut final_res = FinalResponse::empty();
700 print!("Response: ");
701 while let Some(content) = stream.next().await {
702 match content {
703 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
704 Text { text },
705 ))) => {
706 print!("{text}");
707 std::io::Write::flush(&mut std::io::stdout()).unwrap();
708 }
709 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
710 reasoning,
711 ))) => {
712 let reasoning = reasoning.display_text();
713 print!("{reasoning}");
714 std::io::Write::flush(&mut std::io::stdout()).unwrap();
715 }
716 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
717 final_res = res;
718 }
719 Err(err) => {
720 eprintln!("Error: {err}");
721 }
722 _ => {}
723 }
724 }
725
726 Ok(final_res)
727}
728
729#[cfg(test)]
730mod tests {
731 use super::*;
732 use crate::agent::AgentBuilder;
733 use crate::client::ProviderClient;
734 use crate::client::completion::CompletionClient;
735 use crate::completion::{
736 CompletionError, CompletionModel, CompletionRequest, CompletionResponse,
737 };
738 use crate::message::ReasoningContent;
739 use crate::providers::anthropic;
740 use crate::streaming::StreamingPrompt;
741 use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
742 use futures::StreamExt;
743 use serde::{Deserialize, Serialize};
744 use std::sync::Arc;
745 use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
746 use std::time::Duration;
747
748 #[test]
749 fn merge_reasoning_blocks_preserves_order_and_signatures() {
750 let mut accumulated = Vec::new();
751 let first = crate::message::Reasoning {
752 id: Some("rs_1".to_string()),
753 content: vec![ReasoningContent::Text {
754 text: "step-1".to_string(),
755 signature: Some("sig-1".to_string()),
756 }],
757 };
758 let second = crate::message::Reasoning {
759 id: Some("rs_1".to_string()),
760 content: vec![
761 ReasoningContent::Text {
762 text: "step-2".to_string(),
763 signature: Some("sig-2".to_string()),
764 },
765 ReasoningContent::Summary("summary".to_string()),
766 ],
767 };
768
769 merge_reasoning_blocks(&mut accumulated, &first);
770 merge_reasoning_blocks(&mut accumulated, &second);
771
772 assert_eq!(accumulated.len(), 1);
773 let merged = accumulated.first().expect("expected accumulated reasoning");
774 assert_eq!(merged.id.as_deref(), Some("rs_1"));
775 assert_eq!(merged.content.len(), 3);
776 assert!(matches!(
777 merged.content.first(),
778 Some(ReasoningContent::Text { text, signature: Some(sig) })
779 if text == "step-1" && sig == "sig-1"
780 ));
781 assert!(matches!(
782 merged.content.get(1),
783 Some(ReasoningContent::Text { text, signature: Some(sig) })
784 if text == "step-2" && sig == "sig-2"
785 ));
786 }
787
788 #[test]
789 fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
790 let mut accumulated = vec![crate::message::Reasoning {
791 id: Some("rs_a".to_string()),
792 content: vec![ReasoningContent::Text {
793 text: "step-1".to_string(),
794 signature: None,
795 }],
796 }];
797 let incoming = crate::message::Reasoning {
798 id: Some("rs_b".to_string()),
799 content: vec![ReasoningContent::Text {
800 text: "step-2".to_string(),
801 signature: None,
802 }],
803 };
804
805 merge_reasoning_blocks(&mut accumulated, &incoming);
806 assert_eq!(accumulated.len(), 2);
807 assert_eq!(
808 accumulated.first().and_then(|r| r.id.as_deref()),
809 Some("rs_a")
810 );
811 assert_eq!(
812 accumulated.get(1).and_then(|r| r.id.as_deref()),
813 Some("rs_b")
814 );
815 }
816
817 #[test]
818 fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
819 let mut accumulated = vec![crate::message::Reasoning {
820 id: None,
821 content: vec![ReasoningContent::Text {
822 text: "first".to_string(),
823 signature: None,
824 }],
825 }];
826 let incoming = crate::message::Reasoning {
827 id: None,
828 content: vec![ReasoningContent::Text {
829 text: "second".to_string(),
830 signature: None,
831 }],
832 };
833
834 merge_reasoning_blocks(&mut accumulated, &incoming);
835 assert_eq!(accumulated.len(), 2);
836 assert!(matches!(
837 accumulated.first(),
838 Some(crate::message::Reasoning {
839 id: None,
840 content
841 }) if matches!(
842 content.first(),
843 Some(ReasoningContent::Text { text, .. }) if text == "first"
844 )
845 ));
846 assert!(matches!(
847 accumulated.get(1),
848 Some(crate::message::Reasoning {
849 id: None,
850 content
851 }) if matches!(
852 content.first(),
853 Some(ReasoningContent::Text { text, .. }) if text == "second"
854 )
855 ));
856 }
857
858 #[derive(Clone, Debug, Deserialize, Serialize)]
859 struct MockStreamingResponse {
860 usage: crate::completion::Usage,
861 }
862
863 impl MockStreamingResponse {
864 fn new(total_tokens: u64) -> Self {
865 let mut usage = crate::completion::Usage::new();
866 usage.total_tokens = total_tokens;
867 Self { usage }
868 }
869 }
870
871 impl crate::completion::GetTokenUsage for MockStreamingResponse {
872 fn token_usage(&self) -> Option<crate::completion::Usage> {
873 Some(self.usage)
874 }
875 }
876
877 #[derive(Clone, Default)]
878 struct MultiTurnMockModel {
879 turn_counter: Arc<AtomicUsize>,
880 }
881
882 #[allow(refining_impl_trait)]
883 impl CompletionModel for MultiTurnMockModel {
884 type Response = ();
885 type StreamingResponse = MockStreamingResponse;
886 type Client = ();
887
888 fn make(_: &Self::Client, _: impl Into<String>) -> Self {
889 Self::default()
890 }
891
892 async fn completion(
893 &self,
894 _request: CompletionRequest,
895 ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
896 Err(CompletionError::ProviderError(
897 "completion is unused in this streaming test".to_string(),
898 ))
899 }
900
901 async fn stream(
902 &self,
903 _request: CompletionRequest,
904 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
905 let turn = self.turn_counter.fetch_add(1, Ordering::SeqCst);
906 let stream = async_stream::stream! {
907 if turn == 0 {
908 yield Ok(RawStreamingChoice::ToolCall(
909 RawStreamingToolCall::new(
910 "tool_call_1".to_string(),
911 "missing_tool".to_string(),
912 serde_json::json!({"input": "value"}),
913 )
914 .with_call_id("call_1".to_string()),
915 ));
916 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(4)));
917 } else {
918 yield Ok(RawStreamingChoice::Message("done".to_string()));
919 yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(6)));
920 }
921 };
922
923 let pinned_stream: crate::streaming::StreamingResult<Self::StreamingResponse> =
924 Box::pin(stream);
925 Ok(StreamingCompletionResponse::stream(pinned_stream))
926 }
927 }
928
929 #[tokio::test]
930 async fn stream_prompt_continues_after_tool_call_turn() {
931 let model = MultiTurnMockModel::default();
932 let turn_counter = model.turn_counter.clone();
933 let agent = AgentBuilder::new(model).build();
934
935 let mut stream = agent.stream_prompt("do tool work").multi_turn(3).await;
936 let mut saw_tool_call = false;
937 let mut saw_tool_result = false;
938 let mut saw_final_response = false;
939 let mut final_text = String::new();
940
941 while let Some(item) = stream.next().await {
942 match item {
943 Ok(MultiTurnStreamItem::StreamAssistantItem(
944 StreamedAssistantContent::ToolCall { .. },
945 )) => {
946 saw_tool_call = true;
947 }
948 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
949 ..
950 })) => {
951 saw_tool_result = true;
952 }
953 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
954 text,
955 ))) => {
956 final_text.push_str(&text.text);
957 }
958 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
959 saw_final_response = true;
960 break;
961 }
962 Ok(_) => {}
963 Err(err) => panic!("unexpected streaming error: {err:?}"),
964 }
965 }
966
967 assert!(saw_tool_call);
968 assert!(saw_tool_result);
969 assert!(saw_final_response);
970 assert_eq!(final_text, "done");
971 assert_eq!(turn_counter.load(Ordering::SeqCst), 2);
972 }
973
974 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
977 let mut interval = tokio::time::interval(Duration::from_millis(50));
978 let mut count = 0u32;
979
980 while !stop.load(Ordering::Relaxed) {
981 interval.tick().await;
982 count += 1;
983
984 tracing::event!(
985 target: "background_logger",
986 tracing::Level::INFO,
987 count = count,
988 "Background tick"
989 );
990
991 let current = tracing::Span::current();
993 if !current.is_disabled() && !current.is_none() {
994 leak_count.fetch_add(1, Ordering::Relaxed);
995 }
996 }
997
998 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
999 }
1000
1001 #[tokio::test(flavor = "current_thread")]
1009 #[ignore = "This requires an API key"]
1010 async fn test_span_context_isolation() {
1011 let stop = Arc::new(AtomicBool::new(false));
1012 let leak_count = Arc::new(AtomicU32::new(0));
1013
1014 let bg_stop = stop.clone();
1016 let bg_leak = leak_count.clone();
1017 let bg_handle = tokio::spawn(async move {
1018 background_logger(bg_stop, bg_leak).await;
1019 });
1020
1021 tokio::time::sleep(Duration::from_millis(100)).await;
1023
1024 let client = anthropic::Client::from_env();
1027 let agent = client
1028 .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
1029 .preamble("You are a helpful assistant.")
1030 .temperature(0.1)
1031 .max_tokens(100)
1032 .build();
1033
1034 let mut stream = agent
1035 .stream_prompt("Say 'hello world' and nothing else.")
1036 .await;
1037
1038 let mut full_content = String::new();
1039 while let Some(item) = stream.next().await {
1040 match item {
1041 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1042 text,
1043 ))) => {
1044 full_content.push_str(&text.text);
1045 }
1046 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
1047 break;
1048 }
1049 Err(e) => {
1050 tracing::warn!("Error: {:?}", e);
1051 break;
1052 }
1053 _ => {}
1054 }
1055 }
1056
1057 tracing::info!("Got response: {:?}", full_content);
1058
1059 stop.store(true, Ordering::Relaxed);
1061 bg_handle.await.unwrap();
1062
1063 let leaks = leak_count.load(Ordering::Relaxed);
1064 assert_eq!(
1065 leaks, 0,
1066 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
1067 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
1068 );
1069 }
1070
1071 #[tokio::test]
1077 #[ignore = "This requires an API key"]
1078 async fn test_chat_history_in_final_response() {
1079 use crate::message::Message;
1080
1081 let client = anthropic::Client::from_env();
1082 let agent = client
1083 .agent(anthropic::completion::CLAUDE_3_5_HAIKU)
1084 .preamble("You are a helpful assistant. Keep responses brief.")
1085 .temperature(0.1)
1086 .max_tokens(50)
1087 .build();
1088
1089 let mut stream = agent
1091 .stream_prompt("Say 'hello' and nothing else.")
1092 .with_history(vec![])
1093 .await;
1094
1095 let mut response_text = String::new();
1097 let mut final_history = None;
1098 while let Some(item) = stream.next().await {
1099 match item {
1100 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1101 text,
1102 ))) => {
1103 response_text.push_str(&text.text);
1104 }
1105 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1106 final_history = res.history().map(|h| h.to_vec());
1107 break;
1108 }
1109 Err(e) => {
1110 panic!("Streaming error: {:?}", e);
1111 }
1112 _ => {}
1113 }
1114 }
1115
1116 let history =
1117 final_history.expect("FinalResponse should contain history when with_history is used");
1118
1119 assert!(
1121 history.iter().any(|m| matches!(m, Message::User { .. })),
1122 "History should contain the user message"
1123 );
1124
1125 assert!(
1127 history
1128 .iter()
1129 .any(|m| matches!(m, Message::Assistant { .. })),
1130 "History should contain the assistant response"
1131 );
1132
1133 tracing::info!(
1134 "History after streaming: {} messages, response: {:?}",
1135 history.len(),
1136 response_text
1137 );
1138 }
1139}