Skip to main content

halter_hooks/
engine.rs

1// pattern: Functional Core
2
3use std::collections::{BTreeMap, BTreeSet};
4use std::fmt;
5use std::path::PathBuf;
6use std::time::Duration;
7
8use anyhow::Context;
9use chrono::Utc;
10use halter_protocol::{HookHandlerType, HookRunStatus, HookRunSummary, PluginId};
11use serde_json::Value;
12
13use crate::config::{HookEventName, HookHandlerConfig as FileHookHandlerConfig, HooksFile};
14use crate::matcher::CompiledMatcher;
15use crate::merge::{HandlerPriority, HandlerPriorityGroup, HookMergedOutcome};
16use crate::sdk::{HookCallback, HookKind, RegisteredHook, RegisteredHookPriority};
17
18pub const HOOK_PROTOCOL_VERSION: u32 = 1;
19
20#[derive(Debug, Clone)]
21pub struct HookRegistrySource {
22    pub plugin_id: PluginId,
23    pub plugin_root: PathBuf,
24    pub source_path: PathBuf,
25    pub allowed_http_hosts: Vec<String>,
26    pub allowed_env_vars: Vec<String>,
27    pub file: HooksFile,
28}
29
30#[derive(Debug, Clone, Default)]
31pub struct Hooks {
32    handlers_by_event: BTreeMap<HookEventName, Vec<ConfiguredHandler>>,
33}
34
35impl Hooks {
36    #[must_use]
37    pub fn from_sources(sources: impl IntoIterator<Item = HookRegistrySource>) -> Self {
38        let mut handlers_by_event = BTreeMap::new();
39
40        for (plugin_index, source) in sources.into_iter().enumerate() {
41            for (event_index, (event_name, matcher_groups)) in source.file.hooks.iter().enumerate()
42            {
43                for (matcher_index, matcher_group) in matcher_groups.iter().enumerate() {
44                    for (hook_index, hook) in matcher_group.hooks.iter().enumerate() {
45                        let matcher = matcher_group.matcher.clone();
46                        let handler_id = format!(
47                            "{}:{}:{}:{}:{}",
48                            source.plugin_id,
49                            event_name.canonical_name(),
50                            event_index,
51                            matcher_index,
52                            hook_index
53                        );
54                        handlers_by_event
55                            .entry(*event_name)
56                            .or_insert_with(Vec::new)
57                            .push(ConfiguredHandler {
58                                handler_id,
59                                plugin_id: source.plugin_id.clone(),
60                                plugin_root: source.plugin_root.clone(),
61                                source_path: source.source_path.clone(),
62                                allowed_http_hosts: source.allowed_http_hosts.clone(),
63                                allowed_env_vars: source.allowed_env_vars.clone(),
64                                event_name: *event_name,
65                                handler_type: hook.handler_type,
66                                timeout: hook.timeout,
67                                status_message: hook.status_message.clone(),
68                                if_condition: hook.if_condition.clone(),
69                                once: hook.once,
70                                matcher,
71                                config: ConfiguredHandlerConfig::File(hook.config.clone()),
72                                priority: HandlerPriority {
73                                    group: HandlerPriorityGroup::PluginFiles,
74                                    plugin_load_order: plugin_index,
75                                    event_declaration_index: event_index,
76                                    matcher_group_index: matcher_index,
77                                    hook_index_within_group: hook_index,
78                                },
79                            });
80                    }
81                }
82            }
83        }
84
85        Self { handlers_by_event }
86    }
87
88    pub fn from_registered(
89        hooks: impl IntoIterator<Item = RegisteredHook>,
90    ) -> anyhow::Result<Self> {
91        let mut handlers_by_event = BTreeMap::new();
92
93        for (hook_index, registered) in hooks.into_iter().enumerate() {
94            let matcher = registered
95                .hook
96                .matcher
97                .as_deref()
98                .map(str::trim)
99                .filter(|value| !value.is_empty())
100                .map(|pattern| {
101                    if registered.hook.event.matcher_field().is_none() {
102                        anyhow::bail!(
103                            "hook event '{}' does not support matcher",
104                            registered.hook.event.canonical_name()
105                        );
106                    }
107                    Ok(CompiledMatcher::compile_regex(pattern)?)
108                })
109                .transpose()
110                .with_context(|| {
111                    format!(
112                        "failed to compile sdk hook matcher for plugin '{}' event '{}'",
113                        registered.plugin_id,
114                        registered.hook.event.canonical_name()
115                    )
116                })?;
117            let priority_group = match registered.priority {
118                RegisteredHookPriority::BeforePlugins => HandlerPriorityGroup::SdkBeforePlugins,
119                RegisteredHookPriority::AfterPlugins => HandlerPriorityGroup::SdkAfterPlugins,
120            };
121            let handler_type = registered.hook.kind.handler_type();
122            let config = match registered.hook.kind {
123                HookKind::Callback(callback) => ConfiguredHandlerConfig::Callback(callback),
124                HookKind::Function(factory) => ConfiguredHandlerConfig::Function(factory()),
125            };
126            handlers_by_event
127                .entry(registered.hook.event)
128                .or_insert_with(Vec::new)
129                .push(ConfiguredHandler {
130                    handler_id: format!(
131                        "{}:{}:sdk:{}",
132                        registered.plugin_id,
133                        registered.hook.event.canonical_name(),
134                        hook_index
135                    ),
136                    plugin_id: registered.plugin_id.clone(),
137                    plugin_root: registered.plugin_root.clone(),
138                    source_path: PathBuf::from(format!(
139                        "<sdk-hook:{}:{}>",
140                        registered.plugin_id, hook_index
141                    )),
142                    allowed_http_hosts: Vec::new(),
143                    allowed_env_vars: Vec::new(),
144                    event_name: registered.hook.event,
145                    handler_type,
146                    timeout: registered.hook.timeout,
147                    status_message: registered.hook.status_message.clone(),
148                    if_condition: registered.hook.if_condition.clone(),
149                    once: registered.hook.once,
150                    matcher,
151                    config,
152                    priority: HandlerPriority {
153                        group: priority_group,
154                        plugin_load_order: hook_index,
155                        event_declaration_index: 0,
156                        matcher_group_index: 0,
157                        hook_index_within_group: 0,
158                    },
159                });
160        }
161
162        Ok(Self { handlers_by_event })
163    }
164
165    #[must_use]
166    pub fn prepare(&self, request: HookDispatchRequest) -> PreparedHookDispatch {
167        Self::prepare_many([self], request)
168    }
169
170    #[must_use]
171    pub fn prepare_many<'a>(
172        hook_sets: impl IntoIterator<Item = &'a Hooks>,
173        request: HookDispatchRequest,
174    ) -> PreparedHookDispatch {
175        let mut matched_handlers = Vec::new();
176
177        for hooks in hook_sets {
178            for handler in hooks
179                .handlers_by_event
180                .get(&request.event_name)
181                .into_iter()
182                .flatten()
183            {
184                if handler.once && request.fired_hook_ids.contains(&handler.handler_id) {
185                    continue;
186                }
187                if !handler.matches(&request) {
188                    continue;
189                }
190
191                matched_handlers.push(handler.clone());
192            }
193        }
194
195        matched_handlers.sort_by(|left, right| left.priority.cmp(&right.priority));
196        let previews = matched_handlers.iter().map(build_preview_run).collect();
197
198        PreparedHookDispatch {
199            request,
200            previews,
201            matched_handlers,
202        }
203    }
204}
205
206#[derive(Debug, Clone)]
207pub struct HookDispatchRequest {
208    pub event_name: HookEventName,
209    pub matcher_value: Option<String>,
210    pub payload: Value,
211    pub fired_hook_ids: BTreeSet<String>,
212}
213
214#[derive(Debug, Clone)]
215pub struct PreparedHookDispatch {
216    request: HookDispatchRequest,
217    previews: Vec<HookRunSummary>,
218    matched_handlers: Vec<ConfiguredHandler>,
219}
220
221impl PreparedHookDispatch {
222    #[must_use]
223    pub fn request(&self) -> &HookDispatchRequest {
224        &self.request
225    }
226
227    #[must_use]
228    pub fn preview_runs(&self) -> &[HookRunSummary] {
229        &self.previews
230    }
231
232    #[must_use]
233    pub fn matched_handlers(&self) -> &[ConfiguredHandler] {
234        &self.matched_handlers
235    }
236}
237
238#[derive(Debug, Clone)]
239pub struct HookDispatchOutcome {
240    pub merged: HookMergedOutcome,
241    pub runs: Vec<HookRunSummary>,
242    pub fired_hook_ids: Vec<String>,
243}
244
245#[derive(Debug, Clone)]
246pub struct ConfiguredHandler {
247    pub handler_id: String,
248    pub plugin_id: PluginId,
249    pub plugin_root: PathBuf,
250    pub source_path: PathBuf,
251    pub allowed_http_hosts: Vec<String>,
252    pub allowed_env_vars: Vec<String>,
253    pub event_name: HookEventName,
254    pub handler_type: HookHandlerType,
255    pub timeout: Duration,
256    pub status_message: Option<String>,
257    pub if_condition: Option<String>,
258    pub once: bool,
259    pub matcher: Option<CompiledMatcher>,
260    pub config: ConfiguredHandlerConfig,
261    pub priority: HandlerPriority,
262}
263
264impl ConfiguredHandler {
265    /// Single-pass match: the regex matcher must hit (or be absent) and the
266    /// `if` expression must evaluate true (or be absent). Collapsed from two
267    /// methods to one chained expression (finding L19).
268    fn matches(&self, request: &HookDispatchRequest) -> bool {
269        let matcher_hit = match (&self.matcher, self.event_name.matcher_field()) {
270            (Some(matcher), Some(_)) => request
271                .matcher_value
272                .as_deref()
273                .is_some_and(|value| matcher.is_match(value)),
274            (Some(_), None) => false,
275            (None, _) => true,
276        };
277        matcher_hit
278            && self
279                .if_condition
280                .as_deref()
281                .is_none_or(|condition| matches_if_condition(condition, request))
282    }
283}
284
285#[derive(Clone)]
286pub enum ConfiguredHandlerConfig {
287    File(FileHookHandlerConfig),
288    Callback(HookCallback),
289    Function(HookCallback),
290}
291
292impl fmt::Debug for ConfiguredHandlerConfig {
293    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294        match self {
295            Self::File(config) => f.debug_tuple("File").field(config).finish(),
296            Self::Callback(_) => f.write_str("Callback(..)"),
297            Self::Function(_) => f.write_str("Function(..)"),
298        }
299    }
300}
301
302fn build_preview_run(handler: &ConfiguredHandler) -> HookRunSummary {
303    let started_at = Utc::now();
304    HookRunSummary {
305        run_id: format!(
306            "{}:{}",
307            handler.handler_id,
308            started_at.timestamp_nanos_opt().unwrap_or_default()
309        ),
310        event_name: handler.event_name.canonical_name().to_owned(),
311        handler_type: handler.handler_type,
312        plugin_id: handler.plugin_id.clone(),
313        plugin_root: handler.plugin_root.clone(),
314        status: HookRunStatus::Running,
315        status_message: handler.status_message.clone(),
316        started_at,
317        completed_at: None,
318        duration_ms: None,
319        entries: Vec::new(),
320    }
321}
322
323fn matches_if_condition(condition: &str, request: &HookDispatchRequest) -> bool {
324    let trimmed = condition.trim();
325    if trimmed.is_empty() || trimmed == "*" {
326        return true;
327    }
328
329    let Some(tool_name) = request.payload.get("tool_name").and_then(Value::as_str) else {
330        return false;
331    };
332
333    if let Some((tool_pattern, input_pattern)) = parse_if_condition(trimmed) {
334        if !matches_text_pattern(tool_pattern, tool_name) {
335            return false;
336        }
337
338        let input_text = request
339            .payload
340            .get("tool_input")
341            .and_then(render_if_input_text)
342            .unwrap_or_default();
343        return matches_text_pattern(input_pattern, &input_text);
344    }
345
346    matches_text_pattern(trimmed, tool_name)
347}
348
349fn parse_if_condition(condition: &str) -> Option<(&str, &str)> {
350    let open = condition.find('(')?;
351    if !condition.ends_with(')') {
352        return None;
353    }
354    let close = condition.len().saturating_sub(1);
355    if close <= open {
356        return None;
357    }
358    Some((condition[..open].trim(), condition[open + 1..close].trim()))
359}
360
361fn render_if_input_text(value: &Value) -> Option<String> {
362    match value {
363        Value::Object(map) => map
364            .get("command")
365            .and_then(Value::as_str)
366            .map(ToOwned::to_owned)
367            .or_else(|| Some(Value::Object(map.clone()).to_string())),
368        Value::String(text) => Some(text.clone()),
369        Value::Null => None,
370        other => Some(other.to_string()),
371    }
372}
373
374fn matches_text_pattern(pattern: &str, candidate: &str) -> bool {
375    let pattern = pattern.trim();
376    if pattern.is_empty() || pattern == "*" {
377        return true;
378    }
379
380    // Runtime match for `if_condition` patterns. These aren't validated at
381    // config-load time, so an invalid pattern fails closed (no match).
382    match CompiledMatcher::compile(pattern) {
383        Ok(matcher) => matcher.is_match(candidate),
384        Err(_) => false,
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use serde_json::json;
391
392    use super::*;
393    use crate::config::{HookHandler, HookMatcherGroup, HooksFile, PromptHookConfig};
394
395    #[test]
396    fn wildcard_match_supports_globs() {
397        assert!(matches_text_pattern("git *", "git status"));
398        assert!(matches_text_pattern("shell", "Shell"));
399        assert!(!matches_text_pattern("git *", "cargo test"));
400    }
401
402    /// AC3.5: an invalid matcher cannot reach the engine. `HooksFile::from_raw`
403    /// rejects the config at load, so `Hooks::from_sources` never sees a raw
404    /// string matcher. Defense-in-depth via the type system (H22/H27).
405    #[test]
406    fn review_hook_runtime_ac3_5_engine_never_sees_invalid_matcher() {
407        let error = HooksFile::from_json_bytes(
408            br#"{
409                "hooks": {
410                    "PreToolUse": [
411                        {
412                            "matcher": "(",
413                            "hooks": [
414                                {
415                                    "type": "prompt",
416                                    "prompt": "never reached"
417                                }
418                            ]
419                        }
420                    ]
421                }
422            }"#,
423        )
424        .expect_err("invalid matcher must hard-fail at load");
425        let rendered = format!("{error:#}");
426        assert!(
427            rendered.contains("invalid matcher regex")
428                || rendered.contains("invalid regex pattern"),
429            "expected compile error, got: {rendered}",
430        );
431    }
432
433    #[test]
434    fn if_condition_matches_tool_name_and_command() {
435        let handler = ConfiguredHandler {
436            handler_id: "hook".to_owned(),
437            plugin_id: PluginId::from("plugin"),
438            plugin_root: PathBuf::from("/tmp/plugin"),
439            source_path: PathBuf::from("/tmp/plugin/hooks.json"),
440            allowed_http_hosts: Vec::new(),
441            allowed_env_vars: Vec::new(),
442            event_name: HookEventName::PreToolUse,
443            handler_type: HookHandlerType::Prompt,
444            timeout: Duration::from_secs(1),
445            status_message: None,
446            if_condition: Some("Shell(git *)".to_owned()),
447            once: false,
448            matcher: None,
449            config: ConfiguredHandlerConfig::File(FileHookHandlerConfig::Prompt(
450                PromptHookConfig {
451                    prompt: "noop".to_owned(),
452                    model: None,
453                },
454            )),
455            priority: HandlerPriority {
456                group: HandlerPriorityGroup::PluginFiles,
457                plugin_load_order: 0,
458                event_declaration_index: 0,
459                matcher_group_index: 0,
460                hook_index_within_group: 0,
461            },
462        };
463
464        let request = HookDispatchRequest {
465            event_name: HookEventName::PreToolUse,
466            matcher_value: Some("Shell".to_owned()),
467            payload: json!({
468                "tool_name": "Shell",
469                "tool_input": { "command": "git status" },
470            }),
471            fired_hook_ids: BTreeSet::new(),
472        };
473
474        assert!(handler.matches(&request));
475    }
476
477    #[test]
478    fn if_condition_matches_regex_patterns_and_string_inputs() {
479        let request = HookDispatchRequest {
480            event_name: HookEventName::PreToolUse,
481            matcher_value: Some("Read".to_owned()),
482            payload: json!({
483                "tool_name": "Read",
484                "tool_input": "src/lib.rs",
485            }),
486            fired_hook_ids: BTreeSet::new(),
487        };
488
489        assert!(matches_if_condition("^Read$", &request));
490        assert!(matches_if_condition("Read(^src/.*\\.rs$)", &request));
491        assert!(!matches_if_condition("Write(src/.*)", &request));
492    }
493
494    #[test]
495    fn if_condition_rejects_non_tool_payloads_and_unbalanced_groups() {
496        let request = HookDispatchRequest {
497            event_name: HookEventName::Notification,
498            matcher_value: None,
499            payload: json!({
500                "message": "hello"
501            }),
502            fired_hook_ids: BTreeSet::new(),
503        };
504
505        assert!(!matches_if_condition("Shell(git *)", &request));
506        assert!(!matches_if_condition("Shell(", &request));
507    }
508
509    #[test]
510    fn if_condition_rejects_trailing_text_after_group() {
511        let request = HookDispatchRequest {
512            event_name: HookEventName::PreToolUse,
513            matcher_value: Some("Shell".to_owned()),
514            payload: json!({
515                "tool_name": "Shell",
516                "tool_input": { "command": "git status" },
517            }),
518            fired_hook_ids: BTreeSet::new(),
519        };
520
521        assert!(!matches_if_condition("Shell(git *) trailing", &request));
522    }
523
524    #[test]
525    fn prepare_filters_once_handlers() {
526        let hooks = Hooks::from_sources([HookRegistrySource {
527            plugin_id: PluginId::from("plugin"),
528            plugin_root: PathBuf::from("/tmp/plugin"),
529            source_path: PathBuf::from("/tmp/plugin/hooks.json"),
530            allowed_http_hosts: Vec::new(),
531            allowed_env_vars: Vec::new(),
532            file: HooksFile {
533                hooks: [(
534                    HookEventName::UserPromptSubmit,
535                    vec![HookMatcherGroup {
536                        matcher: None,
537                        hooks: vec![HookHandler {
538                            handler_type: HookHandlerType::Prompt,
539                            timeout: Duration::from_secs(1),
540                            status_message: None,
541                            if_condition: None,
542                            once: true,
543                            config: FileHookHandlerConfig::Prompt(PromptHookConfig {
544                                prompt: "noop".to_owned(),
545                                model: None,
546                            }),
547                        }],
548                    }],
549                )]
550                .into_iter()
551                .collect(),
552            },
553        }]);
554
555        let prepared = hooks.prepare(HookDispatchRequest {
556            event_name: HookEventName::UserPromptSubmit,
557            matcher_value: None,
558            payload: json!({}),
559            fired_hook_ids: ["plugin:UserPromptSubmit:0:0:0".to_owned()]
560                .into_iter()
561                .collect(),
562        });
563
564        assert!(prepared.matched_handlers().is_empty());
565    }
566}