1use crate::{
2 OneOrMany,
3 agent::completion::{DynamicContextStore, build_completion_request},
4 agent::prompt_request::{HookAction, hooks::PromptHook},
5 completion::{Document, GetTokenUsage},
6 json_utils,
7 memory::ConversationMemory,
8 message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent},
9 streaming::{StreamedAssistantContent, StreamedUserContent},
10 tool::server::ToolServerHandle,
11 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
12};
13use futures::{Stream, StreamExt};
14use serde::{Deserialize, Serialize};
15use std::{pin::Pin, sync::Arc};
16use tracing::info_span;
17use tracing_futures::Instrument;
18
19use super::ToolCallHookAction;
20use crate::{
21 agent::Agent,
22 completion::{CompletionError, CompletionModel, PromptError},
23 message::{Message, Text},
24 tool::ToolSetError,
25};
26
27#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
28pub type StreamingResult<R> =
29 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;
30
31#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
32pub type StreamingResult<R> =
33 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;
34
35#[derive(Deserialize, Serialize, Debug, Clone)]
36#[serde(tag = "type", rename_all = "camelCase")]
37#[non_exhaustive]
38pub enum MultiTurnStreamItem<R> {
39 StreamAssistantItem(StreamedAssistantContent<R>),
41 StreamUserItem(StreamedUserContent),
43 FinalResponse(FinalResponse),
45}
46
47#[derive(Deserialize, Serialize, Debug, Clone)]
48#[serde(rename_all = "camelCase")]
49pub struct FinalResponse {
50 response: String,
53 aggregated_usage: crate::completion::Usage,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 history: Option<Vec<Message>>,
56}
57
58impl FinalResponse {
59 pub fn empty() -> Self {
60 Self {
61 response: String::new(),
62 aggregated_usage: crate::completion::Usage::new(),
63 history: None,
64 }
65 }
66
67 pub fn response(&self) -> &str {
69 &self.response
70 }
71
72 pub fn usage(&self) -> crate::completion::Usage {
73 self.aggregated_usage
74 }
75
76 pub fn history(&self) -> Option<&[Message]> {
77 self.history.as_deref()
78 }
79}
80
81impl<R> MultiTurnStreamItem<R> {
82 pub(crate) fn stream_item(item: StreamedAssistantContent<R>) -> Self {
83 Self::StreamAssistantItem(item)
84 }
85
86 pub fn final_response(response: &str, aggregated_usage: crate::completion::Usage) -> Self {
87 Self::FinalResponse(FinalResponse {
88 response: response.to_string(),
89 aggregated_usage,
90 history: None,
91 })
92 }
93
94 pub fn final_response_with_history(
95 response: &str,
96 aggregated_usage: crate::completion::Usage,
97 history: Option<Vec<Message>>,
98 ) -> Self {
99 Self::FinalResponse(FinalResponse {
100 response: response.to_string(),
101 aggregated_usage,
102 history,
103 })
104 }
105}
106
107fn merge_reasoning_blocks(
108 accumulated_reasoning: &mut Vec<crate::message::Reasoning>,
109 incoming: &crate::message::Reasoning,
110) {
111 let ids_match = |existing: &crate::message::Reasoning| {
112 matches!(
113 (&existing.id, &incoming.id),
114 (Some(existing_id), Some(incoming_id)) if existing_id == incoming_id
115 )
116 };
117
118 if let Some(existing) = accumulated_reasoning
119 .iter_mut()
120 .rev()
121 .find(|existing| ids_match(existing))
122 {
123 existing.content.extend(incoming.content.clone());
124 } else {
125 accumulated_reasoning.push(incoming.clone());
126 }
127}
128
129fn build_full_history(
131 chat_history: Option<&[Message]>,
132 new_messages: Vec<Message>,
133) -> Vec<Message> {
134 let input = chat_history.unwrap_or(&[]);
135 input.iter().cloned().chain(new_messages).collect()
136}
137
138fn build_history_for_request(
140 chat_history: Option<&[Message]>,
141 new_messages: &[Message],
142) -> Vec<Message> {
143 let input = chat_history.unwrap_or(&[]);
144 input.iter().chain(new_messages.iter()).cloned().collect()
145}
146
147async fn cancelled_prompt_error(
148 chat_history: Option<&[Message]>,
149 new_messages: Vec<Message>,
150 reason: String,
151) -> StreamingError {
152 StreamingError::Prompt(
153 PromptError::prompt_cancelled(build_full_history(chat_history, new_messages), reason)
154 .into(),
155 )
156}
157
158fn tool_result_to_user_message(
159 id: String,
160 call_id: Option<String>,
161 tool_result: String,
162) -> Message {
163 let content = ToolResultContent::from_tool_output(tool_result);
164 let user_content = match call_id {
165 Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
166 None => UserContent::tool_result(id, content),
167 };
168
169 Message::User {
170 content: OneOrMany::one(user_content),
171 }
172}
173
174fn assistant_text_from_choice(choice: &OneOrMany<AssistantContent>) -> String {
175 choice
176 .iter()
177 .filter_map(|content| match content {
178 AssistantContent::Text(text) => Some(text.text.as_str()),
179 _ => None,
180 })
181 .collect()
182}
183
184#[derive(Debug, thiserror::Error)]
185pub enum StreamingError {
186 #[error("CompletionError: {0}")]
187 Completion(#[from] CompletionError),
188 #[error("PromptError: {0}")]
189 Prompt(#[from] Box<PromptError>),
190 #[error("ToolSetError: {0}")]
191 Tool(#[from] ToolSetError),
192}
193
194impl From<crate::memory::MemoryError> for StreamingError {
198 fn from(err: crate::memory::MemoryError) -> Self {
199 Self::Completion(CompletionError::RequestError(Box::new(err)))
200 }
201}
202
203const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
204
205pub struct StreamingPromptRequest<M, P>
214where
215 M: CompletionModel,
216 P: PromptHook<M> + 'static,
217{
218 prompt: Message,
220 chat_history: Option<Vec<Message>>,
222 max_turns: usize,
224
225 model: Arc<M>,
228 agent_name: Option<String>,
230 preamble: Option<String>,
232 static_context: Vec<Document>,
234 temperature: Option<f64>,
236 max_tokens: Option<u64>,
238 additional_params: Option<serde_json::Value>,
240 tool_server_handle: ToolServerHandle,
242 dynamic_context: DynamicContextStore,
244 tool_choice: Option<ToolChoice>,
246 output_schema: Option<schemars::Schema>,
248 hook: Option<P>,
250 memory: Option<Arc<dyn ConversationMemory>>,
252 conversation_id: Option<String>,
254}
255
256impl<M, P> StreamingPromptRequest<M, P>
257where
258 M: CompletionModel + 'static,
259 <M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
260 P: PromptHook<M>,
261{
262 pub fn new(agent: Arc<Agent<M>>, prompt: impl Into<Message>) -> StreamingPromptRequest<M, ()> {
265 StreamingPromptRequest {
266 prompt: prompt.into(),
267 chat_history: None,
268 max_turns: agent.default_max_turns.unwrap_or_default(),
269 model: agent.model.clone(),
270 agent_name: agent.name.clone(),
271 preamble: agent.preamble.clone(),
272 static_context: agent.static_context.clone(),
273 temperature: agent.temperature,
274 max_tokens: agent.max_tokens,
275 additional_params: agent.additional_params.clone(),
276 tool_server_handle: agent.tool_server_handle.clone(),
277 dynamic_context: agent.dynamic_context.clone(),
278 tool_choice: agent.tool_choice.clone(),
279 output_schema: agent.output_schema.clone(),
280 hook: None,
281 memory: agent.memory.clone(),
282 conversation_id: agent.default_conversation_id.clone(),
283 }
284 }
285
286 pub fn from_agent<P2>(
288 agent: &Agent<M, P2>,
289 prompt: impl Into<Message>,
290 ) -> StreamingPromptRequest<M, P2>
291 where
292 P2: PromptHook<M>,
293 {
294 StreamingPromptRequest {
295 prompt: prompt.into(),
296 chat_history: None,
297 max_turns: agent.default_max_turns.unwrap_or_default(),
298 model: agent.model.clone(),
299 agent_name: agent.name.clone(),
300 preamble: agent.preamble.clone(),
301 static_context: agent.static_context.clone(),
302 temperature: agent.temperature,
303 max_tokens: agent.max_tokens,
304 additional_params: agent.additional_params.clone(),
305 tool_server_handle: agent.tool_server_handle.clone(),
306 dynamic_context: agent.dynamic_context.clone(),
307 tool_choice: agent.tool_choice.clone(),
308 output_schema: agent.output_schema.clone(),
309 hook: agent.hook.clone(),
310 memory: agent.memory.clone(),
311 conversation_id: agent.default_conversation_id.clone(),
312 }
313 }
314
315 fn agent_name(&self) -> &str {
316 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
317 }
318
319 pub fn multi_turn(mut self, turns: usize) -> Self {
322 self.max_turns = turns;
323 self
324 }
325
326 pub fn with_history<I, T>(mut self, history: I) -> Self
339 where
340 I: IntoIterator<Item = T>,
341 T: Into<Message>,
342 {
343 self.chat_history = Some(history.into_iter().map(Into::into).collect());
344 self
345 }
346
347 pub fn with_hook<P2>(self, hook: P2) -> StreamingPromptRequest<M, P2>
350 where
351 P2: PromptHook<M>,
352 {
353 StreamingPromptRequest {
354 prompt: self.prompt,
355 chat_history: self.chat_history,
356 max_turns: self.max_turns,
357 model: self.model,
358 agent_name: self.agent_name,
359 preamble: self.preamble,
360 static_context: self.static_context,
361 temperature: self.temperature,
362 max_tokens: self.max_tokens,
363 additional_params: self.additional_params,
364 tool_server_handle: self.tool_server_handle,
365 dynamic_context: self.dynamic_context,
366 tool_choice: self.tool_choice,
367 output_schema: self.output_schema,
368 hook: Some(hook),
369 memory: self.memory,
370 conversation_id: self.conversation_id,
371 }
372 }
373
374 pub fn conversation(mut self, id: impl Into<String>) -> Self {
379 self.conversation_id = Some(id.into());
380 self
381 }
382
383 pub fn without_memory(mut self) -> Self {
387 self.memory = None;
388 self.conversation_id = None;
389 self
390 }
391
392 async fn send(self) -> StreamingResult<M::StreamingResponse> {
393 let agent_span = if tracing::Span::current().is_disabled() {
394 info_span!(
395 "invoke_agent",
396 gen_ai.operation.name = "invoke_agent",
397 gen_ai.agent.name = self.agent_name(),
398 gen_ai.system_instructions = self.preamble,
399 gen_ai.prompt = tracing::field::Empty,
400 gen_ai.completion = tracing::field::Empty,
401 gen_ai.usage.input_tokens = tracing::field::Empty,
402 gen_ai.usage.output_tokens = tracing::field::Empty,
403 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
404 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
405 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
406 )
407 } else {
408 tracing::Span::current()
409 };
410
411 let prompt = self.prompt;
412 if let Some(text) = prompt.rag_text() {
413 agent_span.record("gen_ai.prompt", text);
414 }
415
416 let model = self.model.clone();
418 let preamble = self.preamble.clone();
419 let static_context = self.static_context.clone();
420 let temperature = self.temperature;
421 let max_tokens = self.max_tokens;
422 let additional_params = self.additional_params.clone();
423 let tool_server_handle = self.tool_server_handle.clone();
424 let dynamic_context = self.dynamic_context.clone();
425 let tool_choice = self.tool_choice.clone();
426 let agent_name = self.agent_name.clone();
427 let (chat_history, memory_handle) = match self.chat_history {
432 Some(history) => (Some(history), None),
433 None => match (self.memory, self.conversation_id) {
434 (Some(memory), Some(id)) => match memory.load(&id).await {
435 Ok(loaded) => (Some(loaded), Some((memory, id))),
436 Err(err) => {
437 let stream = async_stream::stream! {
438 yield Err(StreamingError::from(err));
439 };
440 return Box::pin(stream);
441 }
442 },
443 _ => (None, None),
444 },
445 };
446 let has_history = chat_history.is_some();
447 let mut new_messages: Vec<Message> = vec![prompt.clone()];
448
449 let mut current_max_turns = 0;
450 let mut last_prompt_error = String::new();
451
452 let mut text_delta_response = String::new();
453 let mut saw_text_this_turn = false;
454 let mut max_turns_reached = false;
455 let output_schema = self.output_schema;
456
457 let mut aggregated_usage = crate::completion::Usage::new();
458
459 let stream = async_stream::stream! {
466 'outer: loop {
467 let Some((current_prompt_ref, previous_messages)) = new_messages.split_last() else {
468 yield Err(cancelled_prompt_error(
469 chat_history.as_deref(),
470 new_messages.clone(),
471 "streaming loop lost its pending prompt".to_string(),
472 ).await);
473 break 'outer;
474 };
475 let current_prompt = current_prompt_ref.clone();
476
477 if current_max_turns > self.max_turns + 1 {
478 last_prompt_error = current_prompt.rag_text().unwrap_or_default();
479 max_turns_reached = true;
480 break;
481 }
482
483 current_max_turns += 1;
484
485 if self.max_turns > 1 {
486 tracing::info!(
487 "Current conversation Turns: {}/{}",
488 current_max_turns,
489 self.max_turns
490 );
491 }
492
493 let history_snapshot: Vec<Message> = build_history_for_request(
494 chat_history.as_deref(),
495 previous_messages,
496 );
497
498 if let Some(ref hook) = self.hook
499 && let HookAction::Terminate { reason } =
500 hook.on_completion_call(¤t_prompt, &history_snapshot).await
501 {
502 yield Err(
503 cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason)
504 .await,
505 );
506 break 'outer;
507 }
508
509 let chat_stream_span = info_span!(
510 target: "rig::agent_chat",
511 parent: tracing::Span::current(),
512 "chat_streaming",
513 gen_ai.operation.name = "chat",
514 gen_ai.agent.name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
515 gen_ai.system_instructions = preamble,
516 gen_ai.provider.name = tracing::field::Empty,
517 gen_ai.request.model = tracing::field::Empty,
518 gen_ai.response.id = tracing::field::Empty,
519 gen_ai.response.model = tracing::field::Empty,
520 gen_ai.usage.output_tokens = tracing::field::Empty,
521 gen_ai.usage.input_tokens = tracing::field::Empty,
522 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
523 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
524 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
525 gen_ai.input.messages = tracing::field::Empty,
526 gen_ai.output.messages = tracing::field::Empty,
527 );
528
529 let mut stream = tracing::Instrument::instrument(
530 build_completion_request(
531 &model,
532 current_prompt.clone(),
533 &history_snapshot,
534 preamble.as_deref(),
535 &static_context,
536 temperature,
537 max_tokens,
538 additional_params.as_ref(),
539 tool_choice.as_ref(),
540 &tool_server_handle,
541 &dynamic_context,
542 output_schema.as_ref(),
543 )
544 .await?
545 .stream(), chat_stream_span
546 )
547
548 .await?;
549
550 let mut tool_calls = vec![];
551 let mut tool_results = vec![];
552 let mut accumulated_reasoning: Vec<rig::message::Reasoning> = vec![];
553 let mut pending_reasoning_delta_text = String::new();
556 let mut pending_reasoning_delta_id: Option<String> = None;
557 let mut saw_tool_call_this_turn = false;
558
559 while let Some(content) = stream.next().await {
560 match content {
561 Ok(StreamedAssistantContent::Text(text)) => {
562 if !saw_text_this_turn {
563 text_delta_response.clear();
564 saw_text_this_turn = true;
565 }
566 text_delta_response.push_str(&text.text);
567 if let Some(ref hook) = self.hook &&
568 let HookAction::Terminate { reason } = hook.on_text_delta(&text.text, &text_delta_response).await {
569 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
570 break 'outer;
571 }
572
573 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text)));
574 },
575 Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id }) => {
576 let tool_span = info_span!(
577 parent: tracing::Span::current(),
578 "execute_tool",
579 gen_ai.operation.name = "execute_tool",
580 gen_ai.tool.type = "function",
581 gen_ai.tool.name = tracing::field::Empty,
582 gen_ai.tool.call.id = tracing::field::Empty,
583 gen_ai.tool.call.arguments = tracing::field::Empty,
584 gen_ai.tool.call.result = tracing::field::Empty
585 );
586
587 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() }));
588
589 let tc_result = async {
590 let tool_span = tracing::Span::current();
591 let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments);
592 if let Some(ref hook) = self.hook {
593 let action = hook
594 .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args)
595 .await;
596
597 if let ToolCallHookAction::Terminate { reason } = action {
598 return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
599 }
600
601 if let ToolCallHookAction::Skip { reason } = action {
602 tracing::info!(
604 tool_name = tool_call.function.name.as_str(),
605 reason = reason,
606 "Tool call rejected"
607 );
608 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
609 tool_calls.push(tool_call_msg);
610 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone()));
611 saw_tool_call_this_turn = true;
612 return Ok(reason);
613 }
614 }
615
616 tool_span.record("gen_ai.tool.name", &tool_call.function.name);
617 tool_span.record("gen_ai.tool.call.arguments", &tool_args);
618
619 let tool_result = match
620 tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
621 Ok(thing) => thing,
622 Err(e) => {
623 tracing::warn!("Error while calling tool: {e}");
624 e.to_string()
625 }
626 };
627
628 tool_span.record("gen_ai.tool.call.result", &tool_result);
629
630 if let Some(ref hook) = self.hook &&
631 let HookAction::Terminate { reason } =
632 hook.on_tool_result(
633 &tool_call.function.name,
634 tool_call.call_id.clone(),
635 &internal_call_id,
636 &tool_args,
637 &tool_result.to_string()
638 )
639 .await {
640 return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
641 }
642
643 let tool_call_msg = AssistantContent::ToolCall(tool_call.clone());
644
645 tool_calls.push(tool_call_msg);
646 tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone()));
647
648 saw_tool_call_this_turn = true;
649 Ok(tool_result)
650 }.instrument(tool_span).await;
651
652 match tc_result {
653 Ok(text) => {
654 let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) };
655 yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id }));
656 }
657 Err(e) => {
658 yield Err(e);
659 break 'outer;
660 }
661 }
662 },
663 Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content }) => {
664 if let Some(ref hook) = self.hook {
665 let (name, delta) = match &content {
666 rig::streaming::ToolCallDeltaContent::Name(n) => {
667 (Some(n.as_str()), "")
668 }
669 rig::streaming::ToolCallDeltaContent::Delta(d) => {
670 (None, d.as_str())
671 }
672 };
673
674 if let HookAction::Terminate { reason } = hook.on_tool_call_delta(&id, &internal_call_id, name, delta)
675 .await {
676 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
677 break 'outer;
678 }
679 }
680 }
681 Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
682 merge_reasoning_blocks(&mut accumulated_reasoning, &reasoning);
686 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(reasoning)));
687 },
688 Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
689 pending_reasoning_delta_text.push_str(&reasoning);
693 if pending_reasoning_delta_id.is_none() {
694 pending_reasoning_delta_id = id.clone();
695 }
696 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
697 },
698 Ok(StreamedAssistantContent::Final(final_resp)) => {
699 if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
700 if saw_text_this_turn {
701 if let Some(ref hook) = self.hook &&
702 let HookAction::Terminate { reason } = hook.on_stream_completion_response_finish(¤t_prompt, &final_resp).await {
703 yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await);
704 break 'outer;
705 }
706
707 yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp)));
708 saw_text_this_turn = false;
709 }
710 }
711 Err(e) => {
712 yield Err(e.into());
713 break 'outer;
714 }
715 }
716 }
717
718 if accumulated_reasoning.is_empty() && !pending_reasoning_delta_text.is_empty() {
722 let mut assembled = crate::message::Reasoning::new(&pending_reasoning_delta_text);
723 if let Some(id) = pending_reasoning_delta_id.take() {
724 assembled = assembled.with_id(id);
725 }
726 accumulated_reasoning.push(assembled);
727 }
728
729 let turn_text_response = assistant_text_from_choice(&stream.choice);
730 tracing::Span::current().record("gen_ai.completion", &turn_text_response);
731
732 if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() {
735 let mut content_items: Vec<rig::message::AssistantContent> = vec![];
736
737 if !turn_text_response.is_empty() {
739 content_items.push(rig::message::AssistantContent::text(&turn_text_response));
740 }
741
742 for reasoning in accumulated_reasoning.drain(..) {
744 content_items.push(rig::message::AssistantContent::Reasoning(reasoning));
745 }
746
747 content_items.extend(tool_calls.clone());
748
749 if let Some(content) = OneOrMany::from_iter_optional(content_items) {
750 new_messages.push(Message::Assistant {
751 id: stream.message_id.clone(),
752 content,
753 });
754 }
755 }
756
757 for (id, call_id, tool_result) in tool_results {
758 new_messages.push(tool_result_to_user_message(id, call_id, tool_result));
759 }
760
761 if !saw_tool_call_this_turn {
762 if !turn_text_response.is_empty() {
764 new_messages.push(Message::assistant(&turn_text_response));
765 } else {
766 tracing::warn!(
767 agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
768 message_id = ?stream.message_id,
769 "Streaming turn completed without assistant text; final response will be empty"
770 );
771 }
772
773 let current_span = tracing::Span::current();
774 current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens);
775 current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens);
776 current_span.record("gen_ai.usage.cache_read.input_tokens", aggregated_usage.cached_input_tokens);
777 current_span.record("gen_ai.usage.cache_creation.input_tokens", aggregated_usage.cache_creation_input_tokens);
778 current_span.record("gen_ai.usage.reasoning_tokens", aggregated_usage.reasoning_tokens);
779 tracing::info!("Agent multi-turn stream finished");
780 if let Some((memory, id)) = memory_handle.as_ref()
781 && let Err(err) = memory.append(id, new_messages.clone()).await
782 {
783 tracing::warn!(
784 error = %err,
785 conversation_id = %id,
786 "conversation memory append failed; yielding final response anyway"
787 );
788 }
789 let final_messages: Option<Vec<Message>> = if has_history {
790 Some(new_messages.clone())
791 } else {
792 None
793 };
794 yield Ok(MultiTurnStreamItem::final_response_with_history(
795 &turn_text_response,
796 aggregated_usage,
797 final_messages,
798 ));
799 break;
800 }
801 }
802
803 if max_turns_reached {
804 yield Err(Box::new(PromptError::MaxTurnsError {
805 max_turns: self.max_turns,
806 chat_history: build_full_history(chat_history.as_deref(), new_messages.clone()).into(),
807 prompt: Box::new(last_prompt_error.clone().into()),
808 }).into());
809 }
810 };
811
812 Box::pin(stream.instrument(agent_span))
813 }
814}
815
816impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
817where
818 M: CompletionModel + 'static,
819 <M as CompletionModel>::StreamingResponse: WasmCompatSend,
820 P: PromptHook<M> + 'static,
821{
822 type Output = StreamingResult<M::StreamingResponse>; type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
824
825 fn into_future(self) -> Self::IntoFuture {
826 Box::pin(async move { self.send().await })
828 }
829}
830
831pub async fn stream_to_stdout<R>(
833 stream: &mut StreamingResult<R>,
834) -> Result<FinalResponse, std::io::Error> {
835 let mut final_res = FinalResponse::empty();
836 print!("Response: ");
837 while let Some(content) = stream.next().await {
838 match content {
839 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
840 Text { text },
841 ))) => {
842 print!("{text}");
843 std::io::Write::flush(&mut std::io::stdout())?;
844 }
845 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(
846 reasoning,
847 ))) => {
848 let reasoning = reasoning.display_text();
849 print!("{reasoning}");
850 std::io::Write::flush(&mut std::io::stdout())?;
851 }
852 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
853 final_res = res;
854 }
855 Err(err) => {
856 eprintln!("Error: {err}");
857 }
858 _ => {}
859 }
860 }
861
862 Ok(final_res)
863}
864
865#[cfg(test)]
866mod tests {
867 use super::*;
868 use crate::agent::AgentBuilder;
869 use crate::client::ProviderClient;
870 use crate::client::completion::CompletionClient;
871 use crate::completion::CompletionRequest;
872 use crate::message::{
873 AssistantContent, DocumentSourceKind, ImageMediaType, Message, ReasoningContent,
874 ToolResultContent, UserContent,
875 };
876 use crate::providers::anthropic;
877 use crate::streaming::StreamingPrompt;
878 use crate::test_utils::{
879 AppendFailingMemory, FailingMemory, MockCompletionModel, MockStreamEvent,
880 };
881 use futures::StreamExt;
882 use std::sync::Arc;
883 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
884 use std::time::Duration;
885
886 #[test]
887 fn merge_reasoning_blocks_preserves_order_and_signatures() {
888 let mut accumulated = Vec::new();
889 let first = crate::message::Reasoning {
890 id: Some("rs_1".to_string()),
891 content: vec![ReasoningContent::Text {
892 text: "step-1".to_string(),
893 signature: Some("sig-1".to_string()),
894 }],
895 };
896 let second = crate::message::Reasoning {
897 id: Some("rs_1".to_string()),
898 content: vec![
899 ReasoningContent::Text {
900 text: "step-2".to_string(),
901 signature: Some("sig-2".to_string()),
902 },
903 ReasoningContent::Summary("summary".to_string()),
904 ],
905 };
906
907 merge_reasoning_blocks(&mut accumulated, &first);
908 merge_reasoning_blocks(&mut accumulated, &second);
909
910 assert_eq!(accumulated.len(), 1);
911 let merged = accumulated.first().expect("expected accumulated reasoning");
912 assert_eq!(merged.id.as_deref(), Some("rs_1"));
913 assert_eq!(merged.content.len(), 3);
914 assert!(matches!(
915 merged.content.first(),
916 Some(ReasoningContent::Text { text, signature: Some(sig) })
917 if text == "step-1" && sig == "sig-1"
918 ));
919 assert!(matches!(
920 merged.content.get(1),
921 Some(ReasoningContent::Text { text, signature: Some(sig) })
922 if text == "step-2" && sig == "sig-2"
923 ));
924 }
925
926 #[test]
927 fn merge_reasoning_blocks_keeps_distinct_ids_as_separate_items() {
928 let mut accumulated = vec![crate::message::Reasoning {
929 id: Some("rs_a".to_string()),
930 content: vec![ReasoningContent::Text {
931 text: "step-1".to_string(),
932 signature: None,
933 }],
934 }];
935 let incoming = crate::message::Reasoning {
936 id: Some("rs_b".to_string()),
937 content: vec![ReasoningContent::Text {
938 text: "step-2".to_string(),
939 signature: None,
940 }],
941 };
942
943 merge_reasoning_blocks(&mut accumulated, &incoming);
944 assert_eq!(accumulated.len(), 2);
945 assert_eq!(
946 accumulated.first().and_then(|r| r.id.as_deref()),
947 Some("rs_a")
948 );
949 assert_eq!(
950 accumulated.get(1).and_then(|r| r.id.as_deref()),
951 Some("rs_b")
952 );
953 }
954
955 #[test]
956 fn merge_reasoning_blocks_keeps_none_ids_separate_items() {
957 let mut accumulated = vec![crate::message::Reasoning {
958 id: None,
959 content: vec![ReasoningContent::Text {
960 text: "first".to_string(),
961 signature: None,
962 }],
963 }];
964 let incoming = crate::message::Reasoning {
965 id: None,
966 content: vec![ReasoningContent::Text {
967 text: "second".to_string(),
968 signature: None,
969 }],
970 };
971
972 merge_reasoning_blocks(&mut accumulated, &incoming);
973 assert_eq!(accumulated.len(), 2);
974 assert!(matches!(
975 accumulated.first(),
976 Some(crate::message::Reasoning {
977 id: None,
978 content
979 }) if matches!(
980 content.first(),
981 Some(ReasoningContent::Text { text, .. }) if text == "first"
982 )
983 ));
984 assert!(matches!(
985 accumulated.get(1),
986 Some(crate::message::Reasoning {
987 id: None,
988 content
989 }) if matches!(
990 content.first(),
991 Some(ReasoningContent::Text { text, .. }) if text == "second"
992 )
993 ));
994 }
995
996 #[test]
997 fn tool_result_to_user_message_preserves_multimodal_tool_output() {
998 let message = tool_result_to_user_message(
999 "tool_call_1".to_string(),
1000 Some("call_1".to_string()),
1001 serde_json::json!({
1002 "response": {
1003 "instruction": "Use the image part to answer."
1004 },
1005 "parts": [
1006 {
1007 "type": "image",
1008 "data": "base64data==",
1009 "mimeType": "image/png"
1010 }
1011 ]
1012 })
1013 .to_string(),
1014 );
1015
1016 let tool_result = match message {
1017 Message::User { content } => match content.first() {
1018 UserContent::ToolResult(tool_result) => tool_result,
1019 other => panic!("expected tool result content, got {other:?}"),
1020 },
1021 other => panic!("expected user message, got {other:?}"),
1022 };
1023
1024 assert_eq!(tool_result.id, "tool_call_1");
1025 assert_eq!(tool_result.call_id.as_deref(), Some("call_1"));
1026 assert_eq!(tool_result.content.len(), 2);
1027
1028 let mut items = tool_result.content.iter();
1029 match items.next() {
1030 Some(ToolResultContent::Text(text)) => {
1031 assert!(text.text.contains("Use the image part to answer."));
1032 }
1033 other => panic!("expected structured text payload first, got {other:?}"),
1034 }
1035
1036 match items.next() {
1037 Some(ToolResultContent::Image(image)) => {
1038 assert_eq!(image.media_type, Some(ImageMediaType::PNG));
1039 assert!(matches!(
1040 image.data,
1041 DocumentSourceKind::Base64(ref data) if data == "base64data=="
1042 ));
1043 }
1044 other => panic!("expected image payload second, got {other:?}"),
1045 }
1046 }
1047
1048 fn validate_follow_up_tool_history(request: &CompletionRequest) -> Result<(), String> {
1049 let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
1050 if history.len() != 3 {
1051 return Err(format!(
1052 "follow-up request should contain [original user prompt, assistant tool call, user tool result]: {history:?}"
1053 ));
1054 }
1055
1056 if !matches!(
1057 history.first(),
1058 Some(Message::User { content })
1059 if matches!(
1060 content.first(),
1061 UserContent::Text(text) if text.text == "do tool work"
1062 )
1063 ) {
1064 return Err(format!(
1065 "follow-up request should begin with the original user prompt: {history:?}"
1066 ));
1067 }
1068
1069 if !matches!(
1070 history.get(1),
1071 Some(Message::Assistant { content, .. })
1072 if matches!(
1073 content.first(),
1074 AssistantContent::ToolCall(tool_call)
1075 if tool_call.id == "tool_call_1"
1076 && tool_call.call_id.as_deref() == Some("call_1")
1077 )
1078 ) {
1079 return Err(format!(
1080 "follow-up request is missing the assistant tool call in position 2: {history:?}"
1081 ));
1082 }
1083
1084 if !matches!(
1085 history.get(2),
1086 Some(Message::User { content })
1087 if matches!(
1088 content.first(),
1089 UserContent::ToolResult(tool_result)
1090 if tool_result.id == "tool_call_1"
1091 && tool_result.call_id.as_deref() == Some("call_1")
1092 )
1093 ) {
1094 return Err(format!(
1095 "follow-up request should end with the user tool result: {history:?}"
1096 ));
1097 }
1098
1099 Ok(())
1100 }
1101
1102 fn streaming_tool_then_text_model() -> MockCompletionModel {
1103 MockCompletionModel::from_stream_turns([
1104 vec![
1105 MockStreamEvent::tool_call(
1106 "tool_call_1",
1107 "missing_tool",
1108 serde_json::json!({"input": "value"}),
1109 )
1110 .with_call_id("call_1"),
1111 MockStreamEvent::final_response_with_total_tokens(4),
1112 ],
1113 vec![
1114 MockStreamEvent::text("done"),
1115 MockStreamEvent::final_response_with_total_tokens(6),
1116 ],
1117 ])
1118 }
1119
1120 fn streaming_text_then_final_model() -> MockCompletionModel {
1121 MockCompletionModel::from_stream_turns([[
1122 MockStreamEvent::text("hello"),
1123 MockStreamEvent::text(" world"),
1124 MockStreamEvent::final_response_with_total_tokens(3),
1125 ]])
1126 }
1127
1128 fn streaming_final_only_model() -> MockCompletionModel {
1129 MockCompletionModel::from_stream_turns([[
1130 MockStreamEvent::final_response_with_total_tokens(1),
1131 ]])
1132 }
1133
1134 #[tokio::test]
1135 async fn stream_prompt_continues_after_tool_call_turn() {
1136 let model = streaming_tool_then_text_model();
1137 let recorded = model.clone();
1138 let agent = AgentBuilder::new(model).build();
1139 let empty_history: &[Message] = &[];
1140
1141 let mut stream = agent
1142 .stream_prompt("do tool work")
1143 .with_history(empty_history)
1144 .multi_turn(3)
1145 .await;
1146 let mut saw_tool_call = false;
1147 let mut saw_tool_result = false;
1148 let mut saw_final_response = false;
1149 let mut final_text = String::new();
1150 let mut final_response_text = None;
1151 let mut final_history = None;
1152
1153 while let Some(item) = stream.next().await {
1154 match item {
1155 Ok(MultiTurnStreamItem::StreamAssistantItem(
1156 StreamedAssistantContent::ToolCall { .. },
1157 )) => {
1158 saw_tool_call = true;
1159 }
1160 Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
1161 ..
1162 })) => {
1163 saw_tool_result = true;
1164 }
1165 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1166 text,
1167 ))) => {
1168 final_text.push_str(&text.text);
1169 }
1170 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1171 saw_final_response = true;
1172 final_response_text = Some(res.response().to_owned());
1173 final_history = res.history().map(|history| history.to_vec());
1174 break;
1175 }
1176 Ok(_) => {}
1177 Err(err) => panic!("unexpected streaming error: {err:?}"),
1178 }
1179 }
1180
1181 assert!(saw_tool_call);
1182 assert!(saw_tool_result);
1183 assert!(saw_final_response);
1184 assert_eq!(final_text, "done");
1185 assert_eq!(final_response_text.as_deref(), Some("done"));
1186 let history = final_history.expect("expected final response history");
1187 assert!(history.iter().any(|message| matches!(
1188 message,
1189 Message::Assistant { content, .. }
1190 if content.iter().any(|item| matches!(
1191 item,
1192 AssistantContent::Text(text) if text.text == "done"
1193 ))
1194 )));
1195 let requests = recorded.requests();
1196 assert_eq!(requests.len(), 2);
1197 assert!(validate_follow_up_tool_history(&requests[1]).is_ok());
1198 }
1199
1200 #[tokio::test]
1201 async fn final_response_matches_streamed_text_when_provider_final_is_textless() {
1202 let agent = AgentBuilder::new(streaming_text_then_final_model()).build();
1203
1204 let mut stream = agent.stream_prompt("say hello").await;
1205 let mut streamed_text = String::new();
1206 let mut final_response_text = None;
1207
1208 while let Some(item) = stream.next().await {
1209 match item {
1210 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1211 text,
1212 ))) => streamed_text.push_str(&text.text),
1213 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1214 final_response_text = Some(res.response().to_owned());
1215 break;
1216 }
1217 Ok(_) => {}
1218 Err(err) => panic!("unexpected streaming error: {err:?}"),
1219 }
1220 }
1221
1222 assert_eq!(streamed_text, "hello world");
1223 assert_eq!(final_response_text.as_deref(), Some("hello world"));
1224 }
1225
1226 #[tokio::test]
1227 async fn final_response_can_remain_empty_for_truly_textless_turns() {
1228 let agent = AgentBuilder::new(streaming_final_only_model()).build();
1229
1230 let mut stream = agent.stream_prompt("say nothing").await;
1231 let mut streamed_text = String::new();
1232 let mut final_response_text = None;
1233
1234 while let Some(item) = stream.next().await {
1235 match item {
1236 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1237 text,
1238 ))) => streamed_text.push_str(&text.text),
1239 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1240 final_response_text = Some(res.response().to_owned());
1241 break;
1242 }
1243 Ok(_) => {}
1244 Err(err) => panic!("unexpected streaming error: {err:?}"),
1245 }
1246 }
1247
1248 assert!(streamed_text.is_empty());
1249 assert_eq!(final_response_text.as_deref(), Some(""));
1250 }
1251
1252 async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
1255 let mut interval = tokio::time::interval(Duration::from_millis(50));
1256 let mut count = 0u32;
1257
1258 while !stop.load(Ordering::Relaxed) {
1259 interval.tick().await;
1260 count += 1;
1261
1262 tracing::event!(
1263 target: "background_logger",
1264 tracing::Level::INFO,
1265 count = count,
1266 "Background tick"
1267 );
1268
1269 let current = tracing::Span::current();
1271 if !current.is_disabled() && !current.is_none() {
1272 leak_count.fetch_add(1, Ordering::Relaxed);
1273 }
1274 }
1275
1276 tracing::info!(target: "background_logger", total_ticks = count, "Background logger stopped");
1277 }
1278
1279 #[tokio::test(flavor = "current_thread")]
1287 #[ignore = "This requires an API key"]
1288 async fn test_span_context_isolation() -> anyhow::Result<()> {
1289 let stop = Arc::new(AtomicBool::new(false));
1290 let leak_count = Arc::new(AtomicU32::new(0));
1291
1292 let bg_stop = stop.clone();
1294 let bg_leak = leak_count.clone();
1295 let bg_handle = tokio::spawn(async move {
1296 background_logger(bg_stop, bg_leak).await;
1297 });
1298
1299 tokio::time::sleep(Duration::from_millis(100)).await;
1301
1302 let client = anthropic::Client::from_env()?;
1305 let agent = client
1306 .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
1307 .preamble("You are a helpful assistant.")
1308 .temperature(0.1)
1309 .max_tokens(100)
1310 .build();
1311
1312 let mut stream = agent
1313 .stream_prompt("Say 'hello world' and nothing else.")
1314 .await;
1315
1316 let mut full_content = String::new();
1317 while let Some(item) = stream.next().await {
1318 match item {
1319 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1320 text,
1321 ))) => {
1322 full_content.push_str(&text.text);
1323 }
1324 Ok(MultiTurnStreamItem::FinalResponse(_)) => {
1325 break;
1326 }
1327 Err(e) => {
1328 tracing::warn!("Error: {:?}", e);
1329 break;
1330 }
1331 _ => {}
1332 }
1333 }
1334
1335 tracing::info!("Got response: {:?}", full_content);
1336
1337 stop.store(true, Ordering::Relaxed);
1339 bg_handle.await?;
1340
1341 let leaks = leak_count.load(Ordering::Relaxed);
1342 anyhow::ensure!(
1343 leaks == 0,
1344 "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \
1345 This indicates that span.enter() is being used inside async_stream instead of .instrument()"
1346 );
1347
1348 Ok(())
1349 }
1350
1351 #[tokio::test]
1357 #[ignore = "This requires an API key"]
1358 async fn test_chat_history_in_final_response() -> anyhow::Result<()> {
1359 use crate::message::Message;
1360
1361 let client = anthropic::Client::from_env()?;
1362 let agent = client
1363 .agent(anthropic::completion::CLAUDE_HAIKU_4_5)
1364 .preamble("You are a helpful assistant. Keep responses brief.")
1365 .temperature(0.1)
1366 .max_tokens(50)
1367 .build();
1368
1369 let empty_history: &[Message] = &[];
1371 let mut stream = agent
1372 .stream_prompt("Say 'hello' and nothing else.")
1373 .with_history(empty_history)
1374 .await;
1375
1376 let mut response_text = String::new();
1378 let mut final_history = None;
1379 while let Some(item) = stream.next().await {
1380 match item {
1381 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
1382 text,
1383 ))) => {
1384 response_text.push_str(&text.text);
1385 }
1386 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1387 final_history = res.history().map(|h| h.to_vec());
1388 break;
1389 }
1390 Err(e) => {
1391 return Err(e.into());
1392 }
1393 _ => {}
1394 }
1395 }
1396
1397 let history = final_history
1398 .ok_or_else(|| anyhow::anyhow!("final response should include history"))?;
1399
1400 anyhow::ensure!(
1402 history.iter().any(|m| matches!(m, Message::User { .. })),
1403 "History should contain the user message"
1404 );
1405
1406 anyhow::ensure!(
1408 history
1409 .iter()
1410 .any(|m| matches!(m, Message::Assistant { .. })),
1411 "History should contain the assistant response"
1412 );
1413
1414 tracing::info!(
1415 "History after streaming: {} messages, response: {:?}",
1416 history.len(),
1417 response_text
1418 );
1419
1420 Ok(())
1421 }
1422
1423 #[tokio::test]
1424 async fn streaming_appends_to_memory_after_final_response() {
1425 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
1426
1427 let memory = InMemoryConversationMemory::new();
1428 let agent = AgentBuilder::new(streaming_text_then_final_model())
1429 .memory(memory.clone())
1430 .build();
1431
1432 let mut stream = agent
1433 .stream_prompt("hi there")
1434 .conversation("stream-thread")
1435 .await;
1436
1437 let mut history_in_final = None;
1438 while let Some(item) = stream.next().await {
1439 match item {
1440 Ok(MultiTurnStreamItem::FinalResponse(res)) => {
1441 history_in_final = res.history().map(|h| h.to_vec());
1442 break;
1443 }
1444 Ok(_) => {}
1445 Err(err) => panic!("unexpected streaming error: {err:?}"),
1446 }
1447 }
1448
1449 let final_history = history_in_final
1450 .expect("FinalResponse.history should be populated when memory is configured");
1451 assert_eq!(
1452 final_history.len(),
1453 2,
1454 "user prompt + assistant response in final history: {final_history:?}"
1455 );
1456
1457 let stored = memory.load("stream-thread").await.unwrap();
1458 assert_eq!(stored.len(), 2, "memory should contain user + assistant");
1459 }
1460
1461 #[tokio::test]
1462 async fn streaming_with_history_overrides_memory() {
1463 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
1464
1465 let memory = InMemoryConversationMemory::new();
1466 memory
1467 .append("t1", vec![Message::user("from-memory")])
1468 .await
1469 .unwrap();
1470
1471 let agent = AgentBuilder::new(streaming_text_then_final_model())
1472 .memory(memory.clone())
1473 .build();
1474
1475 let mut stream = agent
1476 .stream_prompt("hi")
1477 .conversation("t1")
1478 .with_history(vec![Message::user("from-caller")])
1479 .await;
1480
1481 while let Some(item) = stream.next().await {
1482 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
1483 break;
1484 }
1485 }
1486
1487 let stored = memory.load("t1").await.unwrap();
1488 assert_eq!(
1489 stored.len(),
1490 1,
1491 "with_history bypasses memory; only the pre-seeded entry remains: {stored:?}"
1492 );
1493 }
1494
1495 #[tokio::test]
1496 async fn streaming_without_memory_disables_for_request() {
1497 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
1498
1499 let memory = InMemoryConversationMemory::new();
1500 let agent = AgentBuilder::new(streaming_text_then_final_model())
1501 .memory(memory.clone())
1502 .conversation_id("default")
1503 .build();
1504
1505 let mut stream = agent.stream_prompt("hi").without_memory().await;
1506
1507 while let Some(item) = stream.next().await {
1508 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
1509 break;
1510 }
1511 }
1512
1513 let stored = memory.load("default").await.unwrap();
1514 assert!(stored.is_empty(), "without_memory disables save");
1515 }
1516
1517 #[tokio::test]
1518 async fn streaming_load_error_yields_memory_error() {
1519 let agent = AgentBuilder::new(streaming_text_then_final_model())
1520 .memory(FailingMemory::default())
1521 .build();
1522
1523 let mut stream = agent.stream_prompt("hi").conversation("t1").await;
1524
1525 let first = stream.next().await.expect("at least one item");
1526 match first {
1527 Err(err) => {
1528 let msg = format!("{err:?}");
1529 assert!(
1530 msg.contains("Memory") || msg.contains("memory") || msg.contains("load boom"),
1531 "expected memory error, got: {msg}"
1532 );
1533 }
1534 Ok(other) => panic!("expected memory error, got {other:?}"),
1535 }
1536 }
1537
1538 #[tokio::test]
1539 async fn streaming_with_filter_shapes_loaded_history() {
1540 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
1541
1542 let memory = InMemoryConversationMemory::new()
1543 .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
1544 memory
1545 .append(
1546 "t1",
1547 vec![
1548 Message::user("1"),
1549 Message::assistant("2"),
1550 Message::user("3"),
1551 Message::assistant("4"),
1552 ],
1553 )
1554 .await
1555 .unwrap();
1556
1557 let model = MockCompletionModel::from_stream_turns([[
1558 MockStreamEvent::text("ok"),
1559 MockStreamEvent::final_response_with_total_tokens(1),
1560 ]]);
1561 let recorded = model.clone();
1562 let agent = AgentBuilder::new(model).memory(memory).build();
1563
1564 let mut stream = agent.stream_prompt("ping").conversation("t1").await;
1565 while let Some(item) = stream.next().await {
1566 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
1567 break;
1568 }
1569 }
1570
1571 let received = recorded.requests()[0]
1572 .chat_history
1573 .iter()
1574 .cloned()
1575 .collect::<Vec<_>>();
1576 assert_eq!(
1577 received.len(),
1578 3,
1579 "window-truncated history (2) + current prompt: {received:?}"
1580 );
1581 }
1582
1583 #[tokio::test]
1584 async fn streaming_append_error_does_not_suppress_final_response() {
1585 let agent = AgentBuilder::new(streaming_text_then_final_model())
1586 .memory(AppendFailingMemory::default())
1587 .build();
1588
1589 let mut stream = agent.stream_prompt("hi").conversation("t1").await;
1590
1591 let mut saw_final = false;
1592 while let Some(item) = stream.next().await {
1593 if let Ok(MultiTurnStreamItem::FinalResponse(_)) = item {
1594 saw_final = true;
1595 break;
1596 }
1597 }
1598 assert!(
1599 saw_final,
1600 "FinalResponse must be yielded even when memory.append fails"
1601 );
1602 }
1603}