Skip to main content

lean_ctx/tools/
ctx_shell.rs

1use crate::core::patterns;
2use crate::core::protocol;
3use crate::core::symbol_map::{self, SymbolMap};
4use crate::core::tokens::count_tokens;
5use crate::tools::CrpMode;
6
7const MAX_COMMAND_BYTES: usize = 8192;
8
9const HEREDOC_PATTERNS: &[&str] = &[
10    "<< 'EOF'",
11    "<<'EOF'",
12    "<< 'ENDOFFILE'",
13    "<<'ENDOFFILE'",
14    "<< 'END'",
15    "<<'END'",
16    "<< EOF",
17    "<<EOF",
18    "cat <<",
19];
20
21/// Validates a shell command before execution. Returns Some(error_message) if
22/// the command should be rejected, None if it's safe to run.
23pub fn validate_command(command: &str) -> Option<String> {
24    if command.len() > MAX_COMMAND_BYTES {
25        return Some(format!(
26            "ERROR: Command too large ({} bytes, limit {}). \
27             If you're writing file content, use the native Write/Edit tool instead. \
28             ctx_shell is for reading command output only (git, cargo, npm, etc.).",
29            command.len(),
30            MAX_COMMAND_BYTES
31        ));
32    }
33
34    if has_file_write_redirect(command) {
35        return Some(
36            "ERROR: ctx_shell detected a file-write command (shell redirect > or >>). \
37             Use the native Write tool to create/modify files. \
38             ctx_shell is ONLY for reading command output (git status, cargo test, npm run, etc.). \
39             File writes via shell cause MCP protocol corruption on large payloads."
40                .to_string(),
41        );
42    }
43
44    let cmd_lower = command.to_lowercase();
45
46    if cmd_lower.starts_with("tee ") || cmd_lower.contains("| tee ") {
47        return Some(
48            "ERROR: ctx_shell detected a file-write command (tee). \
49             Use the native Write tool to create/modify files. \
50             ctx_shell is ONLY for reading command output."
51                .to_string(),
52        );
53    }
54
55    for pattern in HEREDOC_PATTERNS {
56        if cmd_lower.contains(&pattern.to_lowercase()) {
57            return Some(
58                "ERROR: ctx_shell detected a heredoc file-write command. \
59                 Use the native Write tool to create/modify files. \
60                 ctx_shell is ONLY for reading command output."
61                    .to_string(),
62            );
63        }
64    }
65
66    None
67}
68
69/// Detects shell redirect operators (`>` or `>>`) that write to files.
70/// Ignores `>` inside quotes, `2>` (stderr), `/dev/null`, and comparison operators.
71fn has_file_write_redirect(command: &str) -> bool {
72    let bytes = command.as_bytes();
73    let len = bytes.len();
74    let mut i = 0;
75    let mut in_single_quote = false;
76    let mut in_double_quote = false;
77
78    while i < len {
79        let c = bytes[i];
80        if c == b'\'' && !in_double_quote {
81            in_single_quote = !in_single_quote;
82        } else if c == b'"' && !in_single_quote {
83            in_double_quote = !in_double_quote;
84        } else if c == b'>' && !in_single_quote && !in_double_quote {
85            if i > 0 && bytes[i - 1] == b'2' {
86                i += 1;
87                continue;
88            }
89            let target_start = if i + 1 < len && bytes[i + 1] == b'>' {
90                i + 2
91            } else {
92                i + 1
93            };
94            let target: String = command[target_start..]
95                .trim_start()
96                .chars()
97                .take_while(|c| !c.is_whitespace())
98                .collect();
99            if target == "/dev/null" {
100                i += 1;
101                continue;
102            }
103            if !target.is_empty() {
104                return true;
105            }
106        }
107        i += 1;
108    }
109    false
110}
111
112/// On Windows cmd.exe, `;` is not a valid command separator.
113/// Convert `cmd1; cmd2` to `cmd1 && cmd2` when running under cmd.exe.
114pub fn normalize_command_for_shell(command: &str) -> String {
115    if !cfg!(windows) {
116        return command.to_string();
117    }
118    let (_, flag) = crate::shell::shell_and_flag();
119    if flag != "/C" {
120        return command.to_string();
121    }
122    let bytes = command.as_bytes();
123    let mut result = Vec::with_capacity(bytes.len() + 16);
124    let mut in_single = false;
125    let mut in_double = false;
126    for (i, &b) in bytes.iter().enumerate() {
127        if b == b'\'' && !in_double {
128            in_single = !in_single;
129        } else if b == b'"' && !in_single {
130            in_double = !in_double;
131        } else if b == b';' && !in_single && !in_double {
132            result.extend_from_slice(b" && ");
133            continue;
134        }
135        result.push(b);
136        let _ = i;
137    }
138    String::from_utf8(result).unwrap_or_else(|_| command.to_string())
139}
140
141pub fn handle(command: &str, output: &str, crp_mode: CrpMode) -> String {
142    let original_tokens = count_tokens(output);
143
144    if contains_auth_flow(output) {
145        let savings = protocol::format_savings(original_tokens, original_tokens);
146        return format!(
147            "{output}\n[lean-ctx: auth/device-code flow detected — output preserved uncompressed]\n{savings}"
148        );
149    }
150
151    let compressed = match patterns::compress_output(command, output) {
152        Some(c) => c,
153        None => generic_compress(output),
154    };
155
156    if crp_mode.is_tdd() && looks_like_code(&compressed) {
157        let ext = detect_ext_from_command(command);
158        let mut sym = SymbolMap::new();
159        let idents = symbol_map::extract_identifiers(&compressed, ext);
160        for ident in &idents {
161            sym.register(ident);
162        }
163        if !sym.is_empty() {
164            let mapped = sym.apply(&compressed);
165            let sym_table = sym.format_table();
166            let result = format!("{mapped}{sym_table}");
167            let sent = count_tokens(&result);
168            let savings = protocol::format_savings(original_tokens, sent);
169            return format!("{result}\n{savings}");
170        }
171    }
172
173    let sent = count_tokens(&compressed);
174    let savings = protocol::format_savings(original_tokens, sent);
175
176    format!("{compressed}\n{savings}")
177}
178
179fn generic_compress(output: &str) -> String {
180    let output = crate::core::compressor::strip_ansi(output);
181    let lines: Vec<&str> = output
182        .lines()
183        .filter(|l| {
184            let t = l.trim();
185            !t.is_empty()
186        })
187        .collect();
188
189    if lines.len() <= 10 {
190        return lines.join("\n");
191    }
192
193    let first_3 = &lines[..3];
194    let last_3 = &lines[lines.len() - 3..];
195    let omitted = lines.len() - 6;
196    format!(
197        "{}\n[truncated: showing 6/{} lines, {} omitted. Use raw=true for full output.]\n{}",
198        first_3.join("\n"),
199        lines.len(),
200        omitted,
201        last_3.join("\n")
202    )
203}
204
205fn looks_like_code(text: &str) -> bool {
206    let indicators = [
207        "fn ",
208        "pub ",
209        "let ",
210        "const ",
211        "impl ",
212        "struct ",
213        "enum ",
214        "function ",
215        "class ",
216        "import ",
217        "export ",
218        "def ",
219        "async ",
220        "=>",
221        "->",
222        "::",
223        "self.",
224        "this.",
225    ];
226    let total_lines = text.lines().count();
227    if total_lines < 3 {
228        return false;
229    }
230    let code_lines = text
231        .lines()
232        .filter(|l| indicators.iter().any(|i| l.contains(i)))
233        .count();
234    code_lines as f64 / total_lines as f64 > 0.15
235}
236
237fn detect_ext_from_command(command: &str) -> &str {
238    let cmd = command.to_lowercase();
239    if cmd.contains("cargo") || cmd.contains(".rs") {
240        "rs"
241    } else if cmd.contains("npm")
242        || cmd.contains("node")
243        || cmd.contains(".ts")
244        || cmd.contains(".js")
245    {
246        "ts"
247    } else if cmd.contains("python") || cmd.contains("pip") || cmd.contains(".py") {
248        "py"
249    } else if cmd.contains("go ") || cmd.contains(".go") {
250        "go"
251    } else {
252        "rs"
253    }
254}
255
256/// Detects OAuth device code flow output that must not be compressed.
257/// Uses a two-tier approach: strong signals match alone (very specific to
258/// device code flows), weak signals require a URL/domain in the same output.
259pub fn contains_auth_flow(output: &str) -> bool {
260    let lower = output.to_lowercase();
261
262    const STRONG_SIGNALS: &[&str] = &[
263        "devicelogin",
264        "deviceauth",
265        "device_code",
266        "device code",
267        "device-code",
268        "verification_uri",
269        "user_code",
270        "one-time code",
271    ];
272
273    if STRONG_SIGNALS.iter().any(|s| lower.contains(s)) {
274        return true;
275    }
276
277    const WEAK_SIGNALS: &[&str] = &[
278        "enter the code",
279        "enter this code",
280        "enter code:",
281        "use the code",
282        "use a web browser to open",
283        "open the page",
284        "authenticate by visiting",
285        "sign in with the code",
286        "sign in using a code",
287        "verification code",
288        "authorize this device",
289        "waiting for authentication",
290        "waiting for login",
291        "waiting for you to authenticate",
292        "open your browser",
293        "open in your browser",
294    ];
295
296    let has_weak_signal = WEAK_SIGNALS.iter().any(|s| lower.contains(s));
297    if !has_weak_signal {
298        return false;
299    }
300
301    lower.contains("http://") || lower.contains("https://")
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn normalize_cmd_no_change_on_unix() {
310        if cfg!(windows) {
311            return;
312        }
313        assert_eq!(
314            normalize_command_for_shell("cd /tmp; ls -la"),
315            "cd /tmp; ls -la"
316        );
317    }
318
319    #[test]
320    fn validate_allows_safe_commands() {
321        assert!(validate_command("git status").is_none());
322        assert!(validate_command("cargo test").is_none());
323        assert!(validate_command("npm run build").is_none());
324        assert!(validate_command("ls -la").is_none());
325    }
326
327    #[test]
328    fn validate_blocks_file_writes() {
329        assert!(validate_command("cat > file.py << 'EOF'\nprint('hi')\nEOF").is_some());
330        assert!(validate_command("echo 'data' > output.txt").is_some());
331        assert!(validate_command("tee /tmp/file.txt").is_some());
332        assert!(validate_command("printf 'hello' > test.txt").is_some());
333        assert!(validate_command("cat << EOF\ncontent\nEOF").is_some());
334    }
335
336    #[test]
337    fn validate_blocks_oversized_commands() {
338        let huge = "x".repeat(MAX_COMMAND_BYTES + 1);
339        let result = validate_command(&huge);
340        assert!(result.is_some());
341        assert!(result.unwrap().contains("too large"));
342    }
343
344    #[test]
345    fn validate_allows_cat_without_redirect() {
346        assert!(validate_command("cat file.txt").is_none());
347    }
348
349    // --- Auth flow detection: strong signals (no URL needed) ---
350
351    #[test]
352    fn auth_flow_detects_azure_device_code() {
353        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.";
354        assert!(contains_auth_flow(output));
355    }
356
357    #[test]
358    fn auth_flow_detects_gh_auth_one_time_code() {
359        let output =
360            "! First copy your one-time code: ABCD-1234\n- Press Enter to open github.com in your browser...";
361        assert!(contains_auth_flow(output));
362    }
363
364    #[test]
365    fn auth_flow_detects_device_code_json() {
366        let output = r#"{"device_code":"abc123","user_code":"ABCD-1234","verification_uri":"https://example.com/activate"}"#;
367        assert!(contains_auth_flow(output));
368    }
369
370    #[test]
371    fn auth_flow_detects_verification_uri_field() {
372        let output =
373            r#"{"verification_uri": "https://login.microsoftonline.com/common/oauth2/deviceauth"}"#;
374        assert!(contains_auth_flow(output));
375    }
376
377    #[test]
378    fn auth_flow_detects_user_code_field() {
379        let output = r#"{"user_code": "FGHJK-LMNOP", "expires_in": 900}"#;
380        assert!(contains_auth_flow(output));
381    }
382
383    // --- Auth flow detection: weak signals (require URL) ---
384
385    #[test]
386    fn auth_flow_detects_gcloud_with_url() {
387        let output = "Go to the following link in your browser:\n\n    https://accounts.google.com/o/oauth2/auth?response_type=code\n\nEnter verification code: ";
388        assert!(contains_auth_flow(output));
389    }
390
391    #[test]
392    fn auth_flow_detects_aws_sso_with_url() {
393        let output = "If the browser does not open, open the following URL:\nhttps://device.sso.us-east-1.amazonaws.com/\n\nThen enter the code:\nABCD-EFGH";
394        assert!(contains_auth_flow(output));
395    }
396
397    #[test]
398    fn auth_flow_detects_firebase_with_url() {
399        let output = "Visit this URL on this device to log in:\nhttps://accounts.google.com/o/oauth2/auth?...\n\nWaiting for authentication...";
400        assert!(contains_auth_flow(output));
401    }
402
403    #[test]
404    fn auth_flow_detects_generic_browser_open_with_url() {
405        let output =
406            "Open your browser to https://login.example.com/device and enter the code XYZW-1234";
407        assert!(contains_auth_flow(output));
408    }
409
410    // --- False positive protection ---
411
412    #[test]
413    fn auth_flow_ignores_normal_build_output() {
414        let output = "Compiling lean-ctx v2.21.9\nFinished release profile\n";
415        assert!(!contains_auth_flow(output));
416    }
417
418    #[test]
419    fn auth_flow_ignores_git_output() {
420        let output = "On branch main\nYour branch is up to date with 'origin/main'.\nnothing to commit, working tree clean";
421        assert!(!contains_auth_flow(output));
422    }
423
424    #[test]
425    fn auth_flow_ignores_npm_install_output() {
426        let output = "added 150 packages in 3s\n\n24 packages are looking for funding\n  run `npm fund` for details\nhttps://npmjs.com/package/lean-ctx";
427        assert!(!contains_auth_flow(output));
428    }
429
430    #[test]
431    fn auth_flow_ignores_docs_mentioning_auth() {
432        let output = "The authorization code grant type is the most common OAuth flow.\nSee https://oauth.net/2/grant-types/ for details.";
433        assert!(!contains_auth_flow(output));
434    }
435
436    #[test]
437    fn auth_flow_weak_signal_requires_url() {
438        let output = "Please enter the code ABC123 in the terminal";
439        assert!(!contains_auth_flow(output));
440    }
441
442    #[test]
443    fn auth_flow_weak_signal_without_url_is_ignored() {
444        let output = "Waiting for authentication to complete... done!";
445        assert!(!contains_auth_flow(output));
446    }
447
448    #[test]
449    fn auth_flow_ignores_virtualenv_activate() {
450        let output = "Created virtualenv at .venv\nRun: source .venv/bin/activate";
451        assert!(!contains_auth_flow(output));
452    }
453
454    #[test]
455    fn auth_flow_ignores_api_response_with_code_field() {
456        let output = r#"{"status": "ok", "code": 200, "message": "success"}"#;
457        assert!(!contains_auth_flow(output));
458    }
459
460    // --- Integration: handle() preserves auth flow ---
461
462    #[test]
463    fn handle_preserves_auth_flow_output_fully() {
464        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.\nWaiting for you...\nLine 3\nLine 4\nLine 5\nLine 6\nLine 7\nLine 8\nLine 9\nLine 10\nLine 11\nLine 12\nLine 13";
465        let result = handle("az login --use-device-code", output, CrpMode::Off);
466        assert!(result.contains("ABCD1234"), "auth code must be preserved");
467        assert!(result.contains("devicelogin"), "URL must be preserved");
468        assert!(
469            result.contains("auth/device-code flow detected"),
470            "detection note must be present"
471        );
472        assert!(
473            result.contains("Line 13"),
474            "all lines must be preserved (no truncation)"
475        );
476    }
477
478    #[test]
479    fn handle_compresses_normal_output_not_auth() {
480        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i} of output")).collect();
481        let output = lines.join("\n");
482        let result = handle("some-tool check", &output, CrpMode::Off);
483        assert!(
484            !result.contains("auth/device-code flow detected"),
485            "normal output must not trigger auth detection"
486        );
487        assert!(
488            result.len() < output.len() + 100,
489            "normal output should be compressed, not inflated"
490        );
491    }
492}