1pub mod hooks;
2pub mod streaming;
3
4use super::{
5 Agent,
6 completion::{DynamicContextStore, build_prepared_completion_request},
7 run::{AgentRun, AgentRunStep, ModelTurn, ModelTurnOutcome, PendingToolCall},
8};
9use crate::{
10 OneOrMany,
11 completion::{CompletionModel, Document, Message, PromptError, Usage},
12 json_utils,
13 memory::ConversationMemory,
14 message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
15 tool::server::ToolServerHandle,
16 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
17};
18use futures::{StreamExt, stream};
19use hooks::{HookAction, InvalidToolCallHookAction, PromptHook, ToolCallHookAction};
20use serde::{Deserialize, Serialize};
21use std::{
22 future::IntoFuture,
23 marker::PhantomData,
24 sync::{
25 Arc,
26 atomic::{AtomicU64, Ordering},
27 },
28};
29use tracing::info_span;
30use tracing::{Instrument, span::Id};
31
32pub trait PromptType {}
33pub struct Standard;
34pub struct Extended;
35
36impl PromptType for Standard {}
37impl PromptType for Extended {}
38
39pub struct PromptRequest<S, M, P>
48where
49 S: PromptType,
50 M: CompletionModel,
51 P: PromptHook<M>,
52{
53 prompt: Message,
55 chat_history: Option<Vec<Message>>,
57 max_turns: usize,
59
60 model: Arc<M>,
63 agent_name: Option<String>,
65 preamble: Option<String>,
67 static_context: Vec<Document>,
69 temperature: Option<f64>,
71 max_tokens: Option<u64>,
73 additional_params: Option<serde_json::Value>,
75 tool_server_handle: ToolServerHandle,
77 dynamic_context: DynamicContextStore,
79 tool_choice: Option<ToolChoice>,
81
82 state: PhantomData<S>,
84 hook: Option<P>,
86 max_invalid_tool_call_retries: usize,
88 concurrency: usize,
90 output_schema: Option<schemars::Schema>,
92 memory: Option<Arc<dyn ConversationMemory>>,
94 conversation_id: Option<String>,
96}
97
98impl<M, P> PromptRequest<Standard, M, P>
99where
100 M: CompletionModel,
101 P: PromptHook<M>,
102{
103 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
105 PromptRequest {
106 prompt: prompt.into(),
107 chat_history: None,
108 max_turns: agent.default_max_turns.unwrap_or_default(),
109 model: agent.model.clone(),
110 agent_name: agent.name.clone(),
111 preamble: agent.preamble.clone(),
112 static_context: agent.static_context.clone(),
113 temperature: agent.temperature,
114 max_tokens: agent.max_tokens,
115 additional_params: agent.additional_params.clone(),
116 tool_server_handle: agent.tool_server_handle.clone(),
117 dynamic_context: agent.dynamic_context.clone(),
118 tool_choice: agent.tool_choice.clone(),
119 state: PhantomData,
120 hook: agent.hook.clone(),
121 max_invalid_tool_call_retries: 0,
122 concurrency: 1,
123 output_schema: agent.output_schema.clone(),
124 memory: agent.memory.clone(),
125 conversation_id: agent.default_conversation_id.clone(),
126 }
127 }
128}
129
130impl<S, M, P> PromptRequest<S, M, P>
131where
132 S: PromptType,
133 M: CompletionModel,
134 P: PromptHook<M>,
135{
136 pub fn extended_details(self) -> PromptRequest<Extended, M, P> {
143 PromptRequest {
144 prompt: self.prompt,
145 chat_history: self.chat_history,
146 max_turns: self.max_turns,
147 model: self.model,
148 agent_name: self.agent_name,
149 preamble: self.preamble,
150 static_context: self.static_context,
151 temperature: self.temperature,
152 max_tokens: self.max_tokens,
153 additional_params: self.additional_params,
154 tool_server_handle: self.tool_server_handle,
155 dynamic_context: self.dynamic_context,
156 tool_choice: self.tool_choice,
157 state: PhantomData,
158 hook: self.hook,
159 max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
160 concurrency: self.concurrency,
161 output_schema: self.output_schema,
162 memory: self.memory,
163 conversation_id: self.conversation_id,
164 }
165 }
166
167 pub fn max_turns(mut self, depth: usize) -> Self {
170 self.max_turns = depth;
171 self
172 }
173
174 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
177 self.concurrency = concurrency;
178 self
179 }
180
181 pub fn with_history<H, T>(mut self, history: H) -> Self
183 where
184 H: IntoIterator<Item = T>,
185 T: Into<Message>,
186 {
187 self.chat_history = Some(history.into_iter().map(Into::into).collect());
188 self
189 }
190
191 pub fn conversation(mut self, id: impl Into<String>) -> Self {
196 self.conversation_id = Some(id.into());
197 self
198 }
199
200 pub fn without_memory(mut self) -> Self {
204 self.memory = None;
205 self.conversation_id = None;
206 self
207 }
208
209 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<S, M, P2>
212 where
213 P2: PromptHook<M>,
214 {
215 PromptRequest {
216 prompt: self.prompt,
217 chat_history: self.chat_history,
218 max_turns: self.max_turns,
219 model: self.model,
220 agent_name: self.agent_name,
221 preamble: self.preamble,
222 static_context: self.static_context,
223 temperature: self.temperature,
224 max_tokens: self.max_tokens,
225 additional_params: self.additional_params,
226 tool_server_handle: self.tool_server_handle,
227 dynamic_context: self.dynamic_context,
228 tool_choice: self.tool_choice,
229 state: PhantomData,
230 hook: Some(hook),
231 max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
232 concurrency: self.concurrency,
233 output_schema: self.output_schema,
234 memory: self.memory,
235 conversation_id: self.conversation_id,
236 }
237 }
238
239 pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
243 self.max_invalid_tool_call_retries = retries;
244 self
245 }
246}
247
248impl<M, P> IntoFuture for PromptRequest<Standard, M, P>
252where
253 M: CompletionModel + 'static,
254 P: PromptHook<M> + 'static,
255{
256 type Output = Result<String, PromptError>;
257 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
258
259 fn into_future(self) -> Self::IntoFuture {
260 Box::pin(self.send())
261 }
262}
263
264impl<M, P> IntoFuture for PromptRequest<Extended, M, P>
265where
266 M: CompletionModel + 'static,
267 P: PromptHook<M> + 'static,
268{
269 type Output = Result<PromptResponse, PromptError>;
270 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
271
272 fn into_future(self) -> Self::IntoFuture {
273 Box::pin(self.send())
274 }
275}
276
277impl<M, P> PromptRequest<Standard, M, P>
278where
279 M: CompletionModel,
280 P: PromptHook<M>,
281{
282 async fn send(self) -> Result<String, PromptError> {
283 self.extended_details().send().await.map(|resp| resp.output)
284 }
285}
286
287#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
289#[non_exhaustive]
290pub struct CompletionCall {
291 pub call_index: usize,
293 #[serde(default, deserialize_with = "usage_null_as_default")]
299 pub usage: Usage,
300}
301
302impl CompletionCall {
303 pub fn new(call_index: usize, usage: Usage) -> Self {
305 Self { call_index, usage }
306 }
307}
308
309fn usage_null_as_default<'de, D>(deserializer: D) -> Result<Usage, D::Error>
316where
317 D: serde::Deserializer<'de>,
318{
319 Ok(Option::<Usage>::deserialize(deserializer)?.unwrap_or_default())
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323#[non_exhaustive]
324pub struct PromptResponse {
325 pub output: String,
326 pub usage: Usage,
327 #[serde(default, skip_serializing_if = "Vec::is_empty")]
334 pub completion_calls: Vec<CompletionCall>,
335 pub messages: Option<Vec<Message>>,
336}
337
338impl std::fmt::Display for PromptResponse {
339 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340 self.output.fmt(f)
341 }
342}
343
344impl PromptResponse {
345 pub fn new(output: impl Into<String>, usage: Usage) -> Self {
346 Self {
347 output: output.into(),
348 usage,
349 completion_calls: Vec::new(),
350 messages: None,
351 }
352 }
353
354 pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
355 self.messages = Some(messages);
356 self
357 }
358
359 pub fn with_completion_calls(mut self, completion_calls: Vec<CompletionCall>) -> Self {
361 self.completion_calls = completion_calls;
362 self
363 }
364
365 pub fn completion_calls(&self) -> &[CompletionCall] {
370 &self.completion_calls
371 }
372
373 pub fn requests(&self) -> usize {
375 self.completion_calls.len()
376 }
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
380#[non_exhaustive]
381pub struct TypedPromptResponse<T> {
382 pub output: T,
383 pub usage: Usage,
384 #[serde(default, skip_serializing_if = "Vec::is_empty")]
391 pub completion_calls: Vec<CompletionCall>,
392}
393
394impl<T> TypedPromptResponse<T> {
395 pub fn new(output: T, usage: Usage) -> Self {
396 Self {
397 output,
398 usage,
399 completion_calls: Vec::new(),
400 }
401 }
402
403 pub fn with_completion_calls(mut self, completion_calls: Vec<CompletionCall>) -> Self {
405 self.completion_calls = completion_calls;
406 self
407 }
408
409 pub fn completion_calls(&self) -> &[CompletionCall] {
414 &self.completion_calls
415 }
416
417 pub fn requests(&self) -> usize {
419 self.completion_calls.len()
420 }
421}
422
423const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
424
425pub(crate) const TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER: &str =
426 "Tool not executed because another tool call in the same assistant turn was invalid.";
427
428pub(crate) fn build_history_for_request(
430 chat_history: Option<&[Message]>,
431 new_messages: &[Message],
432) -> Vec<Message> {
433 let input = chat_history.unwrap_or(&[]);
434 input.iter().chain(new_messages.iter()).cloned().collect()
435}
436
437pub(crate) fn build_full_history(
439 chat_history: Option<&[Message]>,
440 new_messages: Vec<Message>,
441) -> Vec<Message> {
442 let input = chat_history.unwrap_or(&[]);
443 input.iter().cloned().chain(new_messages).collect()
444}
445
446pub(crate) fn tool_result_user_content(
447 id: String,
448 call_id: Option<String>,
449 tool_result: String,
450) -> UserContent {
451 let content = ToolResultContent::from_tool_output(tool_result);
452 match call_id {
453 Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
454 None => UserContent::tool_result(id, content),
455 }
456}
457
458pub(crate) fn invalid_tool_retry_user_message(
459 assistant_content: &OneOrMany<AssistantContent>,
460 invalid_tool_call_id: &str,
461 feedback: String,
462) -> Option<Message> {
463 let retry_results = assistant_content
464 .iter()
465 .filter_map(|content| match content {
466 AssistantContent::ToolCall(tool_call) if tool_call.id == invalid_tool_call_id => {
467 Some(tool_result_user_content(
468 tool_call.id.clone(),
469 tool_call.call_id.clone(),
470 feedback.clone(),
471 ))
472 }
473 AssistantContent::ToolCall(tool_call) => Some(tool_result_user_content(
474 tool_call.id.clone(),
475 tool_call.call_id.clone(),
476 TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
477 )),
478 _ => None,
479 })
480 .collect::<Vec<_>>();
481
482 Some(Message::User {
483 content: OneOrMany::from_iter_optional(retry_results)?,
484 })
485}
486
487pub(crate) fn is_empty_assistant_turn(choice: &OneOrMany<AssistantContent>) -> bool {
488 choice.len() == 1
489 && matches!(
490 choice.first(),
491 AssistantContent::Text(text) if text.text.is_empty() && text.additional_params.is_none()
492 )
493}
494
495pub(crate) fn assistant_text_from_choice(choice: &OneOrMany<AssistantContent>) -> String {
496 choice
497 .iter()
498 .filter_map(|content| match content {
499 AssistantContent::Text(text) => Some(text.text.as_str()),
500 _ => None,
501 })
502 .collect()
503}
504
505impl<M, P> PromptRequest<Extended, M, P>
506where
507 M: CompletionModel,
508 P: PromptHook<M>,
509{
510 fn agent_name(&self) -> &str {
511 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
512 }
513
514 async fn send(self) -> Result<PromptResponse, PromptError> {
515 let agent_span = if tracing::Span::current().is_disabled() {
516 info_span!(
517 "invoke_agent",
518 gen_ai.operation.name = "invoke_agent",
519 gen_ai.agent.name = self.agent_name(),
520 gen_ai.system_instructions = self.preamble,
521 gen_ai.prompt = tracing::field::Empty,
522 gen_ai.completion = tracing::field::Empty,
523 gen_ai.usage.input_tokens = tracing::field::Empty,
524 gen_ai.usage.output_tokens = tracing::field::Empty,
525 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
526 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
527 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
528 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
529 )
530 } else {
531 tracing::Span::current()
532 };
533
534 if let Some(text) = self.prompt.rag_text() {
535 agent_span.record("gen_ai.prompt", text);
536 }
537
538 let agent_name_for_span = self.agent_name.clone();
539 let (chat_history, memory_handle) = match self.chat_history {
544 Some(history) => (Some(history), None),
545 None => match (self.memory, self.conversation_id) {
546 (Some(memory), Some(id)) => {
547 let loaded = memory.load(&id).await?;
548 (Some(loaded), Some((memory, id)))
549 }
550 _ => (None, None),
551 },
552 };
553
554 let mut run = AgentRun::new(self.prompt.clone())
555 .max_turns(self.max_turns)
556 .max_invalid_tool_call_retries(self.max_invalid_tool_call_retries);
557 if let Some(history) = chat_history {
558 run = run.with_history(history);
559 }
560 if let Some(tool_choice) = self.tool_choice.clone() {
561 run = run.with_tool_choice(tool_choice);
562 }
563
564 let current_span_id: AtomicU64 = AtomicU64::new(0);
565
566 loop {
567 match run.next_step()? {
568 AgentRunStep::CallModel {
569 prompt,
570 history,
571 turn,
572 } => {
573 if self.max_turns > 1 {
574 tracing::info!("Current conversation depth: {}/{}", turn, self.max_turns);
575 }
576
577 if let Some(ref hook) = self.hook
578 && let HookAction::Terminate { reason } =
579 hook.on_completion_call(&prompt, &history).await
580 {
581 return Err(run.cancel_error(reason));
582 }
583
584 let span = tracing::Span::current();
585 let chat_span = info_span!(
586 target: "rig::agent_chat",
587 parent: &span,
588 "chat",
589 gen_ai.operation.name = "chat",
590 gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
591 gen_ai.system_instructions = self.preamble,
592 gen_ai.provider.name = tracing::field::Empty,
593 gen_ai.request.model = tracing::field::Empty,
594 gen_ai.response.id = tracing::field::Empty,
595 gen_ai.response.model = tracing::field::Empty,
596 gen_ai.usage.output_tokens = tracing::field::Empty,
597 gen_ai.usage.input_tokens = tracing::field::Empty,
598 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
599 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
600 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
601 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
602 gen_ai.input.messages = tracing::field::Empty,
603 gen_ai.output.messages = tracing::field::Empty,
604 );
605
606 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
607 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
608 chat_span.follows_from(id).to_owned()
609 } else {
610 chat_span
611 };
612
613 if let Some(id) = chat_span.id() {
614 current_span_id.store(id.into_u64(), Ordering::SeqCst);
615 };
616
617 let prepared_request = build_prepared_completion_request(
618 &self.model,
619 prompt.clone(),
620 &history,
621 self.preamble.as_deref(),
622 &self.static_context,
623 self.temperature,
624 self.max_tokens,
625 self.additional_params.as_ref(),
626 self.tool_choice.as_ref(),
627 &self.tool_server_handle,
628 &self.dynamic_context,
629 self.output_schema.as_ref(),
630 )
631 .await?;
632
633 let resp = prepared_request
634 .builder
635 .send()
636 .instrument(chat_span.clone())
637 .await?;
638
639 let mut outcome = run.model_response(ModelTurn::new(
640 resp.message_id.clone(),
641 resp.choice.clone(),
642 resp.usage,
643 prepared_request.executable_tool_names,
644 prepared_request.allowed_tool_names,
645 ))?;
646
647 loop {
648 match outcome {
649 ModelTurnOutcome::NeedsResolution(context) => {
650 let action = match self.hook.as_ref() {
651 Some(hook) => hook.on_invalid_tool_call(&context).await,
652 None => InvalidToolCallHookAction::fail(),
653 };
654 outcome = run.resolve_invalid_tool_call(action)?;
655 }
656 ModelTurnOutcome::TurnRetried => break,
657 ModelTurnOutcome::Continue {
658 response_hook_suppressed,
659 } => {
660 if !response_hook_suppressed
661 && let Some(ref hook) = self.hook
662 && let HookAction::Terminate { reason } =
663 hook.on_completion_response(&prompt, &resp).await
664 {
665 return Err(run.cancel_error(reason));
666 }
667 break;
668 }
669 }
670 }
671 }
672 AgentRunStep::CallTools { calls } => {
673 let hook = self.hook.clone();
674 let tool_server_handle = self.tool_server_handle.clone();
675
676 let full_history_for_errors = run.full_history();
678
679 let tool_content = stream::iter(calls)
680 .map(|pending| {
681 let hook1 = hook.clone();
682 let hook2 = hook.clone();
683 let tool_server_handle = tool_server_handle.clone();
684
685 let tool_span = info_span!(
686 "execute_tool",
687 gen_ai.operation.name = "execute_tool",
688 gen_ai.tool.type = "function",
689 gen_ai.tool.name = tracing::field::Empty,
690 gen_ai.tool.call.id = tracing::field::Empty,
691 gen_ai.tool.call.arguments = tracing::field::Empty,
692 gen_ai.tool.call.result = tracing::field::Empty
693 );
694
695 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
696 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
697 tool_span.follows_from(id).to_owned()
698 } else {
699 tool_span
700 };
701
702 if let Some(id) = tool_span.id() {
703 current_span_id.store(id.into_u64(), Ordering::SeqCst);
704 };
705
706 let cloned_history_for_error = full_history_for_errors.clone();
708
709 async move {
710 let PendingToolCall {
711 tool_call,
712 preresolved_result,
713 ..
714 } = pending;
715 let tool_name = &tool_call.function.name;
716 let args =
717 json_utils::value_to_json_string(&tool_call.function.arguments);
718 let internal_call_id = nanoid::nanoid!();
719 if let Some(result) = preresolved_result {
720 return Ok(result);
721 }
722 let tool_span = tracing::Span::current();
723 tool_span.record("gen_ai.tool.name", tool_name);
724 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
725 tool_span.record("gen_ai.tool.call.arguments", &args);
726 if let Some(hook) = hook1 {
727 let action = hook
728 .on_tool_call(
729 tool_name,
730 tool_call.call_id.clone(),
731 &internal_call_id,
732 &args,
733 )
734 .await;
735
736 if let ToolCallHookAction::Terminate { reason } = action {
737 return Err(PromptError::prompt_cancelled(
738 cloned_history_for_error,
739 reason,
740 ));
741 }
742
743 if let ToolCallHookAction::Skip { reason } = action {
744 tracing::info!(
746 tool_name = tool_name,
747 reason = reason,
748 "Tool call rejected"
749 );
750 if let Some(call_id) = tool_call.call_id.clone() {
751 return Ok(UserContent::tool_result_with_call_id(
752 tool_call.id.clone(),
753 call_id,
754 OneOrMany::one(reason.into()),
755 ));
756 } else {
757 return Ok(UserContent::tool_result(
758 tool_call.id.clone(),
759 OneOrMany::one(reason.into()),
760 ));
761 }
762 }
763 }
764 let output =
765 match tool_server_handle.call_tool(tool_name, &args).await {
766 Ok(res) => res,
767 Err(e) => {
768 tracing::warn!("Error while executing tool: {e}");
769 e.to_string()
770 }
771 };
772 if let Some(hook) = hook2
773 && let HookAction::Terminate { reason } = hook
774 .on_tool_result(
775 tool_name,
776 tool_call.call_id.clone(),
777 &internal_call_id,
778 &args,
779 &output.to_string(),
780 )
781 .await
782 {
783 return Err(PromptError::prompt_cancelled(
784 cloned_history_for_error,
785 reason,
786 ));
787 }
788
789 tool_span.record("gen_ai.tool.call.result", &output);
790 tracing::info!(
791 "executed tool {tool_name} with args {args}. result: {output}"
792 );
793 if let Some(call_id) = tool_call.call_id.clone() {
794 Ok(UserContent::tool_result_with_call_id(
795 tool_call.id.clone(),
796 call_id,
797 ToolResultContent::from_tool_output(output),
798 ))
799 } else {
800 Ok(UserContent::tool_result(
801 tool_call.id.clone(),
802 ToolResultContent::from_tool_output(output),
803 ))
804 }
805 }
806 .instrument(tool_span)
807 })
808 .buffer_unordered(self.concurrency)
809 .collect::<Vec<Result<UserContent, PromptError>>>()
810 .await
811 .into_iter()
812 .collect::<Result<Vec<_>, _>>()?;
813
814 run.tool_results(tool_content)?;
815 }
816 AgentRunStep::Done(response) => {
817 if self.max_turns > 1 {
818 tracing::info!("Depth reached: {}/{}", run.turn(), self.max_turns);
819 }
820
821 let usage = response.usage;
822 agent_span.record("gen_ai.completion", &response.output);
823 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
824 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
825 agent_span.record(
826 "gen_ai.usage.cache_read.input_tokens",
827 usage.cached_input_tokens,
828 );
829 agent_span.record(
830 "gen_ai.usage.cache_creation.input_tokens",
831 usage.cache_creation_input_tokens,
832 );
833 agent_span.record(
834 "gen_ai.usage.tool_use_prompt_tokens",
835 usage.tool_use_prompt_tokens,
836 );
837 agent_span.record("gen_ai.usage.reasoning_tokens", usage.reasoning_tokens);
838
839 if let Some((memory, id)) = memory_handle.as_ref()
840 && let Err(err) = memory
841 .append(id, response.messages.clone().unwrap_or_default())
842 .await
843 {
844 tracing::warn!(
845 error = %err,
846 conversation_id = %id,
847 "conversation memory append failed; returning model response anyway"
848 );
849 }
850
851 return Ok(response);
852 }
853 }
854 }
855 }
856}
857
858use crate::completion::StructuredOutputError;
863use schemars::{JsonSchema, schema_for};
864use serde::de::DeserializeOwned;
865
866pub struct TypedPromptRequest<T, S, M, P>
883where
884 T: JsonSchema + DeserializeOwned + WasmCompatSend,
885 S: PromptType,
886 M: CompletionModel,
887 P: PromptHook<M>,
888{
889 inner: PromptRequest<S, M, P>,
890 _phantom: std::marker::PhantomData<T>,
891}
892
893impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
894where
895 T: JsonSchema + DeserializeOwned + WasmCompatSend,
896 M: CompletionModel,
897 P: PromptHook<M>,
898{
899 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
903 let mut inner = PromptRequest::from_agent(agent, prompt);
904 inner.output_schema = Some(schema_for!(T));
906 Self {
907 inner,
908 _phantom: std::marker::PhantomData,
909 }
910 }
911}
912
913impl<T, S, M, P> TypedPromptRequest<T, S, M, P>
914where
915 T: JsonSchema + DeserializeOwned + WasmCompatSend,
916 S: PromptType,
917 M: CompletionModel,
918 P: PromptHook<M>,
919{
920 pub fn extended_details(self) -> TypedPromptRequest<T, Extended, M, P> {
926 TypedPromptRequest {
927 inner: self.inner.extended_details(),
928 _phantom: std::marker::PhantomData,
929 }
930 }
931
932 pub fn max_turns(mut self, depth: usize) -> Self {
938 self.inner = self.inner.max_turns(depth);
939 self
940 }
941
942 pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
946 self.inner = self.inner.max_invalid_tool_call_retries(retries);
947 self
948 }
949
950 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
954 self.inner = self.inner.with_tool_concurrency(concurrency);
955 self
956 }
957
958 pub fn with_history<H, U>(mut self, history: H) -> Self
960 where
961 H: IntoIterator<Item = U>,
962 U: Into<Message>,
963 {
964 self.inner = self.inner.with_history(history);
965 self
966 }
967
968 pub fn conversation(mut self, id: impl Into<String>) -> Self {
973 self.inner = self.inner.conversation(id);
974 self
975 }
976
977 pub fn without_memory(mut self) -> Self {
981 self.inner = self.inner.without_memory();
982 self
983 }
984
985 pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<T, S, M, P2>
989 where
990 P2: PromptHook<M>,
991 {
992 TypedPromptRequest {
993 inner: self.inner.with_hook(hook),
994 _phantom: std::marker::PhantomData,
995 }
996 }
997}
998
999impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
1000where
1001 T: JsonSchema + DeserializeOwned + WasmCompatSend,
1002 M: CompletionModel,
1003 P: PromptHook<M>,
1004{
1005 async fn send(self) -> Result<T, StructuredOutputError> {
1007 let response = self.inner.send().await.map_err(Box::new)?;
1008
1009 if response.is_empty() {
1010 return Err(StructuredOutputError::EmptyResponse);
1011 }
1012
1013 let parsed: T = serde_json::from_str(&response)?;
1014 Ok(parsed)
1015 }
1016}
1017
1018impl<T, M, P> TypedPromptRequest<T, Extended, M, P>
1019where
1020 T: JsonSchema + DeserializeOwned + WasmCompatSend,
1021 M: CompletionModel,
1022 P: PromptHook<M>,
1023{
1024 async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
1026 let response = self.inner.send().await.map_err(Box::new)?;
1027
1028 if response.output.is_empty() {
1029 return Err(StructuredOutputError::EmptyResponse);
1030 }
1031
1032 let parsed: T = serde_json::from_str(&response.output)?;
1033 Ok(TypedPromptResponse::new(parsed, response.usage)
1034 .with_completion_calls(response.completion_calls))
1035 }
1036}
1037
1038impl<T, M, P> IntoFuture for TypedPromptRequest<T, Standard, M, P>
1039where
1040 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
1041 M: CompletionModel + 'static,
1042 P: PromptHook<M> + 'static,
1043{
1044 type Output = Result<T, StructuredOutputError>;
1045 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1046
1047 fn into_future(self) -> Self::IntoFuture {
1048 Box::pin(self.send())
1049 }
1050}
1051
1052impl<T, M, P> IntoFuture for TypedPromptRequest<T, Extended, M, P>
1053where
1054 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
1055 M: CompletionModel + 'static,
1056 P: PromptHook<M> + 'static,
1057{
1058 type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
1059 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1060
1061 fn into_future(self) -> Self::IntoFuture {
1062 Box::pin(self.send())
1063 }
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068 use super::{CompletionCall, PromptResponse, TypedPromptResponse};
1069 use crate::{
1070 agent::{
1071 AgentBuilder,
1072 prompt_request::hooks::{
1073 HookAction, InvalidToolCallContext, InvalidToolCallHookAction, PromptHook,
1074 ToolCallHookAction,
1075 },
1076 },
1077 completion::{
1078 AssistantContent, CompletionError, CompletionModel, CompletionRequest, Message, Prompt,
1079 PromptError, StructuredOutputError, ToolDefinition, TypedPrompt, Usage,
1080 },
1081 message::{Text, ToolCall, ToolChoice, ToolFunction, UserContent},
1082 test_utils::{
1083 AppendFailingMemory, CountingMemory, FailingMemory, MockAddTool, MockCompletionModel,
1084 MockOperationArgs, MockSubtractTool, MockToolError, MockTurn,
1085 },
1086 tool::Tool,
1087 };
1088 use schemars::JsonSchema;
1089 use serde::{Deserialize, Serialize};
1090 use serde_json::json;
1091 use std::sync::{
1092 Arc, Mutex,
1093 atomic::{AtomicU32, Ordering},
1094 };
1095
1096 #[derive(Serialize)]
1097 struct SerializeOnly {
1098 value: &'static str,
1099 }
1100
1101 #[derive(Deserialize)]
1102 struct DeserializeOnly {
1103 value: String,
1104 }
1105
1106 #[derive(Debug, Deserialize, JsonSchema, PartialEq)]
1107 struct TypedAnswer {
1108 value: String,
1109 }
1110
1111 #[derive(Clone)]
1112 struct PanicOnUnknownToolHook;
1113
1114 impl PromptHook<MockCompletionModel> for PanicOnUnknownToolHook {
1115 async fn on_completion_response(
1116 &self,
1117 _prompt: &Message,
1118 _response: &crate::completion::CompletionResponse<
1119 <MockCompletionModel as CompletionModel>::Response,
1120 >,
1121 ) -> HookAction {
1122 panic!("unknown tool response should fail before response hooks run")
1123 }
1124
1125 async fn on_tool_call(
1126 &self,
1127 _tool_name: &str,
1128 _tool_call_id: Option<String>,
1129 _internal_call_id: &str,
1130 _args: &str,
1131 ) -> ToolCallHookAction {
1132 panic!("unknown tool call should fail before tool hooks run")
1133 }
1134 }
1135
1136 #[derive(Clone)]
1137 struct PanicOnToolCallHook;
1138
1139 impl PromptHook<MockCompletionModel> for PanicOnToolCallHook {
1140 async fn on_tool_call(
1141 &self,
1142 _tool_name: &str,
1143 _tool_call_id: Option<String>,
1144 _internal_call_id: &str,
1145 _args: &str,
1146 ) -> ToolCallHookAction {
1147 panic!("recovered invalid turn should not invoke normal tool hooks")
1148 }
1149 }
1150
1151 #[derive(Clone)]
1152 struct SkipDefaultApiAndPanicOnToolCallHook;
1153
1154 impl PromptHook<MockCompletionModel> for SkipDefaultApiAndPanicOnToolCallHook {
1155 async fn on_invalid_tool_call(
1156 &self,
1157 context: &InvalidToolCallContext,
1158 ) -> InvalidToolCallHookAction {
1159 SkipDefaultApiHook.on_invalid_tool_call(context).await
1160 }
1161
1162 async fn on_tool_call(
1163 &self,
1164 tool_name: &str,
1165 tool_call_id: Option<String>,
1166 internal_call_id: &str,
1167 args: &str,
1168 ) -> ToolCallHookAction {
1169 PanicOnToolCallHook
1170 .on_tool_call(tool_name, tool_call_id, internal_call_id, args)
1171 .await
1172 }
1173 }
1174
1175 #[derive(Clone)]
1176 struct RepairDefaultApiHook;
1177
1178 impl PromptHook<MockCompletionModel> for RepairDefaultApiHook {
1179 fn on_invalid_tool_call(
1180 &self,
1181 context: &InvalidToolCallContext,
1182 ) -> impl std::future::Future<Output = InvalidToolCallHookAction> + Send {
1183 let tool_name = context.tool_name.clone();
1184 async move {
1185 assert_eq!(tool_name, "default_api");
1186 InvalidToolCallHookAction::repair("add")
1187 }
1188 }
1189 }
1190
1191 #[derive(Clone)]
1192 struct RepairToSubtractHook;
1193
1194 impl PromptHook<MockCompletionModel> for RepairToSubtractHook {
1195 async fn on_invalid_tool_call(
1196 &self,
1197 _context: &InvalidToolCallContext,
1198 ) -> InvalidToolCallHookAction {
1199 InvalidToolCallHookAction::repair("subtract")
1200 }
1201 }
1202
1203 #[derive(Clone)]
1204 struct RetryDefaultApiHook;
1205
1206 impl PromptHook<MockCompletionModel> for RetryDefaultApiHook {
1207 fn on_invalid_tool_call(
1208 &self,
1209 context: &InvalidToolCallContext,
1210 ) -> impl std::future::Future<Output = InvalidToolCallHookAction> + Send {
1211 let allowed_tools = context.allowed_tools.clone();
1212 async move {
1213 InvalidToolCallHookAction::retry(format!(
1214 "Use one of these tools instead: {allowed_tools:?}"
1215 ))
1216 }
1217 }
1218 }
1219
1220 #[derive(Clone)]
1221 struct SkipDefaultApiHook;
1222
1223 impl PromptHook<MockCompletionModel> for SkipDefaultApiHook {
1224 async fn on_invalid_tool_call(
1225 &self,
1226 _context: &InvalidToolCallContext,
1227 ) -> InvalidToolCallHookAction {
1228 InvalidToolCallHookAction::skip("default_api is not available")
1229 }
1230 }
1231
1232 #[derive(Clone, Default)]
1233 struct RecordingInvalidToolCallHook {
1234 contexts: Arc<Mutex<Vec<InvalidToolCallContext>>>,
1235 }
1236
1237 impl RecordingInvalidToolCallHook {
1238 fn observed(&self) -> Vec<InvalidToolCallContext> {
1239 self.contexts
1240 .lock()
1241 .expect("invalid tool context records mutex was poisoned")
1242 .clone()
1243 }
1244 }
1245
1246 impl PromptHook<MockCompletionModel> for RecordingInvalidToolCallHook {
1247 async fn on_invalid_tool_call(
1248 &self,
1249 context: &InvalidToolCallContext,
1250 ) -> InvalidToolCallHookAction {
1251 self.contexts
1252 .lock()
1253 .expect("invalid tool context records mutex was poisoned")
1254 .push(context.clone());
1255 InvalidToolCallHookAction::fail()
1256 }
1257 }
1258
1259 #[derive(Clone)]
1260 struct CountingAddTool {
1261 calls: Arc<AtomicU32>,
1262 }
1263
1264 impl Tool for CountingAddTool {
1265 const NAME: &'static str = "add";
1266 type Error = MockToolError;
1267 type Args = MockOperationArgs;
1268 type Output = i32;
1269
1270 async fn definition(&self, _prompt: String) -> ToolDefinition {
1271 MockAddTool.definition(String::new()).await
1272 }
1273
1274 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
1275 self.calls.fetch_add(1, Ordering::SeqCst);
1276 Ok(0)
1277 }
1278 }
1279
1280 fn usage(input_tokens: u64, output_tokens: u64) -> Usage {
1281 Usage {
1282 input_tokens,
1283 output_tokens,
1284 total_tokens: input_tokens + output_tokens,
1285 cached_input_tokens: 0,
1286 cache_creation_input_tokens: 0,
1287 tool_use_prompt_tokens: 0,
1288 reasoning_tokens: 0,
1289 }
1290 }
1291
1292 #[test]
1293 fn typed_prompt_response_serializes_with_serialize_only_output() {
1294 let response = TypedPromptResponse::new(
1295 SerializeOnly { value: "ok" },
1296 Usage {
1297 input_tokens: 1,
1298 output_tokens: 2,
1299 total_tokens: 3,
1300 cached_input_tokens: 0,
1301 cache_creation_input_tokens: 0,
1302 tool_use_prompt_tokens: 0,
1303 reasoning_tokens: 0,
1304 },
1305 );
1306
1307 let json = serde_json::to_string(&response).expect("serialize typed prompt response");
1308 assert!(json.contains("\"value\":\"ok\""));
1309 }
1310
1311 #[test]
1312 fn typed_prompt_response_deserializes_with_deserialize_only_output() {
1313 let response: TypedPromptResponse<DeserializeOnly> = serde_json::from_str(
1314 r#"{"output":{"value":"ok"},"usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3,"cached_input_tokens":0,"cache_creation_input_tokens":0,"reasoning_tokens":0}}"#,
1315 )
1316 .expect("deserialize typed prompt response");
1317
1318 assert_eq!(response.requests(), 0);
1319 assert_eq!(response.output.value, "ok");
1320 assert_eq!(response.usage.input_tokens, 1);
1321 assert_eq!(response.usage.output_tokens, 2);
1322 assert_eq!(response.usage.total_tokens, 3);
1323 }
1324
1325 #[test]
1326 fn prompt_response_serializes_completion_calls_with_missing_usage() {
1327 let reported_usage = usage(3, 4);
1328 let response = PromptResponse::new("ok", reported_usage).with_completion_calls(vec![
1329 CompletionCall::new(0, Usage::new()),
1330 CompletionCall::new(1, reported_usage),
1331 ]);
1332
1333 let value = serde_json::to_value(&response).expect("serialize prompt response");
1334
1335 assert_eq!(
1339 value.get("completion_calls"),
1340 Some(&json!([
1341 {
1342 "call_index": 0,
1343 "usage": {
1344 "input_tokens": 0,
1345 "output_tokens": 0,
1346 "total_tokens": 0,
1347 "cached_input_tokens": 0,
1348 "cache_creation_input_tokens": 0,
1349 "tool_use_prompt_tokens": 0,
1350 "reasoning_tokens": 0,
1351 }
1352 },
1353 {
1354 "call_index": 1,
1355 "usage": {
1356 "input_tokens": 3,
1357 "output_tokens": 4,
1358 "total_tokens": 7,
1359 "cached_input_tokens": 0,
1360 "cache_creation_input_tokens": 0,
1361 "tool_use_prompt_tokens": 0,
1362 "reasoning_tokens": 0,
1363 }
1364 }
1365 ]))
1366 );
1367
1368 let response: PromptResponse =
1369 serde_json::from_value(value).expect("deserialize prompt response");
1370 assert_eq!(
1371 response.completion_calls(),
1372 &[
1373 CompletionCall::new(0, Usage::new()),
1374 CompletionCall::new(1, reported_usage)
1375 ]
1376 );
1377 assert_eq!(response.requests(), 2);
1378 }
1379
1380 #[test]
1381 fn prompt_response_deserializes_pre_monoid_null_usage_format() {
1382 let fixture = r#"{"output":"ok","usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"cached_input_tokens":0,"cache_creation_input_tokens":0,"tool_use_prompt_tokens":0,"reasoning_tokens":0},"completion_calls":[{"call_index":0,"usage":null},{"call_index":1,"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"cached_input_tokens":0,"cache_creation_input_tokens":0,"tool_use_prompt_tokens":0,"reasoning_tokens":0}}],"messages":[{"role":"user","content":[{"type":"text","text":"add things"}]}]}"#;
1385
1386 let response: PromptResponse =
1387 serde_json::from_str(fixture).expect("old-format response should deserialize");
1388 assert_eq!(
1389 response.completion_calls(),
1390 &[
1391 CompletionCall::new(0, Usage::new()),
1392 CompletionCall::new(1, usage(3, 4))
1393 ]
1394 );
1395 }
1396
1397 #[tokio::test]
1398 async fn prompt_response_records_completion_call_without_reported_usage() {
1399 let model = MockCompletionModel::new([MockTurn::text("ok")]);
1400 let agent = AgentBuilder::new(model).build();
1401
1402 let response = agent
1403 .prompt("say ok")
1404 .extended_details()
1405 .await
1406 .expect("prompt should succeed");
1407
1408 assert_eq!(response.output, "ok");
1409 assert_eq!(response.usage, Usage::new());
1410 assert_eq!(
1411 response.completion_calls(),
1412 &[CompletionCall::new(0, Usage::new())]
1413 );
1414 }
1415
1416 #[tokio::test]
1417 async fn typed_prompt_response_preserves_completion_calls() {
1418 let call_usage = Usage {
1419 input_tokens: 4,
1420 output_tokens: 6,
1421 total_tokens: 10,
1422 cached_input_tokens: 0,
1423 cache_creation_input_tokens: 0,
1424 tool_use_prompt_tokens: 0,
1425 reasoning_tokens: 0,
1426 };
1427 let model =
1428 MockCompletionModel::new([MockTurn::text(r#"{"value":"ok"}"#).with_usage(call_usage)]);
1429 let agent = AgentBuilder::new(model).build();
1430
1431 let response = agent
1432 .prompt_typed::<TypedAnswer>("return typed json")
1433 .extended_details()
1434 .await
1435 .expect("typed prompt should succeed");
1436
1437 assert_eq!(
1438 response.output,
1439 TypedAnswer {
1440 value: "ok".to_string()
1441 }
1442 );
1443 assert_eq!(response.usage, call_usage);
1444 assert_eq!(
1445 response.completion_calls(),
1446 &[CompletionCall::new(0, call_usage)]
1447 );
1448 }
1449
1450 fn validate_follow_up_tool_history(request: &CompletionRequest) {
1451 let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
1452 assert_eq!(
1453 history.len(),
1454 3,
1455 "follow-up request should contain the prompt, assistant tool call, and user tool result: {history:?}"
1456 );
1457
1458 assert!(matches!(
1459 history.first(),
1460 Some(Message::User { content })
1461 if matches!(
1462 content.first(),
1463 UserContent::Text(text) if text.text == "do tool work"
1464 )
1465 ));
1466
1467 assert!(matches!(
1468 history.get(1),
1469 Some(Message::Assistant { content, .. })
1470 if matches!(
1471 content.first(),
1472 AssistantContent::ToolCall(tool_call)
1473 if tool_call.id == "tool_call_1"
1474 && tool_call.call_id.as_deref() == Some("call_1")
1475 )
1476 ));
1477
1478 assert!(matches!(
1479 history.get(2),
1480 Some(Message::User { content })
1481 if matches!(
1482 content.first(),
1483 UserContent::ToolResult(tool_result)
1484 if tool_result.id == "tool_call_1"
1485 && tool_result.call_id.as_deref() == Some("call_1")
1486 )
1487 ));
1488 }
1489
1490 fn history_contains_tool_call(history: &[Message], tool_name: &str) -> bool {
1491 history.iter().any(|message| {
1492 matches!(
1493 message,
1494 Message::Assistant { content, .. }
1495 if content.iter().any(|item| matches!(
1496 item,
1497 AssistantContent::ToolCall(tool_call)
1498 if tool_call.function.name == tool_name
1499 ))
1500 )
1501 })
1502 }
1503
1504 #[tokio::test]
1505 async fn unknown_tool_call_fails_before_non_streaming_second_request() {
1506 let model = MockCompletionModel::new([
1507 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 1, "y": 2})),
1508 MockTurn::text("should not be requested"),
1509 ]);
1510 let recorded = model.clone();
1511 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1512
1513 let err = agent
1514 .prompt("use the tool")
1515 .with_hook(PanicOnUnknownToolHook)
1516 .max_turns(3)
1517 .await
1518 .expect_err("unknown model-emitted tool should fail");
1519
1520 match err {
1521 PromptError::UnknownToolCall {
1522 tool_name,
1523 available_tools,
1524 allowed_tools,
1525 chat_history,
1526 } => {
1527 assert_eq!(tool_name, "default_api");
1528 assert_eq!(available_tools, vec!["add".to_string()]);
1529 assert_eq!(allowed_tools, vec!["add".to_string()]);
1530 assert!(history_contains_tool_call(&chat_history, "default_api"));
1531 }
1532 other => panic!("expected UnknownToolCall, got {other:?}"),
1533 }
1534 assert_eq!(recorded.request_count(), 1);
1535 }
1536
1537 #[tokio::test]
1538 async fn invalid_tool_call_context_uses_completed_tool_call_provider_id() {
1539 let invalid_hook = RecordingInvalidToolCallHook::default();
1540 let model = MockCompletionModel::new([
1541 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 1, "y": 2}))
1542 .with_call_id("provider_call_1"),
1543 MockTurn::text("should not be requested"),
1544 ]);
1545 let recorded = model.clone();
1546 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1547
1548 let err = agent
1549 .prompt("use the tool")
1550 .with_hook(invalid_hook.clone())
1551 .max_turns(3)
1552 .await
1553 .expect_err("invalid tool should fail");
1554
1555 assert!(matches!(err, PromptError::UnknownToolCall { .. }));
1556 assert_eq!(recorded.request_count(), 1);
1557 let contexts = invalid_hook.observed();
1558 assert_eq!(contexts.len(), 1);
1559 let context = &contexts[0];
1560 assert_eq!(context.tool_name, "default_api");
1561 assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
1562 assert_eq!(context.internal_call_id, None);
1563 assert!(!context.is_streaming);
1564 }
1565
1566 #[tokio::test]
1567 async fn disallowed_specific_tool_call_fails_before_non_streaming_second_request() {
1568 let model = MockCompletionModel::new([
1569 MockTurn::tool_call("tool_call_1", "subtract", json!({"x": 3, "y": 1})),
1570 MockTurn::text("should not be requested"),
1571 ]);
1572 let recorded = model.clone();
1573 let agent = AgentBuilder::new(model)
1574 .tool(MockAddTool)
1575 .tool(MockSubtractTool)
1576 .tool_choice(ToolChoice::Specific {
1577 function_names: vec!["add".to_string()],
1578 })
1579 .build();
1580
1581 let err = agent
1582 .prompt("use the allowed tool")
1583 .with_hook(PanicOnUnknownToolHook)
1584 .max_turns(3)
1585 .await
1586 .expect_err("disallowed model-emitted tool should fail");
1587
1588 match err {
1589 PromptError::UnknownToolCall {
1590 tool_name,
1591 available_tools,
1592 allowed_tools,
1593 chat_history,
1594 } => {
1595 assert_eq!(tool_name, "subtract");
1596 assert_eq!(
1597 available_tools,
1598 vec!["add".to_string(), "subtract".to_string()]
1599 );
1600 assert_eq!(allowed_tools, vec!["add".to_string()]);
1601 assert!(history_contains_tool_call(&chat_history, "subtract"));
1602 }
1603 other => panic!("expected UnknownToolCall, got {other:?}"),
1604 }
1605 assert_eq!(recorded.request_count(), 1);
1606 }
1607
1608 #[tokio::test]
1609 async fn tool_choice_none_rejects_non_streaming_tool_call() {
1610 let model = MockCompletionModel::new([
1611 MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2})),
1612 MockTurn::text("should not be requested"),
1613 ]);
1614 let recorded = model.clone();
1615 let agent = AgentBuilder::new(model)
1616 .tool(MockAddTool)
1617 .tool_choice(ToolChoice::None)
1618 .build();
1619
1620 let err = agent
1621 .prompt("do not use tools")
1622 .with_hook(PanicOnUnknownToolHook)
1623 .max_turns(3)
1624 .await
1625 .expect_err("ToolChoice::None should reject returned tool calls");
1626
1627 match err {
1628 PromptError::UnknownToolCall {
1629 tool_name,
1630 available_tools,
1631 allowed_tools,
1632 chat_history,
1633 } => {
1634 assert_eq!(tool_name, "add");
1635 assert_eq!(available_tools, vec!["add".to_string()]);
1636 assert!(allowed_tools.is_empty());
1637 assert!(history_contains_tool_call(&chat_history, "add"));
1638 }
1639 other => panic!("expected UnknownToolCall, got {other:?}"),
1640 }
1641 assert_eq!(recorded.request_count(), 1);
1642 }
1643
1644 #[tokio::test]
1645 async fn invalid_tool_call_hook_can_repair_non_streaming_tool_name() {
1646 let model = MockCompletionModel::new([
1647 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1648 MockTurn::text("done"),
1649 ]);
1650 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1651
1652 let response = agent
1653 .prompt("add")
1654 .with_hook(RepairDefaultApiHook)
1655 .max_turns(3)
1656 .extended_details()
1657 .await
1658 .expect("repaired tool call should execute");
1659
1660 assert_eq!(response.output, "done");
1661 let messages = response.messages.expect("messages should be present");
1662 assert!(history_contains_tool_call(&messages, "add"));
1663 assert!(!history_contains_tool_call(&messages, "default_api"));
1664 assert!(messages.iter().any(|message| {
1665 matches!(
1666 message,
1667 Message::User { content }
1668 if content.iter().any(|content| {
1669 matches!(
1670 content,
1671 UserContent::ToolResult(result)
1672 if result.content.iter().any(|content| {
1673 matches!(
1674 content,
1675 crate::message::ToolResultContent::Text(text)
1676 if text.text == "5"
1677 )
1678 })
1679 )
1680 })
1681 )
1682 }));
1683 }
1684
1685 #[tokio::test]
1686 async fn invalid_tool_call_hook_retry_adds_feedback_and_retries_non_streaming() {
1687 let model = MockCompletionModel::new([
1688 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1689 MockTurn::text("retried"),
1690 ]);
1691 let recorded = model.clone();
1692 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1693
1694 let response = agent
1695 .prompt("add")
1696 .with_hook(RetryDefaultApiHook)
1697 .max_invalid_tool_call_retries(1)
1698 .max_turns(3)
1699 .extended_details()
1700 .await
1701 .expect("retry should recover");
1702
1703 assert_eq!(response.output, "retried");
1704 assert_eq!(recorded.request_count(), 2);
1705 let messages = response.messages.expect("messages should be present");
1706 assert!(messages.iter().any(|message| {
1707 matches!(
1708 message,
1709 Message::User { content }
1710 if content.iter().any(|content| {
1711 matches!(
1712 content,
1713 UserContent::ToolResult(result)
1714 if result.content.iter().any(|content| {
1715 matches!(
1716 content,
1717 crate::message::ToolResultContent::Text(text)
1718 if text.text.contains("Use one of these tools instead")
1719 )
1720 })
1721 )
1722 })
1723 )
1724 }));
1725 }
1726
1727 #[tokio::test]
1728 async fn invalid_tool_call_hook_retries_mixed_non_streaming_turn_without_executing_valid_call()
1729 {
1730 let add_calls = Arc::new(AtomicU32::new(0));
1731 let mut valid_tool_call = ToolCall::new(
1732 "tool_call_1".to_string(),
1733 ToolFunction::new("add".to_string(), json!({"x": 2, "y": 3})),
1734 );
1735 valid_tool_call.call_id = Some("call_1".to_string());
1736 let mut invalid_tool_call = ToolCall::new(
1737 "tool_call_2".to_string(),
1738 ToolFunction::new("default_api".to_string(), json!({"x": 4, "y": 5})),
1739 );
1740 invalid_tool_call.call_id = Some("call_2".to_string());
1741 let model = MockCompletionModel::new([
1742 MockTurn::from_contents([
1743 AssistantContent::ToolCall(valid_tool_call),
1744 AssistantContent::ToolCall(invalid_tool_call),
1745 ])
1746 .expect("tool-call response should be non-empty"),
1747 MockTurn::text("retried"),
1748 ]);
1749 let recorded = model.clone();
1750 let agent = AgentBuilder::new(model)
1751 .tool(CountingAddTool {
1752 calls: add_calls.clone(),
1753 })
1754 .build();
1755
1756 let response = agent
1757 .prompt("add")
1758 .with_hook(RetryDefaultApiHook)
1759 .max_invalid_tool_call_retries(1)
1760 .max_turns(3)
1761 .extended_details()
1762 .await
1763 .expect("retry should recover");
1764
1765 assert_eq!(response.output, "retried");
1766 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
1767 let requests = recorded.requests();
1768 assert_eq!(requests.len(), 2);
1769 let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
1770 assert_eq!(retry_history.len(), 3);
1771 assert!(matches!(
1772 retry_history.get(1),
1773 Some(Message::Assistant { content, .. })
1774 if content.iter().any(|item| matches!(
1775 item,
1776 AssistantContent::ToolCall(tool_call)
1777 if tool_call.id == "tool_call_1"
1778 && tool_call.function.name == "add"
1779 ))
1780 && content.iter().any(|item| matches!(
1781 item,
1782 AssistantContent::ToolCall(tool_call)
1783 if tool_call.id == "tool_call_2"
1784 && tool_call.function.name == "default_api"
1785 ))
1786 ));
1787 assert!(matches!(
1788 retry_history.get(2),
1789 Some(Message::User { content })
1790 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
1791 && content.iter().any(|item| matches!(
1792 item,
1793 UserContent::ToolResult(result)
1794 if result.id == "tool_call_1"
1795 && result.call_id.as_deref() == Some("call_1")
1796 && result.content.iter().any(|content| matches!(
1797 content,
1798 crate::message::ToolResultContent::Text(text)
1799 if text.text == super::TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
1800 ))
1801 ))
1802 && content.iter().any(|item| matches!(
1803 item,
1804 UserContent::ToolResult(result)
1805 if result.id == "tool_call_2"
1806 && result.call_id.as_deref() == Some("call_2")
1807 && result.content.iter().any(|content| matches!(
1808 content,
1809 crate::message::ToolResultContent::Text(text)
1810 if text.text.contains("Use one of these tools instead")
1811 ))
1812 ))
1813 ));
1814 }
1815
1816 #[tokio::test]
1817 async fn invalid_tool_call_hook_skips_mixed_non_streaming_turn_without_executing_valid_call() {
1818 let add_calls = Arc::new(AtomicU32::new(0));
1819 let mut valid_tool_call = ToolCall::new(
1820 "tool_call_1".to_string(),
1821 ToolFunction::new("add".to_string(), json!({"x": 2, "y": 3})),
1822 );
1823 valid_tool_call.call_id = Some("call_1".to_string());
1824 let mut invalid_tool_call = ToolCall::new(
1825 "tool_call_2".to_string(),
1826 ToolFunction::new("default_api".to_string(), json!({"x": 4, "y": 5})),
1827 );
1828 invalid_tool_call.call_id = Some("call_2".to_string());
1829 let model = MockCompletionModel::new([
1830 MockTurn::from_contents([
1831 AssistantContent::ToolCall(valid_tool_call),
1832 AssistantContent::ToolCall(invalid_tool_call),
1833 ])
1834 .expect("tool-call response should be non-empty"),
1835 MockTurn::text("skipped"),
1836 ]);
1837 let agent = AgentBuilder::new(model)
1838 .tool(CountingAddTool {
1839 calls: add_calls.clone(),
1840 })
1841 .build();
1842
1843 let response = agent
1844 .prompt("add")
1845 .with_hook(SkipDefaultApiAndPanicOnToolCallHook)
1846 .max_turns(3)
1847 .extended_details()
1848 .await
1849 .expect("skip should recover without executing peer tools");
1850
1851 assert_eq!(response.output, "skipped");
1852 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
1853 let messages = response.messages.expect("messages should be present");
1854 assert!(history_contains_tool_call(&messages, "add"));
1855 assert!(history_contains_tool_call(&messages, "default_api"));
1856 assert!(matches!(
1857 messages.get(2),
1858 Some(Message::User { content })
1859 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
1860 && content.iter().any(|item| matches!(
1861 item,
1862 UserContent::ToolResult(result)
1863 if result.id == "tool_call_1"
1864 && result.call_id.as_deref() == Some("call_1")
1865 && result.content.iter().any(|content| matches!(
1866 content,
1867 crate::message::ToolResultContent::Text(text)
1868 if text.text == super::TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
1869 ))
1870 ))
1871 && content.iter().any(|item| matches!(
1872 item,
1873 UserContent::ToolResult(result)
1874 if result.id == "tool_call_2"
1875 && result.call_id.as_deref() == Some("call_2")
1876 && result.content.iter().any(|content| matches!(
1877 content,
1878 crate::message::ToolResultContent::Text(text)
1879 if text.text == "default_api is not available"
1880 ))
1881 ))
1882 ));
1883 }
1884
1885 #[tokio::test]
1886 async fn invalid_tool_call_hook_retry_budget_exhaustion_fails() {
1887 let model = MockCompletionModel::new([
1888 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1889 MockTurn::text("should not be requested"),
1890 ]);
1891 let recorded = model.clone();
1892 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1893
1894 let err = agent
1895 .prompt("add")
1896 .with_hook(RetryDefaultApiHook)
1897 .max_invalid_tool_call_retries(0)
1898 .max_turns(3)
1899 .await
1900 .expect_err("retry without budget should fail");
1901
1902 match err {
1903 PromptError::UnknownToolCall {
1904 tool_name,
1905 chat_history,
1906 ..
1907 } => {
1908 assert_eq!(tool_name, "default_api");
1909 assert!(history_contains_tool_call(&chat_history, "default_api"));
1910 }
1911 other => panic!("expected UnknownToolCall, got {other:?}"),
1912 }
1913 assert_eq!(recorded.request_count(), 1);
1914 }
1915
1916 #[tokio::test]
1917 async fn invalid_tool_call_hook_can_skip_structured_non_streaming_call() {
1918 let model = MockCompletionModel::new([
1919 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1920 MockTurn::text("skipped"),
1921 ]);
1922 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1923
1924 let response = agent
1925 .prompt("add")
1926 .with_hook(SkipDefaultApiHook)
1927 .max_turns(3)
1928 .extended_details()
1929 .await
1930 .expect("skip should continue with synthetic tool result");
1931
1932 assert_eq!(response.output, "skipped");
1933 let messages = response.messages.expect("messages should be present");
1934 assert!(history_contains_tool_call(&messages, "default_api"));
1935 assert!(messages.iter().any(|message| {
1936 matches!(
1937 message,
1938 Message::User { content }
1939 if content.iter().any(|content| {
1940 matches!(
1941 content,
1942 UserContent::ToolResult(result)
1943 if result.content.iter().any(|content| {
1944 matches!(
1945 content,
1946 crate::message::ToolResultContent::Text(text)
1947 if text.text == "default_api is not available"
1948 )
1949 })
1950 )
1951 })
1952 )
1953 }));
1954 }
1955
1956 #[tokio::test]
1957 async fn skip_under_specific_tool_choice_returns_synthetic_feedback() {
1958 let model = MockCompletionModel::new([
1959 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1960 MockTurn::text("skipped"),
1961 ]);
1962 let agent = AgentBuilder::new(model)
1963 .tool(MockAddTool)
1964 .tool_choice(ToolChoice::Specific {
1965 function_names: vec!["add".to_string()],
1966 })
1967 .build();
1968
1969 let response = agent
1970 .prompt("add")
1971 .with_hook(SkipDefaultApiHook)
1972 .max_turns(3)
1973 .extended_details()
1974 .await
1975 .expect("skip should produce synthetic feedback under Specific");
1976
1977 assert_eq!(response.output, "skipped");
1978 let messages = response.messages.expect("messages should be present");
1979 assert!(history_contains_tool_call(&messages, "default_api"));
1980 assert!(messages.iter().any(|message| {
1981 matches!(
1982 message,
1983 Message::User { content }
1984 if content.iter().any(|content| {
1985 matches!(
1986 content,
1987 UserContent::ToolResult(result)
1988 if result.id == "tool_call_1"
1989 && result.content.iter().any(|content| {
1990 matches!(
1991 content,
1992 crate::message::ToolResultContent::Text(text)
1993 if text.text == "default_api is not available"
1994 )
1995 })
1996 )
1997 })
1998 )
1999 }));
2000 }
2001
2002 #[tokio::test]
2003 async fn repair_to_disallowed_specific_tool_fails() {
2004 let model = MockCompletionModel::new([
2005 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2006 MockTurn::text("should not be requested"),
2007 ]);
2008 let recorded = model.clone();
2009 let agent = AgentBuilder::new(model)
2010 .tool(MockAddTool)
2011 .tool(MockSubtractTool)
2012 .tool_choice(ToolChoice::Specific {
2013 function_names: vec!["add".to_string()],
2014 })
2015 .build();
2016
2017 let err = agent
2018 .prompt("add")
2019 .with_hook(RepairToSubtractHook)
2020 .max_turns(3)
2021 .await
2022 .expect_err("repair to a disallowed tool should fail");
2023
2024 match err {
2025 PromptError::UnknownToolCall { tool_name, .. } => {
2026 assert_eq!(tool_name, "subtract");
2027 }
2028 other => panic!("expected UnknownToolCall, got {other:?}"),
2029 }
2030 assert_eq!(recorded.request_count(), 1);
2031 }
2032
2033 #[tokio::test]
2034 async fn repair_under_tool_choice_none_fails() {
2035 let model = MockCompletionModel::new([
2036 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2037 MockTurn::text("should not be requested"),
2038 ]);
2039 let recorded = model.clone();
2040 let agent = AgentBuilder::new(model)
2041 .tool(MockAddTool)
2042 .tool_choice(ToolChoice::None)
2043 .build();
2044
2045 let err = agent
2046 .prompt("do not use tools")
2047 .with_hook(RepairDefaultApiHook)
2048 .max_turns(3)
2049 .await
2050 .expect_err("ToolChoice::None should reject repaired tool calls");
2051
2052 match err {
2053 PromptError::UnknownToolCall { tool_name, .. } => {
2054 assert_eq!(tool_name, "add");
2055 }
2056 other => panic!("expected UnknownToolCall, got {other:?}"),
2057 }
2058 assert_eq!(recorded.request_count(), 1);
2059 }
2060
2061 #[tokio::test]
2062 async fn skip_under_tool_choice_none_fails() {
2063 let model = MockCompletionModel::new([
2064 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2065 MockTurn::text("should not be requested"),
2066 ]);
2067 let recorded = model.clone();
2068 let agent = AgentBuilder::new(model)
2069 .tool(MockAddTool)
2070 .tool_choice(ToolChoice::None)
2071 .build();
2072
2073 let err = agent
2074 .prompt("do not use tools")
2075 .with_hook(SkipDefaultApiHook)
2076 .max_turns(3)
2077 .await
2078 .expect_err("ToolChoice::None should reject skipped tool calls");
2079
2080 match err {
2081 PromptError::UnknownToolCall { tool_name, .. } => {
2082 assert_eq!(tool_name, "default_api");
2083 }
2084 other => panic!("expected UnknownToolCall, got {other:?}"),
2085 }
2086 assert_eq!(recorded.request_count(), 1);
2087 }
2088
2089 #[tokio::test]
2090 async fn typed_prompt_default_invalid_tool_call_fails_fast() {
2091 let model = MockCompletionModel::new([
2092 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2093 MockTurn::text(r#"{"value":"should not be requested"}"#),
2094 ]);
2095 let recorded = model.clone();
2096 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2097
2098 let err = agent
2099 .prompt_typed::<TypedAnswer>("return typed json")
2100 .with_hook(PanicOnUnknownToolHook)
2101 .max_turns(3)
2102 .await
2103 .expect_err("typed prompt should preserve fail-fast default");
2104
2105 match err {
2106 StructuredOutputError::PromptError(err) => match *err {
2107 PromptError::UnknownToolCall { tool_name, .. } => {
2108 assert_eq!(tool_name, "default_api");
2109 }
2110 other => panic!("expected UnknownToolCall, got {other:?}"),
2111 },
2112 other => panic!("expected prompt error, got {other:?}"),
2113 }
2114 assert_eq!(recorded.request_count(), 1);
2115 }
2116
2117 #[tokio::test]
2118 async fn typed_prompt_invalid_tool_call_hook_can_repair_tool_name() {
2119 let model = MockCompletionModel::new([
2120 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2121 MockTurn::text(r#"{"value":"repaired"}"#),
2122 ]);
2123 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2124
2125 let response = agent
2126 .prompt_typed::<TypedAnswer>("return typed json")
2127 .with_hook(RepairDefaultApiHook)
2128 .max_turns(3)
2129 .await
2130 .expect("typed prompt should repair invalid tool call");
2131
2132 assert_eq!(
2133 response,
2134 TypedAnswer {
2135 value: "repaired".to_string()
2136 }
2137 );
2138 }
2139
2140 #[tokio::test]
2141 async fn typed_prompt_invalid_tool_call_hook_can_retry_and_parse_response() {
2142 let model = MockCompletionModel::new([
2143 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2144 MockTurn::text(r#"{"value":"retried"}"#),
2145 ]);
2146 let recorded = model.clone();
2147 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2148
2149 let response = agent
2150 .prompt_typed::<TypedAnswer>("return typed json")
2151 .with_hook(RetryDefaultApiHook)
2152 .max_invalid_tool_call_retries(1)
2153 .max_turns(3)
2154 .await
2155 .expect("typed prompt should retry invalid tool call");
2156
2157 assert_eq!(
2158 response,
2159 TypedAnswer {
2160 value: "retried".to_string()
2161 }
2162 );
2163 assert_eq!(recorded.request_count(), 2);
2164 }
2165
2166 #[tokio::test]
2167 async fn typed_prompt_invalid_tool_call_retry_budget_exhaustion_fails() {
2168 let model = MockCompletionModel::new([
2169 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2170 MockTurn::text(r#"{"value":"should not be requested"}"#),
2171 ]);
2172 let recorded = model.clone();
2173 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2174
2175 let err = agent
2176 .prompt_typed::<TypedAnswer>("return typed json")
2177 .with_hook(RetryDefaultApiHook)
2178 .max_invalid_tool_call_retries(0)
2179 .max_turns(3)
2180 .await
2181 .expect_err("typed prompt should fail when retry budget is exhausted");
2182
2183 match err {
2184 StructuredOutputError::PromptError(err) => match *err {
2185 PromptError::UnknownToolCall { tool_name, .. } => {
2186 assert_eq!(tool_name, "default_api");
2187 }
2188 other => panic!("expected UnknownToolCall, got {other:?}"),
2189 },
2190 other => panic!("expected prompt error, got {other:?}"),
2191 }
2192 assert_eq!(recorded.request_count(), 1);
2193 }
2194
2195 #[tokio::test]
2196 async fn invalid_specific_tool_choice_fails_before_non_streaming_provider_request() {
2197 let model = MockCompletionModel::text("should not be requested");
2198 let recorded = model.clone();
2199 let agent = AgentBuilder::new(model)
2200 .tool(MockAddTool)
2201 .tool_choice(ToolChoice::Specific {
2202 function_names: vec!["missing".to_string()],
2203 })
2204 .build();
2205
2206 let err = agent
2207 .prompt("use the missing tool")
2208 .await
2209 .expect_err("invalid ToolChoice::Specific should fail before provider request");
2210
2211 match err {
2212 PromptError::CompletionError(CompletionError::RequestError(err)) => {
2213 let msg = err.to_string();
2214 assert!(msg.contains("missing"), "got: {msg}");
2215 assert!(msg.contains("add"), "got: {msg}");
2216 }
2217 other => panic!("expected CompletionError::RequestError, got {other:?}"),
2218 }
2219 assert_eq!(recorded.request_count(), 0);
2220 }
2221
2222 #[tokio::test]
2223 async fn allowed_specific_tool_call_executes_normally() {
2224 let model = MockCompletionModel::new([
2225 MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2})),
2226 MockTurn::text("done"),
2227 ]);
2228 let recorded = model.clone();
2229 let agent = AgentBuilder::new(model)
2230 .tool(MockAddTool)
2231 .tool_choice(ToolChoice::Specific {
2232 function_names: vec!["add".to_string()],
2233 })
2234 .build();
2235
2236 let response = agent
2237 .prompt("use the allowed tool")
2238 .max_turns(3)
2239 .await
2240 .expect("allowed specific tool should execute");
2241
2242 assert_eq!(response, "done");
2243 assert_eq!(recorded.request_count(), 2);
2244 }
2245
2246 #[tokio::test]
2247 async fn prompt_request_stops_cleanly_on_empty_terminal_turn() {
2248 let first_call_usage = Usage {
2249 input_tokens: 1,
2250 output_tokens: 1,
2251 total_tokens: 2,
2252 cached_input_tokens: 0,
2253 cache_creation_input_tokens: 0,
2254 tool_use_prompt_tokens: 0,
2255 reasoning_tokens: 0,
2256 };
2257 let second_call_usage = Usage {
2258 input_tokens: 1,
2259 output_tokens: 1,
2260 total_tokens: 2,
2261 cached_input_tokens: 0,
2262 cache_creation_input_tokens: 0,
2263 tool_use_prompt_tokens: 0,
2264 reasoning_tokens: 0,
2265 };
2266 let model = MockCompletionModel::new([
2267 MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2}))
2268 .with_call_id("call_1")
2269 .with_usage(first_call_usage),
2270 MockTurn::text("").with_usage(second_call_usage),
2271 ]);
2272 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2273
2274 let response = agent
2275 .prompt("do tool work")
2276 .max_turns(3)
2277 .extended_details()
2278 .await
2279 .expect("empty terminal turn should not error");
2280
2281 assert!(response.output.is_empty());
2282 assert_eq!(
2283 response.usage,
2284 Usage {
2285 input_tokens: 2,
2286 output_tokens: 2,
2287 total_tokens: 4,
2288 cached_input_tokens: 0,
2289 cache_creation_input_tokens: 0,
2290 tool_use_prompt_tokens: 0,
2291 reasoning_tokens: 0,
2292 }
2293 );
2294 assert_eq!(
2295 response.completion_calls(),
2296 &[
2297 CompletionCall::new(0, first_call_usage),
2298 CompletionCall::new(1, second_call_usage)
2299 ]
2300 );
2301
2302 let history = response
2303 .messages
2304 .expect("extended response should include history");
2305 assert_eq!(history.len(), 3);
2306 assert!(matches!(
2307 history.first(),
2308 Some(Message::User { content })
2309 if matches!(
2310 content.first(),
2311 UserContent::Text(text) if text.text == "do tool work"
2312 )
2313 ));
2314 assert!(history.iter().any(|message| matches!(
2315 message,
2316 Message::Assistant { content, .. }
2317 if matches!(
2318 content.first(),
2319 AssistantContent::ToolCall(tool_call)
2320 if tool_call.id == "tool_call_1"
2321 && tool_call.call_id.as_deref() == Some("call_1")
2322 )
2323 )));
2324 assert!(history.iter().any(|message| matches!(
2325 message,
2326 Message::User { content }
2327 if matches!(
2328 content.first(),
2329 UserContent::ToolResult(tool_result)
2330 if tool_result.id == "tool_call_1"
2331 && tool_result.call_id.as_deref() == Some("call_1")
2332 )
2333 )));
2334 assert!(!history.iter().any(|message| matches!(
2335 message,
2336 Message::Assistant { content, .. }
2337 if content.iter().any(|item| matches!(
2338 item,
2339 AssistantContent::Text(text) if text.text.is_empty()
2340 ))
2341 )));
2342 let requests = agent.model.requests();
2343 assert_eq!(requests.len(), 2);
2344 validate_follow_up_tool_history(&requests[1]);
2345 }
2346
2347 #[tokio::test]
2348 async fn prompt_request_concatenates_text_blocks_without_inserted_newlines() {
2349 let model = MockCompletionModel::new([MockTurn::from_contents([
2350 AssistantContent::Text(Text::new("According to the document, ")),
2351 AssistantContent::Text(Text::new("the grass is green")),
2352 AssistantContent::Text(Text::new(" and the sky is blue.")),
2353 ])
2354 .expect("mock response should contain text blocks")]);
2355 let agent = AgentBuilder::new(model).build();
2356
2357 let response = agent
2358 .prompt("answer with cited spans")
2359 .await
2360 .expect("prompt should succeed");
2361
2362 assert_eq!(
2363 response,
2364 "According to the document, the grass is green and the sky is blue."
2365 );
2366 }
2367
2368 #[tokio::test]
2369 async fn prompt_request_preserves_metadata_only_text_turn_in_history() {
2370 let metadata = json!({
2371 "citations": [{
2372 "type": "web_search_result_location",
2373 "cited_text": "Claude Shannon was born in 1916.",
2374 "url": "https://example.com/shannon",
2375 "title": null,
2376 "encrypted_index": "encrypted-reference"
2377 }]
2378 });
2379 let model =
2380 MockCompletionModel::new([MockTurn::from_content(AssistantContent::Text(Text {
2381 text: String::new(),
2382 additional_params: Some(metadata.clone()),
2383 }))]);
2384 let agent = AgentBuilder::new(model).build();
2385
2386 let response = agent
2387 .prompt("answer with cited metadata")
2388 .extended_details()
2389 .await
2390 .expect("metadata-only text turn should succeed");
2391
2392 assert!(response.output.is_empty());
2393 let history = response
2394 .messages
2395 .expect("extended response should include history");
2396 assert!(history.iter().any(|message| matches!(
2397 message,
2398 Message::Assistant { content, .. }
2399 if matches!(
2400 content.first(),
2401 AssistantContent::Text(text)
2402 if text.text.is_empty()
2403 && text.additional_params.as_ref() == Some(&metadata)
2404 )
2405 )));
2406 }
2407
2408 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
2411
2412 #[tokio::test]
2413 async fn memory_loads_into_request_history() {
2414 let memory = InMemoryConversationMemory::new();
2415 memory
2416 .append(
2417 "thread-1",
2418 vec![Message::user("hello"), Message::assistant("hi there")],
2419 )
2420 .await
2421 .unwrap();
2422
2423 let model = MockCompletionModel::text("ack");
2424 let recorded = model.clone();
2425
2426 let agent = AgentBuilder::new(model).memory(memory).build();
2427 let _ = agent
2428 .prompt("ping")
2429 .conversation("thread-1")
2430 .await
2431 .expect("prompt should succeed");
2432
2433 let received = recorded.requests()[0]
2434 .chat_history
2435 .iter()
2436 .cloned()
2437 .collect::<Vec<_>>();
2438 assert_eq!(
2439 received.len(),
2440 3,
2441 "loaded memory (2) + current prompt should appear in request: {received:?}"
2442 );
2443 }
2444
2445 #[tokio::test]
2446 async fn memory_appends_full_turn_after_success() {
2447 let memory = InMemoryConversationMemory::new();
2448 let model = MockCompletionModel::text("ack");
2449 let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2450
2451 let _ = agent
2452 .prompt("hello")
2453 .conversation("t1")
2454 .await
2455 .expect("prompt should succeed");
2456
2457 let stored = memory.load("t1").await.unwrap();
2458 assert_eq!(stored.len(), 2, "user prompt + assistant response saved");
2459 }
2460
2461 #[tokio::test]
2462 async fn explicit_with_history_overrides_memory() {
2463 let memory = CountingMemory::default();
2464 memory
2465 .inner()
2466 .append("t1", vec![Message::user("from-memory")])
2467 .await
2468 .unwrap();
2469
2470 let model = MockCompletionModel::text("ack");
2471 let recorded = model.clone();
2472
2473 let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2474 let _ = agent
2475 .prompt("hello")
2476 .conversation("t1")
2477 .with_history(vec![Message::user("from-caller")])
2478 .await
2479 .expect("prompt should succeed");
2480
2481 assert_eq!(memory.load_count(), 0, "load skipped");
2482 let appends = memory.append_count();
2483 assert_eq!(appends, 0, "append skipped");
2484
2485 let received = recorded.requests()[0]
2486 .chat_history
2487 .iter()
2488 .cloned()
2489 .collect::<Vec<_>>();
2490 assert_eq!(received.len(), 2, "caller history (1) + current prompt");
2491 assert!(matches!(
2492 received.first(),
2493 Some(Message::User { content })
2494 if matches!(content.first(), UserContent::Text(t) if t.text == "from-caller")
2495 ));
2496 }
2497
2498 #[tokio::test]
2499 async fn memory_unchanged_on_provider_error() {
2500 let memory = InMemoryConversationMemory::new();
2501 let model = MockCompletionModel::new([MockTurn::error("boom")]);
2502
2503 let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2504 let result = agent.prompt("hello").conversation("t1").await;
2505 assert!(result.is_err());
2506
2507 let stored = memory.load("t1").await.unwrap();
2508 assert!(stored.is_empty(), "no append on error");
2509 }
2510
2511 #[tokio::test]
2512 async fn missing_conversation_id_behaves_as_no_memory() {
2513 let memory = CountingMemory::default();
2514 let model = MockCompletionModel::text("ack");
2515 let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2516
2517 let _ = agent.prompt("hello").await.expect("prompt should succeed");
2518
2519 assert_eq!(memory.load_count(), 0);
2520 assert_eq!(memory.append_count(), 0);
2521 }
2522
2523 #[tokio::test]
2524 async fn default_conversation_id_is_used_when_none_per_request() {
2525 let memory = InMemoryConversationMemory::new();
2526 let model = MockCompletionModel::text("ack");
2527 let agent = AgentBuilder::new(model)
2528 .memory(memory.clone())
2529 .conversation_id("default-thread")
2530 .build();
2531
2532 let _ = agent.prompt("hello").await.expect("prompt should succeed");
2533 let stored = memory.load("default-thread").await.unwrap();
2534 assert_eq!(stored.len(), 2);
2535 }
2536
2537 #[tokio::test]
2538 async fn with_filter_truncates_loaded_history() {
2539 let memory = InMemoryConversationMemory::new()
2540 .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
2541 memory
2542 .append(
2543 "t1",
2544 vec![
2545 Message::user("1"),
2546 Message::assistant("2"),
2547 Message::user("3"),
2548 Message::assistant("4"),
2549 ],
2550 )
2551 .await
2552 .unwrap();
2553
2554 let model = MockCompletionModel::text("ack");
2555 let recorded = model.clone();
2556 let agent = AgentBuilder::new(model).memory(memory).build();
2557
2558 let _ = agent
2559 .prompt("ping")
2560 .conversation("t1")
2561 .await
2562 .expect("prompt should succeed");
2563
2564 let received = recorded.requests()[0]
2565 .chat_history
2566 .iter()
2567 .cloned()
2568 .collect::<Vec<_>>();
2569 assert_eq!(
2570 received.len(),
2571 3,
2572 "window-truncated history (2) + current prompt"
2573 );
2574 }
2575
2576 #[tokio::test]
2577 async fn without_memory_disables_for_request() {
2578 let memory = CountingMemory::default();
2579 let model = MockCompletionModel::text("ack");
2580 let agent = AgentBuilder::new(model)
2581 .memory(memory.clone())
2582 .conversation_id("t1")
2583 .build();
2584
2585 let _ = agent
2586 .prompt("hello")
2587 .without_memory()
2588 .await
2589 .expect("prompt should succeed");
2590
2591 assert_eq!(memory.load_count(), 0);
2592 assert_eq!(memory.append_count(), 0);
2593 }
2594
2595 #[tokio::test]
2596 async fn memory_load_error_surfaces_as_prompt_error() {
2597 let model = MockCompletionModel::text("ack");
2598 let agent = AgentBuilder::new(model)
2599 .memory(FailingMemory::default())
2600 .build();
2601 let result = agent.prompt("hello").conversation("t1").await;
2602
2603 match result {
2604 Err(PromptError::CompletionError(CompletionError::RequestError(err))) => {
2605 let msg = format!("{err}");
2606 assert!(msg.contains("load boom"), "got: {msg}");
2607 }
2608 other => panic!("expected PromptError::CompletionError(RequestError), got {other:?}"),
2609 }
2610 }
2611
2612 #[tokio::test]
2613 async fn memory_append_error_does_not_drop_response() {
2614 let model = MockCompletionModel::text("ack");
2615 let agent = AgentBuilder::new(model)
2616 .memory(AppendFailingMemory::default())
2617 .build();
2618 let response: String = agent
2619 .prompt("hello")
2620 .conversation("t1")
2621 .await
2622 .expect("append failure must not block successful completion");
2623
2624 assert!(!response.is_empty());
2625 }
2626}