Skip to main content

lean_ctx/tools/
ctx_shell.rs

1use crate::core::patterns;
2use crate::core::protocol;
3use crate::core::symbol_map::{self, SymbolMap};
4use crate::core::tokens::count_tokens;
5use crate::tools::CrpMode;
6
7const MAX_COMMAND_BYTES: usize = 8192;
8
9/// Validates a shell command before execution. Returns Some(error_message) if
10/// the command should be rejected, None if it's safe to run.
11pub fn validate_command(command: &str) -> Option<String> {
12    if command.len() > MAX_COMMAND_BYTES {
13        return Some(format!(
14            "ERROR: Command too large ({} bytes, limit {}). \
15             If you're writing file content, use the native Write/Edit tool instead. \
16             ctx_shell is for reading command output only (git, cargo, npm, etc.).",
17            command.len(),
18            MAX_COMMAND_BYTES
19        ));
20    }
21
22    if has_file_write_redirect(command) {
23        return Some(
24            "ERROR: ctx_shell detected a file-write command (shell redirect > or >>). \
25             Use the native Write tool to create/modify files. \
26             ctx_shell is ONLY for reading command output (git status, cargo test, npm run, etc.). \
27             File writes via shell cause MCP protocol corruption on large payloads."
28                .to_string(),
29        );
30    }
31
32    let cmd_lower = command.to_lowercase();
33
34    if cmd_lower.starts_with("tee ") || cmd_lower.contains("| tee ") {
35        return Some(
36            "ERROR: ctx_shell detected a file-write command (tee). \
37             Use the native Write tool to create/modify files. \
38             ctx_shell is ONLY for reading command output."
39                .to_string(),
40        );
41    }
42
43    if is_heredoc_file_write(command) {
44        return Some(
45            "ERROR: ctx_shell detected a heredoc writing to a file. \
46             Use the native Write tool to create/modify files. \
47             ctx_shell is ONLY for reading command output. \
48             Note: heredocs for input piping (e.g. psql <<EOF) are allowed."
49                .to_string(),
50        );
51    }
52
53    None
54}
55
56/// Returns true only for heredocs that redirect to files (the dangerous pattern).
57/// Legitimate heredoc uses (input piping, inline scripts) are allowed through.
58fn is_heredoc_file_write(command: &str) -> bool {
59    let has_heredoc = command.contains("<<");
60    if !has_heredoc {
61        return false;
62    }
63    // Only block: heredoc combined with file output redirect
64    // e.g. `cat <<EOF > file.txt` or `cat <<'EOF' >> output.log`
65    let cmd_lower = command.to_lowercase();
66    let heredoc_patterns = ["<<eof", "<<'eof'", "<<\"eof\"", "<<end", "<<'end'"];
67    let has_known_heredoc = heredoc_patterns.iter().any(|p| cmd_lower.contains(p));
68    if !has_known_heredoc {
69        return false;
70    }
71    has_file_write_redirect(command)
72}
73
74/// Detects shell redirect operators (`>` or `>>`) that write to files.
75/// Ignores `>` inside quotes, `2>` (stderr), `/dev/null`, and comparison operators.
76fn has_file_write_redirect(command: &str) -> bool {
77    let bytes = command.as_bytes();
78    let len = bytes.len();
79    let mut i = 0;
80    let mut in_single_quote = false;
81    let mut in_double_quote = false;
82
83    while i < len {
84        let c = bytes[i];
85        if c == b'\'' && !in_double_quote {
86            in_single_quote = !in_single_quote;
87        } else if c == b'"' && !in_single_quote {
88            in_double_quote = !in_double_quote;
89        } else if c == b'>' && !in_single_quote && !in_double_quote {
90            if i > 0 && bytes[i - 1] == b'2' {
91                i += 1;
92                continue;
93            }
94            let target_start = if i + 1 < len && bytes[i + 1] == b'>' {
95                i + 2
96            } else {
97                i + 1
98            };
99            let target: String = command[target_start..]
100                .trim_start()
101                .chars()
102                .take_while(|c| !c.is_whitespace())
103                .collect();
104            if target == "/dev/null" {
105                i += 1;
106                continue;
107            }
108            if !target.is_empty() {
109                return true;
110            }
111        }
112        i += 1;
113    }
114    false
115}
116
117/// On Windows cmd.exe, `;` is not a valid command separator.
118/// Convert `cmd1; cmd2` to `cmd1 && cmd2` when running under cmd.exe.
119pub fn normalize_command_for_shell(command: &str) -> String {
120    if !cfg!(windows) {
121        return command.to_string();
122    }
123    let (_, flag) = crate::shell::shell_and_flag();
124    if flag != "/C" {
125        return command.to_string();
126    }
127    let bytes = command.as_bytes();
128    let mut result = Vec::with_capacity(bytes.len() + 16);
129    let mut in_single = false;
130    let mut in_double = false;
131    for (i, &b) in bytes.iter().enumerate() {
132        if b == b'\'' && !in_double {
133            in_single = !in_single;
134        } else if b == b'"' && !in_single {
135            in_double = !in_double;
136        } else if b == b';' && !in_single && !in_double {
137            result.extend_from_slice(b" && ");
138            continue;
139        }
140        result.push(b);
141        let _ = i;
142    }
143    String::from_utf8(result).unwrap_or_else(|_| command.to_string())
144}
145
146/// Compresses shell command output using pattern matching, dedup, and symbol mapping.
147pub fn handle(command: &str, output: &str, crp_mode: CrpMode) -> String {
148    let original_tokens = count_tokens(output);
149
150    if !is_search_command(command) && contains_auth_flow(output) {
151        let savings = protocol::format_savings(original_tokens, original_tokens);
152        return format!(
153            "{output}\n[lean-ctx: auth/device-code flow detected — output preserved uncompressed]\n{savings}"
154        );
155    }
156
157    let raw_compressed = match patterns::compress_output(command, output) {
158        Some(c) if is_search_command(command) => {
159            if count_tokens(&c) <= count_tokens(output) {
160                c
161            } else {
162                output.to_string()
163            }
164        }
165        Some(c) => crate::core::compressor::safeguard_ratio(output, &c),
166        None if is_search_command(command) => {
167            let stripped = crate::core::compressor::strip_ansi(output);
168            stripped.clone()
169        }
170        None => generic_compress(output),
171    };
172    let compressed = crate::core::compressor::verbatim_compact(&raw_compressed);
173
174    if crp_mode.is_tdd() && looks_like_code(&compressed) {
175        let ext = detect_ext_from_command(command);
176        let mut sym = SymbolMap::new();
177        let idents = symbol_map::extract_identifiers(&compressed, ext);
178        for ident in &idents {
179            sym.register(ident);
180        }
181        if !sym.is_empty() {
182            let mapped = sym.apply(&compressed);
183            let sym_table = sym.format_table();
184            let result = format!("{mapped}{sym_table}");
185            let sent = count_tokens(&result);
186            let savings = protocol::format_savings(original_tokens, sent);
187            return format!("{result}\n{savings}");
188        }
189    }
190
191    let sent = count_tokens(&compressed);
192    let savings = protocol::format_savings(original_tokens, sent);
193
194    format!("{compressed}\n{savings}")
195}
196
197fn is_search_command(command: &str) -> bool {
198    let cmd = command.trim_start();
199    cmd.starts_with("grep ")
200        || cmd.starts_with("rg ")
201        || cmd.starts_with("find ")
202        || cmd.starts_with("fd ")
203        || cmd.starts_with("ag ")
204        || cmd.starts_with("ack ")
205}
206
207fn generic_compress(output: &str) -> String {
208    let output = crate::core::compressor::strip_ansi(output);
209    let lines: Vec<&str> = output
210        .lines()
211        .filter(|l| {
212            let t = l.trim();
213            !t.is_empty()
214        })
215        .collect();
216
217    if lines.len() <= 20 {
218        return lines.join("\n");
219    }
220
221    let show_count = (lines.len() / 3).min(30);
222    let half = show_count / 2;
223    let first = &lines[..half];
224    let last = &lines[lines.len() - half..];
225    let omitted = lines.len() - (half * 2);
226    format!(
227        "{}\n[truncated: showing {}/{} lines, {} omitted. Use raw=true for full output.]\n{}",
228        first.join("\n"),
229        half * 2,
230        lines.len(),
231        omitted,
232        last.join("\n")
233    )
234}
235
236fn looks_like_code(text: &str) -> bool {
237    let indicators = [
238        "fn ",
239        "pub ",
240        "let ",
241        "const ",
242        "impl ",
243        "struct ",
244        "enum ",
245        "function ",
246        "class ",
247        "import ",
248        "export ",
249        "def ",
250        "async ",
251        "=>",
252        "->",
253        "::",
254        "self.",
255        "this.",
256    ];
257    let total_lines = text.lines().count();
258    if total_lines < 3 {
259        return false;
260    }
261    let code_lines = text
262        .lines()
263        .filter(|l| indicators.iter().any(|i| l.contains(i)))
264        .count();
265    code_lines as f64 / total_lines as f64 > 0.15
266}
267
268fn detect_ext_from_command(command: &str) -> &str {
269    let cmd = command.to_lowercase();
270    if cmd.contains("cargo") || cmd.contains(".rs") {
271        "rs"
272    } else if cmd.contains("npm")
273        || cmd.contains("node")
274        || cmd.contains(".ts")
275        || cmd.contains(".js")
276    {
277        "ts"
278    } else if cmd.contains("python") || cmd.contains("pip") || cmd.contains(".py") {
279        "py"
280    } else if cmd.contains("go ") || cmd.contains(".go") {
281        "go"
282    } else {
283        "rs"
284    }
285}
286
287/// Detects OAuth device code flow output that must not be compressed.
288/// Uses a two-tier approach: strong signals match alone (very specific to
289/// device code flows), weak signals require a URL/domain in the same output.
290pub fn contains_auth_flow(output: &str) -> bool {
291    let lower = output.to_lowercase();
292
293    const STRONG_SIGNALS: &[&str] = &[
294        "devicelogin",
295        "deviceauth",
296        "device_code",
297        "device code",
298        "device-code",
299        "verification_uri",
300        "user_code",
301        "one-time code",
302    ];
303
304    if STRONG_SIGNALS.iter().any(|s| lower.contains(s)) {
305        return true;
306    }
307
308    const WEAK_SIGNALS: &[&str] = &[
309        "enter the code",
310        "enter this code",
311        "enter code:",
312        "use the code",
313        "use a web browser to open",
314        "open the page",
315        "authenticate by visiting",
316        "sign in with the code",
317        "sign in using a code",
318        "verification code",
319        "authorize this device",
320        "waiting for authentication",
321        "waiting for login",
322        "waiting for you to authenticate",
323        "open your browser",
324        "open in your browser",
325    ];
326
327    let has_weak_signal = WEAK_SIGNALS.iter().any(|s| lower.contains(s));
328    if !has_weak_signal {
329        return false;
330    }
331
332    lower.contains("http://") || lower.contains("https://")
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn normalize_cmd_no_change_on_unix() {
341        if cfg!(windows) {
342            return;
343        }
344        assert_eq!(
345            normalize_command_for_shell("cd /tmp; ls -la"),
346            "cd /tmp; ls -la"
347        );
348    }
349
350    #[test]
351    fn validate_allows_safe_commands() {
352        assert!(validate_command("git status").is_none());
353        assert!(validate_command("cargo test").is_none());
354        assert!(validate_command("npm run build").is_none());
355        assert!(validate_command("ls -la").is_none());
356    }
357
358    #[test]
359    fn validate_blocks_file_writes() {
360        assert!(validate_command("echo 'data' > output.txt").is_some());
361        assert!(validate_command("tee /tmp/file.txt").is_some());
362        assert!(validate_command("printf 'hello' > test.txt").is_some());
363    }
364
365    #[test]
366    fn validate_blocks_heredoc_with_file_redirect() {
367        assert!(validate_command("cat > file.py <<'EOF'\nprint('hi')\nEOF").is_some());
368        assert!(validate_command("cat <<EOF > output.txt\nhello\nEOF").is_some());
369        assert!(validate_command("cat <<'END' >> logfile.txt\ndata\nEND").is_some());
370    }
371
372    #[test]
373    fn validate_allows_heredoc_without_file_redirect() {
374        assert!(validate_command("cat <<EOF\nhello world\nEOF").is_none());
375        assert!(validate_command("psql -d mydb <<EOF\nSELECT 1;\nEOF").is_none());
376        assert!(
377            validate_command("git commit -m \"$(cat <<'EOF'\nfix: something\nEOF\n)\"").is_none()
378        );
379        assert!(validate_command("grep pattern <<EOF\nfoo\nbar\nEOF").is_none());
380    }
381
382    #[test]
383    fn validate_blocks_oversized_commands() {
384        let huge = "x".repeat(MAX_COMMAND_BYTES + 1);
385        let result = validate_command(&huge);
386        assert!(result.is_some());
387        assert!(result.unwrap().contains("too large"));
388    }
389
390    #[test]
391    fn validate_allows_cat_without_redirect() {
392        assert!(validate_command("cat file.txt").is_none());
393    }
394
395    // --- Auth flow detection: strong signals (no URL needed) ---
396
397    #[test]
398    fn auth_flow_detects_azure_device_code() {
399        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.";
400        assert!(contains_auth_flow(output));
401    }
402
403    #[test]
404    fn auth_flow_detects_gh_auth_one_time_code() {
405        let output =
406            "! First copy your one-time code: ABCD-1234\n- Press Enter to open github.com in your browser...";
407        assert!(contains_auth_flow(output));
408    }
409
410    #[test]
411    fn auth_flow_detects_device_code_json() {
412        let output = r#"{"device_code":"abc123","user_code":"ABCD-1234","verification_uri":"https://example.com/activate"}"#;
413        assert!(contains_auth_flow(output));
414    }
415
416    #[test]
417    fn auth_flow_detects_verification_uri_field() {
418        let output =
419            r#"{"verification_uri": "https://login.microsoftonline.com/common/oauth2/deviceauth"}"#;
420        assert!(contains_auth_flow(output));
421    }
422
423    #[test]
424    fn auth_flow_detects_user_code_field() {
425        let output = r#"{"user_code": "FGHJK-LMNOP", "expires_in": 900}"#;
426        assert!(contains_auth_flow(output));
427    }
428
429    // --- Auth flow detection: weak signals (require URL) ---
430
431    #[test]
432    fn auth_flow_detects_gcloud_with_url() {
433        let output = "Go to the following link in your browser:\n\n    https://accounts.google.com/o/oauth2/auth?response_type=code\n\nEnter verification code: ";
434        assert!(contains_auth_flow(output));
435    }
436
437    #[test]
438    fn auth_flow_detects_aws_sso_with_url() {
439        let output = "If the browser does not open, open the following URL:\nhttps://device.sso.us-east-1.amazonaws.com/\n\nThen enter the code:\nABCD-EFGH";
440        assert!(contains_auth_flow(output));
441    }
442
443    #[test]
444    fn auth_flow_detects_firebase_with_url() {
445        let output = "Visit this URL on this device to log in:\nhttps://accounts.google.com/o/oauth2/auth?...\n\nWaiting for authentication...";
446        assert!(contains_auth_flow(output));
447    }
448
449    #[test]
450    fn auth_flow_detects_generic_browser_open_with_url() {
451        let output =
452            "Open your browser to https://login.example.com/device and enter the code XYZW-1234";
453        assert!(contains_auth_flow(output));
454    }
455
456    // --- False positive protection ---
457
458    #[test]
459    fn auth_flow_ignores_normal_build_output() {
460        let output = "Compiling lean-ctx v2.21.9\nFinished release profile\n";
461        assert!(!contains_auth_flow(output));
462    }
463
464    #[test]
465    fn auth_flow_ignores_git_output() {
466        let output = "On branch main\nYour branch is up to date with 'origin/main'.\nnothing to commit, working tree clean";
467        assert!(!contains_auth_flow(output));
468    }
469
470    #[test]
471    fn auth_flow_ignores_npm_install_output() {
472        let output = "added 150 packages in 3s\n\n24 packages are looking for funding\n  run `npm fund` for details\nhttps://npmjs.com/package/lean-ctx";
473        assert!(!contains_auth_flow(output));
474    }
475
476    #[test]
477    fn auth_flow_ignores_docs_mentioning_auth() {
478        let output = "The authorization code grant type is the most common OAuth flow.\nSee https://oauth.net/2/grant-types/ for details.";
479        assert!(!contains_auth_flow(output));
480    }
481
482    #[test]
483    fn auth_flow_weak_signal_requires_url() {
484        let output = "Please enter the code ABC123 in the terminal";
485        assert!(!contains_auth_flow(output));
486    }
487
488    #[test]
489    fn auth_flow_weak_signal_without_url_is_ignored() {
490        let output = "Waiting for authentication to complete... done!";
491        assert!(!contains_auth_flow(output));
492    }
493
494    #[test]
495    fn auth_flow_ignores_virtualenv_activate() {
496        let output = "Created virtualenv at .venv\nRun: source .venv/bin/activate";
497        assert!(!contains_auth_flow(output));
498    }
499
500    #[test]
501    fn auth_flow_ignores_api_response_with_code_field() {
502        let output = r#"{"status": "ok", "code": 200, "message": "success"}"#;
503        assert!(!contains_auth_flow(output));
504    }
505
506    // --- Integration: handle() preserves auth flow ---
507
508    #[test]
509    fn handle_preserves_auth_flow_output_fully() {
510        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.\nWaiting for you...\nLine 3\nLine 4\nLine 5\nLine 6\nLine 7\nLine 8\nLine 9\nLine 10\nLine 11\nLine 12\nLine 13";
511        let result = handle("az login --use-device-code", output, CrpMode::Off);
512        assert!(result.contains("ABCD1234"), "auth code must be preserved");
513        assert!(result.contains("devicelogin"), "URL must be preserved");
514        assert!(
515            result.contains("auth/device-code flow detected"),
516            "detection note must be present"
517        );
518        assert!(
519            result.contains("Line 13"),
520            "all lines must be preserved (no truncation)"
521        );
522    }
523
524    #[test]
525    fn handle_compresses_normal_output_not_auth() {
526        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i} of output")).collect();
527        let output = lines.join("\n");
528        let result = handle("some-tool check", &output, CrpMode::Off);
529        assert!(
530            !result.contains("auth/device-code flow detected"),
531            "normal output must not trigger auth detection"
532        );
533        assert!(
534            result.len() < output.len() + 100,
535            "normal output should be compressed, not inflated"
536        );
537    }
538
539    #[test]
540    fn is_search_command_detects_grep() {
541        assert!(is_search_command("grep -r pattern src/"));
542        assert!(is_search_command("rg pattern src/"));
543        assert!(is_search_command("find . -name '*.rs'"));
544        assert!(is_search_command("fd pattern"));
545        assert!(is_search_command("ag pattern src/"));
546        assert!(is_search_command("ack pattern"));
547    }
548
549    #[test]
550    fn is_search_command_rejects_non_search() {
551        assert!(!is_search_command("cargo build"));
552        assert!(!is_search_command("git status"));
553        assert!(!is_search_command("npm install"));
554        assert!(!is_search_command("cat file.rs"));
555    }
556
557    #[test]
558    fn generic_compress_preserves_short_output() {
559        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i}")).collect();
560        let output = lines.join("\n");
561        let result = generic_compress(&output);
562        assert_eq!(result, output);
563    }
564
565    #[test]
566    fn generic_compress_scales_with_length() {
567        let lines: Vec<String> = (1..=60).map(|i| format!("Line {i}")).collect();
568        let output = lines.join("\n");
569        let result = generic_compress(&output);
570        assert!(result.contains("truncated"));
571        let shown_count = result.lines().count();
572        assert!(
573            shown_count > 10,
574            "should show more than old 6-line limit, got {shown_count}"
575        );
576        assert!(shown_count < 60, "should be truncated, not full output");
577    }
578
579    #[test]
580    fn handle_preserves_search_results() {
581        let lines: Vec<String> = (1..=30)
582            .map(|i| format!("src/file{i}.rs:42: fn search_result()"))
583            .collect();
584        let output = lines.join("\n");
585        let result = handle("rg search_result src/", &output, CrpMode::Off);
586        for i in 1..=30 {
587            assert!(
588                result.contains(&format!("file{i}")),
589                "search result file{i} should be preserved in output"
590            );
591        }
592    }
593}