1use std::collections::HashMap;
2use std::process::Command;
3use std::str::FromStr;
4
5use anyhow::{Context, Result, bail};
6use serde_json::Value;
7
8use crate::provider::ToolDefinition;
9use crate::tools::Tool;
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub enum Event {
17 SessionStart,
18 SessionEnd,
19 BeforePrompt,
20 AfterPrompt,
21 BeforeToolCall,
22 AfterToolCall,
23 BeforeCompact,
24 AfterCompact,
25 ModelSwitch,
26 AgentSwitch,
27 OnError,
28 OnStreamStart,
29 OnStreamEnd,
30 OnResume,
31 OnUserInput,
32 OnToolError,
33 BeforeExit,
34 OnThinkingStart,
35 OnThinkingEnd,
36 OnTitleGenerated,
37 BeforePermissionCheck,
38 OnContextLoad,
39}
40
41impl FromStr for Event {
42 type Err = ();
43
44 fn from_str(s: &str) -> Result<Self, Self::Err> {
45 match s {
46 "session_start" => Ok(Self::SessionStart),
47 "session_end" => Ok(Self::SessionEnd),
48 "before_prompt" => Ok(Self::BeforePrompt),
49 "after_prompt" => Ok(Self::AfterPrompt),
50 "before_tool_call" => Ok(Self::BeforeToolCall),
51 "after_tool_call" => Ok(Self::AfterToolCall),
52 "before_compact" => Ok(Self::BeforeCompact),
53 "after_compact" => Ok(Self::AfterCompact),
54 "model_switch" => Ok(Self::ModelSwitch),
55 "agent_switch" => Ok(Self::AgentSwitch),
56 "on_error" => Ok(Self::OnError),
57 "on_stream_start" => Ok(Self::OnStreamStart),
58 "on_stream_end" => Ok(Self::OnStreamEnd),
59 "on_resume" => Ok(Self::OnResume),
60 "on_user_input" => Ok(Self::OnUserInput),
61 "on_tool_error" => Ok(Self::OnToolError),
62 "before_exit" => Ok(Self::BeforeExit),
63 "on_thinking_start" => Ok(Self::OnThinkingStart),
64 "on_thinking_end" => Ok(Self::OnThinkingEnd),
65 "on_title_generated" => Ok(Self::OnTitleGenerated),
66 "before_permission_check" => Ok(Self::BeforePermissionCheck),
67 "on_context_load" => Ok(Self::OnContextLoad),
68 _ => Err(()),
69 }
70 }
71}
72
73impl Event {
74 pub fn as_str(&self) -> &'static str {
75 match self {
76 Self::SessionStart => "session_start",
77 Self::SessionEnd => "session_end",
78 Self::BeforePrompt => "before_prompt",
79 Self::AfterPrompt => "after_prompt",
80 Self::BeforeToolCall => "before_tool_call",
81 Self::AfterToolCall => "after_tool_call",
82 Self::BeforeCompact => "before_compact",
83 Self::AfterCompact => "after_compact",
84 Self::ModelSwitch => "model_switch",
85 Self::AgentSwitch => "agent_switch",
86 Self::OnError => "on_error",
87 Self::OnStreamStart => "on_stream_start",
88 Self::OnStreamEnd => "on_stream_end",
89 Self::OnResume => "on_resume",
90 Self::OnUserInput => "on_user_input",
91 Self::OnToolError => "on_tool_error",
92 Self::BeforeExit => "before_exit",
93 Self::OnThinkingStart => "on_thinking_start",
94 Self::OnThinkingEnd => "on_thinking_end",
95 Self::OnTitleGenerated => "on_title_generated",
96 Self::BeforePermissionCheck => "before_permission_check",
97 Self::OnContextLoad => "on_context_load",
98 }
99 }
100
101 pub fn is_blocking(&self) -> bool {
102 matches!(
103 self,
104 Self::BeforePrompt
105 | Self::BeforeToolCall
106 | Self::BeforeCompact
107 | Self::BeforePermissionCheck
108 )
109 }
110}
111
112#[derive(Debug, Clone, Default)]
117pub struct EventContext {
118 pub event: String,
119 pub model: String,
120 pub provider: String,
121 pub cwd: String,
122 pub session_id: String,
123 pub tool_name: Option<String>,
124 pub tool_input: Option<String>,
125 pub tool_output: Option<String>,
126 pub prompt: Option<String>,
127 pub error: Option<String>,
128 pub title: Option<String>,
129 pub agent_name: Option<String>,
130}
131
132#[derive(Debug, Clone)]
137pub enum HookResult {
138 Allow,
140 Block(String),
142 Modify(String),
144}
145
146#[derive(Debug, Clone)]
151pub struct Hook {
152 pub event: Event,
153 pub command: String,
154 pub timeout: u64,
155}
156
157impl Hook {
158 pub fn execute(&self, ctx: &EventContext) -> Result<HookResult> {
159 let mut cmd = Command::new("/bin/sh");
160 cmd.arg("-c").arg(&self.command);
161 cmd.env("DOT_EVENT", &ctx.event);
162 cmd.env("DOT_MODEL", &ctx.model);
163 cmd.env("DOT_PROVIDER", &ctx.provider);
164 cmd.env("DOT_CWD", &ctx.cwd);
165 cmd.env("DOT_SESSION_ID", &ctx.session_id);
166 if let Some(ref name) = ctx.tool_name {
167 cmd.env("DOT_TOOL_NAME", name);
168 }
169 if let Some(ref input) = ctx.tool_input {
170 cmd.env("DOT_TOOL_INPUT", input);
171 }
172 if let Some(ref output) = ctx.tool_output {
173 cmd.env("DOT_TOOL_OUTPUT", output);
174 }
175 if let Some(ref prompt) = ctx.prompt {
176 cmd.env("DOT_PROMPT", prompt);
177 }
178 if let Some(ref error) = ctx.error {
179 cmd.env("DOT_ERROR", error);
180 }
181 if let Some(ref title) = ctx.title {
182 cmd.env("DOT_TITLE", title);
183 }
184 if let Some(ref agent) = ctx.agent_name {
185 cmd.env("DOT_AGENT", agent);
186 }
187 let output = cmd
188 .output()
189 .with_context(|| format!("hook '{}' failed to execute", self.command))?;
190
191 if !output.status.success() {
192 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
193 return Ok(HookResult::Block(stderr));
194 }
195
196 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
197 if stdout.trim().is_empty() {
198 return Ok(HookResult::Allow);
199 }
200
201 if self.event.is_blocking() {
202 Ok(HookResult::Modify(stdout))
203 } else {
204 Ok(HookResult::Allow)
205 }
206 }
207}
208
209pub trait Extension: Send + Sync {
214 fn name(&self) -> &str;
215
216 fn description(&self) -> &str {
217 ""
218 }
219
220 fn tools(&self) -> Vec<Box<dyn Tool>> {
221 Vec::new()
222 }
223
224 fn tool_definitions(&self) -> Vec<ToolDefinition> {
225 Vec::new()
226 }
227
228 fn on_event(&self, _event: &Event, _ctx: &EventContext) -> Result<Option<String>> {
229 Ok(None)
230 }
231
232 fn on_tool_call(&self, _name: &str, _input: Value) -> Result<String> {
233 bail!("tool not implemented")
234 }
235}
236
237pub struct ScriptTool {
242 tool_name: String,
243 tool_description: String,
244 schema: Value,
245 command: String,
246 _timeout: u64,
247}
248
249impl ScriptTool {
250 pub fn new(
251 name: String,
252 description: String,
253 schema: Value,
254 command: String,
255 timeout: u64,
256 ) -> Self {
257 ScriptTool {
258 tool_name: name,
259 tool_description: description,
260 schema,
261 command,
262 _timeout: timeout,
263 }
264 }
265}
266
267impl Tool for ScriptTool {
268 fn name(&self) -> &str {
269 &self.tool_name
270 }
271
272 fn description(&self) -> &str {
273 &self.tool_description
274 }
275
276 fn input_schema(&self) -> Value {
277 self.schema.clone()
278 }
279
280 fn execute(&self, input: Value) -> Result<String> {
281 let input_json = serde_json::to_string(&input)?;
282 let mut cmd = Command::new("/bin/sh");
283 cmd.arg("-c").arg(&self.command);
284 cmd.env("DOT_TOOL_INPUT", &input_json);
285
286 if let Some(obj) = input.as_object() {
287 for (key, val) in obj {
288 let env_key = format!("DOT_ARG_{}", key.to_uppercase());
289 let env_val = match val {
290 Value::String(s) => s.clone(),
291 other => other.to_string(),
292 };
293 cmd.env(env_key, env_val);
294 }
295 }
296
297 let output = cmd
298 .output()
299 .with_context(|| format!("script tool '{}' failed", self.tool_name))?;
300
301 let stdout = String::from_utf8_lossy(&output.stdout);
302 let stderr = String::from_utf8_lossy(&output.stderr);
303
304 if !output.status.success() {
305 bail!(
306 "script tool '{}' exited with {}: {}",
307 self.tool_name,
308 output.status,
309 stderr
310 );
311 }
312
313 Ok(stdout.to_string())
314 }
315}
316
317pub struct HookRegistry {
322 hooks: HashMap<Event, Vec<Hook>>,
323}
324
325impl HookRegistry {
326 pub fn new() -> Self {
327 HookRegistry {
328 hooks: HashMap::new(),
329 }
330 }
331
332 pub fn register(&mut self, hook: Hook) {
333 self.hooks.entry(hook.event.clone()).or_default().push(hook);
334 }
335
336 pub fn emit(&self, event: &Event, ctx: &EventContext) {
338 if let Some(hooks) = self.hooks.get(event) {
339 for hook in hooks {
340 match hook.execute(ctx) {
341 Ok(_) => {}
342 Err(e) => {
343 tracing::warn!("hook for '{}' failed: {}", event.as_str(), e);
344 }
345 }
346 }
347 }
348 }
349
350 pub fn emit_blocking(&self, event: &Event, ctx: &EventContext) -> HookResult {
353 if let Some(hooks) = self.hooks.get(event) {
354 let mut last_modify: Option<String> = None;
355 for hook in hooks {
356 match hook.execute(ctx) {
357 Ok(HookResult::Block(reason)) => {
358 tracing::info!("hook blocked '{}': {}", event.as_str(), reason.trim());
359 return HookResult::Block(reason);
360 }
361 Ok(HookResult::Modify(data)) => {
362 last_modify = Some(data);
363 }
364 Ok(HookResult::Allow) => {}
365 Err(e) => {
366 tracing::warn!("hook for '{}' failed: {}", event.as_str(), e);
367 }
368 }
369 }
370 if let Some(data) = last_modify {
371 return HookResult::Modify(data);
372 }
373 }
374 HookResult::Allow
375 }
376
377 pub fn has_hooks(&self, event: &Event) -> bool {
378 self.hooks.get(event).is_some_and(|h| !h.is_empty())
379 }
380}
381
382impl Default for HookRegistry {
383 fn default() -> Self {
384 Self::new()
385 }
386}
387
388pub struct ExtensionRegistry {
393 extensions: Vec<Box<dyn Extension>>,
394}
395
396impl ExtensionRegistry {
397 pub fn new() -> Self {
398 ExtensionRegistry {
399 extensions: Vec::new(),
400 }
401 }
402
403 pub fn register(&mut self, ext: Box<dyn Extension>) {
404 tracing::info!("Registered extension: {}", ext.name());
405 self.extensions.push(ext);
406 }
407
408 pub fn tools(&self) -> Vec<Box<dyn Tool>> {
409 self.extensions.iter().flat_map(|e| e.tools()).collect()
410 }
411
412 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
413 self.extensions
414 .iter()
415 .flat_map(|e| e.tool_definitions())
416 .collect()
417 }
418
419 pub fn emit(&self, event: &Event, ctx: &EventContext) {
420 for ext in &self.extensions {
421 if let Err(e) = ext.on_event(event, ctx) {
422 tracing::warn!(
423 "extension '{}' error on '{}': {}",
424 ext.name(),
425 event.as_str(),
426 e
427 );
428 }
429 }
430 }
431
432 pub fn handle_tool_call(&self, name: &str, input: Value) -> Option<Result<String>> {
433 for ext in &self.extensions {
434 let defs = ext.tool_definitions();
435 if defs.iter().any(|d| d.name == name) {
436 return Some(ext.on_tool_call(name, input));
437 }
438 }
439 None
440 }
441
442 pub fn is_empty(&self) -> bool {
443 self.extensions.is_empty()
444 }
445}
446
447impl Default for ExtensionRegistry {
448 fn default() -> Self {
449 Self::new()
450 }
451}