lean_ctx/
hook_handlers.rs1use crate::compound_lexer;
2use crate::rewrite_registry;
3use std::io::Read;
4
5pub fn handle_rewrite() {
6 let binary = resolve_binary();
7 let mut input = String::new();
8 if std::io::stdin().read_to_string(&mut input).is_err() {
9 return;
10 }
11
12 let tool = extract_json_field(&input, "tool_name");
13 if !matches!(tool.as_deref(), Some("Bash" | "bash")) {
14 return;
15 }
16
17 let cmd = match extract_json_field(&input, "command") {
18 Some(c) => c,
19 None => return,
20 };
21
22 if cmd.starts_with("lean-ctx ") || cmd.starts_with(&format!("{binary} ")) {
23 return;
24 }
25
26 if let Some(rewritten) = build_rewrite_compound(&cmd, &binary) {
27 emit_rewrite(&rewritten);
28 return;
29 }
30
31 if is_rewritable(&cmd) {
32 let rewritten = wrap_single_command(&cmd, &binary);
33 emit_rewrite(&rewritten);
34 }
35}
36
37fn is_rewritable(cmd: &str) -> bool {
38 rewrite_registry::is_rewritable_command(cmd)
39}
40
41fn wrap_single_command(cmd: &str, binary: &str) -> String {
42 let shell_escaped = cmd.replace('\\', "\\\\").replace('"', "\\\"");
43 format!("{binary} -c \"{shell_escaped}\"")
44}
45
46fn build_rewrite_compound(cmd: &str, binary: &str) -> Option<String> {
47 compound_lexer::rewrite_compound(cmd, |segment| {
48 if segment.starts_with("lean-ctx ") || segment.starts_with(&format!("{binary} ")) {
49 return None;
50 }
51 if is_rewritable(segment) {
52 Some(wrap_single_command(segment, binary))
53 } else {
54 None
55 }
56 })
57}
58
59fn emit_rewrite(rewritten: &str) {
60 let json_escaped = rewritten.replace('\\', "\\\\").replace('"', "\\\"");
61 print!(
62 "{{\"hookSpecificOutput\":{{\"hookEventName\":\"PreToolUse\",\"permissionDecision\":\"allow\",\"updatedInput\":{{\"command\":\"{json_escaped}\"}}}}}}"
63 );
64}
65
66pub fn handle_redirect() {
67 }
72
73pub fn handle_copilot() {
77 let binary = resolve_binary();
78 let mut input = String::new();
79 if std::io::stdin().read_to_string(&mut input).is_err() {
80 return;
81 }
82
83 let tool = extract_json_field(&input, "tool_name");
84 let tool_name = match tool.as_deref() {
85 Some(name) => name,
86 None => return,
87 };
88
89 let is_shell_tool = matches!(
90 tool_name,
91 "Bash" | "bash" | "runInTerminal" | "run_in_terminal" | "terminal" | "shell"
92 );
93 if !is_shell_tool {
94 return;
95 }
96
97 let cmd = match extract_json_field(&input, "command") {
98 Some(c) => c,
99 None => return,
100 };
101
102 if cmd.starts_with("lean-ctx ") || cmd.starts_with(&format!("{binary} ")) {
103 return;
104 }
105
106 if let Some(rewritten) = build_rewrite_compound(&cmd, &binary) {
107 emit_rewrite(&rewritten);
108 return;
109 }
110
111 if is_rewritable(&cmd) {
112 let rewritten = wrap_single_command(&cmd, &binary);
113 emit_rewrite(&rewritten);
114 }
115}
116
117pub fn handle_rewrite_inline() {
121 let binary = resolve_binary();
122 let args: Vec<String> = std::env::args().collect();
123 if args.len() < 4 {
125 return;
126 }
127 let cmd = args[3..].join(" ");
128
129 if cmd.starts_with("lean-ctx ") || cmd.starts_with(&format!("{binary} ")) {
130 print!("{cmd}");
131 return;
132 }
133
134 if let Some(rewritten) = build_rewrite_compound(&cmd, &binary) {
135 print!("{rewritten}");
136 return;
137 }
138
139 if is_rewritable(&cmd) {
140 let rewritten = wrap_single_command(&cmd, &binary);
141 print!("{rewritten}");
142 return;
143 }
144
145 print!("{cmd}");
146}
147
148fn resolve_binary() -> String {
149 std::env::current_exe()
150 .map(|p| p.to_string_lossy().to_string())
151 .unwrap_or_else(|_| "lean-ctx".to_string())
152}
153
154fn extract_json_field(input: &str, field: &str) -> Option<String> {
155 let pattern = format!("\"{}\":\"", field);
156 let start = input.find(&pattern)? + pattern.len();
157 let rest = &input[start..];
158 let bytes = rest.as_bytes();
159 let mut end = 0;
160 while end < bytes.len() {
161 if bytes[end] == b'\\' && end + 1 < bytes.len() {
162 end += 2;
163 continue;
164 }
165 if bytes[end] == b'"' {
166 break;
167 }
168 end += 1;
169 }
170 if end >= bytes.len() {
171 return None;
172 }
173 let raw = &rest[..end];
174 Some(raw.replace("\\\"", "\"").replace("\\\\", "\\"))
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn is_rewritable_basic() {
183 assert!(is_rewritable("git status"));
184 assert!(is_rewritable("cargo test --lib"));
185 assert!(is_rewritable("npm run build"));
186 assert!(!is_rewritable("echo hello"));
187 assert!(!is_rewritable("cd src"));
188 }
189
190 #[test]
191 fn wrap_single() {
192 let r = wrap_single_command("git status", "lean-ctx");
193 assert_eq!(r, r#"lean-ctx -c "git status""#);
194 }
195
196 #[test]
197 fn wrap_with_quotes() {
198 let r = wrap_single_command(r#"curl -H "Auth" https://api.com"#, "lean-ctx");
199 assert_eq!(r, r#"lean-ctx -c "curl -H \"Auth\" https://api.com""#);
200 }
201
202 #[test]
203 fn compound_rewrite_and_chain() {
204 let result = build_rewrite_compound("cd src && git status && echo done", "lean-ctx");
205 assert_eq!(
206 result,
207 Some(r#"cd src && lean-ctx -c "git status" && echo done"#.into())
208 );
209 }
210
211 #[test]
212 fn compound_rewrite_pipe() {
213 let result = build_rewrite_compound("git log --oneline | head -5", "lean-ctx");
214 assert_eq!(
215 result,
216 Some(r#"lean-ctx -c "git log --oneline" | head -5"#.into())
217 );
218 }
219
220 #[test]
221 fn compound_rewrite_no_match() {
222 let result = build_rewrite_compound("cd src && echo done", "lean-ctx");
223 assert_eq!(result, None);
224 }
225
226 #[test]
227 fn compound_rewrite_multiple_rewritable() {
228 let result = build_rewrite_compound("git add . && cargo test && npm run lint", "lean-ctx");
229 assert_eq!(
230 result,
231 Some(
232 r#"lean-ctx -c "git add ." && lean-ctx -c "cargo test" && lean-ctx -c "npm run lint""#
233 .into()
234 )
235 );
236 }
237
238 #[test]
239 fn compound_rewrite_semicolons() {
240 let result = build_rewrite_compound("git add .; git commit -m 'fix'", "lean-ctx");
241 assert_eq!(
242 result,
243 Some(r#"lean-ctx -c "git add ." ; lean-ctx -c "git commit -m 'fix'""#.into())
244 );
245 }
246
247 #[test]
248 fn compound_rewrite_or_chain() {
249 let result = build_rewrite_compound("git pull || echo failed", "lean-ctx");
250 assert_eq!(
251 result,
252 Some(r#"lean-ctx -c "git pull" || echo failed"#.into())
253 );
254 }
255
256 #[test]
257 fn compound_skips_already_rewritten() {
258 let result = build_rewrite_compound("lean-ctx -c git status && git diff", "lean-ctx");
259 assert_eq!(
260 result,
261 Some(r#"lean-ctx -c git status && lean-ctx -c "git diff""#.into())
262 );
263 }
264
265 #[test]
266 fn single_command_not_compound() {
267 let result = build_rewrite_compound("git status", "lean-ctx");
268 assert_eq!(result, None);
269 }
270
271 #[test]
272 fn extract_field_works() {
273 let input = r#"{"tool_name":"Bash","command":"git status"}"#;
274 assert_eq!(
275 extract_json_field(input, "tool_name"),
276 Some("Bash".to_string())
277 );
278 assert_eq!(
279 extract_json_field(input, "command"),
280 Some("git status".to_string())
281 );
282 }
283
284 #[test]
285 fn extract_field_handles_escaped_quotes() {
286 let input = r#"{"tool_name":"Bash","command":"grep -r \"TODO\" src/"}"#;
287 assert_eq!(
288 extract_json_field(input, "command"),
289 Some(r#"grep -r "TODO" src/"#.to_string())
290 );
291 }
292
293 #[test]
294 fn extract_field_handles_escaped_backslash() {
295 let input = r#"{"tool_name":"Bash","command":"echo \\\"hello\\\""}"#;
296 assert_eq!(
297 extract_json_field(input, "command"),
298 Some(r#"echo \"hello\""#.to_string())
299 );
300 }
301
302 #[test]
303 fn extract_field_handles_complex_curl() {
304 let input = r#"{"tool_name":"Bash","command":"curl -H \"Authorization: Bearer token\" https://api.com"}"#;
305 assert_eq!(
306 extract_json_field(input, "command"),
307 Some(r#"curl -H "Authorization: Bearer token" https://api.com"#.to_string())
308 );
309 }
310}