Skip to main content

lean_ctx/tools/
ctx_shell.rs

1use crate::core::patterns;
2use crate::core::protocol;
3use crate::core::symbol_map::{self, SymbolMap};
4use crate::core::tokens::count_tokens;
5use crate::tools::CrpMode;
6
7const MAX_COMMAND_BYTES: usize = 8192;
8
9const HEREDOC_PATTERNS: &[&str] = &[
10    "<< 'EOF'",
11    "<<'EOF'",
12    "<< 'ENDOFFILE'",
13    "<<'ENDOFFILE'",
14    "<< 'END'",
15    "<<'END'",
16    "<< EOF",
17    "<<EOF",
18    "cat <<",
19];
20
21/// Validates a shell command before execution. Returns Some(error_message) if
22/// the command should be rejected, None if it's safe to run.
23pub fn validate_command(command: &str) -> Option<String> {
24    if command.len() > MAX_COMMAND_BYTES {
25        return Some(format!(
26            "ERROR: Command too large ({} bytes, limit {}). \
27             If you're writing file content, use the native Write/Edit tool instead. \
28             ctx_shell is for reading command output only (git, cargo, npm, etc.).",
29            command.len(),
30            MAX_COMMAND_BYTES
31        ));
32    }
33
34    if has_file_write_redirect(command) {
35        return Some(
36            "ERROR: ctx_shell detected a file-write command (shell redirect > or >>). \
37             Use the native Write tool to create/modify files. \
38             ctx_shell is ONLY for reading command output (git status, cargo test, npm run, etc.). \
39             File writes via shell cause MCP protocol corruption on large payloads."
40                .to_string(),
41        );
42    }
43
44    let cmd_lower = command.to_lowercase();
45
46    if cmd_lower.starts_with("tee ") || cmd_lower.contains("| tee ") {
47        return Some(
48            "ERROR: ctx_shell detected a file-write command (tee). \
49             Use the native Write tool to create/modify files. \
50             ctx_shell is ONLY for reading command output."
51                .to_string(),
52        );
53    }
54
55    for pattern in HEREDOC_PATTERNS {
56        if cmd_lower.contains(&pattern.to_lowercase()) {
57            return Some(
58                "ERROR: ctx_shell detected a heredoc file-write command. \
59                 Use the native Write tool to create/modify files. \
60                 ctx_shell is ONLY for reading command output."
61                    .to_string(),
62            );
63        }
64    }
65
66    None
67}
68
69/// Detects shell redirect operators (`>` or `>>`) that write to files.
70/// Ignores `>` inside quotes, `2>` (stderr), `/dev/null`, and comparison operators.
71fn has_file_write_redirect(command: &str) -> bool {
72    let bytes = command.as_bytes();
73    let len = bytes.len();
74    let mut i = 0;
75    let mut in_single_quote = false;
76    let mut in_double_quote = false;
77
78    while i < len {
79        let c = bytes[i];
80        if c == b'\'' && !in_double_quote {
81            in_single_quote = !in_single_quote;
82        } else if c == b'"' && !in_single_quote {
83            in_double_quote = !in_double_quote;
84        } else if c == b'>' && !in_single_quote && !in_double_quote {
85            if i > 0 && bytes[i - 1] == b'2' {
86                i += 1;
87                continue;
88            }
89            let target_start = if i + 1 < len && bytes[i + 1] == b'>' {
90                i + 2
91            } else {
92                i + 1
93            };
94            let target: String = command[target_start..]
95                .trim_start()
96                .chars()
97                .take_while(|c| !c.is_whitespace())
98                .collect();
99            if target == "/dev/null" {
100                i += 1;
101                continue;
102            }
103            if !target.is_empty() {
104                return true;
105            }
106        }
107        i += 1;
108    }
109    false
110}
111
112/// On Windows cmd.exe, `;` is not a valid command separator.
113/// Convert `cmd1; cmd2` to `cmd1 && cmd2` when running under cmd.exe.
114pub fn normalize_command_for_shell(command: &str) -> String {
115    if !cfg!(windows) {
116        return command.to_string();
117    }
118    let (_, flag) = crate::shell::shell_and_flag();
119    if flag != "/C" {
120        return command.to_string();
121    }
122    let bytes = command.as_bytes();
123    let mut result = Vec::with_capacity(bytes.len() + 16);
124    let mut in_single = false;
125    let mut in_double = false;
126    for (i, &b) in bytes.iter().enumerate() {
127        if b == b'\'' && !in_double {
128            in_single = !in_single;
129        } else if b == b'"' && !in_single {
130            in_double = !in_double;
131        } else if b == b';' && !in_single && !in_double {
132            result.extend_from_slice(b" && ");
133            continue;
134        }
135        result.push(b);
136        let _ = i;
137    }
138    String::from_utf8(result).unwrap_or_else(|_| command.to_string())
139}
140
141pub fn handle(command: &str, output: &str, crp_mode: CrpMode) -> String {
142    let original_tokens = count_tokens(output);
143
144    if contains_auth_flow(output) {
145        let savings = protocol::format_savings(original_tokens, original_tokens);
146        return format!(
147            "{output}\n[lean-ctx: auth/device-code flow detected — output preserved uncompressed]\n{savings}"
148        );
149    }
150
151    let raw_compressed = match patterns::compress_output(command, output) {
152        Some(c) => crate::core::compressor::safeguard_ratio(output, &c),
153        None if is_search_command(command) => {
154            let stripped = crate::core::compressor::strip_ansi(output);
155            stripped.to_string()
156        }
157        None => generic_compress(output),
158    };
159    let compressed = crate::core::compressor::verbatim_compact(&raw_compressed);
160
161    if crp_mode.is_tdd() && looks_like_code(&compressed) {
162        let ext = detect_ext_from_command(command);
163        let mut sym = SymbolMap::new();
164        let idents = symbol_map::extract_identifiers(&compressed, ext);
165        for ident in &idents {
166            sym.register(ident);
167        }
168        if !sym.is_empty() {
169            let mapped = sym.apply(&compressed);
170            let sym_table = sym.format_table();
171            let result = format!("{mapped}{sym_table}");
172            let sent = count_tokens(&result);
173            let savings = protocol::format_savings(original_tokens, sent);
174            return format!("{result}\n{savings}");
175        }
176    }
177
178    let sent = count_tokens(&compressed);
179    let savings = protocol::format_savings(original_tokens, sent);
180
181    format!("{compressed}\n{savings}")
182}
183
184fn is_search_command(command: &str) -> bool {
185    let cmd = command.trim_start();
186    cmd.starts_with("grep ")
187        || cmd.starts_with("rg ")
188        || cmd.starts_with("find ")
189        || cmd.starts_with("fd ")
190        || cmd.starts_with("ag ")
191        || cmd.starts_with("ack ")
192}
193
194fn generic_compress(output: &str) -> String {
195    let output = crate::core::compressor::strip_ansi(output);
196    let lines: Vec<&str> = output
197        .lines()
198        .filter(|l| {
199            let t = l.trim();
200            !t.is_empty()
201        })
202        .collect();
203
204    if lines.len() <= 20 {
205        return lines.join("\n");
206    }
207
208    let show_count = (lines.len() / 3).min(30);
209    let half = show_count / 2;
210    let first = &lines[..half];
211    let last = &lines[lines.len() - half..];
212    let omitted = lines.len() - (half * 2);
213    format!(
214        "{}\n[truncated: showing {}/{} lines, {} omitted. Use raw=true for full output.]\n{}",
215        first.join("\n"),
216        half * 2,
217        lines.len(),
218        omitted,
219        last.join("\n")
220    )
221}
222
223fn looks_like_code(text: &str) -> bool {
224    let indicators = [
225        "fn ",
226        "pub ",
227        "let ",
228        "const ",
229        "impl ",
230        "struct ",
231        "enum ",
232        "function ",
233        "class ",
234        "import ",
235        "export ",
236        "def ",
237        "async ",
238        "=>",
239        "->",
240        "::",
241        "self.",
242        "this.",
243    ];
244    let total_lines = text.lines().count();
245    if total_lines < 3 {
246        return false;
247    }
248    let code_lines = text
249        .lines()
250        .filter(|l| indicators.iter().any(|i| l.contains(i)))
251        .count();
252    code_lines as f64 / total_lines as f64 > 0.15
253}
254
255fn detect_ext_from_command(command: &str) -> &str {
256    let cmd = command.to_lowercase();
257    if cmd.contains("cargo") || cmd.contains(".rs") {
258        "rs"
259    } else if cmd.contains("npm")
260        || cmd.contains("node")
261        || cmd.contains(".ts")
262        || cmd.contains(".js")
263    {
264        "ts"
265    } else if cmd.contains("python") || cmd.contains("pip") || cmd.contains(".py") {
266        "py"
267    } else if cmd.contains("go ") || cmd.contains(".go") {
268        "go"
269    } else {
270        "rs"
271    }
272}
273
274/// Detects OAuth device code flow output that must not be compressed.
275/// Uses a two-tier approach: strong signals match alone (very specific to
276/// device code flows), weak signals require a URL/domain in the same output.
277pub fn contains_auth_flow(output: &str) -> bool {
278    let lower = output.to_lowercase();
279
280    const STRONG_SIGNALS: &[&str] = &[
281        "devicelogin",
282        "deviceauth",
283        "device_code",
284        "device code",
285        "device-code",
286        "verification_uri",
287        "user_code",
288        "one-time code",
289    ];
290
291    if STRONG_SIGNALS.iter().any(|s| lower.contains(s)) {
292        return true;
293    }
294
295    const WEAK_SIGNALS: &[&str] = &[
296        "enter the code",
297        "enter this code",
298        "enter code:",
299        "use the code",
300        "use a web browser to open",
301        "open the page",
302        "authenticate by visiting",
303        "sign in with the code",
304        "sign in using a code",
305        "verification code",
306        "authorize this device",
307        "waiting for authentication",
308        "waiting for login",
309        "waiting for you to authenticate",
310        "open your browser",
311        "open in your browser",
312    ];
313
314    let has_weak_signal = WEAK_SIGNALS.iter().any(|s| lower.contains(s));
315    if !has_weak_signal {
316        return false;
317    }
318
319    lower.contains("http://") || lower.contains("https://")
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn normalize_cmd_no_change_on_unix() {
328        if cfg!(windows) {
329            return;
330        }
331        assert_eq!(
332            normalize_command_for_shell("cd /tmp; ls -la"),
333            "cd /tmp; ls -la"
334        );
335    }
336
337    #[test]
338    fn validate_allows_safe_commands() {
339        assert!(validate_command("git status").is_none());
340        assert!(validate_command("cargo test").is_none());
341        assert!(validate_command("npm run build").is_none());
342        assert!(validate_command("ls -la").is_none());
343    }
344
345    #[test]
346    fn validate_blocks_file_writes() {
347        assert!(validate_command("cat > file.py << 'EOF'\nprint('hi')\nEOF").is_some());
348        assert!(validate_command("echo 'data' > output.txt").is_some());
349        assert!(validate_command("tee /tmp/file.txt").is_some());
350        assert!(validate_command("printf 'hello' > test.txt").is_some());
351        assert!(validate_command("cat << EOF\ncontent\nEOF").is_some());
352    }
353
354    #[test]
355    fn validate_blocks_oversized_commands() {
356        let huge = "x".repeat(MAX_COMMAND_BYTES + 1);
357        let result = validate_command(&huge);
358        assert!(result.is_some());
359        assert!(result.unwrap().contains("too large"));
360    }
361
362    #[test]
363    fn validate_allows_cat_without_redirect() {
364        assert!(validate_command("cat file.txt").is_none());
365    }
366
367    // --- Auth flow detection: strong signals (no URL needed) ---
368
369    #[test]
370    fn auth_flow_detects_azure_device_code() {
371        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.";
372        assert!(contains_auth_flow(output));
373    }
374
375    #[test]
376    fn auth_flow_detects_gh_auth_one_time_code() {
377        let output =
378            "! First copy your one-time code: ABCD-1234\n- Press Enter to open github.com in your browser...";
379        assert!(contains_auth_flow(output));
380    }
381
382    #[test]
383    fn auth_flow_detects_device_code_json() {
384        let output = r#"{"device_code":"abc123","user_code":"ABCD-1234","verification_uri":"https://example.com/activate"}"#;
385        assert!(contains_auth_flow(output));
386    }
387
388    #[test]
389    fn auth_flow_detects_verification_uri_field() {
390        let output =
391            r#"{"verification_uri": "https://login.microsoftonline.com/common/oauth2/deviceauth"}"#;
392        assert!(contains_auth_flow(output));
393    }
394
395    #[test]
396    fn auth_flow_detects_user_code_field() {
397        let output = r#"{"user_code": "FGHJK-LMNOP", "expires_in": 900}"#;
398        assert!(contains_auth_flow(output));
399    }
400
401    // --- Auth flow detection: weak signals (require URL) ---
402
403    #[test]
404    fn auth_flow_detects_gcloud_with_url() {
405        let output = "Go to the following link in your browser:\n\n    https://accounts.google.com/o/oauth2/auth?response_type=code\n\nEnter verification code: ";
406        assert!(contains_auth_flow(output));
407    }
408
409    #[test]
410    fn auth_flow_detects_aws_sso_with_url() {
411        let output = "If the browser does not open, open the following URL:\nhttps://device.sso.us-east-1.amazonaws.com/\n\nThen enter the code:\nABCD-EFGH";
412        assert!(contains_auth_flow(output));
413    }
414
415    #[test]
416    fn auth_flow_detects_firebase_with_url() {
417        let output = "Visit this URL on this device to log in:\nhttps://accounts.google.com/o/oauth2/auth?...\n\nWaiting for authentication...";
418        assert!(contains_auth_flow(output));
419    }
420
421    #[test]
422    fn auth_flow_detects_generic_browser_open_with_url() {
423        let output =
424            "Open your browser to https://login.example.com/device and enter the code XYZW-1234";
425        assert!(contains_auth_flow(output));
426    }
427
428    // --- False positive protection ---
429
430    #[test]
431    fn auth_flow_ignores_normal_build_output() {
432        let output = "Compiling lean-ctx v2.21.9\nFinished release profile\n";
433        assert!(!contains_auth_flow(output));
434    }
435
436    #[test]
437    fn auth_flow_ignores_git_output() {
438        let output = "On branch main\nYour branch is up to date with 'origin/main'.\nnothing to commit, working tree clean";
439        assert!(!contains_auth_flow(output));
440    }
441
442    #[test]
443    fn auth_flow_ignores_npm_install_output() {
444        let output = "added 150 packages in 3s\n\n24 packages are looking for funding\n  run `npm fund` for details\nhttps://npmjs.com/package/lean-ctx";
445        assert!(!contains_auth_flow(output));
446    }
447
448    #[test]
449    fn auth_flow_ignores_docs_mentioning_auth() {
450        let output = "The authorization code grant type is the most common OAuth flow.\nSee https://oauth.net/2/grant-types/ for details.";
451        assert!(!contains_auth_flow(output));
452    }
453
454    #[test]
455    fn auth_flow_weak_signal_requires_url() {
456        let output = "Please enter the code ABC123 in the terminal";
457        assert!(!contains_auth_flow(output));
458    }
459
460    #[test]
461    fn auth_flow_weak_signal_without_url_is_ignored() {
462        let output = "Waiting for authentication to complete... done!";
463        assert!(!contains_auth_flow(output));
464    }
465
466    #[test]
467    fn auth_flow_ignores_virtualenv_activate() {
468        let output = "Created virtualenv at .venv\nRun: source .venv/bin/activate";
469        assert!(!contains_auth_flow(output));
470    }
471
472    #[test]
473    fn auth_flow_ignores_api_response_with_code_field() {
474        let output = r#"{"status": "ok", "code": 200, "message": "success"}"#;
475        assert!(!contains_auth_flow(output));
476    }
477
478    // --- Integration: handle() preserves auth flow ---
479
480    #[test]
481    fn handle_preserves_auth_flow_output_fully() {
482        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.\nWaiting for you...\nLine 3\nLine 4\nLine 5\nLine 6\nLine 7\nLine 8\nLine 9\nLine 10\nLine 11\nLine 12\nLine 13";
483        let result = handle("az login --use-device-code", output, CrpMode::Off);
484        assert!(result.contains("ABCD1234"), "auth code must be preserved");
485        assert!(result.contains("devicelogin"), "URL must be preserved");
486        assert!(
487            result.contains("auth/device-code flow detected"),
488            "detection note must be present"
489        );
490        assert!(
491            result.contains("Line 13"),
492            "all lines must be preserved (no truncation)"
493        );
494    }
495
496    #[test]
497    fn handle_compresses_normal_output_not_auth() {
498        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i} of output")).collect();
499        let output = lines.join("\n");
500        let result = handle("some-tool check", &output, CrpMode::Off);
501        assert!(
502            !result.contains("auth/device-code flow detected"),
503            "normal output must not trigger auth detection"
504        );
505        assert!(
506            result.len() < output.len() + 100,
507            "normal output should be compressed, not inflated"
508        );
509    }
510
511    #[test]
512    fn is_search_command_detects_grep() {
513        assert!(is_search_command("grep -r pattern src/"));
514        assert!(is_search_command("rg pattern src/"));
515        assert!(is_search_command("find . -name '*.rs'"));
516        assert!(is_search_command("fd pattern"));
517        assert!(is_search_command("ag pattern src/"));
518        assert!(is_search_command("ack pattern"));
519    }
520
521    #[test]
522    fn is_search_command_rejects_non_search() {
523        assert!(!is_search_command("cargo build"));
524        assert!(!is_search_command("git status"));
525        assert!(!is_search_command("npm install"));
526        assert!(!is_search_command("cat file.rs"));
527    }
528
529    #[test]
530    fn generic_compress_preserves_short_output() {
531        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i}")).collect();
532        let output = lines.join("\n");
533        let result = generic_compress(&output);
534        assert_eq!(result, output);
535    }
536
537    #[test]
538    fn generic_compress_scales_with_length() {
539        let lines: Vec<String> = (1..=60).map(|i| format!("Line {i}")).collect();
540        let output = lines.join("\n");
541        let result = generic_compress(&output);
542        assert!(result.contains("truncated"));
543        let shown_count = result.lines().count();
544        assert!(
545            shown_count > 10,
546            "should show more than old 6-line limit, got {shown_count}"
547        );
548        assert!(shown_count < 60, "should be truncated, not full output");
549    }
550
551    #[test]
552    fn handle_preserves_search_results() {
553        let lines: Vec<String> = (1..=30)
554            .map(|i| format!("src/file{i}.rs:42: fn search_result()"))
555            .collect();
556        let output = lines.join("\n");
557        let result = handle("rg search_result src/", &output, CrpMode::Off);
558        for i in 1..=30 {
559            assert!(
560                result.contains(&format!("file{i}")),
561                "search result file{i} should be preserved in output"
562            );
563        }
564    }
565}