Skip to main content

lean_ctx/tools/
ctx_shell.rs

1use crate::tools::CrpMode;
2
3const MAX_COMMAND_BYTES: usize = 8192;
4
5/// Validates a shell command before execution. Returns Some(error_message) if
6/// the command should be rejected, None if it's safe to run.
7pub fn validate_command(command: &str) -> Option<String> {
8    if command.len() > MAX_COMMAND_BYTES {
9        return Some(format!(
10            "ERROR: Command too large ({} bytes, limit {}). \
11             If you're writing file content, use the native Write/Edit tool instead. \
12             ctx_shell is for reading command output only (git, cargo, npm, etc.).",
13            command.len(),
14            MAX_COMMAND_BYTES
15        ));
16    }
17
18    if has_file_write_redirect(command) {
19        return Some(
20            "ERROR: ctx_shell detected a file-write command (shell redirect > or >>). \
21             Use the native Write tool to create/modify files. \
22             ctx_shell is ONLY for reading command output (git status, cargo test, npm run, etc.). \
23             File writes via shell cause MCP protocol corruption on large payloads."
24                .to_string(),
25        );
26    }
27
28    let cmd_lower = command.to_lowercase();
29
30    if cmd_lower.starts_with("tee ") || cmd_lower.contains("| tee ") {
31        return Some(
32            "ERROR: ctx_shell detected a file-write command (tee). \
33             Use the native Write tool to create/modify files. \
34             ctx_shell is ONLY for reading command output."
35                .to_string(),
36        );
37    }
38
39    if is_heredoc_file_write(command) {
40        return Some(
41            "ERROR: ctx_shell detected a heredoc writing to a file. \
42             Use the native Write tool to create/modify files. \
43             ctx_shell is ONLY for reading command output. \
44             Note: heredocs for input piping (e.g. psql <<EOF) are allowed."
45                .to_string(),
46        );
47    }
48
49    None
50}
51
52/// Returns true only for heredocs that redirect to files (the dangerous pattern).
53/// Legitimate heredoc uses (input piping, inline scripts) are allowed through.
54fn is_heredoc_file_write(command: &str) -> bool {
55    let has_heredoc = command.contains("<<");
56    if !has_heredoc {
57        return false;
58    }
59    // Only block: heredoc combined with file output redirect
60    // e.g. `cat <<EOF > file.txt` or `cat <<'EOF' >> output.log`
61    let cmd_lower = command.to_lowercase();
62    let heredoc_patterns = ["<<eof", "<<'eof'", "<<\"eof\"", "<<end", "<<'end'"];
63    let has_known_heredoc = heredoc_patterns.iter().any(|p| cmd_lower.contains(p));
64    if !has_known_heredoc {
65        return false;
66    }
67    has_file_write_redirect(command)
68}
69
70/// Detects shell redirect operators (`>` or `>>`) that write to files.
71/// Ignores `>` inside quotes, `2>` (stderr), `/dev/null`, and comparison operators.
72fn has_file_write_redirect(command: &str) -> bool {
73    let bytes = command.as_bytes();
74    let len = bytes.len();
75    let mut i = 0;
76    let mut in_single_quote = false;
77    let mut in_double_quote = false;
78
79    while i < len {
80        let c = bytes[i];
81        if c == b'\'' && !in_double_quote {
82            in_single_quote = !in_single_quote;
83        } else if c == b'"' && !in_single_quote {
84            in_double_quote = !in_double_quote;
85        } else if c == b'>' && !in_single_quote && !in_double_quote {
86            if i > 0 && bytes[i - 1] == b'2' {
87                i += 1;
88                continue;
89            }
90            let target_start = if i + 1 < len && bytes[i + 1] == b'>' {
91                i + 2
92            } else {
93                i + 1
94            };
95            let target: String = command[target_start..]
96                .trim_start()
97                .chars()
98                .take_while(|c| !c.is_whitespace())
99                .collect();
100            if target == "/dev/null" {
101                i += 1;
102                continue;
103            }
104            if !target.is_empty() {
105                return true;
106            }
107        }
108        i += 1;
109    }
110    false
111}
112
113/// On Windows cmd.exe, `;` is not a valid command separator.
114/// Convert `cmd1; cmd2` to `cmd1 && cmd2` when running under cmd.exe.
115pub fn normalize_command_for_shell(command: &str) -> String {
116    if !cfg!(windows) {
117        return command.to_string();
118    }
119    let (_, flag) = crate::shell::shell_and_flag();
120    if flag != "/C" {
121        return command.to_string();
122    }
123    let bytes = command.as_bytes();
124    let mut result = Vec::with_capacity(bytes.len() + 16);
125    let mut in_single = false;
126    let mut in_double = false;
127    for (i, &b) in bytes.iter().enumerate() {
128        if b == b'\'' && !in_double {
129            in_single = !in_single;
130        } else if b == b'"' && !in_single {
131            in_double = !in_double;
132        } else if b == b';' && !in_single && !in_double {
133            result.extend_from_slice(b" && ");
134            continue;
135        }
136        result.push(b);
137        let _ = i;
138    }
139    String::from_utf8(result).unwrap_or_else(|_| command.to_string())
140}
141
142/// Compresses shell command output using the unified compression pipeline.
143/// This delegates to the same `compress_if_beneficial` logic used by the CLI,
144/// ensuring consistent behavior for excluded_commands, structural routing, and terse.
145pub fn handle(command: &str, output: &str, _crp_mode: CrpMode) -> String {
146    crate::shell::compress::engine::compress_if_beneficial(command, output)
147}
148
149#[cfg(test)]
150fn is_search_command(command: &str) -> bool {
151    let cmd = command.trim_start();
152    cmd.starts_with("grep ")
153        || cmd.starts_with("rg ")
154        || cmd.starts_with("find ")
155        || cmd.starts_with("fd ")
156        || cmd.starts_with("ag ")
157        || cmd.starts_with("ack ")
158}
159
160#[cfg(test)]
161fn generic_compress(output: &str) -> String {
162    let output = crate::core::compressor::strip_ansi(output);
163    let lines: Vec<&str> = output
164        .lines()
165        .filter(|l| {
166            let t = l.trim();
167            !t.is_empty()
168        })
169        .collect();
170
171    if lines.len() <= 20 {
172        return lines.join("\n");
173    }
174
175    let show_count = (lines.len() / 3).min(30);
176    let half = show_count / 2;
177    let first = &lines[..half];
178    let last = &lines[lines.len() - half..];
179    let omitted = lines.len() - (half * 2);
180    format!(
181        "{}\n[truncated: showing {}/{} lines, {} omitted. Use raw=true for full output.]\n{}",
182        first.join("\n"),
183        half * 2,
184        lines.len(),
185        omitted,
186        last.join("\n")
187    )
188}
189
190/// Detects OAuth device code flow output that must not be compressed.
191/// Uses a two-tier approach: strong signals match alone (very specific to
192/// device code flows), weak signals require a URL/domain in the same output.
193pub fn contains_auth_flow(output: &str) -> bool {
194    let lower = output.to_lowercase();
195
196    const STRONG_SIGNALS: &[&str] = &[
197        "devicelogin",
198        "deviceauth",
199        "device_code",
200        "device code",
201        "device-code",
202        "verification_uri",
203        "user_code",
204        "one-time code",
205    ];
206
207    if STRONG_SIGNALS.iter().any(|s| lower.contains(s)) {
208        return true;
209    }
210
211    const WEAK_SIGNALS: &[&str] = &[
212        "enter the code",
213        "enter this code",
214        "enter code:",
215        "use the code",
216        "use a web browser to open",
217        "open the page",
218        "authenticate by visiting",
219        "sign in with the code",
220        "sign in using a code",
221        "verification code",
222        "authorize this device",
223        "waiting for authentication",
224        "waiting for login",
225        "waiting for you to authenticate",
226        "open your browser",
227        "open in your browser",
228    ];
229
230    let has_weak_signal = WEAK_SIGNALS.iter().any(|s| lower.contains(s));
231    if !has_weak_signal {
232        return false;
233    }
234
235    lower.contains("http://") || lower.contains("https://")
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn normalize_cmd_no_change_on_unix() {
244        if cfg!(windows) {
245            return;
246        }
247        assert_eq!(
248            normalize_command_for_shell("cd /tmp; ls -la"),
249            "cd /tmp; ls -la"
250        );
251    }
252
253    #[test]
254    fn validate_allows_safe_commands() {
255        assert!(validate_command("git status").is_none());
256        assert!(validate_command("cargo test").is_none());
257        assert!(validate_command("npm run build").is_none());
258        assert!(validate_command("ls -la").is_none());
259    }
260
261    #[test]
262    fn validate_blocks_file_writes() {
263        assert!(validate_command("echo 'data' > output.txt").is_some());
264        assert!(validate_command("tee /tmp/file.txt").is_some());
265        assert!(validate_command("printf 'hello' > test.txt").is_some());
266    }
267
268    #[test]
269    fn validate_blocks_heredoc_with_file_redirect() {
270        assert!(validate_command("cat > file.py <<'EOF'\nprint('hi')\nEOF").is_some());
271        assert!(validate_command("cat <<EOF > output.txt\nhello\nEOF").is_some());
272        assert!(validate_command("cat <<'END' >> logfile.txt\ndata\nEND").is_some());
273    }
274
275    #[test]
276    fn validate_allows_heredoc_without_file_redirect() {
277        assert!(validate_command("cat <<EOF\nhello world\nEOF").is_none());
278        assert!(validate_command("psql -d mydb <<EOF\nSELECT 1;\nEOF").is_none());
279        assert!(
280            validate_command("git commit -m \"$(cat <<'EOF'\nfix: something\nEOF\n)\"").is_none()
281        );
282        assert!(validate_command("grep pattern <<EOF\nfoo\nbar\nEOF").is_none());
283    }
284
285    #[test]
286    fn validate_blocks_oversized_commands() {
287        let huge = "x".repeat(MAX_COMMAND_BYTES + 1);
288        let result = validate_command(&huge);
289        assert!(result.is_some());
290        assert!(result.unwrap().contains("too large"));
291    }
292
293    #[test]
294    fn validate_allows_cat_without_redirect() {
295        assert!(validate_command("cat file.txt").is_none());
296    }
297
298    // --- Auth flow detection: strong signals (no URL needed) ---
299
300    #[test]
301    fn auth_flow_detects_azure_device_code() {
302        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.";
303        assert!(contains_auth_flow(output));
304    }
305
306    #[test]
307    fn auth_flow_detects_gh_auth_one_time_code() {
308        let output =
309            "! First copy your one-time code: ABCD-1234\n- Press Enter to open github.com in your browser...";
310        assert!(contains_auth_flow(output));
311    }
312
313    #[test]
314    fn auth_flow_detects_device_code_json() {
315        let output = r#"{"device_code":"abc123","user_code":"ABCD-1234","verification_uri":"https://example.com/activate"}"#;
316        assert!(contains_auth_flow(output));
317    }
318
319    #[test]
320    fn auth_flow_detects_verification_uri_field() {
321        let output =
322            r#"{"verification_uri": "https://login.microsoftonline.com/common/oauth2/deviceauth"}"#;
323        assert!(contains_auth_flow(output));
324    }
325
326    #[test]
327    fn auth_flow_detects_user_code_field() {
328        let output = r#"{"user_code": "FGHJK-LMNOP", "expires_in": 900}"#;
329        assert!(contains_auth_flow(output));
330    }
331
332    // --- Auth flow detection: weak signals (require URL) ---
333
334    #[test]
335    fn auth_flow_detects_gcloud_with_url() {
336        let output = "Go to the following link in your browser:\n\n    https://accounts.google.com/o/oauth2/auth?response_type=code\n\nEnter verification code: ";
337        assert!(contains_auth_flow(output));
338    }
339
340    #[test]
341    fn auth_flow_detects_aws_sso_with_url() {
342        let output = "If the browser does not open, open the following URL:\nhttps://device.sso.us-east-1.amazonaws.com/\n\nThen enter the code:\nABCD-EFGH";
343        assert!(contains_auth_flow(output));
344    }
345
346    #[test]
347    fn auth_flow_detects_firebase_with_url() {
348        let output = "Visit this URL on this device to log in:\nhttps://accounts.google.com/o/oauth2/auth?...\n\nWaiting for authentication...";
349        assert!(contains_auth_flow(output));
350    }
351
352    #[test]
353    fn auth_flow_detects_generic_browser_open_with_url() {
354        let output =
355            "Open your browser to https://login.example.com/device and enter the code XYZW-1234";
356        assert!(contains_auth_flow(output));
357    }
358
359    // --- False positive protection ---
360
361    #[test]
362    fn auth_flow_ignores_normal_build_output() {
363        let output = "Compiling lean-ctx v2.21.9\nFinished release profile\n";
364        assert!(!contains_auth_flow(output));
365    }
366
367    #[test]
368    fn auth_flow_ignores_git_output() {
369        let output = "On branch main\nYour branch is up to date with 'origin/main'.\nnothing to commit, working tree clean";
370        assert!(!contains_auth_flow(output));
371    }
372
373    #[test]
374    fn auth_flow_ignores_npm_install_output() {
375        let output = "added 150 packages in 3s\n\n24 packages are looking for funding\n  run `npm fund` for details\nhttps://npmjs.com/package/lean-ctx";
376        assert!(!contains_auth_flow(output));
377    }
378
379    #[test]
380    fn auth_flow_ignores_docs_mentioning_auth() {
381        let output = "The authorization code grant type is the most common OAuth flow.\nSee https://oauth.net/2/grant-types/ for details.";
382        assert!(!contains_auth_flow(output));
383    }
384
385    #[test]
386    fn auth_flow_weak_signal_requires_url() {
387        let output = "Please enter the code ABC123 in the terminal";
388        assert!(!contains_auth_flow(output));
389    }
390
391    #[test]
392    fn auth_flow_weak_signal_without_url_is_ignored() {
393        let output = "Waiting for authentication to complete... done!";
394        assert!(!contains_auth_flow(output));
395    }
396
397    #[test]
398    fn auth_flow_ignores_virtualenv_activate() {
399        let output = "Created virtualenv at .venv\nRun: source .venv/bin/activate";
400        assert!(!contains_auth_flow(output));
401    }
402
403    #[test]
404    fn auth_flow_ignores_api_response_with_code_field() {
405        let output = r#"{"status": "ok", "code": 200, "message": "success"}"#;
406        assert!(!contains_auth_flow(output));
407    }
408
409    // --- Integration: handle() preserves auth flow ---
410
411    #[test]
412    fn handle_preserves_auth_flow_output_fully() {
413        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.\nWaiting for you...\nLine 3\nLine 4\nLine 5\nLine 6\nLine 7\nLine 8\nLine 9\nLine 10\nLine 11\nLine 12\nLine 13";
414        // az login is Passthrough via OutputPolicy, so all content is preserved
415        let result = handle("az login --use-device-code", output, CrpMode::Off);
416        assert!(result.contains("ABCD1234"), "auth code must be preserved");
417        assert!(result.contains("devicelogin"), "URL must be preserved");
418        assert!(
419            result.contains("Line 13"),
420            "all lines must be preserved (no truncation)"
421        );
422    }
423
424    #[test]
425    fn handle_compresses_normal_output_not_auth() {
426        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i} of output")).collect();
427        let output = lines.join("\n");
428        let result = handle("some-tool check", &output, CrpMode::Off);
429        assert!(
430            !result.contains("auth/device-code flow detected"),
431            "normal output must not trigger auth detection"
432        );
433        assert!(
434            result.len() < output.len() + 100,
435            "normal output should be compressed, not inflated"
436        );
437    }
438
439    #[test]
440    fn is_search_command_detects_grep() {
441        assert!(is_search_command("grep -r pattern src/"));
442        assert!(is_search_command("rg pattern src/"));
443        assert!(is_search_command("find . -name '*.rs'"));
444        assert!(is_search_command("fd pattern"));
445        assert!(is_search_command("ag pattern src/"));
446        assert!(is_search_command("ack pattern"));
447    }
448
449    #[test]
450    fn is_search_command_rejects_non_search() {
451        assert!(!is_search_command("cargo build"));
452        assert!(!is_search_command("git status"));
453        assert!(!is_search_command("npm install"));
454        assert!(!is_search_command("cat file.rs"));
455    }
456
457    #[test]
458    fn generic_compress_preserves_short_output() {
459        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i}")).collect();
460        let output = lines.join("\n");
461        let result = generic_compress(&output);
462        assert_eq!(result, output);
463    }
464
465    #[test]
466    fn generic_compress_scales_with_length() {
467        let lines: Vec<String> = (1..=60).map(|i| format!("Line {i}")).collect();
468        let output = lines.join("\n");
469        let result = generic_compress(&output);
470        assert!(result.contains("truncated"));
471        let shown_count = result.lines().count();
472        assert!(
473            shown_count > 10,
474            "should show more than old 6-line limit, got {shown_count}"
475        );
476        assert!(shown_count < 60, "should be truncated, not full output");
477    }
478
479    #[test]
480    fn handle_preserves_search_results() {
481        let lines: Vec<String> = (1..=30)
482            .map(|i| format!("src/file{i}.rs:42: fn search_result()"))
483            .collect();
484        let output = lines.join("\n");
485        let result = handle("rg search_result src/", &output, CrpMode::Off);
486        for i in 1..=30 {
487            assert!(
488                result.contains(&format!("file{i}")),
489                "search result file{i} should be preserved in output"
490            );
491        }
492    }
493}