1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub enum HookType {
9 PreToolUse,
11 PostToolUse,
13 PreCompact,
15 SessionStart,
17 UserPromptSubmit,
19}
20
21impl HookType {
22 pub fn all() -> &'static [HookType] {
24 &[
25 HookType::PreToolUse,
26 HookType::PostToolUse,
27 HookType::PreCompact,
28 HookType::SessionStart,
29 HookType::UserPromptSubmit,
30 ]
31 }
32
33 pub fn label(&self) -> &'static str {
35 match self {
36 HookType::PreToolUse => "pre_tool_use",
37 HookType::PostToolUse => "post_tool_use",
38 HookType::PreCompact => "pre_compact",
39 HookType::SessionStart => "session_start",
40 HookType::UserPromptSubmit => "user_prompt_submit",
41 }
42 }
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
49pub enum HookAction {
50 Allow,
52 Block { reason: String },
54 Redirect { to_tool: String },
56 InjectContext { content: String },
58 CaptureEvent { event_type: String, data: String },
60 BuildSnapshot,
62 RestoreSnapshot,
64 CaptureDecision { decision: String },
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Hook {
71 pub hook_type: HookType,
72 pub action: HookAction,
73 pub filter: Option<String>,
76}
77
78#[derive(Debug, Clone, Default)]
80pub struct HookContext {
81 pub tool_name: Option<String>,
83 pub command: Option<String>,
85 pub metadata: HashMap<String, String>,
87}
88
89pub struct HookManager {
93 hooks: HashMap<HookType, Vec<Hook>>,
94}
95
96impl HookManager {
97 pub fn new() -> Self {
98 Self {
99 hooks: HashMap::new(),
100 }
101 }
102
103 pub fn register(&mut self, hook: Hook) {
105 self.hooks
106 .entry(hook.hook_type)
107 .or_default()
108 .push(hook);
109 }
110
111 pub fn fire(&self, hook_type: HookType, context: &HookContext) -> HookAction {
114 let Some(hooks) = self.hooks.get(&hook_type) else {
115 return HookAction::Allow;
116 };
117
118 for hook in hooks {
119 if let Some(ref filter) = hook.filter {
120 let matches = context
122 .tool_name
123 .as_deref()
124 .map_or(false, |t| t == filter)
125 || context
126 .command
127 .as_deref()
128 .map_or(false, |c| c.contains(filter));
129 if !matches {
130 continue;
131 }
132 }
133 if hook.action != HookAction::Allow {
135 return hook.action.clone();
136 }
137 }
138
139 HookAction::Allow
140 }
141
142 pub fn hooks_for(&self, hook_type: HookType) -> &[Hook] {
144 self.hooks.get(&hook_type).map_or(&[], |v| v.as_slice())
145 }
146
147 pub fn len(&self) -> usize {
149 self.hooks.values().map(|v| v.len()).sum()
150 }
151
152 pub fn is_empty(&self) -> bool {
153 self.len() == 0
154 }
155}
156
157impl Default for HookManager {
158 fn default() -> Self {
159 Self::new()
160 }
161}
162
163
164const KNOWN_PLATFORMS: &[&str] = &[
168 "claude-code",
169 "cursor",
170 "kiro",
171 "copilot",
172 "windsurf",
173 "cline",
174 "gemini-cli",
175 "codex",
176 "opencode",
177 "goose",
178 "aider",
179 "amp",
180 "continue",
181 "zed",
182 "amazon-q",
183];
184
185pub fn generate_platform_config(platform: &str) -> Option<String> {
192 match platform {
193 "continue" | "zed" | "amazon-q" => Some(generate_level1_config(platform)),
195
196 "claude-code" | "cursor" | "kiro" | "copilot" | "windsurf" | "cline"
198 | "gemini-cli" | "codex" | "opencode" | "goose" | "aider" | "amp" => {
199 Some(generate_level2_config(platform))
200 }
201
202 _ => None,
203 }
204}
205
206pub fn known_platforms() -> &'static [&'static str] {
208 KNOWN_PLATFORMS
209}
210
211fn generate_level1_config(platform: &str) -> String {
212 let config_path = match platform {
213 "continue" => "~/.continue/config.json",
214 "zed" => "~/.config/zed/settings.json",
215 "amazon-q" => "~/.aws/amazonq/mcp.json",
216 _ => "mcp.json",
217 };
218
219 format!(
220 r#"{{
221 "_comment": "sqz MCP config for {platform}",
222 "_path": "{config_path}",
223 "mcpServers": {{
224 "sqz": {{
225 "command": "sqz-mcp",
226 "args": ["--transport", "stdio"],
227 "env": {{}}
228 }}
229 }}
230}}"#
231 )
232}
233
234fn generate_level2_config(platform: &str) -> String {
235 let config_path = match platform {
236 "claude-code" => ".claude/mcp_servers.json",
237 "cursor" => "~/.cursor/mcp.json",
238 "kiro" => ".kiro/settings/mcp.json",
239 "copilot" => ".github/copilot/mcp.json",
240 "windsurf" => "~/.windsurf/mcp.json",
241 "cline" => "~/.cline/mcp.json",
242 _ => "mcp.json",
243 };
244
245 format!(
246 r#"# sqz hook config for {platform}
247# MCP config path: {config_path}
248
249[hooks.pre_tool_use]
250enabled = true
251block_dangerous = true
252sandbox_redirect = ["shell", "bash", "exec"]
253inject_context = true
254
255[hooks.post_tool_use]
256enabled = true
257capture_events = ["file_edit", "git_op", "task_update", "error"]
258
259[hooks.pre_compact]
260enabled = true
261build_snapshot = true
262
263[hooks.session_start]
264enabled = true
265restore_snapshot = true
266
267[hooks.user_prompt_submit]
268enabled = true
269capture_decisions = true
270capture_corrections = true
271
272[mcp]
273command = "sqz-mcp"
274args = ["--transport", "stdio"]
275config_path = "{config_path}"
276"#
277 )
278}
279
280#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
289 fn test_hook_type_all_returns_5_variants() {
290 assert_eq!(HookType::all().len(), 5);
291 }
292
293 #[test]
294 fn test_hook_type_labels_are_unique() {
295 let labels: Vec<&str> = HookType::all().iter().map(|h| h.label()).collect();
296 let mut deduped = labels.clone();
297 deduped.sort();
298 deduped.dedup();
299 assert_eq!(labels.len(), deduped.len());
300 }
301
302 #[test]
305 fn test_new_manager_is_empty() {
306 let mgr = HookManager::new();
307 assert!(mgr.is_empty());
308 assert_eq!(mgr.len(), 0);
309 }
310
311 #[test]
312 fn test_register_and_count() {
313 let mut mgr = HookManager::new();
314 mgr.register(Hook {
315 hook_type: HookType::PreToolUse,
316 action: HookAction::Block {
317 reason: "dangerous".into(),
318 },
319 filter: None,
320 });
321 assert_eq!(mgr.len(), 1);
322 assert!(!mgr.is_empty());
323 }
324
325 #[test]
326 fn test_hooks_for_returns_registered_hooks() {
327 let mut mgr = HookManager::new();
328 mgr.register(Hook {
329 hook_type: HookType::PostToolUse,
330 action: HookAction::CaptureEvent {
331 event_type: "file_edit".into(),
332 data: "{}".into(),
333 },
334 filter: None,
335 });
336 assert_eq!(mgr.hooks_for(HookType::PostToolUse).len(), 1);
337 assert_eq!(mgr.hooks_for(HookType::PreToolUse).len(), 0);
338 }
339
340 #[test]
343 fn test_fire_returns_allow_when_no_hooks() {
344 let mgr = HookManager::new();
345 let ctx = HookContext::default();
346 assert_eq!(mgr.fire(HookType::PreToolUse, &ctx), HookAction::Allow);
347 }
348
349 #[test]
350 fn test_fire_returns_first_matching_action() {
351 let mut mgr = HookManager::new();
352 mgr.register(Hook {
353 hook_type: HookType::PreToolUse,
354 action: HookAction::Block {
355 reason: "blocked".into(),
356 },
357 filter: None,
358 });
359 mgr.register(Hook {
360 hook_type: HookType::PreToolUse,
361 action: HookAction::Redirect {
362 to_tool: "sandbox".into(),
363 },
364 filter: None,
365 });
366
367 let ctx = HookContext::default();
368 assert_eq!(
370 mgr.fire(HookType::PreToolUse, &ctx),
371 HookAction::Block {
372 reason: "blocked".into()
373 }
374 );
375 }
376
377 #[test]
378 fn test_fire_with_filter_matches_tool_name() {
379 let mut mgr = HookManager::new();
380 mgr.register(Hook {
381 hook_type: HookType::PreToolUse,
382 action: HookAction::Redirect {
383 to_tool: "sandbox".into(),
384 },
385 filter: Some("exec_shell".into()),
386 });
387
388 let ctx_miss = HookContext {
390 tool_name: Some("read_file".into()),
391 ..Default::default()
392 };
393 assert_eq!(mgr.fire(HookType::PreToolUse, &ctx_miss), HookAction::Allow);
394
395 let ctx_hit = HookContext {
397 tool_name: Some("exec_shell".into()),
398 ..Default::default()
399 };
400 assert_eq!(
401 mgr.fire(HookType::PreToolUse, &ctx_hit),
402 HookAction::Redirect {
403 to_tool: "sandbox".into()
404 }
405 );
406 }
407
408 #[test]
409 fn test_fire_with_filter_matches_command_substring() {
410 let mut mgr = HookManager::new();
411 mgr.register(Hook {
412 hook_type: HookType::PreToolUse,
413 action: HookAction::Block {
414 reason: "rm blocked".into(),
415 },
416 filter: Some("rm -rf".into()),
417 });
418
419 let ctx = HookContext {
420 command: Some("rm -rf /tmp/stuff".into()),
421 ..Default::default()
422 };
423 assert_eq!(
424 mgr.fire(HookType::PreToolUse, &ctx),
425 HookAction::Block {
426 reason: "rm blocked".into()
427 }
428 );
429 }
430
431 #[test]
434 fn test_pre_tool_use_block() {
435 let mut mgr = HookManager::new();
436 mgr.register(Hook {
437 hook_type: HookType::PreToolUse,
438 action: HookAction::Block {
439 reason: "dangerous command".into(),
440 },
441 filter: None,
442 });
443 let action = mgr.fire(HookType::PreToolUse, &HookContext::default());
444 assert!(matches!(action, HookAction::Block { .. }));
445 }
446
447 #[test]
448 fn test_pre_tool_use_redirect() {
449 let mut mgr = HookManager::new();
450 mgr.register(Hook {
451 hook_type: HookType::PreToolUse,
452 action: HookAction::Redirect {
453 to_tool: "sandbox_exec".into(),
454 },
455 filter: None,
456 });
457 let action = mgr.fire(HookType::PreToolUse, &HookContext::default());
458 assert!(matches!(action, HookAction::Redirect { .. }));
459 }
460
461 #[test]
462 fn test_pre_tool_use_inject_context() {
463 let mut mgr = HookManager::new();
464 mgr.register(Hook {
465 hook_type: HookType::PreToolUse,
466 action: HookAction::InjectContext {
467 content: "extra context".into(),
468 },
469 filter: None,
470 });
471 let action = mgr.fire(HookType::PreToolUse, &HookContext::default());
472 assert!(matches!(action, HookAction::InjectContext { .. }));
473 }
474
475 #[test]
478 fn test_post_tool_use_capture_event() {
479 let mut mgr = HookManager::new();
480 mgr.register(Hook {
481 hook_type: HookType::PostToolUse,
482 action: HookAction::CaptureEvent {
483 event_type: "file_edit".into(),
484 data: r#"{"path":"src/main.rs"}"#.into(),
485 },
486 filter: None,
487 });
488 let action = mgr.fire(HookType::PostToolUse, &HookContext::default());
489 assert!(matches!(action, HookAction::CaptureEvent { .. }));
490 }
491
492 #[test]
495 fn test_pre_compact_build_snapshot() {
496 let mut mgr = HookManager::new();
497 mgr.register(Hook {
498 hook_type: HookType::PreCompact,
499 action: HookAction::BuildSnapshot,
500 filter: None,
501 });
502 let action = mgr.fire(HookType::PreCompact, &HookContext::default());
503 assert_eq!(action, HookAction::BuildSnapshot);
504 }
505
506 #[test]
509 fn test_session_start_restore_snapshot() {
510 let mut mgr = HookManager::new();
511 mgr.register(Hook {
512 hook_type: HookType::SessionStart,
513 action: HookAction::RestoreSnapshot,
514 filter: None,
515 });
516 let action = mgr.fire(HookType::SessionStart, &HookContext::default());
517 assert_eq!(action, HookAction::RestoreSnapshot);
518 }
519
520 #[test]
523 fn test_user_prompt_submit_capture_decision() {
524 let mut mgr = HookManager::new();
525 mgr.register(Hook {
526 hook_type: HookType::UserPromptSubmit,
527 action: HookAction::CaptureDecision {
528 decision: "use async/await".into(),
529 },
530 filter: None,
531 });
532 let action = mgr.fire(HookType::UserPromptSubmit, &HookContext::default());
533 assert!(matches!(action, HookAction::CaptureDecision { .. }));
534 }
535
536 #[test]
539 fn test_generate_config_unknown_platform_returns_none() {
540 assert!(generate_platform_config("unknown-platform").is_none());
541 }
542
543 #[test]
544 fn test_generate_config_level1_platforms_produce_json() {
545 for platform in &["continue", "zed", "amazon-q"] {
546 let config = generate_platform_config(platform).unwrap();
547 assert!(config.contains("mcpServers"), "missing mcpServers for {platform}");
548 assert!(config.contains("sqz-mcp"), "missing sqz-mcp for {platform}");
549 }
550 }
551
552 #[test]
553 fn test_generate_config_level2_platforms_produce_toml() {
554 for platform in &[
555 "claude-code", "cursor", "kiro", "copilot", "windsurf", "cline",
556 "gemini-cli", "codex", "opencode", "goose", "aider", "amp",
557 ] {
558 let config = generate_platform_config(platform).unwrap();
559 assert!(
560 config.contains("[hooks.pre_tool_use]"),
561 "missing pre_tool_use section for {platform}"
562 );
563 assert!(
564 config.contains("[hooks.session_start]"),
565 "missing session_start section for {platform}"
566 );
567 assert!(
568 config.contains("sqz-mcp"),
569 "missing sqz-mcp for {platform}"
570 );
571 }
572 }
573
574 #[test]
575 fn test_generate_config_claude_code_has_correct_path() {
576 let config = generate_platform_config("claude-code").unwrap();
577 assert!(config.contains(".claude/mcp_servers.json"));
578 }
579
580 #[test]
581 fn test_generate_config_kiro_has_correct_path() {
582 let config = generate_platform_config("kiro").unwrap();
583 assert!(config.contains(".kiro/settings/mcp.json"));
584 }
585
586 #[test]
587 fn test_generate_config_cursor_has_correct_path() {
588 let config = generate_platform_config("cursor").unwrap();
589 assert!(config.contains("~/.cursor/mcp.json"));
590 }
591
592 #[test]
593 fn test_known_platforms_covers_all() {
594 assert_eq!(known_platforms().len(), 15);
595 for p in known_platforms() {
597 assert!(
598 generate_platform_config(p).is_some(),
599 "no config for known platform: {p}"
600 );
601 }
602 }
603
604 #[test]
605 fn test_level2_config_contains_all_5_hook_sections() {
606 let config = generate_platform_config("claude-code").unwrap();
607 assert!(config.contains("[hooks.pre_tool_use]"));
608 assert!(config.contains("[hooks.post_tool_use]"));
609 assert!(config.contains("[hooks.pre_compact]"));
610 assert!(config.contains("[hooks.session_start]"));
611 assert!(config.contains("[hooks.user_prompt_submit]"));
612 }
613
614 #[test]
617 fn test_multiple_hooks_same_type_different_filters() {
618 let mut mgr = HookManager::new();
619 mgr.register(Hook {
620 hook_type: HookType::PreToolUse,
621 action: HookAction::Block {
622 reason: "shell blocked".into(),
623 },
624 filter: Some("exec_shell".into()),
625 });
626 mgr.register(Hook {
627 hook_type: HookType::PreToolUse,
628 action: HookAction::Redirect {
629 to_tool: "sandbox".into(),
630 },
631 filter: Some("run_code".into()),
632 });
633
634 assert_eq!(mgr.len(), 2);
635
636 let ctx_shell = HookContext {
637 tool_name: Some("exec_shell".into()),
638 ..Default::default()
639 };
640 assert!(matches!(
641 mgr.fire(HookType::PreToolUse, &ctx_shell),
642 HookAction::Block { .. }
643 ));
644
645 let ctx_code = HookContext {
646 tool_name: Some("run_code".into()),
647 ..Default::default()
648 };
649 assert!(matches!(
650 mgr.fire(HookType::PreToolUse, &ctx_code),
651 HookAction::Redirect { .. }
652 ));
653 }
654}