1pub mod hooks;
2pub mod streaming;
3
4use super::{
5 Agent,
6 completion::{DynamicContextStore, build_prepared_completion_request},
7};
8use crate::{
9 OneOrMany,
10 completion::{CompletionModel, Document, Message, PromptError, Usage},
11 json_utils,
12 memory::ConversationMemory,
13 message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
14 tool::server::ToolServerHandle,
15 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
16};
17use futures::{StreamExt, stream};
18use hooks::{
19 HookAction, InvalidToolCallContext, InvalidToolCallHookAction, PromptHook, ToolCallHookAction,
20};
21use serde::{Deserialize, Serialize};
22use std::{
23 collections::{BTreeMap, BTreeSet},
24 future::IntoFuture,
25 marker::PhantomData,
26 sync::{
27 Arc,
28 atomic::{AtomicU64, Ordering},
29 },
30};
31use tracing::info_span;
32use tracing::{Instrument, span::Id};
33
34pub trait PromptType {}
35pub struct Standard;
36pub struct Extended;
37
38impl PromptType for Standard {}
39impl PromptType for Extended {}
40
41pub struct PromptRequest<S, M, P>
50where
51 S: PromptType,
52 M: CompletionModel,
53 P: PromptHook<M>,
54{
55 prompt: Message,
57 chat_history: Option<Vec<Message>>,
59 max_turns: usize,
61
62 model: Arc<M>,
65 agent_name: Option<String>,
67 preamble: Option<String>,
69 static_context: Vec<Document>,
71 temperature: Option<f64>,
73 max_tokens: Option<u64>,
75 additional_params: Option<serde_json::Value>,
77 tool_server_handle: ToolServerHandle,
79 dynamic_context: DynamicContextStore,
81 tool_choice: Option<ToolChoice>,
83
84 state: PhantomData<S>,
86 hook: Option<P>,
88 max_invalid_tool_call_retries: usize,
90 concurrency: usize,
92 output_schema: Option<schemars::Schema>,
94 memory: Option<Arc<dyn ConversationMemory>>,
96 conversation_id: Option<String>,
98}
99
100impl<M, P> PromptRequest<Standard, M, P>
101where
102 M: CompletionModel,
103 P: PromptHook<M>,
104{
105 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
107 PromptRequest {
108 prompt: prompt.into(),
109 chat_history: None,
110 max_turns: agent.default_max_turns.unwrap_or_default(),
111 model: agent.model.clone(),
112 agent_name: agent.name.clone(),
113 preamble: agent.preamble.clone(),
114 static_context: agent.static_context.clone(),
115 temperature: agent.temperature,
116 max_tokens: agent.max_tokens,
117 additional_params: agent.additional_params.clone(),
118 tool_server_handle: agent.tool_server_handle.clone(),
119 dynamic_context: agent.dynamic_context.clone(),
120 tool_choice: agent.tool_choice.clone(),
121 state: PhantomData,
122 hook: agent.hook.clone(),
123 max_invalid_tool_call_retries: 0,
124 concurrency: 1,
125 output_schema: agent.output_schema.clone(),
126 memory: agent.memory.clone(),
127 conversation_id: agent.default_conversation_id.clone(),
128 }
129 }
130}
131
132impl<S, M, P> PromptRequest<S, M, P>
133where
134 S: PromptType,
135 M: CompletionModel,
136 P: PromptHook<M>,
137{
138 pub fn extended_details(self) -> PromptRequest<Extended, M, P> {
145 PromptRequest {
146 prompt: self.prompt,
147 chat_history: self.chat_history,
148 max_turns: self.max_turns,
149 model: self.model,
150 agent_name: self.agent_name,
151 preamble: self.preamble,
152 static_context: self.static_context,
153 temperature: self.temperature,
154 max_tokens: self.max_tokens,
155 additional_params: self.additional_params,
156 tool_server_handle: self.tool_server_handle,
157 dynamic_context: self.dynamic_context,
158 tool_choice: self.tool_choice,
159 state: PhantomData,
160 hook: self.hook,
161 max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
162 concurrency: self.concurrency,
163 output_schema: self.output_schema,
164 memory: self.memory,
165 conversation_id: self.conversation_id,
166 }
167 }
168
169 pub fn max_turns(mut self, depth: usize) -> Self {
172 self.max_turns = depth;
173 self
174 }
175
176 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
179 self.concurrency = concurrency;
180 self
181 }
182
183 pub fn with_history<H, T>(mut self, history: H) -> Self
185 where
186 H: IntoIterator<Item = T>,
187 T: Into<Message>,
188 {
189 self.chat_history = Some(history.into_iter().map(Into::into).collect());
190 self
191 }
192
193 pub fn conversation(mut self, id: impl Into<String>) -> Self {
198 self.conversation_id = Some(id.into());
199 self
200 }
201
202 pub fn without_memory(mut self) -> Self {
206 self.memory = None;
207 self.conversation_id = None;
208 self
209 }
210
211 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<S, M, P2>
214 where
215 P2: PromptHook<M>,
216 {
217 PromptRequest {
218 prompt: self.prompt,
219 chat_history: self.chat_history,
220 max_turns: self.max_turns,
221 model: self.model,
222 agent_name: self.agent_name,
223 preamble: self.preamble,
224 static_context: self.static_context,
225 temperature: self.temperature,
226 max_tokens: self.max_tokens,
227 additional_params: self.additional_params,
228 tool_server_handle: self.tool_server_handle,
229 dynamic_context: self.dynamic_context,
230 tool_choice: self.tool_choice,
231 state: PhantomData,
232 hook: Some(hook),
233 max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
234 concurrency: self.concurrency,
235 output_schema: self.output_schema,
236 memory: self.memory,
237 conversation_id: self.conversation_id,
238 }
239 }
240
241 pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
245 self.max_invalid_tool_call_retries = retries;
246 self
247 }
248}
249
250impl<M, P> IntoFuture for PromptRequest<Standard, M, P>
254where
255 M: CompletionModel + 'static,
256 P: PromptHook<M> + 'static,
257{
258 type Output = Result<String, PromptError>;
259 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
260
261 fn into_future(self) -> Self::IntoFuture {
262 Box::pin(self.send())
263 }
264}
265
266impl<M, P> IntoFuture for PromptRequest<Extended, M, P>
267where
268 M: CompletionModel + 'static,
269 P: PromptHook<M> + 'static,
270{
271 type Output = Result<PromptResponse, PromptError>;
272 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
273
274 fn into_future(self) -> Self::IntoFuture {
275 Box::pin(self.send())
276 }
277}
278
279impl<M, P> PromptRequest<Standard, M, P>
280where
281 M: CompletionModel,
282 P: PromptHook<M>,
283{
284 async fn send(self) -> Result<String, PromptError> {
285 self.extended_details().send().await.map(|resp| resp.output)
286 }
287}
288
289#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
291#[non_exhaustive]
292pub struct CompletionCall {
293 pub call_index: usize,
295 #[serde(default)]
300 pub usage: Option<Usage>,
301}
302
303impl CompletionCall {
304 pub fn new(call_index: usize, usage: Option<Usage>) -> Self {
306 Self { call_index, usage }
307 }
308
309 pub(crate) fn from_reported_usage(call_index: usize, usage: Usage) -> Self {
310 Self::new(call_index, reported_usage(usage))
311 }
312}
313
314pub(crate) fn reported_usage(usage: Usage) -> Option<Usage> {
315 (usage != Usage::new()).then_some(usage)
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
319#[non_exhaustive]
320pub struct PromptResponse {
321 pub output: String,
322 pub usage: Usage,
323 #[serde(default, skip_serializing_if = "Vec::is_empty")]
330 pub completion_calls: Vec<CompletionCall>,
331 pub messages: Option<Vec<Message>>,
332}
333
334impl std::fmt::Display for PromptResponse {
335 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336 self.output.fmt(f)
337 }
338}
339
340impl PromptResponse {
341 pub fn new(output: impl Into<String>, usage: Usage) -> Self {
342 Self {
343 output: output.into(),
344 usage,
345 completion_calls: Vec::new(),
346 messages: None,
347 }
348 }
349
350 pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
351 self.messages = Some(messages);
352 self
353 }
354
355 pub fn with_completion_calls(mut self, completion_calls: Vec<CompletionCall>) -> Self {
357 self.completion_calls = completion_calls;
358 self
359 }
360
361 pub fn completion_calls(&self) -> &[CompletionCall] {
365 &self.completion_calls
366 }
367}
368
369#[derive(Debug, Clone, Serialize, Deserialize)]
370#[non_exhaustive]
371pub struct TypedPromptResponse<T> {
372 pub output: T,
373 pub usage: Usage,
374 #[serde(default, skip_serializing_if = "Vec::is_empty")]
381 pub completion_calls: Vec<CompletionCall>,
382}
383
384impl<T> TypedPromptResponse<T> {
385 pub fn new(output: T, usage: Usage) -> Self {
386 Self {
387 output,
388 usage,
389 completion_calls: Vec::new(),
390 }
391 }
392
393 pub fn with_completion_calls(mut self, completion_calls: Vec<CompletionCall>) -> Self {
395 self.completion_calls = completion_calls;
396 self
397 }
398
399 pub fn completion_calls(&self) -> &[CompletionCall] {
403 &self.completion_calls
404 }
405}
406
407const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
408
409pub(crate) const TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER: &str =
410 "Tool not executed because another tool call in the same assistant turn was invalid.";
411
412fn build_history_for_request(
414 chat_history: Option<&[Message]>,
415 new_messages: &[Message],
416) -> Vec<Message> {
417 let input = chat_history.unwrap_or(&[]);
418 input.iter().chain(new_messages.iter()).cloned().collect()
419}
420
421fn build_full_history(
423 chat_history: Option<&[Message]>,
424 new_messages: Vec<Message>,
425) -> Vec<Message> {
426 let input = chat_history.unwrap_or(&[]);
427 input.iter().cloned().chain(new_messages).collect()
428}
429
430fn tool_result_user_content(
431 id: String,
432 call_id: Option<String>,
433 tool_result: String,
434) -> UserContent {
435 let content = ToolResultContent::from_tool_output(tool_result);
436 match call_id {
437 Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
438 None => UserContent::tool_result(id, content),
439 }
440}
441
442fn invalid_tool_retry_user_message(
443 assistant_content: &OneOrMany<AssistantContent>,
444 invalid_tool_call_id: &str,
445 feedback: String,
446) -> Option<Message> {
447 let retry_results = assistant_content
448 .iter()
449 .filter_map(|content| match content {
450 AssistantContent::ToolCall(tool_call) if tool_call.id == invalid_tool_call_id => {
451 Some(tool_result_user_content(
452 tool_call.id.clone(),
453 tool_call.call_id.clone(),
454 feedback.clone(),
455 ))
456 }
457 AssistantContent::ToolCall(tool_call) => Some(tool_result_user_content(
458 tool_call.id.clone(),
459 tool_call.call_id.clone(),
460 TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
461 )),
462 _ => None,
463 })
464 .collect::<Vec<_>>();
465
466 Some(Message::User {
467 content: OneOrMany::from_iter_optional(retry_results)?,
468 })
469}
470
471pub(crate) fn validate_tool_call_name(
472 tool_name: &str,
473 executable_tool_names: &BTreeSet<String>,
474 allowed_tool_names: &BTreeSet<String>,
475 chat_history: Vec<Message>,
476) -> Result<(), PromptError> {
477 if allowed_tool_names.contains(tool_name) {
478 return Ok(());
479 }
480
481 Err(PromptError::UnknownToolCall {
482 tool_name: tool_name.to_owned(),
483 available_tools: executable_tool_names.iter().cloned().collect(),
484 allowed_tools: allowed_tool_names.iter().cloned().collect(),
485 chat_history: Box::new(chat_history),
486 })
487}
488
489enum InvalidToolCallResolution {
490 Fail(PromptError),
491 Retry(String),
492 Repair(String),
493 Skip(String),
494}
495
496#[allow(clippy::too_many_arguments)]
497async fn resolve_invalid_tool_call<M, P>(
498 hook: Option<&P>,
499 tool_name: &str,
500 tool_call_id: Option<String>,
501 internal_call_id: Option<String>,
502 args: Option<String>,
503 executable_tool_names: &BTreeSet<String>,
504 allowed_tool_names: &BTreeSet<String>,
505 tool_choice: Option<&ToolChoice>,
506 chat_history: Vec<Message>,
507 is_streaming: bool,
508) -> InvalidToolCallResolution
509where
510 M: CompletionModel,
511 P: PromptHook<M>,
512{
513 let err = PromptError::UnknownToolCall {
514 tool_name: tool_name.to_owned(),
515 available_tools: executable_tool_names.iter().cloned().collect(),
516 allowed_tools: allowed_tool_names.iter().cloned().collect(),
517 chat_history: Box::new(chat_history.clone()),
518 };
519
520 let Some(hook) = hook else {
521 return InvalidToolCallResolution::Fail(err);
522 };
523
524 let context = InvalidToolCallContext {
525 tool_name: tool_name.to_owned(),
526 tool_call_id,
527 internal_call_id,
528 args,
529 available_tools: executable_tool_names.iter().cloned().collect(),
530 allowed_tools: allowed_tool_names.iter().cloned().collect(),
531 tool_choice: tool_choice.cloned(),
532 chat_history,
533 is_streaming,
534 };
535
536 match hook.on_invalid_tool_call(&context).await {
537 InvalidToolCallHookAction::Fail => InvalidToolCallResolution::Fail(err),
538 InvalidToolCallHookAction::Retry { feedback } => InvalidToolCallResolution::Retry(feedback),
539 InvalidToolCallHookAction::Repair { tool_name } => {
540 if allowed_tool_names.contains(&tool_name) {
541 InvalidToolCallResolution::Repair(tool_name)
542 } else {
543 InvalidToolCallResolution::Fail(PromptError::UnknownToolCall {
544 tool_name,
545 available_tools: executable_tool_names.iter().cloned().collect(),
546 allowed_tools: allowed_tool_names.iter().cloned().collect(),
547 chat_history: Box::new(context.chat_history),
548 })
549 }
550 }
551 InvalidToolCallHookAction::Skip { reason } => {
552 if matches!(tool_choice, Some(ToolChoice::None)) {
553 InvalidToolCallResolution::Fail(err)
554 } else {
555 InvalidToolCallResolution::Skip(reason)
556 }
557 }
558 }
559}
560
561fn is_empty_assistant_turn(choice: &OneOrMany<AssistantContent>) -> bool {
562 choice.len() == 1
563 && matches!(
564 choice.first(),
565 AssistantContent::Text(text) if text.text.is_empty() && text.additional_params.is_none()
566 )
567}
568
569fn assistant_text_from_choice(choice: &OneOrMany<AssistantContent>) -> String {
570 choice
571 .iter()
572 .filter_map(|content| match content {
573 AssistantContent::Text(text) => Some(text.text.as_str()),
574 _ => None,
575 })
576 .collect()
577}
578
579impl<M, P> PromptRequest<Extended, M, P>
580where
581 M: CompletionModel,
582 P: PromptHook<M>,
583{
584 fn agent_name(&self) -> &str {
585 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
586 }
587
588 async fn send(self) -> Result<PromptResponse, PromptError> {
589 let agent_span = if tracing::Span::current().is_disabled() {
590 info_span!(
591 "invoke_agent",
592 gen_ai.operation.name = "invoke_agent",
593 gen_ai.agent.name = self.agent_name(),
594 gen_ai.system_instructions = self.preamble,
595 gen_ai.prompt = tracing::field::Empty,
596 gen_ai.completion = tracing::field::Empty,
597 gen_ai.usage.input_tokens = tracing::field::Empty,
598 gen_ai.usage.output_tokens = tracing::field::Empty,
599 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
600 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
601 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
602 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
603 )
604 } else {
605 tracing::Span::current()
606 };
607
608 if let Some(text) = self.prompt.rag_text() {
609 agent_span.record("gen_ai.prompt", text);
610 }
611
612 let agent_name_for_span = self.agent_name.clone();
613 let (chat_history, memory_handle) = match self.chat_history {
618 Some(history) => (Some(history), None),
619 None => match (self.memory, self.conversation_id) {
620 (Some(memory), Some(id)) => {
621 let loaded = memory.load(&id).await?;
622 (Some(loaded), Some((memory, id)))
623 }
624 _ => (None, None),
625 },
626 };
627 let mut new_messages: Vec<Message> = vec![self.prompt.clone()];
628
629 let mut current_max_turns = 0;
630 let mut usage = Usage::new();
631 let mut completion_calls = Vec::new();
632 let mut completion_call_index = 0;
633 let mut invalid_tool_call_retries = 0;
634 let current_span_id: AtomicU64 = AtomicU64::new(0);
635
636 let last_prompt = 'prompt_loop: loop {
638 let Some((prompt_ref, history_for_current_turn)) = new_messages.split_last() else {
640 return Err(PromptError::prompt_cancelled(
641 build_full_history(chat_history.as_deref(), new_messages),
642 "prompt loop lost its pending prompt",
643 ));
644 };
645 let prompt = prompt_ref.clone();
646
647 if current_max_turns > self.max_turns + 1 {
648 break prompt;
649 }
650
651 current_max_turns += 1;
652
653 if self.max_turns > 1 {
654 tracing::info!(
655 "Current conversation depth: {}/{}",
656 current_max_turns,
657 self.max_turns
658 );
659 }
660
661 let history_for_hook =
663 build_history_for_request(chat_history.as_deref(), history_for_current_turn);
664
665 if let Some(ref hook) = self.hook
666 && let HookAction::Terminate { reason } =
667 hook.on_completion_call(&prompt, &history_for_hook).await
668 {
669 return Err(PromptError::prompt_cancelled(
670 build_full_history(chat_history.as_deref(), new_messages),
671 reason,
672 ));
673 }
674
675 let span = tracing::Span::current();
676 let chat_span = info_span!(
677 target: "rig::agent_chat",
678 parent: &span,
679 "chat",
680 gen_ai.operation.name = "chat",
681 gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
682 gen_ai.system_instructions = self.preamble,
683 gen_ai.provider.name = tracing::field::Empty,
684 gen_ai.request.model = tracing::field::Empty,
685 gen_ai.response.id = tracing::field::Empty,
686 gen_ai.response.model = tracing::field::Empty,
687 gen_ai.usage.output_tokens = tracing::field::Empty,
688 gen_ai.usage.input_tokens = tracing::field::Empty,
689 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
690 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
691 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
692 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
693 gen_ai.input.messages = tracing::field::Empty,
694 gen_ai.output.messages = tracing::field::Empty,
695 );
696
697 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
698 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
699 chat_span.follows_from(id).to_owned()
700 } else {
701 chat_span
702 };
703
704 if let Some(id) = chat_span.id() {
705 current_span_id.store(id.into_u64(), Ordering::SeqCst);
706 };
707
708 let history_for_request =
710 build_history_for_request(chat_history.as_deref(), history_for_current_turn);
711
712 let prepared_request = build_prepared_completion_request(
713 &self.model,
714 prompt.clone(),
715 &history_for_request,
716 self.preamble.as_deref(),
717 &self.static_context,
718 self.temperature,
719 self.max_tokens,
720 self.additional_params.as_ref(),
721 self.tool_choice.as_ref(),
722 &self.tool_server_handle,
723 &self.dynamic_context,
724 self.output_schema.as_ref(),
725 )
726 .await?;
727 let executable_tool_names = prepared_request.executable_tool_names.clone();
728 let allowed_tool_names = prepared_request.allowed_tool_names.clone();
729
730 let resp = prepared_request
731 .builder
732 .send()
733 .instrument(chat_span.clone())
734 .await?;
735
736 completion_calls.push(CompletionCall::from_reported_usage(
737 completion_call_index,
738 resp.usage,
739 ));
740 completion_call_index += 1;
741 usage += resp.usage;
742
743 let mut response_choice = resp.choice.clone();
744 let has_tool_calls = response_choice
745 .iter()
746 .any(|choice| matches!(choice, AssistantContent::ToolCall(_)));
747
748 let mut skipped_tool_results = BTreeMap::new();
752 let mut invalid_tool_call_recovered = false;
753 let mut invalid_tool_call_skipped = false;
754 if has_tool_calls {
755 for choice in response_choice.iter_mut() {
756 let AssistantContent::ToolCall(tool_call) = choice else {
757 continue;
758 };
759
760 if allowed_tool_names.contains(&tool_call.function.name) {
761 continue;
762 }
763
764 let mut diagnostic_messages = new_messages.clone();
765 diagnostic_messages.push(Message::Assistant {
766 id: resp.message_id.clone(),
767 content: resp.choice.clone(),
768 });
769 let diagnostic_history =
770 build_full_history(chat_history.as_deref(), diagnostic_messages);
771 let args = json_utils::value_to_json_string(&tool_call.function.arguments);
772 let emitted_tool_name = tool_call.function.name.clone();
773
774 match resolve_invalid_tool_call::<M, P>(
775 self.hook.as_ref(),
776 &emitted_tool_name,
777 Some(tool_call.id.clone()),
778 None,
779 Some(args),
780 &executable_tool_names,
781 &allowed_tool_names,
782 self.tool_choice.as_ref(),
783 diagnostic_history.clone(),
784 false,
785 )
786 .await
787 {
788 InvalidToolCallResolution::Fail(err) => return Err(err),
789 InvalidToolCallResolution::Retry(feedback) => {
790 if invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
791 return Err(PromptError::UnknownToolCall {
792 tool_name: emitted_tool_name,
793 available_tools: executable_tool_names
794 .iter()
795 .cloned()
796 .collect(),
797 allowed_tools: allowed_tool_names.iter().cloned().collect(),
798 chat_history: Box::new(diagnostic_history.clone()),
799 });
800 }
801
802 invalid_tool_call_retries += 1;
803 new_messages.push(Message::Assistant {
804 id: resp.message_id.clone(),
805 content: resp.choice.clone(),
806 });
807 let Some(user_message) = invalid_tool_retry_user_message(
808 &resp.choice,
809 &tool_call.id,
810 feedback,
811 ) else {
812 return Err(PromptError::prompt_cancelled(
813 diagnostic_history,
814 "invalid tool call retry produced no retry messages",
815 ));
816 };
817 new_messages.push(user_message);
818 continue 'prompt_loop;
819 }
820 InvalidToolCallResolution::Repair(repaired_name) => {
821 tool_call.function.name = repaired_name;
822 invalid_tool_call_recovered = true;
823 }
824 InvalidToolCallResolution::Skip(reason) => {
825 let user_content = if let Some(call_id) = tool_call.call_id.clone() {
826 UserContent::tool_result_with_call_id(
827 tool_call.id.clone(),
828 call_id,
829 OneOrMany::one(reason.into()),
830 )
831 } else {
832 UserContent::tool_result(
833 tool_call.id.clone(),
834 OneOrMany::one(reason.into()),
835 )
836 };
837 skipped_tool_results.insert(tool_call.id.clone(), user_content);
838 invalid_tool_call_recovered = true;
839 invalid_tool_call_skipped = true;
840 }
841 }
842 }
843 }
844
845 if invalid_tool_call_skipped {
846 for choice in response_choice.iter() {
847 let AssistantContent::ToolCall(tool_call) = choice else {
848 continue;
849 };
850
851 skipped_tool_results
852 .entry(tool_call.id.clone())
853 .or_insert_with(|| {
854 tool_result_user_content(
855 tool_call.id.clone(),
856 tool_call.call_id.clone(),
857 TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
858 )
859 });
860 }
861 }
862
863 let assistant_response_message =
864 (!is_empty_assistant_turn(&response_choice)).then(|| Message::Assistant {
865 id: resp.message_id.clone(),
866 content: response_choice.clone(),
867 });
868
869 if !invalid_tool_call_recovered
870 && let Some(ref hook) = self.hook
871 && let HookAction::Terminate { reason } =
872 hook.on_completion_response(&prompt, &resp).await
873 {
874 return Err(PromptError::prompt_cancelled(
875 build_full_history(chat_history.as_deref(), new_messages),
876 reason,
877 ));
878 }
879
880 if let Some(message) = assistant_response_message {
881 new_messages.push(message);
882 }
883
884 if !has_tool_calls {
885 let merged_texts = assistant_text_from_choice(&response_choice);
886
887 if self.max_turns > 1 {
888 tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
889 }
890
891 agent_span.record("gen_ai.completion", &merged_texts);
892 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
893 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
894 agent_span.record(
895 "gen_ai.usage.cache_read.input_tokens",
896 usage.cached_input_tokens,
897 );
898 agent_span.record(
899 "gen_ai.usage.cache_creation.input_tokens",
900 usage.cache_creation_input_tokens,
901 );
902 agent_span.record(
903 "gen_ai.usage.tool_use_prompt_tokens",
904 usage.tool_use_prompt_tokens,
905 );
906 agent_span.record("gen_ai.usage.reasoning_tokens", usage.reasoning_tokens);
907
908 if let Some((memory, id)) = memory_handle.as_ref()
909 && let Err(err) = memory.append(id, new_messages.clone()).await
910 {
911 tracing::warn!(
912 error = %err,
913 conversation_id = %id,
914 "conversation memory append failed; returning model response anyway"
915 );
916 }
917
918 return Ok(PromptResponse::new(merged_texts, usage)
919 .with_messages(new_messages)
920 .with_completion_calls(completion_calls));
921 }
922
923 let hook = self.hook.clone();
924 let tool_server_handle = self.tool_server_handle.clone();
925 let skipped_tool_results = Arc::new(skipped_tool_results);
926
927 let full_history_for_errors =
929 build_full_history(chat_history.as_deref(), new_messages.clone());
930
931 let tool_calls: Vec<AssistantContent> = response_choice
932 .iter()
933 .filter(|choice| matches!(choice, AssistantContent::ToolCall(_)))
934 .cloned()
935 .collect();
936 let tool_content = stream::iter(tool_calls)
937 .map(|choice| {
938 let hook1 = hook.clone();
939 let hook2 = hook.clone();
940 let tool_server_handle = tool_server_handle.clone();
941 let skipped_tool_results = skipped_tool_results.clone();
942
943 let tool_span = info_span!(
944 "execute_tool",
945 gen_ai.operation.name = "execute_tool",
946 gen_ai.tool.type = "function",
947 gen_ai.tool.name = tracing::field::Empty,
948 gen_ai.tool.call.id = tracing::field::Empty,
949 gen_ai.tool.call.arguments = tracing::field::Empty,
950 gen_ai.tool.call.result = tracing::field::Empty
951 );
952
953 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
954 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
955 tool_span.follows_from(id).to_owned()
956 } else {
957 tool_span
958 };
959
960 if let Some(id) = tool_span.id() {
961 current_span_id.store(id.into_u64(), Ordering::SeqCst);
962 };
963
964 let cloned_history_for_error = full_history_for_errors.clone();
966
967 async move {
968 if let AssistantContent::ToolCall(tool_call) = choice {
969 let tool_name = &tool_call.function.name;
970 let args =
971 json_utils::value_to_json_string(&tool_call.function.arguments);
972 let internal_call_id = nanoid::nanoid!();
973 if let Some(result) = skipped_tool_results.get(&tool_call.id) {
974 return Ok(result.clone());
975 }
976 let tool_span = tracing::Span::current();
977 tool_span.record("gen_ai.tool.name", tool_name);
978 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
979 tool_span.record("gen_ai.tool.call.arguments", &args);
980 if let Some(hook) = hook1 {
981 let action = hook
982 .on_tool_call(
983 tool_name,
984 tool_call.call_id.clone(),
985 &internal_call_id,
986 &args,
987 )
988 .await;
989
990 if let ToolCallHookAction::Terminate { reason } = action {
991 return Err(PromptError::prompt_cancelled(
992 cloned_history_for_error,
993 reason,
994 ));
995 }
996
997 if let ToolCallHookAction::Skip { reason } = action {
998 tracing::info!(
1000 tool_name = tool_name,
1001 reason = reason,
1002 "Tool call rejected"
1003 );
1004 if let Some(call_id) = tool_call.call_id.clone() {
1005 return Ok(UserContent::tool_result_with_call_id(
1006 tool_call.id.clone(),
1007 call_id,
1008 OneOrMany::one(reason.into()),
1009 ));
1010 } else {
1011 return Ok(UserContent::tool_result(
1012 tool_call.id.clone(),
1013 OneOrMany::one(reason.into()),
1014 ));
1015 }
1016 }
1017 }
1018 let output = match tool_server_handle.call_tool(tool_name, &args).await
1019 {
1020 Ok(res) => res,
1021 Err(e) => {
1022 tracing::warn!("Error while executing tool: {e}");
1023 e.to_string()
1024 }
1025 };
1026 if let Some(hook) = hook2
1027 && let HookAction::Terminate { reason } = hook
1028 .on_tool_result(
1029 tool_name,
1030 tool_call.call_id.clone(),
1031 &internal_call_id,
1032 &args,
1033 &output.to_string(),
1034 )
1035 .await
1036 {
1037 return Err(PromptError::prompt_cancelled(
1038 cloned_history_for_error,
1039 reason,
1040 ));
1041 }
1042
1043 tool_span.record("gen_ai.tool.call.result", &output);
1044 tracing::info!(
1045 "executed tool {tool_name} with args {args}. result: {output}"
1046 );
1047 if let Some(call_id) = tool_call.call_id.clone() {
1048 Ok(UserContent::tool_result_with_call_id(
1049 tool_call.id.clone(),
1050 call_id,
1051 ToolResultContent::from_tool_output(output),
1052 ))
1053 } else {
1054 Ok(UserContent::tool_result(
1055 tool_call.id.clone(),
1056 ToolResultContent::from_tool_output(output),
1057 ))
1058 }
1059 } else {
1060 Err(PromptError::prompt_cancelled(
1061 Vec::new(),
1062 "tool execution received non-tool assistant content",
1063 ))
1064 }
1065 }
1066 .instrument(tool_span)
1067 })
1068 .buffer_unordered(self.concurrency)
1069 .collect::<Vec<Result<UserContent, PromptError>>>()
1070 .await
1071 .into_iter()
1072 .collect::<Result<Vec<_>, _>>()?;
1073
1074 let Some(content) = OneOrMany::from_iter_optional(tool_content) else {
1075 return Err(PromptError::prompt_cancelled(
1076 build_full_history(chat_history.as_deref(), new_messages),
1077 "tool execution produced no tool results",
1078 ));
1079 };
1080
1081 new_messages.push(Message::User { content });
1082 };
1083
1084 Err(PromptError::MaxTurnsError {
1086 max_turns: self.max_turns,
1087 chat_history: build_full_history(chat_history.as_deref(), new_messages).into(),
1088 prompt: last_prompt.into(),
1089 })
1090 }
1091}
1092
1093use crate::completion::StructuredOutputError;
1098use schemars::{JsonSchema, schema_for};
1099use serde::de::DeserializeOwned;
1100
1101pub struct TypedPromptRequest<T, S, M, P>
1118where
1119 T: JsonSchema + DeserializeOwned + WasmCompatSend,
1120 S: PromptType,
1121 M: CompletionModel,
1122 P: PromptHook<M>,
1123{
1124 inner: PromptRequest<S, M, P>,
1125 _phantom: std::marker::PhantomData<T>,
1126}
1127
1128impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
1129where
1130 T: JsonSchema + DeserializeOwned + WasmCompatSend,
1131 M: CompletionModel,
1132 P: PromptHook<M>,
1133{
1134 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
1138 let mut inner = PromptRequest::from_agent(agent, prompt);
1139 inner.output_schema = Some(schema_for!(T));
1141 Self {
1142 inner,
1143 _phantom: std::marker::PhantomData,
1144 }
1145 }
1146}
1147
1148impl<T, S, M, P> TypedPromptRequest<T, S, M, P>
1149where
1150 T: JsonSchema + DeserializeOwned + WasmCompatSend,
1151 S: PromptType,
1152 M: CompletionModel,
1153 P: PromptHook<M>,
1154{
1155 pub fn extended_details(self) -> TypedPromptRequest<T, Extended, M, P> {
1161 TypedPromptRequest {
1162 inner: self.inner.extended_details(),
1163 _phantom: std::marker::PhantomData,
1164 }
1165 }
1166
1167 pub fn max_turns(mut self, depth: usize) -> Self {
1173 self.inner = self.inner.max_turns(depth);
1174 self
1175 }
1176
1177 pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
1181 self.inner = self.inner.max_invalid_tool_call_retries(retries);
1182 self
1183 }
1184
1185 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
1189 self.inner = self.inner.with_tool_concurrency(concurrency);
1190 self
1191 }
1192
1193 pub fn with_history<H, U>(mut self, history: H) -> Self
1195 where
1196 H: IntoIterator<Item = U>,
1197 U: Into<Message>,
1198 {
1199 self.inner = self.inner.with_history(history);
1200 self
1201 }
1202
1203 pub fn conversation(mut self, id: impl Into<String>) -> Self {
1208 self.inner = self.inner.conversation(id);
1209 self
1210 }
1211
1212 pub fn without_memory(mut self) -> Self {
1216 self.inner = self.inner.without_memory();
1217 self
1218 }
1219
1220 pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<T, S, M, P2>
1224 where
1225 P2: PromptHook<M>,
1226 {
1227 TypedPromptRequest {
1228 inner: self.inner.with_hook(hook),
1229 _phantom: std::marker::PhantomData,
1230 }
1231 }
1232}
1233
1234impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
1235where
1236 T: JsonSchema + DeserializeOwned + WasmCompatSend,
1237 M: CompletionModel,
1238 P: PromptHook<M>,
1239{
1240 async fn send(self) -> Result<T, StructuredOutputError> {
1242 let response = self.inner.send().await.map_err(Box::new)?;
1243
1244 if response.is_empty() {
1245 return Err(StructuredOutputError::EmptyResponse);
1246 }
1247
1248 let parsed: T = serde_json::from_str(&response)?;
1249 Ok(parsed)
1250 }
1251}
1252
1253impl<T, M, P> TypedPromptRequest<T, Extended, M, P>
1254where
1255 T: JsonSchema + DeserializeOwned + WasmCompatSend,
1256 M: CompletionModel,
1257 P: PromptHook<M>,
1258{
1259 async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
1261 let response = self.inner.send().await.map_err(Box::new)?;
1262
1263 if response.output.is_empty() {
1264 return Err(StructuredOutputError::EmptyResponse);
1265 }
1266
1267 let parsed: T = serde_json::from_str(&response.output)?;
1268 Ok(TypedPromptResponse::new(parsed, response.usage)
1269 .with_completion_calls(response.completion_calls))
1270 }
1271}
1272
1273impl<T, M, P> IntoFuture for TypedPromptRequest<T, Standard, M, P>
1274where
1275 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
1276 M: CompletionModel + 'static,
1277 P: PromptHook<M> + 'static,
1278{
1279 type Output = Result<T, StructuredOutputError>;
1280 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1281
1282 fn into_future(self) -> Self::IntoFuture {
1283 Box::pin(self.send())
1284 }
1285}
1286
1287impl<T, M, P> IntoFuture for TypedPromptRequest<T, Extended, M, P>
1288where
1289 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
1290 M: CompletionModel + 'static,
1291 P: PromptHook<M> + 'static,
1292{
1293 type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
1294 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1295
1296 fn into_future(self) -> Self::IntoFuture {
1297 Box::pin(self.send())
1298 }
1299}
1300
1301#[cfg(test)]
1302mod tests {
1303 use super::{CompletionCall, PromptResponse, TypedPromptResponse};
1304 use crate::{
1305 agent::{
1306 AgentBuilder,
1307 prompt_request::hooks::{
1308 HookAction, InvalidToolCallContext, InvalidToolCallHookAction, PromptHook,
1309 ToolCallHookAction,
1310 },
1311 },
1312 completion::{
1313 AssistantContent, CompletionError, CompletionModel, CompletionRequest, Message, Prompt,
1314 PromptError, StructuredOutputError, ToolDefinition, TypedPrompt, Usage,
1315 },
1316 message::{Text, ToolCall, ToolChoice, ToolFunction, UserContent},
1317 test_utils::{
1318 AppendFailingMemory, CountingMemory, FailingMemory, MockAddTool, MockCompletionModel,
1319 MockOperationArgs, MockSubtractTool, MockToolError, MockTurn,
1320 },
1321 tool::Tool,
1322 };
1323 use schemars::JsonSchema;
1324 use serde::{Deserialize, Serialize};
1325 use serde_json::json;
1326 use std::sync::{
1327 Arc, Mutex,
1328 atomic::{AtomicU32, Ordering},
1329 };
1330
1331 #[derive(Serialize)]
1332 struct SerializeOnly {
1333 value: &'static str,
1334 }
1335
1336 #[derive(Deserialize)]
1337 struct DeserializeOnly {
1338 value: String,
1339 }
1340
1341 #[derive(Debug, Deserialize, JsonSchema, PartialEq)]
1342 struct TypedAnswer {
1343 value: String,
1344 }
1345
1346 #[derive(Clone)]
1347 struct PanicOnUnknownToolHook;
1348
1349 impl PromptHook<MockCompletionModel> for PanicOnUnknownToolHook {
1350 async fn on_completion_response(
1351 &self,
1352 _prompt: &Message,
1353 _response: &crate::completion::CompletionResponse<
1354 <MockCompletionModel as CompletionModel>::Response,
1355 >,
1356 ) -> HookAction {
1357 panic!("unknown tool response should fail before response hooks run")
1358 }
1359
1360 async fn on_tool_call(
1361 &self,
1362 _tool_name: &str,
1363 _tool_call_id: Option<String>,
1364 _internal_call_id: &str,
1365 _args: &str,
1366 ) -> ToolCallHookAction {
1367 panic!("unknown tool call should fail before tool hooks run")
1368 }
1369 }
1370
1371 #[derive(Clone)]
1372 struct PanicOnToolCallHook;
1373
1374 impl PromptHook<MockCompletionModel> for PanicOnToolCallHook {
1375 async fn on_tool_call(
1376 &self,
1377 _tool_name: &str,
1378 _tool_call_id: Option<String>,
1379 _internal_call_id: &str,
1380 _args: &str,
1381 ) -> ToolCallHookAction {
1382 panic!("recovered invalid turn should not invoke normal tool hooks")
1383 }
1384 }
1385
1386 #[derive(Clone)]
1387 struct SkipDefaultApiAndPanicOnToolCallHook;
1388
1389 impl PromptHook<MockCompletionModel> for SkipDefaultApiAndPanicOnToolCallHook {
1390 async fn on_invalid_tool_call(
1391 &self,
1392 context: &InvalidToolCallContext,
1393 ) -> InvalidToolCallHookAction {
1394 SkipDefaultApiHook.on_invalid_tool_call(context).await
1395 }
1396
1397 async fn on_tool_call(
1398 &self,
1399 tool_name: &str,
1400 tool_call_id: Option<String>,
1401 internal_call_id: &str,
1402 args: &str,
1403 ) -> ToolCallHookAction {
1404 PanicOnToolCallHook
1405 .on_tool_call(tool_name, tool_call_id, internal_call_id, args)
1406 .await
1407 }
1408 }
1409
1410 #[derive(Clone)]
1411 struct RepairDefaultApiHook;
1412
1413 impl PromptHook<MockCompletionModel> for RepairDefaultApiHook {
1414 fn on_invalid_tool_call(
1415 &self,
1416 context: &InvalidToolCallContext,
1417 ) -> impl std::future::Future<Output = InvalidToolCallHookAction> + Send {
1418 let tool_name = context.tool_name.clone();
1419 async move {
1420 assert_eq!(tool_name, "default_api");
1421 InvalidToolCallHookAction::repair("add")
1422 }
1423 }
1424 }
1425
1426 #[derive(Clone)]
1427 struct RepairToSubtractHook;
1428
1429 impl PromptHook<MockCompletionModel> for RepairToSubtractHook {
1430 async fn on_invalid_tool_call(
1431 &self,
1432 _context: &InvalidToolCallContext,
1433 ) -> InvalidToolCallHookAction {
1434 InvalidToolCallHookAction::repair("subtract")
1435 }
1436 }
1437
1438 #[derive(Clone)]
1439 struct RetryDefaultApiHook;
1440
1441 impl PromptHook<MockCompletionModel> for RetryDefaultApiHook {
1442 fn on_invalid_tool_call(
1443 &self,
1444 context: &InvalidToolCallContext,
1445 ) -> impl std::future::Future<Output = InvalidToolCallHookAction> + Send {
1446 let allowed_tools = context.allowed_tools.clone();
1447 async move {
1448 InvalidToolCallHookAction::retry(format!(
1449 "Use one of these tools instead: {allowed_tools:?}"
1450 ))
1451 }
1452 }
1453 }
1454
1455 #[derive(Clone)]
1456 struct SkipDefaultApiHook;
1457
1458 impl PromptHook<MockCompletionModel> for SkipDefaultApiHook {
1459 async fn on_invalid_tool_call(
1460 &self,
1461 _context: &InvalidToolCallContext,
1462 ) -> InvalidToolCallHookAction {
1463 InvalidToolCallHookAction::skip("default_api is not available")
1464 }
1465 }
1466
1467 #[derive(Clone, Default)]
1468 struct RecordingInvalidToolCallHook {
1469 contexts: Arc<Mutex<Vec<InvalidToolCallContext>>>,
1470 }
1471
1472 impl RecordingInvalidToolCallHook {
1473 fn observed(&self) -> Vec<InvalidToolCallContext> {
1474 self.contexts
1475 .lock()
1476 .expect("invalid tool context records mutex was poisoned")
1477 .clone()
1478 }
1479 }
1480
1481 impl PromptHook<MockCompletionModel> for RecordingInvalidToolCallHook {
1482 async fn on_invalid_tool_call(
1483 &self,
1484 context: &InvalidToolCallContext,
1485 ) -> InvalidToolCallHookAction {
1486 self.contexts
1487 .lock()
1488 .expect("invalid tool context records mutex was poisoned")
1489 .push(context.clone());
1490 InvalidToolCallHookAction::fail()
1491 }
1492 }
1493
1494 #[derive(Clone)]
1495 struct CountingAddTool {
1496 calls: Arc<AtomicU32>,
1497 }
1498
1499 impl Tool for CountingAddTool {
1500 const NAME: &'static str = "add";
1501 type Error = MockToolError;
1502 type Args = MockOperationArgs;
1503 type Output = i32;
1504
1505 async fn definition(&self, _prompt: String) -> ToolDefinition {
1506 MockAddTool.definition(String::new()).await
1507 }
1508
1509 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
1510 self.calls.fetch_add(1, Ordering::SeqCst);
1511 Ok(0)
1512 }
1513 }
1514
1515 fn usage(input_tokens: u64, output_tokens: u64) -> Usage {
1516 Usage {
1517 input_tokens,
1518 output_tokens,
1519 total_tokens: input_tokens + output_tokens,
1520 cached_input_tokens: 0,
1521 cache_creation_input_tokens: 0,
1522 tool_use_prompt_tokens: 0,
1523 reasoning_tokens: 0,
1524 }
1525 }
1526
1527 #[test]
1528 fn typed_prompt_response_serializes_with_serialize_only_output() {
1529 let response = TypedPromptResponse::new(
1530 SerializeOnly { value: "ok" },
1531 Usage {
1532 input_tokens: 1,
1533 output_tokens: 2,
1534 total_tokens: 3,
1535 cached_input_tokens: 0,
1536 cache_creation_input_tokens: 0,
1537 tool_use_prompt_tokens: 0,
1538 reasoning_tokens: 0,
1539 },
1540 );
1541
1542 let json = serde_json::to_string(&response).expect("serialize typed prompt response");
1543 assert!(json.contains("\"value\":\"ok\""));
1544 }
1545
1546 #[test]
1547 fn typed_prompt_response_deserializes_with_deserialize_only_output() {
1548 let response: TypedPromptResponse<DeserializeOnly> = serde_json::from_str(
1549 r#"{"output":{"value":"ok"},"usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3,"cached_input_tokens":0,"cache_creation_input_tokens":0,"reasoning_tokens":0}}"#,
1550 )
1551 .expect("deserialize typed prompt response");
1552
1553 assert_eq!(response.output.value, "ok");
1554 assert_eq!(response.usage.input_tokens, 1);
1555 assert_eq!(response.usage.output_tokens, 2);
1556 assert_eq!(response.usage.total_tokens, 3);
1557 }
1558
1559 #[test]
1560 fn prompt_response_serializes_completion_calls_with_missing_usage() {
1561 let reported_usage = usage(3, 4);
1562 let response = PromptResponse::new("ok", reported_usage).with_completion_calls(vec![
1563 CompletionCall::new(0, None),
1564 CompletionCall::new(1, Some(reported_usage)),
1565 ]);
1566
1567 let value = serde_json::to_value(&response).expect("serialize prompt response");
1568
1569 assert_eq!(
1570 value.get("completion_calls"),
1571 Some(&json!([
1572 {
1573 "call_index": 0,
1574 "usage": null,
1575 },
1576 {
1577 "call_index": 1,
1578 "usage": {
1579 "input_tokens": 3,
1580 "output_tokens": 4,
1581 "total_tokens": 7,
1582 "cached_input_tokens": 0,
1583 "cache_creation_input_tokens": 0,
1584 "tool_use_prompt_tokens": 0,
1585 "reasoning_tokens": 0,
1586 }
1587 }
1588 ]))
1589 );
1590
1591 let response: PromptResponse =
1592 serde_json::from_value(value).expect("deserialize prompt response");
1593 assert_eq!(
1594 response.completion_calls(),
1595 &[
1596 CompletionCall::new(0, None),
1597 CompletionCall::new(1, Some(reported_usage))
1598 ]
1599 );
1600 }
1601
1602 #[tokio::test]
1603 async fn prompt_response_records_completion_call_without_reported_usage() {
1604 let model = MockCompletionModel::new([MockTurn::text("ok")]);
1605 let agent = AgentBuilder::new(model).build();
1606
1607 let response = agent
1608 .prompt("say ok")
1609 .extended_details()
1610 .await
1611 .expect("prompt should succeed");
1612
1613 assert_eq!(response.output, "ok");
1614 assert_eq!(response.usage, Usage::new());
1615 assert_eq!(response.completion_calls(), &[CompletionCall::new(0, None)]);
1616 }
1617
1618 #[tokio::test]
1619 async fn typed_prompt_response_preserves_completion_calls() {
1620 let call_usage = Usage {
1621 input_tokens: 4,
1622 output_tokens: 6,
1623 total_tokens: 10,
1624 cached_input_tokens: 0,
1625 cache_creation_input_tokens: 0,
1626 tool_use_prompt_tokens: 0,
1627 reasoning_tokens: 0,
1628 };
1629 let model =
1630 MockCompletionModel::new([MockTurn::text(r#"{"value":"ok"}"#).with_usage(call_usage)]);
1631 let agent = AgentBuilder::new(model).build();
1632
1633 let response = agent
1634 .prompt_typed::<TypedAnswer>("return typed json")
1635 .extended_details()
1636 .await
1637 .expect("typed prompt should succeed");
1638
1639 assert_eq!(
1640 response.output,
1641 TypedAnswer {
1642 value: "ok".to_string()
1643 }
1644 );
1645 assert_eq!(response.usage, call_usage);
1646 assert_eq!(
1647 response.completion_calls(),
1648 &[CompletionCall::new(0, Some(call_usage))]
1649 );
1650 }
1651
1652 fn validate_follow_up_tool_history(request: &CompletionRequest) {
1653 let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
1654 assert_eq!(
1655 history.len(),
1656 3,
1657 "follow-up request should contain the prompt, assistant tool call, and user tool result: {history:?}"
1658 );
1659
1660 assert!(matches!(
1661 history.first(),
1662 Some(Message::User { content })
1663 if matches!(
1664 content.first(),
1665 UserContent::Text(text) if text.text == "do tool work"
1666 )
1667 ));
1668
1669 assert!(matches!(
1670 history.get(1),
1671 Some(Message::Assistant { content, .. })
1672 if matches!(
1673 content.first(),
1674 AssistantContent::ToolCall(tool_call)
1675 if tool_call.id == "tool_call_1"
1676 && tool_call.call_id.as_deref() == Some("call_1")
1677 )
1678 ));
1679
1680 assert!(matches!(
1681 history.get(2),
1682 Some(Message::User { content })
1683 if matches!(
1684 content.first(),
1685 UserContent::ToolResult(tool_result)
1686 if tool_result.id == "tool_call_1"
1687 && tool_result.call_id.as_deref() == Some("call_1")
1688 )
1689 ));
1690 }
1691
1692 fn history_contains_tool_call(history: &[Message], tool_name: &str) -> bool {
1693 history.iter().any(|message| {
1694 matches!(
1695 message,
1696 Message::Assistant { content, .. }
1697 if content.iter().any(|item| matches!(
1698 item,
1699 AssistantContent::ToolCall(tool_call)
1700 if tool_call.function.name == tool_name
1701 ))
1702 )
1703 })
1704 }
1705
1706 #[tokio::test]
1707 async fn unknown_tool_call_fails_before_non_streaming_second_request() {
1708 let model = MockCompletionModel::new([
1709 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 1, "y": 2})),
1710 MockTurn::text("should not be requested"),
1711 ]);
1712 let recorded = model.clone();
1713 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1714
1715 let err = agent
1716 .prompt("use the tool")
1717 .with_hook(PanicOnUnknownToolHook)
1718 .max_turns(3)
1719 .await
1720 .expect_err("unknown model-emitted tool should fail");
1721
1722 match err {
1723 PromptError::UnknownToolCall {
1724 tool_name,
1725 available_tools,
1726 allowed_tools,
1727 chat_history,
1728 } => {
1729 assert_eq!(tool_name, "default_api");
1730 assert_eq!(available_tools, vec!["add".to_string()]);
1731 assert_eq!(allowed_tools, vec!["add".to_string()]);
1732 assert!(history_contains_tool_call(&chat_history, "default_api"));
1733 }
1734 other => panic!("expected UnknownToolCall, got {other:?}"),
1735 }
1736 assert_eq!(recorded.request_count(), 1);
1737 }
1738
1739 #[tokio::test]
1740 async fn invalid_tool_call_context_uses_completed_tool_call_provider_id() {
1741 let invalid_hook = RecordingInvalidToolCallHook::default();
1742 let model = MockCompletionModel::new([
1743 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 1, "y": 2}))
1744 .with_call_id("provider_call_1"),
1745 MockTurn::text("should not be requested"),
1746 ]);
1747 let recorded = model.clone();
1748 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1749
1750 let err = agent
1751 .prompt("use the tool")
1752 .with_hook(invalid_hook.clone())
1753 .max_turns(3)
1754 .await
1755 .expect_err("invalid tool should fail");
1756
1757 assert!(matches!(err, PromptError::UnknownToolCall { .. }));
1758 assert_eq!(recorded.request_count(), 1);
1759 let contexts = invalid_hook.observed();
1760 assert_eq!(contexts.len(), 1);
1761 let context = &contexts[0];
1762 assert_eq!(context.tool_name, "default_api");
1763 assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
1764 assert_eq!(context.internal_call_id, None);
1765 assert!(!context.is_streaming);
1766 }
1767
1768 #[tokio::test]
1769 async fn disallowed_specific_tool_call_fails_before_non_streaming_second_request() {
1770 let model = MockCompletionModel::new([
1771 MockTurn::tool_call("tool_call_1", "subtract", json!({"x": 3, "y": 1})),
1772 MockTurn::text("should not be requested"),
1773 ]);
1774 let recorded = model.clone();
1775 let agent = AgentBuilder::new(model)
1776 .tool(MockAddTool)
1777 .tool(MockSubtractTool)
1778 .tool_choice(ToolChoice::Specific {
1779 function_names: vec!["add".to_string()],
1780 })
1781 .build();
1782
1783 let err = agent
1784 .prompt("use the allowed tool")
1785 .with_hook(PanicOnUnknownToolHook)
1786 .max_turns(3)
1787 .await
1788 .expect_err("disallowed model-emitted tool should fail");
1789
1790 match err {
1791 PromptError::UnknownToolCall {
1792 tool_name,
1793 available_tools,
1794 allowed_tools,
1795 chat_history,
1796 } => {
1797 assert_eq!(tool_name, "subtract");
1798 assert_eq!(
1799 available_tools,
1800 vec!["add".to_string(), "subtract".to_string()]
1801 );
1802 assert_eq!(allowed_tools, vec!["add".to_string()]);
1803 assert!(history_contains_tool_call(&chat_history, "subtract"));
1804 }
1805 other => panic!("expected UnknownToolCall, got {other:?}"),
1806 }
1807 assert_eq!(recorded.request_count(), 1);
1808 }
1809
1810 #[tokio::test]
1811 async fn tool_choice_none_rejects_non_streaming_tool_call() {
1812 let model = MockCompletionModel::new([
1813 MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2})),
1814 MockTurn::text("should not be requested"),
1815 ]);
1816 let recorded = model.clone();
1817 let agent = AgentBuilder::new(model)
1818 .tool(MockAddTool)
1819 .tool_choice(ToolChoice::None)
1820 .build();
1821
1822 let err = agent
1823 .prompt("do not use tools")
1824 .with_hook(PanicOnUnknownToolHook)
1825 .max_turns(3)
1826 .await
1827 .expect_err("ToolChoice::None should reject returned tool calls");
1828
1829 match err {
1830 PromptError::UnknownToolCall {
1831 tool_name,
1832 available_tools,
1833 allowed_tools,
1834 chat_history,
1835 } => {
1836 assert_eq!(tool_name, "add");
1837 assert_eq!(available_tools, vec!["add".to_string()]);
1838 assert!(allowed_tools.is_empty());
1839 assert!(history_contains_tool_call(&chat_history, "add"));
1840 }
1841 other => panic!("expected UnknownToolCall, got {other:?}"),
1842 }
1843 assert_eq!(recorded.request_count(), 1);
1844 }
1845
1846 #[tokio::test]
1847 async fn invalid_tool_call_hook_can_repair_non_streaming_tool_name() {
1848 let model = MockCompletionModel::new([
1849 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1850 MockTurn::text("done"),
1851 ]);
1852 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1853
1854 let response = agent
1855 .prompt("add")
1856 .with_hook(RepairDefaultApiHook)
1857 .max_turns(3)
1858 .extended_details()
1859 .await
1860 .expect("repaired tool call should execute");
1861
1862 assert_eq!(response.output, "done");
1863 let messages = response.messages.expect("messages should be present");
1864 assert!(history_contains_tool_call(&messages, "add"));
1865 assert!(!history_contains_tool_call(&messages, "default_api"));
1866 assert!(messages.iter().any(|message| {
1867 matches!(
1868 message,
1869 Message::User { content }
1870 if content.iter().any(|content| {
1871 matches!(
1872 content,
1873 UserContent::ToolResult(result)
1874 if result.content.iter().any(|content| {
1875 matches!(
1876 content,
1877 crate::message::ToolResultContent::Text(text)
1878 if text.text == "5"
1879 )
1880 })
1881 )
1882 })
1883 )
1884 }));
1885 }
1886
1887 #[tokio::test]
1888 async fn invalid_tool_call_hook_retry_adds_feedback_and_retries_non_streaming() {
1889 let model = MockCompletionModel::new([
1890 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1891 MockTurn::text("retried"),
1892 ]);
1893 let recorded = model.clone();
1894 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1895
1896 let response = agent
1897 .prompt("add")
1898 .with_hook(RetryDefaultApiHook)
1899 .max_invalid_tool_call_retries(1)
1900 .max_turns(3)
1901 .extended_details()
1902 .await
1903 .expect("retry should recover");
1904
1905 assert_eq!(response.output, "retried");
1906 assert_eq!(recorded.request_count(), 2);
1907 let messages = response.messages.expect("messages should be present");
1908 assert!(messages.iter().any(|message| {
1909 matches!(
1910 message,
1911 Message::User { content }
1912 if content.iter().any(|content| {
1913 matches!(
1914 content,
1915 UserContent::ToolResult(result)
1916 if result.content.iter().any(|content| {
1917 matches!(
1918 content,
1919 crate::message::ToolResultContent::Text(text)
1920 if text.text.contains("Use one of these tools instead")
1921 )
1922 })
1923 )
1924 })
1925 )
1926 }));
1927 }
1928
1929 #[tokio::test]
1930 async fn invalid_tool_call_hook_retries_mixed_non_streaming_turn_without_executing_valid_call()
1931 {
1932 let add_calls = Arc::new(AtomicU32::new(0));
1933 let mut valid_tool_call = ToolCall::new(
1934 "tool_call_1".to_string(),
1935 ToolFunction::new("add".to_string(), json!({"x": 2, "y": 3})),
1936 );
1937 valid_tool_call.call_id = Some("call_1".to_string());
1938 let mut invalid_tool_call = ToolCall::new(
1939 "tool_call_2".to_string(),
1940 ToolFunction::new("default_api".to_string(), json!({"x": 4, "y": 5})),
1941 );
1942 invalid_tool_call.call_id = Some("call_2".to_string());
1943 let model = MockCompletionModel::new([
1944 MockTurn::from_contents([
1945 AssistantContent::ToolCall(valid_tool_call),
1946 AssistantContent::ToolCall(invalid_tool_call),
1947 ])
1948 .expect("tool-call response should be non-empty"),
1949 MockTurn::text("retried"),
1950 ]);
1951 let recorded = model.clone();
1952 let agent = AgentBuilder::new(model)
1953 .tool(CountingAddTool {
1954 calls: add_calls.clone(),
1955 })
1956 .build();
1957
1958 let response = agent
1959 .prompt("add")
1960 .with_hook(RetryDefaultApiHook)
1961 .max_invalid_tool_call_retries(1)
1962 .max_turns(3)
1963 .extended_details()
1964 .await
1965 .expect("retry should recover");
1966
1967 assert_eq!(response.output, "retried");
1968 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
1969 let requests = recorded.requests();
1970 assert_eq!(requests.len(), 2);
1971 let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
1972 assert_eq!(retry_history.len(), 3);
1973 assert!(matches!(
1974 retry_history.get(1),
1975 Some(Message::Assistant { content, .. })
1976 if content.iter().any(|item| matches!(
1977 item,
1978 AssistantContent::ToolCall(tool_call)
1979 if tool_call.id == "tool_call_1"
1980 && tool_call.function.name == "add"
1981 ))
1982 && content.iter().any(|item| matches!(
1983 item,
1984 AssistantContent::ToolCall(tool_call)
1985 if tool_call.id == "tool_call_2"
1986 && tool_call.function.name == "default_api"
1987 ))
1988 ));
1989 assert!(matches!(
1990 retry_history.get(2),
1991 Some(Message::User { content })
1992 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
1993 && content.iter().any(|item| matches!(
1994 item,
1995 UserContent::ToolResult(result)
1996 if result.id == "tool_call_1"
1997 && result.call_id.as_deref() == Some("call_1")
1998 && result.content.iter().any(|content| matches!(
1999 content,
2000 crate::message::ToolResultContent::Text(text)
2001 if text.text == super::TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
2002 ))
2003 ))
2004 && content.iter().any(|item| matches!(
2005 item,
2006 UserContent::ToolResult(result)
2007 if result.id == "tool_call_2"
2008 && result.call_id.as_deref() == Some("call_2")
2009 && result.content.iter().any(|content| matches!(
2010 content,
2011 crate::message::ToolResultContent::Text(text)
2012 if text.text.contains("Use one of these tools instead")
2013 ))
2014 ))
2015 ));
2016 }
2017
2018 #[tokio::test]
2019 async fn invalid_tool_call_hook_skips_mixed_non_streaming_turn_without_executing_valid_call() {
2020 let add_calls = Arc::new(AtomicU32::new(0));
2021 let mut valid_tool_call = ToolCall::new(
2022 "tool_call_1".to_string(),
2023 ToolFunction::new("add".to_string(), json!({"x": 2, "y": 3})),
2024 );
2025 valid_tool_call.call_id = Some("call_1".to_string());
2026 let mut invalid_tool_call = ToolCall::new(
2027 "tool_call_2".to_string(),
2028 ToolFunction::new("default_api".to_string(), json!({"x": 4, "y": 5})),
2029 );
2030 invalid_tool_call.call_id = Some("call_2".to_string());
2031 let model = MockCompletionModel::new([
2032 MockTurn::from_contents([
2033 AssistantContent::ToolCall(valid_tool_call),
2034 AssistantContent::ToolCall(invalid_tool_call),
2035 ])
2036 .expect("tool-call response should be non-empty"),
2037 MockTurn::text("skipped"),
2038 ]);
2039 let agent = AgentBuilder::new(model)
2040 .tool(CountingAddTool {
2041 calls: add_calls.clone(),
2042 })
2043 .build();
2044
2045 let response = agent
2046 .prompt("add")
2047 .with_hook(SkipDefaultApiAndPanicOnToolCallHook)
2048 .max_turns(3)
2049 .extended_details()
2050 .await
2051 .expect("skip should recover without executing peer tools");
2052
2053 assert_eq!(response.output, "skipped");
2054 assert_eq!(add_calls.load(Ordering::SeqCst), 0);
2055 let messages = response.messages.expect("messages should be present");
2056 assert!(history_contains_tool_call(&messages, "add"));
2057 assert!(history_contains_tool_call(&messages, "default_api"));
2058 assert!(matches!(
2059 messages.get(2),
2060 Some(Message::User { content })
2061 if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
2062 && content.iter().any(|item| matches!(
2063 item,
2064 UserContent::ToolResult(result)
2065 if result.id == "tool_call_1"
2066 && result.call_id.as_deref() == Some("call_1")
2067 && result.content.iter().any(|content| matches!(
2068 content,
2069 crate::message::ToolResultContent::Text(text)
2070 if text.text == super::TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
2071 ))
2072 ))
2073 && content.iter().any(|item| matches!(
2074 item,
2075 UserContent::ToolResult(result)
2076 if result.id == "tool_call_2"
2077 && result.call_id.as_deref() == Some("call_2")
2078 && result.content.iter().any(|content| matches!(
2079 content,
2080 crate::message::ToolResultContent::Text(text)
2081 if text.text == "default_api is not available"
2082 ))
2083 ))
2084 ));
2085 }
2086
2087 #[tokio::test]
2088 async fn invalid_tool_call_hook_retry_budget_exhaustion_fails() {
2089 let model = MockCompletionModel::new([
2090 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2091 MockTurn::text("should not be requested"),
2092 ]);
2093 let recorded = model.clone();
2094 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2095
2096 let err = agent
2097 .prompt("add")
2098 .with_hook(RetryDefaultApiHook)
2099 .max_invalid_tool_call_retries(0)
2100 .max_turns(3)
2101 .await
2102 .expect_err("retry without budget should fail");
2103
2104 match err {
2105 PromptError::UnknownToolCall {
2106 tool_name,
2107 chat_history,
2108 ..
2109 } => {
2110 assert_eq!(tool_name, "default_api");
2111 assert!(history_contains_tool_call(&chat_history, "default_api"));
2112 }
2113 other => panic!("expected UnknownToolCall, got {other:?}"),
2114 }
2115 assert_eq!(recorded.request_count(), 1);
2116 }
2117
2118 #[tokio::test]
2119 async fn invalid_tool_call_hook_can_skip_structured_non_streaming_call() {
2120 let model = MockCompletionModel::new([
2121 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2122 MockTurn::text("skipped"),
2123 ]);
2124 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2125
2126 let response = agent
2127 .prompt("add")
2128 .with_hook(SkipDefaultApiHook)
2129 .max_turns(3)
2130 .extended_details()
2131 .await
2132 .expect("skip should continue with synthetic tool result");
2133
2134 assert_eq!(response.output, "skipped");
2135 let messages = response.messages.expect("messages should be present");
2136 assert!(history_contains_tool_call(&messages, "default_api"));
2137 assert!(messages.iter().any(|message| {
2138 matches!(
2139 message,
2140 Message::User { content }
2141 if content.iter().any(|content| {
2142 matches!(
2143 content,
2144 UserContent::ToolResult(result)
2145 if result.content.iter().any(|content| {
2146 matches!(
2147 content,
2148 crate::message::ToolResultContent::Text(text)
2149 if text.text == "default_api is not available"
2150 )
2151 })
2152 )
2153 })
2154 )
2155 }));
2156 }
2157
2158 #[tokio::test]
2159 async fn skip_under_specific_tool_choice_returns_synthetic_feedback() {
2160 let model = MockCompletionModel::new([
2161 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2162 MockTurn::text("skipped"),
2163 ]);
2164 let agent = AgentBuilder::new(model)
2165 .tool(MockAddTool)
2166 .tool_choice(ToolChoice::Specific {
2167 function_names: vec!["add".to_string()],
2168 })
2169 .build();
2170
2171 let response = agent
2172 .prompt("add")
2173 .with_hook(SkipDefaultApiHook)
2174 .max_turns(3)
2175 .extended_details()
2176 .await
2177 .expect("skip should produce synthetic feedback under Specific");
2178
2179 assert_eq!(response.output, "skipped");
2180 let messages = response.messages.expect("messages should be present");
2181 assert!(history_contains_tool_call(&messages, "default_api"));
2182 assert!(messages.iter().any(|message| {
2183 matches!(
2184 message,
2185 Message::User { content }
2186 if content.iter().any(|content| {
2187 matches!(
2188 content,
2189 UserContent::ToolResult(result)
2190 if result.id == "tool_call_1"
2191 && result.content.iter().any(|content| {
2192 matches!(
2193 content,
2194 crate::message::ToolResultContent::Text(text)
2195 if text.text == "default_api is not available"
2196 )
2197 })
2198 )
2199 })
2200 )
2201 }));
2202 }
2203
2204 #[tokio::test]
2205 async fn repair_to_disallowed_specific_tool_fails() {
2206 let model = MockCompletionModel::new([
2207 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2208 MockTurn::text("should not be requested"),
2209 ]);
2210 let recorded = model.clone();
2211 let agent = AgentBuilder::new(model)
2212 .tool(MockAddTool)
2213 .tool(MockSubtractTool)
2214 .tool_choice(ToolChoice::Specific {
2215 function_names: vec!["add".to_string()],
2216 })
2217 .build();
2218
2219 let err = agent
2220 .prompt("add")
2221 .with_hook(RepairToSubtractHook)
2222 .max_turns(3)
2223 .await
2224 .expect_err("repair to a disallowed tool should fail");
2225
2226 match err {
2227 PromptError::UnknownToolCall { tool_name, .. } => {
2228 assert_eq!(tool_name, "subtract");
2229 }
2230 other => panic!("expected UnknownToolCall, got {other:?}"),
2231 }
2232 assert_eq!(recorded.request_count(), 1);
2233 }
2234
2235 #[tokio::test]
2236 async fn repair_under_tool_choice_none_fails() {
2237 let model = MockCompletionModel::new([
2238 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2239 MockTurn::text("should not be requested"),
2240 ]);
2241 let recorded = model.clone();
2242 let agent = AgentBuilder::new(model)
2243 .tool(MockAddTool)
2244 .tool_choice(ToolChoice::None)
2245 .build();
2246
2247 let err = agent
2248 .prompt("do not use tools")
2249 .with_hook(RepairDefaultApiHook)
2250 .max_turns(3)
2251 .await
2252 .expect_err("ToolChoice::None should reject repaired tool calls");
2253
2254 match err {
2255 PromptError::UnknownToolCall { tool_name, .. } => {
2256 assert_eq!(tool_name, "add");
2257 }
2258 other => panic!("expected UnknownToolCall, got {other:?}"),
2259 }
2260 assert_eq!(recorded.request_count(), 1);
2261 }
2262
2263 #[tokio::test]
2264 async fn skip_under_tool_choice_none_fails() {
2265 let model = MockCompletionModel::new([
2266 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2267 MockTurn::text("should not be requested"),
2268 ]);
2269 let recorded = model.clone();
2270 let agent = AgentBuilder::new(model)
2271 .tool(MockAddTool)
2272 .tool_choice(ToolChoice::None)
2273 .build();
2274
2275 let err = agent
2276 .prompt("do not use tools")
2277 .with_hook(SkipDefaultApiHook)
2278 .max_turns(3)
2279 .await
2280 .expect_err("ToolChoice::None should reject skipped tool calls");
2281
2282 match err {
2283 PromptError::UnknownToolCall { tool_name, .. } => {
2284 assert_eq!(tool_name, "default_api");
2285 }
2286 other => panic!("expected UnknownToolCall, got {other:?}"),
2287 }
2288 assert_eq!(recorded.request_count(), 1);
2289 }
2290
2291 #[tokio::test]
2292 async fn typed_prompt_default_invalid_tool_call_fails_fast() {
2293 let model = MockCompletionModel::new([
2294 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2295 MockTurn::text(r#"{"value":"should not be requested"}"#),
2296 ]);
2297 let recorded = model.clone();
2298 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2299
2300 let err = agent
2301 .prompt_typed::<TypedAnswer>("return typed json")
2302 .with_hook(PanicOnUnknownToolHook)
2303 .max_turns(3)
2304 .await
2305 .expect_err("typed prompt should preserve fail-fast default");
2306
2307 match err {
2308 StructuredOutputError::PromptError(err) => match *err {
2309 PromptError::UnknownToolCall { tool_name, .. } => {
2310 assert_eq!(tool_name, "default_api");
2311 }
2312 other => panic!("expected UnknownToolCall, got {other:?}"),
2313 },
2314 other => panic!("expected prompt error, got {other:?}"),
2315 }
2316 assert_eq!(recorded.request_count(), 1);
2317 }
2318
2319 #[tokio::test]
2320 async fn typed_prompt_invalid_tool_call_hook_can_repair_tool_name() {
2321 let model = MockCompletionModel::new([
2322 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2323 MockTurn::text(r#"{"value":"repaired"}"#),
2324 ]);
2325 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2326
2327 let response = agent
2328 .prompt_typed::<TypedAnswer>("return typed json")
2329 .with_hook(RepairDefaultApiHook)
2330 .max_turns(3)
2331 .await
2332 .expect("typed prompt should repair invalid tool call");
2333
2334 assert_eq!(
2335 response,
2336 TypedAnswer {
2337 value: "repaired".to_string()
2338 }
2339 );
2340 }
2341
2342 #[tokio::test]
2343 async fn typed_prompt_invalid_tool_call_hook_can_retry_and_parse_response() {
2344 let model = MockCompletionModel::new([
2345 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2346 MockTurn::text(r#"{"value":"retried"}"#),
2347 ]);
2348 let recorded = model.clone();
2349 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2350
2351 let response = agent
2352 .prompt_typed::<TypedAnswer>("return typed json")
2353 .with_hook(RetryDefaultApiHook)
2354 .max_invalid_tool_call_retries(1)
2355 .max_turns(3)
2356 .await
2357 .expect("typed prompt should retry invalid tool call");
2358
2359 assert_eq!(
2360 response,
2361 TypedAnswer {
2362 value: "retried".to_string()
2363 }
2364 );
2365 assert_eq!(recorded.request_count(), 2);
2366 }
2367
2368 #[tokio::test]
2369 async fn typed_prompt_invalid_tool_call_retry_budget_exhaustion_fails() {
2370 let model = MockCompletionModel::new([
2371 MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2372 MockTurn::text(r#"{"value":"should not be requested"}"#),
2373 ]);
2374 let recorded = model.clone();
2375 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2376
2377 let err = agent
2378 .prompt_typed::<TypedAnswer>("return typed json")
2379 .with_hook(RetryDefaultApiHook)
2380 .max_invalid_tool_call_retries(0)
2381 .max_turns(3)
2382 .await
2383 .expect_err("typed prompt should fail when retry budget is exhausted");
2384
2385 match err {
2386 StructuredOutputError::PromptError(err) => match *err {
2387 PromptError::UnknownToolCall { tool_name, .. } => {
2388 assert_eq!(tool_name, "default_api");
2389 }
2390 other => panic!("expected UnknownToolCall, got {other:?}"),
2391 },
2392 other => panic!("expected prompt error, got {other:?}"),
2393 }
2394 assert_eq!(recorded.request_count(), 1);
2395 }
2396
2397 #[tokio::test]
2398 async fn invalid_specific_tool_choice_fails_before_non_streaming_provider_request() {
2399 let model = MockCompletionModel::text("should not be requested");
2400 let recorded = model.clone();
2401 let agent = AgentBuilder::new(model)
2402 .tool(MockAddTool)
2403 .tool_choice(ToolChoice::Specific {
2404 function_names: vec!["missing".to_string()],
2405 })
2406 .build();
2407
2408 let err = agent
2409 .prompt("use the missing tool")
2410 .await
2411 .expect_err("invalid ToolChoice::Specific should fail before provider request");
2412
2413 match err {
2414 PromptError::CompletionError(CompletionError::RequestError(err)) => {
2415 let msg = err.to_string();
2416 assert!(msg.contains("missing"), "got: {msg}");
2417 assert!(msg.contains("add"), "got: {msg}");
2418 }
2419 other => panic!("expected CompletionError::RequestError, got {other:?}"),
2420 }
2421 assert_eq!(recorded.request_count(), 0);
2422 }
2423
2424 #[tokio::test]
2425 async fn allowed_specific_tool_call_executes_normally() {
2426 let model = MockCompletionModel::new([
2427 MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2})),
2428 MockTurn::text("done"),
2429 ]);
2430 let recorded = model.clone();
2431 let agent = AgentBuilder::new(model)
2432 .tool(MockAddTool)
2433 .tool_choice(ToolChoice::Specific {
2434 function_names: vec!["add".to_string()],
2435 })
2436 .build();
2437
2438 let response = agent
2439 .prompt("use the allowed tool")
2440 .max_turns(3)
2441 .await
2442 .expect("allowed specific tool should execute");
2443
2444 assert_eq!(response, "done");
2445 assert_eq!(recorded.request_count(), 2);
2446 }
2447
2448 #[tokio::test]
2449 async fn prompt_request_stops_cleanly_on_empty_terminal_turn() {
2450 let first_call_usage = Usage {
2451 input_tokens: 1,
2452 output_tokens: 1,
2453 total_tokens: 2,
2454 cached_input_tokens: 0,
2455 cache_creation_input_tokens: 0,
2456 tool_use_prompt_tokens: 0,
2457 reasoning_tokens: 0,
2458 };
2459 let second_call_usage = Usage {
2460 input_tokens: 1,
2461 output_tokens: 1,
2462 total_tokens: 2,
2463 cached_input_tokens: 0,
2464 cache_creation_input_tokens: 0,
2465 tool_use_prompt_tokens: 0,
2466 reasoning_tokens: 0,
2467 };
2468 let model = MockCompletionModel::new([
2469 MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2}))
2470 .with_call_id("call_1")
2471 .with_usage(first_call_usage),
2472 MockTurn::text("").with_usage(second_call_usage),
2473 ]);
2474 let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2475
2476 let response = agent
2477 .prompt("do tool work")
2478 .max_turns(3)
2479 .extended_details()
2480 .await
2481 .expect("empty terminal turn should not error");
2482
2483 assert!(response.output.is_empty());
2484 assert_eq!(
2485 response.usage,
2486 Usage {
2487 input_tokens: 2,
2488 output_tokens: 2,
2489 total_tokens: 4,
2490 cached_input_tokens: 0,
2491 cache_creation_input_tokens: 0,
2492 tool_use_prompt_tokens: 0,
2493 reasoning_tokens: 0,
2494 }
2495 );
2496 assert_eq!(
2497 response.completion_calls(),
2498 &[
2499 CompletionCall::new(0, Some(first_call_usage)),
2500 CompletionCall::new(1, Some(second_call_usage))
2501 ]
2502 );
2503
2504 let history = response
2505 .messages
2506 .expect("extended response should include history");
2507 assert_eq!(history.len(), 3);
2508 assert!(matches!(
2509 history.first(),
2510 Some(Message::User { content })
2511 if matches!(
2512 content.first(),
2513 UserContent::Text(text) if text.text == "do tool work"
2514 )
2515 ));
2516 assert!(history.iter().any(|message| matches!(
2517 message,
2518 Message::Assistant { content, .. }
2519 if matches!(
2520 content.first(),
2521 AssistantContent::ToolCall(tool_call)
2522 if tool_call.id == "tool_call_1"
2523 && tool_call.call_id.as_deref() == Some("call_1")
2524 )
2525 )));
2526 assert!(history.iter().any(|message| matches!(
2527 message,
2528 Message::User { content }
2529 if matches!(
2530 content.first(),
2531 UserContent::ToolResult(tool_result)
2532 if tool_result.id == "tool_call_1"
2533 && tool_result.call_id.as_deref() == Some("call_1")
2534 )
2535 )));
2536 assert!(!history.iter().any(|message| matches!(
2537 message,
2538 Message::Assistant { content, .. }
2539 if content.iter().any(|item| matches!(
2540 item,
2541 AssistantContent::Text(text) if text.text.is_empty()
2542 ))
2543 )));
2544 let requests = agent.model.requests();
2545 assert_eq!(requests.len(), 2);
2546 validate_follow_up_tool_history(&requests[1]);
2547 }
2548
2549 #[tokio::test]
2550 async fn prompt_request_concatenates_text_blocks_without_inserted_newlines() {
2551 let model = MockCompletionModel::new([MockTurn::from_contents([
2552 AssistantContent::Text(Text::new("According to the document, ")),
2553 AssistantContent::Text(Text::new("the grass is green")),
2554 AssistantContent::Text(Text::new(" and the sky is blue.")),
2555 ])
2556 .expect("mock response should contain text blocks")]);
2557 let agent = AgentBuilder::new(model).build();
2558
2559 let response = agent
2560 .prompt("answer with cited spans")
2561 .await
2562 .expect("prompt should succeed");
2563
2564 assert_eq!(
2565 response,
2566 "According to the document, the grass is green and the sky is blue."
2567 );
2568 }
2569
2570 #[tokio::test]
2571 async fn prompt_request_preserves_metadata_only_text_turn_in_history() {
2572 let metadata = json!({
2573 "citations": [{
2574 "type": "web_search_result_location",
2575 "cited_text": "Claude Shannon was born in 1916.",
2576 "url": "https://example.com/shannon",
2577 "title": null,
2578 "encrypted_index": "encrypted-reference"
2579 }]
2580 });
2581 let model =
2582 MockCompletionModel::new([MockTurn::from_content(AssistantContent::Text(Text {
2583 text: String::new(),
2584 additional_params: Some(metadata.clone()),
2585 }))]);
2586 let agent = AgentBuilder::new(model).build();
2587
2588 let response = agent
2589 .prompt("answer with cited metadata")
2590 .extended_details()
2591 .await
2592 .expect("metadata-only text turn should succeed");
2593
2594 assert!(response.output.is_empty());
2595 let history = response
2596 .messages
2597 .expect("extended response should include history");
2598 assert!(history.iter().any(|message| matches!(
2599 message,
2600 Message::Assistant { content, .. }
2601 if matches!(
2602 content.first(),
2603 AssistantContent::Text(text)
2604 if text.text.is_empty()
2605 && text.additional_params.as_ref() == Some(&metadata)
2606 )
2607 )));
2608 }
2609
2610 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
2613
2614 #[tokio::test]
2615 async fn memory_loads_into_request_history() {
2616 let memory = InMemoryConversationMemory::new();
2617 memory
2618 .append(
2619 "thread-1",
2620 vec![Message::user("hello"), Message::assistant("hi there")],
2621 )
2622 .await
2623 .unwrap();
2624
2625 let model = MockCompletionModel::text("ack");
2626 let recorded = model.clone();
2627
2628 let agent = AgentBuilder::new(model).memory(memory).build();
2629 let _ = agent
2630 .prompt("ping")
2631 .conversation("thread-1")
2632 .await
2633 .expect("prompt should succeed");
2634
2635 let received = recorded.requests()[0]
2636 .chat_history
2637 .iter()
2638 .cloned()
2639 .collect::<Vec<_>>();
2640 assert_eq!(
2641 received.len(),
2642 3,
2643 "loaded memory (2) + current prompt should appear in request: {received:?}"
2644 );
2645 }
2646
2647 #[tokio::test]
2648 async fn memory_appends_full_turn_after_success() {
2649 let memory = InMemoryConversationMemory::new();
2650 let model = MockCompletionModel::text("ack");
2651 let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2652
2653 let _ = agent
2654 .prompt("hello")
2655 .conversation("t1")
2656 .await
2657 .expect("prompt should succeed");
2658
2659 let stored = memory.load("t1").await.unwrap();
2660 assert_eq!(stored.len(), 2, "user prompt + assistant response saved");
2661 }
2662
2663 #[tokio::test]
2664 async fn explicit_with_history_overrides_memory() {
2665 let memory = CountingMemory::default();
2666 memory
2667 .inner()
2668 .append("t1", vec![Message::user("from-memory")])
2669 .await
2670 .unwrap();
2671
2672 let model = MockCompletionModel::text("ack");
2673 let recorded = model.clone();
2674
2675 let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2676 let _ = agent
2677 .prompt("hello")
2678 .conversation("t1")
2679 .with_history(vec![Message::user("from-caller")])
2680 .await
2681 .expect("prompt should succeed");
2682
2683 assert_eq!(memory.load_count(), 0, "load skipped");
2684 let appends = memory.append_count();
2685 assert_eq!(appends, 0, "append skipped");
2686
2687 let received = recorded.requests()[0]
2688 .chat_history
2689 .iter()
2690 .cloned()
2691 .collect::<Vec<_>>();
2692 assert_eq!(received.len(), 2, "caller history (1) + current prompt");
2693 assert!(matches!(
2694 received.first(),
2695 Some(Message::User { content })
2696 if matches!(content.first(), UserContent::Text(t) if t.text == "from-caller")
2697 ));
2698 }
2699
2700 #[tokio::test]
2701 async fn memory_unchanged_on_provider_error() {
2702 let memory = InMemoryConversationMemory::new();
2703 let model = MockCompletionModel::new([MockTurn::error("boom")]);
2704
2705 let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2706 let result = agent.prompt("hello").conversation("t1").await;
2707 assert!(result.is_err());
2708
2709 let stored = memory.load("t1").await.unwrap();
2710 assert!(stored.is_empty(), "no append on error");
2711 }
2712
2713 #[tokio::test]
2714 async fn missing_conversation_id_behaves_as_no_memory() {
2715 let memory = CountingMemory::default();
2716 let model = MockCompletionModel::text("ack");
2717 let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2718
2719 let _ = agent.prompt("hello").await.expect("prompt should succeed");
2720
2721 assert_eq!(memory.load_count(), 0);
2722 assert_eq!(memory.append_count(), 0);
2723 }
2724
2725 #[tokio::test]
2726 async fn default_conversation_id_is_used_when_none_per_request() {
2727 let memory = InMemoryConversationMemory::new();
2728 let model = MockCompletionModel::text("ack");
2729 let agent = AgentBuilder::new(model)
2730 .memory(memory.clone())
2731 .conversation_id("default-thread")
2732 .build();
2733
2734 let _ = agent.prompt("hello").await.expect("prompt should succeed");
2735 let stored = memory.load("default-thread").await.unwrap();
2736 assert_eq!(stored.len(), 2);
2737 }
2738
2739 #[tokio::test]
2740 async fn with_filter_truncates_loaded_history() {
2741 let memory = InMemoryConversationMemory::new()
2742 .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
2743 memory
2744 .append(
2745 "t1",
2746 vec![
2747 Message::user("1"),
2748 Message::assistant("2"),
2749 Message::user("3"),
2750 Message::assistant("4"),
2751 ],
2752 )
2753 .await
2754 .unwrap();
2755
2756 let model = MockCompletionModel::text("ack");
2757 let recorded = model.clone();
2758 let agent = AgentBuilder::new(model).memory(memory).build();
2759
2760 let _ = agent
2761 .prompt("ping")
2762 .conversation("t1")
2763 .await
2764 .expect("prompt should succeed");
2765
2766 let received = recorded.requests()[0]
2767 .chat_history
2768 .iter()
2769 .cloned()
2770 .collect::<Vec<_>>();
2771 assert_eq!(
2772 received.len(),
2773 3,
2774 "window-truncated history (2) + current prompt"
2775 );
2776 }
2777
2778 #[tokio::test]
2779 async fn without_memory_disables_for_request() {
2780 let memory = CountingMemory::default();
2781 let model = MockCompletionModel::text("ack");
2782 let agent = AgentBuilder::new(model)
2783 .memory(memory.clone())
2784 .conversation_id("t1")
2785 .build();
2786
2787 let _ = agent
2788 .prompt("hello")
2789 .without_memory()
2790 .await
2791 .expect("prompt should succeed");
2792
2793 assert_eq!(memory.load_count(), 0);
2794 assert_eq!(memory.append_count(), 0);
2795 }
2796
2797 #[tokio::test]
2798 async fn memory_load_error_surfaces_as_prompt_error() {
2799 let model = MockCompletionModel::text("ack");
2800 let agent = AgentBuilder::new(model)
2801 .memory(FailingMemory::default())
2802 .build();
2803 let result = agent.prompt("hello").conversation("t1").await;
2804
2805 match result {
2806 Err(PromptError::CompletionError(CompletionError::RequestError(err))) => {
2807 let msg = format!("{err}");
2808 assert!(msg.contains("load boom"), "got: {msg}");
2809 }
2810 other => panic!("expected PromptError::CompletionError(RequestError), got {other:?}"),
2811 }
2812 }
2813
2814 #[tokio::test]
2815 async fn memory_append_error_does_not_drop_response() {
2816 let model = MockCompletionModel::text("ack");
2817 let agent = AgentBuilder::new(model)
2818 .memory(AppendFailingMemory::default())
2819 .build();
2820 let response: String = agent
2821 .prompt("hello")
2822 .conversation("t1")
2823 .await
2824 .expect("append failure must not block successful completion");
2825
2826 assert!(!response.is_empty());
2827 }
2828}