Skip to main content

lean_ctx/tools/
ctx_shell.rs

1use crate::core::patterns;
2use crate::core::protocol;
3use crate::core::symbol_map::{self, SymbolMap};
4use crate::core::tokens::count_tokens;
5use crate::tools::CrpMode;
6
7const MAX_COMMAND_BYTES: usize = 8192;
8
9const HEREDOC_PATTERNS: &[&str] = &[
10    "<< 'EOF'",
11    "<<'EOF'",
12    "<< 'ENDOFFILE'",
13    "<<'ENDOFFILE'",
14    "<< 'END'",
15    "<<'END'",
16    "<< EOF",
17    "<<EOF",
18    "cat <<",
19];
20
21/// Validates a shell command before execution. Returns Some(error_message) if
22/// the command should be rejected, None if it's safe to run.
23pub fn validate_command(command: &str) -> Option<String> {
24    if command.len() > MAX_COMMAND_BYTES {
25        return Some(format!(
26            "ERROR: Command too large ({} bytes, limit {}). \
27             If you're writing file content, use the native Write/Edit tool instead. \
28             ctx_shell is for reading command output only (git, cargo, npm, etc.).",
29            command.len(),
30            MAX_COMMAND_BYTES
31        ));
32    }
33
34    if has_file_write_redirect(command) {
35        return Some(
36            "ERROR: ctx_shell detected a file-write command (shell redirect > or >>). \
37             Use the native Write tool to create/modify files. \
38             ctx_shell is ONLY for reading command output (git status, cargo test, npm run, etc.). \
39             File writes via shell cause MCP protocol corruption on large payloads."
40                .to_string(),
41        );
42    }
43
44    let cmd_lower = command.to_lowercase();
45
46    if cmd_lower.starts_with("tee ") || cmd_lower.contains("| tee ") {
47        return Some(
48            "ERROR: ctx_shell detected a file-write command (tee). \
49             Use the native Write tool to create/modify files. \
50             ctx_shell is ONLY for reading command output."
51                .to_string(),
52        );
53    }
54
55    for pattern in HEREDOC_PATTERNS {
56        if cmd_lower.contains(&pattern.to_lowercase()) {
57            return Some(
58                "ERROR: ctx_shell detected a heredoc file-write command. \
59                 Use the native Write tool to create/modify files. \
60                 ctx_shell is ONLY for reading command output."
61                    .to_string(),
62            );
63        }
64    }
65
66    None
67}
68
69/// Detects shell redirect operators (`>` or `>>`) that write to files.
70/// Ignores `>` inside quotes, `2>` (stderr), `/dev/null`, and comparison operators.
71fn has_file_write_redirect(command: &str) -> bool {
72    let bytes = command.as_bytes();
73    let len = bytes.len();
74    let mut i = 0;
75    let mut in_single_quote = false;
76    let mut in_double_quote = false;
77
78    while i < len {
79        let c = bytes[i];
80        if c == b'\'' && !in_double_quote {
81            in_single_quote = !in_single_quote;
82        } else if c == b'"' && !in_single_quote {
83            in_double_quote = !in_double_quote;
84        } else if c == b'>' && !in_single_quote && !in_double_quote {
85            if i > 0 && bytes[i - 1] == b'2' {
86                i += 1;
87                continue;
88            }
89            let target_start = if i + 1 < len && bytes[i + 1] == b'>' {
90                i + 2
91            } else {
92                i + 1
93            };
94            let target: String = command[target_start..]
95                .trim_start()
96                .chars()
97                .take_while(|c| !c.is_whitespace())
98                .collect();
99            if target == "/dev/null" {
100                i += 1;
101                continue;
102            }
103            if !target.is_empty() {
104                return true;
105            }
106        }
107        i += 1;
108    }
109    false
110}
111
112pub fn handle(command: &str, output: &str, crp_mode: CrpMode) -> String {
113    let original_tokens = count_tokens(output);
114
115    if contains_auth_flow(output) {
116        let savings = protocol::format_savings(original_tokens, original_tokens);
117        return format!(
118            "{output}\n[lean-ctx: auth/device-code flow detected — output preserved uncompressed]\n{savings}"
119        );
120    }
121
122    let compressed = match patterns::compress_output(command, output) {
123        Some(c) => c,
124        None => generic_compress(output),
125    };
126
127    if crp_mode.is_tdd() && looks_like_code(&compressed) {
128        let ext = detect_ext_from_command(command);
129        let mut sym = SymbolMap::new();
130        let idents = symbol_map::extract_identifiers(&compressed, ext);
131        for ident in &idents {
132            sym.register(ident);
133        }
134        if !sym.is_empty() {
135            let mapped = sym.apply(&compressed);
136            let sym_table = sym.format_table();
137            let result = format!("{mapped}{sym_table}");
138            let sent = count_tokens(&result);
139            let savings = protocol::format_savings(original_tokens, sent);
140            return format!("{result}\n{savings}");
141        }
142    }
143
144    let sent = count_tokens(&compressed);
145    let savings = protocol::format_savings(original_tokens, sent);
146
147    format!("{compressed}\n{savings}")
148}
149
150fn generic_compress(output: &str) -> String {
151    let output = crate::core::compressor::strip_ansi(output);
152    let lines: Vec<&str> = output
153        .lines()
154        .filter(|l| {
155            let t = l.trim();
156            !t.is_empty()
157        })
158        .collect();
159
160    if lines.len() <= 10 {
161        return lines.join("\n");
162    }
163
164    let first_3 = &lines[..3];
165    let last_3 = &lines[lines.len() - 3..];
166    let omitted = lines.len() - 6;
167    format!(
168        "{}\n[truncated: showing 6/{} lines, {} omitted. Use raw=true for full output.]\n{}",
169        first_3.join("\n"),
170        lines.len(),
171        omitted,
172        last_3.join("\n")
173    )
174}
175
176fn looks_like_code(text: &str) -> bool {
177    let indicators = [
178        "fn ",
179        "pub ",
180        "let ",
181        "const ",
182        "impl ",
183        "struct ",
184        "enum ",
185        "function ",
186        "class ",
187        "import ",
188        "export ",
189        "def ",
190        "async ",
191        "=>",
192        "->",
193        "::",
194        "self.",
195        "this.",
196    ];
197    let total_lines = text.lines().count();
198    if total_lines < 3 {
199        return false;
200    }
201    let code_lines = text
202        .lines()
203        .filter(|l| indicators.iter().any(|i| l.contains(i)))
204        .count();
205    code_lines as f64 / total_lines as f64 > 0.15
206}
207
208fn detect_ext_from_command(command: &str) -> &str {
209    let cmd = command.to_lowercase();
210    if cmd.contains("cargo") || cmd.contains(".rs") {
211        "rs"
212    } else if cmd.contains("npm")
213        || cmd.contains("node")
214        || cmd.contains(".ts")
215        || cmd.contains(".js")
216    {
217        "ts"
218    } else if cmd.contains("python") || cmd.contains("pip") || cmd.contains(".py") {
219        "py"
220    } else if cmd.contains("go ") || cmd.contains(".go") {
221        "go"
222    } else {
223        "rs"
224    }
225}
226
227/// Detects OAuth device code flow output that must not be compressed.
228/// Uses a two-tier approach: strong signals match alone (very specific to
229/// device code flows), weak signals require a URL/domain in the same output.
230pub fn contains_auth_flow(output: &str) -> bool {
231    let lower = output.to_lowercase();
232
233    const STRONG_SIGNALS: &[&str] = &[
234        "devicelogin",
235        "deviceauth",
236        "device_code",
237        "device code",
238        "device-code",
239        "verification_uri",
240        "user_code",
241        "one-time code",
242    ];
243
244    if STRONG_SIGNALS.iter().any(|s| lower.contains(s)) {
245        return true;
246    }
247
248    const WEAK_SIGNALS: &[&str] = &[
249        "enter the code",
250        "enter this code",
251        "enter code:",
252        "use the code",
253        "use a web browser to open",
254        "open the page",
255        "authenticate by visiting",
256        "sign in with the code",
257        "sign in using a code",
258        "verification code",
259        "authorize this device",
260        "waiting for authentication",
261        "waiting for login",
262        "waiting for you to authenticate",
263        "open your browser",
264        "open in your browser",
265    ];
266
267    let has_weak_signal = WEAK_SIGNALS.iter().any(|s| lower.contains(s));
268    if !has_weak_signal {
269        return false;
270    }
271
272    lower.contains("http://") || lower.contains("https://")
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn validate_allows_safe_commands() {
281        assert!(validate_command("git status").is_none());
282        assert!(validate_command("cargo test").is_none());
283        assert!(validate_command("npm run build").is_none());
284        assert!(validate_command("ls -la").is_none());
285    }
286
287    #[test]
288    fn validate_blocks_file_writes() {
289        assert!(validate_command("cat > file.py << 'EOF'\nprint('hi')\nEOF").is_some());
290        assert!(validate_command("echo 'data' > output.txt").is_some());
291        assert!(validate_command("tee /tmp/file.txt").is_some());
292        assert!(validate_command("printf 'hello' > test.txt").is_some());
293        assert!(validate_command("cat << EOF\ncontent\nEOF").is_some());
294    }
295
296    #[test]
297    fn validate_blocks_oversized_commands() {
298        let huge = "x".repeat(MAX_COMMAND_BYTES + 1);
299        let result = validate_command(&huge);
300        assert!(result.is_some());
301        assert!(result.unwrap().contains("too large"));
302    }
303
304    #[test]
305    fn validate_allows_cat_without_redirect() {
306        assert!(validate_command("cat file.txt").is_none());
307    }
308
309    // --- Auth flow detection: strong signals (no URL needed) ---
310
311    #[test]
312    fn auth_flow_detects_azure_device_code() {
313        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.";
314        assert!(contains_auth_flow(output));
315    }
316
317    #[test]
318    fn auth_flow_detects_gh_auth_one_time_code() {
319        let output =
320            "! First copy your one-time code: ABCD-1234\n- Press Enter to open github.com in your browser...";
321        assert!(contains_auth_flow(output));
322    }
323
324    #[test]
325    fn auth_flow_detects_device_code_json() {
326        let output = r#"{"device_code":"abc123","user_code":"ABCD-1234","verification_uri":"https://example.com/activate"}"#;
327        assert!(contains_auth_flow(output));
328    }
329
330    #[test]
331    fn auth_flow_detects_verification_uri_field() {
332        let output =
333            r#"{"verification_uri": "https://login.microsoftonline.com/common/oauth2/deviceauth"}"#;
334        assert!(contains_auth_flow(output));
335    }
336
337    #[test]
338    fn auth_flow_detects_user_code_field() {
339        let output = r#"{"user_code": "FGHJK-LMNOP", "expires_in": 900}"#;
340        assert!(contains_auth_flow(output));
341    }
342
343    // --- Auth flow detection: weak signals (require URL) ---
344
345    #[test]
346    fn auth_flow_detects_gcloud_with_url() {
347        let output = "Go to the following link in your browser:\n\n    https://accounts.google.com/o/oauth2/auth?response_type=code\n\nEnter verification code: ";
348        assert!(contains_auth_flow(output));
349    }
350
351    #[test]
352    fn auth_flow_detects_aws_sso_with_url() {
353        let output = "If the browser does not open, open the following URL:\nhttps://device.sso.us-east-1.amazonaws.com/\n\nThen enter the code:\nABCD-EFGH";
354        assert!(contains_auth_flow(output));
355    }
356
357    #[test]
358    fn auth_flow_detects_firebase_with_url() {
359        let output = "Visit this URL on this device to log in:\nhttps://accounts.google.com/o/oauth2/auth?...\n\nWaiting for authentication...";
360        assert!(contains_auth_flow(output));
361    }
362
363    #[test]
364    fn auth_flow_detects_generic_browser_open_with_url() {
365        let output =
366            "Open your browser to https://login.example.com/device and enter the code XYZW-1234";
367        assert!(contains_auth_flow(output));
368    }
369
370    // --- False positive protection ---
371
372    #[test]
373    fn auth_flow_ignores_normal_build_output() {
374        let output = "Compiling lean-ctx v2.21.9\nFinished release profile\n";
375        assert!(!contains_auth_flow(output));
376    }
377
378    #[test]
379    fn auth_flow_ignores_git_output() {
380        let output = "On branch main\nYour branch is up to date with 'origin/main'.\nnothing to commit, working tree clean";
381        assert!(!contains_auth_flow(output));
382    }
383
384    #[test]
385    fn auth_flow_ignores_npm_install_output() {
386        let output = "added 150 packages in 3s\n\n24 packages are looking for funding\n  run `npm fund` for details\nhttps://npmjs.com/package/lean-ctx";
387        assert!(!contains_auth_flow(output));
388    }
389
390    #[test]
391    fn auth_flow_ignores_docs_mentioning_auth() {
392        let output = "The authorization code grant type is the most common OAuth flow.\nSee https://oauth.net/2/grant-types/ for details.";
393        assert!(!contains_auth_flow(output));
394    }
395
396    #[test]
397    fn auth_flow_weak_signal_requires_url() {
398        let output = "Please enter the code ABC123 in the terminal";
399        assert!(!contains_auth_flow(output));
400    }
401
402    #[test]
403    fn auth_flow_weak_signal_without_url_is_ignored() {
404        let output = "Waiting for authentication to complete... done!";
405        assert!(!contains_auth_flow(output));
406    }
407
408    #[test]
409    fn auth_flow_ignores_virtualenv_activate() {
410        let output = "Created virtualenv at .venv\nRun: source .venv/bin/activate";
411        assert!(!contains_auth_flow(output));
412    }
413
414    #[test]
415    fn auth_flow_ignores_api_response_with_code_field() {
416        let output = r#"{"status": "ok", "code": 200, "message": "success"}"#;
417        assert!(!contains_auth_flow(output));
418    }
419
420    // --- Integration: handle() preserves auth flow ---
421
422    #[test]
423    fn handle_preserves_auth_flow_output_fully() {
424        let output = "To sign in, use a web browser to open the page https://microsoft.com/devicelogin and enter the code ABCD1234 to authenticate.\nWaiting for you...\nLine 3\nLine 4\nLine 5\nLine 6\nLine 7\nLine 8\nLine 9\nLine 10\nLine 11\nLine 12\nLine 13";
425        let result = handle("az login --use-device-code", output, CrpMode::Off);
426        assert!(result.contains("ABCD1234"), "auth code must be preserved");
427        assert!(result.contains("devicelogin"), "URL must be preserved");
428        assert!(
429            result.contains("auth/device-code flow detected"),
430            "detection note must be present"
431        );
432        assert!(
433            result.contains("Line 13"),
434            "all lines must be preserved (no truncation)"
435        );
436    }
437
438    #[test]
439    fn handle_compresses_normal_output_not_auth() {
440        let lines: Vec<String> = (1..=20).map(|i| format!("Line {i} of output")).collect();
441        let output = lines.join("\n");
442        let result = handle("some-tool check", &output, CrpMode::Off);
443        assert!(
444            !result.contains("auth/device-code flow detected"),
445            "normal output must not trigger auth detection"
446        );
447        assert!(
448            result.len() < output.len() + 100,
449            "normal output should be compressed, not inflated"
450        );
451    }
452}