Skip to main content

lean_ctx/tools/
ctx_shell.rs

1use crate::core::patterns;
2use crate::core::protocol;
3use crate::core::symbol_map::{self, SymbolMap};
4use crate::core::tokens::count_tokens;
5use crate::tools::CrpMode;
6
7const MAX_COMMAND_BYTES: usize = 8192;
8
9/// Validates a shell command before execution. Returns Some(error_message) if
10/// the command should be rejected, None if it's safe to run.
11pub fn validate_command(command: &str) -> Option<String> {
12    if command.len() > MAX_COMMAND_BYTES {
13        return Some(format!(
14            "ERROR: Command too large ({} bytes, limit {}). \
15             If you're writing file content, use the native Write/Edit tool instead. \
16             ctx_shell is for reading command output only (git, cargo, npm, etc.).",
17            command.len(),
18            MAX_COMMAND_BYTES
19        ));
20    }
21
22    if has_file_write_redirect(command) {
23        return Some(
24            "ERROR: ctx_shell detected a file-write command (shell redirect > or >>). \
25             Use the native Write tool to create/modify files. \
26             ctx_shell is ONLY for reading command output (git status, cargo test, npm run, etc.). \
27             File writes via shell cause MCP protocol corruption on large payloads."
28                .to_string(),
29        );
30    }
31
32    let cmd_lower = command.to_lowercase();
33
34    if cmd_lower.starts_with("tee ") || cmd_lower.contains("| tee ") {
35        return Some(
36            "ERROR: ctx_shell detected a file-write command (tee). \
37             Use the native Write tool to create/modify files. \
38             ctx_shell is ONLY for reading command output."
39                .to_string(),
40        );
41    }
42
43    if is_heredoc_file_write(command) {
44        return Some(
45            "ERROR: ctx_shell detected a heredoc writing to a file. \
46             Use the native Write tool to create/modify files. \
47             ctx_shell is ONLY for reading command output. \
48             Note: heredocs for input piping (e.g. psql <<EOF) are allowed."
49                .to_string(),
50        );
51    }
52
53    None
54}
55
56/// Returns true only for heredocs that redirect to files (the dangerous pattern).
57/// Legitimate heredoc uses (input piping, inline scripts) are allowed through.
58fn is_heredoc_file_write(command: &str) -> bool {
59    let has_heredoc = command.contains("<<");
60    if !has_heredoc {
61        return false;
62    }
63    // Only block: heredoc combined with file output redirect
64    // e.g. `cat <<EOF > file.txt` or `cat <<'EOF' >> output.log`
65    let cmd_lower = command.to_lowercase();
66    let heredoc_patterns = ["<<eof", "<<'eof'", "<<\"eof\"", "<<end", "<<'end'"];
67    let has_known_heredoc = heredoc_patterns.iter().any(|p| cmd_lower.contains(p));
68    if !has_known_heredoc {
69        return false;
70    }
71    has_file_write_redirect(command)
72}
73
74/// Detects shell redirect operators (`>` or `>>`) that write to files.
75/// Ignores `>` inside quotes, `2>` (stderr), `/dev/null`, and comparison operators.
76fn has_file_write_redirect(command: &str) -> bool {
77    let bytes = command.as_bytes();
78    let len = bytes.len();
79    let mut i = 0;
80    let mut in_single_quote = false;
81    let mut in_double_quote = false;
82
83    while i < len {
84        let c = bytes[i];
85        if c == b'\'' && !in_double_quote {
86            in_single_quote = !in_single_quote;
87        } else if c == b'"' && !in_single_quote {
88            in_double_quote = !in_double_quote;
89        } else if c == b'>' && !in_single_quote && !in_double_quote {
90            if i > 0 && bytes[i - 1] == b'2' {
91                i += 1;
92                continue;
93            }
94            let target_start = if i + 1 < len && bytes[i + 1] == b'>' {
95                i + 2
96            } else {
97                i + 1
98            };
99            let target: String = command[target_start..]
100                .trim_start()
101                .chars()
102                .take_while(|c| !c.is_whitespace())
103                .collect();
104            if target == "/dev/null" {
105                i += 1;
106                continue;
107            }
108            if !target.is_empty() {
109                return true;
110            }
111        }
112        i += 1;
113    }
114    false
115}
116
117/// On Windows cmd.exe, `;` is not a valid command separator.
118/// Convert `cmd1; cmd2` to `cmd1 && cmd2` when running under cmd.exe.
119pub fn normalize_command_for_shell(command: &str) -> String {
120    if !cfg!(windows) {
121        return command.to_string();
122    }
123    let (_, flag) = crate::shell::shell_and_flag();
124    if flag != "/C" {
125        return command.to_string();
126    }
127    let bytes = command.as_bytes();
128    let mut result = Vec::with_capacity(bytes.len() + 16);
129    let mut in_single = false;
130    let mut in_double = false;
131    for (i, &b) in bytes.iter().enumerate() {
132        if b == b'\'' && !in_double {
133            in_single = !in_single;
134        } else if b == b'"' && !in_single {
135            in_double = !in_double;
136        } else if b == b';' && !in_single && !in_double {
137            result.extend_from_slice(b" && ");
138            continue;
139        }
140        result.push(b);
141        let _ = i;
142    }
143    String::from_utf8(result).unwrap_or_else(|_| command.to_string())
144}
145
146pub fn handle(command: &str, output: &str, crp_mode: CrpMode) -> String {
147    let original_tokens = count_tokens(output);
148
149    if contains_auth_flow(output) {
150        let savings = protocol::format_savings(original_tokens, original_tokens);
151        return format!(
152            "{output}\n[lean-ctx: auth/device-code flow detected — output preserved uncompressed]\n{savings}"
153        );
154    }
155
156    let raw_compressed = match patterns::compress_output(command, output) {
157        Some(c) => crate::core::compressor::safeguard_ratio(output, &c),
158        None if is_search_command(command) => {
159            let stripped = crate::core::compressor::strip_ansi(output);
160            stripped.to_string()
161        }
162        None => generic_compress(output),
163    };
164    let compressed = crate::core::compressor::verbatim_compact(&raw_compressed);
165
166    if crp_mode.is_tdd() && looks_like_code(&compressed) {
167        let ext = detect_ext_from_command(command);
168        let mut sym = SymbolMap::new();
169        let idents = symbol_map::extract_identifiers(&compressed, ext);
170        for ident in &idents {
171            sym.register(ident);
172        }
173        if !sym.is_empty() {
174            let mapped = sym.apply(&compressed);
175            let sym_table = sym.format_table();
176            let result = format!("{mapped}{sym_table}");
177            let sent = count_tokens(&result);
178            let savings = protocol::format_savings(original_tokens, sent);
179            return format!("{result}\n{savings}");
180        }
181    }
182
183    let sent = count_tokens(&compressed);
184    let savings = protocol::format_savings(original_tokens, sent);
185
186    format!("{compressed}\n{savings}")
187}
188
189fn is_search_command(command: &str) -> bool {
190    let cmd = command.trim_start();
191    cmd.starts_with("grep ")
192        || cmd.starts_with("rg ")
193        || cmd.starts_with("find ")
194        || cmd.starts_with("fd ")
195        || cmd.starts_with("ag ")
196        || cmd.starts_with("ack ")
197}
198
199fn generic_compress(output: &str) -> String {
200    let output = crate::core::compressor::strip_ansi(output);
201    let lines: Vec<&str> = output
202        .lines()
203        .filter(|l| {
204            let t = l.trim();
205            !t.is_empty()
206        })
207        .collect();
208
209    if lines.len() <= 20 {
210        return lines.join("\n");
211    }
212
213    let show_count = (lines.len() / 3).min(30);
214    let half = show_count / 2;
215    let first = &lines[..half];
216    let last = &lines[lines.len() - half..];
217    let omitted = lines.len() - (half * 2);
218    format!(
219        "{}\n[truncated: showing {}/{} lines, {} omitted. Use raw=true for full output.]\n{}",
220        first.join("\n"),
221        half * 2,
222        lines.len(),
223        omitted,
224        last.join("\n")
225    )
226}
227
228fn looks_like_code(text: &str) -> bool {
229    let indicators = [
230        "fn ",
231        "pub ",
232        "let ",
233        "const ",
234        "impl ",
235        "struct ",
236        "enum ",
237        "function ",
238        "class ",
239        "import ",
240        "export ",
241        "def ",
242        "async ",
243        "=>",
244        "->",
245        "::",
246        "self.",
247        "this.",
248    ];
249    let total_lines = text.lines().count();
250    if total_lines < 3 {
251        return false;
252    }
253    let code_lines = text
254        .lines()
255        .filter(|l| indicators.iter().any(|i| l.contains(i)))
256        .count();
257    code_lines as f64 / total_lines as f64 > 0.15
258}
259
260fn detect_ext_from_command(command: &str) -> &str {
261    let cmd = command.to_lowercase();
262    if cmd.contains("cargo") || cmd.contains(".rs") {
263        "rs"
264    } else if cmd.contains("npm")
265        || cmd.contains("node")
266        || cmd.contains(".ts")
267        || cmd.contains(".js")
268    {
269        "ts"
270    } else if cmd.contains("python") || cmd.contains("pip") || cmd.contains(".py") {
271        "py"
272    } else if cmd.contains("go ") || cmd.contains(".go") {
273        "go"
274    } else {
275        "rs"
276    }
277}
278
279/// Detects OAuth device code flow output that must not be compressed.
280/// Uses a two-tier approach: strong signals match alone (very specific to
281/// device code flows), weak signals require a URL/domain in the same output.
282pub fn contains_auth_flow(output: &str) -> bool {
283    let lower = output.to_lowercase();
284
285    const STRONG_SIGNALS: &[&str] = &[
286        "devicelogin",
287        "deviceauth",
288        "device_code",
289        "device code",
290        "device-code",
291        "verification_uri",
292        "user_code",
293        "one-time code",
294    ];
295
296    if STRONG_SIGNALS.iter().any(|s| lower.contains(s)) {
297        return true;
298    }
299
300    const WEAK_SIGNALS: &[&str] = &[
301        "enter the code",
302        "enter this code",
303        "enter code:",
304        "use the code",
305        "use a web browser to open",
306        "open the page",
307        "authenticate by visiting",
308        "sign in with the code",
309        "sign in using a code",
310        "verification code",
311        "authorize this device",
312        "waiting for authentication",
313        "waiting for login",
314        "waiting for you to authenticate",
315        "open your browser",
316        "open in your browser",
317    ];
318
319    let has_weak_signal = WEAK_SIGNALS.iter().any(|s| lower.contains(s));
320    if !has_weak_signal {
321        return false;
322    }
323
324    lower.contains("http://") || lower.contains("https://")
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    #[test]
332    fn normalize_cmd_no_change_on_unix() {
333        if cfg!(windows) {
334            return;
335        }
336        assert_eq!(
337            normalize_command_for_shell("cd /tmp; ls -la"),
338            "cd /tmp; ls -la"
339        );
340    }
341
342    #[test]
343    fn validate_allows_safe_commands() {
344        assert!(validate_command("git status").is_none());
345        assert!(validate_command("cargo test").is_none());
346        assert!(validate_command("npm run build").is_none());
347        assert!(validate_command("ls -la").is_none());
348    }
349
350    #[test]
351    fn validate_blocks_file_writes() {
352        assert!(validate_command("echo 'data' > output.txt").is_some());
353        assert!(validate_command("tee /tmp/file.txt").is_some());
354        assert!(validate_command("printf 'hello' > test.txt").is_some());
355    }
356
357    #[test]
358    fn validate_blocks_heredoc_with_file_redirect() {
359        assert!(validate_command("cat > file.py <<'EOF'\nprint('hi')\nEOF").is_some());
360        assert!(validate_command("cat <<EOF > output.txt\nhello\nEOF").is_some());
361        assert!(validate_command("cat <<'END' >> logfile.txt\ndata\nEND").is_some());
362    }
363
364    #[test]
365    fn validate_allows_heredoc_without_file_redirect() {
366        assert!(validate_command("cat <<EOF\nhello world\nEOF").is_none());
367        assert!(validate_command("psql -d mydb <<EOF\nSELECT 1;\nEOF").is_none());
368        assert!(
369            validate_command("git commit -m \"$(cat <<'EOF'\nfix: something\nEOF\n)\"").is_none()
370        );
371        assert!(validate_command("grep pattern <<EOF\nfoo\nbar\nEOF").is_none());
372    }
373
374    #[test]
375    fn validate_blocks_oversized_commands() {
376        let huge = "x".repeat(MAX_COMMAND_BYTES + 1);
377        let result = validate_command(&huge);
378        assert!(result.is_some());
379        assert!(result.unwrap().contains("too large"));
380    }
381
382    #[test]
383    fn validate_allows_cat_without_redirect() {
384        assert!(validate_command("cat file.txt").is_none());
385    }
386
387    // --- Auth flow detection: strong signals (no URL needed) ---
388
389    #[test]
390    fn auth_flow_detects_azure_device_code() {
391        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.";
392        assert!(contains_auth_flow(output));
393    }
394
395    #[test]
396    fn auth_flow_detects_gh_auth_one_time_code() {
397        let output =
398            "! First copy your one-time code: ABCD-1234\n- Press Enter to open github.com in your browser...";
399        assert!(contains_auth_flow(output));
400    }
401
402    #[test]
403    fn auth_flow_detects_device_code_json() {
404        let output = r#"{"device_code":"abc123","user_code":"ABCD-1234","verification_uri":"https://example.com/activate"}"#;
405        assert!(contains_auth_flow(output));
406    }
407
408    #[test]
409    fn auth_flow_detects_verification_uri_field() {
410        let output =
411            r#"{"verification_uri": "https://login.microsoftonline.com/common/oauth2/deviceauth"}"#;
412        assert!(contains_auth_flow(output));
413    }
414
415    #[test]
416    fn auth_flow_detects_user_code_field() {
417        let output = r#"{"user_code": "FGHJK-LMNOP", "expires_in": 900}"#;
418        assert!(contains_auth_flow(output));
419    }
420
421    // --- Auth flow detection: weak signals (require URL) ---
422
423    #[test]
424    fn auth_flow_detects_gcloud_with_url() {
425        let output = "Go to the following link in your browser:\n\n    https://accounts.google.com/o/oauth2/auth?response_type=code\n\nEnter verification code: ";
426        assert!(contains_auth_flow(output));
427    }
428
429    #[test]
430    fn auth_flow_detects_aws_sso_with_url() {
431        let output = "If the browser does not open, open the following URL:\nhttps://device.sso.us-east-1.amazonaws.com/\n\nThen enter the code:\nABCD-EFGH";
432        assert!(contains_auth_flow(output));
433    }
434
435    #[test]
436    fn auth_flow_detects_firebase_with_url() {
437        let output = "Visit this URL on this device to log in:\nhttps://accounts.google.com/o/oauth2/auth?...\n\nWaiting for authentication...";
438        assert!(contains_auth_flow(output));
439    }
440
441    #[test]
442    fn auth_flow_detects_generic_browser_open_with_url() {
443        let output =
444            "Open your browser to https://login.example.com/device and enter the code XYZW-1234";
445        assert!(contains_auth_flow(output));
446    }
447
448    // --- False positive protection ---
449
450    #[test]
451    fn auth_flow_ignores_normal_build_output() {
452        let output = "Compiling lean-ctx v2.21.9\nFinished release profile\n";
453        assert!(!contains_auth_flow(output));
454    }
455
456    #[test]
457    fn auth_flow_ignores_git_output() {
458        let output = "On branch main\nYour branch is up to date with 'origin/main'.\nnothing to commit, working tree clean";
459        assert!(!contains_auth_flow(output));
460    }
461
462    #[test]
463    fn auth_flow_ignores_npm_install_output() {
464        let output = "added 150 packages in 3s\n\n24 packages are looking for funding\n  run `npm fund` for details\nhttps://npmjs.com/package/lean-ctx";
465        assert!(!contains_auth_flow(output));
466    }
467
468    #[test]
469    fn auth_flow_ignores_docs_mentioning_auth() {
470        let output = "The authorization code grant type is the most common OAuth flow.\nSee https://oauth.net/2/grant-types/ for details.";
471        assert!(!contains_auth_flow(output));
472    }
473
474    #[test]
475    fn auth_flow_weak_signal_requires_url() {
476        let output = "Please enter the code ABC123 in the terminal";
477        assert!(!contains_auth_flow(output));
478    }
479
480    #[test]
481    fn auth_flow_weak_signal_without_url_is_ignored() {
482        let output = "Waiting for authentication to complete... done!";
483        assert!(!contains_auth_flow(output));
484    }
485
486    #[test]
487    fn auth_flow_ignores_virtualenv_activate() {
488        let output = "Created virtualenv at .venv\nRun: source .venv/bin/activate";
489        assert!(!contains_auth_flow(output));
490    }
491
492    #[test]
493    fn auth_flow_ignores_api_response_with_code_field() {
494        let output = r#"{"status": "ok", "code": 200, "message": "success"}"#;
495        assert!(!contains_auth_flow(output));
496    }
497
498    // --- Integration: handle() preserves auth flow ---
499
500    #[test]
501    fn handle_preserves_auth_flow_output_fully() {
502        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.\nWaiting for you...\nLine 3\nLine 4\nLine 5\nLine 6\nLine 7\nLine 8\nLine 9\nLine 10\nLine 11\nLine 12\nLine 13";
503        let result = handle("az login --use-device-code", output, CrpMode::Off);
504        assert!(result.contains("ABCD1234"), "auth code must be preserved");
505        assert!(result.contains("devicelogin"), "URL must be preserved");
506        assert!(
507            result.contains("auth/device-code flow detected"),
508            "detection note must be present"
509        );
510        assert!(
511            result.contains("Line 13"),
512            "all lines must be preserved (no truncation)"
513        );
514    }
515
516    #[test]
517    fn handle_compresses_normal_output_not_auth() {
518        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i} of output")).collect();
519        let output = lines.join("\n");
520        let result = handle("some-tool check", &output, CrpMode::Off);
521        assert!(
522            !result.contains("auth/device-code flow detected"),
523            "normal output must not trigger auth detection"
524        );
525        assert!(
526            result.len() < output.len() + 100,
527            "normal output should be compressed, not inflated"
528        );
529    }
530
531    #[test]
532    fn is_search_command_detects_grep() {
533        assert!(is_search_command("grep -r pattern src/"));
534        assert!(is_search_command("rg pattern src/"));
535        assert!(is_search_command("find . -name '*.rs'"));
536        assert!(is_search_command("fd pattern"));
537        assert!(is_search_command("ag pattern src/"));
538        assert!(is_search_command("ack pattern"));
539    }
540
541    #[test]
542    fn is_search_command_rejects_non_search() {
543        assert!(!is_search_command("cargo build"));
544        assert!(!is_search_command("git status"));
545        assert!(!is_search_command("npm install"));
546        assert!(!is_search_command("cat file.rs"));
547    }
548
549    #[test]
550    fn generic_compress_preserves_short_output() {
551        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i}")).collect();
552        let output = lines.join("\n");
553        let result = generic_compress(&output);
554        assert_eq!(result, output);
555    }
556
557    #[test]
558    fn generic_compress_scales_with_length() {
559        let lines: Vec<String> = (1..=60).map(|i| format!("Line {i}")).collect();
560        let output = lines.join("\n");
561        let result = generic_compress(&output);
562        assert!(result.contains("truncated"));
563        let shown_count = result.lines().count();
564        assert!(
565            shown_count > 10,
566            "should show more than old 6-line limit, got {shown_count}"
567        );
568        assert!(shown_count < 60, "should be truncated, not full output");
569    }
570
571    #[test]
572    fn handle_preserves_search_results() {
573        let lines: Vec<String> = (1..=30)
574            .map(|i| format!("src/file{i}.rs:42: fn search_result()"))
575            .collect();
576        let output = lines.join("\n");
577        let result = handle("rg search_result src/", &output, CrpMode::Off);
578        for i in 1..=30 {
579            assert!(
580                result.contains(&format!("file{i}")),
581                "search result file{i} should be preserved in output"
582            );
583        }
584    }
585}