Skip to main content

lean_ctx/tools/
ctx_shell.rs

1use crate::core::patterns;
2use crate::core::protocol;
3use crate::core::symbol_map::{self, SymbolMap};
4use crate::core::tokens::count_tokens;
5use crate::tools::CrpMode;
6
7const MAX_COMMAND_BYTES: usize = 8192;
8
9/// Validates a shell command before execution. Returns Some(error_message) if
10/// the command should be rejected, None if it's safe to run.
11pub fn validate_command(command: &str) -> Option<String> {
12    if command.len() > MAX_COMMAND_BYTES {
13        return Some(format!(
14            "ERROR: Command too large ({} bytes, limit {}). \
15             If you're writing file content, use the native Write/Edit tool instead. \
16             ctx_shell is for reading command output only (git, cargo, npm, etc.).",
17            command.len(),
18            MAX_COMMAND_BYTES
19        ));
20    }
21
22    if has_file_write_redirect(command) {
23        return Some(
24            "ERROR: ctx_shell detected a file-write command (shell redirect > or >>). \
25             Use the native Write tool to create/modify files. \
26             ctx_shell is ONLY for reading command output (git status, cargo test, npm run, etc.). \
27             File writes via shell cause MCP protocol corruption on large payloads."
28                .to_string(),
29        );
30    }
31
32    let cmd_lower = command.to_lowercase();
33
34    if cmd_lower.starts_with("tee ") || cmd_lower.contains("| tee ") {
35        return Some(
36            "ERROR: ctx_shell detected a file-write command (tee). \
37             Use the native Write tool to create/modify files. \
38             ctx_shell is ONLY for reading command output."
39                .to_string(),
40        );
41    }
42
43    if is_heredoc_file_write(command) {
44        return Some(
45            "ERROR: ctx_shell detected a heredoc writing to a file. \
46             Use the native Write tool to create/modify files. \
47             ctx_shell is ONLY for reading command output. \
48             Note: heredocs for input piping (e.g. psql <<EOF) are allowed."
49                .to_string(),
50        );
51    }
52
53    None
54}
55
56/// Returns true only for heredocs that redirect to files (the dangerous pattern).
57/// Legitimate heredoc uses (input piping, inline scripts) are allowed through.
58fn is_heredoc_file_write(command: &str) -> bool {
59    let has_heredoc = command.contains("<<");
60    if !has_heredoc {
61        return false;
62    }
63    // Only block: heredoc combined with file output redirect
64    // e.g. `cat <<EOF > file.txt` or `cat <<'EOF' >> output.log`
65    let cmd_lower = command.to_lowercase();
66    let heredoc_patterns = ["<<eof", "<<'eof'", "<<\"eof\"", "<<end", "<<'end'"];
67    let has_known_heredoc = heredoc_patterns.iter().any(|p| cmd_lower.contains(p));
68    if !has_known_heredoc {
69        return false;
70    }
71    has_file_write_redirect(command)
72}
73
74/// Detects shell redirect operators (`>` or `>>`) that write to files.
75/// Ignores `>` inside quotes, `2>` (stderr), `/dev/null`, and comparison operators.
76fn has_file_write_redirect(command: &str) -> bool {
77    let bytes = command.as_bytes();
78    let len = bytes.len();
79    let mut i = 0;
80    let mut in_single_quote = false;
81    let mut in_double_quote = false;
82
83    while i < len {
84        let c = bytes[i];
85        if c == b'\'' && !in_double_quote {
86            in_single_quote = !in_single_quote;
87        } else if c == b'"' && !in_single_quote {
88            in_double_quote = !in_double_quote;
89        } else if c == b'>' && !in_single_quote && !in_double_quote {
90            if i > 0 && bytes[i - 1] == b'2' {
91                i += 1;
92                continue;
93            }
94            let target_start = if i + 1 < len && bytes[i + 1] == b'>' {
95                i + 2
96            } else {
97                i + 1
98            };
99            let target: String = command[target_start..]
100                .trim_start()
101                .chars()
102                .take_while(|c| !c.is_whitespace())
103                .collect();
104            if target == "/dev/null" {
105                i += 1;
106                continue;
107            }
108            if !target.is_empty() {
109                return true;
110            }
111        }
112        i += 1;
113    }
114    false
115}
116
117/// On Windows cmd.exe, `;` is not a valid command separator.
118/// Convert `cmd1; cmd2` to `cmd1 && cmd2` when running under cmd.exe.
119pub fn normalize_command_for_shell(command: &str) -> String {
120    if !cfg!(windows) {
121        return command.to_string();
122    }
123    let (_, flag) = crate::shell::shell_and_flag();
124    if flag != "/C" {
125        return command.to_string();
126    }
127    let bytes = command.as_bytes();
128    let mut result = Vec::with_capacity(bytes.len() + 16);
129    let mut in_single = false;
130    let mut in_double = false;
131    for (i, &b) in bytes.iter().enumerate() {
132        if b == b'\'' && !in_double {
133            in_single = !in_single;
134        } else if b == b'"' && !in_single {
135            in_double = !in_double;
136        } else if b == b';' && !in_single && !in_double {
137            result.extend_from_slice(b" && ");
138            continue;
139        }
140        result.push(b);
141        let _ = i;
142    }
143    String::from_utf8(result).unwrap_or_else(|_| command.to_string())
144}
145
146/// Compresses shell command output using pattern matching, dedup, and symbol mapping.
147pub fn handle(command: &str, output: &str, crp_mode: CrpMode) -> String {
148    let original_tokens = count_tokens(output);
149
150    if contains_auth_flow(output) {
151        let savings = protocol::format_savings(original_tokens, original_tokens);
152        return format!(
153            "{output}\n[lean-ctx: auth/device-code flow detected — output preserved uncompressed]\n{savings}"
154        );
155    }
156
157    let raw_compressed = match patterns::compress_output(command, output) {
158        Some(c) => crate::core::compressor::safeguard_ratio(output, &c),
159        None if is_search_command(command) => {
160            let stripped = crate::core::compressor::strip_ansi(output);
161            stripped.clone()
162        }
163        None => generic_compress(output),
164    };
165    let compressed = crate::core::compressor::verbatim_compact(&raw_compressed);
166
167    if crp_mode.is_tdd() && looks_like_code(&compressed) {
168        let ext = detect_ext_from_command(command);
169        let mut sym = SymbolMap::new();
170        let idents = symbol_map::extract_identifiers(&compressed, ext);
171        for ident in &idents {
172            sym.register(ident);
173        }
174        if !sym.is_empty() {
175            let mapped = sym.apply(&compressed);
176            let sym_table = sym.format_table();
177            let result = format!("{mapped}{sym_table}");
178            let sent = count_tokens(&result);
179            let savings = protocol::format_savings(original_tokens, sent);
180            return format!("{result}\n{savings}");
181        }
182    }
183
184    let sent = count_tokens(&compressed);
185    let savings = protocol::format_savings(original_tokens, sent);
186
187    format!("{compressed}\n{savings}")
188}
189
190fn is_search_command(command: &str) -> bool {
191    let cmd = command.trim_start();
192    cmd.starts_with("grep ")
193        || cmd.starts_with("rg ")
194        || cmd.starts_with("find ")
195        || cmd.starts_with("fd ")
196        || cmd.starts_with("ag ")
197        || cmd.starts_with("ack ")
198}
199
200fn generic_compress(output: &str) -> String {
201    let output = crate::core::compressor::strip_ansi(output);
202    let lines: Vec<&str> = output
203        .lines()
204        .filter(|l| {
205            let t = l.trim();
206            !t.is_empty()
207        })
208        .collect();
209
210    if lines.len() <= 20 {
211        return lines.join("\n");
212    }
213
214    let show_count = (lines.len() / 3).min(30);
215    let half = show_count / 2;
216    let first = &lines[..half];
217    let last = &lines[lines.len() - half..];
218    let omitted = lines.len() - (half * 2);
219    format!(
220        "{}\n[truncated: showing {}/{} lines, {} omitted. Use raw=true for full output.]\n{}",
221        first.join("\n"),
222        half * 2,
223        lines.len(),
224        omitted,
225        last.join("\n")
226    )
227}
228
229fn looks_like_code(text: &str) -> bool {
230    let indicators = [
231        "fn ",
232        "pub ",
233        "let ",
234        "const ",
235        "impl ",
236        "struct ",
237        "enum ",
238        "function ",
239        "class ",
240        "import ",
241        "export ",
242        "def ",
243        "async ",
244        "=>",
245        "->",
246        "::",
247        "self.",
248        "this.",
249    ];
250    let total_lines = text.lines().count();
251    if total_lines < 3 {
252        return false;
253    }
254    let code_lines = text
255        .lines()
256        .filter(|l| indicators.iter().any(|i| l.contains(i)))
257        .count();
258    code_lines as f64 / total_lines as f64 > 0.15
259}
260
261fn detect_ext_from_command(command: &str) -> &str {
262    let cmd = command.to_lowercase();
263    if cmd.contains("cargo") || cmd.contains(".rs") {
264        "rs"
265    } else if cmd.contains("npm")
266        || cmd.contains("node")
267        || cmd.contains(".ts")
268        || cmd.contains(".js")
269    {
270        "ts"
271    } else if cmd.contains("python") || cmd.contains("pip") || cmd.contains(".py") {
272        "py"
273    } else if cmd.contains("go ") || cmd.contains(".go") {
274        "go"
275    } else {
276        "rs"
277    }
278}
279
280/// Detects OAuth device code flow output that must not be compressed.
281/// Uses a two-tier approach: strong signals match alone (very specific to
282/// device code flows), weak signals require a URL/domain in the same output.
283pub fn contains_auth_flow(output: &str) -> bool {
284    let lower = output.to_lowercase();
285
286    const STRONG_SIGNALS: &[&str] = &[
287        "devicelogin",
288        "deviceauth",
289        "device_code",
290        "device code",
291        "device-code",
292        "verification_uri",
293        "user_code",
294        "one-time code",
295    ];
296
297    if STRONG_SIGNALS.iter().any(|s| lower.contains(s)) {
298        return true;
299    }
300
301    const WEAK_SIGNALS: &[&str] = &[
302        "enter the code",
303        "enter this code",
304        "enter code:",
305        "use the code",
306        "use a web browser to open",
307        "open the page",
308        "authenticate by visiting",
309        "sign in with the code",
310        "sign in using a code",
311        "verification code",
312        "authorize this device",
313        "waiting for authentication",
314        "waiting for login",
315        "waiting for you to authenticate",
316        "open your browser",
317        "open in your browser",
318    ];
319
320    let has_weak_signal = WEAK_SIGNALS.iter().any(|s| lower.contains(s));
321    if !has_weak_signal {
322        return false;
323    }
324
325    lower.contains("http://") || lower.contains("https://")
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn normalize_cmd_no_change_on_unix() {
334        if cfg!(windows) {
335            return;
336        }
337        assert_eq!(
338            normalize_command_for_shell("cd /tmp; ls -la"),
339            "cd /tmp; ls -la"
340        );
341    }
342
343    #[test]
344    fn validate_allows_safe_commands() {
345        assert!(validate_command("git status").is_none());
346        assert!(validate_command("cargo test").is_none());
347        assert!(validate_command("npm run build").is_none());
348        assert!(validate_command("ls -la").is_none());
349    }
350
351    #[test]
352    fn validate_blocks_file_writes() {
353        assert!(validate_command("echo 'data' > output.txt").is_some());
354        assert!(validate_command("tee /tmp/file.txt").is_some());
355        assert!(validate_command("printf 'hello' > test.txt").is_some());
356    }
357
358    #[test]
359    fn validate_blocks_heredoc_with_file_redirect() {
360        assert!(validate_command("cat > file.py <<'EOF'\nprint('hi')\nEOF").is_some());
361        assert!(validate_command("cat <<EOF > output.txt\nhello\nEOF").is_some());
362        assert!(validate_command("cat <<'END' >> logfile.txt\ndata\nEND").is_some());
363    }
364
365    #[test]
366    fn validate_allows_heredoc_without_file_redirect() {
367        assert!(validate_command("cat <<EOF\nhello world\nEOF").is_none());
368        assert!(validate_command("psql -d mydb <<EOF\nSELECT 1;\nEOF").is_none());
369        assert!(
370            validate_command("git commit -m \"$(cat <<'EOF'\nfix: something\nEOF\n)\"").is_none()
371        );
372        assert!(validate_command("grep pattern <<EOF\nfoo\nbar\nEOF").is_none());
373    }
374
375    #[test]
376    fn validate_blocks_oversized_commands() {
377        let huge = "x".repeat(MAX_COMMAND_BYTES + 1);
378        let result = validate_command(&huge);
379        assert!(result.is_some());
380        assert!(result.unwrap().contains("too large"));
381    }
382
383    #[test]
384    fn validate_allows_cat_without_redirect() {
385        assert!(validate_command("cat file.txt").is_none());
386    }
387
388    // --- Auth flow detection: strong signals (no URL needed) ---
389
390    #[test]
391    fn auth_flow_detects_azure_device_code() {
392        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.";
393        assert!(contains_auth_flow(output));
394    }
395
396    #[test]
397    fn auth_flow_detects_gh_auth_one_time_code() {
398        let output =
399            "! First copy your one-time code: ABCD-1234\n- Press Enter to open github.com in your browser...";
400        assert!(contains_auth_flow(output));
401    }
402
403    #[test]
404    fn auth_flow_detects_device_code_json() {
405        let output = r#"{"device_code":"abc123","user_code":"ABCD-1234","verification_uri":"https://example.com/activate"}"#;
406        assert!(contains_auth_flow(output));
407    }
408
409    #[test]
410    fn auth_flow_detects_verification_uri_field() {
411        let output =
412            r#"{"verification_uri": "https://login.microsoftonline.com/common/oauth2/deviceauth"}"#;
413        assert!(contains_auth_flow(output));
414    }
415
416    #[test]
417    fn auth_flow_detects_user_code_field() {
418        let output = r#"{"user_code": "FGHJK-LMNOP", "expires_in": 900}"#;
419        assert!(contains_auth_flow(output));
420    }
421
422    // --- Auth flow detection: weak signals (require URL) ---
423
424    #[test]
425    fn auth_flow_detects_gcloud_with_url() {
426        let output = "Go to the following link in your browser:\n\n    https://accounts.google.com/o/oauth2/auth?response_type=code\n\nEnter verification code: ";
427        assert!(contains_auth_flow(output));
428    }
429
430    #[test]
431    fn auth_flow_detects_aws_sso_with_url() {
432        let output = "If the browser does not open, open the following URL:\nhttps://device.sso.us-east-1.amazonaws.com/\n\nThen enter the code:\nABCD-EFGH";
433        assert!(contains_auth_flow(output));
434    }
435
436    #[test]
437    fn auth_flow_detects_firebase_with_url() {
438        let output = "Visit this URL on this device to log in:\nhttps://accounts.google.com/o/oauth2/auth?...\n\nWaiting for authentication...";
439        assert!(contains_auth_flow(output));
440    }
441
442    #[test]
443    fn auth_flow_detects_generic_browser_open_with_url() {
444        let output =
445            "Open your browser to https://login.example.com/device and enter the code XYZW-1234";
446        assert!(contains_auth_flow(output));
447    }
448
449    // --- False positive protection ---
450
451    #[test]
452    fn auth_flow_ignores_normal_build_output() {
453        let output = "Compiling lean-ctx v2.21.9\nFinished release profile\n";
454        assert!(!contains_auth_flow(output));
455    }
456
457    #[test]
458    fn auth_flow_ignores_git_output() {
459        let output = "On branch main\nYour branch is up to date with 'origin/main'.\nnothing to commit, working tree clean";
460        assert!(!contains_auth_flow(output));
461    }
462
463    #[test]
464    fn auth_flow_ignores_npm_install_output() {
465        let output = "added 150 packages in 3s\n\n24 packages are looking for funding\n  run `npm fund` for details\nhttps://npmjs.com/package/lean-ctx";
466        assert!(!contains_auth_flow(output));
467    }
468
469    #[test]
470    fn auth_flow_ignores_docs_mentioning_auth() {
471        let output = "The authorization code grant type is the most common OAuth flow.\nSee https://oauth.net/2/grant-types/ for details.";
472        assert!(!contains_auth_flow(output));
473    }
474
475    #[test]
476    fn auth_flow_weak_signal_requires_url() {
477        let output = "Please enter the code ABC123 in the terminal";
478        assert!(!contains_auth_flow(output));
479    }
480
481    #[test]
482    fn auth_flow_weak_signal_without_url_is_ignored() {
483        let output = "Waiting for authentication to complete... done!";
484        assert!(!contains_auth_flow(output));
485    }
486
487    #[test]
488    fn auth_flow_ignores_virtualenv_activate() {
489        let output = "Created virtualenv at .venv\nRun: source .venv/bin/activate";
490        assert!(!contains_auth_flow(output));
491    }
492
493    #[test]
494    fn auth_flow_ignores_api_response_with_code_field() {
495        let output = r#"{"status": "ok", "code": 200, "message": "success"}"#;
496        assert!(!contains_auth_flow(output));
497    }
498
499    // --- Integration: handle() preserves auth flow ---
500
501    #[test]
502    fn handle_preserves_auth_flow_output_fully() {
503        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.\nWaiting for you...\nLine 3\nLine 4\nLine 5\nLine 6\nLine 7\nLine 8\nLine 9\nLine 10\nLine 11\nLine 12\nLine 13";
504        let result = handle("az login --use-device-code", output, CrpMode::Off);
505        assert!(result.contains("ABCD1234"), "auth code must be preserved");
506        assert!(result.contains("devicelogin"), "URL must be preserved");
507        assert!(
508            result.contains("auth/device-code flow detected"),
509            "detection note must be present"
510        );
511        assert!(
512            result.contains("Line 13"),
513            "all lines must be preserved (no truncation)"
514        );
515    }
516
517    #[test]
518    fn handle_compresses_normal_output_not_auth() {
519        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i} of output")).collect();
520        let output = lines.join("\n");
521        let result = handle("some-tool check", &output, CrpMode::Off);
522        assert!(
523            !result.contains("auth/device-code flow detected"),
524            "normal output must not trigger auth detection"
525        );
526        assert!(
527            result.len() < output.len() + 100,
528            "normal output should be compressed, not inflated"
529        );
530    }
531
532    #[test]
533    fn is_search_command_detects_grep() {
534        assert!(is_search_command("grep -r pattern src/"));
535        assert!(is_search_command("rg pattern src/"));
536        assert!(is_search_command("find . -name '*.rs'"));
537        assert!(is_search_command("fd pattern"));
538        assert!(is_search_command("ag pattern src/"));
539        assert!(is_search_command("ack pattern"));
540    }
541
542    #[test]
543    fn is_search_command_rejects_non_search() {
544        assert!(!is_search_command("cargo build"));
545        assert!(!is_search_command("git status"));
546        assert!(!is_search_command("npm install"));
547        assert!(!is_search_command("cat file.rs"));
548    }
549
550    #[test]
551    fn generic_compress_preserves_short_output() {
552        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i}")).collect();
553        let output = lines.join("\n");
554        let result = generic_compress(&output);
555        assert_eq!(result, output);
556    }
557
558    #[test]
559    fn generic_compress_scales_with_length() {
560        let lines: Vec<String> = (1..=60).map(|i| format!("Line {i}")).collect();
561        let output = lines.join("\n");
562        let result = generic_compress(&output);
563        assert!(result.contains("truncated"));
564        let shown_count = result.lines().count();
565        assert!(
566            shown_count > 10,
567            "should show more than old 6-line limit, got {shown_count}"
568        );
569        assert!(shown_count < 60, "should be truncated, not full output");
570    }
571
572    #[test]
573    fn handle_preserves_search_results() {
574        let lines: Vec<String> = (1..=30)
575            .map(|i| format!("src/file{i}.rs:42: fn search_result()"))
576            .collect();
577        let output = lines.join("\n");
578        let result = handle("rg search_result src/", &output, CrpMode::Off);
579        for i in 1..=30 {
580            assert!(
581                result.contains(&format!("file{i}")),
582                "search result file{i} should be preserved in output"
583            );
584        }
585    }
586}