Skip to main content

sqz_engine/
hook_manager.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4/// The 5 hook types supported by the sqz hook system.
5///
6/// Requirements: 44.1
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub enum HookType {
9    /// Fires before a tool is executed. Can block, redirect, or inject context.
10    PreToolUse,
11    /// Fires after a tool completes. Captures structured events.
12    PostToolUse,
13    /// Fires before context compaction. Builds session snapshot.
14    PreCompact,
15    /// Fires on session start or resume. Restores from snapshot.
16    SessionStart,
17    /// Fires when the user submits a prompt. Captures decisions and corrections.
18    UserPromptSubmit,
19}
20
21impl HookType {
22    /// Returns all hook types in canonical order.
23    pub fn all() -> &'static [HookType] {
24        &[
25            HookType::PreToolUse,
26            HookType::PostToolUse,
27            HookType::PreCompact,
28            HookType::SessionStart,
29            HookType::UserPromptSubmit,
30        ]
31    }
32
33    /// Human-readable label for this hook type.
34    pub fn label(&self) -> &'static str {
35        match self {
36            HookType::PreToolUse => "pre_tool_use",
37            HookType::PostToolUse => "post_tool_use",
38            HookType::PreCompact => "pre_compact",
39            HookType::SessionStart => "session_start",
40            HookType::UserPromptSubmit => "user_prompt_submit",
41        }
42    }
43}
44
45/// Actions a hook can take when fired.
46///
47/// Requirements: 44.2, 44.3, 44.4, 44.5
48#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
49pub enum HookAction {
50    /// Allow the operation to proceed (no-op).
51    Allow,
52    /// Block the operation with a reason (PreToolUse).
53    Block { reason: String },
54    /// Redirect to a different tool (PreToolUse).
55    Redirect { to_tool: String },
56    /// Inject additional context into the operation (PreToolUse).
57    InjectContext { content: String },
58    /// Capture a structured event (PostToolUse).
59    CaptureEvent { event_type: String, data: String },
60    /// Build a session snapshot (PreCompact).
61    BuildSnapshot,
62    /// Restore session state from snapshot (SessionStart).
63    RestoreSnapshot,
64    /// Capture a user decision or correction (UserPromptSubmit).
65    CaptureDecision { decision: String },
66}
67
68/// A registered hook with its type, action, and optional filter.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Hook {
71    pub hook_type: HookType,
72    pub action: HookAction,
73    /// Optional filter — when set, the hook only fires if the context
74    /// matches this pattern (e.g. a tool name for PreToolUse).
75    pub filter: Option<String>,
76}
77
78/// Context passed to hooks when they fire.
79#[derive(Debug, Clone, Default)]
80pub struct HookContext {
81    /// The tool name (relevant for PreToolUse / PostToolUse).
82    pub tool_name: Option<String>,
83    /// The command being executed.
84    pub command: Option<String>,
85    /// Arbitrary key-value metadata.
86    pub metadata: HashMap<String, String>,
87}
88
89/// Manages hook registration and dispatch.
90///
91/// Requirements: 44.1–44.6
92pub struct HookManager {
93    hooks: HashMap<HookType, Vec<Hook>>,
94}
95
96impl HookManager {
97    pub fn new() -> Self {
98        Self {
99            hooks: HashMap::new(),
100        }
101    }
102
103    /// Register a hook.
104    pub fn register(&mut self, hook: Hook) {
105        self.hooks
106            .entry(hook.hook_type)
107            .or_default()
108            .push(hook);
109    }
110
111    /// Fire all hooks of the given type and return the first non-Allow action,
112    /// or `HookAction::Allow` if no hook matched.
113    pub fn fire(&self, hook_type: HookType, context: &HookContext) -> HookAction {
114        let Some(hooks) = self.hooks.get(&hook_type) else {
115            return HookAction::Allow;
116        };
117
118        for hook in hooks {
119            if let Some(ref filter) = hook.filter {
120                // Check filter against tool_name or command.
121                let matches = context
122                    .tool_name
123                    .as_deref()
124                    .map_or(false, |t| t == filter)
125                    || context
126                        .command
127                        .as_deref()
128                        .map_or(false, |c| c.contains(filter));
129                if !matches {
130                    continue;
131                }
132            }
133            // Return the first matching non-Allow action.
134            if hook.action != HookAction::Allow {
135                return hook.action.clone();
136            }
137        }
138
139        HookAction::Allow
140    }
141
142    /// Return all hooks registered for a given type.
143    pub fn hooks_for(&self, hook_type: HookType) -> &[Hook] {
144        self.hooks.get(&hook_type).map_or(&[], |v| v.as_slice())
145    }
146
147    /// Total number of registered hooks.
148    pub fn len(&self) -> usize {
149        self.hooks.values().map(|v| v.len()).sum()
150    }
151
152    pub fn is_empty(&self) -> bool {
153        self.len() == 0
154    }
155}
156
157impl Default for HookManager {
158    fn default() -> Self {
159        Self::new()
160    }
161}
162
163
164// ── Platform config generation ────────────────────────────────────────────
165
166/// Known platforms for `sqz init --agent <platform>`.
167const KNOWN_PLATFORMS: &[&str] = &[
168    "claude-code",
169    "cursor",
170    "kiro",
171    "copilot",
172    "windsurf",
173    "cline",
174    "gemini-cli",
175    "codex",
176    "opencode",
177    "goose",
178    "aider",
179    "amp",
180    "continue",
181    "zed",
182    "amazon-q",
183];
184
185/// Generate a platform-specific hook configuration for `sqz init --agent <platform>`.
186///
187/// Returns a TOML string for Level 2 platforms (shell hook + MCP) and a JSON
188/// string for Level 1 platforms (MCP-only).
189///
190/// Requirements: 44.6
191pub fn generate_platform_config(platform: &str) -> Option<String> {
192    match platform {
193        // ── Level 1: MCP config only ──────────────────────────────────
194        "continue" | "zed" | "amazon-q" => Some(generate_level1_config(platform)),
195
196        // ── Level 2: Shell hook + MCP + hooks ─────────────────────────
197        "claude-code" | "cursor" | "kiro" | "copilot" | "windsurf" | "cline"
198        | "gemini-cli" | "codex" | "opencode" | "goose" | "aider" | "amp" => {
199            Some(generate_level2_config(platform))
200        }
201
202        _ => None,
203    }
204}
205
206/// Returns the list of known platform identifiers.
207pub fn known_platforms() -> &'static [&'static str] {
208    KNOWN_PLATFORMS
209}
210
211fn generate_level1_config(platform: &str) -> String {
212    let config_path = match platform {
213        "continue" => "~/.continue/config.json",
214        "zed" => "~/.config/zed/settings.json",
215        "amazon-q" => "~/.aws/amazonq/mcp.json",
216        _ => "mcp.json",
217    };
218
219    format!(
220        r#"{{
221  "_comment": "sqz MCP config for {platform}",
222  "_path": "{config_path}",
223  "mcpServers": {{
224    "sqz": {{
225      "command": "sqz-mcp",
226      "args": ["--transport", "stdio"],
227      "env": {{}}
228    }}
229  }}
230}}"#
231    )
232}
233
234fn generate_level2_config(platform: &str) -> String {
235    let config_path = match platform {
236        "claude-code" => ".claude/mcp_servers.json",
237        "cursor" => "~/.cursor/mcp.json",
238        "kiro" => ".kiro/settings/mcp.json",
239        "copilot" => ".github/copilot/mcp.json",
240        "windsurf" => "~/.windsurf/mcp.json",
241        "cline" => "~/.cline/mcp.json",
242        _ => "mcp.json",
243    };
244
245    format!(
246        r#"# sqz hook config for {platform}
247# MCP config path: {config_path}
248
249[hooks.pre_tool_use]
250enabled = true
251block_dangerous = true
252sandbox_redirect = ["shell", "bash", "exec"]
253inject_context = true
254
255[hooks.post_tool_use]
256enabled = true
257capture_events = ["file_edit", "git_op", "task_update", "error"]
258
259[hooks.pre_compact]
260enabled = true
261build_snapshot = true
262
263[hooks.session_start]
264enabled = true
265restore_snapshot = true
266
267[hooks.user_prompt_submit]
268enabled = true
269capture_decisions = true
270capture_corrections = true
271
272[mcp]
273command = "sqz-mcp"
274args = ["--transport", "stdio"]
275config_path = "{config_path}"
276"#
277    )
278}
279
280// ── Tests ─────────────────────────────────────────────────────────────────
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    // ── HookType ──────────────────────────────────────────────────────
287
288    #[test]
289    fn test_hook_type_all_returns_5_variants() {
290        assert_eq!(HookType::all().len(), 5);
291    }
292
293    #[test]
294    fn test_hook_type_labels_are_unique() {
295        let labels: Vec<&str> = HookType::all().iter().map(|h| h.label()).collect();
296        let mut deduped = labels.clone();
297        deduped.sort();
298        deduped.dedup();
299        assert_eq!(labels.len(), deduped.len());
300    }
301
302    // ── HookManager basics ────────────────────────────────────────────
303
304    #[test]
305    fn test_new_manager_is_empty() {
306        let mgr = HookManager::new();
307        assert!(mgr.is_empty());
308        assert_eq!(mgr.len(), 0);
309    }
310
311    #[test]
312    fn test_register_and_count() {
313        let mut mgr = HookManager::new();
314        mgr.register(Hook {
315            hook_type: HookType::PreToolUse,
316            action: HookAction::Block {
317                reason: "dangerous".into(),
318            },
319            filter: None,
320        });
321        assert_eq!(mgr.len(), 1);
322        assert!(!mgr.is_empty());
323    }
324
325    #[test]
326    fn test_hooks_for_returns_registered_hooks() {
327        let mut mgr = HookManager::new();
328        mgr.register(Hook {
329            hook_type: HookType::PostToolUse,
330            action: HookAction::CaptureEvent {
331                event_type: "file_edit".into(),
332                data: "{}".into(),
333            },
334            filter: None,
335        });
336        assert_eq!(mgr.hooks_for(HookType::PostToolUse).len(), 1);
337        assert_eq!(mgr.hooks_for(HookType::PreToolUse).len(), 0);
338    }
339
340    // ── fire() dispatch ───────────────────────────────────────────────
341
342    #[test]
343    fn test_fire_returns_allow_when_no_hooks() {
344        let mgr = HookManager::new();
345        let ctx = HookContext::default();
346        assert_eq!(mgr.fire(HookType::PreToolUse, &ctx), HookAction::Allow);
347    }
348
349    #[test]
350    fn test_fire_returns_first_matching_action() {
351        let mut mgr = HookManager::new();
352        mgr.register(Hook {
353            hook_type: HookType::PreToolUse,
354            action: HookAction::Block {
355                reason: "blocked".into(),
356            },
357            filter: None,
358        });
359        mgr.register(Hook {
360            hook_type: HookType::PreToolUse,
361            action: HookAction::Redirect {
362                to_tool: "sandbox".into(),
363            },
364            filter: None,
365        });
366
367        let ctx = HookContext::default();
368        // First non-Allow wins.
369        assert_eq!(
370            mgr.fire(HookType::PreToolUse, &ctx),
371            HookAction::Block {
372                reason: "blocked".into()
373            }
374        );
375    }
376
377    #[test]
378    fn test_fire_with_filter_matches_tool_name() {
379        let mut mgr = HookManager::new();
380        mgr.register(Hook {
381            hook_type: HookType::PreToolUse,
382            action: HookAction::Redirect {
383                to_tool: "sandbox".into(),
384            },
385            filter: Some("exec_shell".into()),
386        });
387
388        // No match → Allow.
389        let ctx_miss = HookContext {
390            tool_name: Some("read_file".into()),
391            ..Default::default()
392        };
393        assert_eq!(mgr.fire(HookType::PreToolUse, &ctx_miss), HookAction::Allow);
394
395        // Match → Redirect.
396        let ctx_hit = HookContext {
397            tool_name: Some("exec_shell".into()),
398            ..Default::default()
399        };
400        assert_eq!(
401            mgr.fire(HookType::PreToolUse, &ctx_hit),
402            HookAction::Redirect {
403                to_tool: "sandbox".into()
404            }
405        );
406    }
407
408    #[test]
409    fn test_fire_with_filter_matches_command_substring() {
410        let mut mgr = HookManager::new();
411        mgr.register(Hook {
412            hook_type: HookType::PreToolUse,
413            action: HookAction::Block {
414                reason: "rm blocked".into(),
415            },
416            filter: Some("rm -rf".into()),
417        });
418
419        let ctx = HookContext {
420            command: Some("rm -rf /tmp/stuff".into()),
421            ..Default::default()
422        };
423        assert_eq!(
424            mgr.fire(HookType::PreToolUse, &ctx),
425            HookAction::Block {
426                reason: "rm blocked".into()
427            }
428        );
429    }
430
431    // ── PreToolUse actions ────────────────────────────────────────────
432
433    #[test]
434    fn test_pre_tool_use_block() {
435        let mut mgr = HookManager::new();
436        mgr.register(Hook {
437            hook_type: HookType::PreToolUse,
438            action: HookAction::Block {
439                reason: "dangerous command".into(),
440            },
441            filter: None,
442        });
443        let action = mgr.fire(HookType::PreToolUse, &HookContext::default());
444        assert!(matches!(action, HookAction::Block { .. }));
445    }
446
447    #[test]
448    fn test_pre_tool_use_redirect() {
449        let mut mgr = HookManager::new();
450        mgr.register(Hook {
451            hook_type: HookType::PreToolUse,
452            action: HookAction::Redirect {
453                to_tool: "sandbox_exec".into(),
454            },
455            filter: None,
456        });
457        let action = mgr.fire(HookType::PreToolUse, &HookContext::default());
458        assert!(matches!(action, HookAction::Redirect { .. }));
459    }
460
461    #[test]
462    fn test_pre_tool_use_inject_context() {
463        let mut mgr = HookManager::new();
464        mgr.register(Hook {
465            hook_type: HookType::PreToolUse,
466            action: HookAction::InjectContext {
467                content: "extra context".into(),
468            },
469            filter: None,
470        });
471        let action = mgr.fire(HookType::PreToolUse, &HookContext::default());
472        assert!(matches!(action, HookAction::InjectContext { .. }));
473    }
474
475    // ── PostToolUse capture ───────────────────────────────────────────
476
477    #[test]
478    fn test_post_tool_use_capture_event() {
479        let mut mgr = HookManager::new();
480        mgr.register(Hook {
481            hook_type: HookType::PostToolUse,
482            action: HookAction::CaptureEvent {
483                event_type: "file_edit".into(),
484                data: r#"{"path":"src/main.rs"}"#.into(),
485            },
486            filter: None,
487        });
488        let action = mgr.fire(HookType::PostToolUse, &HookContext::default());
489        assert!(matches!(action, HookAction::CaptureEvent { .. }));
490    }
491
492    // ── PreCompact snapshot ───────────────────────────────────────────
493
494    #[test]
495    fn test_pre_compact_build_snapshot() {
496        let mut mgr = HookManager::new();
497        mgr.register(Hook {
498            hook_type: HookType::PreCompact,
499            action: HookAction::BuildSnapshot,
500            filter: None,
501        });
502        let action = mgr.fire(HookType::PreCompact, &HookContext::default());
503        assert_eq!(action, HookAction::BuildSnapshot);
504    }
505
506    // ── SessionStart restore ──────────────────────────────────────────
507
508    #[test]
509    fn test_session_start_restore_snapshot() {
510        let mut mgr = HookManager::new();
511        mgr.register(Hook {
512            hook_type: HookType::SessionStart,
513            action: HookAction::RestoreSnapshot,
514            filter: None,
515        });
516        let action = mgr.fire(HookType::SessionStart, &HookContext::default());
517        assert_eq!(action, HookAction::RestoreSnapshot);
518    }
519
520    // ── UserPromptSubmit capture ──────────────────────────────────────
521
522    #[test]
523    fn test_user_prompt_submit_capture_decision() {
524        let mut mgr = HookManager::new();
525        mgr.register(Hook {
526            hook_type: HookType::UserPromptSubmit,
527            action: HookAction::CaptureDecision {
528                decision: "use async/await".into(),
529            },
530            filter: None,
531        });
532        let action = mgr.fire(HookType::UserPromptSubmit, &HookContext::default());
533        assert!(matches!(action, HookAction::CaptureDecision { .. }));
534    }
535
536    // ── Platform config generation ────────────────────────────────────
537
538    #[test]
539    fn test_generate_config_unknown_platform_returns_none() {
540        assert!(generate_platform_config("unknown-platform").is_none());
541    }
542
543    #[test]
544    fn test_generate_config_level1_platforms_produce_json() {
545        for platform in &["continue", "zed", "amazon-q"] {
546            let config = generate_platform_config(platform).unwrap();
547            assert!(config.contains("mcpServers"), "missing mcpServers for {platform}");
548            assert!(config.contains("sqz-mcp"), "missing sqz-mcp for {platform}");
549        }
550    }
551
552    #[test]
553    fn test_generate_config_level2_platforms_produce_toml() {
554        for platform in &[
555            "claude-code", "cursor", "kiro", "copilot", "windsurf", "cline",
556            "gemini-cli", "codex", "opencode", "goose", "aider", "amp",
557        ] {
558            let config = generate_platform_config(platform).unwrap();
559            assert!(
560                config.contains("[hooks.pre_tool_use]"),
561                "missing pre_tool_use section for {platform}"
562            );
563            assert!(
564                config.contains("[hooks.session_start]"),
565                "missing session_start section for {platform}"
566            );
567            assert!(
568                config.contains("sqz-mcp"),
569                "missing sqz-mcp for {platform}"
570            );
571        }
572    }
573
574    #[test]
575    fn test_generate_config_claude_code_has_correct_path() {
576        let config = generate_platform_config("claude-code").unwrap();
577        assert!(config.contains(".claude/mcp_servers.json"));
578    }
579
580    #[test]
581    fn test_generate_config_kiro_has_correct_path() {
582        let config = generate_platform_config("kiro").unwrap();
583        assert!(config.contains(".kiro/settings/mcp.json"));
584    }
585
586    #[test]
587    fn test_generate_config_cursor_has_correct_path() {
588        let config = generate_platform_config("cursor").unwrap();
589        assert!(config.contains("~/.cursor/mcp.json"));
590    }
591
592    #[test]
593    fn test_known_platforms_covers_all() {
594        assert_eq!(known_platforms().len(), 15);
595        // Every known platform should produce a config.
596        for p in known_platforms() {
597            assert!(
598                generate_platform_config(p).is_some(),
599                "no config for known platform: {p}"
600            );
601        }
602    }
603
604    #[test]
605    fn test_level2_config_contains_all_5_hook_sections() {
606        let config = generate_platform_config("claude-code").unwrap();
607        assert!(config.contains("[hooks.pre_tool_use]"));
608        assert!(config.contains("[hooks.post_tool_use]"));
609        assert!(config.contains("[hooks.pre_compact]"));
610        assert!(config.contains("[hooks.session_start]"));
611        assert!(config.contains("[hooks.user_prompt_submit]"));
612    }
613
614    // ── Multiple hooks per type ───────────────────────────────────────
615
616    #[test]
617    fn test_multiple_hooks_same_type_different_filters() {
618        let mut mgr = HookManager::new();
619        mgr.register(Hook {
620            hook_type: HookType::PreToolUse,
621            action: HookAction::Block {
622                reason: "shell blocked".into(),
623            },
624            filter: Some("exec_shell".into()),
625        });
626        mgr.register(Hook {
627            hook_type: HookType::PreToolUse,
628            action: HookAction::Redirect {
629                to_tool: "sandbox".into(),
630            },
631            filter: Some("run_code".into()),
632        });
633
634        assert_eq!(mgr.len(), 2);
635
636        let ctx_shell = HookContext {
637            tool_name: Some("exec_shell".into()),
638            ..Default::default()
639        };
640        assert!(matches!(
641            mgr.fire(HookType::PreToolUse, &ctx_shell),
642            HookAction::Block { .. }
643        ));
644
645        let ctx_code = HookContext {
646            tool_name: Some("run_code".into()),
647            ..Default::default()
648        };
649        assert!(matches!(
650            mgr.fire(HookType::PreToolUse, &ctx_code),
651            HookAction::Redirect { .. }
652        ));
653    }
654}