Skip to main content

lean_ctx/
hook_handlers.rs

1use crate::compound_lexer;
2use crate::rewrite_registry;
3use std::io::Read;
4
5pub fn handle_rewrite() {
6    let binary = resolve_binary();
7    let mut input = String::new();
8    if std::io::stdin().read_to_string(&mut input).is_err() {
9        return;
10    }
11
12    let tool = extract_json_field(&input, "tool_name");
13    if !matches!(tool.as_deref(), Some("Bash" | "bash")) {
14        return;
15    }
16
17    let cmd = match extract_json_field(&input, "command") {
18        Some(c) => c,
19        None => return,
20    };
21
22    if cmd.starts_with("lean-ctx ") || cmd.starts_with(&format!("{binary} ")) {
23        return;
24    }
25
26    if let Some(rewritten) = build_rewrite_compound(&cmd, &binary) {
27        emit_rewrite(&rewritten);
28        return;
29    }
30
31    if is_rewritable(&cmd) {
32        let rewritten = wrap_single_command(&cmd, &binary);
33        emit_rewrite(&rewritten);
34    }
35}
36
37fn is_rewritable(cmd: &str) -> bool {
38    rewrite_registry::is_rewritable_command(cmd)
39}
40
41fn wrap_single_command(cmd: &str, binary: &str) -> String {
42    let shell_escaped = cmd.replace('\\', "\\\\").replace('"', "\\\"");
43    format!("{binary} -c \"{shell_escaped}\"")
44}
45
46fn build_rewrite_compound(cmd: &str, binary: &str) -> Option<String> {
47    compound_lexer::rewrite_compound(cmd, |segment| {
48        if segment.starts_with("lean-ctx ") || segment.starts_with(&format!("{binary} ")) {
49            return None;
50        }
51        if is_rewritable(segment) {
52            Some(wrap_single_command(segment, binary))
53        } else {
54            None
55        }
56    })
57}
58
59fn emit_rewrite(rewritten: &str) {
60    let json_escaped = rewritten.replace('\\', "\\\\").replace('"', "\\\"");
61    print!(
62        "{{\"hookSpecificOutput\":{{\"hookEventName\":\"PreToolUse\",\"permissionDecision\":\"allow\",\"updatedInput\":{{\"command\":\"{json_escaped}\"}}}}}}"
63    );
64}
65
66pub fn handle_redirect() {
67    // Allow all native tools (Read, Grep, ListFiles) to pass through.
68    // Blocking them breaks Edit (which requires native Read) and causes
69    // unnecessary friction. The MCP instructions already guide the AI
70    // to prefer ctx_read/ctx_search/ctx_tree.
71}
72
73/// Copilot-specific PreToolUse handler.
74/// VS Code Copilot Chat uses the same hook format as Claude Code.
75/// Tool names differ: "runInTerminal" / "editFile" instead of "Bash" / "Read".
76pub fn handle_copilot() {
77    let binary = resolve_binary();
78    let mut input = String::new();
79    if std::io::stdin().read_to_string(&mut input).is_err() {
80        return;
81    }
82
83    let tool = extract_json_field(&input, "tool_name");
84    let tool_name = match tool.as_deref() {
85        Some(name) => name,
86        None => return,
87    };
88
89    let is_shell_tool = matches!(
90        tool_name,
91        "Bash" | "bash" | "runInTerminal" | "run_in_terminal" | "terminal" | "shell"
92    );
93    if !is_shell_tool {
94        return;
95    }
96
97    let cmd = match extract_json_field(&input, "command") {
98        Some(c) => c,
99        None => return,
100    };
101
102    if cmd.starts_with("lean-ctx ") || cmd.starts_with(&format!("{binary} ")) {
103        return;
104    }
105
106    if let Some(rewritten) = build_rewrite_compound(&cmd, &binary) {
107        emit_rewrite(&rewritten);
108        return;
109    }
110
111    if is_rewritable(&cmd) {
112        let rewritten = wrap_single_command(&cmd, &binary);
113        emit_rewrite(&rewritten);
114    }
115}
116
117/// Inline rewrite: takes a command as CLI args, prints the rewritten command to stdout.
118/// Used by the OpenCode TS plugin where the command is passed as an argument,
119/// not via stdin JSON.
120pub fn handle_rewrite_inline() {
121    let binary = resolve_binary();
122    let args: Vec<String> = std::env::args().collect();
123    // args: [binary, "hook", "rewrite-inline", ...command parts]
124    if args.len() < 4 {
125        return;
126    }
127    let cmd = args[3..].join(" ");
128
129    if cmd.starts_with("lean-ctx ") || cmd.starts_with(&format!("{binary} ")) {
130        print!("{cmd}");
131        return;
132    }
133
134    if let Some(rewritten) = build_rewrite_compound(&cmd, &binary) {
135        print!("{rewritten}");
136        return;
137    }
138
139    if is_rewritable(&cmd) {
140        let rewritten = wrap_single_command(&cmd, &binary);
141        print!("{rewritten}");
142        return;
143    }
144
145    print!("{cmd}");
146}
147
148fn resolve_binary() -> String {
149    std::env::current_exe()
150        .map(|p| p.to_string_lossy().to_string())
151        .unwrap_or_else(|_| "lean-ctx".to_string())
152}
153
154fn extract_json_field(input: &str, field: &str) -> Option<String> {
155    let pattern = format!("\"{}\":\"", field);
156    let start = input.find(&pattern)? + pattern.len();
157    let rest = &input[start..];
158    let bytes = rest.as_bytes();
159    let mut end = 0;
160    while end < bytes.len() {
161        if bytes[end] == b'\\' && end + 1 < bytes.len() {
162            end += 2;
163            continue;
164        }
165        if bytes[end] == b'"' {
166            break;
167        }
168        end += 1;
169    }
170    if end >= bytes.len() {
171        return None;
172    }
173    let raw = &rest[..end];
174    Some(raw.replace("\\\"", "\"").replace("\\\\", "\\"))
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn is_rewritable_basic() {
183        assert!(is_rewritable("git status"));
184        assert!(is_rewritable("cargo test --lib"));
185        assert!(is_rewritable("npm run build"));
186        assert!(!is_rewritable("echo hello"));
187        assert!(!is_rewritable("cd src"));
188    }
189
190    #[test]
191    fn wrap_single() {
192        let r = wrap_single_command("git status", "lean-ctx");
193        assert_eq!(r, r#"lean-ctx -c "git status""#);
194    }
195
196    #[test]
197    fn wrap_with_quotes() {
198        let r = wrap_single_command(r#"curl -H "Auth" https://api.com"#, "lean-ctx");
199        assert_eq!(r, r#"lean-ctx -c "curl -H \"Auth\" https://api.com""#);
200    }
201
202    #[test]
203    fn compound_rewrite_and_chain() {
204        let result = build_rewrite_compound("cd src && git status && echo done", "lean-ctx");
205        assert_eq!(
206            result,
207            Some(r#"cd src && lean-ctx -c "git status" && echo done"#.into())
208        );
209    }
210
211    #[test]
212    fn compound_rewrite_pipe() {
213        let result = build_rewrite_compound("git log --oneline | head -5", "lean-ctx");
214        assert_eq!(
215            result,
216            Some(r#"lean-ctx -c "git log --oneline" | head -5"#.into())
217        );
218    }
219
220    #[test]
221    fn compound_rewrite_no_match() {
222        let result = build_rewrite_compound("cd src && echo done", "lean-ctx");
223        assert_eq!(result, None);
224    }
225
226    #[test]
227    fn compound_rewrite_multiple_rewritable() {
228        let result = build_rewrite_compound("git add . && cargo test && npm run lint", "lean-ctx");
229        assert_eq!(
230            result,
231            Some(
232                r#"lean-ctx -c "git add ." && lean-ctx -c "cargo test" && lean-ctx -c "npm run lint""#
233                    .into()
234            )
235        );
236    }
237
238    #[test]
239    fn compound_rewrite_semicolons() {
240        let result = build_rewrite_compound("git add .; git commit -m 'fix'", "lean-ctx");
241        assert_eq!(
242            result,
243            Some(r#"lean-ctx -c "git add ." ; lean-ctx -c "git commit -m 'fix'""#.into())
244        );
245    }
246
247    #[test]
248    fn compound_rewrite_or_chain() {
249        let result = build_rewrite_compound("git pull || echo failed", "lean-ctx");
250        assert_eq!(
251            result,
252            Some(r#"lean-ctx -c "git pull" || echo failed"#.into())
253        );
254    }
255
256    #[test]
257    fn compound_skips_already_rewritten() {
258        let result = build_rewrite_compound("lean-ctx -c git status && git diff", "lean-ctx");
259        assert_eq!(
260            result,
261            Some(r#"lean-ctx -c git status && lean-ctx -c "git diff""#.into())
262        );
263    }
264
265    #[test]
266    fn single_command_not_compound() {
267        let result = build_rewrite_compound("git status", "lean-ctx");
268        assert_eq!(result, None);
269    }
270
271    #[test]
272    fn extract_field_works() {
273        let input = r#"{"tool_name":"Bash","command":"git status"}"#;
274        assert_eq!(
275            extract_json_field(input, "tool_name"),
276            Some("Bash".to_string())
277        );
278        assert_eq!(
279            extract_json_field(input, "command"),
280            Some("git status".to_string())
281        );
282    }
283
284    #[test]
285    fn extract_field_handles_escaped_quotes() {
286        let input = r#"{"tool_name":"Bash","command":"grep -r \"TODO\" src/"}"#;
287        assert_eq!(
288            extract_json_field(input, "command"),
289            Some(r#"grep -r "TODO" src/"#.to_string())
290        );
291    }
292
293    #[test]
294    fn extract_field_handles_escaped_backslash() {
295        let input = r#"{"tool_name":"Bash","command":"echo \\\"hello\\\""}"#;
296        assert_eq!(
297            extract_json_field(input, "command"),
298            Some(r#"echo \"hello\""#.to_string())
299        );
300    }
301
302    #[test]
303    fn extract_field_handles_complex_curl() {
304        let input = r#"{"tool_name":"Bash","command":"curl -H \"Authorization: Bearer token\" https://api.com"}"#;
305        assert_eq!(
306            extract_json_field(input, "command"),
307            Some(r#"curl -H "Authorization: Bearer token" https://api.com"#.to_string())
308        );
309    }
310}