Skip to main content

agent_sdk_tools/
hooks.rs

1//! Agent lifecycle hooks for customization.
2//!
3//! Hooks allow you to intercept and customize agent behavior at key points:
4//!
5//! - [`AgentHooks::pre_tool_use`] - Control tool execution permissions
6//! - [`AgentHooks::post_tool_use`] - React to tool completion
7//! - [`AgentHooks::on_event`] - Log or process events
8//! - [`AgentHooks::on_error`] - Handle errors and decide recovery
9//!
10//! # Built-in Implementations
11//!
12//! - [`DefaultHooks`] - Tier-based permissions (default)
13//! - [`AllowAllHooks`] - Allow all tools without confirmation
14//! - [`LoggingHooks`] - Debug logging for all events
15
16use agent_sdk_foundation::events::AgentEvent;
17use agent_sdk_foundation::llm;
18use agent_sdk_foundation::types::{ToolInvocation, ToolResult, ToolTier};
19use async_trait::async_trait;
20
21/// Decision returned by pre-tool hooks
22#[derive(Debug, Clone)]
23#[non_exhaustive]
24pub enum ToolDecision {
25    /// Allow the tool to execute
26    Allow,
27    /// Block the tool execution with a message
28    Block(String),
29    /// Tool requires user confirmation.
30    RequiresConfirmation(String),
31}
32
33/// Decision returned by [`AgentHooks::pre_llm_request`] — an input guardrail
34/// that runs before the outbound [`llm::ChatRequest`] is sent to the provider.
35///
36/// This is the place for prompt-injection scrubbing, PII gating, or
37/// system-prompt policy enforcement.
38#[derive(Debug, Clone)]
39#[non_exhaustive]
40pub enum RequestDecision {
41    /// Send the request unchanged.
42    Proceed,
43    /// Send a modified request instead of the original.
44    Modify(Box<llm::ChatRequest>),
45    /// Refuse to call the model; the string explains why.
46    Block(String),
47}
48
49/// Decision returned by [`AgentHooks::on_llm_response`] — an output guardrail
50/// that runs after the provider responds but before the response is persisted
51/// and surfaced.
52///
53/// This is the place for output moderation or secret-leakage detection.
54#[derive(Debug, Clone)]
55#[non_exhaustive]
56pub enum ResponseDecision {
57    /// Accept the response as-is.
58    Accept,
59    /// Reject the response; the string explains why.
60    Block(String),
61    /// Reject the response and feed the string back to the model so it can
62    /// retry on the next turn.
63    RetryWithFeedback(String),
64}
65
66/// Lifecycle hooks for the agent loop.
67/// Implement this trait to customize agent behavior.
68#[async_trait]
69pub trait AgentHooks: Send + Sync {
70    /// Called before a tool is executed.
71    ///
72    /// Receives a structured [`ToolInvocation`] that bundles tool identity,
73    /// tier, requested input, effective input, and listen-context — everything
74    /// a server-side policy engine needs for an allow / block / confirm decision.
75    ///
76    /// Return [`ToolDecision::Allow`] to proceed, [`ToolDecision::Block`] to
77    /// reject, or [`ToolDecision::RequiresConfirmation`] to yield for user
78    /// approval.
79    async fn pre_tool_use(&self, invocation: &ToolInvocation) -> ToolDecision {
80        match invocation.tier {
81            ToolTier::Observe => ToolDecision::Allow,
82            ToolTier::Confirm => {
83                ToolDecision::RequiresConfirmation(format!("Confirm {}?", invocation.tool_name))
84            }
85        }
86    }
87
88    /// Called after a tool completes execution.
89    async fn post_tool_use(&self, _tool_name: &str, _result: &ToolResult) {
90        // Default: no-op
91    }
92
93    /// Called when the agent emits an event.
94    /// Can be used for logging, metrics, or custom handling.
95    async fn on_event(&self, _event: &AgentEvent) {
96        // Default: no-op
97    }
98
99    /// Called when an error occurs.
100    /// Return true to attempt recovery, false to abort.
101    async fn on_error(&self, _error: &anyhow::Error) -> bool {
102        // Default: don't recover
103        false
104    }
105
106    /// Called when context is about to be compacted due to length.
107    /// Return a summary to use, or None to use default summarization.
108    async fn on_context_compact(&self, _messages: &[llm::Message]) -> Option<String> {
109        // Default: use built-in summarization
110        None
111    }
112
113    /// Input guardrail: called with the outbound [`llm::ChatRequest`] before it
114    /// is sent to the provider.
115    ///
116    /// Return [`RequestDecision::Proceed`] to send it unchanged,
117    /// [`RequestDecision::Modify`] to substitute a sanitized request, or
118    /// [`RequestDecision::Block`] to refuse the call (e.g. prompt-injection or
119    /// PII policy). The default proceeds unchanged.
120    async fn pre_llm_request(&self, _request: &llm::ChatRequest) -> RequestDecision {
121        RequestDecision::Proceed
122    }
123
124    /// Output guardrail: called with the provider's [`llm::ChatResponse`] before
125    /// it is persisted or surfaced.
126    ///
127    /// Return [`ResponseDecision::Accept`] to keep it,
128    /// [`ResponseDecision::Block`] to reject it, or
129    /// [`ResponseDecision::RetryWithFeedback`] to reject it and steer a retry
130    /// (e.g. output moderation or secret-leakage detection). The default
131    /// accepts.
132    async fn on_llm_response(&self, _response: &llm::ChatResponse) -> ResponseDecision {
133        ResponseDecision::Accept
134    }
135}
136
137/// Default hooks implementation that uses tier-based decisions
138#[derive(Clone, Copy, Default)]
139pub struct DefaultHooks;
140
141#[async_trait]
142impl AgentHooks for DefaultHooks {}
143
144/// Hooks that allow all tools without confirmation
145#[derive(Clone, Copy, Default)]
146pub struct AllowAllHooks;
147
148#[async_trait]
149impl AgentHooks for AllowAllHooks {
150    async fn pre_tool_use(&self, _invocation: &ToolInvocation) -> ToolDecision {
151        ToolDecision::Allow
152    }
153}
154
155/// Hooks that log all events (useful for debugging)
156#[derive(Clone, Copy, Default)]
157pub struct LoggingHooks;
158
159#[async_trait]
160impl AgentHooks for LoggingHooks {
161    async fn pre_tool_use(&self, invocation: &ToolInvocation) -> ToolDecision {
162        log::debug!(
163            "Pre-tool use tool={} input={:?} tier={:?}",
164            invocation.tool_name,
165            invocation.requested_input,
166            invocation.tier,
167        );
168        DefaultHooks.pre_tool_use(invocation).await
169    }
170
171    async fn post_tool_use(&self, tool_name: &str, result: &ToolResult) {
172        log::debug!(
173            "Post-tool use tool={tool_name} success={} duration_ms={:?}",
174            result.success,
175            result.duration_ms
176        );
177    }
178
179    async fn on_event(&self, event: &AgentEvent) {
180        log::debug!("Agent event {event:?}");
181    }
182
183    async fn on_error(&self, error: &anyhow::Error) -> bool {
184        log::error!("Agent error {error:?}");
185        false
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use serde_json::json;
193
194    fn invocation(tier: ToolTier) -> ToolInvocation {
195        ToolInvocation {
196            tool_call_id: "call_1".to_string(),
197            tool_name: "danger".to_string(),
198            display_name: "Danger".to_string(),
199            tier,
200            requested_input: json!({}),
201            effective_input: json!({}),
202            listen_context: None,
203        }
204    }
205
206    #[tokio::test]
207    async fn default_hooks_gate_confirm_tier() {
208        // A Confirm-tier tool must yield for confirmation under the default
209        // policy — side-effecting tools never auto-run.
210        let decision = DefaultHooks
211            .pre_tool_use(&invocation(ToolTier::Confirm))
212            .await;
213        assert!(
214            matches!(decision, ToolDecision::RequiresConfirmation(_)),
215            "Confirm tier must require confirmation, got {decision:?}"
216        );
217    }
218
219    #[tokio::test]
220    async fn default_hooks_auto_allow_observe_tier() {
221        let decision = DefaultHooks
222            .pre_tool_use(&invocation(ToolTier::Observe))
223            .await;
224        assert!(
225            matches!(decision, ToolDecision::Allow),
226            "Observe tier may auto-run, got {decision:?}"
227        );
228    }
229
230    #[tokio::test]
231    async fn default_hooks_llm_guardrails_are_permissive_noops() {
232        let request = llm::ChatRequest::new("sys", vec![llm::Message::user("hi")]);
233        assert!(matches!(
234            DefaultHooks.pre_llm_request(&request).await,
235            RequestDecision::Proceed
236        ));
237
238        let response = llm::ChatResponse {
239            id: "resp_1".to_string(),
240            content: Vec::new(),
241            model: "test-model".to_string(),
242            stop_reason: None,
243            usage: llm::Usage {
244                input_tokens: 0,
245                output_tokens: 0,
246                cached_input_tokens: 0,
247                cache_creation_input_tokens: 0,
248            },
249        };
250        assert!(matches!(
251            DefaultHooks.on_llm_response(&response).await,
252            ResponseDecision::Accept
253        ));
254    }
255}