Skip to main content

rig_core/agent/prompt_request/
mod.rs

1pub mod hooks;
2pub mod streaming;
3
4use super::{
5    Agent,
6    completion::{DynamicContextStore, build_prepared_completion_request},
7};
8use crate::{
9    OneOrMany,
10    completion::{CompletionModel, Document, Message, PromptError, Usage},
11    json_utils,
12    memory::ConversationMemory,
13    message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
14    tool::server::ToolServerHandle,
15    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
16};
17use futures::{StreamExt, stream};
18use hooks::{
19    HookAction, InvalidToolCallContext, InvalidToolCallHookAction, PromptHook, ToolCallHookAction,
20};
21use serde::{Deserialize, Serialize};
22use std::{
23    collections::{BTreeMap, BTreeSet},
24    future::IntoFuture,
25    marker::PhantomData,
26    sync::{
27        Arc,
28        atomic::{AtomicU64, Ordering},
29    },
30};
31use tracing::info_span;
32use tracing::{Instrument, span::Id};
33
34pub trait PromptType {}
35pub struct Standard;
36pub struct Extended;
37
38impl PromptType for Standard {}
39impl PromptType for Extended {}
40
41/// A builder for creating prompt requests with customizable options.
42/// Uses generics to track which options have been set during the build process.
43///
44/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
45/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
46/// attempting to await (which will send the prompt request) can potentially return
47/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
48/// back to back.
49pub struct PromptRequest<S, M, P>
50where
51    S: PromptType,
52    M: CompletionModel,
53    P: PromptHook<M>,
54{
55    /// The prompt message to send to the model
56    prompt: Message,
57    /// Optional chat history provided by the caller.
58    chat_history: Option<Vec<Message>>,
59    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
60    max_turns: usize,
61
62    // Agent data (cloned from agent to allow hook type transitions):
63    /// The completion model
64    model: Arc<M>,
65    /// Agent name for logging
66    agent_name: Option<String>,
67    /// System prompt
68    preamble: Option<String>,
69    /// Static context documents
70    static_context: Vec<Document>,
71    /// Temperature setting
72    temperature: Option<f64>,
73    /// Max tokens setting
74    max_tokens: Option<u64>,
75    /// Additional model parameters
76    additional_params: Option<serde_json::Value>,
77    /// Tool server handle for tool execution
78    tool_server_handle: ToolServerHandle,
79    /// Dynamic context store
80    dynamic_context: DynamicContextStore,
81    /// Tool choice setting
82    tool_choice: Option<ToolChoice>,
83
84    /// Phantom data to track the type of the request
85    state: PhantomData<S>,
86    /// Optional per-request hook for events
87    hook: Option<P>,
88    /// Maximum number of invalid tool-call retries for this request.
89    max_invalid_tool_call_retries: usize,
90    /// How many tools should be executed at the same time (1 by default).
91    concurrency: usize,
92    /// Optional JSON Schema for structured output
93    output_schema: Option<schemars::Schema>,
94    /// Optional conversation memory backend cloned from the agent.
95    memory: Option<Arc<dyn ConversationMemory>>,
96    /// Optional conversation id used for loading and saving memory.
97    conversation_id: Option<String>,
98}
99
100impl<M, P> PromptRequest<Standard, M, P>
101where
102    M: CompletionModel,
103    P: PromptHook<M>,
104{
105    /// Create a new PromptRequest from an agent, cloning the agent's data and default hook.
106    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
107        PromptRequest {
108            prompt: prompt.into(),
109            chat_history: None,
110            max_turns: agent.default_max_turns.unwrap_or_default(),
111            model: agent.model.clone(),
112            agent_name: agent.name.clone(),
113            preamble: agent.preamble.clone(),
114            static_context: agent.static_context.clone(),
115            temperature: agent.temperature,
116            max_tokens: agent.max_tokens,
117            additional_params: agent.additional_params.clone(),
118            tool_server_handle: agent.tool_server_handle.clone(),
119            dynamic_context: agent.dynamic_context.clone(),
120            tool_choice: agent.tool_choice.clone(),
121            state: PhantomData,
122            hook: agent.hook.clone(),
123            max_invalid_tool_call_retries: 0,
124            concurrency: 1,
125            output_schema: agent.output_schema.clone(),
126            memory: agent.memory.clone(),
127            conversation_id: agent.default_conversation_id.clone(),
128        }
129    }
130}
131
132impl<S, M, P> PromptRequest<S, M, P>
133where
134    S: PromptType,
135    M: CompletionModel,
136    P: PromptHook<M>,
137{
138    /// Enable returning extended details for responses (includes aggregated token usage
139    /// and the full message history accumulated during the agent loop).
140    ///
141    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
142    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
143    /// of conversation and inspecting the full message exchange.
144    pub fn extended_details(self) -> PromptRequest<Extended, M, P> {
145        PromptRequest {
146            prompt: self.prompt,
147            chat_history: self.chat_history,
148            max_turns: self.max_turns,
149            model: self.model,
150            agent_name: self.agent_name,
151            preamble: self.preamble,
152            static_context: self.static_context,
153            temperature: self.temperature,
154            max_tokens: self.max_tokens,
155            additional_params: self.additional_params,
156            tool_server_handle: self.tool_server_handle,
157            dynamic_context: self.dynamic_context,
158            tool_choice: self.tool_choice,
159            state: PhantomData,
160            hook: self.hook,
161            max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
162            concurrency: self.concurrency,
163            output_schema: self.output_schema,
164            memory: self.memory,
165            conversation_id: self.conversation_id,
166        }
167    }
168
169    /// Set the maximum number of turns for multi-turn conversations. A given agent may require multiple turns for tool-calling before giving an answer.
170    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
171    pub fn max_turns(mut self, depth: usize) -> Self {
172        self.max_turns = depth;
173        self
174    }
175
176    /// Add concurrency to the prompt request.
177    /// This will cause the agent to execute tools concurrently.
178    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
179        self.concurrency = concurrency;
180        self
181    }
182
183    /// Add chat history to the prompt request.
184    pub fn with_history<H, T>(mut self, history: H) -> Self
185    where
186        H: IntoIterator<Item = T>,
187        T: Into<Message>,
188    {
189        self.chat_history = Some(history.into_iter().map(Into::into).collect());
190        self
191    }
192
193    /// Set the conversation id used to load and persist memory for this request.
194    ///
195    /// Overrides any default conversation id set on the agent. If memory is not
196    /// configured on the agent, this has no effect.
197    pub fn conversation(mut self, id: impl Into<String>) -> Self {
198        self.conversation_id = Some(id.into());
199        self
200    }
201
202    /// Disable conversation memory for this request.
203    ///
204    /// History will neither be loaded from nor saved to the agent's memory backend.
205    pub fn without_memory(mut self) -> Self {
206        self.memory = None;
207        self.conversation_id = None;
208        self
209    }
210
211    /// Attach a per-request hook for tool call events.
212    /// This overrides any default hook set on the agent.
213    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<S, M, P2>
214    where
215        P2: PromptHook<M>,
216    {
217        PromptRequest {
218            prompt: self.prompt,
219            chat_history: self.chat_history,
220            max_turns: self.max_turns,
221            model: self.model,
222            agent_name: self.agent_name,
223            preamble: self.preamble,
224            static_context: self.static_context,
225            temperature: self.temperature,
226            max_tokens: self.max_tokens,
227            additional_params: self.additional_params,
228            tool_server_handle: self.tool_server_handle,
229            dynamic_context: self.dynamic_context,
230            tool_choice: self.tool_choice,
231            state: PhantomData,
232            hook: Some(hook),
233            max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
234            concurrency: self.concurrency,
235            output_schema: self.output_schema,
236            memory: self.memory,
237            conversation_id: self.conversation_id,
238        }
239    }
240
241    /// Set the retry budget for [`InvalidToolCallHookAction::Retry`].
242    ///
243    /// Invalid tool-call retries also consume normal multi-turn depth.
244    pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
245        self.max_invalid_tool_call_retries = retries;
246        self
247    }
248}
249
250/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
251///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
252///  directly via the associated type.
253impl<M, P> IntoFuture for PromptRequest<Standard, M, P>
254where
255    M: CompletionModel + 'static,
256    P: PromptHook<M> + 'static,
257{
258    type Output = Result<String, PromptError>;
259    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
260
261    fn into_future(self) -> Self::IntoFuture {
262        Box::pin(self.send())
263    }
264}
265
266impl<M, P> IntoFuture for PromptRequest<Extended, M, P>
267where
268    M: CompletionModel + 'static,
269    P: PromptHook<M> + 'static,
270{
271    type Output = Result<PromptResponse, PromptError>;
272    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
273
274    fn into_future(self) -> Self::IntoFuture {
275        Box::pin(self.send())
276    }
277}
278
279impl<M, P> PromptRequest<Standard, M, P>
280where
281    M: CompletionModel,
282    P: PromptHook<M>,
283{
284    async fn send(self) -> Result<String, PromptError> {
285        self.extended_details().send().await.map(|resp| resp.output)
286    }
287}
288
289/// Details for one successfully completed completion request made by an agent run.
290#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
291#[non_exhaustive]
292pub struct CompletionCall {
293    /// Zero-based index of the completion request within this agent run.
294    pub call_index: usize,
295    /// Token usage reported for this completion request, when the provider supplied it.
296    ///
297    /// Rig normalizes zero-valued [`Usage`] to `None` because zero-valued usage
298    /// is the existing sentinel for missing provider usage metrics.
299    #[serde(default)]
300    pub usage: Option<Usage>,
301}
302
303impl CompletionCall {
304    /// Create details for one completion request in an agent run.
305    pub fn new(call_index: usize, usage: Option<Usage>) -> Self {
306        Self { call_index, usage }
307    }
308
309    pub(crate) fn from_reported_usage(call_index: usize, usage: Usage) -> Self {
310        Self::new(call_index, reported_usage(usage))
311    }
312}
313
314pub(crate) fn reported_usage(usage: Usage) -> Option<Usage> {
315    (usage != Usage::new()).then_some(usage)
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
319#[non_exhaustive]
320pub struct PromptResponse {
321    pub output: String,
322    pub usage: Usage,
323    /// Successfully completed completion requests made by this agent run, with token usage when available.
324    ///
325    /// `usage` remains the aggregate across the whole run. Use the last entry's
326    /// usage, when present, to inspect the final completion request's
327    /// prompt/context length.
328    /// If a provider does not report token usage, its entry contains `None`.
329    #[serde(default, skip_serializing_if = "Vec::is_empty")]
330    pub completion_calls: Vec<CompletionCall>,
331    pub messages: Option<Vec<Message>>,
332}
333
334impl std::fmt::Display for PromptResponse {
335    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336        self.output.fmt(f)
337    }
338}
339
340impl PromptResponse {
341    pub fn new(output: impl Into<String>, usage: Usage) -> Self {
342        Self {
343            output: output.into(),
344            usage,
345            completion_calls: Vec::new(),
346            messages: None,
347        }
348    }
349
350    pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
351        self.messages = Some(messages);
352        self
353    }
354
355    /// Attach completion call details to this response.
356    pub fn with_completion_calls(mut self, completion_calls: Vec<CompletionCall>) -> Self {
357        self.completion_calls = completion_calls;
358        self
359    }
360
361    /// Returns successfully completed completion requests made by this agent run, with token usage when available.
362    ///
363    /// If a provider does not report token usage, its entry contains `None`.
364    pub fn completion_calls(&self) -> &[CompletionCall] {
365        &self.completion_calls
366    }
367}
368
369#[derive(Debug, Clone, Serialize, Deserialize)]
370#[non_exhaustive]
371pub struct TypedPromptResponse<T> {
372    pub output: T,
373    pub usage: Usage,
374    /// Successfully completed completion requests made by this agent run, with token usage when available.
375    ///
376    /// `usage` remains the aggregate across the whole run. Use the last entry's
377    /// usage, when present, to inspect the final completion request's
378    /// prompt/context length.
379    /// If a provider does not report token usage, its entry contains `None`.
380    #[serde(default, skip_serializing_if = "Vec::is_empty")]
381    pub completion_calls: Vec<CompletionCall>,
382}
383
384impl<T> TypedPromptResponse<T> {
385    pub fn new(output: T, usage: Usage) -> Self {
386        Self {
387            output,
388            usage,
389            completion_calls: Vec::new(),
390        }
391    }
392
393    /// Attach completion call details to this response.
394    pub fn with_completion_calls(mut self, completion_calls: Vec<CompletionCall>) -> Self {
395        self.completion_calls = completion_calls;
396        self
397    }
398
399    /// Returns successfully completed completion requests made by this agent run, with token usage when available.
400    ///
401    /// If a provider does not report token usage, its entry contains `None`.
402    pub fn completion_calls(&self) -> &[CompletionCall] {
403        &self.completion_calls
404    }
405}
406
407const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
408
409pub(crate) const TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER: &str =
410    "Tool not executed because another tool call in the same assistant turn was invalid.";
411
412/// Combine input history with new messages for building completion requests.
413fn build_history_for_request(
414    chat_history: Option<&[Message]>,
415    new_messages: &[Message],
416) -> Vec<Message> {
417    let input = chat_history.unwrap_or(&[]);
418    input.iter().chain(new_messages.iter()).cloned().collect()
419}
420
421/// Build the full history for error reporting (input + new messages).
422fn build_full_history(
423    chat_history: Option<&[Message]>,
424    new_messages: Vec<Message>,
425) -> Vec<Message> {
426    let input = chat_history.unwrap_or(&[]);
427    input.iter().cloned().chain(new_messages).collect()
428}
429
430fn tool_result_user_content(
431    id: String,
432    call_id: Option<String>,
433    tool_result: String,
434) -> UserContent {
435    let content = ToolResultContent::from_tool_output(tool_result);
436    match call_id {
437        Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
438        None => UserContent::tool_result(id, content),
439    }
440}
441
442fn invalid_tool_retry_user_message(
443    assistant_content: &OneOrMany<AssistantContent>,
444    invalid_tool_call_id: &str,
445    feedback: String,
446) -> Option<Message> {
447    let retry_results = assistant_content
448        .iter()
449        .filter_map(|content| match content {
450            AssistantContent::ToolCall(tool_call) if tool_call.id == invalid_tool_call_id => {
451                Some(tool_result_user_content(
452                    tool_call.id.clone(),
453                    tool_call.call_id.clone(),
454                    feedback.clone(),
455                ))
456            }
457            AssistantContent::ToolCall(tool_call) => Some(tool_result_user_content(
458                tool_call.id.clone(),
459                tool_call.call_id.clone(),
460                TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
461            )),
462            _ => None,
463        })
464        .collect::<Vec<_>>();
465
466    Some(Message::User {
467        content: OneOrMany::from_iter_optional(retry_results)?,
468    })
469}
470
471pub(crate) fn validate_tool_call_name(
472    tool_name: &str,
473    executable_tool_names: &BTreeSet<String>,
474    allowed_tool_names: &BTreeSet<String>,
475    chat_history: Vec<Message>,
476) -> Result<(), PromptError> {
477    if allowed_tool_names.contains(tool_name) {
478        return Ok(());
479    }
480
481    Err(PromptError::UnknownToolCall {
482        tool_name: tool_name.to_owned(),
483        available_tools: executable_tool_names.iter().cloned().collect(),
484        allowed_tools: allowed_tool_names.iter().cloned().collect(),
485        chat_history: Box::new(chat_history),
486    })
487}
488
489enum InvalidToolCallResolution {
490    Fail(PromptError),
491    Retry(String),
492    Repair(String),
493    Skip(String),
494}
495
496#[allow(clippy::too_many_arguments)]
497async fn resolve_invalid_tool_call<M, P>(
498    hook: Option<&P>,
499    tool_name: &str,
500    tool_call_id: Option<String>,
501    internal_call_id: Option<String>,
502    args: Option<String>,
503    executable_tool_names: &BTreeSet<String>,
504    allowed_tool_names: &BTreeSet<String>,
505    tool_choice: Option<&ToolChoice>,
506    chat_history: Vec<Message>,
507    is_streaming: bool,
508) -> InvalidToolCallResolution
509where
510    M: CompletionModel,
511    P: PromptHook<M>,
512{
513    let err = PromptError::UnknownToolCall {
514        tool_name: tool_name.to_owned(),
515        available_tools: executable_tool_names.iter().cloned().collect(),
516        allowed_tools: allowed_tool_names.iter().cloned().collect(),
517        chat_history: Box::new(chat_history.clone()),
518    };
519
520    let Some(hook) = hook else {
521        return InvalidToolCallResolution::Fail(err);
522    };
523
524    let context = InvalidToolCallContext {
525        tool_name: tool_name.to_owned(),
526        tool_call_id,
527        internal_call_id,
528        args,
529        available_tools: executable_tool_names.iter().cloned().collect(),
530        allowed_tools: allowed_tool_names.iter().cloned().collect(),
531        tool_choice: tool_choice.cloned(),
532        chat_history,
533        is_streaming,
534    };
535
536    match hook.on_invalid_tool_call(&context).await {
537        InvalidToolCallHookAction::Fail => InvalidToolCallResolution::Fail(err),
538        InvalidToolCallHookAction::Retry { feedback } => InvalidToolCallResolution::Retry(feedback),
539        InvalidToolCallHookAction::Repair { tool_name } => {
540            if allowed_tool_names.contains(&tool_name) {
541                InvalidToolCallResolution::Repair(tool_name)
542            } else {
543                InvalidToolCallResolution::Fail(PromptError::UnknownToolCall {
544                    tool_name,
545                    available_tools: executable_tool_names.iter().cloned().collect(),
546                    allowed_tools: allowed_tool_names.iter().cloned().collect(),
547                    chat_history: Box::new(context.chat_history),
548                })
549            }
550        }
551        InvalidToolCallHookAction::Skip { reason } => {
552            if matches!(tool_choice, Some(ToolChoice::None)) {
553                InvalidToolCallResolution::Fail(err)
554            } else {
555                InvalidToolCallResolution::Skip(reason)
556            }
557        }
558    }
559}
560
561fn is_empty_assistant_turn(choice: &OneOrMany<AssistantContent>) -> bool {
562    choice.len() == 1
563        && matches!(
564            choice.first(),
565            AssistantContent::Text(text) if text.text.is_empty() && text.additional_params.is_none()
566        )
567}
568
569fn assistant_text_from_choice(choice: &OneOrMany<AssistantContent>) -> String {
570    choice
571        .iter()
572        .filter_map(|content| match content {
573            AssistantContent::Text(text) => Some(text.text.as_str()),
574            _ => None,
575        })
576        .collect()
577}
578
579impl<M, P> PromptRequest<Extended, M, P>
580where
581    M: CompletionModel,
582    P: PromptHook<M>,
583{
584    fn agent_name(&self) -> &str {
585        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
586    }
587
588    async fn send(self) -> Result<PromptResponse, PromptError> {
589        let agent_span = if tracing::Span::current().is_disabled() {
590            info_span!(
591                "invoke_agent",
592                gen_ai.operation.name = "invoke_agent",
593                gen_ai.agent.name = self.agent_name(),
594                gen_ai.system_instructions = self.preamble,
595                gen_ai.prompt = tracing::field::Empty,
596                gen_ai.completion = tracing::field::Empty,
597                gen_ai.usage.input_tokens = tracing::field::Empty,
598                gen_ai.usage.output_tokens = tracing::field::Empty,
599                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
600                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
601                gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
602                gen_ai.usage.reasoning_tokens = tracing::field::Empty,
603            )
604        } else {
605            tracing::Span::current()
606        };
607
608        if let Some(text) = self.prompt.rag_text() {
609            agent_span.record("gen_ai.prompt", text);
610        }
611
612        let agent_name_for_span = self.agent_name.clone();
613        // When the caller passes explicit history, memory is fully bypassed for this
614        // request (no load AND no save). Otherwise, if a memory backend and
615        // conversation id are both configured, load prior history; if either is
616        // missing, behave as if no memory is configured.
617        let (chat_history, memory_handle) = match self.chat_history {
618            Some(history) => (Some(history), None),
619            None => match (self.memory, self.conversation_id) {
620                (Some(memory), Some(id)) => {
621                    let loaded = memory.load(&id).await?;
622                    (Some(loaded), Some((memory, id)))
623                }
624                _ => (None, None),
625            },
626        };
627        let mut new_messages: Vec<Message> = vec![self.prompt.clone()];
628
629        let mut current_max_turns = 0;
630        let mut usage = Usage::new();
631        let mut completion_calls = Vec::new();
632        let mut completion_call_index = 0;
633        let mut invalid_tool_call_retries = 0;
634        let current_span_id: AtomicU64 = AtomicU64::new(0);
635
636        // We need to do at least 2 loops for 1 roundtrip (user expects normal message)
637        let last_prompt = 'prompt_loop: loop {
638            // Get the last message (the current prompt)
639            let Some((prompt_ref, history_for_current_turn)) = new_messages.split_last() else {
640                return Err(PromptError::prompt_cancelled(
641                    build_full_history(chat_history.as_deref(), new_messages),
642                    "prompt loop lost its pending prompt",
643                ));
644            };
645            let prompt = prompt_ref.clone();
646
647            if current_max_turns > self.max_turns + 1 {
648                break prompt;
649            }
650
651            current_max_turns += 1;
652
653            if self.max_turns > 1 {
654                tracing::info!(
655                    "Current conversation depth: {}/{}",
656                    current_max_turns,
657                    self.max_turns
658                );
659            }
660
661            // Build history for hook callback (input + new messages except last)
662            let history_for_hook =
663                build_history_for_request(chat_history.as_deref(), history_for_current_turn);
664
665            if let Some(ref hook) = self.hook
666                && let HookAction::Terminate { reason } =
667                    hook.on_completion_call(&prompt, &history_for_hook).await
668            {
669                return Err(PromptError::prompt_cancelled(
670                    build_full_history(chat_history.as_deref(), new_messages),
671                    reason,
672                ));
673            }
674
675            let span = tracing::Span::current();
676            let chat_span = info_span!(
677                target: "rig::agent_chat",
678                parent: &span,
679                "chat",
680                gen_ai.operation.name = "chat",
681                gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
682                gen_ai.system_instructions = self.preamble,
683                gen_ai.provider.name = tracing::field::Empty,
684                gen_ai.request.model = tracing::field::Empty,
685                gen_ai.response.id = tracing::field::Empty,
686                gen_ai.response.model = tracing::field::Empty,
687                gen_ai.usage.output_tokens = tracing::field::Empty,
688                gen_ai.usage.input_tokens = tracing::field::Empty,
689                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
690                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
691                gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
692                gen_ai.usage.reasoning_tokens = tracing::field::Empty,
693                gen_ai.input.messages = tracing::field::Empty,
694                gen_ai.output.messages = tracing::field::Empty,
695            );
696
697            let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
698                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
699                chat_span.follows_from(id).to_owned()
700            } else {
701                chat_span
702            };
703
704            if let Some(id) = chat_span.id() {
705                current_span_id.store(id.into_u64(), Ordering::SeqCst);
706            };
707
708            // Build history for completion request (input + new messages except last)
709            let history_for_request =
710                build_history_for_request(chat_history.as_deref(), history_for_current_turn);
711
712            let prepared_request = build_prepared_completion_request(
713                &self.model,
714                prompt.clone(),
715                &history_for_request,
716                self.preamble.as_deref(),
717                &self.static_context,
718                self.temperature,
719                self.max_tokens,
720                self.additional_params.as_ref(),
721                self.tool_choice.as_ref(),
722                &self.tool_server_handle,
723                &self.dynamic_context,
724                self.output_schema.as_ref(),
725            )
726            .await?;
727            let executable_tool_names = prepared_request.executable_tool_names.clone();
728            let allowed_tool_names = prepared_request.allowed_tool_names.clone();
729
730            let resp = prepared_request
731                .builder
732                .send()
733                .instrument(chat_span.clone())
734                .await?;
735
736            completion_calls.push(CompletionCall::from_reported_usage(
737                completion_call_index,
738                resp.usage,
739            ));
740            completion_call_index += 1;
741            usage += resp.usage;
742
743            let mut response_choice = resp.choice.clone();
744            let has_tool_calls = response_choice
745                .iter()
746                .any(|choice| matches!(choice, AssistantContent::ToolCall(_)));
747
748            // Some providers normalize textless terminal turns into a single empty text item
749            // because the generic completion response cannot represent an empty choice. Treat
750            // that sentinel as "no assistant output" so it does not pollute returned history.
751            let mut skipped_tool_results = BTreeMap::new();
752            let mut invalid_tool_call_recovered = false;
753            let mut invalid_tool_call_skipped = false;
754            if has_tool_calls {
755                for choice in response_choice.iter_mut() {
756                    let AssistantContent::ToolCall(tool_call) = choice else {
757                        continue;
758                    };
759
760                    if allowed_tool_names.contains(&tool_call.function.name) {
761                        continue;
762                    }
763
764                    let mut diagnostic_messages = new_messages.clone();
765                    diagnostic_messages.push(Message::Assistant {
766                        id: resp.message_id.clone(),
767                        content: resp.choice.clone(),
768                    });
769                    let diagnostic_history =
770                        build_full_history(chat_history.as_deref(), diagnostic_messages);
771                    let args = json_utils::value_to_json_string(&tool_call.function.arguments);
772                    let emitted_tool_name = tool_call.function.name.clone();
773
774                    match resolve_invalid_tool_call::<M, P>(
775                        self.hook.as_ref(),
776                        &emitted_tool_name,
777                        Some(tool_call.id.clone()),
778                        None,
779                        Some(args),
780                        &executable_tool_names,
781                        &allowed_tool_names,
782                        self.tool_choice.as_ref(),
783                        diagnostic_history.clone(),
784                        false,
785                    )
786                    .await
787                    {
788                        InvalidToolCallResolution::Fail(err) => return Err(err),
789                        InvalidToolCallResolution::Retry(feedback) => {
790                            if invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
791                                return Err(PromptError::UnknownToolCall {
792                                    tool_name: emitted_tool_name,
793                                    available_tools: executable_tool_names
794                                        .iter()
795                                        .cloned()
796                                        .collect(),
797                                    allowed_tools: allowed_tool_names.iter().cloned().collect(),
798                                    chat_history: Box::new(diagnostic_history.clone()),
799                                });
800                            }
801
802                            invalid_tool_call_retries += 1;
803                            new_messages.push(Message::Assistant {
804                                id: resp.message_id.clone(),
805                                content: resp.choice.clone(),
806                            });
807                            let Some(user_message) = invalid_tool_retry_user_message(
808                                &resp.choice,
809                                &tool_call.id,
810                                feedback,
811                            ) else {
812                                return Err(PromptError::prompt_cancelled(
813                                    diagnostic_history,
814                                    "invalid tool call retry produced no retry messages",
815                                ));
816                            };
817                            new_messages.push(user_message);
818                            continue 'prompt_loop;
819                        }
820                        InvalidToolCallResolution::Repair(repaired_name) => {
821                            tool_call.function.name = repaired_name;
822                            invalid_tool_call_recovered = true;
823                        }
824                        InvalidToolCallResolution::Skip(reason) => {
825                            let user_content = if let Some(call_id) = tool_call.call_id.clone() {
826                                UserContent::tool_result_with_call_id(
827                                    tool_call.id.clone(),
828                                    call_id,
829                                    OneOrMany::one(reason.into()),
830                                )
831                            } else {
832                                UserContent::tool_result(
833                                    tool_call.id.clone(),
834                                    OneOrMany::one(reason.into()),
835                                )
836                            };
837                            skipped_tool_results.insert(tool_call.id.clone(), user_content);
838                            invalid_tool_call_recovered = true;
839                            invalid_tool_call_skipped = true;
840                        }
841                    }
842                }
843            }
844
845            if invalid_tool_call_skipped {
846                for choice in response_choice.iter() {
847                    let AssistantContent::ToolCall(tool_call) = choice else {
848                        continue;
849                    };
850
851                    skipped_tool_results
852                        .entry(tool_call.id.clone())
853                        .or_insert_with(|| {
854                            tool_result_user_content(
855                                tool_call.id.clone(),
856                                tool_call.call_id.clone(),
857                                TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
858                            )
859                        });
860                }
861            }
862
863            let assistant_response_message =
864                (!is_empty_assistant_turn(&response_choice)).then(|| Message::Assistant {
865                    id: resp.message_id.clone(),
866                    content: response_choice.clone(),
867                });
868
869            if !invalid_tool_call_recovered
870                && let Some(ref hook) = self.hook
871                && let HookAction::Terminate { reason } =
872                    hook.on_completion_response(&prompt, &resp).await
873            {
874                return Err(PromptError::prompt_cancelled(
875                    build_full_history(chat_history.as_deref(), new_messages),
876                    reason,
877                ));
878            }
879
880            if let Some(message) = assistant_response_message {
881                new_messages.push(message);
882            }
883
884            if !has_tool_calls {
885                let merged_texts = assistant_text_from_choice(&response_choice);
886
887                if self.max_turns > 1 {
888                    tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
889                }
890
891                agent_span.record("gen_ai.completion", &merged_texts);
892                agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
893                agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
894                agent_span.record(
895                    "gen_ai.usage.cache_read.input_tokens",
896                    usage.cached_input_tokens,
897                );
898                agent_span.record(
899                    "gen_ai.usage.cache_creation.input_tokens",
900                    usage.cache_creation_input_tokens,
901                );
902                agent_span.record(
903                    "gen_ai.usage.tool_use_prompt_tokens",
904                    usage.tool_use_prompt_tokens,
905                );
906                agent_span.record("gen_ai.usage.reasoning_tokens", usage.reasoning_tokens);
907
908                if let Some((memory, id)) = memory_handle.as_ref()
909                    && let Err(err) = memory.append(id, new_messages.clone()).await
910                {
911                    tracing::warn!(
912                        error = %err,
913                        conversation_id = %id,
914                        "conversation memory append failed; returning model response anyway"
915                    );
916                }
917
918                return Ok(PromptResponse::new(merged_texts, usage)
919                    .with_messages(new_messages)
920                    .with_completion_calls(completion_calls));
921            }
922
923            let hook = self.hook.clone();
924            let tool_server_handle = self.tool_server_handle.clone();
925            let skipped_tool_results = Arc::new(skipped_tool_results);
926
927            // For error handling in concurrent tool execution, we need to build full history
928            let full_history_for_errors =
929                build_full_history(chat_history.as_deref(), new_messages.clone());
930
931            let tool_calls: Vec<AssistantContent> = response_choice
932                .iter()
933                .filter(|choice| matches!(choice, AssistantContent::ToolCall(_)))
934                .cloned()
935                .collect();
936            let tool_content = stream::iter(tool_calls)
937                .map(|choice| {
938                    let hook1 = hook.clone();
939                    let hook2 = hook.clone();
940                    let tool_server_handle = tool_server_handle.clone();
941                    let skipped_tool_results = skipped_tool_results.clone();
942
943                    let tool_span = info_span!(
944                        "execute_tool",
945                        gen_ai.operation.name = "execute_tool",
946                        gen_ai.tool.type = "function",
947                        gen_ai.tool.name = tracing::field::Empty,
948                        gen_ai.tool.call.id = tracing::field::Empty,
949                        gen_ai.tool.call.arguments = tracing::field::Empty,
950                        gen_ai.tool.call.result = tracing::field::Empty
951                    );
952
953                    let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
954                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
955                        tool_span.follows_from(id).to_owned()
956                    } else {
957                        tool_span
958                    };
959
960                    if let Some(id) = tool_span.id() {
961                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
962                    };
963
964                    // Clone full history for error reporting in concurrent tool execution
965                    let cloned_history_for_error = full_history_for_errors.clone();
966
967                    async move {
968                        if let AssistantContent::ToolCall(tool_call) = choice {
969                            let tool_name = &tool_call.function.name;
970                            let args =
971                                json_utils::value_to_json_string(&tool_call.function.arguments);
972                            let internal_call_id = nanoid::nanoid!();
973                            if let Some(result) = skipped_tool_results.get(&tool_call.id) {
974                                return Ok(result.clone());
975                            }
976                            let tool_span = tracing::Span::current();
977                            tool_span.record("gen_ai.tool.name", tool_name);
978                            tool_span.record("gen_ai.tool.call.id", &tool_call.id);
979                            tool_span.record("gen_ai.tool.call.arguments", &args);
980                            if let Some(hook) = hook1 {
981                                let action = hook
982                                    .on_tool_call(
983                                        tool_name,
984                                        tool_call.call_id.clone(),
985                                        &internal_call_id,
986                                        &args,
987                                    )
988                                    .await;
989
990                                if let ToolCallHookAction::Terminate { reason } = action {
991                                    return Err(PromptError::prompt_cancelled(
992                                        cloned_history_for_error,
993                                        reason,
994                                    ));
995                                }
996
997                                if let ToolCallHookAction::Skip { reason } = action {
998                                    // Tool execution rejected, return rejection message as tool result
999                                    tracing::info!(
1000                                        tool_name = tool_name,
1001                                        reason = reason,
1002                                        "Tool call rejected"
1003                                    );
1004                                    if let Some(call_id) = tool_call.call_id.clone() {
1005                                        return Ok(UserContent::tool_result_with_call_id(
1006                                            tool_call.id.clone(),
1007                                            call_id,
1008                                            OneOrMany::one(reason.into()),
1009                                        ));
1010                                    } else {
1011                                        return Ok(UserContent::tool_result(
1012                                            tool_call.id.clone(),
1013                                            OneOrMany::one(reason.into()),
1014                                        ));
1015                                    }
1016                                }
1017                            }
1018                            let output = match tool_server_handle.call_tool(tool_name, &args).await
1019                            {
1020                                Ok(res) => res,
1021                                Err(e) => {
1022                                    tracing::warn!("Error while executing tool: {e}");
1023                                    e.to_string()
1024                                }
1025                            };
1026                            if let Some(hook) = hook2
1027                                && let HookAction::Terminate { reason } = hook
1028                                    .on_tool_result(
1029                                        tool_name,
1030                                        tool_call.call_id.clone(),
1031                                        &internal_call_id,
1032                                        &args,
1033                                        &output.to_string(),
1034                                    )
1035                                    .await
1036                            {
1037                                return Err(PromptError::prompt_cancelled(
1038                                    cloned_history_for_error,
1039                                    reason,
1040                                ));
1041                            }
1042
1043                            tool_span.record("gen_ai.tool.call.result", &output);
1044                            tracing::info!(
1045                                "executed tool {tool_name} with args {args}. result: {output}"
1046                            );
1047                            if let Some(call_id) = tool_call.call_id.clone() {
1048                                Ok(UserContent::tool_result_with_call_id(
1049                                    tool_call.id.clone(),
1050                                    call_id,
1051                                    ToolResultContent::from_tool_output(output),
1052                                ))
1053                            } else {
1054                                Ok(UserContent::tool_result(
1055                                    tool_call.id.clone(),
1056                                    ToolResultContent::from_tool_output(output),
1057                                ))
1058                            }
1059                        } else {
1060                            Err(PromptError::prompt_cancelled(
1061                                Vec::new(),
1062                                "tool execution received non-tool assistant content",
1063                            ))
1064                        }
1065                    }
1066                    .instrument(tool_span)
1067                })
1068                .buffer_unordered(self.concurrency)
1069                .collect::<Vec<Result<UserContent, PromptError>>>()
1070                .await
1071                .into_iter()
1072                .collect::<Result<Vec<_>, _>>()?;
1073
1074            let Some(content) = OneOrMany::from_iter_optional(tool_content) else {
1075                return Err(PromptError::prompt_cancelled(
1076                    build_full_history(chat_history.as_deref(), new_messages),
1077                    "tool execution produced no tool results",
1078                ));
1079            };
1080
1081            new_messages.push(Message::User { content });
1082        };
1083
1084        // If we reach here, we exceeded max turns without a final response
1085        Err(PromptError::MaxTurnsError {
1086            max_turns: self.max_turns,
1087            chat_history: build_full_history(chat_history.as_deref(), new_messages).into(),
1088            prompt: last_prompt.into(),
1089        })
1090    }
1091}
1092
1093// ================================================================
1094// TypedPromptRequest - for structured output with automatic deserialization
1095// ================================================================
1096
1097use crate::completion::StructuredOutputError;
1098use schemars::{JsonSchema, schema_for};
1099use serde::de::DeserializeOwned;
1100
1101/// A builder for creating typed prompt requests that return deserialized structured output.
1102///
1103/// This struct wraps a standard `PromptRequest` and adds:
1104/// - Automatic JSON schema generation from the target type `T`
1105/// - Automatic deserialization of the response into `T`
1106///
1107/// The type parameter `S` represents the state of the request (Standard or Extended).
1108/// Use `.extended_details()` to transition to Extended state for usage tracking.
1109///
1110/// # Example
1111/// ```rust,ignore
1112/// let forecast: WeatherForecast = agent
1113///     .prompt_typed("What's the weather in NYC?")
1114///     .max_turns(3)
1115///     .await?;
1116/// ```
1117pub struct TypedPromptRequest<T, S, M, P>
1118where
1119    T: JsonSchema + DeserializeOwned + WasmCompatSend,
1120    S: PromptType,
1121    M: CompletionModel,
1122    P: PromptHook<M>,
1123{
1124    inner: PromptRequest<S, M, P>,
1125    _phantom: std::marker::PhantomData<T>,
1126}
1127
1128impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
1129where
1130    T: JsonSchema + DeserializeOwned + WasmCompatSend,
1131    M: CompletionModel,
1132    P: PromptHook<M>,
1133{
1134    /// Create a new TypedPromptRequest from an agent.
1135    ///
1136    /// This automatically sets the output schema based on the type parameter `T`.
1137    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
1138        let mut inner = PromptRequest::from_agent(agent, prompt);
1139        // Override the output schema with the schema for T
1140        inner.output_schema = Some(schema_for!(T));
1141        Self {
1142            inner,
1143            _phantom: std::marker::PhantomData,
1144        }
1145    }
1146}
1147
1148impl<T, S, M, P> TypedPromptRequest<T, S, M, P>
1149where
1150    T: JsonSchema + DeserializeOwned + WasmCompatSend,
1151    S: PromptType,
1152    M: CompletionModel,
1153    P: PromptHook<M>,
1154{
1155    /// Enable returning extended details for responses (includes aggregated token usage).
1156    ///
1157    /// Note: This changes the type of the response from `.send()` to return a `TypedPromptResponse<T>` struct
1158    /// instead of just `T`. This is useful for tracking token usage across multiple turns
1159    /// of conversation.
1160    pub fn extended_details(self) -> TypedPromptRequest<T, Extended, M, P> {
1161        TypedPromptRequest {
1162            inner: self.inner.extended_details(),
1163            _phantom: std::marker::PhantomData,
1164        }
1165    }
1166
1167    /// Set the maximum number of turns for multi-turn conversations.
1168    ///
1169    /// A given agent may require multiple turns for tool-calling before giving an answer.
1170    /// If the maximum turn number is exceeded, it will return a
1171    /// [`StructuredOutputError::PromptError`] wrapping a `MaxTurnsError`.
1172    pub fn max_turns(mut self, depth: usize) -> Self {
1173        self.inner = self.inner.max_turns(depth);
1174        self
1175    }
1176
1177    /// Set the retry budget for invalid tool-call recovery.
1178    ///
1179    /// Invalid tool-call retries also consume normal multi-turn depth.
1180    pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
1181        self.inner = self.inner.max_invalid_tool_call_retries(retries);
1182        self
1183    }
1184
1185    /// Add concurrency to the prompt request.
1186    ///
1187    /// This will cause the agent to execute tools concurrently.
1188    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
1189        self.inner = self.inner.with_tool_concurrency(concurrency);
1190        self
1191    }
1192
1193    /// Add chat history to the prompt request.
1194    pub fn with_history<H, U>(mut self, history: H) -> Self
1195    where
1196        H: IntoIterator<Item = U>,
1197        U: Into<Message>,
1198    {
1199        self.inner = self.inner.with_history(history);
1200        self
1201    }
1202
1203    /// Set the conversation id used to load and persist memory for this request.
1204    ///
1205    /// Overrides any default conversation id set on the agent. If memory is not
1206    /// configured on the agent, this has no effect.
1207    pub fn conversation(mut self, id: impl Into<String>) -> Self {
1208        self.inner = self.inner.conversation(id);
1209        self
1210    }
1211
1212    /// Disable conversation memory for this request.
1213    ///
1214    /// History will neither be loaded from nor saved to the agent's memory backend.
1215    pub fn without_memory(mut self) -> Self {
1216        self.inner = self.inner.without_memory();
1217        self
1218    }
1219
1220    /// Attach a per-request hook for tool call events.
1221    ///
1222    /// This overrides any default hook set on the agent.
1223    pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<T, S, M, P2>
1224    where
1225        P2: PromptHook<M>,
1226    {
1227        TypedPromptRequest {
1228            inner: self.inner.with_hook(hook),
1229            _phantom: std::marker::PhantomData,
1230        }
1231    }
1232}
1233
1234impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
1235where
1236    T: JsonSchema + DeserializeOwned + WasmCompatSend,
1237    M: CompletionModel,
1238    P: PromptHook<M>,
1239{
1240    /// Send the typed prompt request and deserialize the response.
1241    async fn send(self) -> Result<T, StructuredOutputError> {
1242        let response = self.inner.send().await.map_err(Box::new)?;
1243
1244        if response.is_empty() {
1245            return Err(StructuredOutputError::EmptyResponse);
1246        }
1247
1248        let parsed: T = serde_json::from_str(&response)?;
1249        Ok(parsed)
1250    }
1251}
1252
1253impl<T, M, P> TypedPromptRequest<T, Extended, M, P>
1254where
1255    T: JsonSchema + DeserializeOwned + WasmCompatSend,
1256    M: CompletionModel,
1257    P: PromptHook<M>,
1258{
1259    /// Send the typed prompt request with extended details and deserialize the response.
1260    async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
1261        let response = self.inner.send().await.map_err(Box::new)?;
1262
1263        if response.output.is_empty() {
1264            return Err(StructuredOutputError::EmptyResponse);
1265        }
1266
1267        let parsed: T = serde_json::from_str(&response.output)?;
1268        Ok(TypedPromptResponse::new(parsed, response.usage)
1269            .with_completion_calls(response.completion_calls))
1270    }
1271}
1272
1273impl<T, M, P> IntoFuture for TypedPromptRequest<T, Standard, M, P>
1274where
1275    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
1276    M: CompletionModel + 'static,
1277    P: PromptHook<M> + 'static,
1278{
1279    type Output = Result<T, StructuredOutputError>;
1280    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1281
1282    fn into_future(self) -> Self::IntoFuture {
1283        Box::pin(self.send())
1284    }
1285}
1286
1287impl<T, M, P> IntoFuture for TypedPromptRequest<T, Extended, M, P>
1288where
1289    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
1290    M: CompletionModel + 'static,
1291    P: PromptHook<M> + 'static,
1292{
1293    type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
1294    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1295
1296    fn into_future(self) -> Self::IntoFuture {
1297        Box::pin(self.send())
1298    }
1299}
1300
1301#[cfg(test)]
1302mod tests {
1303    use super::{CompletionCall, PromptResponse, TypedPromptResponse};
1304    use crate::{
1305        agent::{
1306            AgentBuilder,
1307            prompt_request::hooks::{
1308                HookAction, InvalidToolCallContext, InvalidToolCallHookAction, PromptHook,
1309                ToolCallHookAction,
1310            },
1311        },
1312        completion::{
1313            AssistantContent, CompletionError, CompletionModel, CompletionRequest, Message, Prompt,
1314            PromptError, StructuredOutputError, ToolDefinition, TypedPrompt, Usage,
1315        },
1316        message::{Text, ToolCall, ToolChoice, ToolFunction, UserContent},
1317        test_utils::{
1318            AppendFailingMemory, CountingMemory, FailingMemory, MockAddTool, MockCompletionModel,
1319            MockOperationArgs, MockSubtractTool, MockToolError, MockTurn,
1320        },
1321        tool::Tool,
1322    };
1323    use schemars::JsonSchema;
1324    use serde::{Deserialize, Serialize};
1325    use serde_json::json;
1326    use std::sync::{
1327        Arc, Mutex,
1328        atomic::{AtomicU32, Ordering},
1329    };
1330
1331    #[derive(Serialize)]
1332    struct SerializeOnly {
1333        value: &'static str,
1334    }
1335
1336    #[derive(Deserialize)]
1337    struct DeserializeOnly {
1338        value: String,
1339    }
1340
1341    #[derive(Debug, Deserialize, JsonSchema, PartialEq)]
1342    struct TypedAnswer {
1343        value: String,
1344    }
1345
1346    #[derive(Clone)]
1347    struct PanicOnUnknownToolHook;
1348
1349    impl PromptHook<MockCompletionModel> for PanicOnUnknownToolHook {
1350        async fn on_completion_response(
1351            &self,
1352            _prompt: &Message,
1353            _response: &crate::completion::CompletionResponse<
1354                <MockCompletionModel as CompletionModel>::Response,
1355            >,
1356        ) -> HookAction {
1357            panic!("unknown tool response should fail before response hooks run")
1358        }
1359
1360        async fn on_tool_call(
1361            &self,
1362            _tool_name: &str,
1363            _tool_call_id: Option<String>,
1364            _internal_call_id: &str,
1365            _args: &str,
1366        ) -> ToolCallHookAction {
1367            panic!("unknown tool call should fail before tool hooks run")
1368        }
1369    }
1370
1371    #[derive(Clone)]
1372    struct PanicOnToolCallHook;
1373
1374    impl PromptHook<MockCompletionModel> for PanicOnToolCallHook {
1375        async fn on_tool_call(
1376            &self,
1377            _tool_name: &str,
1378            _tool_call_id: Option<String>,
1379            _internal_call_id: &str,
1380            _args: &str,
1381        ) -> ToolCallHookAction {
1382            panic!("recovered invalid turn should not invoke normal tool hooks")
1383        }
1384    }
1385
1386    #[derive(Clone)]
1387    struct SkipDefaultApiAndPanicOnToolCallHook;
1388
1389    impl PromptHook<MockCompletionModel> for SkipDefaultApiAndPanicOnToolCallHook {
1390        async fn on_invalid_tool_call(
1391            &self,
1392            context: &InvalidToolCallContext,
1393        ) -> InvalidToolCallHookAction {
1394            SkipDefaultApiHook.on_invalid_tool_call(context).await
1395        }
1396
1397        async fn on_tool_call(
1398            &self,
1399            tool_name: &str,
1400            tool_call_id: Option<String>,
1401            internal_call_id: &str,
1402            args: &str,
1403        ) -> ToolCallHookAction {
1404            PanicOnToolCallHook
1405                .on_tool_call(tool_name, tool_call_id, internal_call_id, args)
1406                .await
1407        }
1408    }
1409
1410    #[derive(Clone)]
1411    struct RepairDefaultApiHook;
1412
1413    impl PromptHook<MockCompletionModel> for RepairDefaultApiHook {
1414        fn on_invalid_tool_call(
1415            &self,
1416            context: &InvalidToolCallContext,
1417        ) -> impl std::future::Future<Output = InvalidToolCallHookAction> + Send {
1418            let tool_name = context.tool_name.clone();
1419            async move {
1420                assert_eq!(tool_name, "default_api");
1421                InvalidToolCallHookAction::repair("add")
1422            }
1423        }
1424    }
1425
1426    #[derive(Clone)]
1427    struct RepairToSubtractHook;
1428
1429    impl PromptHook<MockCompletionModel> for RepairToSubtractHook {
1430        async fn on_invalid_tool_call(
1431            &self,
1432            _context: &InvalidToolCallContext,
1433        ) -> InvalidToolCallHookAction {
1434            InvalidToolCallHookAction::repair("subtract")
1435        }
1436    }
1437
1438    #[derive(Clone)]
1439    struct RetryDefaultApiHook;
1440
1441    impl PromptHook<MockCompletionModel> for RetryDefaultApiHook {
1442        fn on_invalid_tool_call(
1443            &self,
1444            context: &InvalidToolCallContext,
1445        ) -> impl std::future::Future<Output = InvalidToolCallHookAction> + Send {
1446            let allowed_tools = context.allowed_tools.clone();
1447            async move {
1448                InvalidToolCallHookAction::retry(format!(
1449                    "Use one of these tools instead: {allowed_tools:?}"
1450                ))
1451            }
1452        }
1453    }
1454
1455    #[derive(Clone)]
1456    struct SkipDefaultApiHook;
1457
1458    impl PromptHook<MockCompletionModel> for SkipDefaultApiHook {
1459        async fn on_invalid_tool_call(
1460            &self,
1461            _context: &InvalidToolCallContext,
1462        ) -> InvalidToolCallHookAction {
1463            InvalidToolCallHookAction::skip("default_api is not available")
1464        }
1465    }
1466
1467    #[derive(Clone, Default)]
1468    struct RecordingInvalidToolCallHook {
1469        contexts: Arc<Mutex<Vec<InvalidToolCallContext>>>,
1470    }
1471
1472    impl RecordingInvalidToolCallHook {
1473        fn observed(&self) -> Vec<InvalidToolCallContext> {
1474            self.contexts
1475                .lock()
1476                .expect("invalid tool context records mutex was poisoned")
1477                .clone()
1478        }
1479    }
1480
1481    impl PromptHook<MockCompletionModel> for RecordingInvalidToolCallHook {
1482        async fn on_invalid_tool_call(
1483            &self,
1484            context: &InvalidToolCallContext,
1485        ) -> InvalidToolCallHookAction {
1486            self.contexts
1487                .lock()
1488                .expect("invalid tool context records mutex was poisoned")
1489                .push(context.clone());
1490            InvalidToolCallHookAction::fail()
1491        }
1492    }
1493
1494    #[derive(Clone)]
1495    struct CountingAddTool {
1496        calls: Arc<AtomicU32>,
1497    }
1498
1499    impl Tool for CountingAddTool {
1500        const NAME: &'static str = "add";
1501        type Error = MockToolError;
1502        type Args = MockOperationArgs;
1503        type Output = i32;
1504
1505        async fn definition(&self, _prompt: String) -> ToolDefinition {
1506            MockAddTool.definition(String::new()).await
1507        }
1508
1509        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
1510            self.calls.fetch_add(1, Ordering::SeqCst);
1511            Ok(0)
1512        }
1513    }
1514
1515    fn usage(input_tokens: u64, output_tokens: u64) -> Usage {
1516        Usage {
1517            input_tokens,
1518            output_tokens,
1519            total_tokens: input_tokens + output_tokens,
1520            cached_input_tokens: 0,
1521            cache_creation_input_tokens: 0,
1522            tool_use_prompt_tokens: 0,
1523            reasoning_tokens: 0,
1524        }
1525    }
1526
1527    #[test]
1528    fn typed_prompt_response_serializes_with_serialize_only_output() {
1529        let response = TypedPromptResponse::new(
1530            SerializeOnly { value: "ok" },
1531            Usage {
1532                input_tokens: 1,
1533                output_tokens: 2,
1534                total_tokens: 3,
1535                cached_input_tokens: 0,
1536                cache_creation_input_tokens: 0,
1537                tool_use_prompt_tokens: 0,
1538                reasoning_tokens: 0,
1539            },
1540        );
1541
1542        let json = serde_json::to_string(&response).expect("serialize typed prompt response");
1543        assert!(json.contains("\"value\":\"ok\""));
1544    }
1545
1546    #[test]
1547    fn typed_prompt_response_deserializes_with_deserialize_only_output() {
1548        let response: TypedPromptResponse<DeserializeOnly> = serde_json::from_str(
1549            r#"{"output":{"value":"ok"},"usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3,"cached_input_tokens":0,"cache_creation_input_tokens":0,"reasoning_tokens":0}}"#,
1550        )
1551        .expect("deserialize typed prompt response");
1552
1553        assert_eq!(response.output.value, "ok");
1554        assert_eq!(response.usage.input_tokens, 1);
1555        assert_eq!(response.usage.output_tokens, 2);
1556        assert_eq!(response.usage.total_tokens, 3);
1557    }
1558
1559    #[test]
1560    fn prompt_response_serializes_completion_calls_with_missing_usage() {
1561        let reported_usage = usage(3, 4);
1562        let response = PromptResponse::new("ok", reported_usage).with_completion_calls(vec![
1563            CompletionCall::new(0, None),
1564            CompletionCall::new(1, Some(reported_usage)),
1565        ]);
1566
1567        let value = serde_json::to_value(&response).expect("serialize prompt response");
1568
1569        assert_eq!(
1570            value.get("completion_calls"),
1571            Some(&json!([
1572                {
1573                    "call_index": 0,
1574                    "usage": null,
1575                },
1576                {
1577                    "call_index": 1,
1578                    "usage": {
1579                        "input_tokens": 3,
1580                        "output_tokens": 4,
1581                        "total_tokens": 7,
1582                        "cached_input_tokens": 0,
1583                        "cache_creation_input_tokens": 0,
1584                        "tool_use_prompt_tokens": 0,
1585                        "reasoning_tokens": 0,
1586                    }
1587                }
1588            ]))
1589        );
1590
1591        let response: PromptResponse =
1592            serde_json::from_value(value).expect("deserialize prompt response");
1593        assert_eq!(
1594            response.completion_calls(),
1595            &[
1596                CompletionCall::new(0, None),
1597                CompletionCall::new(1, Some(reported_usage))
1598            ]
1599        );
1600    }
1601
1602    #[tokio::test]
1603    async fn prompt_response_records_completion_call_without_reported_usage() {
1604        let model = MockCompletionModel::new([MockTurn::text("ok")]);
1605        let agent = AgentBuilder::new(model).build();
1606
1607        let response = agent
1608            .prompt("say ok")
1609            .extended_details()
1610            .await
1611            .expect("prompt should succeed");
1612
1613        assert_eq!(response.output, "ok");
1614        assert_eq!(response.usage, Usage::new());
1615        assert_eq!(response.completion_calls(), &[CompletionCall::new(0, None)]);
1616    }
1617
1618    #[tokio::test]
1619    async fn typed_prompt_response_preserves_completion_calls() {
1620        let call_usage = Usage {
1621            input_tokens: 4,
1622            output_tokens: 6,
1623            total_tokens: 10,
1624            cached_input_tokens: 0,
1625            cache_creation_input_tokens: 0,
1626            tool_use_prompt_tokens: 0,
1627            reasoning_tokens: 0,
1628        };
1629        let model =
1630            MockCompletionModel::new([MockTurn::text(r#"{"value":"ok"}"#).with_usage(call_usage)]);
1631        let agent = AgentBuilder::new(model).build();
1632
1633        let response = agent
1634            .prompt_typed::<TypedAnswer>("return typed json")
1635            .extended_details()
1636            .await
1637            .expect("typed prompt should succeed");
1638
1639        assert_eq!(
1640            response.output,
1641            TypedAnswer {
1642                value: "ok".to_string()
1643            }
1644        );
1645        assert_eq!(response.usage, call_usage);
1646        assert_eq!(
1647            response.completion_calls(),
1648            &[CompletionCall::new(0, Some(call_usage))]
1649        );
1650    }
1651
1652    fn validate_follow_up_tool_history(request: &CompletionRequest) {
1653        let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
1654        assert_eq!(
1655            history.len(),
1656            3,
1657            "follow-up request should contain the prompt, assistant tool call, and user tool result: {history:?}"
1658        );
1659
1660        assert!(matches!(
1661            history.first(),
1662            Some(Message::User { content })
1663                if matches!(
1664                    content.first(),
1665                    UserContent::Text(text) if text.text == "do tool work"
1666                )
1667        ));
1668
1669        assert!(matches!(
1670            history.get(1),
1671            Some(Message::Assistant { content, .. })
1672                if matches!(
1673                    content.first(),
1674                    AssistantContent::ToolCall(tool_call)
1675                        if tool_call.id == "tool_call_1"
1676                            && tool_call.call_id.as_deref() == Some("call_1")
1677                )
1678        ));
1679
1680        assert!(matches!(
1681            history.get(2),
1682            Some(Message::User { content })
1683                if matches!(
1684                    content.first(),
1685                    UserContent::ToolResult(tool_result)
1686                        if tool_result.id == "tool_call_1"
1687                            && tool_result.call_id.as_deref() == Some("call_1")
1688                )
1689        ));
1690    }
1691
1692    fn history_contains_tool_call(history: &[Message], tool_name: &str) -> bool {
1693        history.iter().any(|message| {
1694            matches!(
1695                message,
1696                Message::Assistant { content, .. }
1697                    if content.iter().any(|item| matches!(
1698                        item,
1699                        AssistantContent::ToolCall(tool_call)
1700                            if tool_call.function.name == tool_name
1701                    ))
1702            )
1703        })
1704    }
1705
1706    #[tokio::test]
1707    async fn unknown_tool_call_fails_before_non_streaming_second_request() {
1708        let model = MockCompletionModel::new([
1709            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 1, "y": 2})),
1710            MockTurn::text("should not be requested"),
1711        ]);
1712        let recorded = model.clone();
1713        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1714
1715        let err = agent
1716            .prompt("use the tool")
1717            .with_hook(PanicOnUnknownToolHook)
1718            .max_turns(3)
1719            .await
1720            .expect_err("unknown model-emitted tool should fail");
1721
1722        match err {
1723            PromptError::UnknownToolCall {
1724                tool_name,
1725                available_tools,
1726                allowed_tools,
1727                chat_history,
1728            } => {
1729                assert_eq!(tool_name, "default_api");
1730                assert_eq!(available_tools, vec!["add".to_string()]);
1731                assert_eq!(allowed_tools, vec!["add".to_string()]);
1732                assert!(history_contains_tool_call(&chat_history, "default_api"));
1733            }
1734            other => panic!("expected UnknownToolCall, got {other:?}"),
1735        }
1736        assert_eq!(recorded.request_count(), 1);
1737    }
1738
1739    #[tokio::test]
1740    async fn invalid_tool_call_context_uses_completed_tool_call_provider_id() {
1741        let invalid_hook = RecordingInvalidToolCallHook::default();
1742        let model = MockCompletionModel::new([
1743            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 1, "y": 2}))
1744                .with_call_id("provider_call_1"),
1745            MockTurn::text("should not be requested"),
1746        ]);
1747        let recorded = model.clone();
1748        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1749
1750        let err = agent
1751            .prompt("use the tool")
1752            .with_hook(invalid_hook.clone())
1753            .max_turns(3)
1754            .await
1755            .expect_err("invalid tool should fail");
1756
1757        assert!(matches!(err, PromptError::UnknownToolCall { .. }));
1758        assert_eq!(recorded.request_count(), 1);
1759        let contexts = invalid_hook.observed();
1760        assert_eq!(contexts.len(), 1);
1761        let context = &contexts[0];
1762        assert_eq!(context.tool_name, "default_api");
1763        assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
1764        assert_eq!(context.internal_call_id, None);
1765        assert!(!context.is_streaming);
1766    }
1767
1768    #[tokio::test]
1769    async fn disallowed_specific_tool_call_fails_before_non_streaming_second_request() {
1770        let model = MockCompletionModel::new([
1771            MockTurn::tool_call("tool_call_1", "subtract", json!({"x": 3, "y": 1})),
1772            MockTurn::text("should not be requested"),
1773        ]);
1774        let recorded = model.clone();
1775        let agent = AgentBuilder::new(model)
1776            .tool(MockAddTool)
1777            .tool(MockSubtractTool)
1778            .tool_choice(ToolChoice::Specific {
1779                function_names: vec!["add".to_string()],
1780            })
1781            .build();
1782
1783        let err = agent
1784            .prompt("use the allowed tool")
1785            .with_hook(PanicOnUnknownToolHook)
1786            .max_turns(3)
1787            .await
1788            .expect_err("disallowed model-emitted tool should fail");
1789
1790        match err {
1791            PromptError::UnknownToolCall {
1792                tool_name,
1793                available_tools,
1794                allowed_tools,
1795                chat_history,
1796            } => {
1797                assert_eq!(tool_name, "subtract");
1798                assert_eq!(
1799                    available_tools,
1800                    vec!["add".to_string(), "subtract".to_string()]
1801                );
1802                assert_eq!(allowed_tools, vec!["add".to_string()]);
1803                assert!(history_contains_tool_call(&chat_history, "subtract"));
1804            }
1805            other => panic!("expected UnknownToolCall, got {other:?}"),
1806        }
1807        assert_eq!(recorded.request_count(), 1);
1808    }
1809
1810    #[tokio::test]
1811    async fn tool_choice_none_rejects_non_streaming_tool_call() {
1812        let model = MockCompletionModel::new([
1813            MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2})),
1814            MockTurn::text("should not be requested"),
1815        ]);
1816        let recorded = model.clone();
1817        let agent = AgentBuilder::new(model)
1818            .tool(MockAddTool)
1819            .tool_choice(ToolChoice::None)
1820            .build();
1821
1822        let err = agent
1823            .prompt("do not use tools")
1824            .with_hook(PanicOnUnknownToolHook)
1825            .max_turns(3)
1826            .await
1827            .expect_err("ToolChoice::None should reject returned tool calls");
1828
1829        match err {
1830            PromptError::UnknownToolCall {
1831                tool_name,
1832                available_tools,
1833                allowed_tools,
1834                chat_history,
1835            } => {
1836                assert_eq!(tool_name, "add");
1837                assert_eq!(available_tools, vec!["add".to_string()]);
1838                assert!(allowed_tools.is_empty());
1839                assert!(history_contains_tool_call(&chat_history, "add"));
1840            }
1841            other => panic!("expected UnknownToolCall, got {other:?}"),
1842        }
1843        assert_eq!(recorded.request_count(), 1);
1844    }
1845
1846    #[tokio::test]
1847    async fn invalid_tool_call_hook_can_repair_non_streaming_tool_name() {
1848        let model = MockCompletionModel::new([
1849            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1850            MockTurn::text("done"),
1851        ]);
1852        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1853
1854        let response = agent
1855            .prompt("add")
1856            .with_hook(RepairDefaultApiHook)
1857            .max_turns(3)
1858            .extended_details()
1859            .await
1860            .expect("repaired tool call should execute");
1861
1862        assert_eq!(response.output, "done");
1863        let messages = response.messages.expect("messages should be present");
1864        assert!(history_contains_tool_call(&messages, "add"));
1865        assert!(!history_contains_tool_call(&messages, "default_api"));
1866        assert!(messages.iter().any(|message| {
1867            matches!(
1868                message,
1869                Message::User { content }
1870                    if content.iter().any(|content| {
1871                        matches!(
1872                            content,
1873                            UserContent::ToolResult(result)
1874                                if result.content.iter().any(|content| {
1875                                    matches!(
1876                                        content,
1877                                        crate::message::ToolResultContent::Text(text)
1878                                            if text.text == "5"
1879                                    )
1880                                })
1881                        )
1882                    })
1883            )
1884        }));
1885    }
1886
1887    #[tokio::test]
1888    async fn invalid_tool_call_hook_retry_adds_feedback_and_retries_non_streaming() {
1889        let model = MockCompletionModel::new([
1890            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1891            MockTurn::text("retried"),
1892        ]);
1893        let recorded = model.clone();
1894        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1895
1896        let response = agent
1897            .prompt("add")
1898            .with_hook(RetryDefaultApiHook)
1899            .max_invalid_tool_call_retries(1)
1900            .max_turns(3)
1901            .extended_details()
1902            .await
1903            .expect("retry should recover");
1904
1905        assert_eq!(response.output, "retried");
1906        assert_eq!(recorded.request_count(), 2);
1907        let messages = response.messages.expect("messages should be present");
1908        assert!(messages.iter().any(|message| {
1909            matches!(
1910                message,
1911                Message::User { content }
1912                    if content.iter().any(|content| {
1913                        matches!(
1914                            content,
1915                            UserContent::ToolResult(result)
1916                                if result.content.iter().any(|content| {
1917                                    matches!(
1918                                        content,
1919                                        crate::message::ToolResultContent::Text(text)
1920                                            if text.text.contains("Use one of these tools instead")
1921                                    )
1922                                })
1923                        )
1924                    })
1925            )
1926        }));
1927    }
1928
1929    #[tokio::test]
1930    async fn invalid_tool_call_hook_retries_mixed_non_streaming_turn_without_executing_valid_call()
1931    {
1932        let add_calls = Arc::new(AtomicU32::new(0));
1933        let mut valid_tool_call = ToolCall::new(
1934            "tool_call_1".to_string(),
1935            ToolFunction::new("add".to_string(), json!({"x": 2, "y": 3})),
1936        );
1937        valid_tool_call.call_id = Some("call_1".to_string());
1938        let mut invalid_tool_call = ToolCall::new(
1939            "tool_call_2".to_string(),
1940            ToolFunction::new("default_api".to_string(), json!({"x": 4, "y": 5})),
1941        );
1942        invalid_tool_call.call_id = Some("call_2".to_string());
1943        let model = MockCompletionModel::new([
1944            MockTurn::from_contents([
1945                AssistantContent::ToolCall(valid_tool_call),
1946                AssistantContent::ToolCall(invalid_tool_call),
1947            ])
1948            .expect("tool-call response should be non-empty"),
1949            MockTurn::text("retried"),
1950        ]);
1951        let recorded = model.clone();
1952        let agent = AgentBuilder::new(model)
1953            .tool(CountingAddTool {
1954                calls: add_calls.clone(),
1955            })
1956            .build();
1957
1958        let response = agent
1959            .prompt("add")
1960            .with_hook(RetryDefaultApiHook)
1961            .max_invalid_tool_call_retries(1)
1962            .max_turns(3)
1963            .extended_details()
1964            .await
1965            .expect("retry should recover");
1966
1967        assert_eq!(response.output, "retried");
1968        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
1969        let requests = recorded.requests();
1970        assert_eq!(requests.len(), 2);
1971        let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
1972        assert_eq!(retry_history.len(), 3);
1973        assert!(matches!(
1974            retry_history.get(1),
1975            Some(Message::Assistant { content, .. })
1976                if content.iter().any(|item| matches!(
1977                    item,
1978                    AssistantContent::ToolCall(tool_call)
1979                        if tool_call.id == "tool_call_1"
1980                            && tool_call.function.name == "add"
1981                ))
1982                    && content.iter().any(|item| matches!(
1983                        item,
1984                        AssistantContent::ToolCall(tool_call)
1985                            if tool_call.id == "tool_call_2"
1986                                && tool_call.function.name == "default_api"
1987                    ))
1988        ));
1989        assert!(matches!(
1990            retry_history.get(2),
1991            Some(Message::User { content })
1992                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
1993                    && content.iter().any(|item| matches!(
1994                        item,
1995                        UserContent::ToolResult(result)
1996                            if result.id == "tool_call_1"
1997                                && result.call_id.as_deref() == Some("call_1")
1998                                && result.content.iter().any(|content| matches!(
1999                                    content,
2000                                    crate::message::ToolResultContent::Text(text)
2001                                        if text.text == super::TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
2002                                ))
2003                    ))
2004                    && content.iter().any(|item| matches!(
2005                        item,
2006                        UserContent::ToolResult(result)
2007                            if result.id == "tool_call_2"
2008                                && result.call_id.as_deref() == Some("call_2")
2009                                && result.content.iter().any(|content| matches!(
2010                                    content,
2011                                    crate::message::ToolResultContent::Text(text)
2012                                        if text.text.contains("Use one of these tools instead")
2013                                ))
2014            ))
2015        ));
2016    }
2017
2018    #[tokio::test]
2019    async fn invalid_tool_call_hook_skips_mixed_non_streaming_turn_without_executing_valid_call() {
2020        let add_calls = Arc::new(AtomicU32::new(0));
2021        let mut valid_tool_call = ToolCall::new(
2022            "tool_call_1".to_string(),
2023            ToolFunction::new("add".to_string(), json!({"x": 2, "y": 3})),
2024        );
2025        valid_tool_call.call_id = Some("call_1".to_string());
2026        let mut invalid_tool_call = ToolCall::new(
2027            "tool_call_2".to_string(),
2028            ToolFunction::new("default_api".to_string(), json!({"x": 4, "y": 5})),
2029        );
2030        invalid_tool_call.call_id = Some("call_2".to_string());
2031        let model = MockCompletionModel::new([
2032            MockTurn::from_contents([
2033                AssistantContent::ToolCall(valid_tool_call),
2034                AssistantContent::ToolCall(invalid_tool_call),
2035            ])
2036            .expect("tool-call response should be non-empty"),
2037            MockTurn::text("skipped"),
2038        ]);
2039        let agent = AgentBuilder::new(model)
2040            .tool(CountingAddTool {
2041                calls: add_calls.clone(),
2042            })
2043            .build();
2044
2045        let response = agent
2046            .prompt("add")
2047            .with_hook(SkipDefaultApiAndPanicOnToolCallHook)
2048            .max_turns(3)
2049            .extended_details()
2050            .await
2051            .expect("skip should recover without executing peer tools");
2052
2053        assert_eq!(response.output, "skipped");
2054        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
2055        let messages = response.messages.expect("messages should be present");
2056        assert!(history_contains_tool_call(&messages, "add"));
2057        assert!(history_contains_tool_call(&messages, "default_api"));
2058        assert!(matches!(
2059            messages.get(2),
2060            Some(Message::User { content })
2061                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
2062                    && content.iter().any(|item| matches!(
2063                        item,
2064                        UserContent::ToolResult(result)
2065                            if result.id == "tool_call_1"
2066                                && result.call_id.as_deref() == Some("call_1")
2067                                && result.content.iter().any(|content| matches!(
2068                                    content,
2069                                    crate::message::ToolResultContent::Text(text)
2070                                        if text.text == super::TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
2071                                ))
2072                    ))
2073                    && content.iter().any(|item| matches!(
2074                        item,
2075                        UserContent::ToolResult(result)
2076                            if result.id == "tool_call_2"
2077                                && result.call_id.as_deref() == Some("call_2")
2078                                && result.content.iter().any(|content| matches!(
2079                                    content,
2080                                    crate::message::ToolResultContent::Text(text)
2081                                        if text.text == "default_api is not available"
2082                                ))
2083                    ))
2084        ));
2085    }
2086
2087    #[tokio::test]
2088    async fn invalid_tool_call_hook_retry_budget_exhaustion_fails() {
2089        let model = MockCompletionModel::new([
2090            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2091            MockTurn::text("should not be requested"),
2092        ]);
2093        let recorded = model.clone();
2094        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2095
2096        let err = agent
2097            .prompt("add")
2098            .with_hook(RetryDefaultApiHook)
2099            .max_invalid_tool_call_retries(0)
2100            .max_turns(3)
2101            .await
2102            .expect_err("retry without budget should fail");
2103
2104        match err {
2105            PromptError::UnknownToolCall {
2106                tool_name,
2107                chat_history,
2108                ..
2109            } => {
2110                assert_eq!(tool_name, "default_api");
2111                assert!(history_contains_tool_call(&chat_history, "default_api"));
2112            }
2113            other => panic!("expected UnknownToolCall, got {other:?}"),
2114        }
2115        assert_eq!(recorded.request_count(), 1);
2116    }
2117
2118    #[tokio::test]
2119    async fn invalid_tool_call_hook_can_skip_structured_non_streaming_call() {
2120        let model = MockCompletionModel::new([
2121            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2122            MockTurn::text("skipped"),
2123        ]);
2124        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2125
2126        let response = agent
2127            .prompt("add")
2128            .with_hook(SkipDefaultApiHook)
2129            .max_turns(3)
2130            .extended_details()
2131            .await
2132            .expect("skip should continue with synthetic tool result");
2133
2134        assert_eq!(response.output, "skipped");
2135        let messages = response.messages.expect("messages should be present");
2136        assert!(history_contains_tool_call(&messages, "default_api"));
2137        assert!(messages.iter().any(|message| {
2138            matches!(
2139                message,
2140                Message::User { content }
2141                    if content.iter().any(|content| {
2142                        matches!(
2143                            content,
2144                            UserContent::ToolResult(result)
2145                                if result.content.iter().any(|content| {
2146                                    matches!(
2147                                        content,
2148                                        crate::message::ToolResultContent::Text(text)
2149                                            if text.text == "default_api is not available"
2150                                    )
2151                                })
2152                        )
2153                    })
2154            )
2155        }));
2156    }
2157
2158    #[tokio::test]
2159    async fn skip_under_specific_tool_choice_returns_synthetic_feedback() {
2160        let model = MockCompletionModel::new([
2161            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2162            MockTurn::text("skipped"),
2163        ]);
2164        let agent = AgentBuilder::new(model)
2165            .tool(MockAddTool)
2166            .tool_choice(ToolChoice::Specific {
2167                function_names: vec!["add".to_string()],
2168            })
2169            .build();
2170
2171        let response = agent
2172            .prompt("add")
2173            .with_hook(SkipDefaultApiHook)
2174            .max_turns(3)
2175            .extended_details()
2176            .await
2177            .expect("skip should produce synthetic feedback under Specific");
2178
2179        assert_eq!(response.output, "skipped");
2180        let messages = response.messages.expect("messages should be present");
2181        assert!(history_contains_tool_call(&messages, "default_api"));
2182        assert!(messages.iter().any(|message| {
2183            matches!(
2184                message,
2185                Message::User { content }
2186                    if content.iter().any(|content| {
2187                        matches!(
2188                            content,
2189                            UserContent::ToolResult(result)
2190                                if result.id == "tool_call_1"
2191                                    && result.content.iter().any(|content| {
2192                                        matches!(
2193                                            content,
2194                                            crate::message::ToolResultContent::Text(text)
2195                                                if text.text == "default_api is not available"
2196                                        )
2197                                    })
2198                        )
2199                    })
2200            )
2201        }));
2202    }
2203
2204    #[tokio::test]
2205    async fn repair_to_disallowed_specific_tool_fails() {
2206        let model = MockCompletionModel::new([
2207            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2208            MockTurn::text("should not be requested"),
2209        ]);
2210        let recorded = model.clone();
2211        let agent = AgentBuilder::new(model)
2212            .tool(MockAddTool)
2213            .tool(MockSubtractTool)
2214            .tool_choice(ToolChoice::Specific {
2215                function_names: vec!["add".to_string()],
2216            })
2217            .build();
2218
2219        let err = agent
2220            .prompt("add")
2221            .with_hook(RepairToSubtractHook)
2222            .max_turns(3)
2223            .await
2224            .expect_err("repair to a disallowed tool should fail");
2225
2226        match err {
2227            PromptError::UnknownToolCall { tool_name, .. } => {
2228                assert_eq!(tool_name, "subtract");
2229            }
2230            other => panic!("expected UnknownToolCall, got {other:?}"),
2231        }
2232        assert_eq!(recorded.request_count(), 1);
2233    }
2234
2235    #[tokio::test]
2236    async fn repair_under_tool_choice_none_fails() {
2237        let model = MockCompletionModel::new([
2238            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2239            MockTurn::text("should not be requested"),
2240        ]);
2241        let recorded = model.clone();
2242        let agent = AgentBuilder::new(model)
2243            .tool(MockAddTool)
2244            .tool_choice(ToolChoice::None)
2245            .build();
2246
2247        let err = agent
2248            .prompt("do not use tools")
2249            .with_hook(RepairDefaultApiHook)
2250            .max_turns(3)
2251            .await
2252            .expect_err("ToolChoice::None should reject repaired tool calls");
2253
2254        match err {
2255            PromptError::UnknownToolCall { tool_name, .. } => {
2256                assert_eq!(tool_name, "add");
2257            }
2258            other => panic!("expected UnknownToolCall, got {other:?}"),
2259        }
2260        assert_eq!(recorded.request_count(), 1);
2261    }
2262
2263    #[tokio::test]
2264    async fn skip_under_tool_choice_none_fails() {
2265        let model = MockCompletionModel::new([
2266            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2267            MockTurn::text("should not be requested"),
2268        ]);
2269        let recorded = model.clone();
2270        let agent = AgentBuilder::new(model)
2271            .tool(MockAddTool)
2272            .tool_choice(ToolChoice::None)
2273            .build();
2274
2275        let err = agent
2276            .prompt("do not use tools")
2277            .with_hook(SkipDefaultApiHook)
2278            .max_turns(3)
2279            .await
2280            .expect_err("ToolChoice::None should reject skipped tool calls");
2281
2282        match err {
2283            PromptError::UnknownToolCall { tool_name, .. } => {
2284                assert_eq!(tool_name, "default_api");
2285            }
2286            other => panic!("expected UnknownToolCall, got {other:?}"),
2287        }
2288        assert_eq!(recorded.request_count(), 1);
2289    }
2290
2291    #[tokio::test]
2292    async fn typed_prompt_default_invalid_tool_call_fails_fast() {
2293        let model = MockCompletionModel::new([
2294            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2295            MockTurn::text(r#"{"value":"should not be requested"}"#),
2296        ]);
2297        let recorded = model.clone();
2298        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2299
2300        let err = agent
2301            .prompt_typed::<TypedAnswer>("return typed json")
2302            .with_hook(PanicOnUnknownToolHook)
2303            .max_turns(3)
2304            .await
2305            .expect_err("typed prompt should preserve fail-fast default");
2306
2307        match err {
2308            StructuredOutputError::PromptError(err) => match *err {
2309                PromptError::UnknownToolCall { tool_name, .. } => {
2310                    assert_eq!(tool_name, "default_api");
2311                }
2312                other => panic!("expected UnknownToolCall, got {other:?}"),
2313            },
2314            other => panic!("expected prompt error, got {other:?}"),
2315        }
2316        assert_eq!(recorded.request_count(), 1);
2317    }
2318
2319    #[tokio::test]
2320    async fn typed_prompt_invalid_tool_call_hook_can_repair_tool_name() {
2321        let model = MockCompletionModel::new([
2322            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2323            MockTurn::text(r#"{"value":"repaired"}"#),
2324        ]);
2325        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2326
2327        let response = agent
2328            .prompt_typed::<TypedAnswer>("return typed json")
2329            .with_hook(RepairDefaultApiHook)
2330            .max_turns(3)
2331            .await
2332            .expect("typed prompt should repair invalid tool call");
2333
2334        assert_eq!(
2335            response,
2336            TypedAnswer {
2337                value: "repaired".to_string()
2338            }
2339        );
2340    }
2341
2342    #[tokio::test]
2343    async fn typed_prompt_invalid_tool_call_hook_can_retry_and_parse_response() {
2344        let model = MockCompletionModel::new([
2345            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2346            MockTurn::text(r#"{"value":"retried"}"#),
2347        ]);
2348        let recorded = model.clone();
2349        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2350
2351        let response = agent
2352            .prompt_typed::<TypedAnswer>("return typed json")
2353            .with_hook(RetryDefaultApiHook)
2354            .max_invalid_tool_call_retries(1)
2355            .max_turns(3)
2356            .await
2357            .expect("typed prompt should retry invalid tool call");
2358
2359        assert_eq!(
2360            response,
2361            TypedAnswer {
2362                value: "retried".to_string()
2363            }
2364        );
2365        assert_eq!(recorded.request_count(), 2);
2366    }
2367
2368    #[tokio::test]
2369    async fn typed_prompt_invalid_tool_call_retry_budget_exhaustion_fails() {
2370        let model = MockCompletionModel::new([
2371            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2372            MockTurn::text(r#"{"value":"should not be requested"}"#),
2373        ]);
2374        let recorded = model.clone();
2375        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2376
2377        let err = agent
2378            .prompt_typed::<TypedAnswer>("return typed json")
2379            .with_hook(RetryDefaultApiHook)
2380            .max_invalid_tool_call_retries(0)
2381            .max_turns(3)
2382            .await
2383            .expect_err("typed prompt should fail when retry budget is exhausted");
2384
2385        match err {
2386            StructuredOutputError::PromptError(err) => match *err {
2387                PromptError::UnknownToolCall { tool_name, .. } => {
2388                    assert_eq!(tool_name, "default_api");
2389                }
2390                other => panic!("expected UnknownToolCall, got {other:?}"),
2391            },
2392            other => panic!("expected prompt error, got {other:?}"),
2393        }
2394        assert_eq!(recorded.request_count(), 1);
2395    }
2396
2397    #[tokio::test]
2398    async fn invalid_specific_tool_choice_fails_before_non_streaming_provider_request() {
2399        let model = MockCompletionModel::text("should not be requested");
2400        let recorded = model.clone();
2401        let agent = AgentBuilder::new(model)
2402            .tool(MockAddTool)
2403            .tool_choice(ToolChoice::Specific {
2404                function_names: vec!["missing".to_string()],
2405            })
2406            .build();
2407
2408        let err = agent
2409            .prompt("use the missing tool")
2410            .await
2411            .expect_err("invalid ToolChoice::Specific should fail before provider request");
2412
2413        match err {
2414            PromptError::CompletionError(CompletionError::RequestError(err)) => {
2415                let msg = err.to_string();
2416                assert!(msg.contains("missing"), "got: {msg}");
2417                assert!(msg.contains("add"), "got: {msg}");
2418            }
2419            other => panic!("expected CompletionError::RequestError, got {other:?}"),
2420        }
2421        assert_eq!(recorded.request_count(), 0);
2422    }
2423
2424    #[tokio::test]
2425    async fn allowed_specific_tool_call_executes_normally() {
2426        let model = MockCompletionModel::new([
2427            MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2})),
2428            MockTurn::text("done"),
2429        ]);
2430        let recorded = model.clone();
2431        let agent = AgentBuilder::new(model)
2432            .tool(MockAddTool)
2433            .tool_choice(ToolChoice::Specific {
2434                function_names: vec!["add".to_string()],
2435            })
2436            .build();
2437
2438        let response = agent
2439            .prompt("use the allowed tool")
2440            .max_turns(3)
2441            .await
2442            .expect("allowed specific tool should execute");
2443
2444        assert_eq!(response, "done");
2445        assert_eq!(recorded.request_count(), 2);
2446    }
2447
2448    #[tokio::test]
2449    async fn prompt_request_stops_cleanly_on_empty_terminal_turn() {
2450        let first_call_usage = Usage {
2451            input_tokens: 1,
2452            output_tokens: 1,
2453            total_tokens: 2,
2454            cached_input_tokens: 0,
2455            cache_creation_input_tokens: 0,
2456            tool_use_prompt_tokens: 0,
2457            reasoning_tokens: 0,
2458        };
2459        let second_call_usage = Usage {
2460            input_tokens: 1,
2461            output_tokens: 1,
2462            total_tokens: 2,
2463            cached_input_tokens: 0,
2464            cache_creation_input_tokens: 0,
2465            tool_use_prompt_tokens: 0,
2466            reasoning_tokens: 0,
2467        };
2468        let model = MockCompletionModel::new([
2469            MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2}))
2470                .with_call_id("call_1")
2471                .with_usage(first_call_usage),
2472            MockTurn::text("").with_usage(second_call_usage),
2473        ]);
2474        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2475
2476        let response = agent
2477            .prompt("do tool work")
2478            .max_turns(3)
2479            .extended_details()
2480            .await
2481            .expect("empty terminal turn should not error");
2482
2483        assert!(response.output.is_empty());
2484        assert_eq!(
2485            response.usage,
2486            Usage {
2487                input_tokens: 2,
2488                output_tokens: 2,
2489                total_tokens: 4,
2490                cached_input_tokens: 0,
2491                cache_creation_input_tokens: 0,
2492                tool_use_prompt_tokens: 0,
2493                reasoning_tokens: 0,
2494            }
2495        );
2496        assert_eq!(
2497            response.completion_calls(),
2498            &[
2499                CompletionCall::new(0, Some(first_call_usage)),
2500                CompletionCall::new(1, Some(second_call_usage))
2501            ]
2502        );
2503
2504        let history = response
2505            .messages
2506            .expect("extended response should include history");
2507        assert_eq!(history.len(), 3);
2508        assert!(matches!(
2509            history.first(),
2510            Some(Message::User { content })
2511                if matches!(
2512                    content.first(),
2513                    UserContent::Text(text) if text.text == "do tool work"
2514                )
2515        ));
2516        assert!(history.iter().any(|message| matches!(
2517            message,
2518            Message::Assistant { content, .. }
2519                if matches!(
2520                    content.first(),
2521                    AssistantContent::ToolCall(tool_call)
2522                        if tool_call.id == "tool_call_1"
2523                            && tool_call.call_id.as_deref() == Some("call_1")
2524                )
2525        )));
2526        assert!(history.iter().any(|message| matches!(
2527            message,
2528            Message::User { content }
2529                if matches!(
2530                    content.first(),
2531                    UserContent::ToolResult(tool_result)
2532                        if tool_result.id == "tool_call_1"
2533                            && tool_result.call_id.as_deref() == Some("call_1")
2534                )
2535        )));
2536        assert!(!history.iter().any(|message| matches!(
2537            message,
2538            Message::Assistant { content, .. }
2539                if content.iter().any(|item| matches!(
2540                    item,
2541                    AssistantContent::Text(text) if text.text.is_empty()
2542                ))
2543        )));
2544        let requests = agent.model.requests();
2545        assert_eq!(requests.len(), 2);
2546        validate_follow_up_tool_history(&requests[1]);
2547    }
2548
2549    #[tokio::test]
2550    async fn prompt_request_concatenates_text_blocks_without_inserted_newlines() {
2551        let model = MockCompletionModel::new([MockTurn::from_contents([
2552            AssistantContent::Text(Text::new("According to the document, ")),
2553            AssistantContent::Text(Text::new("the grass is green")),
2554            AssistantContent::Text(Text::new(" and the sky is blue.")),
2555        ])
2556        .expect("mock response should contain text blocks")]);
2557        let agent = AgentBuilder::new(model).build();
2558
2559        let response = agent
2560            .prompt("answer with cited spans")
2561            .await
2562            .expect("prompt should succeed");
2563
2564        assert_eq!(
2565            response,
2566            "According to the document, the grass is green and the sky is blue."
2567        );
2568    }
2569
2570    #[tokio::test]
2571    async fn prompt_request_preserves_metadata_only_text_turn_in_history() {
2572        let metadata = json!({
2573            "citations": [{
2574                "type": "web_search_result_location",
2575                "cited_text": "Claude Shannon was born in 1916.",
2576                "url": "https://example.com/shannon",
2577                "title": null,
2578                "encrypted_index": "encrypted-reference"
2579            }]
2580        });
2581        let model =
2582            MockCompletionModel::new([MockTurn::from_content(AssistantContent::Text(Text {
2583                text: String::new(),
2584                additional_params: Some(metadata.clone()),
2585            }))]);
2586        let agent = AgentBuilder::new(model).build();
2587
2588        let response = agent
2589            .prompt("answer with cited metadata")
2590            .extended_details()
2591            .await
2592            .expect("metadata-only text turn should succeed");
2593
2594        assert!(response.output.is_empty());
2595        let history = response
2596            .messages
2597            .expect("extended response should include history");
2598        assert!(history.iter().any(|message| matches!(
2599            message,
2600            Message::Assistant { content, .. }
2601                if matches!(
2602                    content.first(),
2603                    AssistantContent::Text(text)
2604                        if text.text.is_empty()
2605                            && text.additional_params.as_ref() == Some(&metadata)
2606                )
2607        )));
2608    }
2609
2610    // ----- Conversation memory integration tests -----
2611
2612    use crate::memory::{ConversationMemory, InMemoryConversationMemory};
2613
2614    #[tokio::test]
2615    async fn memory_loads_into_request_history() {
2616        let memory = InMemoryConversationMemory::new();
2617        memory
2618            .append(
2619                "thread-1",
2620                vec![Message::user("hello"), Message::assistant("hi there")],
2621            )
2622            .await
2623            .unwrap();
2624
2625        let model = MockCompletionModel::text("ack");
2626        let recorded = model.clone();
2627
2628        let agent = AgentBuilder::new(model).memory(memory).build();
2629        let _ = agent
2630            .prompt("ping")
2631            .conversation("thread-1")
2632            .await
2633            .expect("prompt should succeed");
2634
2635        let received = recorded.requests()[0]
2636            .chat_history
2637            .iter()
2638            .cloned()
2639            .collect::<Vec<_>>();
2640        assert_eq!(
2641            received.len(),
2642            3,
2643            "loaded memory (2) + current prompt should appear in request: {received:?}"
2644        );
2645    }
2646
2647    #[tokio::test]
2648    async fn memory_appends_full_turn_after_success() {
2649        let memory = InMemoryConversationMemory::new();
2650        let model = MockCompletionModel::text("ack");
2651        let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2652
2653        let _ = agent
2654            .prompt("hello")
2655            .conversation("t1")
2656            .await
2657            .expect("prompt should succeed");
2658
2659        let stored = memory.load("t1").await.unwrap();
2660        assert_eq!(stored.len(), 2, "user prompt + assistant response saved");
2661    }
2662
2663    #[tokio::test]
2664    async fn explicit_with_history_overrides_memory() {
2665        let memory = CountingMemory::default();
2666        memory
2667            .inner()
2668            .append("t1", vec![Message::user("from-memory")])
2669            .await
2670            .unwrap();
2671
2672        let model = MockCompletionModel::text("ack");
2673        let recorded = model.clone();
2674
2675        let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2676        let _ = agent
2677            .prompt("hello")
2678            .conversation("t1")
2679            .with_history(vec![Message::user("from-caller")])
2680            .await
2681            .expect("prompt should succeed");
2682
2683        assert_eq!(memory.load_count(), 0, "load skipped");
2684        let appends = memory.append_count();
2685        assert_eq!(appends, 0, "append skipped");
2686
2687        let received = recorded.requests()[0]
2688            .chat_history
2689            .iter()
2690            .cloned()
2691            .collect::<Vec<_>>();
2692        assert_eq!(received.len(), 2, "caller history (1) + current prompt");
2693        assert!(matches!(
2694            received.first(),
2695            Some(Message::User { content })
2696                if matches!(content.first(), UserContent::Text(t) if t.text == "from-caller")
2697        ));
2698    }
2699
2700    #[tokio::test]
2701    async fn memory_unchanged_on_provider_error() {
2702        let memory = InMemoryConversationMemory::new();
2703        let model = MockCompletionModel::new([MockTurn::error("boom")]);
2704
2705        let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2706        let result = agent.prompt("hello").conversation("t1").await;
2707        assert!(result.is_err());
2708
2709        let stored = memory.load("t1").await.unwrap();
2710        assert!(stored.is_empty(), "no append on error");
2711    }
2712
2713    #[tokio::test]
2714    async fn missing_conversation_id_behaves_as_no_memory() {
2715        let memory = CountingMemory::default();
2716        let model = MockCompletionModel::text("ack");
2717        let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2718
2719        let _ = agent.prompt("hello").await.expect("prompt should succeed");
2720
2721        assert_eq!(memory.load_count(), 0);
2722        assert_eq!(memory.append_count(), 0);
2723    }
2724
2725    #[tokio::test]
2726    async fn default_conversation_id_is_used_when_none_per_request() {
2727        let memory = InMemoryConversationMemory::new();
2728        let model = MockCompletionModel::text("ack");
2729        let agent = AgentBuilder::new(model)
2730            .memory(memory.clone())
2731            .conversation_id("default-thread")
2732            .build();
2733
2734        let _ = agent.prompt("hello").await.expect("prompt should succeed");
2735        let stored = memory.load("default-thread").await.unwrap();
2736        assert_eq!(stored.len(), 2);
2737    }
2738
2739    #[tokio::test]
2740    async fn with_filter_truncates_loaded_history() {
2741        let memory = InMemoryConversationMemory::new()
2742            .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
2743        memory
2744            .append(
2745                "t1",
2746                vec![
2747                    Message::user("1"),
2748                    Message::assistant("2"),
2749                    Message::user("3"),
2750                    Message::assistant("4"),
2751                ],
2752            )
2753            .await
2754            .unwrap();
2755
2756        let model = MockCompletionModel::text("ack");
2757        let recorded = model.clone();
2758        let agent = AgentBuilder::new(model).memory(memory).build();
2759
2760        let _ = agent
2761            .prompt("ping")
2762            .conversation("t1")
2763            .await
2764            .expect("prompt should succeed");
2765
2766        let received = recorded.requests()[0]
2767            .chat_history
2768            .iter()
2769            .cloned()
2770            .collect::<Vec<_>>();
2771        assert_eq!(
2772            received.len(),
2773            3,
2774            "window-truncated history (2) + current prompt"
2775        );
2776    }
2777
2778    #[tokio::test]
2779    async fn without_memory_disables_for_request() {
2780        let memory = CountingMemory::default();
2781        let model = MockCompletionModel::text("ack");
2782        let agent = AgentBuilder::new(model)
2783            .memory(memory.clone())
2784            .conversation_id("t1")
2785            .build();
2786
2787        let _ = agent
2788            .prompt("hello")
2789            .without_memory()
2790            .await
2791            .expect("prompt should succeed");
2792
2793        assert_eq!(memory.load_count(), 0);
2794        assert_eq!(memory.append_count(), 0);
2795    }
2796
2797    #[tokio::test]
2798    async fn memory_load_error_surfaces_as_prompt_error() {
2799        let model = MockCompletionModel::text("ack");
2800        let agent = AgentBuilder::new(model)
2801            .memory(FailingMemory::default())
2802            .build();
2803        let result = agent.prompt("hello").conversation("t1").await;
2804
2805        match result {
2806            Err(PromptError::CompletionError(CompletionError::RequestError(err))) => {
2807                let msg = format!("{err}");
2808                assert!(msg.contains("load boom"), "got: {msg}");
2809            }
2810            other => panic!("expected PromptError::CompletionError(RequestError), got {other:?}"),
2811        }
2812    }
2813
2814    #[tokio::test]
2815    async fn memory_append_error_does_not_drop_response() {
2816        let model = MockCompletionModel::text("ack");
2817        let agent = AgentBuilder::new(model)
2818            .memory(AppendFailingMemory::default())
2819            .build();
2820        let response: String = agent
2821            .prompt("hello")
2822            .conversation("t1")
2823            .await
2824            .expect("append failure must not block successful completion");
2825
2826        assert!(!response.is_empty());
2827    }
2828}