Skip to main content

rig_core/agent/prompt_request/
hooks.rs

1//! Optional hooks for agent prompting.
2//! Hooks can be used to create custom behaviour like logging, calling external services or conditionally skipping tool calls.
3//! Alternatively, you can also use them to terminate agent loops early.
4
5use crate::{
6    completion::CompletionModel,
7    message::{Message, ToolChoice},
8    wasm_compat::{WasmCompatSend, WasmCompatSync},
9};
10
11/// Context passed to [`PromptHook::on_invalid_tool_call`] when the model emits a tool call
12/// that Rig would reject before normal tool-call hooks or execution.
13#[derive(Debug, Clone)]
14pub struct InvalidToolCallContext {
15    /// Tool name emitted by the model.
16    pub tool_name: String,
17    /// Provider-supplied tool call ID, when available.
18    pub tool_call_id: Option<String>,
19    /// Internal Rig call ID, when available.
20    pub internal_call_id: Option<String>,
21    /// JSON arguments emitted for the tool call, when available.
22    pub args: Option<String>,
23    /// Executable Rig tools advertised to the provider for this turn.
24    pub available_tools: Vec<String>,
25    /// Tools allowed by the active [`ToolChoice`] for this turn.
26    pub allowed_tools: Vec<String>,
27    /// Active tool choice for this turn.
28    pub tool_choice: Option<ToolChoice>,
29    /// Diagnostic chat history including the rejected model output when available.
30    pub chat_history: Vec<Message>,
31    /// Whether the rejected call came from the streaming path.
32    pub is_streaming: bool,
33}
34
35/// Recovery action for invalid tool-call hooks.
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum InvalidToolCallHookAction {
38    /// Preserve Rig's default fail-fast behavior.
39    Fail,
40    /// Retry the model turn with corrective feedback.
41    Retry { feedback: String },
42    /// Rewrite only the emitted tool name. The repaired name is revalidated
43    /// against registered tools and the current `ToolChoice` before use.
44    Repair { tool_name: String },
45    /// Treat an invalid structured tool call as skipped by returning synthetic
46    /// feedback as its tool result. This does not execute the invalid tool.
47    Skip { reason: String },
48}
49
50impl InvalidToolCallHookAction {
51    /// Preserve Rig's default fail-fast behavior.
52    pub fn fail() -> Self {
53        Self::Fail
54    }
55
56    /// Retry the model turn with corrective feedback.
57    pub fn retry(feedback: impl Into<String>) -> Self {
58        Self::Retry {
59            feedback: feedback.into(),
60        }
61    }
62
63    /// Repair the emitted tool name.
64    pub fn repair(tool_name: impl Into<String>) -> Self {
65        Self::Repair {
66            tool_name: tool_name.into(),
67        }
68    }
69
70    /// Skip the invalid call with a synthetic tool result.
71    pub fn skip(reason: impl Into<String>) -> Self {
72        Self::Skip {
73            reason: reason.into(),
74        }
75    }
76}
77
78/// Trait for per-request hooks to observe tool call events.
79pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
80where
81    M: CompletionModel,
82{
83    /// Called before the prompt is sent to the model
84    fn on_completion_call(
85        &self,
86        _prompt: &Message,
87        _history: &[Message],
88    ) -> impl Future<Output = HookAction> + WasmCompatSend {
89        async { HookAction::cont() }
90    }
91
92    /// Called after the prompt is sent to the model and a response is received.
93    fn on_completion_response(
94        &self,
95        _prompt: &Message,
96        _response: &crate::completion::CompletionResponse<M::Response>,
97    ) -> impl Future<Output = HookAction> + WasmCompatSend {
98        async { HookAction::cont() }
99    }
100
101    /// Called when a model-emitted tool call is unknown or disallowed by the
102    /// current request's tool choice.
103    ///
104    /// The default behavior remains fail-fast. Override this method to opt into
105    /// retry, repair, or skip recovery for invalid tool calls.
106    fn on_invalid_tool_call(
107        &self,
108        _context: &InvalidToolCallContext,
109    ) -> impl Future<Output = InvalidToolCallHookAction> + WasmCompatSend {
110        async { InvalidToolCallHookAction::fail() }
111    }
112
113    /// Called before a tool is invoked.
114    ///
115    /// # Returns
116    /// - `ToolCallHookAction::Continue` - Allow tool execution to proceed
117    /// - `ToolCallHookAction::Skip { reason }` - Reject tool execution; `reason` will be returned to the LLM as the tool result
118    fn on_tool_call(
119        &self,
120        _tool_name: &str,
121        _tool_call_id: Option<String>,
122        _internal_call_id: &str,
123        _args: &str,
124    ) -> impl Future<Output = ToolCallHookAction> + WasmCompatSend {
125        async { ToolCallHookAction::cont() }
126    }
127
128    /// Called after a tool is invoked (and a result has been returned).
129    fn on_tool_result(
130        &self,
131        _tool_name: &str,
132        _tool_call_id: Option<String>,
133        _internal_call_id: &str,
134        _args: &str,
135        _result: &str,
136    ) -> impl Future<Output = HookAction> + WasmCompatSend {
137        async { HookAction::cont() }
138    }
139
140    /// Called when receiving a text delta (streaming responses only)
141    fn on_text_delta(
142        &self,
143        _text_delta: &str,
144        _aggregated_text: &str,
145    ) -> impl Future<Output = HookAction> + Send {
146        async { HookAction::cont() }
147    }
148
149    /// Called when receiving a tool call delta (streaming_responses_only).
150    /// `tool_name` is Some on the first delta for a tool call, None on subsequent deltas.
151    fn on_tool_call_delta(
152        &self,
153        _tool_call_id: &str,
154        _internal_call_id: &str,
155        _tool_name: Option<&str>,
156        _tool_call_delta: &str,
157    ) -> impl Future<Output = HookAction> + Send {
158        async { HookAction::cont() }
159    }
160
161    /// Called after the model provider has finished streaming a text response from their completion API to the client.
162    fn on_stream_completion_response_finish(
163        &self,
164        _prompt: &Message,
165        _response: &<M as CompletionModel>::StreamingResponse,
166    ) -> impl Future<Output = HookAction> + Send {
167        async { HookAction::cont() }
168    }
169}
170
171impl<M> PromptHook<M> for () where M: CompletionModel {}
172
173/// Control flow action for tool call hooks. This is different from the regular [`HookAction`] in that tool call executions may be skipped for one or more reasons.
174#[derive(Debug, Clone, PartialEq, Eq)]
175pub enum ToolCallHookAction {
176    /// Continue tool execution as normal.
177    Continue,
178    /// Skip tool execution and return the provided reason as the tool result.
179    Skip { reason: String },
180    /// Terminate agent loop early
181    Terminate { reason: String },
182}
183
184impl ToolCallHookAction {
185    /// Continue the agentic loop as normal
186    pub fn cont() -> Self {
187        Self::Continue
188    }
189
190    /// Skip a given tool call (with a provided reason).
191    pub fn skip(reason: impl Into<String>) -> Self {
192        Self::Skip {
193            reason: reason.into(),
194        }
195    }
196
197    /// Terminates the agentic loop entirely.
198    pub fn terminate(reason: impl Into<String>) -> Self {
199        Self::Terminate {
200            reason: reason.into(),
201        }
202    }
203}
204
205/// Control flow action for hooks.
206#[derive(Debug, Clone, PartialEq, Eq)]
207pub enum HookAction {
208    /// Continue agentic loop execution as normal.
209    Continue,
210    /// Terminate agent loop early
211    Terminate { reason: String },
212}
213
214impl HookAction {
215    /// Continue the agentic loop as normal
216    pub fn cont() -> Self {
217        Self::Continue
218    }
219
220    /// Terminates the agentic loop entirely.
221    pub fn terminate(reason: impl Into<String>) -> Self {
222        Self::Terminate {
223            reason: reason.into(),
224        }
225    }
226}