Skip to main content

lean_ctx/tools/
ctx_shell.rs

1use crate::core::patterns;
2use crate::core::protocol;
3use crate::core::symbol_map::{self, SymbolMap};
4use crate::core::tokens::count_tokens;
5use crate::tools::CrpMode;
6
7const MAX_COMMAND_BYTES: usize = 8192;
8
9/// Validates a shell command before execution. Returns Some(error_message) if
10/// the command should be rejected, None if it's safe to run.
11pub fn validate_command(command: &str) -> Option<String> {
12    if command.len() > MAX_COMMAND_BYTES {
13        return Some(format!(
14            "ERROR: Command too large ({} bytes, limit {}). \
15             If you're writing file content, use the native Write/Edit tool instead. \
16             ctx_shell is for reading command output only (git, cargo, npm, etc.).",
17            command.len(),
18            MAX_COMMAND_BYTES
19        ));
20    }
21
22    if has_file_write_redirect(command) {
23        return Some(
24            "ERROR: ctx_shell detected a file-write command (shell redirect > or >>). \
25             Use the native Write tool to create/modify files. \
26             ctx_shell is ONLY for reading command output (git status, cargo test, npm run, etc.). \
27             File writes via shell cause MCP protocol corruption on large payloads."
28                .to_string(),
29        );
30    }
31
32    let cmd_lower = command.to_lowercase();
33
34    if cmd_lower.starts_with("tee ") || cmd_lower.contains("| tee ") {
35        return Some(
36            "ERROR: ctx_shell detected a file-write command (tee). \
37             Use the native Write tool to create/modify files. \
38             ctx_shell is ONLY for reading command output."
39                .to_string(),
40        );
41    }
42
43    if is_heredoc_file_write(command) {
44        return Some(
45            "ERROR: ctx_shell detected a heredoc writing to a file. \
46             Use the native Write tool to create/modify files. \
47             ctx_shell is ONLY for reading command output. \
48             Note: heredocs for input piping (e.g. psql <<EOF) are allowed."
49                .to_string(),
50        );
51    }
52
53    None
54}
55
56/// Returns true only for heredocs that redirect to files (the dangerous pattern).
57/// Legitimate heredoc uses (input piping, inline scripts) are allowed through.
58fn is_heredoc_file_write(command: &str) -> bool {
59    let has_heredoc = command.contains("<<");
60    if !has_heredoc {
61        return false;
62    }
63    // Only block: heredoc combined with file output redirect
64    // e.g. `cat <<EOF > file.txt` or `cat <<'EOF' >> output.log`
65    let cmd_lower = command.to_lowercase();
66    let heredoc_patterns = ["<<eof", "<<'eof'", "<<\"eof\"", "<<end", "<<'end'"];
67    let has_known_heredoc = heredoc_patterns.iter().any(|p| cmd_lower.contains(p));
68    if !has_known_heredoc {
69        return false;
70    }
71    has_file_write_redirect(command)
72}
73
74/// Detects shell redirect operators (`>` or `>>`) that write to files.
75/// Ignores `>` inside quotes, `2>` (stderr), `/dev/null`, and comparison operators.
76fn has_file_write_redirect(command: &str) -> bool {
77    let bytes = command.as_bytes();
78    let len = bytes.len();
79    let mut i = 0;
80    let mut in_single_quote = false;
81    let mut in_double_quote = false;
82
83    while i < len {
84        let c = bytes[i];
85        if c == b'\'' && !in_double_quote {
86            in_single_quote = !in_single_quote;
87        } else if c == b'"' && !in_single_quote {
88            in_double_quote = !in_double_quote;
89        } else if c == b'>' && !in_single_quote && !in_double_quote {
90            if i > 0 && bytes[i - 1] == b'2' {
91                i += 1;
92                continue;
93            }
94            let target_start = if i + 1 < len && bytes[i + 1] == b'>' {
95                i + 2
96            } else {
97                i + 1
98            };
99            let target: String = command[target_start..]
100                .trim_start()
101                .chars()
102                .take_while(|c| !c.is_whitespace())
103                .collect();
104            if target == "/dev/null" {
105                i += 1;
106                continue;
107            }
108            if !target.is_empty() {
109                return true;
110            }
111        }
112        i += 1;
113    }
114    false
115}
116
117/// On Windows cmd.exe, `;` is not a valid command separator.
118/// Convert `cmd1; cmd2` to `cmd1 && cmd2` when running under cmd.exe.
119pub fn normalize_command_for_shell(command: &str) -> String {
120    if !cfg!(windows) {
121        return command.to_string();
122    }
123    let (_, flag) = crate::shell::shell_and_flag();
124    if flag != "/C" {
125        return command.to_string();
126    }
127    let bytes = command.as_bytes();
128    let mut result = Vec::with_capacity(bytes.len() + 16);
129    let mut in_single = false;
130    let mut in_double = false;
131    for (i, &b) in bytes.iter().enumerate() {
132        if b == b'\'' && !in_double {
133            in_single = !in_single;
134        } else if b == b'"' && !in_single {
135            in_double = !in_double;
136        } else if b == b';' && !in_single && !in_double {
137            result.extend_from_slice(b" && ");
138            continue;
139        }
140        result.push(b);
141        let _ = i;
142    }
143    String::from_utf8(result).unwrap_or_else(|_| command.to_string())
144}
145
146/// Compresses shell command output using pattern matching, dedup, and symbol mapping.
147pub fn handle(command: &str, output: &str, crp_mode: CrpMode) -> String {
148    let original_tokens = count_tokens(output);
149
150    // OutputPolicy gate: protected commands bypass ALL compression.
151    let policy = crate::shell::output_policy::classify(command, &[]);
152    if policy.is_protected() {
153        let savings = protocol::format_savings(original_tokens, original_tokens);
154        return format!("{output}\n{savings}");
155    }
156
157    if !is_search_command(command) && contains_auth_flow(output) {
158        let savings = protocol::format_savings(original_tokens, original_tokens);
159        return format!(
160            "{output}\n[lean-ctx: auth/device-code flow detected — output preserved uncompressed]\n{savings}"
161        );
162    }
163
164    let raw_compressed = match patterns::compress_output(command, output) {
165        Some(c) if is_search_command(command) => {
166            if count_tokens(&c) <= count_tokens(output) {
167                c
168            } else {
169                output.to_string()
170            }
171        }
172        Some(c) => crate::core::compressor::safeguard_ratio(output, &c),
173        None if is_search_command(command) => {
174            let stripped = crate::core::compressor::strip_ansi(output);
175            stripped.clone()
176        }
177        None => generic_compress(output),
178    };
179    let compressed = crate::core::compressor::verbatim_compact(&raw_compressed);
180
181    if crp_mode.is_tdd() && looks_like_code(&compressed) {
182        let ext = detect_ext_from_command(command);
183        let mut sym = SymbolMap::new();
184        let idents = symbol_map::extract_identifiers(&compressed, ext);
185        for ident in &idents {
186            sym.register(ident);
187        }
188        if !sym.is_empty() {
189            let mapped = sym.apply(&compressed);
190            let sym_table = sym.format_table();
191            let result = format!("{mapped}{sym_table}");
192            let sent = count_tokens(&result);
193            let savings = protocol::format_savings(original_tokens, sent);
194            return format!("{result}\n{savings}");
195        }
196    }
197
198    let sent = count_tokens(&compressed);
199    let savings = protocol::format_savings(original_tokens, sent);
200
201    format!("{compressed}\n{savings}")
202}
203
204fn is_search_command(command: &str) -> bool {
205    let cmd = command.trim_start();
206    cmd.starts_with("grep ")
207        || cmd.starts_with("rg ")
208        || cmd.starts_with("find ")
209        || cmd.starts_with("fd ")
210        || cmd.starts_with("ag ")
211        || cmd.starts_with("ack ")
212}
213
214fn generic_compress(output: &str) -> String {
215    let output = crate::core::compressor::strip_ansi(output);
216    let lines: Vec<&str> = output
217        .lines()
218        .filter(|l| {
219            let t = l.trim();
220            !t.is_empty()
221        })
222        .collect();
223
224    if lines.len() <= 20 {
225        return lines.join("\n");
226    }
227
228    let show_count = (lines.len() / 3).min(30);
229    let half = show_count / 2;
230    let first = &lines[..half];
231    let last = &lines[lines.len() - half..];
232    let omitted = lines.len() - (half * 2);
233    format!(
234        "{}\n[truncated: showing {}/{} lines, {} omitted. Use raw=true for full output.]\n{}",
235        first.join("\n"),
236        half * 2,
237        lines.len(),
238        omitted,
239        last.join("\n")
240    )
241}
242
243fn looks_like_code(text: &str) -> bool {
244    let indicators = [
245        "fn ",
246        "pub ",
247        "let ",
248        "const ",
249        "impl ",
250        "struct ",
251        "enum ",
252        "function ",
253        "class ",
254        "import ",
255        "export ",
256        "def ",
257        "async ",
258        "=>",
259        "->",
260        "::",
261        "self.",
262        "this.",
263    ];
264    let total_lines = text.lines().count();
265    if total_lines < 3 {
266        return false;
267    }
268    let code_lines = text
269        .lines()
270        .filter(|l| indicators.iter().any(|i| l.contains(i)))
271        .count();
272    code_lines as f64 / total_lines as f64 > 0.15
273}
274
275fn detect_ext_from_command(command: &str) -> &str {
276    let cmd = command.to_lowercase();
277    if cmd.contains("cargo") || cmd.contains(".rs") {
278        "rs"
279    } else if cmd.contains("npm")
280        || cmd.contains("node")
281        || cmd.contains(".ts")
282        || cmd.contains(".js")
283    {
284        "ts"
285    } else if cmd.contains("python") || cmd.contains("pip") || cmd.contains(".py") {
286        "py"
287    } else if cmd.contains("go ") || cmd.contains(".go") {
288        "go"
289    } else {
290        "rs"
291    }
292}
293
294/// Detects OAuth device code flow output that must not be compressed.
295/// Uses a two-tier approach: strong signals match alone (very specific to
296/// device code flows), weak signals require a URL/domain in the same output.
297pub fn contains_auth_flow(output: &str) -> bool {
298    let lower = output.to_lowercase();
299
300    const STRONG_SIGNALS: &[&str] = &[
301        "devicelogin",
302        "deviceauth",
303        "device_code",
304        "device code",
305        "device-code",
306        "verification_uri",
307        "user_code",
308        "one-time code",
309    ];
310
311    if STRONG_SIGNALS.iter().any(|s| lower.contains(s)) {
312        return true;
313    }
314
315    const WEAK_SIGNALS: &[&str] = &[
316        "enter the code",
317        "enter this code",
318        "enter code:",
319        "use the code",
320        "use a web browser to open",
321        "open the page",
322        "authenticate by visiting",
323        "sign in with the code",
324        "sign in using a code",
325        "verification code",
326        "authorize this device",
327        "waiting for authentication",
328        "waiting for login",
329        "waiting for you to authenticate",
330        "open your browser",
331        "open in your browser",
332    ];
333
334    let has_weak_signal = WEAK_SIGNALS.iter().any(|s| lower.contains(s));
335    if !has_weak_signal {
336        return false;
337    }
338
339    lower.contains("http://") || lower.contains("https://")
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn normalize_cmd_no_change_on_unix() {
348        if cfg!(windows) {
349            return;
350        }
351        assert_eq!(
352            normalize_command_for_shell("cd /tmp; ls -la"),
353            "cd /tmp; ls -la"
354        );
355    }
356
357    #[test]
358    fn validate_allows_safe_commands() {
359        assert!(validate_command("git status").is_none());
360        assert!(validate_command("cargo test").is_none());
361        assert!(validate_command("npm run build").is_none());
362        assert!(validate_command("ls -la").is_none());
363    }
364
365    #[test]
366    fn validate_blocks_file_writes() {
367        assert!(validate_command("echo 'data' > output.txt").is_some());
368        assert!(validate_command("tee /tmp/file.txt").is_some());
369        assert!(validate_command("printf 'hello' > test.txt").is_some());
370    }
371
372    #[test]
373    fn validate_blocks_heredoc_with_file_redirect() {
374        assert!(validate_command("cat > file.py <<'EOF'\nprint('hi')\nEOF").is_some());
375        assert!(validate_command("cat <<EOF > output.txt\nhello\nEOF").is_some());
376        assert!(validate_command("cat <<'END' >> logfile.txt\ndata\nEND").is_some());
377    }
378
379    #[test]
380    fn validate_allows_heredoc_without_file_redirect() {
381        assert!(validate_command("cat <<EOF\nhello world\nEOF").is_none());
382        assert!(validate_command("psql -d mydb <<EOF\nSELECT 1;\nEOF").is_none());
383        assert!(
384            validate_command("git commit -m \"$(cat <<'EOF'\nfix: something\nEOF\n)\"").is_none()
385        );
386        assert!(validate_command("grep pattern <<EOF\nfoo\nbar\nEOF").is_none());
387    }
388
389    #[test]
390    fn validate_blocks_oversized_commands() {
391        let huge = "x".repeat(MAX_COMMAND_BYTES + 1);
392        let result = validate_command(&huge);
393        assert!(result.is_some());
394        assert!(result.unwrap().contains("too large"));
395    }
396
397    #[test]
398    fn validate_allows_cat_without_redirect() {
399        assert!(validate_command("cat file.txt").is_none());
400    }
401
402    // --- Auth flow detection: strong signals (no URL needed) ---
403
404    #[test]
405    fn auth_flow_detects_azure_device_code() {
406        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.";
407        assert!(contains_auth_flow(output));
408    }
409
410    #[test]
411    fn auth_flow_detects_gh_auth_one_time_code() {
412        let output =
413            "! First copy your one-time code: ABCD-1234\n- Press Enter to open github.com in your browser...";
414        assert!(contains_auth_flow(output));
415    }
416
417    #[test]
418    fn auth_flow_detects_device_code_json() {
419        let output = r#"{"device_code":"abc123","user_code":"ABCD-1234","verification_uri":"https://example.com/activate"}"#;
420        assert!(contains_auth_flow(output));
421    }
422
423    #[test]
424    fn auth_flow_detects_verification_uri_field() {
425        let output =
426            r#"{"verification_uri": "https://login.microsoftonline.com/common/oauth2/deviceauth"}"#;
427        assert!(contains_auth_flow(output));
428    }
429
430    #[test]
431    fn auth_flow_detects_user_code_field() {
432        let output = r#"{"user_code": "FGHJK-LMNOP", "expires_in": 900}"#;
433        assert!(contains_auth_flow(output));
434    }
435
436    // --- Auth flow detection: weak signals (require URL) ---
437
438    #[test]
439    fn auth_flow_detects_gcloud_with_url() {
440        let output = "Go to the following link in your browser:\n\n    https://accounts.google.com/o/oauth2/auth?response_type=code\n\nEnter verification code: ";
441        assert!(contains_auth_flow(output));
442    }
443
444    #[test]
445    fn auth_flow_detects_aws_sso_with_url() {
446        let output = "If the browser does not open, open the following URL:\nhttps://device.sso.us-east-1.amazonaws.com/\n\nThen enter the code:\nABCD-EFGH";
447        assert!(contains_auth_flow(output));
448    }
449
450    #[test]
451    fn auth_flow_detects_firebase_with_url() {
452        let output = "Visit this URL on this device to log in:\nhttps://accounts.google.com/o/oauth2/auth?...\n\nWaiting for authentication...";
453        assert!(contains_auth_flow(output));
454    }
455
456    #[test]
457    fn auth_flow_detects_generic_browser_open_with_url() {
458        let output =
459            "Open your browser to https://login.example.com/device and enter the code XYZW-1234";
460        assert!(contains_auth_flow(output));
461    }
462
463    // --- False positive protection ---
464
465    #[test]
466    fn auth_flow_ignores_normal_build_output() {
467        let output = "Compiling lean-ctx v2.21.9\nFinished release profile\n";
468        assert!(!contains_auth_flow(output));
469    }
470
471    #[test]
472    fn auth_flow_ignores_git_output() {
473        let output = "On branch main\nYour branch is up to date with 'origin/main'.\nnothing to commit, working tree clean";
474        assert!(!contains_auth_flow(output));
475    }
476
477    #[test]
478    fn auth_flow_ignores_npm_install_output() {
479        let output = "added 150 packages in 3s\n\n24 packages are looking for funding\n  run `npm fund` for details\nhttps://npmjs.com/package/lean-ctx";
480        assert!(!contains_auth_flow(output));
481    }
482
483    #[test]
484    fn auth_flow_ignores_docs_mentioning_auth() {
485        let output = "The authorization code grant type is the most common OAuth flow.\nSee https://oauth.net/2/grant-types/ for details.";
486        assert!(!contains_auth_flow(output));
487    }
488
489    #[test]
490    fn auth_flow_weak_signal_requires_url() {
491        let output = "Please enter the code ABC123 in the terminal";
492        assert!(!contains_auth_flow(output));
493    }
494
495    #[test]
496    fn auth_flow_weak_signal_without_url_is_ignored() {
497        let output = "Waiting for authentication to complete... done!";
498        assert!(!contains_auth_flow(output));
499    }
500
501    #[test]
502    fn auth_flow_ignores_virtualenv_activate() {
503        let output = "Created virtualenv at .venv\nRun: source .venv/bin/activate";
504        assert!(!contains_auth_flow(output));
505    }
506
507    #[test]
508    fn auth_flow_ignores_api_response_with_code_field() {
509        let output = r#"{"status": "ok", "code": 200, "message": "success"}"#;
510        assert!(!contains_auth_flow(output));
511    }
512
513    // --- Integration: handle() preserves auth flow ---
514
515    #[test]
516    fn handle_preserves_auth_flow_output_fully() {
517        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.\nWaiting for you...\nLine 3\nLine 4\nLine 5\nLine 6\nLine 7\nLine 8\nLine 9\nLine 10\nLine 11\nLine 12\nLine 13";
518        // az login is Passthrough via OutputPolicy, so all content is preserved
519        let result = handle("az login --use-device-code", output, CrpMode::Off);
520        assert!(result.contains("ABCD1234"), "auth code must be preserved");
521        assert!(result.contains("devicelogin"), "URL must be preserved");
522        assert!(
523            result.contains("Line 13"),
524            "all lines must be preserved (no truncation)"
525        );
526    }
527
528    #[test]
529    fn handle_compresses_normal_output_not_auth() {
530        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i} of output")).collect();
531        let output = lines.join("\n");
532        let result = handle("some-tool check", &output, CrpMode::Off);
533        assert!(
534            !result.contains("auth/device-code flow detected"),
535            "normal output must not trigger auth detection"
536        );
537        assert!(
538            result.len() < output.len() + 100,
539            "normal output should be compressed, not inflated"
540        );
541    }
542
543    #[test]
544    fn is_search_command_detects_grep() {
545        assert!(is_search_command("grep -r pattern src/"));
546        assert!(is_search_command("rg pattern src/"));
547        assert!(is_search_command("find . -name '*.rs'"));
548        assert!(is_search_command("fd pattern"));
549        assert!(is_search_command("ag pattern src/"));
550        assert!(is_search_command("ack pattern"));
551    }
552
553    #[test]
554    fn is_search_command_rejects_non_search() {
555        assert!(!is_search_command("cargo build"));
556        assert!(!is_search_command("git status"));
557        assert!(!is_search_command("npm install"));
558        assert!(!is_search_command("cat file.rs"));
559    }
560
561    #[test]
562    fn generic_compress_preserves_short_output() {
563        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i}")).collect();
564        let output = lines.join("\n");
565        let result = generic_compress(&output);
566        assert_eq!(result, output);
567    }
568
569    #[test]
570    fn generic_compress_scales_with_length() {
571        let lines: Vec<String> = (1..=60).map(|i| format!("Line {i}")).collect();
572        let output = lines.join("\n");
573        let result = generic_compress(&output);
574        assert!(result.contains("truncated"));
575        let shown_count = result.lines().count();
576        assert!(
577            shown_count > 10,
578            "should show more than old 6-line limit, got {shown_count}"
579        );
580        assert!(shown_count < 60, "should be truncated, not full output");
581    }
582
583    #[test]
584    fn handle_preserves_search_results() {
585        let lines: Vec<String> = (1..=30)
586            .map(|i| format!("src/file{i}.rs:42: fn search_result()"))
587            .collect();
588        let output = lines.join("\n");
589        let result = handle("rg search_result src/", &output, CrpMode::Off);
590        for i in 1..=30 {
591            assert!(
592                result.contains(&format!("file{i}")),
593                "search result file{i} should be preserved in output"
594            );
595        }
596    }
597}