Skip to main content

lean_ctx/tools/
ctx_shell.rs

1use crate::core::patterns;
2use crate::core::protocol;
3use crate::core::symbol_map::{self, SymbolMap};
4use crate::core::tokens::count_tokens;
5use crate::tools::CrpMode;
6
7const MAX_COMMAND_BYTES: usize = 8192;
8
9const HEREDOC_PATTERNS: &[&str] = &[
10    "<< 'EOF'",
11    "<<'EOF'",
12    "<< 'ENDOFFILE'",
13    "<<'ENDOFFILE'",
14    "<< 'END'",
15    "<<'END'",
16    "<< EOF",
17    "<<EOF",
18    "cat <<",
19];
20
21/// Validates a shell command before execution. Returns Some(error_message) if
22/// the command should be rejected, None if it's safe to run.
23pub fn validate_command(command: &str) -> Option<String> {
24    if command.len() > MAX_COMMAND_BYTES {
25        return Some(format!(
26            "ERROR: Command too large ({} bytes, limit {}). \
27             If you're writing file content, use the native Write/Edit tool instead. \
28             ctx_shell is for reading command output only (git, cargo, npm, etc.).",
29            command.len(),
30            MAX_COMMAND_BYTES
31        ));
32    }
33
34    if has_file_write_redirect(command) {
35        return Some(
36            "ERROR: ctx_shell detected a file-write command (shell redirect > or >>). \
37             Use the native Write tool to create/modify files. \
38             ctx_shell is ONLY for reading command output (git status, cargo test, npm run, etc.). \
39             File writes via shell cause MCP protocol corruption on large payloads."
40                .to_string(),
41        );
42    }
43
44    let cmd_lower = command.to_lowercase();
45
46    if cmd_lower.starts_with("tee ") || cmd_lower.contains("| tee ") {
47        return Some(
48            "ERROR: ctx_shell detected a file-write command (tee). \
49             Use the native Write tool to create/modify files. \
50             ctx_shell is ONLY for reading command output."
51                .to_string(),
52        );
53    }
54
55    for pattern in HEREDOC_PATTERNS {
56        if cmd_lower.contains(&pattern.to_lowercase()) {
57            return Some(
58                "ERROR: ctx_shell detected a heredoc file-write command. \
59                 Use the native Write tool to create/modify files. \
60                 ctx_shell is ONLY for reading command output."
61                    .to_string(),
62            );
63        }
64    }
65
66    None
67}
68
69/// Detects shell redirect operators (`>` or `>>`) that write to files.
70/// Ignores `>` inside quotes, `2>` (stderr), `/dev/null`, and comparison operators.
71fn has_file_write_redirect(command: &str) -> bool {
72    let bytes = command.as_bytes();
73    let len = bytes.len();
74    let mut i = 0;
75    let mut in_single_quote = false;
76    let mut in_double_quote = false;
77
78    while i < len {
79        let c = bytes[i];
80        if c == b'\'' && !in_double_quote {
81            in_single_quote = !in_single_quote;
82        } else if c == b'"' && !in_single_quote {
83            in_double_quote = !in_double_quote;
84        } else if c == b'>' && !in_single_quote && !in_double_quote {
85            if i > 0 && bytes[i - 1] == b'2' {
86                i += 1;
87                continue;
88            }
89            let target_start = if i + 1 < len && bytes[i + 1] == b'>' {
90                i + 2
91            } else {
92                i + 1
93            };
94            let target: String = command[target_start..]
95                .trim_start()
96                .chars()
97                .take_while(|c| !c.is_whitespace())
98                .collect();
99            if target == "/dev/null" {
100                i += 1;
101                continue;
102            }
103            if !target.is_empty() {
104                return true;
105            }
106        }
107        i += 1;
108    }
109    false
110}
111
112/// On Windows cmd.exe, `;` is not a valid command separator.
113/// Convert `cmd1; cmd2` to `cmd1 && cmd2` when running under cmd.exe.
114pub fn normalize_command_for_shell(command: &str) -> String {
115    if !cfg!(windows) {
116        return command.to_string();
117    }
118    let (_, flag) = crate::shell::shell_and_flag();
119    if flag != "/C" {
120        return command.to_string();
121    }
122    let bytes = command.as_bytes();
123    let mut result = Vec::with_capacity(bytes.len() + 16);
124    let mut in_single = false;
125    let mut in_double = false;
126    for (i, &b) in bytes.iter().enumerate() {
127        if b == b'\'' && !in_double {
128            in_single = !in_single;
129        } else if b == b'"' && !in_single {
130            in_double = !in_double;
131        } else if b == b';' && !in_single && !in_double {
132            result.extend_from_slice(b" && ");
133            continue;
134        }
135        result.push(b);
136        let _ = i;
137    }
138    String::from_utf8(result).unwrap_or_else(|_| command.to_string())
139}
140
141pub fn handle(command: &str, output: &str, crp_mode: CrpMode) -> String {
142    let original_tokens = count_tokens(output);
143
144    if contains_auth_flow(output) {
145        let savings = protocol::format_savings(original_tokens, original_tokens);
146        return format!(
147            "{output}\n[lean-ctx: auth/device-code flow detected — output preserved uncompressed]\n{savings}"
148        );
149    }
150
151    let compressed = match patterns::compress_output(command, output) {
152        Some(c) => c,
153        None if is_search_command(command) => {
154            let stripped = crate::core::compressor::strip_ansi(output);
155            stripped.to_string()
156        }
157        None => generic_compress(output),
158    };
159
160    if crp_mode.is_tdd() && looks_like_code(&compressed) {
161        let ext = detect_ext_from_command(command);
162        let mut sym = SymbolMap::new();
163        let idents = symbol_map::extract_identifiers(&compressed, ext);
164        for ident in &idents {
165            sym.register(ident);
166        }
167        if !sym.is_empty() {
168            let mapped = sym.apply(&compressed);
169            let sym_table = sym.format_table();
170            let result = format!("{mapped}{sym_table}");
171            let sent = count_tokens(&result);
172            let savings = protocol::format_savings(original_tokens, sent);
173            return format!("{result}\n{savings}");
174        }
175    }
176
177    let sent = count_tokens(&compressed);
178    let savings = protocol::format_savings(original_tokens, sent);
179
180    format!("{compressed}\n{savings}")
181}
182
183fn is_search_command(command: &str) -> bool {
184    let cmd = command.trim_start();
185    cmd.starts_with("grep ")
186        || cmd.starts_with("rg ")
187        || cmd.starts_with("find ")
188        || cmd.starts_with("fd ")
189        || cmd.starts_with("ag ")
190        || cmd.starts_with("ack ")
191}
192
193fn generic_compress(output: &str) -> String {
194    let output = crate::core::compressor::strip_ansi(output);
195    let lines: Vec<&str> = output
196        .lines()
197        .filter(|l| {
198            let t = l.trim();
199            !t.is_empty()
200        })
201        .collect();
202
203    if lines.len() <= 20 {
204        return lines.join("\n");
205    }
206
207    let show_count = (lines.len() / 3).min(30);
208    let half = show_count / 2;
209    let first = &lines[..half];
210    let last = &lines[lines.len() - half..];
211    let omitted = lines.len() - (half * 2);
212    format!(
213        "{}\n[truncated: showing {}/{} lines, {} omitted. Use raw=true for full output.]\n{}",
214        first.join("\n"),
215        half * 2,
216        lines.len(),
217        omitted,
218        last.join("\n")
219    )
220}
221
222fn looks_like_code(text: &str) -> bool {
223    let indicators = [
224        "fn ",
225        "pub ",
226        "let ",
227        "const ",
228        "impl ",
229        "struct ",
230        "enum ",
231        "function ",
232        "class ",
233        "import ",
234        "export ",
235        "def ",
236        "async ",
237        "=>",
238        "->",
239        "::",
240        "self.",
241        "this.",
242    ];
243    let total_lines = text.lines().count();
244    if total_lines < 3 {
245        return false;
246    }
247    let code_lines = text
248        .lines()
249        .filter(|l| indicators.iter().any(|i| l.contains(i)))
250        .count();
251    code_lines as f64 / total_lines as f64 > 0.15
252}
253
254fn detect_ext_from_command(command: &str) -> &str {
255    let cmd = command.to_lowercase();
256    if cmd.contains("cargo") || cmd.contains(".rs") {
257        "rs"
258    } else if cmd.contains("npm")
259        || cmd.contains("node")
260        || cmd.contains(".ts")
261        || cmd.contains(".js")
262    {
263        "ts"
264    } else if cmd.contains("python") || cmd.contains("pip") || cmd.contains(".py") {
265        "py"
266    } else if cmd.contains("go ") || cmd.contains(".go") {
267        "go"
268    } else {
269        "rs"
270    }
271}
272
273/// Detects OAuth device code flow output that must not be compressed.
274/// Uses a two-tier approach: strong signals match alone (very specific to
275/// device code flows), weak signals require a URL/domain in the same output.
276pub fn contains_auth_flow(output: &str) -> bool {
277    let lower = output.to_lowercase();
278
279    const STRONG_SIGNALS: &[&str] = &[
280        "devicelogin",
281        "deviceauth",
282        "device_code",
283        "device code",
284        "device-code",
285        "verification_uri",
286        "user_code",
287        "one-time code",
288    ];
289
290    if STRONG_SIGNALS.iter().any(|s| lower.contains(s)) {
291        return true;
292    }
293
294    const WEAK_SIGNALS: &[&str] = &[
295        "enter the code",
296        "enter this code",
297        "enter code:",
298        "use the code",
299        "use a web browser to open",
300        "open the page",
301        "authenticate by visiting",
302        "sign in with the code",
303        "sign in using a code",
304        "verification code",
305        "authorize this device",
306        "waiting for authentication",
307        "waiting for login",
308        "waiting for you to authenticate",
309        "open your browser",
310        "open in your browser",
311    ];
312
313    let has_weak_signal = WEAK_SIGNALS.iter().any(|s| lower.contains(s));
314    if !has_weak_signal {
315        return false;
316    }
317
318    lower.contains("http://") || lower.contains("https://")
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn normalize_cmd_no_change_on_unix() {
327        if cfg!(windows) {
328            return;
329        }
330        assert_eq!(
331            normalize_command_for_shell("cd /tmp; ls -la"),
332            "cd /tmp; ls -la"
333        );
334    }
335
336    #[test]
337    fn validate_allows_safe_commands() {
338        assert!(validate_command("git status").is_none());
339        assert!(validate_command("cargo test").is_none());
340        assert!(validate_command("npm run build").is_none());
341        assert!(validate_command("ls -la").is_none());
342    }
343
344    #[test]
345    fn validate_blocks_file_writes() {
346        assert!(validate_command("cat > file.py << 'EOF'\nprint('hi')\nEOF").is_some());
347        assert!(validate_command("echo 'data' > output.txt").is_some());
348        assert!(validate_command("tee /tmp/file.txt").is_some());
349        assert!(validate_command("printf 'hello' > test.txt").is_some());
350        assert!(validate_command("cat << EOF\ncontent\nEOF").is_some());
351    }
352
353    #[test]
354    fn validate_blocks_oversized_commands() {
355        let huge = "x".repeat(MAX_COMMAND_BYTES + 1);
356        let result = validate_command(&huge);
357        assert!(result.is_some());
358        assert!(result.unwrap().contains("too large"));
359    }
360
361    #[test]
362    fn validate_allows_cat_without_redirect() {
363        assert!(validate_command("cat file.txt").is_none());
364    }
365
366    // --- Auth flow detection: strong signals (no URL needed) ---
367
368    #[test]
369    fn auth_flow_detects_azure_device_code() {
370        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.";
371        assert!(contains_auth_flow(output));
372    }
373
374    #[test]
375    fn auth_flow_detects_gh_auth_one_time_code() {
376        let output =
377            "! First copy your one-time code: ABCD-1234\n- Press Enter to open github.com in your browser...";
378        assert!(contains_auth_flow(output));
379    }
380
381    #[test]
382    fn auth_flow_detects_device_code_json() {
383        let output = r#"{"device_code":"abc123","user_code":"ABCD-1234","verification_uri":"https://example.com/activate"}"#;
384        assert!(contains_auth_flow(output));
385    }
386
387    #[test]
388    fn auth_flow_detects_verification_uri_field() {
389        let output =
390            r#"{"verification_uri": "https://login.microsoftonline.com/common/oauth2/deviceauth"}"#;
391        assert!(contains_auth_flow(output));
392    }
393
394    #[test]
395    fn auth_flow_detects_user_code_field() {
396        let output = r#"{"user_code": "FGHJK-LMNOP", "expires_in": 900}"#;
397        assert!(contains_auth_flow(output));
398    }
399
400    // --- Auth flow detection: weak signals (require URL) ---
401
402    #[test]
403    fn auth_flow_detects_gcloud_with_url() {
404        let output = "Go to the following link in your browser:\n\n    https://accounts.google.com/o/oauth2/auth?response_type=code\n\nEnter verification code: ";
405        assert!(contains_auth_flow(output));
406    }
407
408    #[test]
409    fn auth_flow_detects_aws_sso_with_url() {
410        let output = "If the browser does not open, open the following URL:\nhttps://device.sso.us-east-1.amazonaws.com/\n\nThen enter the code:\nABCD-EFGH";
411        assert!(contains_auth_flow(output));
412    }
413
414    #[test]
415    fn auth_flow_detects_firebase_with_url() {
416        let output = "Visit this URL on this device to log in:\nhttps://accounts.google.com/o/oauth2/auth?...\n\nWaiting for authentication...";
417        assert!(contains_auth_flow(output));
418    }
419
420    #[test]
421    fn auth_flow_detects_generic_browser_open_with_url() {
422        let output =
423            "Open your browser to https://login.example.com/device and enter the code XYZW-1234";
424        assert!(contains_auth_flow(output));
425    }
426
427    // --- False positive protection ---
428
429    #[test]
430    fn auth_flow_ignores_normal_build_output() {
431        let output = "Compiling lean-ctx v2.21.9\nFinished release profile\n";
432        assert!(!contains_auth_flow(output));
433    }
434
435    #[test]
436    fn auth_flow_ignores_git_output() {
437        let output = "On branch main\nYour branch is up to date with 'origin/main'.\nnothing to commit, working tree clean";
438        assert!(!contains_auth_flow(output));
439    }
440
441    #[test]
442    fn auth_flow_ignores_npm_install_output() {
443        let output = "added 150 packages in 3s\n\n24 packages are looking for funding\n  run `npm fund` for details\nhttps://npmjs.com/package/lean-ctx";
444        assert!(!contains_auth_flow(output));
445    }
446
447    #[test]
448    fn auth_flow_ignores_docs_mentioning_auth() {
449        let output = "The authorization code grant type is the most common OAuth flow.\nSee https://oauth.net/2/grant-types/ for details.";
450        assert!(!contains_auth_flow(output));
451    }
452
453    #[test]
454    fn auth_flow_weak_signal_requires_url() {
455        let output = "Please enter the code ABC123 in the terminal";
456        assert!(!contains_auth_flow(output));
457    }
458
459    #[test]
460    fn auth_flow_weak_signal_without_url_is_ignored() {
461        let output = "Waiting for authentication to complete... done!";
462        assert!(!contains_auth_flow(output));
463    }
464
465    #[test]
466    fn auth_flow_ignores_virtualenv_activate() {
467        let output = "Created virtualenv at .venv\nRun: source .venv/bin/activate";
468        assert!(!contains_auth_flow(output));
469    }
470
471    #[test]
472    fn auth_flow_ignores_api_response_with_code_field() {
473        let output = r#"{"status": "ok", "code": 200, "message": "success"}"#;
474        assert!(!contains_auth_flow(output));
475    }
476
477    // --- Integration: handle() preserves auth flow ---
478
479    #[test]
480    fn handle_preserves_auth_flow_output_fully() {
481        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.\nWaiting for you...\nLine 3\nLine 4\nLine 5\nLine 6\nLine 7\nLine 8\nLine 9\nLine 10\nLine 11\nLine 12\nLine 13";
482        let result = handle("az login --use-device-code", output, CrpMode::Off);
483        assert!(result.contains("ABCD1234"), "auth code must be preserved");
484        assert!(result.contains("devicelogin"), "URL must be preserved");
485        assert!(
486            result.contains("auth/device-code flow detected"),
487            "detection note must be present"
488        );
489        assert!(
490            result.contains("Line 13"),
491            "all lines must be preserved (no truncation)"
492        );
493    }
494
495    #[test]
496    fn handle_compresses_normal_output_not_auth() {
497        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i} of output")).collect();
498        let output = lines.join("\n");
499        let result = handle("some-tool check", &output, CrpMode::Off);
500        assert!(
501            !result.contains("auth/device-code flow detected"),
502            "normal output must not trigger auth detection"
503        );
504        assert!(
505            result.len() < output.len() + 100,
506            "normal output should be compressed, not inflated"
507        );
508    }
509
510    #[test]
511    fn is_search_command_detects_grep() {
512        assert!(is_search_command("grep -r pattern src/"));
513        assert!(is_search_command("rg pattern src/"));
514        assert!(is_search_command("find . -name '*.rs'"));
515        assert!(is_search_command("fd pattern"));
516        assert!(is_search_command("ag pattern src/"));
517        assert!(is_search_command("ack pattern"));
518    }
519
520    #[test]
521    fn is_search_command_rejects_non_search() {
522        assert!(!is_search_command("cargo build"));
523        assert!(!is_search_command("git status"));
524        assert!(!is_search_command("npm install"));
525        assert!(!is_search_command("cat file.rs"));
526    }
527
528    #[test]
529    fn generic_compress_preserves_short_output() {
530        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i}")).collect();
531        let output = lines.join("\n");
532        let result = generic_compress(&output);
533        assert_eq!(result, output);
534    }
535
536    #[test]
537    fn generic_compress_scales_with_length() {
538        let lines: Vec<String> = (1..=60).map(|i| format!("Line {i}")).collect();
539        let output = lines.join("\n");
540        let result = generic_compress(&output);
541        assert!(result.contains("truncated"));
542        let shown_count = result.lines().count();
543        assert!(
544            shown_count > 10,
545            "should show more than old 6-line limit, got {shown_count}"
546        );
547        assert!(shown_count < 60, "should be truncated, not full output");
548    }
549
550    #[test]
551    fn handle_preserves_search_results() {
552        let lines: Vec<String> = (1..=30)
553            .map(|i| format!("src/file{i}.rs:42: fn search_result()"))
554            .collect();
555        let output = lines.join("\n");
556        let result = handle("rg search_result src/", &output, CrpMode::Off);
557        for i in 1..=30 {
558            assert!(
559                result.contains(&format!("file{i}")),
560                "search result file{i} should be preserved in output"
561            );
562        }
563    }
564}