Skip to main content

dot/
extension.rs

1use std::collections::HashMap;
2use std::process::Command;
3use std::str::FromStr;
4
5use anyhow::{Context, Result, bail};
6use serde_json::Value;
7
8use crate::provider::ToolDefinition;
9use crate::tools::Tool;
10
11// ============================================================================
12// Lifecycle Events — mirrors pi's 30+ event system via config hooks
13// ============================================================================
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub enum Event {
17    SessionStart,
18    SessionEnd,
19    BeforePrompt,
20    AfterPrompt,
21    BeforeToolCall,
22    AfterToolCall,
23    BeforeCompact,
24    AfterCompact,
25    ModelSwitch,
26    AgentSwitch,
27    OnError,
28    OnStreamStart,
29    OnStreamEnd,
30    OnResume,
31    OnUserInput,
32    OnToolError,
33    BeforeExit,
34    OnThinkingStart,
35    OnThinkingEnd,
36    OnTitleGenerated,
37    BeforePermissionCheck,
38    OnContextLoad,
39}
40
41impl FromStr for Event {
42    type Err = ();
43
44    fn from_str(s: &str) -> Result<Self, Self::Err> {
45        match s {
46            "session_start" => Ok(Self::SessionStart),
47            "session_end" => Ok(Self::SessionEnd),
48            "before_prompt" => Ok(Self::BeforePrompt),
49            "after_prompt" => Ok(Self::AfterPrompt),
50            "before_tool_call" => Ok(Self::BeforeToolCall),
51            "after_tool_call" => Ok(Self::AfterToolCall),
52            "before_compact" => Ok(Self::BeforeCompact),
53            "after_compact" => Ok(Self::AfterCompact),
54            "model_switch" => Ok(Self::ModelSwitch),
55            "agent_switch" => Ok(Self::AgentSwitch),
56            "on_error" => Ok(Self::OnError),
57            "on_stream_start" => Ok(Self::OnStreamStart),
58            "on_stream_end" => Ok(Self::OnStreamEnd),
59            "on_resume" => Ok(Self::OnResume),
60            "on_user_input" => Ok(Self::OnUserInput),
61            "on_tool_error" => Ok(Self::OnToolError),
62            "before_exit" => Ok(Self::BeforeExit),
63            "on_thinking_start" => Ok(Self::OnThinkingStart),
64            "on_thinking_end" => Ok(Self::OnThinkingEnd),
65            "on_title_generated" => Ok(Self::OnTitleGenerated),
66            "before_permission_check" => Ok(Self::BeforePermissionCheck),
67            "on_context_load" => Ok(Self::OnContextLoad),
68            _ => Err(()),
69        }
70    }
71}
72
73impl Event {
74    pub fn as_str(&self) -> &'static str {
75        match self {
76            Self::SessionStart => "session_start",
77            Self::SessionEnd => "session_end",
78            Self::BeforePrompt => "before_prompt",
79            Self::AfterPrompt => "after_prompt",
80            Self::BeforeToolCall => "before_tool_call",
81            Self::AfterToolCall => "after_tool_call",
82            Self::BeforeCompact => "before_compact",
83            Self::AfterCompact => "after_compact",
84            Self::ModelSwitch => "model_switch",
85            Self::AgentSwitch => "agent_switch",
86            Self::OnError => "on_error",
87            Self::OnStreamStart => "on_stream_start",
88            Self::OnStreamEnd => "on_stream_end",
89            Self::OnResume => "on_resume",
90            Self::OnUserInput => "on_user_input",
91            Self::OnToolError => "on_tool_error",
92            Self::BeforeExit => "before_exit",
93            Self::OnThinkingStart => "on_thinking_start",
94            Self::OnThinkingEnd => "on_thinking_end",
95            Self::OnTitleGenerated => "on_title_generated",
96            Self::BeforePermissionCheck => "before_permission_check",
97            Self::OnContextLoad => "on_context_load",
98        }
99    }
100
101    pub fn is_blocking(&self) -> bool {
102        matches!(
103            self,
104            Self::BeforePrompt
105                | Self::BeforeToolCall
106                | Self::BeforeCompact
107                | Self::BeforePermissionCheck
108        )
109    }
110}
111
112// ============================================================================
113// Event Context — data passed to hook handlers
114// ============================================================================
115
116#[derive(Debug, Clone, Default)]
117pub struct EventContext {
118    pub event: String,
119    pub model: String,
120    pub provider: String,
121    pub cwd: String,
122    pub session_id: String,
123    pub tool_name: Option<String>,
124    pub tool_input: Option<String>,
125    pub tool_output: Option<String>,
126    pub prompt: Option<String>,
127    pub error: Option<String>,
128    pub title: Option<String>,
129    pub agent_name: Option<String>,
130}
131
132// ============================================================================
133// HookResult — what a hook returns (allow, block, or modify)
134// ============================================================================
135
136#[derive(Debug, Clone)]
137pub enum HookResult {
138    /// Hook executed successfully, proceed normally
139    Allow,
140    /// Hook wants to block the action (before_* events only)
141    Block(String),
142    /// Hook wants to modify the data (stdout contents replace the input)
143    Modify(String),
144}
145
146// ============================================================================
147// Hook — a shell command triggered on a lifecycle event
148// ============================================================================
149
150#[derive(Debug, Clone)]
151pub struct Hook {
152    pub event: Event,
153    pub command: String,
154    pub timeout: u64,
155}
156
157impl Hook {
158    pub fn execute(&self, ctx: &EventContext) -> Result<HookResult> {
159        let mut cmd = Command::new("/bin/sh");
160        cmd.arg("-c").arg(&self.command);
161        cmd.env("DOT_EVENT", &ctx.event);
162        cmd.env("DOT_MODEL", &ctx.model);
163        cmd.env("DOT_PROVIDER", &ctx.provider);
164        cmd.env("DOT_CWD", &ctx.cwd);
165        cmd.env("DOT_SESSION_ID", &ctx.session_id);
166        if let Some(ref name) = ctx.tool_name {
167            cmd.env("DOT_TOOL_NAME", name);
168        }
169        if let Some(ref input) = ctx.tool_input {
170            cmd.env("DOT_TOOL_INPUT", input);
171        }
172        if let Some(ref output) = ctx.tool_output {
173            cmd.env("DOT_TOOL_OUTPUT", output);
174        }
175        if let Some(ref prompt) = ctx.prompt {
176            cmd.env("DOT_PROMPT", prompt);
177        }
178        if let Some(ref error) = ctx.error {
179            cmd.env("DOT_ERROR", error);
180        }
181        if let Some(ref title) = ctx.title {
182            cmd.env("DOT_TITLE", title);
183        }
184        if let Some(ref agent) = ctx.agent_name {
185            cmd.env("DOT_AGENT", agent);
186        }
187        let output = cmd
188            .output()
189            .with_context(|| format!("hook '{}' failed to execute", self.command))?;
190
191        if !output.status.success() {
192            let stderr = String::from_utf8_lossy(&output.stderr).to_string();
193            return Ok(HookResult::Block(stderr));
194        }
195
196        let stdout = String::from_utf8_lossy(&output.stdout).to_string();
197        if stdout.trim().is_empty() {
198            return Ok(HookResult::Allow);
199        }
200
201        if self.event.is_blocking() {
202            Ok(HookResult::Modify(stdout))
203        } else {
204            Ok(HookResult::Allow)
205        }
206    }
207}
208
209// ============================================================================
210// Extension Trait — for compiled-in Rust extensions
211// ============================================================================
212
213pub trait Extension: Send + Sync {
214    fn name(&self) -> &str;
215
216    fn description(&self) -> &str {
217        ""
218    }
219
220    fn tools(&self) -> Vec<Box<dyn Tool>> {
221        Vec::new()
222    }
223
224    fn tool_definitions(&self) -> Vec<ToolDefinition> {
225        Vec::new()
226    }
227
228    fn on_event(&self, _event: &Event, _ctx: &EventContext) -> Result<Option<String>> {
229        Ok(None)
230    }
231
232    fn on_tool_call(&self, _name: &str, _input: Value) -> Result<String> {
233        bail!("tool not implemented")
234    }
235}
236
237// ============================================================================
238// ScriptTool — a tool defined in config backed by a shell command
239// ============================================================================
240
241pub struct ScriptTool {
242    tool_name: String,
243    tool_description: String,
244    schema: Value,
245    command: String,
246    _timeout: u64,
247}
248
249impl ScriptTool {
250    pub fn new(
251        name: String,
252        description: String,
253        schema: Value,
254        command: String,
255        timeout: u64,
256    ) -> Self {
257        ScriptTool {
258            tool_name: name,
259            tool_description: description,
260            schema,
261            command,
262            _timeout: timeout,
263        }
264    }
265}
266
267impl Tool for ScriptTool {
268    fn name(&self) -> &str {
269        &self.tool_name
270    }
271
272    fn description(&self) -> &str {
273        &self.tool_description
274    }
275
276    fn input_schema(&self) -> Value {
277        self.schema.clone()
278    }
279
280    fn execute(&self, input: Value) -> Result<String> {
281        let input_json = serde_json::to_string(&input)?;
282        let mut cmd = Command::new("/bin/sh");
283        cmd.arg("-c").arg(&self.command);
284        cmd.env("DOT_TOOL_INPUT", &input_json);
285
286        if let Some(obj) = input.as_object() {
287            for (key, val) in obj {
288                let env_key = format!("DOT_ARG_{}", key.to_uppercase());
289                let env_val = match val {
290                    Value::String(s) => s.clone(),
291                    other => other.to_string(),
292                };
293                cmd.env(env_key, env_val);
294            }
295        }
296
297        let output = cmd
298            .output()
299            .with_context(|| format!("script tool '{}' failed", self.tool_name))?;
300
301        let stdout = String::from_utf8_lossy(&output.stdout);
302        let stderr = String::from_utf8_lossy(&output.stderr);
303
304        if !output.status.success() {
305            bail!(
306                "script tool '{}' exited with {}: {}",
307                self.tool_name,
308                output.status,
309                stderr
310            );
311        }
312
313        Ok(stdout.to_string())
314    }
315}
316
317// ============================================================================
318// HookRegistry — manages lifecycle hooks from config
319// ============================================================================
320
321pub struct HookRegistry {
322    hooks: HashMap<Event, Vec<Hook>>,
323}
324
325impl HookRegistry {
326    pub fn new() -> Self {
327        HookRegistry {
328            hooks: HashMap::new(),
329        }
330    }
331
332    pub fn register(&mut self, hook: Hook) {
333        self.hooks.entry(hook.event.clone()).or_default().push(hook);
334    }
335
336    /// Fire-and-forget emit for non-blocking events.
337    pub fn emit(&self, event: &Event, ctx: &EventContext) {
338        if let Some(hooks) = self.hooks.get(event) {
339            for hook in hooks {
340                match hook.execute(ctx) {
341                    Ok(_) => {}
342                    Err(e) => {
343                        tracing::warn!("hook for '{}' failed: {}", event.as_str(), e);
344                    }
345                }
346            }
347        }
348    }
349
350    /// Blocking emit for before_* events. Returns Block if any hook blocks,
351    /// Modify with the last modifier's output, or Allow.
352    pub fn emit_blocking(&self, event: &Event, ctx: &EventContext) -> HookResult {
353        if let Some(hooks) = self.hooks.get(event) {
354            let mut last_modify: Option<String> = None;
355            for hook in hooks {
356                match hook.execute(ctx) {
357                    Ok(HookResult::Block(reason)) => {
358                        tracing::info!("hook blocked '{}': {}", event.as_str(), reason.trim());
359                        return HookResult::Block(reason);
360                    }
361                    Ok(HookResult::Modify(data)) => {
362                        last_modify = Some(data);
363                    }
364                    Ok(HookResult::Allow) => {}
365                    Err(e) => {
366                        tracing::warn!("hook for '{}' failed: {}", event.as_str(), e);
367                    }
368                }
369            }
370            if let Some(data) = last_modify {
371                return HookResult::Modify(data);
372            }
373        }
374        HookResult::Allow
375    }
376
377    pub fn has_hooks(&self, event: &Event) -> bool {
378        self.hooks.get(event).is_some_and(|h| !h.is_empty())
379    }
380}
381
382impl Default for HookRegistry {
383    fn default() -> Self {
384        Self::new()
385    }
386}
387
388// ============================================================================
389// ExtensionRegistry — manages compiled extensions
390// ============================================================================
391
392pub struct ExtensionRegistry {
393    extensions: Vec<Box<dyn Extension>>,
394}
395
396impl ExtensionRegistry {
397    pub fn new() -> Self {
398        ExtensionRegistry {
399            extensions: Vec::new(),
400        }
401    }
402
403    pub fn register(&mut self, ext: Box<dyn Extension>) {
404        tracing::info!("Registered extension: {}", ext.name());
405        self.extensions.push(ext);
406    }
407
408    pub fn tools(&self) -> Vec<Box<dyn Tool>> {
409        self.extensions.iter().flat_map(|e| e.tools()).collect()
410    }
411
412    pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
413        self.extensions
414            .iter()
415            .flat_map(|e| e.tool_definitions())
416            .collect()
417    }
418
419    pub fn emit(&self, event: &Event, ctx: &EventContext) {
420        for ext in &self.extensions {
421            if let Err(e) = ext.on_event(event, ctx) {
422                tracing::warn!(
423                    "extension '{}' error on '{}': {}",
424                    ext.name(),
425                    event.as_str(),
426                    e
427                );
428            }
429        }
430    }
431
432    pub fn handle_tool_call(&self, name: &str, input: Value) -> Option<Result<String>> {
433        for ext in &self.extensions {
434            let defs = ext.tool_definitions();
435            if defs.iter().any(|d| d.name == name) {
436                return Some(ext.on_tool_call(name, input));
437            }
438        }
439        None
440    }
441
442    pub fn is_empty(&self) -> bool {
443        self.extensions.is_empty()
444    }
445}
446
447impl Default for ExtensionRegistry {
448    fn default() -> Self {
449        Self::new()
450    }
451}