rig_core/agent/prompt_request/hooks.rs
1//! Optional hooks for agent prompting.
2//! Hooks can be used to create custom behaviour like logging, calling external services or conditionally skipping tool calls.
3//! Alternatively, you can also use them to terminate agent loops early.
4
5use crate::{
6 completion::CompletionModel,
7 message::{Message, ToolChoice},
8 wasm_compat::{WasmCompatSend, WasmCompatSync},
9};
10
11/// Context passed to [`PromptHook::on_invalid_tool_call`] when the model emits a tool call
12/// that Rig would reject before normal tool-call hooks or execution.
13#[derive(Debug, Clone)]
14pub struct InvalidToolCallContext {
15 /// Tool name emitted by the model.
16 pub tool_name: String,
17 /// Provider-supplied tool call ID, when available.
18 pub tool_call_id: Option<String>,
19 /// Internal Rig call ID, when available.
20 pub internal_call_id: Option<String>,
21 /// JSON arguments emitted for the tool call, when available.
22 pub args: Option<String>,
23 /// Executable Rig tools advertised to the provider for this turn.
24 pub available_tools: Vec<String>,
25 /// Tools allowed by the active [`ToolChoice`] for this turn.
26 pub allowed_tools: Vec<String>,
27 /// Active tool choice for this turn.
28 pub tool_choice: Option<ToolChoice>,
29 /// Diagnostic chat history including the rejected model output when available.
30 pub chat_history: Vec<Message>,
31 /// Whether the rejected call came from the streaming path.
32 pub is_streaming: bool,
33}
34
35/// Recovery action for invalid tool-call hooks.
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum InvalidToolCallHookAction {
38 /// Preserve Rig's default fail-fast behavior.
39 Fail,
40 /// Retry the model turn with corrective feedback.
41 Retry { feedback: String },
42 /// Rewrite only the emitted tool name. The repaired name is revalidated
43 /// against registered tools and the current `ToolChoice` before use.
44 Repair { tool_name: String },
45 /// Treat an invalid structured tool call as skipped by returning synthetic
46 /// feedback as its tool result. This does not execute the invalid tool.
47 Skip { reason: String },
48}
49
50impl InvalidToolCallHookAction {
51 /// Preserve Rig's default fail-fast behavior.
52 pub fn fail() -> Self {
53 Self::Fail
54 }
55
56 /// Retry the model turn with corrective feedback.
57 pub fn retry(feedback: impl Into<String>) -> Self {
58 Self::Retry {
59 feedback: feedback.into(),
60 }
61 }
62
63 /// Repair the emitted tool name.
64 pub fn repair(tool_name: impl Into<String>) -> Self {
65 Self::Repair {
66 tool_name: tool_name.into(),
67 }
68 }
69
70 /// Skip the invalid call with a synthetic tool result.
71 pub fn skip(reason: impl Into<String>) -> Self {
72 Self::Skip {
73 reason: reason.into(),
74 }
75 }
76}
77
78/// Trait for per-request hooks to observe tool call events.
79pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
80where
81 M: CompletionModel,
82{
83 /// Called before the prompt is sent to the model
84 fn on_completion_call(
85 &self,
86 _prompt: &Message,
87 _history: &[Message],
88 ) -> impl Future<Output = HookAction> + WasmCompatSend {
89 async { HookAction::cont() }
90 }
91
92 /// Called after the prompt is sent to the model and a response is received.
93 fn on_completion_response(
94 &self,
95 _prompt: &Message,
96 _response: &crate::completion::CompletionResponse<M::Response>,
97 ) -> impl Future<Output = HookAction> + WasmCompatSend {
98 async { HookAction::cont() }
99 }
100
101 /// Called when a model-emitted tool call is unknown or disallowed by the
102 /// current request's tool choice.
103 ///
104 /// The default behavior remains fail-fast. Override this method to opt into
105 /// retry, repair, or skip recovery for invalid tool calls.
106 fn on_invalid_tool_call(
107 &self,
108 _context: &InvalidToolCallContext,
109 ) -> impl Future<Output = InvalidToolCallHookAction> + WasmCompatSend {
110 async { InvalidToolCallHookAction::fail() }
111 }
112
113 /// Called before a tool is invoked.
114 ///
115 /// # Returns
116 /// - `ToolCallHookAction::Continue` - Allow tool execution to proceed
117 /// - `ToolCallHookAction::Skip { reason }` - Reject tool execution; `reason` will be returned to the LLM as the tool result
118 fn on_tool_call(
119 &self,
120 _tool_name: &str,
121 _tool_call_id: Option<String>,
122 _internal_call_id: &str,
123 _args: &str,
124 ) -> impl Future<Output = ToolCallHookAction> + WasmCompatSend {
125 async { ToolCallHookAction::cont() }
126 }
127
128 /// Called after a tool is invoked (and a result has been returned).
129 fn on_tool_result(
130 &self,
131 _tool_name: &str,
132 _tool_call_id: Option<String>,
133 _internal_call_id: &str,
134 _args: &str,
135 _result: &str,
136 ) -> impl Future<Output = HookAction> + WasmCompatSend {
137 async { HookAction::cont() }
138 }
139
140 /// Called when receiving a text delta (streaming responses only)
141 fn on_text_delta(
142 &self,
143 _text_delta: &str,
144 _aggregated_text: &str,
145 ) -> impl Future<Output = HookAction> + Send {
146 async { HookAction::cont() }
147 }
148
149 /// Called when receiving a tool call delta (streaming_responses_only).
150 /// `tool_name` is Some on the first delta for a tool call, None on subsequent deltas.
151 fn on_tool_call_delta(
152 &self,
153 _tool_call_id: &str,
154 _internal_call_id: &str,
155 _tool_name: Option<&str>,
156 _tool_call_delta: &str,
157 ) -> impl Future<Output = HookAction> + Send {
158 async { HookAction::cont() }
159 }
160
161 /// Called after the model provider has finished streaming a text response from their completion API to the client.
162 fn on_stream_completion_response_finish(
163 &self,
164 _prompt: &Message,
165 _response: &<M as CompletionModel>::StreamingResponse,
166 ) -> impl Future<Output = HookAction> + Send {
167 async { HookAction::cont() }
168 }
169}
170
171impl<M> PromptHook<M> for () where M: CompletionModel {}
172
173/// Control flow action for tool call hooks. This is different from the regular [`HookAction`] in that tool call executions may be skipped for one or more reasons.
174#[derive(Debug, Clone, PartialEq, Eq)]
175pub enum ToolCallHookAction {
176 /// Continue tool execution as normal.
177 Continue,
178 /// Skip tool execution and return the provided reason as the tool result.
179 Skip { reason: String },
180 /// Terminate agent loop early
181 Terminate { reason: String },
182}
183
184impl ToolCallHookAction {
185 /// Continue the agentic loop as normal
186 pub fn cont() -> Self {
187 Self::Continue
188 }
189
190 /// Skip a given tool call (with a provided reason).
191 pub fn skip(reason: impl Into<String>) -> Self {
192 Self::Skip {
193 reason: reason.into(),
194 }
195 }
196
197 /// Terminates the agentic loop entirely.
198 pub fn terminate(reason: impl Into<String>) -> Self {
199 Self::Terminate {
200 reason: reason.into(),
201 }
202 }
203}
204
205/// Control flow action for hooks.
206#[derive(Debug, Clone, PartialEq, Eq)]
207pub enum HookAction {
208 /// Continue agentic loop execution as normal.
209 Continue,
210 /// Terminate agent loop early
211 Terminate { reason: String },
212}
213
214impl HookAction {
215 /// Continue the agentic loop as normal
216 pub fn cont() -> Self {
217 Self::Continue
218 }
219
220 /// Terminates the agentic loop entirely.
221 pub fn terminate(reason: impl Into<String>) -> Self {
222 Self::Terminate {
223 reason: reason.into(),
224 }
225 }
226}