Skip to main content

rig_core/agent/prompt_request/
mod.rs

1pub mod hooks;
2pub mod streaming;
3
4use super::{
5    Agent,
6    completion::{DynamicContextStore, build_prepared_completion_request},
7    run::{AgentRun, AgentRunStep, ModelTurn, ModelTurnOutcome, PendingToolCall},
8};
9use crate::{
10    OneOrMany,
11    completion::{CompletionModel, Document, Message, PromptError, Usage},
12    json_utils,
13    memory::ConversationMemory,
14    message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
15    tool::server::ToolServerHandle,
16    wasm_compat::{WasmBoxedFuture, WasmCompatSend},
17};
18use futures::{StreamExt, stream};
19use hooks::{HookAction, InvalidToolCallHookAction, PromptHook, ToolCallHookAction};
20use serde::{Deserialize, Serialize};
21use std::{
22    future::IntoFuture,
23    marker::PhantomData,
24    sync::{
25        Arc,
26        atomic::{AtomicU64, Ordering},
27    },
28};
29use tracing::info_span;
30use tracing::{Instrument, span::Id};
31
32pub trait PromptType {}
33pub struct Standard;
34pub struct Extended;
35
36impl PromptType for Standard {}
37impl PromptType for Extended {}
38
39/// A builder for creating prompt requests with customizable options.
40/// Uses generics to track which options have been set during the build process.
41///
42/// If you expect to continuously call tools, you will want to ensure you use the `.multi_turn()`
43/// argument to add more turns as by default, it is 0 (meaning only 1 tool round-trip). Otherwise,
44/// attempting to await (which will send the prompt request) can potentially return
45/// [`crate::completion::request::PromptError::MaxTurnsError`] if the agent decides to call tools
46/// back to back.
47pub struct PromptRequest<S, M, P>
48where
49    S: PromptType,
50    M: CompletionModel,
51    P: PromptHook<M>,
52{
53    /// The prompt message to send to the model
54    prompt: Message,
55    /// Optional chat history provided by the caller.
56    chat_history: Option<Vec<Message>>,
57    /// Maximum depth for multi-turn conversations (0 means no multi-turn)
58    max_turns: usize,
59
60    // Agent data (cloned from agent to allow hook type transitions):
61    /// The completion model
62    model: Arc<M>,
63    /// Agent name for logging
64    agent_name: Option<String>,
65    /// System prompt
66    preamble: Option<String>,
67    /// Static context documents
68    static_context: Vec<Document>,
69    /// Temperature setting
70    temperature: Option<f64>,
71    /// Max tokens setting
72    max_tokens: Option<u64>,
73    /// Additional model parameters
74    additional_params: Option<serde_json::Value>,
75    /// Tool server handle for tool execution
76    tool_server_handle: ToolServerHandle,
77    /// Dynamic context store
78    dynamic_context: DynamicContextStore,
79    /// Tool choice setting
80    tool_choice: Option<ToolChoice>,
81
82    /// Phantom data to track the type of the request
83    state: PhantomData<S>,
84    /// Optional per-request hook for events
85    hook: Option<P>,
86    /// Maximum number of invalid tool-call retries for this request.
87    max_invalid_tool_call_retries: usize,
88    /// How many tools should be executed at the same time (1 by default).
89    concurrency: usize,
90    /// Optional JSON Schema for structured output
91    output_schema: Option<schemars::Schema>,
92    /// Optional conversation memory backend cloned from the agent.
93    memory: Option<Arc<dyn ConversationMemory>>,
94    /// Optional conversation id used for loading and saving memory.
95    conversation_id: Option<String>,
96}
97
98impl<M, P> PromptRequest<Standard, M, P>
99where
100    M: CompletionModel,
101    P: PromptHook<M>,
102{
103    /// Create a new PromptRequest from an agent, cloning the agent's data and default hook.
104    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
105        PromptRequest {
106            prompt: prompt.into(),
107            chat_history: None,
108            max_turns: agent.default_max_turns.unwrap_or_default(),
109            model: agent.model.clone(),
110            agent_name: agent.name.clone(),
111            preamble: agent.preamble.clone(),
112            static_context: agent.static_context.clone(),
113            temperature: agent.temperature,
114            max_tokens: agent.max_tokens,
115            additional_params: agent.additional_params.clone(),
116            tool_server_handle: agent.tool_server_handle.clone(),
117            dynamic_context: agent.dynamic_context.clone(),
118            tool_choice: agent.tool_choice.clone(),
119            state: PhantomData,
120            hook: agent.hook.clone(),
121            max_invalid_tool_call_retries: 0,
122            concurrency: 1,
123            output_schema: agent.output_schema.clone(),
124            memory: agent.memory.clone(),
125            conversation_id: agent.default_conversation_id.clone(),
126        }
127    }
128}
129
130impl<S, M, P> PromptRequest<S, M, P>
131where
132    S: PromptType,
133    M: CompletionModel,
134    P: PromptHook<M>,
135{
136    /// Enable returning extended details for responses (includes aggregated token usage
137    /// and the full message history accumulated during the agent loop).
138    ///
139    /// Note: This changes the type of the response from `.send` to return a `PromptResponse` struct
140    /// instead of a simple `String`. This is useful for tracking token usage across multiple turns
141    /// of conversation and inspecting the full message exchange.
142    pub fn extended_details(self) -> PromptRequest<Extended, M, P> {
143        PromptRequest {
144            prompt: self.prompt,
145            chat_history: self.chat_history,
146            max_turns: self.max_turns,
147            model: self.model,
148            agent_name: self.agent_name,
149            preamble: self.preamble,
150            static_context: self.static_context,
151            temperature: self.temperature,
152            max_tokens: self.max_tokens,
153            additional_params: self.additional_params,
154            tool_server_handle: self.tool_server_handle,
155            dynamic_context: self.dynamic_context,
156            tool_choice: self.tool_choice,
157            state: PhantomData,
158            hook: self.hook,
159            max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
160            concurrency: self.concurrency,
161            output_schema: self.output_schema,
162            memory: self.memory,
163            conversation_id: self.conversation_id,
164        }
165    }
166
167    /// Set the maximum number of turns for multi-turn conversations. A given agent may require multiple turns for tool-calling before giving an answer.
168    /// If the maximum turn number is exceeded, it will return a [`crate::completion::request::PromptError::MaxTurnsError`].
169    pub fn max_turns(mut self, depth: usize) -> Self {
170        self.max_turns = depth;
171        self
172    }
173
174    /// Add concurrency to the prompt request.
175    /// This will cause the agent to execute tools concurrently.
176    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
177        self.concurrency = concurrency;
178        self
179    }
180
181    /// Add chat history to the prompt request.
182    pub fn with_history<H, T>(mut self, history: H) -> Self
183    where
184        H: IntoIterator<Item = T>,
185        T: Into<Message>,
186    {
187        self.chat_history = Some(history.into_iter().map(Into::into).collect());
188        self
189    }
190
191    /// Set the conversation id used to load and persist memory for this request.
192    ///
193    /// Overrides any default conversation id set on the agent. If memory is not
194    /// configured on the agent, this has no effect.
195    pub fn conversation(mut self, id: impl Into<String>) -> Self {
196        self.conversation_id = Some(id.into());
197        self
198    }
199
200    /// Disable conversation memory for this request.
201    ///
202    /// History will neither be loaded from nor saved to the agent's memory backend.
203    pub fn without_memory(mut self) -> Self {
204        self.memory = None;
205        self.conversation_id = None;
206        self
207    }
208
209    /// Attach a per-request hook for tool call events.
210    /// This overrides any default hook set on the agent.
211    pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<S, M, P2>
212    where
213        P2: PromptHook<M>,
214    {
215        PromptRequest {
216            prompt: self.prompt,
217            chat_history: self.chat_history,
218            max_turns: self.max_turns,
219            model: self.model,
220            agent_name: self.agent_name,
221            preamble: self.preamble,
222            static_context: self.static_context,
223            temperature: self.temperature,
224            max_tokens: self.max_tokens,
225            additional_params: self.additional_params,
226            tool_server_handle: self.tool_server_handle,
227            dynamic_context: self.dynamic_context,
228            tool_choice: self.tool_choice,
229            state: PhantomData,
230            hook: Some(hook),
231            max_invalid_tool_call_retries: self.max_invalid_tool_call_retries,
232            concurrency: self.concurrency,
233            output_schema: self.output_schema,
234            memory: self.memory,
235            conversation_id: self.conversation_id,
236        }
237    }
238
239    /// Set the retry budget for [`InvalidToolCallHookAction::Retry`].
240    ///
241    /// Invalid tool-call retries also consume normal multi-turn depth.
242    pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
243        self.max_invalid_tool_call_retries = retries;
244        self
245    }
246}
247
248/// Due to: [RFC 2515](https://github.com/rust-lang/rust/issues/63063), we have to use a `BoxFuture`
249///  for the `IntoFuture` implementation. In the future, we should be able to use `impl Future<...>`
250///  directly via the associated type.
251impl<M, P> IntoFuture for PromptRequest<Standard, M, P>
252where
253    M: CompletionModel + 'static,
254    P: PromptHook<M> + 'static,
255{
256    type Output = Result<String, PromptError>;
257    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
258
259    fn into_future(self) -> Self::IntoFuture {
260        Box::pin(self.send())
261    }
262}
263
264impl<M, P> IntoFuture for PromptRequest<Extended, M, P>
265where
266    M: CompletionModel + 'static,
267    P: PromptHook<M> + 'static,
268{
269    type Output = Result<PromptResponse, PromptError>;
270    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
271
272    fn into_future(self) -> Self::IntoFuture {
273        Box::pin(self.send())
274    }
275}
276
277impl<M, P> PromptRequest<Standard, M, P>
278where
279    M: CompletionModel,
280    P: PromptHook<M>,
281{
282    async fn send(self) -> Result<String, PromptError> {
283        self.extended_details().send().await.map(|resp| resp.output)
284    }
285}
286
287/// Details for one successfully completed completion request made by an agent run.
288#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
289#[non_exhaustive]
290pub struct CompletionCall {
291    /// Zero-based index of the completion request within this agent run.
292    pub call_index: usize,
293    /// Token usage reported for this completion request.
294    ///
295    /// Zero-valued usage is [`Usage`]'s documented sentinel for missing
296    /// provider usage metrics; rig does not distinguish "reported all zeros"
297    /// from "unreported".
298    #[serde(default, deserialize_with = "usage_null_as_default")]
299    pub usage: Usage,
300}
301
302impl CompletionCall {
303    /// Create details for one completion request in an agent run.
304    pub fn new(call_index: usize, usage: Usage) -> Self {
305        Self { call_index, usage }
306    }
307}
308
309/// Tolerate `null` usage from data serialized before rig dropped the
310/// `Option<Usage>` encoding of missing provider usage metrics.
311///
312/// This tolerance requires a self-describing format such as JSON; data
313/// serialized with non-self-describing formats (e.g. bincode) from before the
314/// change cannot round-trip.
315fn usage_null_as_default<'de, D>(deserializer: D) -> Result<Usage, D::Error>
316where
317    D: serde::Deserializer<'de>,
318{
319    Ok(Option::<Usage>::deserialize(deserializer)?.unwrap_or_default())
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
323#[non_exhaustive]
324pub struct PromptResponse {
325    pub output: String,
326    pub usage: Usage,
327    /// Successfully completed completion requests made by this agent run.
328    ///
329    /// `usage` remains the aggregate across the whole run. Use the last
330    /// entry's usage to inspect the final completion request's prompt/context
331    /// length. Zero-valued entry usage means the provider reported no usage
332    /// metrics for that request.
333    #[serde(default, skip_serializing_if = "Vec::is_empty")]
334    pub completion_calls: Vec<CompletionCall>,
335    pub messages: Option<Vec<Message>>,
336}
337
338impl std::fmt::Display for PromptResponse {
339    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340        self.output.fmt(f)
341    }
342}
343
344impl PromptResponse {
345    pub fn new(output: impl Into<String>, usage: Usage) -> Self {
346        Self {
347            output: output.into(),
348            usage,
349            completion_calls: Vec::new(),
350            messages: None,
351        }
352    }
353
354    pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
355        self.messages = Some(messages);
356        self
357    }
358
359    /// Attach completion call details to this response.
360    pub fn with_completion_calls(mut self, completion_calls: Vec<CompletionCall>) -> Self {
361        self.completion_calls = completion_calls;
362        self
363    }
364
365    /// Returns successfully completed completion requests made by this agent run.
366    ///
367    /// Zero-valued entry usage means the provider reported no usage metrics
368    /// for that request.
369    pub fn completion_calls(&self) -> &[CompletionCall] {
370        &self.completion_calls
371    }
372
373    /// Number of completion requests this agent run made.
374    pub fn requests(&self) -> usize {
375        self.completion_calls.len()
376    }
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
380#[non_exhaustive]
381pub struct TypedPromptResponse<T> {
382    pub output: T,
383    pub usage: Usage,
384    /// Successfully completed completion requests made by this agent run.
385    ///
386    /// `usage` remains the aggregate across the whole run. Use the last
387    /// entry's usage to inspect the final completion request's prompt/context
388    /// length. Zero-valued entry usage means the provider reported no usage
389    /// metrics for that request.
390    #[serde(default, skip_serializing_if = "Vec::is_empty")]
391    pub completion_calls: Vec<CompletionCall>,
392}
393
394impl<T> TypedPromptResponse<T> {
395    pub fn new(output: T, usage: Usage) -> Self {
396        Self {
397            output,
398            usage,
399            completion_calls: Vec::new(),
400        }
401    }
402
403    /// Attach completion call details to this response.
404    pub fn with_completion_calls(mut self, completion_calls: Vec<CompletionCall>) -> Self {
405        self.completion_calls = completion_calls;
406        self
407    }
408
409    /// Returns successfully completed completion requests made by this agent run.
410    ///
411    /// Zero-valued entry usage means the provider reported no usage metrics
412    /// for that request.
413    pub fn completion_calls(&self) -> &[CompletionCall] {
414        &self.completion_calls
415    }
416
417    /// Number of completion requests this agent run made.
418    pub fn requests(&self) -> usize {
419        self.completion_calls.len()
420    }
421}
422
423const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
424
425pub(crate) const TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER: &str =
426    "Tool not executed because another tool call in the same assistant turn was invalid.";
427
428/// Combine input history with new messages for building completion requests.
429pub(crate) fn build_history_for_request(
430    chat_history: Option<&[Message]>,
431    new_messages: &[Message],
432) -> Vec<Message> {
433    let input = chat_history.unwrap_or(&[]);
434    input.iter().chain(new_messages.iter()).cloned().collect()
435}
436
437/// Build the full history for error reporting (input + new messages).
438pub(crate) fn build_full_history(
439    chat_history: Option<&[Message]>,
440    new_messages: Vec<Message>,
441) -> Vec<Message> {
442    let input = chat_history.unwrap_or(&[]);
443    input.iter().cloned().chain(new_messages).collect()
444}
445
446pub(crate) fn tool_result_user_content(
447    id: String,
448    call_id: Option<String>,
449    tool_result: String,
450) -> UserContent {
451    let content = ToolResultContent::from_tool_output(tool_result);
452    match call_id {
453        Some(call_id) => UserContent::tool_result_with_call_id(id, call_id, content),
454        None => UserContent::tool_result(id, content),
455    }
456}
457
458pub(crate) fn invalid_tool_retry_user_message(
459    assistant_content: &OneOrMany<AssistantContent>,
460    invalid_tool_call_id: &str,
461    feedback: String,
462) -> Option<Message> {
463    let retry_results = assistant_content
464        .iter()
465        .filter_map(|content| match content {
466            AssistantContent::ToolCall(tool_call) if tool_call.id == invalid_tool_call_id => {
467                Some(tool_result_user_content(
468                    tool_call.id.clone(),
469                    tool_call.call_id.clone(),
470                    feedback.clone(),
471                ))
472            }
473            AssistantContent::ToolCall(tool_call) => Some(tool_result_user_content(
474                tool_call.id.clone(),
475                tool_call.call_id.clone(),
476                TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
477            )),
478            _ => None,
479        })
480        .collect::<Vec<_>>();
481
482    Some(Message::User {
483        content: OneOrMany::from_iter_optional(retry_results)?,
484    })
485}
486
487pub(crate) fn is_empty_assistant_turn(choice: &OneOrMany<AssistantContent>) -> bool {
488    choice.len() == 1
489        && matches!(
490            choice.first(),
491            AssistantContent::Text(text) if text.text.is_empty() && text.additional_params.is_none()
492        )
493}
494
495pub(crate) fn assistant_text_from_choice(choice: &OneOrMany<AssistantContent>) -> String {
496    choice
497        .iter()
498        .filter_map(|content| match content {
499            AssistantContent::Text(text) => Some(text.text.as_str()),
500            _ => None,
501        })
502        .collect()
503}
504
505impl<M, P> PromptRequest<Extended, M, P>
506where
507    M: CompletionModel,
508    P: PromptHook<M>,
509{
510    fn agent_name(&self) -> &str {
511        self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
512    }
513
514    async fn send(self) -> Result<PromptResponse, PromptError> {
515        let agent_span = if tracing::Span::current().is_disabled() {
516            info_span!(
517                "invoke_agent",
518                gen_ai.operation.name = "invoke_agent",
519                gen_ai.agent.name = self.agent_name(),
520                gen_ai.system_instructions = self.preamble,
521                gen_ai.prompt = tracing::field::Empty,
522                gen_ai.completion = tracing::field::Empty,
523                gen_ai.usage.input_tokens = tracing::field::Empty,
524                gen_ai.usage.output_tokens = tracing::field::Empty,
525                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
526                gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
527                gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
528                gen_ai.usage.reasoning_tokens = tracing::field::Empty,
529            )
530        } else {
531            tracing::Span::current()
532        };
533
534        if let Some(text) = self.prompt.rag_text() {
535            agent_span.record("gen_ai.prompt", text);
536        }
537
538        let agent_name_for_span = self.agent_name.clone();
539        // When the caller passes explicit history, memory is fully bypassed for this
540        // request (no load AND no save). Otherwise, if a memory backend and
541        // conversation id are both configured, load prior history; if either is
542        // missing, behave as if no memory is configured.
543        let (chat_history, memory_handle) = match self.chat_history {
544            Some(history) => (Some(history), None),
545            None => match (self.memory, self.conversation_id) {
546                (Some(memory), Some(id)) => {
547                    let loaded = memory.load(&id).await?;
548                    (Some(loaded), Some((memory, id)))
549                }
550                _ => (None, None),
551            },
552        };
553
554        let mut run = AgentRun::new(self.prompt.clone())
555            .max_turns(self.max_turns)
556            .max_invalid_tool_call_retries(self.max_invalid_tool_call_retries);
557        if let Some(history) = chat_history {
558            run = run.with_history(history);
559        }
560        if let Some(tool_choice) = self.tool_choice.clone() {
561            run = run.with_tool_choice(tool_choice);
562        }
563
564        let current_span_id: AtomicU64 = AtomicU64::new(0);
565
566        loop {
567            match run.next_step()? {
568                AgentRunStep::CallModel {
569                    prompt,
570                    history,
571                    turn,
572                } => {
573                    if self.max_turns > 1 {
574                        tracing::info!("Current conversation depth: {}/{}", turn, self.max_turns);
575                    }
576
577                    if let Some(ref hook) = self.hook
578                        && let HookAction::Terminate { reason } =
579                            hook.on_completion_call(&prompt, &history).await
580                    {
581                        return Err(run.cancel_error(reason));
582                    }
583
584                    let span = tracing::Span::current();
585                    let chat_span = info_span!(
586                        target: "rig::agent_chat",
587                        parent: &span,
588                        "chat",
589                        gen_ai.operation.name = "chat",
590                        gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
591                        gen_ai.system_instructions = self.preamble,
592                        gen_ai.provider.name = tracing::field::Empty,
593                        gen_ai.request.model = tracing::field::Empty,
594                        gen_ai.response.id = tracing::field::Empty,
595                        gen_ai.response.model = tracing::field::Empty,
596                        gen_ai.usage.output_tokens = tracing::field::Empty,
597                        gen_ai.usage.input_tokens = tracing::field::Empty,
598                        gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
599                        gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
600                        gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
601                        gen_ai.usage.reasoning_tokens = tracing::field::Empty,
602                        gen_ai.input.messages = tracing::field::Empty,
603                        gen_ai.output.messages = tracing::field::Empty,
604                    );
605
606                    let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
607                        let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
608                        chat_span.follows_from(id).to_owned()
609                    } else {
610                        chat_span
611                    };
612
613                    if let Some(id) = chat_span.id() {
614                        current_span_id.store(id.into_u64(), Ordering::SeqCst);
615                    };
616
617                    let prepared_request = build_prepared_completion_request(
618                        &self.model,
619                        prompt.clone(),
620                        &history,
621                        self.preamble.as_deref(),
622                        &self.static_context,
623                        self.temperature,
624                        self.max_tokens,
625                        self.additional_params.as_ref(),
626                        self.tool_choice.as_ref(),
627                        &self.tool_server_handle,
628                        &self.dynamic_context,
629                        self.output_schema.as_ref(),
630                    )
631                    .await?;
632
633                    let resp = prepared_request
634                        .builder
635                        .send()
636                        .instrument(chat_span.clone())
637                        .await?;
638
639                    let mut outcome = run.model_response(ModelTurn::new(
640                        resp.message_id.clone(),
641                        resp.choice.clone(),
642                        resp.usage,
643                        prepared_request.executable_tool_names,
644                        prepared_request.allowed_tool_names,
645                    ))?;
646
647                    loop {
648                        match outcome {
649                            ModelTurnOutcome::NeedsResolution(context) => {
650                                let action = match self.hook.as_ref() {
651                                    Some(hook) => hook.on_invalid_tool_call(&context).await,
652                                    None => InvalidToolCallHookAction::fail(),
653                                };
654                                outcome = run.resolve_invalid_tool_call(action)?;
655                            }
656                            ModelTurnOutcome::TurnRetried => break,
657                            ModelTurnOutcome::Continue {
658                                response_hook_suppressed,
659                            } => {
660                                if !response_hook_suppressed
661                                    && let Some(ref hook) = self.hook
662                                    && let HookAction::Terminate { reason } =
663                                        hook.on_completion_response(&prompt, &resp).await
664                                {
665                                    return Err(run.cancel_error(reason));
666                                }
667                                break;
668                            }
669                        }
670                    }
671                }
672                AgentRunStep::CallTools { calls } => {
673                    let hook = self.hook.clone();
674                    let tool_server_handle = self.tool_server_handle.clone();
675
676                    // For error handling in concurrent tool execution, we need to build full history
677                    let full_history_for_errors = run.full_history();
678
679                    let tool_content = stream::iter(calls)
680                        .map(|pending| {
681                            let hook1 = hook.clone();
682                            let hook2 = hook.clone();
683                            let tool_server_handle = tool_server_handle.clone();
684
685                            let tool_span = info_span!(
686                                "execute_tool",
687                                gen_ai.operation.name = "execute_tool",
688                                gen_ai.tool.type = "function",
689                                gen_ai.tool.name = tracing::field::Empty,
690                                gen_ai.tool.call.id = tracing::field::Empty,
691                                gen_ai.tool.call.arguments = tracing::field::Empty,
692                                gen_ai.tool.call.result = tracing::field::Empty
693                            );
694
695                            let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
696                                let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
697                                tool_span.follows_from(id).to_owned()
698                            } else {
699                                tool_span
700                            };
701
702                            if let Some(id) = tool_span.id() {
703                                current_span_id.store(id.into_u64(), Ordering::SeqCst);
704                            };
705
706                            // Clone full history for error reporting in concurrent tool execution
707                            let cloned_history_for_error = full_history_for_errors.clone();
708
709                            async move {
710                                let PendingToolCall {
711                                    tool_call,
712                                    preresolved_result,
713                                    ..
714                                } = pending;
715                                let tool_name = &tool_call.function.name;
716                                let args =
717                                    json_utils::value_to_json_string(&tool_call.function.arguments);
718                                let internal_call_id = nanoid::nanoid!();
719                                if let Some(result) = preresolved_result {
720                                    return Ok(result);
721                                }
722                                let tool_span = tracing::Span::current();
723                                tool_span.record("gen_ai.tool.name", tool_name);
724                                tool_span.record("gen_ai.tool.call.id", &tool_call.id);
725                                tool_span.record("gen_ai.tool.call.arguments", &args);
726                                if let Some(hook) = hook1 {
727                                    let action = hook
728                                        .on_tool_call(
729                                            tool_name,
730                                            tool_call.call_id.clone(),
731                                            &internal_call_id,
732                                            &args,
733                                        )
734                                        .await;
735
736                                    if let ToolCallHookAction::Terminate { reason } = action {
737                                        return Err(PromptError::prompt_cancelled(
738                                            cloned_history_for_error,
739                                            reason,
740                                        ));
741                                    }
742
743                                    if let ToolCallHookAction::Skip { reason } = action {
744                                        // Tool execution rejected, return rejection message as tool result
745                                        tracing::info!(
746                                            tool_name = tool_name,
747                                            reason = reason,
748                                            "Tool call rejected"
749                                        );
750                                        if let Some(call_id) = tool_call.call_id.clone() {
751                                            return Ok(UserContent::tool_result_with_call_id(
752                                                tool_call.id.clone(),
753                                                call_id,
754                                                OneOrMany::one(reason.into()),
755                                            ));
756                                        } else {
757                                            return Ok(UserContent::tool_result(
758                                                tool_call.id.clone(),
759                                                OneOrMany::one(reason.into()),
760                                            ));
761                                        }
762                                    }
763                                }
764                                let output =
765                                    match tool_server_handle.call_tool(tool_name, &args).await {
766                                        Ok(res) => res,
767                                        Err(e) => {
768                                            tracing::warn!("Error while executing tool: {e}");
769                                            e.to_string()
770                                        }
771                                    };
772                                if let Some(hook) = hook2
773                                    && let HookAction::Terminate { reason } = hook
774                                        .on_tool_result(
775                                            tool_name,
776                                            tool_call.call_id.clone(),
777                                            &internal_call_id,
778                                            &args,
779                                            &output.to_string(),
780                                        )
781                                        .await
782                                {
783                                    return Err(PromptError::prompt_cancelled(
784                                        cloned_history_for_error,
785                                        reason,
786                                    ));
787                                }
788
789                                tool_span.record("gen_ai.tool.call.result", &output);
790                                tracing::info!(
791                                    "executed tool {tool_name} with args {args}. result: {output}"
792                                );
793                                if let Some(call_id) = tool_call.call_id.clone() {
794                                    Ok(UserContent::tool_result_with_call_id(
795                                        tool_call.id.clone(),
796                                        call_id,
797                                        ToolResultContent::from_tool_output(output),
798                                    ))
799                                } else {
800                                    Ok(UserContent::tool_result(
801                                        tool_call.id.clone(),
802                                        ToolResultContent::from_tool_output(output),
803                                    ))
804                                }
805                            }
806                            .instrument(tool_span)
807                        })
808                        .buffer_unordered(self.concurrency)
809                        .collect::<Vec<Result<UserContent, PromptError>>>()
810                        .await
811                        .into_iter()
812                        .collect::<Result<Vec<_>, _>>()?;
813
814                    run.tool_results(tool_content)?;
815                }
816                AgentRunStep::Done(response) => {
817                    if self.max_turns > 1 {
818                        tracing::info!("Depth reached: {}/{}", run.turn(), self.max_turns);
819                    }
820
821                    let usage = response.usage;
822                    agent_span.record("gen_ai.completion", &response.output);
823                    agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
824                    agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
825                    agent_span.record(
826                        "gen_ai.usage.cache_read.input_tokens",
827                        usage.cached_input_tokens,
828                    );
829                    agent_span.record(
830                        "gen_ai.usage.cache_creation.input_tokens",
831                        usage.cache_creation_input_tokens,
832                    );
833                    agent_span.record(
834                        "gen_ai.usage.tool_use_prompt_tokens",
835                        usage.tool_use_prompt_tokens,
836                    );
837                    agent_span.record("gen_ai.usage.reasoning_tokens", usage.reasoning_tokens);
838
839                    if let Some((memory, id)) = memory_handle.as_ref()
840                        && let Err(err) = memory
841                            .append(id, response.messages.clone().unwrap_or_default())
842                            .await
843                    {
844                        tracing::warn!(
845                            error = %err,
846                            conversation_id = %id,
847                            "conversation memory append failed; returning model response anyway"
848                        );
849                    }
850
851                    return Ok(response);
852                }
853            }
854        }
855    }
856}
857
858// ================================================================
859// TypedPromptRequest - for structured output with automatic deserialization
860// ================================================================
861
862use crate::completion::StructuredOutputError;
863use schemars::{JsonSchema, schema_for};
864use serde::de::DeserializeOwned;
865
866/// A builder for creating typed prompt requests that return deserialized structured output.
867///
868/// This struct wraps a standard `PromptRequest` and adds:
869/// - Automatic JSON schema generation from the target type `T`
870/// - Automatic deserialization of the response into `T`
871///
872/// The type parameter `S` represents the state of the request (Standard or Extended).
873/// Use `.extended_details()` to transition to Extended state for usage tracking.
874///
875/// # Example
876/// ```rust,ignore
877/// let forecast: WeatherForecast = agent
878///     .prompt_typed("What's the weather in NYC?")
879///     .max_turns(3)
880///     .await?;
881/// ```
882pub struct TypedPromptRequest<T, S, M, P>
883where
884    T: JsonSchema + DeserializeOwned + WasmCompatSend,
885    S: PromptType,
886    M: CompletionModel,
887    P: PromptHook<M>,
888{
889    inner: PromptRequest<S, M, P>,
890    _phantom: std::marker::PhantomData<T>,
891}
892
893impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
894where
895    T: JsonSchema + DeserializeOwned + WasmCompatSend,
896    M: CompletionModel,
897    P: PromptHook<M>,
898{
899    /// Create a new TypedPromptRequest from an agent.
900    ///
901    /// This automatically sets the output schema based on the type parameter `T`.
902    pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
903        let mut inner = PromptRequest::from_agent(agent, prompt);
904        // Override the output schema with the schema for T
905        inner.output_schema = Some(schema_for!(T));
906        Self {
907            inner,
908            _phantom: std::marker::PhantomData,
909        }
910    }
911}
912
913impl<T, S, M, P> TypedPromptRequest<T, S, M, P>
914where
915    T: JsonSchema + DeserializeOwned + WasmCompatSend,
916    S: PromptType,
917    M: CompletionModel,
918    P: PromptHook<M>,
919{
920    /// Enable returning extended details for responses (includes aggregated token usage).
921    ///
922    /// Note: This changes the type of the response from `.send()` to return a `TypedPromptResponse<T>` struct
923    /// instead of just `T`. This is useful for tracking token usage across multiple turns
924    /// of conversation.
925    pub fn extended_details(self) -> TypedPromptRequest<T, Extended, M, P> {
926        TypedPromptRequest {
927            inner: self.inner.extended_details(),
928            _phantom: std::marker::PhantomData,
929        }
930    }
931
932    /// Set the maximum number of turns for multi-turn conversations.
933    ///
934    /// A given agent may require multiple turns for tool-calling before giving an answer.
935    /// If the maximum turn number is exceeded, it will return a
936    /// [`StructuredOutputError::PromptError`] wrapping a `MaxTurnsError`.
937    pub fn max_turns(mut self, depth: usize) -> Self {
938        self.inner = self.inner.max_turns(depth);
939        self
940    }
941
942    /// Set the retry budget for invalid tool-call recovery.
943    ///
944    /// Invalid tool-call retries also consume normal multi-turn depth.
945    pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
946        self.inner = self.inner.max_invalid_tool_call_retries(retries);
947        self
948    }
949
950    /// Add concurrency to the prompt request.
951    ///
952    /// This will cause the agent to execute tools concurrently.
953    pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
954        self.inner = self.inner.with_tool_concurrency(concurrency);
955        self
956    }
957
958    /// Add chat history to the prompt request.
959    pub fn with_history<H, U>(mut self, history: H) -> Self
960    where
961        H: IntoIterator<Item = U>,
962        U: Into<Message>,
963    {
964        self.inner = self.inner.with_history(history);
965        self
966    }
967
968    /// Set the conversation id used to load and persist memory for this request.
969    ///
970    /// Overrides any default conversation id set on the agent. If memory is not
971    /// configured on the agent, this has no effect.
972    pub fn conversation(mut self, id: impl Into<String>) -> Self {
973        self.inner = self.inner.conversation(id);
974        self
975    }
976
977    /// Disable conversation memory for this request.
978    ///
979    /// History will neither be loaded from nor saved to the agent's memory backend.
980    pub fn without_memory(mut self) -> Self {
981        self.inner = self.inner.without_memory();
982        self
983    }
984
985    /// Attach a per-request hook for tool call events.
986    ///
987    /// This overrides any default hook set on the agent.
988    pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<T, S, M, P2>
989    where
990        P2: PromptHook<M>,
991    {
992        TypedPromptRequest {
993            inner: self.inner.with_hook(hook),
994            _phantom: std::marker::PhantomData,
995        }
996    }
997}
998
999impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
1000where
1001    T: JsonSchema + DeserializeOwned + WasmCompatSend,
1002    M: CompletionModel,
1003    P: PromptHook<M>,
1004{
1005    /// Send the typed prompt request and deserialize the response.
1006    async fn send(self) -> Result<T, StructuredOutputError> {
1007        let response = self.inner.send().await.map_err(Box::new)?;
1008
1009        if response.is_empty() {
1010            return Err(StructuredOutputError::EmptyResponse);
1011        }
1012
1013        let parsed: T = serde_json::from_str(&response)?;
1014        Ok(parsed)
1015    }
1016}
1017
1018impl<T, M, P> TypedPromptRequest<T, Extended, M, P>
1019where
1020    T: JsonSchema + DeserializeOwned + WasmCompatSend,
1021    M: CompletionModel,
1022    P: PromptHook<M>,
1023{
1024    /// Send the typed prompt request with extended details and deserialize the response.
1025    async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
1026        let response = self.inner.send().await.map_err(Box::new)?;
1027
1028        if response.output.is_empty() {
1029            return Err(StructuredOutputError::EmptyResponse);
1030        }
1031
1032        let parsed: T = serde_json::from_str(&response.output)?;
1033        Ok(TypedPromptResponse::new(parsed, response.usage)
1034            .with_completion_calls(response.completion_calls))
1035    }
1036}
1037
1038impl<T, M, P> IntoFuture for TypedPromptRequest<T, Standard, M, P>
1039where
1040    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
1041    M: CompletionModel + 'static,
1042    P: PromptHook<M> + 'static,
1043{
1044    type Output = Result<T, StructuredOutputError>;
1045    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1046
1047    fn into_future(self) -> Self::IntoFuture {
1048        Box::pin(self.send())
1049    }
1050}
1051
1052impl<T, M, P> IntoFuture for TypedPromptRequest<T, Extended, M, P>
1053where
1054    T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
1055    M: CompletionModel + 'static,
1056    P: PromptHook<M> + 'static,
1057{
1058    type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
1059    type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
1060
1061    fn into_future(self) -> Self::IntoFuture {
1062        Box::pin(self.send())
1063    }
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068    use super::{CompletionCall, PromptResponse, TypedPromptResponse};
1069    use crate::{
1070        agent::{
1071            AgentBuilder,
1072            prompt_request::hooks::{
1073                HookAction, InvalidToolCallContext, InvalidToolCallHookAction, PromptHook,
1074                ToolCallHookAction,
1075            },
1076        },
1077        completion::{
1078            AssistantContent, CompletionError, CompletionModel, CompletionRequest, Message, Prompt,
1079            PromptError, StructuredOutputError, ToolDefinition, TypedPrompt, Usage,
1080        },
1081        message::{Text, ToolCall, ToolChoice, ToolFunction, UserContent},
1082        test_utils::{
1083            AppendFailingMemory, CountingMemory, FailingMemory, MockAddTool, MockCompletionModel,
1084            MockOperationArgs, MockSubtractTool, MockToolError, MockTurn,
1085        },
1086        tool::Tool,
1087    };
1088    use schemars::JsonSchema;
1089    use serde::{Deserialize, Serialize};
1090    use serde_json::json;
1091    use std::sync::{
1092        Arc, Mutex,
1093        atomic::{AtomicU32, Ordering},
1094    };
1095
1096    #[derive(Serialize)]
1097    struct SerializeOnly {
1098        value: &'static str,
1099    }
1100
1101    #[derive(Deserialize)]
1102    struct DeserializeOnly {
1103        value: String,
1104    }
1105
1106    #[derive(Debug, Deserialize, JsonSchema, PartialEq)]
1107    struct TypedAnswer {
1108        value: String,
1109    }
1110
1111    #[derive(Clone)]
1112    struct PanicOnUnknownToolHook;
1113
1114    impl PromptHook<MockCompletionModel> for PanicOnUnknownToolHook {
1115        async fn on_completion_response(
1116            &self,
1117            _prompt: &Message,
1118            _response: &crate::completion::CompletionResponse<
1119                <MockCompletionModel as CompletionModel>::Response,
1120            >,
1121        ) -> HookAction {
1122            panic!("unknown tool response should fail before response hooks run")
1123        }
1124
1125        async fn on_tool_call(
1126            &self,
1127            _tool_name: &str,
1128            _tool_call_id: Option<String>,
1129            _internal_call_id: &str,
1130            _args: &str,
1131        ) -> ToolCallHookAction {
1132            panic!("unknown tool call should fail before tool hooks run")
1133        }
1134    }
1135
1136    #[derive(Clone)]
1137    struct PanicOnToolCallHook;
1138
1139    impl PromptHook<MockCompletionModel> for PanicOnToolCallHook {
1140        async fn on_tool_call(
1141            &self,
1142            _tool_name: &str,
1143            _tool_call_id: Option<String>,
1144            _internal_call_id: &str,
1145            _args: &str,
1146        ) -> ToolCallHookAction {
1147            panic!("recovered invalid turn should not invoke normal tool hooks")
1148        }
1149    }
1150
1151    #[derive(Clone)]
1152    struct SkipDefaultApiAndPanicOnToolCallHook;
1153
1154    impl PromptHook<MockCompletionModel> for SkipDefaultApiAndPanicOnToolCallHook {
1155        async fn on_invalid_tool_call(
1156            &self,
1157            context: &InvalidToolCallContext,
1158        ) -> InvalidToolCallHookAction {
1159            SkipDefaultApiHook.on_invalid_tool_call(context).await
1160        }
1161
1162        async fn on_tool_call(
1163            &self,
1164            tool_name: &str,
1165            tool_call_id: Option<String>,
1166            internal_call_id: &str,
1167            args: &str,
1168        ) -> ToolCallHookAction {
1169            PanicOnToolCallHook
1170                .on_tool_call(tool_name, tool_call_id, internal_call_id, args)
1171                .await
1172        }
1173    }
1174
1175    #[derive(Clone)]
1176    struct RepairDefaultApiHook;
1177
1178    impl PromptHook<MockCompletionModel> for RepairDefaultApiHook {
1179        fn on_invalid_tool_call(
1180            &self,
1181            context: &InvalidToolCallContext,
1182        ) -> impl std::future::Future<Output = InvalidToolCallHookAction> + Send {
1183            let tool_name = context.tool_name.clone();
1184            async move {
1185                assert_eq!(tool_name, "default_api");
1186                InvalidToolCallHookAction::repair("add")
1187            }
1188        }
1189    }
1190
1191    #[derive(Clone)]
1192    struct RepairToSubtractHook;
1193
1194    impl PromptHook<MockCompletionModel> for RepairToSubtractHook {
1195        async fn on_invalid_tool_call(
1196            &self,
1197            _context: &InvalidToolCallContext,
1198        ) -> InvalidToolCallHookAction {
1199            InvalidToolCallHookAction::repair("subtract")
1200        }
1201    }
1202
1203    #[derive(Clone)]
1204    struct RetryDefaultApiHook;
1205
1206    impl PromptHook<MockCompletionModel> for RetryDefaultApiHook {
1207        fn on_invalid_tool_call(
1208            &self,
1209            context: &InvalidToolCallContext,
1210        ) -> impl std::future::Future<Output = InvalidToolCallHookAction> + Send {
1211            let allowed_tools = context.allowed_tools.clone();
1212            async move {
1213                InvalidToolCallHookAction::retry(format!(
1214                    "Use one of these tools instead: {allowed_tools:?}"
1215                ))
1216            }
1217        }
1218    }
1219
1220    #[derive(Clone)]
1221    struct SkipDefaultApiHook;
1222
1223    impl PromptHook<MockCompletionModel> for SkipDefaultApiHook {
1224        async fn on_invalid_tool_call(
1225            &self,
1226            _context: &InvalidToolCallContext,
1227        ) -> InvalidToolCallHookAction {
1228            InvalidToolCallHookAction::skip("default_api is not available")
1229        }
1230    }
1231
1232    #[derive(Clone, Default)]
1233    struct RecordingInvalidToolCallHook {
1234        contexts: Arc<Mutex<Vec<InvalidToolCallContext>>>,
1235    }
1236
1237    impl RecordingInvalidToolCallHook {
1238        fn observed(&self) -> Vec<InvalidToolCallContext> {
1239            self.contexts
1240                .lock()
1241                .expect("invalid tool context records mutex was poisoned")
1242                .clone()
1243        }
1244    }
1245
1246    impl PromptHook<MockCompletionModel> for RecordingInvalidToolCallHook {
1247        async fn on_invalid_tool_call(
1248            &self,
1249            context: &InvalidToolCallContext,
1250        ) -> InvalidToolCallHookAction {
1251            self.contexts
1252                .lock()
1253                .expect("invalid tool context records mutex was poisoned")
1254                .push(context.clone());
1255            InvalidToolCallHookAction::fail()
1256        }
1257    }
1258
1259    #[derive(Clone)]
1260    struct CountingAddTool {
1261        calls: Arc<AtomicU32>,
1262    }
1263
1264    impl Tool for CountingAddTool {
1265        const NAME: &'static str = "add";
1266        type Error = MockToolError;
1267        type Args = MockOperationArgs;
1268        type Output = i32;
1269
1270        async fn definition(&self, _prompt: String) -> ToolDefinition {
1271            MockAddTool.definition(String::new()).await
1272        }
1273
1274        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
1275            self.calls.fetch_add(1, Ordering::SeqCst);
1276            Ok(0)
1277        }
1278    }
1279
1280    fn usage(input_tokens: u64, output_tokens: u64) -> Usage {
1281        Usage {
1282            input_tokens,
1283            output_tokens,
1284            total_tokens: input_tokens + output_tokens,
1285            cached_input_tokens: 0,
1286            cache_creation_input_tokens: 0,
1287            tool_use_prompt_tokens: 0,
1288            reasoning_tokens: 0,
1289        }
1290    }
1291
1292    #[test]
1293    fn typed_prompt_response_serializes_with_serialize_only_output() {
1294        let response = TypedPromptResponse::new(
1295            SerializeOnly { value: "ok" },
1296            Usage {
1297                input_tokens: 1,
1298                output_tokens: 2,
1299                total_tokens: 3,
1300                cached_input_tokens: 0,
1301                cache_creation_input_tokens: 0,
1302                tool_use_prompt_tokens: 0,
1303                reasoning_tokens: 0,
1304            },
1305        );
1306
1307        let json = serde_json::to_string(&response).expect("serialize typed prompt response");
1308        assert!(json.contains("\"value\":\"ok\""));
1309    }
1310
1311    #[test]
1312    fn typed_prompt_response_deserializes_with_deserialize_only_output() {
1313        let response: TypedPromptResponse<DeserializeOnly> = serde_json::from_str(
1314            r#"{"output":{"value":"ok"},"usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3,"cached_input_tokens":0,"cache_creation_input_tokens":0,"reasoning_tokens":0}}"#,
1315        )
1316        .expect("deserialize typed prompt response");
1317
1318        assert_eq!(response.requests(), 0);
1319        assert_eq!(response.output.value, "ok");
1320        assert_eq!(response.usage.input_tokens, 1);
1321        assert_eq!(response.usage.output_tokens, 2);
1322        assert_eq!(response.usage.total_tokens, 3);
1323    }
1324
1325    #[test]
1326    fn prompt_response_serializes_completion_calls_with_missing_usage() {
1327        let reported_usage = usage(3, 4);
1328        let response = PromptResponse::new("ok", reported_usage).with_completion_calls(vec![
1329            CompletionCall::new(0, Usage::new()),
1330            CompletionCall::new(1, reported_usage),
1331        ]);
1332
1333        let value = serde_json::to_value(&response).expect("serialize prompt response");
1334
1335        // Unreported usage serializes as a plain zero-valued object: zero is
1336        // Usage's documented sentinel for missing provider metrics, so there
1337        // is no null encoding to keep in sync.
1338        assert_eq!(
1339            value.get("completion_calls"),
1340            Some(&json!([
1341                {
1342                    "call_index": 0,
1343                    "usage": {
1344                        "input_tokens": 0,
1345                        "output_tokens": 0,
1346                        "total_tokens": 0,
1347                        "cached_input_tokens": 0,
1348                        "cache_creation_input_tokens": 0,
1349                        "tool_use_prompt_tokens": 0,
1350                        "reasoning_tokens": 0,
1351                    }
1352                },
1353                {
1354                    "call_index": 1,
1355                    "usage": {
1356                        "input_tokens": 3,
1357                        "output_tokens": 4,
1358                        "total_tokens": 7,
1359                        "cached_input_tokens": 0,
1360                        "cache_creation_input_tokens": 0,
1361                        "tool_use_prompt_tokens": 0,
1362                        "reasoning_tokens": 0,
1363                    }
1364                }
1365            ]))
1366        );
1367
1368        let response: PromptResponse =
1369            serde_json::from_value(value).expect("deserialize prompt response");
1370        assert_eq!(
1371            response.completion_calls(),
1372            &[
1373                CompletionCall::new(0, Usage::new()),
1374                CompletionCall::new(1, reported_usage)
1375            ]
1376        );
1377        assert_eq!(response.requests(), 2);
1378    }
1379
1380    #[test]
1381    fn prompt_response_deserializes_pre_monoid_null_usage_format() {
1382        // Fixture captured from rig before CompletionCall.usage dropped its
1383        // Option encoding; `"usage": null` must map to zero-valued usage.
1384        let fixture = r#"{"output":"ok","usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"cached_input_tokens":0,"cache_creation_input_tokens":0,"tool_use_prompt_tokens":0,"reasoning_tokens":0},"completion_calls":[{"call_index":0,"usage":null},{"call_index":1,"usage":{"input_tokens":3,"output_tokens":4,"total_tokens":7,"cached_input_tokens":0,"cache_creation_input_tokens":0,"tool_use_prompt_tokens":0,"reasoning_tokens":0}}],"messages":[{"role":"user","content":[{"type":"text","text":"add things"}]}]}"#;
1385
1386        let response: PromptResponse =
1387            serde_json::from_str(fixture).expect("old-format response should deserialize");
1388        assert_eq!(
1389            response.completion_calls(),
1390            &[
1391                CompletionCall::new(0, Usage::new()),
1392                CompletionCall::new(1, usage(3, 4))
1393            ]
1394        );
1395    }
1396
1397    #[tokio::test]
1398    async fn prompt_response_records_completion_call_without_reported_usage() {
1399        let model = MockCompletionModel::new([MockTurn::text("ok")]);
1400        let agent = AgentBuilder::new(model).build();
1401
1402        let response = agent
1403            .prompt("say ok")
1404            .extended_details()
1405            .await
1406            .expect("prompt should succeed");
1407
1408        assert_eq!(response.output, "ok");
1409        assert_eq!(response.usage, Usage::new());
1410        assert_eq!(
1411            response.completion_calls(),
1412            &[CompletionCall::new(0, Usage::new())]
1413        );
1414    }
1415
1416    #[tokio::test]
1417    async fn typed_prompt_response_preserves_completion_calls() {
1418        let call_usage = Usage {
1419            input_tokens: 4,
1420            output_tokens: 6,
1421            total_tokens: 10,
1422            cached_input_tokens: 0,
1423            cache_creation_input_tokens: 0,
1424            tool_use_prompt_tokens: 0,
1425            reasoning_tokens: 0,
1426        };
1427        let model =
1428            MockCompletionModel::new([MockTurn::text(r#"{"value":"ok"}"#).with_usage(call_usage)]);
1429        let agent = AgentBuilder::new(model).build();
1430
1431        let response = agent
1432            .prompt_typed::<TypedAnswer>("return typed json")
1433            .extended_details()
1434            .await
1435            .expect("typed prompt should succeed");
1436
1437        assert_eq!(
1438            response.output,
1439            TypedAnswer {
1440                value: "ok".to_string()
1441            }
1442        );
1443        assert_eq!(response.usage, call_usage);
1444        assert_eq!(
1445            response.completion_calls(),
1446            &[CompletionCall::new(0, call_usage)]
1447        );
1448    }
1449
1450    fn validate_follow_up_tool_history(request: &CompletionRequest) {
1451        let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
1452        assert_eq!(
1453            history.len(),
1454            3,
1455            "follow-up request should contain the prompt, assistant tool call, and user tool result: {history:?}"
1456        );
1457
1458        assert!(matches!(
1459            history.first(),
1460            Some(Message::User { content })
1461                if matches!(
1462                    content.first(),
1463                    UserContent::Text(text) if text.text == "do tool work"
1464                )
1465        ));
1466
1467        assert!(matches!(
1468            history.get(1),
1469            Some(Message::Assistant { content, .. })
1470                if matches!(
1471                    content.first(),
1472                    AssistantContent::ToolCall(tool_call)
1473                        if tool_call.id == "tool_call_1"
1474                            && tool_call.call_id.as_deref() == Some("call_1")
1475                )
1476        ));
1477
1478        assert!(matches!(
1479            history.get(2),
1480            Some(Message::User { content })
1481                if matches!(
1482                    content.first(),
1483                    UserContent::ToolResult(tool_result)
1484                        if tool_result.id == "tool_call_1"
1485                            && tool_result.call_id.as_deref() == Some("call_1")
1486                )
1487        ));
1488    }
1489
1490    fn history_contains_tool_call(history: &[Message], tool_name: &str) -> bool {
1491        history.iter().any(|message| {
1492            matches!(
1493                message,
1494                Message::Assistant { content, .. }
1495                    if content.iter().any(|item| matches!(
1496                        item,
1497                        AssistantContent::ToolCall(tool_call)
1498                            if tool_call.function.name == tool_name
1499                    ))
1500            )
1501        })
1502    }
1503
1504    #[tokio::test]
1505    async fn unknown_tool_call_fails_before_non_streaming_second_request() {
1506        let model = MockCompletionModel::new([
1507            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 1, "y": 2})),
1508            MockTurn::text("should not be requested"),
1509        ]);
1510        let recorded = model.clone();
1511        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1512
1513        let err = agent
1514            .prompt("use the tool")
1515            .with_hook(PanicOnUnknownToolHook)
1516            .max_turns(3)
1517            .await
1518            .expect_err("unknown model-emitted tool should fail");
1519
1520        match err {
1521            PromptError::UnknownToolCall {
1522                tool_name,
1523                available_tools,
1524                allowed_tools,
1525                chat_history,
1526            } => {
1527                assert_eq!(tool_name, "default_api");
1528                assert_eq!(available_tools, vec!["add".to_string()]);
1529                assert_eq!(allowed_tools, vec!["add".to_string()]);
1530                assert!(history_contains_tool_call(&chat_history, "default_api"));
1531            }
1532            other => panic!("expected UnknownToolCall, got {other:?}"),
1533        }
1534        assert_eq!(recorded.request_count(), 1);
1535    }
1536
1537    #[tokio::test]
1538    async fn invalid_tool_call_context_uses_completed_tool_call_provider_id() {
1539        let invalid_hook = RecordingInvalidToolCallHook::default();
1540        let model = MockCompletionModel::new([
1541            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 1, "y": 2}))
1542                .with_call_id("provider_call_1"),
1543            MockTurn::text("should not be requested"),
1544        ]);
1545        let recorded = model.clone();
1546        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1547
1548        let err = agent
1549            .prompt("use the tool")
1550            .with_hook(invalid_hook.clone())
1551            .max_turns(3)
1552            .await
1553            .expect_err("invalid tool should fail");
1554
1555        assert!(matches!(err, PromptError::UnknownToolCall { .. }));
1556        assert_eq!(recorded.request_count(), 1);
1557        let contexts = invalid_hook.observed();
1558        assert_eq!(contexts.len(), 1);
1559        let context = &contexts[0];
1560        assert_eq!(context.tool_name, "default_api");
1561        assert_eq!(context.tool_call_id.as_deref(), Some("tool_call_1"));
1562        assert_eq!(context.internal_call_id, None);
1563        assert!(!context.is_streaming);
1564    }
1565
1566    #[tokio::test]
1567    async fn disallowed_specific_tool_call_fails_before_non_streaming_second_request() {
1568        let model = MockCompletionModel::new([
1569            MockTurn::tool_call("tool_call_1", "subtract", json!({"x": 3, "y": 1})),
1570            MockTurn::text("should not be requested"),
1571        ]);
1572        let recorded = model.clone();
1573        let agent = AgentBuilder::new(model)
1574            .tool(MockAddTool)
1575            .tool(MockSubtractTool)
1576            .tool_choice(ToolChoice::Specific {
1577                function_names: vec!["add".to_string()],
1578            })
1579            .build();
1580
1581        let err = agent
1582            .prompt("use the allowed tool")
1583            .with_hook(PanicOnUnknownToolHook)
1584            .max_turns(3)
1585            .await
1586            .expect_err("disallowed model-emitted tool should fail");
1587
1588        match err {
1589            PromptError::UnknownToolCall {
1590                tool_name,
1591                available_tools,
1592                allowed_tools,
1593                chat_history,
1594            } => {
1595                assert_eq!(tool_name, "subtract");
1596                assert_eq!(
1597                    available_tools,
1598                    vec!["add".to_string(), "subtract".to_string()]
1599                );
1600                assert_eq!(allowed_tools, vec!["add".to_string()]);
1601                assert!(history_contains_tool_call(&chat_history, "subtract"));
1602            }
1603            other => panic!("expected UnknownToolCall, got {other:?}"),
1604        }
1605        assert_eq!(recorded.request_count(), 1);
1606    }
1607
1608    #[tokio::test]
1609    async fn tool_choice_none_rejects_non_streaming_tool_call() {
1610        let model = MockCompletionModel::new([
1611            MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2})),
1612            MockTurn::text("should not be requested"),
1613        ]);
1614        let recorded = model.clone();
1615        let agent = AgentBuilder::new(model)
1616            .tool(MockAddTool)
1617            .tool_choice(ToolChoice::None)
1618            .build();
1619
1620        let err = agent
1621            .prompt("do not use tools")
1622            .with_hook(PanicOnUnknownToolHook)
1623            .max_turns(3)
1624            .await
1625            .expect_err("ToolChoice::None should reject returned tool calls");
1626
1627        match err {
1628            PromptError::UnknownToolCall {
1629                tool_name,
1630                available_tools,
1631                allowed_tools,
1632                chat_history,
1633            } => {
1634                assert_eq!(tool_name, "add");
1635                assert_eq!(available_tools, vec!["add".to_string()]);
1636                assert!(allowed_tools.is_empty());
1637                assert!(history_contains_tool_call(&chat_history, "add"));
1638            }
1639            other => panic!("expected UnknownToolCall, got {other:?}"),
1640        }
1641        assert_eq!(recorded.request_count(), 1);
1642    }
1643
1644    #[tokio::test]
1645    async fn invalid_tool_call_hook_can_repair_non_streaming_tool_name() {
1646        let model = MockCompletionModel::new([
1647            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1648            MockTurn::text("done"),
1649        ]);
1650        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1651
1652        let response = agent
1653            .prompt("add")
1654            .with_hook(RepairDefaultApiHook)
1655            .max_turns(3)
1656            .extended_details()
1657            .await
1658            .expect("repaired tool call should execute");
1659
1660        assert_eq!(response.output, "done");
1661        let messages = response.messages.expect("messages should be present");
1662        assert!(history_contains_tool_call(&messages, "add"));
1663        assert!(!history_contains_tool_call(&messages, "default_api"));
1664        assert!(messages.iter().any(|message| {
1665            matches!(
1666                message,
1667                Message::User { content }
1668                    if content.iter().any(|content| {
1669                        matches!(
1670                            content,
1671                            UserContent::ToolResult(result)
1672                                if result.content.iter().any(|content| {
1673                                    matches!(
1674                                        content,
1675                                        crate::message::ToolResultContent::Text(text)
1676                                            if text.text == "5"
1677                                    )
1678                                })
1679                        )
1680                    })
1681            )
1682        }));
1683    }
1684
1685    #[tokio::test]
1686    async fn invalid_tool_call_hook_retry_adds_feedback_and_retries_non_streaming() {
1687        let model = MockCompletionModel::new([
1688            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1689            MockTurn::text("retried"),
1690        ]);
1691        let recorded = model.clone();
1692        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1693
1694        let response = agent
1695            .prompt("add")
1696            .with_hook(RetryDefaultApiHook)
1697            .max_invalid_tool_call_retries(1)
1698            .max_turns(3)
1699            .extended_details()
1700            .await
1701            .expect("retry should recover");
1702
1703        assert_eq!(response.output, "retried");
1704        assert_eq!(recorded.request_count(), 2);
1705        let messages = response.messages.expect("messages should be present");
1706        assert!(messages.iter().any(|message| {
1707            matches!(
1708                message,
1709                Message::User { content }
1710                    if content.iter().any(|content| {
1711                        matches!(
1712                            content,
1713                            UserContent::ToolResult(result)
1714                                if result.content.iter().any(|content| {
1715                                    matches!(
1716                                        content,
1717                                        crate::message::ToolResultContent::Text(text)
1718                                            if text.text.contains("Use one of these tools instead")
1719                                    )
1720                                })
1721                        )
1722                    })
1723            )
1724        }));
1725    }
1726
1727    #[tokio::test]
1728    async fn invalid_tool_call_hook_retries_mixed_non_streaming_turn_without_executing_valid_call()
1729    {
1730        let add_calls = Arc::new(AtomicU32::new(0));
1731        let mut valid_tool_call = ToolCall::new(
1732            "tool_call_1".to_string(),
1733            ToolFunction::new("add".to_string(), json!({"x": 2, "y": 3})),
1734        );
1735        valid_tool_call.call_id = Some("call_1".to_string());
1736        let mut invalid_tool_call = ToolCall::new(
1737            "tool_call_2".to_string(),
1738            ToolFunction::new("default_api".to_string(), json!({"x": 4, "y": 5})),
1739        );
1740        invalid_tool_call.call_id = Some("call_2".to_string());
1741        let model = MockCompletionModel::new([
1742            MockTurn::from_contents([
1743                AssistantContent::ToolCall(valid_tool_call),
1744                AssistantContent::ToolCall(invalid_tool_call),
1745            ])
1746            .expect("tool-call response should be non-empty"),
1747            MockTurn::text("retried"),
1748        ]);
1749        let recorded = model.clone();
1750        let agent = AgentBuilder::new(model)
1751            .tool(CountingAddTool {
1752                calls: add_calls.clone(),
1753            })
1754            .build();
1755
1756        let response = agent
1757            .prompt("add")
1758            .with_hook(RetryDefaultApiHook)
1759            .max_invalid_tool_call_retries(1)
1760            .max_turns(3)
1761            .extended_details()
1762            .await
1763            .expect("retry should recover");
1764
1765        assert_eq!(response.output, "retried");
1766        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
1767        let requests = recorded.requests();
1768        assert_eq!(requests.len(), 2);
1769        let retry_history = requests[1].chat_history.iter().cloned().collect::<Vec<_>>();
1770        assert_eq!(retry_history.len(), 3);
1771        assert!(matches!(
1772            retry_history.get(1),
1773            Some(Message::Assistant { content, .. })
1774                if content.iter().any(|item| matches!(
1775                    item,
1776                    AssistantContent::ToolCall(tool_call)
1777                        if tool_call.id == "tool_call_1"
1778                            && tool_call.function.name == "add"
1779                ))
1780                    && content.iter().any(|item| matches!(
1781                        item,
1782                        AssistantContent::ToolCall(tool_call)
1783                            if tool_call.id == "tool_call_2"
1784                                && tool_call.function.name == "default_api"
1785                    ))
1786        ));
1787        assert!(matches!(
1788            retry_history.get(2),
1789            Some(Message::User { content })
1790                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
1791                    && content.iter().any(|item| matches!(
1792                        item,
1793                        UserContent::ToolResult(result)
1794                            if result.id == "tool_call_1"
1795                                && result.call_id.as_deref() == Some("call_1")
1796                                && result.content.iter().any(|content| matches!(
1797                                    content,
1798                                    crate::message::ToolResultContent::Text(text)
1799                                        if text.text == super::TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
1800                                ))
1801                    ))
1802                    && content.iter().any(|item| matches!(
1803                        item,
1804                        UserContent::ToolResult(result)
1805                            if result.id == "tool_call_2"
1806                                && result.call_id.as_deref() == Some("call_2")
1807                                && result.content.iter().any(|content| matches!(
1808                                    content,
1809                                    crate::message::ToolResultContent::Text(text)
1810                                        if text.text.contains("Use one of these tools instead")
1811                                ))
1812            ))
1813        ));
1814    }
1815
1816    #[tokio::test]
1817    async fn invalid_tool_call_hook_skips_mixed_non_streaming_turn_without_executing_valid_call() {
1818        let add_calls = Arc::new(AtomicU32::new(0));
1819        let mut valid_tool_call = ToolCall::new(
1820            "tool_call_1".to_string(),
1821            ToolFunction::new("add".to_string(), json!({"x": 2, "y": 3})),
1822        );
1823        valid_tool_call.call_id = Some("call_1".to_string());
1824        let mut invalid_tool_call = ToolCall::new(
1825            "tool_call_2".to_string(),
1826            ToolFunction::new("default_api".to_string(), json!({"x": 4, "y": 5})),
1827        );
1828        invalid_tool_call.call_id = Some("call_2".to_string());
1829        let model = MockCompletionModel::new([
1830            MockTurn::from_contents([
1831                AssistantContent::ToolCall(valid_tool_call),
1832                AssistantContent::ToolCall(invalid_tool_call),
1833            ])
1834            .expect("tool-call response should be non-empty"),
1835            MockTurn::text("skipped"),
1836        ]);
1837        let agent = AgentBuilder::new(model)
1838            .tool(CountingAddTool {
1839                calls: add_calls.clone(),
1840            })
1841            .build();
1842
1843        let response = agent
1844            .prompt("add")
1845            .with_hook(SkipDefaultApiAndPanicOnToolCallHook)
1846            .max_turns(3)
1847            .extended_details()
1848            .await
1849            .expect("skip should recover without executing peer tools");
1850
1851        assert_eq!(response.output, "skipped");
1852        assert_eq!(add_calls.load(Ordering::SeqCst), 0);
1853        let messages = response.messages.expect("messages should be present");
1854        assert!(history_contains_tool_call(&messages, "add"));
1855        assert!(history_contains_tool_call(&messages, "default_api"));
1856        assert!(matches!(
1857            messages.get(2),
1858            Some(Message::User { content })
1859                if content.iter().filter(|item| matches!(item, UserContent::ToolResult(_))).count() == 2
1860                    && content.iter().any(|item| matches!(
1861                        item,
1862                        UserContent::ToolResult(result)
1863                            if result.id == "tool_call_1"
1864                                && result.call_id.as_deref() == Some("call_1")
1865                                && result.content.iter().any(|content| matches!(
1866                                    content,
1867                                    crate::message::ToolResultContent::Text(text)
1868                                        if text.text == super::TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER
1869                                ))
1870                    ))
1871                    && content.iter().any(|item| matches!(
1872                        item,
1873                        UserContent::ToolResult(result)
1874                            if result.id == "tool_call_2"
1875                                && result.call_id.as_deref() == Some("call_2")
1876                                && result.content.iter().any(|content| matches!(
1877                                    content,
1878                                    crate::message::ToolResultContent::Text(text)
1879                                        if text.text == "default_api is not available"
1880                                ))
1881                    ))
1882        ));
1883    }
1884
1885    #[tokio::test]
1886    async fn invalid_tool_call_hook_retry_budget_exhaustion_fails() {
1887        let model = MockCompletionModel::new([
1888            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1889            MockTurn::text("should not be requested"),
1890        ]);
1891        let recorded = model.clone();
1892        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1893
1894        let err = agent
1895            .prompt("add")
1896            .with_hook(RetryDefaultApiHook)
1897            .max_invalid_tool_call_retries(0)
1898            .max_turns(3)
1899            .await
1900            .expect_err("retry without budget should fail");
1901
1902        match err {
1903            PromptError::UnknownToolCall {
1904                tool_name,
1905                chat_history,
1906                ..
1907            } => {
1908                assert_eq!(tool_name, "default_api");
1909                assert!(history_contains_tool_call(&chat_history, "default_api"));
1910            }
1911            other => panic!("expected UnknownToolCall, got {other:?}"),
1912        }
1913        assert_eq!(recorded.request_count(), 1);
1914    }
1915
1916    #[tokio::test]
1917    async fn invalid_tool_call_hook_can_skip_structured_non_streaming_call() {
1918        let model = MockCompletionModel::new([
1919            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1920            MockTurn::text("skipped"),
1921        ]);
1922        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
1923
1924        let response = agent
1925            .prompt("add")
1926            .with_hook(SkipDefaultApiHook)
1927            .max_turns(3)
1928            .extended_details()
1929            .await
1930            .expect("skip should continue with synthetic tool result");
1931
1932        assert_eq!(response.output, "skipped");
1933        let messages = response.messages.expect("messages should be present");
1934        assert!(history_contains_tool_call(&messages, "default_api"));
1935        assert!(messages.iter().any(|message| {
1936            matches!(
1937                message,
1938                Message::User { content }
1939                    if content.iter().any(|content| {
1940                        matches!(
1941                            content,
1942                            UserContent::ToolResult(result)
1943                                if result.content.iter().any(|content| {
1944                                    matches!(
1945                                        content,
1946                                        crate::message::ToolResultContent::Text(text)
1947                                            if text.text == "default_api is not available"
1948                                    )
1949                                })
1950                        )
1951                    })
1952            )
1953        }));
1954    }
1955
1956    #[tokio::test]
1957    async fn skip_under_specific_tool_choice_returns_synthetic_feedback() {
1958        let model = MockCompletionModel::new([
1959            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
1960            MockTurn::text("skipped"),
1961        ]);
1962        let agent = AgentBuilder::new(model)
1963            .tool(MockAddTool)
1964            .tool_choice(ToolChoice::Specific {
1965                function_names: vec!["add".to_string()],
1966            })
1967            .build();
1968
1969        let response = agent
1970            .prompt("add")
1971            .with_hook(SkipDefaultApiHook)
1972            .max_turns(3)
1973            .extended_details()
1974            .await
1975            .expect("skip should produce synthetic feedback under Specific");
1976
1977        assert_eq!(response.output, "skipped");
1978        let messages = response.messages.expect("messages should be present");
1979        assert!(history_contains_tool_call(&messages, "default_api"));
1980        assert!(messages.iter().any(|message| {
1981            matches!(
1982                message,
1983                Message::User { content }
1984                    if content.iter().any(|content| {
1985                        matches!(
1986                            content,
1987                            UserContent::ToolResult(result)
1988                                if result.id == "tool_call_1"
1989                                    && result.content.iter().any(|content| {
1990                                        matches!(
1991                                            content,
1992                                            crate::message::ToolResultContent::Text(text)
1993                                                if text.text == "default_api is not available"
1994                                        )
1995                                    })
1996                        )
1997                    })
1998            )
1999        }));
2000    }
2001
2002    #[tokio::test]
2003    async fn repair_to_disallowed_specific_tool_fails() {
2004        let model = MockCompletionModel::new([
2005            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2006            MockTurn::text("should not be requested"),
2007        ]);
2008        let recorded = model.clone();
2009        let agent = AgentBuilder::new(model)
2010            .tool(MockAddTool)
2011            .tool(MockSubtractTool)
2012            .tool_choice(ToolChoice::Specific {
2013                function_names: vec!["add".to_string()],
2014            })
2015            .build();
2016
2017        let err = agent
2018            .prompt("add")
2019            .with_hook(RepairToSubtractHook)
2020            .max_turns(3)
2021            .await
2022            .expect_err("repair to a disallowed tool should fail");
2023
2024        match err {
2025            PromptError::UnknownToolCall { tool_name, .. } => {
2026                assert_eq!(tool_name, "subtract");
2027            }
2028            other => panic!("expected UnknownToolCall, got {other:?}"),
2029        }
2030        assert_eq!(recorded.request_count(), 1);
2031    }
2032
2033    #[tokio::test]
2034    async fn repair_under_tool_choice_none_fails() {
2035        let model = MockCompletionModel::new([
2036            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2037            MockTurn::text("should not be requested"),
2038        ]);
2039        let recorded = model.clone();
2040        let agent = AgentBuilder::new(model)
2041            .tool(MockAddTool)
2042            .tool_choice(ToolChoice::None)
2043            .build();
2044
2045        let err = agent
2046            .prompt("do not use tools")
2047            .with_hook(RepairDefaultApiHook)
2048            .max_turns(3)
2049            .await
2050            .expect_err("ToolChoice::None should reject repaired tool calls");
2051
2052        match err {
2053            PromptError::UnknownToolCall { tool_name, .. } => {
2054                assert_eq!(tool_name, "add");
2055            }
2056            other => panic!("expected UnknownToolCall, got {other:?}"),
2057        }
2058        assert_eq!(recorded.request_count(), 1);
2059    }
2060
2061    #[tokio::test]
2062    async fn skip_under_tool_choice_none_fails() {
2063        let model = MockCompletionModel::new([
2064            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2065            MockTurn::text("should not be requested"),
2066        ]);
2067        let recorded = model.clone();
2068        let agent = AgentBuilder::new(model)
2069            .tool(MockAddTool)
2070            .tool_choice(ToolChoice::None)
2071            .build();
2072
2073        let err = agent
2074            .prompt("do not use tools")
2075            .with_hook(SkipDefaultApiHook)
2076            .max_turns(3)
2077            .await
2078            .expect_err("ToolChoice::None should reject skipped tool calls");
2079
2080        match err {
2081            PromptError::UnknownToolCall { tool_name, .. } => {
2082                assert_eq!(tool_name, "default_api");
2083            }
2084            other => panic!("expected UnknownToolCall, got {other:?}"),
2085        }
2086        assert_eq!(recorded.request_count(), 1);
2087    }
2088
2089    #[tokio::test]
2090    async fn typed_prompt_default_invalid_tool_call_fails_fast() {
2091        let model = MockCompletionModel::new([
2092            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2093            MockTurn::text(r#"{"value":"should not be requested"}"#),
2094        ]);
2095        let recorded = model.clone();
2096        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2097
2098        let err = agent
2099            .prompt_typed::<TypedAnswer>("return typed json")
2100            .with_hook(PanicOnUnknownToolHook)
2101            .max_turns(3)
2102            .await
2103            .expect_err("typed prompt should preserve fail-fast default");
2104
2105        match err {
2106            StructuredOutputError::PromptError(err) => match *err {
2107                PromptError::UnknownToolCall { tool_name, .. } => {
2108                    assert_eq!(tool_name, "default_api");
2109                }
2110                other => panic!("expected UnknownToolCall, got {other:?}"),
2111            },
2112            other => panic!("expected prompt error, got {other:?}"),
2113        }
2114        assert_eq!(recorded.request_count(), 1);
2115    }
2116
2117    #[tokio::test]
2118    async fn typed_prompt_invalid_tool_call_hook_can_repair_tool_name() {
2119        let model = MockCompletionModel::new([
2120            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2121            MockTurn::text(r#"{"value":"repaired"}"#),
2122        ]);
2123        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2124
2125        let response = agent
2126            .prompt_typed::<TypedAnswer>("return typed json")
2127            .with_hook(RepairDefaultApiHook)
2128            .max_turns(3)
2129            .await
2130            .expect("typed prompt should repair invalid tool call");
2131
2132        assert_eq!(
2133            response,
2134            TypedAnswer {
2135                value: "repaired".to_string()
2136            }
2137        );
2138    }
2139
2140    #[tokio::test]
2141    async fn typed_prompt_invalid_tool_call_hook_can_retry_and_parse_response() {
2142        let model = MockCompletionModel::new([
2143            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2144            MockTurn::text(r#"{"value":"retried"}"#),
2145        ]);
2146        let recorded = model.clone();
2147        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2148
2149        let response = agent
2150            .prompt_typed::<TypedAnswer>("return typed json")
2151            .with_hook(RetryDefaultApiHook)
2152            .max_invalid_tool_call_retries(1)
2153            .max_turns(3)
2154            .await
2155            .expect("typed prompt should retry invalid tool call");
2156
2157        assert_eq!(
2158            response,
2159            TypedAnswer {
2160                value: "retried".to_string()
2161            }
2162        );
2163        assert_eq!(recorded.request_count(), 2);
2164    }
2165
2166    #[tokio::test]
2167    async fn typed_prompt_invalid_tool_call_retry_budget_exhaustion_fails() {
2168        let model = MockCompletionModel::new([
2169            MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 2, "y": 3})),
2170            MockTurn::text(r#"{"value":"should not be requested"}"#),
2171        ]);
2172        let recorded = model.clone();
2173        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2174
2175        let err = agent
2176            .prompt_typed::<TypedAnswer>("return typed json")
2177            .with_hook(RetryDefaultApiHook)
2178            .max_invalid_tool_call_retries(0)
2179            .max_turns(3)
2180            .await
2181            .expect_err("typed prompt should fail when retry budget is exhausted");
2182
2183        match err {
2184            StructuredOutputError::PromptError(err) => match *err {
2185                PromptError::UnknownToolCall { tool_name, .. } => {
2186                    assert_eq!(tool_name, "default_api");
2187                }
2188                other => panic!("expected UnknownToolCall, got {other:?}"),
2189            },
2190            other => panic!("expected prompt error, got {other:?}"),
2191        }
2192        assert_eq!(recorded.request_count(), 1);
2193    }
2194
2195    #[tokio::test]
2196    async fn invalid_specific_tool_choice_fails_before_non_streaming_provider_request() {
2197        let model = MockCompletionModel::text("should not be requested");
2198        let recorded = model.clone();
2199        let agent = AgentBuilder::new(model)
2200            .tool(MockAddTool)
2201            .tool_choice(ToolChoice::Specific {
2202                function_names: vec!["missing".to_string()],
2203            })
2204            .build();
2205
2206        let err = agent
2207            .prompt("use the missing tool")
2208            .await
2209            .expect_err("invalid ToolChoice::Specific should fail before provider request");
2210
2211        match err {
2212            PromptError::CompletionError(CompletionError::RequestError(err)) => {
2213                let msg = err.to_string();
2214                assert!(msg.contains("missing"), "got: {msg}");
2215                assert!(msg.contains("add"), "got: {msg}");
2216            }
2217            other => panic!("expected CompletionError::RequestError, got {other:?}"),
2218        }
2219        assert_eq!(recorded.request_count(), 0);
2220    }
2221
2222    #[tokio::test]
2223    async fn allowed_specific_tool_call_executes_normally() {
2224        let model = MockCompletionModel::new([
2225            MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2})),
2226            MockTurn::text("done"),
2227        ]);
2228        let recorded = model.clone();
2229        let agent = AgentBuilder::new(model)
2230            .tool(MockAddTool)
2231            .tool_choice(ToolChoice::Specific {
2232                function_names: vec!["add".to_string()],
2233            })
2234            .build();
2235
2236        let response = agent
2237            .prompt("use the allowed tool")
2238            .max_turns(3)
2239            .await
2240            .expect("allowed specific tool should execute");
2241
2242        assert_eq!(response, "done");
2243        assert_eq!(recorded.request_count(), 2);
2244    }
2245
2246    #[tokio::test]
2247    async fn prompt_request_stops_cleanly_on_empty_terminal_turn() {
2248        let first_call_usage = Usage {
2249            input_tokens: 1,
2250            output_tokens: 1,
2251            total_tokens: 2,
2252            cached_input_tokens: 0,
2253            cache_creation_input_tokens: 0,
2254            tool_use_prompt_tokens: 0,
2255            reasoning_tokens: 0,
2256        };
2257        let second_call_usage = Usage {
2258            input_tokens: 1,
2259            output_tokens: 1,
2260            total_tokens: 2,
2261            cached_input_tokens: 0,
2262            cache_creation_input_tokens: 0,
2263            tool_use_prompt_tokens: 0,
2264            reasoning_tokens: 0,
2265        };
2266        let model = MockCompletionModel::new([
2267            MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2}))
2268                .with_call_id("call_1")
2269                .with_usage(first_call_usage),
2270            MockTurn::text("").with_usage(second_call_usage),
2271        ]);
2272        let agent = AgentBuilder::new(model).tool(MockAddTool).build();
2273
2274        let response = agent
2275            .prompt("do tool work")
2276            .max_turns(3)
2277            .extended_details()
2278            .await
2279            .expect("empty terminal turn should not error");
2280
2281        assert!(response.output.is_empty());
2282        assert_eq!(
2283            response.usage,
2284            Usage {
2285                input_tokens: 2,
2286                output_tokens: 2,
2287                total_tokens: 4,
2288                cached_input_tokens: 0,
2289                cache_creation_input_tokens: 0,
2290                tool_use_prompt_tokens: 0,
2291                reasoning_tokens: 0,
2292            }
2293        );
2294        assert_eq!(
2295            response.completion_calls(),
2296            &[
2297                CompletionCall::new(0, first_call_usage),
2298                CompletionCall::new(1, second_call_usage)
2299            ]
2300        );
2301
2302        let history = response
2303            .messages
2304            .expect("extended response should include history");
2305        assert_eq!(history.len(), 3);
2306        assert!(matches!(
2307            history.first(),
2308            Some(Message::User { content })
2309                if matches!(
2310                    content.first(),
2311                    UserContent::Text(text) if text.text == "do tool work"
2312                )
2313        ));
2314        assert!(history.iter().any(|message| matches!(
2315            message,
2316            Message::Assistant { content, .. }
2317                if matches!(
2318                    content.first(),
2319                    AssistantContent::ToolCall(tool_call)
2320                        if tool_call.id == "tool_call_1"
2321                            && tool_call.call_id.as_deref() == Some("call_1")
2322                )
2323        )));
2324        assert!(history.iter().any(|message| matches!(
2325            message,
2326            Message::User { content }
2327                if matches!(
2328                    content.first(),
2329                    UserContent::ToolResult(tool_result)
2330                        if tool_result.id == "tool_call_1"
2331                            && tool_result.call_id.as_deref() == Some("call_1")
2332                )
2333        )));
2334        assert!(!history.iter().any(|message| matches!(
2335            message,
2336            Message::Assistant { content, .. }
2337                if content.iter().any(|item| matches!(
2338                    item,
2339                    AssistantContent::Text(text) if text.text.is_empty()
2340                ))
2341        )));
2342        let requests = agent.model.requests();
2343        assert_eq!(requests.len(), 2);
2344        validate_follow_up_tool_history(&requests[1]);
2345    }
2346
2347    #[tokio::test]
2348    async fn prompt_request_concatenates_text_blocks_without_inserted_newlines() {
2349        let model = MockCompletionModel::new([MockTurn::from_contents([
2350            AssistantContent::Text(Text::new("According to the document, ")),
2351            AssistantContent::Text(Text::new("the grass is green")),
2352            AssistantContent::Text(Text::new(" and the sky is blue.")),
2353        ])
2354        .expect("mock response should contain text blocks")]);
2355        let agent = AgentBuilder::new(model).build();
2356
2357        let response = agent
2358            .prompt("answer with cited spans")
2359            .await
2360            .expect("prompt should succeed");
2361
2362        assert_eq!(
2363            response,
2364            "According to the document, the grass is green and the sky is blue."
2365        );
2366    }
2367
2368    #[tokio::test]
2369    async fn prompt_request_preserves_metadata_only_text_turn_in_history() {
2370        let metadata = json!({
2371            "citations": [{
2372                "type": "web_search_result_location",
2373                "cited_text": "Claude Shannon was born in 1916.",
2374                "url": "https://example.com/shannon",
2375                "title": null,
2376                "encrypted_index": "encrypted-reference"
2377            }]
2378        });
2379        let model =
2380            MockCompletionModel::new([MockTurn::from_content(AssistantContent::Text(Text {
2381                text: String::new(),
2382                additional_params: Some(metadata.clone()),
2383            }))]);
2384        let agent = AgentBuilder::new(model).build();
2385
2386        let response = agent
2387            .prompt("answer with cited metadata")
2388            .extended_details()
2389            .await
2390            .expect("metadata-only text turn should succeed");
2391
2392        assert!(response.output.is_empty());
2393        let history = response
2394            .messages
2395            .expect("extended response should include history");
2396        assert!(history.iter().any(|message| matches!(
2397            message,
2398            Message::Assistant { content, .. }
2399                if matches!(
2400                    content.first(),
2401                    AssistantContent::Text(text)
2402                        if text.text.is_empty()
2403                            && text.additional_params.as_ref() == Some(&metadata)
2404                )
2405        )));
2406    }
2407
2408    // ----- Conversation memory integration tests -----
2409
2410    use crate::memory::{ConversationMemory, InMemoryConversationMemory};
2411
2412    #[tokio::test]
2413    async fn memory_loads_into_request_history() {
2414        let memory = InMemoryConversationMemory::new();
2415        memory
2416            .append(
2417                "thread-1",
2418                vec![Message::user("hello"), Message::assistant("hi there")],
2419            )
2420            .await
2421            .unwrap();
2422
2423        let model = MockCompletionModel::text("ack");
2424        let recorded = model.clone();
2425
2426        let agent = AgentBuilder::new(model).memory(memory).build();
2427        let _ = agent
2428            .prompt("ping")
2429            .conversation("thread-1")
2430            .await
2431            .expect("prompt should succeed");
2432
2433        let received = recorded.requests()[0]
2434            .chat_history
2435            .iter()
2436            .cloned()
2437            .collect::<Vec<_>>();
2438        assert_eq!(
2439            received.len(),
2440            3,
2441            "loaded memory (2) + current prompt should appear in request: {received:?}"
2442        );
2443    }
2444
2445    #[tokio::test]
2446    async fn memory_appends_full_turn_after_success() {
2447        let memory = InMemoryConversationMemory::new();
2448        let model = MockCompletionModel::text("ack");
2449        let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2450
2451        let _ = agent
2452            .prompt("hello")
2453            .conversation("t1")
2454            .await
2455            .expect("prompt should succeed");
2456
2457        let stored = memory.load("t1").await.unwrap();
2458        assert_eq!(stored.len(), 2, "user prompt + assistant response saved");
2459    }
2460
2461    #[tokio::test]
2462    async fn explicit_with_history_overrides_memory() {
2463        let memory = CountingMemory::default();
2464        memory
2465            .inner()
2466            .append("t1", vec![Message::user("from-memory")])
2467            .await
2468            .unwrap();
2469
2470        let model = MockCompletionModel::text("ack");
2471        let recorded = model.clone();
2472
2473        let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2474        let _ = agent
2475            .prompt("hello")
2476            .conversation("t1")
2477            .with_history(vec![Message::user("from-caller")])
2478            .await
2479            .expect("prompt should succeed");
2480
2481        assert_eq!(memory.load_count(), 0, "load skipped");
2482        let appends = memory.append_count();
2483        assert_eq!(appends, 0, "append skipped");
2484
2485        let received = recorded.requests()[0]
2486            .chat_history
2487            .iter()
2488            .cloned()
2489            .collect::<Vec<_>>();
2490        assert_eq!(received.len(), 2, "caller history (1) + current prompt");
2491        assert!(matches!(
2492            received.first(),
2493            Some(Message::User { content })
2494                if matches!(content.first(), UserContent::Text(t) if t.text == "from-caller")
2495        ));
2496    }
2497
2498    #[tokio::test]
2499    async fn memory_unchanged_on_provider_error() {
2500        let memory = InMemoryConversationMemory::new();
2501        let model = MockCompletionModel::new([MockTurn::error("boom")]);
2502
2503        let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2504        let result = agent.prompt("hello").conversation("t1").await;
2505        assert!(result.is_err());
2506
2507        let stored = memory.load("t1").await.unwrap();
2508        assert!(stored.is_empty(), "no append on error");
2509    }
2510
2511    #[tokio::test]
2512    async fn missing_conversation_id_behaves_as_no_memory() {
2513        let memory = CountingMemory::default();
2514        let model = MockCompletionModel::text("ack");
2515        let agent = AgentBuilder::new(model).memory(memory.clone()).build();
2516
2517        let _ = agent.prompt("hello").await.expect("prompt should succeed");
2518
2519        assert_eq!(memory.load_count(), 0);
2520        assert_eq!(memory.append_count(), 0);
2521    }
2522
2523    #[tokio::test]
2524    async fn default_conversation_id_is_used_when_none_per_request() {
2525        let memory = InMemoryConversationMemory::new();
2526        let model = MockCompletionModel::text("ack");
2527        let agent = AgentBuilder::new(model)
2528            .memory(memory.clone())
2529            .conversation_id("default-thread")
2530            .build();
2531
2532        let _ = agent.prompt("hello").await.expect("prompt should succeed");
2533        let stored = memory.load("default-thread").await.unwrap();
2534        assert_eq!(stored.len(), 2);
2535    }
2536
2537    #[tokio::test]
2538    async fn with_filter_truncates_loaded_history() {
2539        let memory = InMemoryConversationMemory::new()
2540            .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
2541        memory
2542            .append(
2543                "t1",
2544                vec![
2545                    Message::user("1"),
2546                    Message::assistant("2"),
2547                    Message::user("3"),
2548                    Message::assistant("4"),
2549                ],
2550            )
2551            .await
2552            .unwrap();
2553
2554        let model = MockCompletionModel::text("ack");
2555        let recorded = model.clone();
2556        let agent = AgentBuilder::new(model).memory(memory).build();
2557
2558        let _ = agent
2559            .prompt("ping")
2560            .conversation("t1")
2561            .await
2562            .expect("prompt should succeed");
2563
2564        let received = recorded.requests()[0]
2565            .chat_history
2566            .iter()
2567            .cloned()
2568            .collect::<Vec<_>>();
2569        assert_eq!(
2570            received.len(),
2571            3,
2572            "window-truncated history (2) + current prompt"
2573        );
2574    }
2575
2576    #[tokio::test]
2577    async fn without_memory_disables_for_request() {
2578        let memory = CountingMemory::default();
2579        let model = MockCompletionModel::text("ack");
2580        let agent = AgentBuilder::new(model)
2581            .memory(memory.clone())
2582            .conversation_id("t1")
2583            .build();
2584
2585        let _ = agent
2586            .prompt("hello")
2587            .without_memory()
2588            .await
2589            .expect("prompt should succeed");
2590
2591        assert_eq!(memory.load_count(), 0);
2592        assert_eq!(memory.append_count(), 0);
2593    }
2594
2595    #[tokio::test]
2596    async fn memory_load_error_surfaces_as_prompt_error() {
2597        let model = MockCompletionModel::text("ack");
2598        let agent = AgentBuilder::new(model)
2599            .memory(FailingMemory::default())
2600            .build();
2601        let result = agent.prompt("hello").conversation("t1").await;
2602
2603        match result {
2604            Err(PromptError::CompletionError(CompletionError::RequestError(err))) => {
2605                let msg = format!("{err}");
2606                assert!(msg.contains("load boom"), "got: {msg}");
2607            }
2608            other => panic!("expected PromptError::CompletionError(RequestError), got {other:?}"),
2609        }
2610    }
2611
2612    #[tokio::test]
2613    async fn memory_append_error_does_not_drop_response() {
2614        let model = MockCompletionModel::text("ack");
2615        let agent = AgentBuilder::new(model)
2616            .memory(AppendFailingMemory::default())
2617            .build();
2618        let response: String = agent
2619            .prompt("hello")
2620            .conversation("t1")
2621            .await
2622            .expect("append failure must not block successful completion");
2623
2624        assert!(!response.is_empty());
2625    }
2626}