Skip to main content

rig_core/agent/run/
mod.rs

1//! A sans-IO, steppable, serializable state machine for the agent prompt loop.
2//!
3//! [`AgentRun`] owns every *decision* the agent loop makes — turn counting,
4//! tool-call validation, invalid tool-call recovery, chat-history threading,
5//! usage aggregation and final response construction — without performing any
6//! IO itself. A driver advances the machine by calling [`AgentRun::next_step`]
7//! and acting on the returned [`AgentRunStep`]:
8//!
9//! - [`AgentRunStep::CallModel`]: send a completion request to the model and
10//!   feed the result back via [`AgentRun::model_response`].
11//! - [`AgentRunStep::CallTools`]: execute the listed tool calls (with whatever
12//!   concurrency the driver chooses) and feed the results back via
13//!   [`AgentRun::tool_results`].
14//! - [`AgentRunStep::Done`]: the run is complete.
15//!
16//! Because the machine never awaits anything, it is runtime-agnostic and the
17//! whole run state is `Serialize + Deserialize`: a driver can serialize a run
18//! between steps (for example while tool calls are pending), persist it, and
19//! resume it later in another process. Note that serialized run state embeds
20//! the full conversation accumulated so far — persisting it inherits whatever
21//! sensitivity the conversation content has — and the serialization format
22//! carries no cross-version stability guarantee yet: resume with the same rig
23//! version that suspended the run.
24//!
25//! [`crate::completion::Prompt::prompt`] on [`crate::agent::Agent`] drives
26//! this machine internally; the same machine can be driven by hand for custom
27//! control flow:
28//!
29//! ```rust,no_run
30//! use rig_core::agent::run::{AgentRun, AgentRunStep, ModelTurn, ModelTurnOutcome};
31//!
32//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
33//! let mut run = AgentRun::new("What is 2+2?").max_turns(3);
34//! loop {
35//!     match run.next_step()? {
36//!         AgentRunStep::CallModel { prompt, history, .. } => {
37//!             // Send `prompt` + `history` to a model, then:
38//!             // run.model_response(ModelTurn { ... })?;
39//!             # let _ = (prompt, history);
40//!             # break;
41//!         }
42//!         AgentRunStep::CallTools { calls } => {
43//!             // Execute `calls`, then: run.tool_results(results)?;
44//!             # let _ = calls;
45//!         }
46//!         AgentRunStep::Done(response) => {
47//!             println!("{}", response.output);
48//!             break;
49//!         }
50//!     }
51//! }
52//! # Ok(())
53//! # }
54//! ```
55
56pub mod streamed;
57
58use std::collections::{BTreeMap, BTreeSet};
59
60use serde::{Deserialize, Serialize};
61
62use crate::{
63    OneOrMany,
64    agent::prompt_request::{
65        CompletionCall, PromptResponse, TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER,
66        assistant_text_from_choice, build_full_history, build_history_for_request,
67        hooks::{InvalidToolCallContext, InvalidToolCallHookAction},
68        invalid_tool_retry_user_message, is_empty_assistant_turn, tool_result_user_content,
69    },
70    completion::{Message, PromptError, Usage},
71    json_utils,
72    message::{AssistantContent, ToolCall, ToolChoice, ToolResult, ToolResultContent, UserContent},
73};
74
75pub use streamed::{
76    PartialStreamedTurn, StreamedInvalidToolCall, StreamedResolution, StreamedTurn,
77    StreamedTurnAssembler, StreamedTurnEvent,
78};
79
80/// What a driver must do next to advance an [`AgentRun`].
81///
82/// Deliberately exhaustive: a driver must handle every step, so adding a
83/// variant is a breaking change by design.
84#[derive(Debug, Clone)]
85pub enum AgentRunStep {
86    /// Send a completion request to the model and feed the result back via
87    /// [`AgentRun::model_response`].
88    CallModel {
89        /// The prompt message for this turn (the latest message in the run).
90        prompt: Message,
91        /// The chat history preceding `prompt`: the caller-provided input
92        /// history followed by messages accumulated by earlier turns.
93        history: Vec<Message>,
94        /// One-based index of this model call within the run.
95        turn: usize,
96    },
97    /// Execute these tool calls and feed the results back via
98    /// [`AgentRun::tool_results`].
99    CallTools {
100        /// The tool calls of the current assistant turn, in emission order.
101        calls: Vec<PendingToolCall>,
102    },
103    /// The run is complete.
104    Done(PromptResponse),
105}
106
107/// One tool call awaiting execution by the driver.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109#[non_exhaustive]
110pub struct PendingToolCall {
111    /// The tool call emitted by the model (with any repaired tool name applied).
112    pub tool_call: ToolCall,
113    /// Pre-resolved result for tool calls suppressed by invalid tool-call
114    /// recovery. When set, the driver must return this content as the tool
115    /// result without executing the tool or invoking tool hooks.
116    pub preresolved_result: Option<UserContent>,
117    /// Rig-generated identifier correlating this call's stream items, when
118    /// the call arrived via a streamed turn. Persisted with the run state so
119    /// a resumed process keeps emitting the IDs consumers already saw in
120    /// tool-call deltas. Drivers generate a fresh ID when absent.
121    #[serde(default)]
122    pub internal_call_id: Option<String>,
123}
124
125/// A completed model turn fed back to [`AgentRun::model_response`].
126#[derive(Debug, Clone, Serialize, Deserialize)]
127#[non_exhaustive]
128pub struct ModelTurn {
129    /// Provider-assigned assistant message ID, when available.
130    pub message_id: Option<String>,
131    /// The assistant content returned by the model.
132    pub choice: OneOrMany<AssistantContent>,
133    /// Token usage reported by the provider for this completion request.
134    pub usage: Usage,
135    /// Executable Rig tools advertised to the provider for this turn.
136    pub executable_tool_names: BTreeSet<String>,
137    /// Tools allowed by the active [`ToolChoice`] for this turn.
138    pub allowed_tool_names: BTreeSet<String>,
139}
140
141impl ModelTurn {
142    /// Create a model turn from response parts and the tool names advertised
143    /// for the turn.
144    pub fn new(
145        message_id: Option<String>,
146        choice: OneOrMany<AssistantContent>,
147        usage: Usage,
148        executable_tool_names: BTreeSet<String>,
149        allowed_tool_names: BTreeSet<String>,
150    ) -> Self {
151        Self {
152            message_id,
153            choice,
154            usage,
155            executable_tool_names,
156            allowed_tool_names,
157        }
158    }
159}
160
161/// Result of feeding a model turn (or an invalid tool-call resolution) into
162/// the machine.
163///
164/// Deliberately exhaustive: a driver must handle every outcome, so adding a
165/// variant is a breaking change by design.
166#[derive(Debug)]
167pub enum ModelTurnOutcome {
168    /// The turn was accepted. Unless `response_hook_suppressed` is set, the
169    /// driver should run its completion-response hook now, then call
170    /// [`AgentRun::next_step`].
171    ///
172    /// `response_hook_suppressed` is set when invalid tool-call recovery
173    /// (repair or skip) modified the turn, matching the agent loop's behavior
174    /// of not invoking `on_completion_response` for recovered turns.
175    Continue {
176        /// Whether the driver should suppress its completion-response hook.
177        response_hook_suppressed: bool,
178    },
179    /// The model emitted a tool call that is unknown or disallowed for this
180    /// turn. The driver must decide how to recover (typically by asking its
181    /// invalid tool-call hook) and answer via
182    /// [`AgentRun::resolve_invalid_tool_call`].
183    NeedsResolution(InvalidToolCallContext),
184    /// The turn was rolled back with corrective feedback appended to the
185    /// history. Call [`AgentRun::next_step`] to obtain the retry
186    /// [`AgentRunStep::CallModel`].
187    TurnRetried,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
191struct ResolvingState {
192    message_id: Option<String>,
193    /// The unmodified model output, used for diagnostic histories and retry
194    /// messages (repairs are never reflected in those).
195    original_choice: OneOrMany<AssistantContent>,
196    /// Working copy of the assistant content; repairs rename tool calls here.
197    items: Vec<AssistantContent>,
198    /// Index of the next item to validate.
199    next_index: usize,
200    executable_tool_names: BTreeSet<String>,
201    allowed_tool_names: BTreeSet<String>,
202    /// Synthetic tool results for skipped tool calls, keyed by tool call ID.
203    skipped: BTreeMap<String, UserContent>,
204    recovered: bool,
205    any_skipped: bool,
206    has_tool_calls: bool,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210struct TurnState {
211    message_id: Option<String>,
212    items: Vec<AssistantContent>,
213    has_tool_calls: bool,
214    skipped: BTreeMap<String, UserContent>,
215    /// `(tool_call_id, internal_call_id)` pairs for streamed turns, in
216    /// emission order; empty for non-streamed turns.
217    #[serde(default)]
218    internal_call_ids: Vec<(String, String)>,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
222enum RunState {
223    /// Ready to emit [`AgentRunStep::CallModel`].
224    PreparingRequest,
225    /// Waiting for [`AgentRun::model_response`].
226    AwaitingModel,
227    /// Scanning the model turn's tool calls for validity; may be waiting for
228    /// [`AgentRun::resolve_invalid_tool_call`].
229    ResolvingToolCalls(Box<ResolvingState>),
230    /// The turn was accepted; ready to emit [`AgentRunStep::CallTools`] or
231    /// [`AgentRunStep::Done`].
232    AwaitingAdvance(Box<TurnState>),
233    /// Waiting for [`AgentRun::tool_results`] for these pending tool calls.
234    /// Carrying the calls in the state keeps a serialized run self-contained:
235    /// a resumed process re-obtains them from [`AgentRun::next_step`].
236    ExecutingTools(Vec<PendingToolCall>),
237    /// Terminal: the run completed successfully.
238    Done(Box<PromptResponse>),
239    /// Terminal: the run returned an error.
240    Failed,
241}
242
243/// The sans-IO agent loop state machine. See the [module docs](self) for the
244/// driving protocol.
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct AgentRun {
247    max_turns: usize,
248    max_invalid_tool_call_retries: usize,
249    tool_choice: Option<ToolChoice>,
250    chat_history: Option<Vec<Message>>,
251    new_messages: Vec<Message>,
252    current_turn: usize,
253    usage: Usage,
254    completion_calls: Vec<CompletionCall>,
255    completion_call_index: usize,
256    invalid_tool_call_retries: usize,
257    /// Set while a streamed turn rollback awaits its completion-call record;
258    /// see [`AgentRun::record_streamed_completion_call`].
259    #[serde(default)]
260    rollback_pending: bool,
261    /// Set once the current streamed model turn's completion call has been
262    /// recorded, rejecting duplicate records; reset when the next
263    /// [`AgentRunStep::CallModel`] is emitted.
264    #[serde(default)]
265    streamed_completion_call_recorded: bool,
266    state: RunState,
267}
268
269impl AgentRun {
270    /// Create a run for one prompt with no input history, no multi-turn depth
271    /// and no invalid tool-call retries.
272    pub fn new(prompt: impl Into<Message>) -> Self {
273        Self {
274            max_turns: 0,
275            max_invalid_tool_call_retries: 0,
276            tool_choice: None,
277            chat_history: None,
278            new_messages: vec![prompt.into()],
279            current_turn: 0,
280            usage: Usage::new(),
281            completion_calls: Vec::new(),
282            completion_call_index: 0,
283            invalid_tool_call_retries: 0,
284            rollback_pending: false,
285            streamed_completion_call_recorded: false,
286            state: RunState::PreparingRequest,
287        }
288    }
289
290    /// Set the input chat history preceding the prompt.
291    pub fn with_history(mut self, history: Vec<Message>) -> Self {
292        self.chat_history = Some(history);
293        self
294    }
295
296    /// Set the maximum multi-turn depth. Exceeding it makes
297    /// [`AgentRun::next_step`] return [`PromptError::MaxTurnsError`].
298    pub fn max_turns(mut self, max_turns: usize) -> Self {
299        self.max_turns = max_turns;
300        self
301    }
302
303    /// Set the retry budget for [`InvalidToolCallHookAction::Retry`]
304    /// resolutions. Invalid tool-call retries also consume multi-turn depth.
305    pub fn max_invalid_tool_call_retries(mut self, retries: usize) -> Self {
306        self.max_invalid_tool_call_retries = retries;
307        self
308    }
309
310    /// Set the tool choice active for this run. Used to reject
311    /// [`InvalidToolCallHookAction::Skip`] resolutions under
312    /// [`ToolChoice::None`] and reported in invalid tool-call contexts.
313    pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
314        self.tool_choice = Some(tool_choice);
315        self
316    }
317
318    /// Aggregated token usage across all completed model calls so far.
319    pub fn usage(&self) -> Usage {
320        self.usage
321    }
322
323    /// Number of model calls emitted so far (including retries).
324    pub fn turn(&self) -> usize {
325        self.current_turn
326    }
327
328    /// Details for each completed model call so far.
329    pub fn completion_calls(&self) -> &[CompletionCall] {
330        &self.completion_calls
331    }
332
333    /// Messages accumulated by this run (the prompt plus all assistant turns
334    /// and tool results), excluding the input history.
335    pub fn messages(&self) -> &[Message] {
336        &self.new_messages
337    }
338
339    /// The full conversation: input history followed by [`Self::messages`].
340    pub fn full_history(&self) -> Vec<Message> {
341        build_full_history(self.chat_history.as_deref(), self.new_messages.clone())
342    }
343
344    /// Whether the run reached [`AgentRunStep::Done`].
345    pub fn is_done(&self) -> bool {
346        matches!(self.state, RunState::Done(_))
347    }
348
349    /// The final response once the run is done, without cloning it.
350    /// [`AgentRun::next_step`] in the done state returns an owned clone
351    /// (including the full accumulated message history); prefer this when
352    /// only inspecting the result.
353    pub fn response(&self) -> Option<&PromptResponse> {
354        match &self.state {
355            RunState::Done(response) => Some(response),
356            _ => None,
357        }
358    }
359
360    /// Build the cancellation error a driver should return when one of its
361    /// hooks terminates the run, carrying the current full history.
362    pub fn cancel_error(&self, reason: impl Into<String>) -> PromptError {
363        PromptError::prompt_cancelled(self.full_history(), reason)
364    }
365
366    /// The invalid tool call currently awaiting
367    /// [`AgentRun::resolve_invalid_tool_call`], if any. Useful to re-derive
368    /// the resolution context after deserializing a suspended run.
369    pub fn pending_invalid_tool_call(&self) -> Option<InvalidToolCallContext> {
370        let RunState::ResolvingToolCalls(resolving) = &self.state else {
371            return None;
372        };
373        let AssistantContent::ToolCall(tool_call) = resolving.items.get(resolving.next_index)?
374        else {
375            return None;
376        };
377        if resolving
378            .allowed_tool_names
379            .contains(&tool_call.function.name)
380        {
381            return None;
382        }
383
384        Some(InvalidToolCallContext {
385            tool_name: tool_call.function.name.clone(),
386            tool_call_id: Some(tool_call.id.clone()),
387            internal_call_id: None,
388            args: Some(json_utils::value_to_json_string(
389                &tool_call.function.arguments,
390            )),
391            available_tools: resolving.executable_tool_names.iter().cloned().collect(),
392            allowed_tools: resolving.allowed_tool_names.iter().cloned().collect(),
393            tool_choice: self.tool_choice.clone(),
394            chat_history: self.diagnostic_history(resolving),
395            is_streaming: false,
396        })
397    }
398
399    /// Advance the machine and return the next action for the driver.
400    ///
401    /// # Errors
402    /// - [`PromptError::MaxTurnsError`] when the multi-turn depth is exhausted.
403    /// - [`PromptError::PromptCancelled`] when the machine is driven out of
404    ///   protocol (for example, calling this while a model response is
405    ///   pending).
406    pub fn next_step(&mut self) -> Result<AgentRunStep, PromptError> {
407        match std::mem::replace(&mut self.state, RunState::Failed) {
408            RunState::PreparingRequest => {
409                let Some((prompt_ref, history_for_turn)) = self.new_messages.split_last() else {
410                    return Err(PromptError::prompt_cancelled(
411                        self.full_history(),
412                        "prompt loop lost its pending prompt",
413                    ));
414                };
415                let prompt = prompt_ref.clone();
416
417                if self.current_turn > self.max_turns + 1 {
418                    return Err(PromptError::MaxTurnsError {
419                        max_turns: self.max_turns,
420                        chat_history: self.full_history().into(),
421                        prompt: prompt.into(),
422                    });
423                }
424
425                let history =
426                    build_history_for_request(self.chat_history.as_deref(), history_for_turn);
427                self.current_turn += 1;
428                self.rollback_pending = false;
429                self.streamed_completion_call_recorded = false;
430                self.state = RunState::AwaitingModel;
431                Ok(AgentRunStep::CallModel {
432                    prompt,
433                    history,
434                    turn: self.current_turn,
435                })
436            }
437            RunState::AwaitingAdvance(turn_state) => {
438                let TurnState {
439                    message_id,
440                    items,
441                    has_tool_calls,
442                    skipped,
443                    mut internal_call_ids,
444                } = *turn_state;
445                let Some(choice) = OneOrMany::from_iter_optional(items.clone()) else {
446                    return Err(PromptError::prompt_cancelled(
447                        self.full_history(),
448                        "model turn lost its assistant content",
449                    ));
450                };
451
452                if !is_empty_assistant_turn(&choice) {
453                    self.new_messages.push(Message::Assistant {
454                        id: message_id,
455                        content: choice.clone(),
456                    });
457                }
458
459                if has_tool_calls {
460                    let calls: Vec<PendingToolCall> = items
461                        .iter()
462                        .filter_map(|item| match item {
463                            AssistantContent::ToolCall(tool_call) => {
464                                // Consume pairs positionally so duplicate
465                                // provider IDs within one turn stay
466                                // distinguishable.
467                                let internal_call_id = internal_call_ids
468                                    .iter()
469                                    .position(|(id, _)| *id == tool_call.id)
470                                    .map(|index| internal_call_ids.remove(index).1);
471                                Some(PendingToolCall {
472                                    tool_call: tool_call.clone(),
473                                    preresolved_result: skipped.get(&tool_call.id).cloned(),
474                                    internal_call_id,
475                                })
476                            }
477                            _ => None,
478                        })
479                        .collect();
480                    self.state = RunState::ExecutingTools(calls.clone());
481                    Ok(AgentRunStep::CallTools { calls })
482                } else {
483                    let response =
484                        PromptResponse::new(assistant_text_from_choice(&choice), self.usage)
485                            .with_messages(self.new_messages.clone())
486                            .with_completion_calls(self.completion_calls.clone());
487                    self.state = RunState::Done(Box::new(response.clone()));
488                    Ok(AgentRunStep::Done(response))
489                }
490            }
491            RunState::ExecutingTools(calls) => {
492                // Idempotent, like Done: a process resuming a serialized run
493                // re-obtains the pending tool calls from the state itself.
494                let step = AgentRunStep::CallTools {
495                    calls: calls.clone(),
496                };
497                self.state = RunState::ExecutingTools(calls);
498                Ok(step)
499            }
500            RunState::Done(response) => {
501                let step = AgentRunStep::Done((*response).clone());
502                self.state = RunState::Done(response);
503                Ok(step)
504            }
505            state @ (RunState::AwaitingModel | RunState::ResolvingToolCalls(_)) => {
506                let reason = match &state {
507                    RunState::AwaitingModel => {
508                        "next_step called while a model response is pending; feed it via model_response first"
509                    }
510                    _ => {
511                        "next_step called while an invalid tool-call resolution is pending; answer it via resolve_invalid_tool_call first"
512                    }
513                };
514                self.state = state;
515                Err(self.protocol_violation(reason))
516            }
517            RunState::Failed => Err(self.protocol_violation(
518                "next_step called after the run already failed or was misdriven",
519            )),
520        }
521    }
522
523    /// Feed the model's response for the pending [`AgentRunStep::CallModel`].
524    ///
525    /// Records the completion call and aggregates usage, then validates the
526    /// turn's tool calls against the advertised tool names. See
527    /// [`ModelTurnOutcome`] for what the driver must do next.
528    pub fn model_response(&mut self, turn: ModelTurn) -> Result<ModelTurnOutcome, PromptError> {
529        if !matches!(self.state, RunState::AwaitingModel) {
530            return Err(
531                self.protocol_violation("model_response called without a pending CallModel step")
532            );
533        }
534        if self.streamed_completion_call_recorded {
535            return Err(self.protocol_violation(
536                "model_response called after record_streamed_completion_call for the same turn; feed streamed turns via streamed_turn",
537            ));
538        }
539
540        self.completion_calls
541            .push(CompletionCall::new(self.completion_call_index, turn.usage));
542        self.completion_call_index += 1;
543        self.usage += turn.usage;
544
545        let items: Vec<AssistantContent> = turn.choice.iter().cloned().collect();
546        let has_tool_calls = items
547            .iter()
548            .any(|item| matches!(item, AssistantContent::ToolCall(_)));
549
550        self.state = RunState::ResolvingToolCalls(Box::new(ResolvingState {
551            message_id: turn.message_id,
552            original_choice: turn.choice,
553            items,
554            next_index: 0,
555            executable_tool_names: turn.executable_tool_names,
556            allowed_tool_names: turn.allowed_tool_names,
557            skipped: BTreeMap::new(),
558            recovered: false,
559            any_skipped: false,
560            has_tool_calls,
561        }));
562
563        self.advance_resolution()
564    }
565
566    /// Answer a pending [`ModelTurnOutcome::NeedsResolution`].
567    ///
568    /// Applies the agent loop's recovery semantics:
569    /// - [`InvalidToolCallHookAction::Fail`] fails the run with
570    ///   [`PromptError::UnknownToolCall`].
571    /// - [`InvalidToolCallHookAction::Retry`] rolls the turn back with
572    ///   corrective feedback while budget remains, consuming multi-turn depth.
573    /// - [`InvalidToolCallHookAction::Repair`] renames the tool call; the
574    ///   repaired name is revalidated against the allowed tools.
575    /// - [`InvalidToolCallHookAction::Skip`] records a synthetic tool result
576    ///   and suppresses execution of every tool call in the turn. Rejected
577    ///   under [`ToolChoice::None`].
578    pub fn resolve_invalid_tool_call(
579        &mut self,
580        action: InvalidToolCallHookAction,
581    ) -> Result<ModelTurnOutcome, PromptError> {
582        // Take the resolving state; rejection paths below restore it so an
583        // out-of-protocol call does not corrupt a drivable run.
584        let mut resolving = match std::mem::replace(&mut self.state, RunState::Failed) {
585            RunState::ResolvingToolCalls(resolving) => resolving,
586            other => {
587                self.state = other;
588                return Err(self.protocol_violation(
589                    "resolve_invalid_tool_call called without a pending invalid tool call",
590                ));
591            }
592        };
593        let tool_call = match resolving.items.get(resolving.next_index) {
594            Some(AssistantContent::ToolCall(tool_call))
595                if !resolving
596                    .allowed_tool_names
597                    .contains(&tool_call.function.name) =>
598            {
599                tool_call.clone()
600            }
601            _ => {
602                self.state = RunState::ResolvingToolCalls(resolving);
603                return Err(self.protocol_violation(
604                    "resolve_invalid_tool_call called without a pending invalid tool call",
605                ));
606            }
607        };
608
609        let diagnostic_history = self.diagnostic_history(&resolving);
610        let executable_tool_names: Vec<String> =
611            resolving.executable_tool_names.iter().cloned().collect();
612        let allowed_tool_names: Vec<String> =
613            resolving.allowed_tool_names.iter().cloned().collect();
614
615        match action {
616            InvalidToolCallHookAction::Fail => Err(PromptError::UnknownToolCall {
617                tool_name: tool_call.function.name,
618                available_tools: executable_tool_names,
619                allowed_tools: allowed_tool_names,
620                chat_history: Box::new(diagnostic_history),
621            }),
622            InvalidToolCallHookAction::Retry { feedback } => {
623                if self.invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
624                    return Err(PromptError::UnknownToolCall {
625                        tool_name: tool_call.function.name,
626                        available_tools: executable_tool_names,
627                        allowed_tools: allowed_tool_names,
628                        chat_history: Box::new(diagnostic_history),
629                    });
630                }
631                self.invalid_tool_call_retries += 1;
632
633                self.new_messages.push(Message::Assistant {
634                    id: resolving.message_id.clone(),
635                    content: resolving.original_choice.clone(),
636                });
637                let Some(user_message) = invalid_tool_retry_user_message(
638                    &resolving.original_choice,
639                    &tool_call.id,
640                    feedback,
641                ) else {
642                    return Err(PromptError::prompt_cancelled(
643                        diagnostic_history,
644                        "invalid tool call retry produced no retry messages",
645                    ));
646                };
647                self.new_messages.push(user_message);
648                self.state = RunState::PreparingRequest;
649                Ok(ModelTurnOutcome::TurnRetried)
650            }
651            InvalidToolCallHookAction::Repair { tool_name } => {
652                if !allowed_tool_names.contains(&tool_name) {
653                    return Err(PromptError::UnknownToolCall {
654                        tool_name,
655                        available_tools: executable_tool_names,
656                        allowed_tools: allowed_tool_names,
657                        chat_history: Box::new(diagnostic_history),
658                    });
659                }
660                if let Some(AssistantContent::ToolCall(tool_call)) =
661                    resolving.items.get_mut(resolving.next_index)
662                {
663                    tool_call.function.name = tool_name;
664                }
665                resolving.recovered = true;
666                self.state = RunState::ResolvingToolCalls(resolving);
667                self.advance_resolution()
668            }
669            InvalidToolCallHookAction::Skip { reason } => {
670                if matches!(self.tool_choice, Some(ToolChoice::None)) {
671                    return Err(PromptError::UnknownToolCall {
672                        tool_name: tool_call.function.name,
673                        available_tools: executable_tool_names,
674                        allowed_tools: allowed_tool_names,
675                        chat_history: Box::new(diagnostic_history),
676                    });
677                }
678                let user_content = if let Some(call_id) = tool_call.call_id.clone() {
679                    UserContent::tool_result_with_call_id(
680                        tool_call.id.clone(),
681                        call_id,
682                        OneOrMany::one(reason.into()),
683                    )
684                } else {
685                    UserContent::tool_result(tool_call.id.clone(), OneOrMany::one(reason.into()))
686                };
687                resolving.skipped.insert(tool_call.id.clone(), user_content);
688                resolving.recovered = true;
689                resolving.any_skipped = true;
690                resolving.next_index += 1;
691                self.state = RunState::ResolvingToolCalls(resolving);
692                self.advance_resolution()
693            }
694        }
695    }
696
697    /// Feed the tool results for the pending [`AgentRunStep::CallTools`].
698    ///
699    /// Results may be in any order; they are appended as a single user
700    /// message, matching what providers expect for parallel tool calls. Each
701    /// result must be a tool result answering one of the pending calls, and
702    /// every pending call must be answered — exactly what providers require
703    /// to accept the next request.
704    pub fn tool_results(&mut self, results: Vec<UserContent>) -> Result<(), PromptError> {
705        let RunState::ExecutingTools(pending) = &self.state else {
706            return Err(
707                self.protocol_violation("tool_results called without a pending CallTools step")
708            );
709        };
710        // Match results against pending calls by tool call ID as a multiset,
711        // so duplicate provider IDs within one turn stay answerable.
712        let mut unanswered: Vec<String> = pending
713            .iter()
714            .map(|call| call.tool_call.id.clone())
715            .collect();
716
717        if results.is_empty() {
718            self.state = RunState::Failed;
719            return Err(PromptError::prompt_cancelled(
720                self.full_history(),
721                "tool execution produced no tool results",
722            ));
723        }
724        for result in &results {
725            let UserContent::ToolResult(tool_result) = result else {
726                return Err(self.protocol_violation(
727                    "tool_results received content that is not a tool result",
728                ));
729            };
730            let Some(index) = unanswered.iter().position(|id| *id == tool_result.id) else {
731                return Err(self.protocol_violation(&format!(
732                    "tool_results received a result for unknown or already-answered tool call id `{}`",
733                    tool_result.id
734                )));
735            };
736            unanswered.swap_remove(index);
737        }
738        if !unanswered.is_empty() {
739            return Err(self.protocol_violation(&format!(
740                "tool_results left pending tool call id(s) unanswered: {unanswered:?}"
741            )));
742        }
743
744        // `results` is non-empty (checked above), so construction succeeds.
745        let Some(content) = OneOrMany::from_iter_optional(results) else {
746            return Err(
747                self.protocol_violation("internal: tool results vanished during validation")
748            );
749        };
750
751        self.new_messages.push(Message::User { content });
752        self.state = RunState::PreparingRequest;
753        Ok(())
754    }
755
756    /// Scan forward for the next invalid tool call; finish the turn when the
757    /// scan completes.
758    fn advance_resolution(&mut self) -> Result<ModelTurnOutcome, PromptError> {
759        let mut resolving = match std::mem::replace(&mut self.state, RunState::Failed) {
760            RunState::ResolvingToolCalls(resolving) => resolving,
761            other => {
762                self.state = other;
763                return Err(self.protocol_violation(
764                    "internal: advance_resolution outside of tool-call resolution",
765                ));
766            }
767        };
768        while let Some(item) = resolving.items.get(resolving.next_index) {
769            match item {
770                AssistantContent::ToolCall(tool_call)
771                    if !resolving
772                        .allowed_tool_names
773                        .contains(&tool_call.function.name) =>
774                {
775                    break;
776                }
777                _ => resolving.next_index += 1,
778            }
779        }
780
781        if resolving.next_index < resolving.items.len() {
782            self.state = RunState::ResolvingToolCalls(resolving);
783            return match self.pending_invalid_tool_call() {
784                Some(context) => Ok(ModelTurnOutcome::NeedsResolution(context)),
785                None => Err(self.protocol_violation(
786                    "internal: pending invalid tool call could not be derived",
787                )),
788            };
789        }
790
791        let ResolvingState {
792            message_id,
793            items,
794            mut skipped,
795            recovered,
796            any_skipped,
797            has_tool_calls,
798            ..
799        } = *resolving;
800
801        // When any tool call was skipped, none of the turn's tool calls
802        // execute: peers get a synthetic "not executed" result.
803        if any_skipped {
804            for item in &items {
805                if let AssistantContent::ToolCall(tool_call) = item {
806                    skipped.entry(tool_call.id.clone()).or_insert_with(|| {
807                        tool_result_user_content(
808                            tool_call.id.clone(),
809                            tool_call.call_id.clone(),
810                            TOOL_NOT_EXECUTED_DUE_TO_INVALID_PEER.to_string(),
811                        )
812                    });
813                }
814            }
815        }
816
817        self.state = RunState::AwaitingAdvance(Box::new(TurnState {
818            message_id,
819            items,
820            has_tool_calls,
821            skipped,
822            internal_call_ids: Vec::new(),
823        }));
824        Ok(ModelTurnOutcome::Continue {
825            response_hook_suppressed: recovered,
826        })
827    }
828
829    // ── Streamed-turn entry points ──────────────────────────────────────
830    // Paired with [`streamed::StreamedTurnAssembler`]; see that module's
831    // docs for the full driving protocol.
832
833    /// Record one provider completion call for a streamed turn.
834    ///
835    /// Streamed turns learn usage from the provider's final stream event —
836    /// including for turns abandoned by invalid tool-call recovery, where the
837    /// stream is drained for usage after the rollback — so recording is
838    /// decoupled from turn ingestion. Valid while a model response is pending
839    /// or between a turn rollback and the next [`AgentRunStep::CallModel`];
840    /// aggregates `usage` into the run total. Zero-valued usage means the
841    /// provider reported no usage metrics.
842    pub fn record_streamed_completion_call(
843        &mut self,
844        usage: Usage,
845    ) -> Result<CompletionCall, PromptError> {
846        let recordable = matches!(self.state, RunState::AwaitingModel)
847            || (matches!(self.state, RunState::PreparingRequest) && self.rollback_pending);
848        if !recordable {
849            return Err(self.protocol_violation(
850                "record_streamed_completion_call called without a pending or rolled-back CallModel step",
851            ));
852        }
853        if self.streamed_completion_call_recorded {
854            return Err(self.protocol_violation(
855                "record_streamed_completion_call called twice for the same model turn",
856            ));
857        }
858        self.streamed_completion_call_recorded = true;
859
860        let call = CompletionCall::new(self.completion_call_index, usage);
861        self.completion_call_index += 1;
862        self.completion_calls.push(call);
863        self.usage += usage;
864        Ok(call)
865    }
866
867    /// The recovery-hook context for an invalid tool call surfaced
868    /// mid-stream by a [`streamed::StreamedTurnAssembler`].
869    pub fn streamed_invalid_tool_call_context(
870        &self,
871        partial: &PartialStreamedTurn,
872        invalid: &StreamedInvalidToolCall,
873    ) -> InvalidToolCallContext {
874        InvalidToolCallContext {
875            tool_name: invalid.tool_call.function.name.clone(),
876            tool_call_id: Some(invalid.tool_call.id.clone()),
877            internal_call_id: Some(invalid.internal_call_id.clone()),
878            args: invalid.args.clone(),
879            available_tools: invalid.executable_tool_names.iter().cloned().collect(),
880            allowed_tools: invalid.allowed_tool_names.iter().cloned().collect(),
881            tool_choice: self.tool_choice.clone(),
882            chat_history: self
883                .streamed_diagnostic_history(partial, Some(invalid.tool_call.clone())),
884            is_streaming: true,
885        }
886    }
887
888    /// Resolve an invalid tool call surfaced mid-stream.
889    ///
890    /// Applies the same recovery semantics as
891    /// [`AgentRun::resolve_invalid_tool_call`], but rollback messages are
892    /// assembled from the partial streamed turn — exactly what the model has
893    /// produced so far — and a successful retry or skip abandons the turn
894    /// (see [`StreamedResolution`]) instead of finishing it.
895    pub fn resolve_streamed_invalid_tool_call(
896        &mut self,
897        partial: &PartialStreamedTurn,
898        invalid: &StreamedInvalidToolCall,
899        action: InvalidToolCallHookAction,
900    ) -> Result<StreamedResolution, PromptError> {
901        if !matches!(self.state, RunState::AwaitingModel) {
902            return Err(self.protocol_violation(
903                "resolve_streamed_invalid_tool_call called without a pending CallModel step",
904            ));
905        }
906
907        let diagnostic_history =
908            self.streamed_diagnostic_history(partial, Some(invalid.tool_call.clone()));
909        let executable_tool_names: Vec<String> =
910            invalid.executable_tool_names.iter().cloned().collect();
911        let allowed_tool_names: Vec<String> = invalid.allowed_tool_names.iter().cloned().collect();
912
913        match action {
914            InvalidToolCallHookAction::Fail => {
915                self.state = RunState::Failed;
916                Err(PromptError::UnknownToolCall {
917                    tool_name: invalid.tool_call.function.name.clone(),
918                    available_tools: executable_tool_names,
919                    allowed_tools: allowed_tool_names,
920                    chat_history: Box::new(diagnostic_history),
921                })
922            }
923            InvalidToolCallHookAction::Retry { feedback } => {
924                if self.invalid_tool_call_retries >= self.max_invalid_tool_call_retries {
925                    self.state = RunState::Failed;
926                    return Err(PromptError::UnknownToolCall {
927                        tool_name: invalid.tool_call.function.name.clone(),
928                        available_tools: executable_tool_names,
929                        allowed_tools: allowed_tool_names,
930                        chat_history: Box::new(diagnostic_history),
931                    });
932                }
933                self.invalid_tool_call_retries += 1;
934
935                let Some((assistant_message, user_message)) =
936                    partial.rollback_messages(invalid.tool_call.clone(), feedback)
937                else {
938                    self.state = RunState::Failed;
939                    return Err(PromptError::prompt_cancelled(
940                        diagnostic_history,
941                        "invalid tool call retry produced no retry messages",
942                    ));
943                };
944                self.new_messages.push(assistant_message);
945                self.new_messages.push(user_message);
946                self.rollback_pending = true;
947                self.state = RunState::PreparingRequest;
948                Ok(StreamedResolution::TurnAbandoned {
949                    skipped_tool_result: None,
950                })
951            }
952            InvalidToolCallHookAction::Repair { tool_name } => {
953                if !invalid.allowed_tool_names.contains(&tool_name) {
954                    self.state = RunState::Failed;
955                    return Err(PromptError::UnknownToolCall {
956                        tool_name,
957                        available_tools: executable_tool_names,
958                        allowed_tools: allowed_tool_names,
959                        chat_history: Box::new(diagnostic_history),
960                    });
961                }
962                Ok(StreamedResolution::Repaired { tool_name })
963            }
964            InvalidToolCallHookAction::Skip { reason } => {
965                if matches!(self.tool_choice, Some(ToolChoice::None)) {
966                    self.state = RunState::Failed;
967                    return Err(PromptError::UnknownToolCall {
968                        tool_name: invalid.tool_call.function.name.clone(),
969                        available_tools: executable_tool_names,
970                        allowed_tools: allowed_tool_names,
971                        chat_history: Box::new(diagnostic_history),
972                    });
973                }
974
975                let skipped_tool_result = ToolResult {
976                    id: invalid.tool_call.id.clone(),
977                    call_id: invalid.tool_call.call_id.clone(),
978                    content: ToolResultContent::from_tool_output(reason.clone()),
979                };
980                let Some((assistant_message, user_message)) =
981                    partial.rollback_messages(invalid.tool_call.clone(), reason)
982                else {
983                    self.state = RunState::Failed;
984                    return Err(PromptError::prompt_cancelled(
985                        diagnostic_history,
986                        "invalid tool call skip produced no recovery messages",
987                    ));
988                };
989                self.new_messages.push(assistant_message);
990                self.new_messages.push(user_message);
991                self.rollback_pending = true;
992                self.state = RunState::PreparingRequest;
993                Ok(StreamedResolution::TurnAbandoned {
994                    skipped_tool_result: Some(skipped_tool_result),
995                })
996            }
997        }
998    }
999
1000    /// Feed the assembled streamed turn for the pending
1001    /// [`AgentRunStep::CallModel`].
1002    ///
1003    /// Remaining tool calls are validated fail-fast — mid-stream resolution
1004    /// already had recovery-hook access — and the turn then advances through
1005    /// [`AgentRun::next_step`] exactly like a non-streamed one.
1006    pub fn streamed_turn(&mut self, turn: StreamedTurn) -> Result<(), PromptError> {
1007        if !matches!(self.state, RunState::AwaitingModel) {
1008            return Err(
1009                self.protocol_violation("streamed_turn called without a pending CallModel step")
1010            );
1011        }
1012
1013        // Guarantee exactly one CompletionCall per model call: drivers that
1014        // never learned usage (no record before the turn completed) still get
1015        // the call recorded, with no reported usage.
1016        if !self.streamed_completion_call_recorded {
1017            self.completion_calls.push(CompletionCall::new(
1018                self.completion_call_index,
1019                Usage::new(),
1020            ));
1021            self.completion_call_index += 1;
1022            self.streamed_completion_call_recorded = true;
1023        }
1024
1025        let items: Vec<AssistantContent> = turn.choice.iter().cloned().collect();
1026        let has_tool_calls = items
1027            .iter()
1028            .any(|item| matches!(item, AssistantContent::ToolCall(_)));
1029
1030        for item in &items {
1031            let AssistantContent::ToolCall(tool_call) = item else {
1032                continue;
1033            };
1034            if !turn.allowed_tool_names.contains(&tool_call.function.name) {
1035                let mut diagnostic_messages = self.new_messages.clone();
1036                if !is_empty_assistant_turn(&turn.choice) {
1037                    diagnostic_messages.push(Message::Assistant {
1038                        id: turn.message_id.clone(),
1039                        content: turn.choice.clone(),
1040                    });
1041                }
1042                let diagnostic_history =
1043                    build_full_history(self.chat_history.as_deref(), diagnostic_messages);
1044                self.state = RunState::Failed;
1045                return Err(PromptError::UnknownToolCall {
1046                    tool_name: tool_call.function.name.clone(),
1047                    available_tools: turn.executable_tool_names.iter().cloned().collect(),
1048                    allowed_tools: turn.allowed_tool_names.iter().cloned().collect(),
1049                    chat_history: Box::new(diagnostic_history),
1050                });
1051            }
1052        }
1053
1054        self.state = RunState::AwaitingAdvance(Box::new(TurnState {
1055            message_id: turn.message_id,
1056            items,
1057            has_tool_calls,
1058            skipped: BTreeMap::new(),
1059            internal_call_ids: turn.internal_call_ids,
1060        }));
1061        Ok(())
1062    }
1063
1064    /// Diagnostic history for a streamed turn: the run's messages plus the
1065    /// partial assistant turn under inspection.
1066    fn streamed_diagnostic_history(
1067        &self,
1068        partial: &PartialStreamedTurn,
1069        current_tool_call: Option<ToolCall>,
1070    ) -> Vec<Message> {
1071        let mut messages = self.new_messages.clone();
1072        if let Some(assistant) = partial.assistant_message(current_tool_call) {
1073            messages.push(assistant);
1074        }
1075        build_full_history(self.chat_history.as_deref(), messages)
1076    }
1077
1078    /// History used for invalid tool-call diagnostics: the run's messages plus
1079    /// the unmodified assistant turn under inspection.
1080    fn diagnostic_history(&self, resolving: &ResolvingState) -> Vec<Message> {
1081        let mut diagnostic_messages = self.new_messages.clone();
1082        diagnostic_messages.push(Message::Assistant {
1083            id: resolving.message_id.clone(),
1084            content: resolving.original_choice.clone(),
1085        });
1086        build_full_history(self.chat_history.as_deref(), diagnostic_messages)
1087    }
1088
1089    fn protocol_violation(&self, reason: &str) -> PromptError {
1090        PromptError::prompt_cancelled(
1091            self.full_history(),
1092            format!("agent run driver protocol violation: {reason}"),
1093        )
1094    }
1095}
1096
1097#[cfg(test)]
1098mod tests {
1099    use super::*;
1100    use crate::message::{ToolFunction, ToolResultContent};
1101    use serde_json::json;
1102
1103    fn tool_names(names: &[&str]) -> BTreeSet<String> {
1104        names.iter().map(|name| (*name).to_string()).collect()
1105    }
1106
1107    fn usage(input_tokens: u64, output_tokens: u64) -> Usage {
1108        Usage {
1109            input_tokens,
1110            output_tokens,
1111            total_tokens: input_tokens + output_tokens,
1112            ..Usage::new()
1113        }
1114    }
1115
1116    fn text_turn(text: &str) -> ModelTurn {
1117        ModelTurn::new(
1118            None,
1119            OneOrMany::one(AssistantContent::text(text)),
1120            Usage::new(),
1121            tool_names(&["add"]),
1122            tool_names(&["add"]),
1123        )
1124    }
1125
1126    fn tool_call(id: &str, name: &str) -> AssistantContent {
1127        AssistantContent::ToolCall(ToolCall::new(
1128            id.to_string(),
1129            ToolFunction::new(name.to_string(), json!({"x": 1})),
1130        ))
1131    }
1132
1133    fn tool_call_turn(id: &str, name: &str) -> ModelTurn {
1134        ModelTurn::new(
1135            None,
1136            OneOrMany::one(tool_call(id, name)),
1137            Usage::new(),
1138            tool_names(&["add"]),
1139            tool_names(&["add"]),
1140        )
1141    }
1142
1143    fn tool_result(id: &str, output: &str) -> UserContent {
1144        UserContent::tool_result(
1145            id.to_string(),
1146            ToolResultContent::from_tool_output(output.to_string()),
1147        )
1148    }
1149
1150    fn expect_call_model(run: &mut AgentRun) -> (Message, Vec<Message>, usize) {
1151        match run.next_step().expect("next_step should succeed") {
1152            AgentRunStep::CallModel {
1153                prompt,
1154                history,
1155                turn,
1156            } => (prompt, history, turn),
1157            step => panic!("expected CallModel, got {step:?}"),
1158        }
1159    }
1160
1161    fn expect_call_tools(run: &mut AgentRun) -> Vec<PendingToolCall> {
1162        match run.next_step().expect("next_step should succeed") {
1163            AgentRunStep::CallTools { calls } => calls,
1164            step => panic!("expected CallTools, got {step:?}"),
1165        }
1166    }
1167
1168    fn expect_done(run: &mut AgentRun) -> PromptResponse {
1169        match run.next_step().expect("next_step should succeed") {
1170            AgentRunStep::Done(response) => response,
1171            step => panic!("expected Done, got {step:?}"),
1172        }
1173    }
1174
1175    fn expect_continue(outcome: ModelTurnOutcome) -> bool {
1176        match outcome {
1177            ModelTurnOutcome::Continue {
1178                response_hook_suppressed,
1179            } => response_hook_suppressed,
1180            outcome => panic!("expected Continue, got {outcome:?}"),
1181        }
1182    }
1183
1184    fn expect_needs_resolution(outcome: ModelTurnOutcome) -> InvalidToolCallContext {
1185        match outcome {
1186            ModelTurnOutcome::NeedsResolution(context) => context,
1187            outcome => panic!("expected NeedsResolution, got {outcome:?}"),
1188        }
1189    }
1190
1191    #[test]
1192    fn text_only_run_completes_in_one_turn() {
1193        let mut run = AgentRun::new("hello");
1194
1195        let (prompt, history, turn) = expect_call_model(&mut run);
1196        assert_eq!(prompt, Message::user("hello"));
1197        assert!(history.is_empty());
1198        assert_eq!(turn, 1);
1199
1200        let suppressed = expect_continue(
1201            run.model_response(text_turn("hi there"))
1202                .expect("model_response should succeed"),
1203        );
1204        assert!(!suppressed);
1205
1206        let response = expect_done(&mut run);
1207        assert_eq!(response.output, "hi there");
1208        let messages = response.messages.expect("messages should be recorded");
1209        assert_eq!(messages.len(), 2);
1210        assert!(run.is_done());
1211    }
1212
1213    #[test]
1214    fn input_history_prefixes_request_history() {
1215        let mut run = AgentRun::new("question")
1216            .with_history(vec![Message::user("earlier"), Message::assistant("reply")]);
1217
1218        let (_, history, _) = expect_call_model(&mut run);
1219        assert_eq!(
1220            history,
1221            vec![Message::user("earlier"), Message::assistant("reply")]
1222        );
1223
1224        expect_continue(
1225            run.model_response(text_turn("answer"))
1226                .expect("model_response should succeed"),
1227        );
1228        let response = expect_done(&mut run);
1229        // Returned messages exclude the input history.
1230        assert_eq!(
1231            response
1232                .messages
1233                .expect("messages should be recorded")
1234                .len(),
1235            2
1236        );
1237    }
1238
1239    #[test]
1240    fn tool_roundtrip_threads_history_and_usage() {
1241        let mut run = AgentRun::new("add things").max_turns(2);
1242
1243        expect_call_model(&mut run);
1244        expect_continue(
1245            run.model_response(tool_call_turn("call_1", "add").with_usage_for_test(usage(10, 5)))
1246                .expect("model_response should succeed"),
1247        );
1248
1249        let calls = expect_call_tools(&mut run);
1250        assert_eq!(calls.len(), 1);
1251        assert_eq!(calls[0].tool_call.function.name, "add");
1252        assert!(calls[0].preresolved_result.is_none());
1253
1254        run.tool_results(vec![tool_result("call_1", "2")])
1255            .expect("tool_results should succeed");
1256
1257        let (prompt, history, turn) = expect_call_model(&mut run);
1258        assert_eq!(turn, 2);
1259        // The tool-result user message becomes the new prompt; the assistant
1260        // turn is part of the history.
1261        assert!(matches!(prompt, Message::User { .. }));
1262        assert_eq!(history.len(), 2);
1263
1264        expect_continue(
1265            run.model_response(text_turn("the answer is 2").with_usage_for_test(usage(20, 7)))
1266                .expect("model_response should succeed"),
1267        );
1268
1269        let response = expect_done(&mut run);
1270        assert_eq!(response.output, "the answer is 2");
1271        assert_eq!(response.usage, usage(30, 12));
1272        assert_eq!(response.completion_calls.len(), 2);
1273        assert_eq!(response.completion_calls[0].call_index, 0);
1274        assert_eq!(response.completion_calls[0].usage, usage(10, 5));
1275        assert_eq!(response.completion_calls[1].usage, usage(20, 7));
1276        // prompt, assistant tool call, tool result, final assistant text
1277        assert_eq!(
1278            response
1279                .messages
1280                .expect("messages should be recorded")
1281                .len(),
1282            4
1283        );
1284    }
1285
1286    #[test]
1287    fn parallel_tool_calls_surface_in_emission_order() {
1288        let mut run = AgentRun::new("do both").max_turns(2);
1289
1290        expect_call_model(&mut run);
1291        let turn = ModelTurn::new(
1292            None,
1293            OneOrMany::many(vec![tool_call("call_1", "add"), tool_call("call_2", "add")])
1294                .expect("two items"),
1295            Usage::new(),
1296            tool_names(&["add"]),
1297            tool_names(&["add"]),
1298        );
1299        expect_continue(
1300            run.model_response(turn)
1301                .expect("model_response should succeed"),
1302        );
1303
1304        let calls = expect_call_tools(&mut run);
1305        assert_eq!(calls.len(), 2);
1306        assert_eq!(calls[0].tool_call.id, "call_1");
1307        assert_eq!(calls[1].tool_call.id, "call_2");
1308
1309        // Results fed out of order still land in one user message.
1310        run.tool_results(vec![tool_result("call_2", "b"), tool_result("call_1", "a")])
1311            .expect("tool_results should succeed");
1312        let messages = run.messages();
1313        assert!(matches!(
1314            messages.last(),
1315            Some(Message::User { content }) if content.len() == 2
1316        ));
1317    }
1318
1319    #[test]
1320    fn max_turns_exhaustion_returns_max_turns_error() {
1321        let mut run = AgentRun::new("loop forever");
1322
1323        for turn_id in ["call_1", "call_2"] {
1324            expect_call_model(&mut run);
1325            expect_continue(
1326                run.model_response(tool_call_turn(turn_id, "add"))
1327                    .expect("model_response should succeed"),
1328            );
1329            expect_call_tools(&mut run);
1330            run.tool_results(vec![tool_result(turn_id, "0")])
1331                .expect("tool_results should succeed");
1332        }
1333
1334        let err = run.next_step().expect_err("depth should be exhausted");
1335        assert!(matches!(
1336            err,
1337            PromptError::MaxTurnsError { max_turns: 0, .. }
1338        ));
1339    }
1340
1341    #[test]
1342    fn invalid_tool_call_fail_returns_unknown_tool_call() {
1343        let mut run = AgentRun::new("call something");
1344
1345        expect_call_model(&mut run);
1346        let context = expect_needs_resolution(
1347            run.model_response(tool_call_turn("call_1", "unknown"))
1348                .expect("model_response should succeed"),
1349        );
1350        assert_eq!(context.tool_name, "unknown");
1351        assert_eq!(context.available_tools, vec!["add".to_string()]);
1352        assert!(!context.is_streaming);
1353        // Diagnostic history includes the rejected assistant turn.
1354        assert_eq!(context.chat_history.len(), 2);
1355
1356        let err = run
1357            .resolve_invalid_tool_call(InvalidToolCallHookAction::fail())
1358            .expect_err("fail action should error");
1359        assert!(matches!(
1360            err,
1361            PromptError::UnknownToolCall { tool_name, .. } if tool_name == "unknown"
1362        ));
1363    }
1364
1365    #[test]
1366    fn invalid_tool_call_retry_rolls_back_with_feedback() {
1367        let mut run = AgentRun::new("call something")
1368            .max_turns(2)
1369            .max_invalid_tool_call_retries(1);
1370
1371        expect_call_model(&mut run);
1372        expect_needs_resolution(
1373            run.model_response(tool_call_turn("call_1", "unknown"))
1374                .expect("model_response should succeed"),
1375        );
1376        let outcome = run
1377            .resolve_invalid_tool_call(InvalidToolCallHookAction::retry("use add instead"))
1378            .expect("retry should be accepted");
1379        assert!(matches!(outcome, ModelTurnOutcome::TurnRetried));
1380
1381        // The rolled-back turn appended the assistant message and feedback.
1382        assert_eq!(run.messages().len(), 3);
1383        let (prompt, _, turn) = expect_call_model(&mut run);
1384        assert_eq!(turn, 2);
1385        assert!(matches!(
1386            prompt,
1387            Message::User { ref content }
1388                if matches!(content.first(), UserContent::ToolResult(_))
1389        ));
1390
1391        // Budget of one: a second retry fails with UnknownToolCall.
1392        expect_needs_resolution(
1393            run.model_response(tool_call_turn("call_2", "unknown"))
1394                .expect("model_response should succeed"),
1395        );
1396        let err = run
1397            .resolve_invalid_tool_call(InvalidToolCallHookAction::retry("again"))
1398            .expect_err("budget exhausted");
1399        assert!(matches!(err, PromptError::UnknownToolCall { .. }));
1400    }
1401
1402    #[test]
1403    fn invalid_tool_call_repair_renames_and_suppresses_response_hook() {
1404        let mut run = AgentRun::new("call something").max_turns(2);
1405
1406        expect_call_model(&mut run);
1407        expect_needs_resolution(
1408            run.model_response(tool_call_turn("call_1", "default_api"))
1409                .expect("model_response should succeed"),
1410        );
1411        let suppressed = expect_continue(
1412            run.resolve_invalid_tool_call(InvalidToolCallHookAction::repair("add"))
1413                .expect("repair should be accepted"),
1414        );
1415        assert!(suppressed);
1416
1417        let calls = expect_call_tools(&mut run);
1418        assert_eq!(calls[0].tool_call.function.name, "add");
1419        assert!(calls[0].preresolved_result.is_none());
1420    }
1421
1422    #[test]
1423    fn invalid_tool_call_repair_to_disallowed_name_fails() {
1424        let mut run = AgentRun::new("call something");
1425
1426        expect_call_model(&mut run);
1427        expect_needs_resolution(
1428            run.model_response(tool_call_turn("call_1", "unknown"))
1429                .expect("model_response should succeed"),
1430        );
1431        let err = run
1432            .resolve_invalid_tool_call(InvalidToolCallHookAction::repair("also_unknown"))
1433            .expect_err("repair to disallowed name should fail");
1434        assert!(matches!(
1435            err,
1436            PromptError::UnknownToolCall { tool_name, .. } if tool_name == "also_unknown"
1437        ));
1438    }
1439
1440    #[test]
1441    fn invalid_tool_call_skip_suppresses_all_peer_executions() {
1442        let mut run = AgentRun::new("call things").max_turns(2);
1443
1444        expect_call_model(&mut run);
1445        let turn = ModelTurn::new(
1446            None,
1447            OneOrMany::many(vec![
1448                tool_call("call_1", "unknown"),
1449                tool_call("call_2", "add"),
1450            ])
1451            .expect("two items"),
1452            Usage::new(),
1453            tool_names(&["add"]),
1454            tool_names(&["add"]),
1455        );
1456        expect_needs_resolution(
1457            run.model_response(turn)
1458                .expect("model_response should succeed"),
1459        );
1460        let suppressed = expect_continue(
1461            run.resolve_invalid_tool_call(InvalidToolCallHookAction::skip("not available"))
1462                .expect("skip should be accepted"),
1463        );
1464        assert!(suppressed);
1465
1466        let calls = expect_call_tools(&mut run);
1467        assert_eq!(calls.len(), 2);
1468        // Both the skipped call and its valid peer carry preresolved results.
1469        assert!(calls.iter().all(|call| call.preresolved_result.is_some()));
1470    }
1471
1472    #[test]
1473    fn skip_under_tool_choice_none_fails() {
1474        let mut run = AgentRun::new("call something").with_tool_choice(ToolChoice::None);
1475
1476        expect_call_model(&mut run);
1477        expect_needs_resolution(
1478            run.model_response(ModelTurn::new(
1479                None,
1480                OneOrMany::one(tool_call("call_1", "add")),
1481                Usage::new(),
1482                tool_names(&["add"]),
1483                BTreeSet::new(),
1484            ))
1485            .expect("model_response should succeed"),
1486        );
1487        let err = run
1488            .resolve_invalid_tool_call(InvalidToolCallHookAction::skip("nope"))
1489            .expect_err("skip under ToolChoice::None should fail");
1490        assert!(matches!(err, PromptError::UnknownToolCall { .. }));
1491    }
1492
1493    #[test]
1494    fn empty_tool_results_cancel_the_run() {
1495        let mut run = AgentRun::new("call something").max_turns(2);
1496
1497        expect_call_model(&mut run);
1498        expect_continue(
1499            run.model_response(tool_call_turn("call_1", "add"))
1500                .expect("model_response should succeed"),
1501        );
1502        expect_call_tools(&mut run);
1503
1504        let err = run
1505            .tool_results(Vec::new())
1506            .expect_err("empty results should cancel");
1507        assert!(matches!(
1508            err,
1509            PromptError::PromptCancelled { reason, .. }
1510                if reason.contains("tool execution produced no tool results")
1511        ));
1512    }
1513
1514    #[test]
1515    fn out_of_protocol_calls_are_rejected_without_corrupting_state() {
1516        let mut run = AgentRun::new("hello");
1517
1518        let err = run
1519            .tool_results(vec![tool_result("call_1", "x")])
1520            .expect_err("no CallTools pending");
1521        assert!(matches!(err, PromptError::PromptCancelled { .. }));
1522
1523        // The run is still drivable after a rejected out-of-protocol call.
1524        expect_call_model(&mut run);
1525        let err = run
1526            .next_step()
1527            .expect_err("model response is pending, next_step must be rejected");
1528        assert!(matches!(err, PromptError::PromptCancelled { .. }));
1529        expect_continue(
1530            run.model_response(text_turn("hi"))
1531                .expect("model_response should still succeed"),
1532        );
1533        assert_eq!(expect_done(&mut run).output, "hi");
1534    }
1535
1536    #[test]
1537    fn model_response_rejected_after_streamed_completion_call_record() {
1538        let mut run = AgentRun::new("hello");
1539        expect_call_model(&mut run);
1540        run.record_streamed_completion_call(Usage::new())
1541            .expect("record should succeed");
1542
1543        let err = run
1544            .model_response(text_turn("hi"))
1545            .expect_err("mixed streamed/non-streamed ingestion must be rejected");
1546        assert!(matches!(err, PromptError::PromptCancelled { .. }));
1547        // No duplicate completion call was appended.
1548        assert_eq!(run.completion_calls().len(), 1);
1549    }
1550
1551    #[test]
1552    fn done_step_is_idempotent() {
1553        let mut run = AgentRun::new("hello");
1554        expect_call_model(&mut run);
1555        expect_continue(
1556            run.model_response(text_turn("hi"))
1557                .expect("model_response should succeed"),
1558        );
1559        assert_eq!(expect_done(&mut run).output, "hi");
1560        assert_eq!(expect_done(&mut run).output, "hi");
1561    }
1562
1563    #[test]
1564    fn serialized_run_alone_carries_pending_tool_calls() {
1565        let mut run = AgentRun::new("add things").max_turns(2);
1566        expect_call_model(&mut run);
1567        expect_continue(
1568            run.model_response(tool_call_turn("call_1", "add"))
1569                .expect("model_response should succeed"),
1570        );
1571        expect_call_tools(&mut run);
1572
1573        // A fresh process receives only the serialized run: the pending tool
1574        // calls must be recoverable from the state itself.
1575        let serialized = serde_json::to_string(&run).expect("mid-run state should serialize");
1576        drop(run);
1577        let mut resumed: AgentRun =
1578            serde_json::from_str(&serialized).expect("mid-run state should deserialize");
1579
1580        let calls = expect_call_tools(&mut resumed);
1581        assert_eq!(calls.len(), 1);
1582        assert_eq!(calls[0].tool_call.function.name, "add");
1583        // Re-emission is idempotent while results are pending.
1584        let calls_again = expect_call_tools(&mut resumed);
1585        assert_eq!(calls_again[0].tool_call.id, calls[0].tool_call.id);
1586
1587        // Answer using only IDs learned from the re-emitted step.
1588        let results = calls
1589            .iter()
1590            .map(|call| tool_result(&call.tool_call.id, "2"))
1591            .collect::<Vec<_>>();
1592        resumed
1593            .tool_results(results)
1594            .expect("tool_results should succeed");
1595        expect_call_model(&mut resumed);
1596        expect_continue(
1597            resumed
1598                .model_response(text_turn("done"))
1599                .expect("model_response should succeed"),
1600        );
1601        assert_eq!(expect_done(&mut resumed).output, "done");
1602    }
1603
1604    #[test]
1605    fn tool_results_validates_against_pending_calls() {
1606        let drive_to_pending_tools = || {
1607            let mut run = AgentRun::new("add things").max_turns(2);
1608            expect_call_model(&mut run);
1609            expect_continue(
1610                run.model_response(tool_call_turn("call_1", "add"))
1611                    .expect("model_response should succeed"),
1612            );
1613            expect_call_tools(&mut run);
1614            run
1615        };
1616
1617        // A result for an unknown call ID is rejected without corrupting the run.
1618        let mut run = drive_to_pending_tools();
1619        let err = run
1620            .tool_results(vec![tool_result("call_unknown", "2")])
1621            .expect_err("unknown tool call id must be rejected");
1622        assert!(matches!(err, PromptError::PromptCancelled { .. }));
1623        run.tool_results(vec![tool_result("call_1", "2")])
1624            .expect("valid results should still be accepted after a rejection");
1625
1626        // Leaving a pending call unanswered is rejected.
1627        let mut run = drive_to_pending_tools();
1628        let err = run
1629            .tool_results(vec![tool_result("call_1", "2"), tool_result("call_1", "3")])
1630            .expect_err("answering one call twice must be rejected");
1631        assert!(matches!(err, PromptError::PromptCancelled { .. }));
1632
1633        // Non-tool-result content is rejected.
1634        let mut run = drive_to_pending_tools();
1635        let err = run
1636            .tool_results(vec![UserContent::text("not a tool result")])
1637            .expect_err("non-tool-result content must be rejected");
1638        assert!(matches!(err, PromptError::PromptCancelled { .. }));
1639    }
1640
1641    #[test]
1642    fn agent_run_deserializes_pre_monoid_suspended_state() {
1643        // Fixture captured from rig before CompletionCall.usage dropped its
1644        // Option encoding, suspended at ExecutingTools with a null-usage
1645        // completion call. It must deserialize and resume.
1646        let fixture = r#"{"max_turns":2,"max_invalid_tool_call_retries":0,"tool_choice":null,"chat_history":null,"new_messages":[{"role":"user","content":[{"type":"text","text":"add things"}]},{"role":"assistant","id":null,"content":[{"id":"call_1","call_id":null,"function":{"name":"add","arguments":{"x":1}},"signature":null,"additional_params":null}]}],"current_turn":1,"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15,"cached_input_tokens":0,"cache_creation_input_tokens":0,"tool_use_prompt_tokens":0,"reasoning_tokens":0},"completion_calls":[{"call_index":0,"usage":null}],"completion_call_index":1,"invalid_tool_call_retries":0,"rollback_pending":false,"streamed_completion_call_recorded":false,"state":{"ExecutingTools":[{"tool_call":{"id":"call_1","call_id":null,"function":{"name":"add","arguments":{"x":1}},"signature":null,"additional_params":null},"preresolved_result":null,"internal_call_id":null}]}}"#;
1647
1648        let mut restored: AgentRun =
1649            serde_json::from_str(fixture).expect("old-format suspended run should deserialize");
1650        assert_eq!(restored.completion_calls()[0].usage, Usage::new());
1651
1652        let calls = expect_call_tools(&mut restored);
1653        assert_eq!(calls.len(), 1);
1654        restored
1655            .tool_results(vec![tool_result("call_1", "2")])
1656            .expect("tool_results should succeed");
1657        expect_call_model(&mut restored);
1658    }
1659
1660    #[test]
1661    fn serde_round_trip_mid_run_resumes_identically() {
1662        let drive_to_pending_tools = || {
1663            let mut run = AgentRun::new("add things").max_turns(2);
1664            expect_call_model(&mut run);
1665            expect_continue(
1666                run.model_response(
1667                    tool_call_turn("call_1", "add").with_usage_for_test(usage(10, 5)),
1668                )
1669                .expect("model_response should succeed"),
1670            );
1671            expect_call_tools(&mut run);
1672            run
1673        };
1674
1675        let finish = |mut run: AgentRun| {
1676            run.tool_results(vec![tool_result("call_1", "2")])
1677                .expect("tool_results should succeed");
1678            expect_call_model(&mut run);
1679            expect_continue(
1680                run.model_response(text_turn("done").with_usage_for_test(usage(3, 4)))
1681                    .expect("model_response should succeed"),
1682            );
1683            expect_done(&mut run)
1684        };
1685
1686        let uninterrupted = finish(drive_to_pending_tools());
1687
1688        let suspended = drive_to_pending_tools();
1689        let serialized = serde_json::to_string(&suspended).expect("mid-run state should serialize");
1690        let restored: AgentRun =
1691            serde_json::from_str(&serialized).expect("mid-run state should deserialize");
1692        let resumed = finish(restored);
1693
1694        assert_eq!(resumed.output, uninterrupted.output);
1695        assert_eq!(resumed.usage, uninterrupted.usage);
1696        assert_eq!(resumed.completion_calls, uninterrupted.completion_calls);
1697        // Compare messages by their serialized form: deserializing a message
1698        // normalizes absent `additional_params` to an empty map, which is
1699        // semantically identical and serializes identically.
1700        assert_eq!(
1701            serde_json::to_value(&resumed.messages).expect("messages should serialize"),
1702            serde_json::to_value(&uninterrupted.messages).expect("messages should serialize"),
1703        );
1704    }
1705
1706    #[test]
1707    fn pending_invalid_tool_call_survives_serde_round_trip() {
1708        let mut run = AgentRun::new("call something");
1709        expect_call_model(&mut run);
1710        let context = expect_needs_resolution(
1711            run.model_response(tool_call_turn("call_1", "unknown"))
1712                .expect("model_response should succeed"),
1713        );
1714
1715        let serialized = serde_json::to_string(&run).expect("state should serialize");
1716        let restored: AgentRun =
1717            serde_json::from_str(&serialized).expect("state should deserialize");
1718        let restored_context = restored
1719            .pending_invalid_tool_call()
1720            .expect("pending resolution should survive serialization");
1721        assert_eq!(restored_context.tool_name, context.tool_name);
1722        assert_eq!(
1723            restored_context.chat_history.len(),
1724            context.chat_history.len()
1725        );
1726    }
1727
1728    impl ModelTurn {
1729        fn with_usage_for_test(mut self, usage: Usage) -> Self {
1730            self.usage = usage;
1731            self
1732        }
1733    }
1734}