Skip to main content

lean_ctx/
shell.rs

1use std::io::{self, BufRead, IsTerminal, Write};
2use std::process::{Command, Stdio};
3
4use crate::core::config;
5use crate::core::patterns;
6use crate::core::slow_log;
7use crate::core::stats;
8use crate::core::tokens::count_tokens;
9
10pub fn exec(command: &str) -> i32 {
11    let (shell, shell_flag) = shell_and_flag();
12
13    if std::env::var("LEAN_CTX_DISABLED").is_ok() {
14        return exec_inherit(command, &shell, &shell_flag);
15    }
16
17    let cfg = config::Config::load();
18    let force_compress = std::env::var("LEAN_CTX_COMPRESS").is_ok();
19    let raw_mode = std::env::var("LEAN_CTX_RAW").is_ok();
20
21    if raw_mode || (!force_compress && is_excluded_command(command, &cfg.excluded_commands)) {
22        return exec_inherit(command, &shell, &shell_flag);
23    }
24
25    if !force_compress && io::stdout().is_terminal() {
26        return exec_inherit_tracked(command, &shell, &shell_flag);
27    }
28
29    exec_buffered(command, &shell, &shell_flag, &cfg)
30}
31
32fn exec_inherit(command: &str, shell: &str, shell_flag: &str) -> i32 {
33    let status = Command::new(shell)
34        .arg(shell_flag)
35        .arg(command)
36        .env("LEAN_CTX_ACTIVE", "1")
37        .stdin(Stdio::inherit())
38        .stdout(Stdio::inherit())
39        .stderr(Stdio::inherit())
40        .status();
41
42    match status {
43        Ok(s) => s.code().unwrap_or(1),
44        Err(e) => {
45            eprintln!("lean-ctx: failed to execute: {e}");
46            127
47        }
48    }
49}
50
51fn exec_inherit_tracked(command: &str, shell: &str, shell_flag: &str) -> i32 {
52    let code = exec_inherit(command, shell, shell_flag);
53    stats::record(command, 0, 0);
54    code
55}
56
57fn combine_output(stdout: &str, stderr: &str) -> String {
58    if stderr.is_empty() {
59        stdout.to_string()
60    } else if stdout.is_empty() {
61        stderr.to_string()
62    } else {
63        format!("{stdout}\n{stderr}")
64    }
65}
66
67fn exec_buffered(command: &str, shell: &str, shell_flag: &str, cfg: &config::Config) -> i32 {
68    let start = std::time::Instant::now();
69
70    let child = Command::new(shell)
71        .arg(shell_flag)
72        .arg(command)
73        .env("LEAN_CTX_ACTIVE", "1")
74        .stdout(Stdio::piped())
75        .stderr(Stdio::piped())
76        .spawn();
77
78    let child = match child {
79        Ok(c) => c,
80        Err(e) => {
81            eprintln!("lean-ctx: failed to execute: {e}");
82            return 127;
83        }
84    };
85
86    let output = match child.wait_with_output() {
87        Ok(o) => o,
88        Err(e) => {
89            eprintln!("lean-ctx: failed to wait: {e}");
90            return 127;
91        }
92    };
93
94    let duration_ms = start.elapsed().as_millis();
95    let exit_code = output.status.code().unwrap_or(1);
96    let stdout = String::from_utf8_lossy(&output.stdout);
97    let stderr = String::from_utf8_lossy(&output.stderr);
98
99    let full_output = combine_output(&stdout, &stderr);
100    let input_tokens = count_tokens(&full_output);
101
102    let (compressed, output_tokens) = compress_and_measure(command, &stdout, &stderr);
103
104    stats::record(command, input_tokens, output_tokens);
105
106    if !compressed.is_empty() {
107        let _ = io::stdout().write_all(compressed.as_bytes());
108        if !compressed.ends_with('\n') {
109            let _ = io::stdout().write_all(b"\n");
110        }
111    }
112    let should_tee = match cfg.tee_mode {
113        config::TeeMode::Always => !full_output.trim().is_empty(),
114        config::TeeMode::Failures => exit_code != 0 && !full_output.trim().is_empty(),
115        config::TeeMode::Never => false,
116    };
117    if should_tee {
118        if let Some(path) = save_tee(command, &full_output) {
119            eprintln!("[lean-ctx: full output -> {path} (redacted, 24h TTL)]");
120        }
121    }
122
123    let threshold = cfg.slow_command_threshold_ms;
124    if threshold > 0 && duration_ms >= threshold as u128 {
125        slow_log::record(command, duration_ms, exit_code);
126    }
127
128    exit_code
129}
130
131const BUILTIN_PASSTHROUGH: &[&str] = &[
132    // JS/TS dev servers & watchers
133    "turbo",
134    "nx serve",
135    "nx dev",
136    "next dev",
137    "vite dev",
138    "vite preview",
139    "vitest",
140    "nuxt dev",
141    "astro dev",
142    "webpack serve",
143    "webpack-dev-server",
144    "nodemon",
145    "concurrently",
146    "pm2",
147    "pm2 logs",
148    "gatsby develop",
149    "expo start",
150    "react-scripts start",
151    "ng serve",
152    "remix dev",
153    "wrangler dev",
154    "hugo server",
155    "hugo serve",
156    "jekyll serve",
157    "bun dev",
158    "ember serve",
159    // Docker
160    "docker compose up",
161    "docker-compose up",
162    "docker compose logs",
163    "docker-compose logs",
164    "docker compose exec",
165    "docker-compose exec",
166    "docker compose run",
167    "docker-compose run",
168    "docker logs",
169    "docker attach",
170    "docker exec -it",
171    "docker exec -ti",
172    "docker run -it",
173    "docker run -ti",
174    "docker stats",
175    "docker events",
176    // Kubernetes
177    "kubectl logs",
178    "kubectl exec -it",
179    "kubectl exec -ti",
180    "kubectl attach",
181    "kubectl port-forward",
182    "kubectl proxy",
183    // System monitors & streaming
184    "top",
185    "htop",
186    "btop",
187    "watch ",
188    "tail -f",
189    "tail -F",
190    "journalctl -f",
191    "journalctl --follow",
192    "dmesg -w",
193    "dmesg --follow",
194    "strace",
195    "tcpdump",
196    "ping ",
197    "ping6 ",
198    "traceroute",
199    // Editors & pagers
200    "less",
201    "more",
202    "vim",
203    "nvim",
204    "vi ",
205    "nano",
206    "micro ",
207    "helix ",
208    "hx ",
209    "emacs",
210    // Terminal multiplexers
211    "tmux",
212    "screen",
213    // Interactive shells & REPLs
214    "ssh ",
215    "telnet ",
216    "nc ",
217    "ncat ",
218    "psql",
219    "mysql",
220    "sqlite3",
221    "redis-cli",
222    "mongosh",
223    "mongo ",
224    "python3 -i",
225    "python -i",
226    "irb",
227    "rails console",
228    "rails c ",
229    "iex",
230    // Rust watchers
231    "cargo watch",
232];
233
234fn is_excluded_command(command: &str, excluded: &[String]) -> bool {
235    let cmd = command.trim().to_lowercase();
236    for pattern in BUILTIN_PASSTHROUGH {
237        if cmd == *pattern || cmd.starts_with(&format!("{pattern} ")) || cmd.contains(pattern) {
238            return true;
239        }
240    }
241    if excluded.is_empty() {
242        return false;
243    }
244    excluded.iter().any(|excl| {
245        let excl_lower = excl.trim().to_lowercase();
246        cmd == excl_lower || cmd.starts_with(&format!("{excl_lower} "))
247    })
248}
249
250pub fn interactive() {
251    let real_shell = detect_shell();
252
253    eprintln!(
254        "lean-ctx shell v{} (wrapping {real_shell})",
255        env!("CARGO_PKG_VERSION")
256    );
257    eprintln!("All command output is automatically compressed.");
258    eprintln!("Type 'exit' to quit.\n");
259
260    let stdin = io::stdin();
261    let mut stdout = io::stdout();
262
263    loop {
264        let _ = write!(stdout, "lean-ctx> ");
265        let _ = stdout.flush();
266
267        let mut line = String::new();
268        match stdin.lock().read_line(&mut line) {
269            Ok(0) => break,
270            Ok(_) => {}
271            Err(_) => break,
272        }
273
274        let cmd = line.trim();
275        if cmd.is_empty() {
276            continue;
277        }
278        if cmd == "exit" || cmd == "quit" {
279            break;
280        }
281        if cmd == "gain" {
282            println!("{}", stats::format_gain());
283            continue;
284        }
285
286        let exit_code = exec(cmd);
287
288        if exit_code != 0 {
289            let _ = writeln!(stdout, "[exit: {exit_code}]");
290        }
291    }
292}
293
294fn compress_and_measure(command: &str, stdout: &str, stderr: &str) -> (String, usize) {
295    let compressed_stdout = compress_if_beneficial(command, stdout);
296    let compressed_stderr = compress_if_beneficial(command, stderr);
297
298    let mut result = String::new();
299    if !compressed_stdout.is_empty() {
300        result.push_str(&compressed_stdout);
301    }
302    if !compressed_stderr.is_empty() {
303        if !result.is_empty() {
304            result.push('\n');
305        }
306        result.push_str(&compressed_stderr);
307    }
308
309    let output_tokens = count_tokens(&result);
310    (result, output_tokens)
311}
312
313fn compress_if_beneficial(command: &str, output: &str) -> String {
314    if output.trim().is_empty() {
315        return String::new();
316    }
317
318    let original_tokens = count_tokens(output);
319
320    if original_tokens < 50 {
321        return output.to_string();
322    }
323
324    let min_output_tokens = 5;
325
326    if let Some(compressed) = patterns::compress_output(command, output) {
327        if !compressed.trim().is_empty() {
328            let compressed_tokens = count_tokens(&compressed);
329            if compressed_tokens >= min_output_tokens && compressed_tokens < original_tokens {
330                let saved = original_tokens - compressed_tokens;
331                let pct = (saved as f64 / original_tokens as f64 * 100.0).round() as usize;
332                return format!(
333                    "{compressed}\n[lean-ctx: {original_tokens}→{compressed_tokens} tok, -{pct}%]"
334                );
335            }
336            if compressed_tokens < min_output_tokens {
337                return output.to_string();
338            }
339        }
340    }
341
342    // Apply lightweight cleanup to remove whitespace-only lines and collapse braces
343    let cleaned = crate::core::compressor::lightweight_cleanup(output);
344    let cleaned_tokens = count_tokens(&cleaned);
345    if cleaned_tokens < original_tokens {
346        let lines: Vec<&str> = cleaned.lines().collect();
347        if lines.len() > 30 {
348            let first = &lines[..5];
349            let last = &lines[lines.len() - 5..];
350            let omitted = lines.len() - 10;
351            let total = lines.len();
352            let compressed = format!(
353                "{}\n[truncated: showing 10/{total} lines, {omitted} omitted]\n{}",
354                first.join("\n"),
355                last.join("\n")
356            );
357            let ct = count_tokens(&compressed);
358            if ct < original_tokens {
359                let saved = original_tokens - ct;
360                let pct = (saved as f64 / original_tokens as f64 * 100.0).round() as usize;
361                return format!("{compressed}\n[lean-ctx: {original_tokens}→{ct} tok, -{pct}%]");
362            }
363        }
364        if cleaned_tokens < original_tokens {
365            let saved = original_tokens - cleaned_tokens;
366            let pct = (saved as f64 / original_tokens as f64 * 100.0).round() as usize;
367            return format!(
368                "{cleaned}\n[lean-ctx: {original_tokens}→{cleaned_tokens} tok, -{pct}%]"
369            );
370        }
371    }
372
373    let lines: Vec<&str> = output.lines().collect();
374    if lines.len() > 30 {
375        let first = &lines[..5];
376        let last = &lines[lines.len() - 5..];
377        let omitted = lines.len() - 10;
378        let compressed = format!(
379            "{}\n... ({omitted} lines omitted) ...\n{}",
380            first.join("\n"),
381            last.join("\n")
382        );
383        let compressed_tokens = count_tokens(&compressed);
384        if compressed_tokens < original_tokens {
385            let saved = original_tokens - compressed_tokens;
386            let pct = (saved as f64 / original_tokens as f64 * 100.0).round() as usize;
387            return format!(
388                "{compressed}\n[lean-ctx: {original_tokens}→{compressed_tokens} tok, -{pct}%]"
389            );
390        }
391    }
392
393    output.to_string()
394}
395
396/// Windows only: argument that passes one command string to the shell binary.
397/// `exe_basename` must already be ASCII-lowercase (e.g. `bash.exe`, `cmd.exe`).
398fn windows_shell_flag_for_exe_basename(exe_basename: &str) -> &'static str {
399    if exe_basename.contains("powershell") || exe_basename.contains("pwsh") {
400        "-Command"
401    } else if exe_basename == "cmd.exe" || exe_basename == "cmd" {
402        "/C"
403    } else {
404        // POSIX-style shells: Git Bash / MSYS (`bash`, `sh`, `zsh`, `fish`, …).
405        // `/C` is only valid for `cmd.exe`; using it with bash produced
406        // `/C: Is a directory` and exit 126 (see github.com/yvgude/lean-ctx/issues/7).
407        "-c"
408    }
409}
410
411pub fn shell_and_flag() -> (String, String) {
412    let shell = detect_shell();
413    let flag = if cfg!(windows) {
414        let name = std::path::Path::new(&shell)
415            .file_name()
416            .and_then(|n| n.to_str())
417            .unwrap_or("")
418            .to_ascii_lowercase();
419        windows_shell_flag_for_exe_basename(&name).to_string()
420    } else {
421        "-c".to_string()
422    };
423    (shell, flag)
424}
425
426fn detect_shell() -> String {
427    if let Ok(shell) = std::env::var("LEAN_CTX_SHELL") {
428        return shell;
429    }
430
431    if let Ok(shell) = std::env::var("SHELL") {
432        let bin = std::path::Path::new(&shell)
433            .file_name()
434            .and_then(|n| n.to_str())
435            .unwrap_or("sh");
436
437        if bin == "lean-ctx" {
438            return find_real_shell();
439        }
440        return shell;
441    }
442
443    find_real_shell()
444}
445
446#[cfg(unix)]
447fn find_real_shell() -> String {
448    for shell in &["/bin/zsh", "/bin/bash", "/bin/sh"] {
449        if std::path::Path::new(shell).exists() {
450            return shell.to_string();
451        }
452    }
453    "/bin/sh".to_string()
454}
455
456#[cfg(windows)]
457fn find_real_shell() -> String {
458    // Always prefer PowerShell over cmd.exe — AI agents send bash-like syntax
459    // that cmd.exe cannot parse (e.g. `&&`, pipes, subshells).
460    // PSModulePath may not be set when the MCP server is spawned by an IDE.
461    if let Ok(pwsh) = which_powershell() {
462        return pwsh;
463    }
464    if let Ok(comspec) = std::env::var("COMSPEC") {
465        return comspec;
466    }
467    "cmd.exe".to_string()
468}
469
470#[cfg(windows)]
471fn is_running_in_powershell() -> bool {
472    std::env::var("PSModulePath").is_ok()
473}
474
475#[cfg(windows)]
476fn which_powershell() -> Result<String, ()> {
477    for candidate in &["pwsh.exe", "powershell.exe"] {
478        if let Ok(output) = std::process::Command::new("where").arg(candidate).output() {
479            if output.status.success() {
480                if let Ok(path) = String::from_utf8(output.stdout) {
481                    if let Some(first_line) = path.lines().next() {
482                        let trimmed = first_line.trim();
483                        if !trimmed.is_empty() {
484                            return Ok(trimmed.to_string());
485                        }
486                    }
487                }
488            }
489        }
490    }
491    Err(())
492}
493
494pub fn save_tee(command: &str, output: &str) -> Option<String> {
495    let tee_dir = dirs::home_dir()?.join(".lean-ctx").join("tee");
496    std::fs::create_dir_all(&tee_dir).ok()?;
497
498    cleanup_old_tee_logs(&tee_dir);
499
500    let cmd_slug: String = command
501        .chars()
502        .take(40)
503        .map(|c| {
504            if c.is_alphanumeric() || c == '-' {
505                c
506            } else {
507                '_'
508            }
509        })
510        .collect();
511    let ts = chrono::Local::now().format("%Y-%m-%d_%H%M%S");
512    let filename = format!("{ts}_{cmd_slug}.log");
513    let path = tee_dir.join(&filename);
514
515    let masked = mask_sensitive_data(output);
516    std::fs::write(&path, masked).ok()?;
517    Some(path.to_string_lossy().to_string())
518}
519
520fn mask_sensitive_data(input: &str) -> String {
521    use regex::Regex;
522
523    let patterns: Vec<(&str, Regex)> = vec![
524        ("Bearer token", Regex::new(r"(?i)(bearer\s+)[a-zA-Z0-9\-_\.]{8,}").unwrap()),
525        ("Authorization header", Regex::new(r"(?i)(authorization:\s*(?:basic|bearer|token)\s+)[^\s\r\n]+").unwrap()),
526        ("API key param", Regex::new(r#"(?i)((?:api[_-]?key|apikey|access[_-]?key|secret[_-]?key|token|password|passwd|pwd|secret)\s*[=:]\s*)[^\s\r\n,;&"']+"#).unwrap()),
527        ("AWS key", Regex::new(r"(AKIA[0-9A-Z]{12,})").unwrap()),
528        ("Private key block", Regex::new(r"(?s)(-----BEGIN\s+(?:RSA\s+)?PRIVATE\s+KEY-----).+?(-----END\s+(?:RSA\s+)?PRIVATE\s+KEY-----)").unwrap()),
529        ("GitHub token", Regex::new(r"(gh[pousr]_)[a-zA-Z0-9]{20,}").unwrap()),
530        ("Generic long hex/base64 secret", Regex::new(r#"(?i)(?:key|token|secret|password|credential|auth)\s*[=:]\s*['"]?([a-zA-Z0-9+/=\-_]{32,})['"]?"#).unwrap()),
531    ];
532
533    let mut result = input.to_string();
534    for (label, re) in &patterns {
535        result = re
536            .replace_all(&result, |caps: &regex::Captures| {
537                if let Some(prefix) = caps.get(1) {
538                    format!("{}[REDACTED:{}]", prefix.as_str(), label)
539                } else {
540                    format!("[REDACTED:{}]", label)
541                }
542            })
543            .to_string();
544    }
545    result
546}
547
548fn cleanup_old_tee_logs(tee_dir: &std::path::Path) {
549    let cutoff =
550        std::time::SystemTime::now().checked_sub(std::time::Duration::from_secs(24 * 60 * 60));
551    let cutoff = match cutoff {
552        Some(t) => t,
553        None => return,
554    };
555
556    if let Ok(entries) = std::fs::read_dir(tee_dir) {
557        for entry in entries.flatten() {
558            if let Ok(meta) = entry.metadata() {
559                if let Ok(modified) = meta.modified() {
560                    if modified < cutoff {
561                        let _ = std::fs::remove_file(entry.path());
562                    }
563                }
564            }
565        }
566    }
567}
568
569#[cfg(test)]
570mod windows_shell_flag_tests {
571    use super::windows_shell_flag_for_exe_basename;
572
573    #[test]
574    fn cmd_uses_slash_c() {
575        assert_eq!(windows_shell_flag_for_exe_basename("cmd.exe"), "/C");
576        assert_eq!(windows_shell_flag_for_exe_basename("cmd"), "/C");
577    }
578
579    #[test]
580    fn powershell_uses_command() {
581        assert_eq!(
582            windows_shell_flag_for_exe_basename("powershell.exe"),
583            "-Command"
584        );
585        assert_eq!(windows_shell_flag_for_exe_basename("pwsh.exe"), "-Command");
586    }
587
588    #[test]
589    fn posix_shells_use_dash_c() {
590        assert_eq!(windows_shell_flag_for_exe_basename("bash.exe"), "-c");
591        assert_eq!(windows_shell_flag_for_exe_basename("bash"), "-c");
592        assert_eq!(windows_shell_flag_for_exe_basename("sh.exe"), "-c");
593        assert_eq!(windows_shell_flag_for_exe_basename("zsh.exe"), "-c");
594        assert_eq!(windows_shell_flag_for_exe_basename("fish.exe"), "-c");
595    }
596}
597
598#[cfg(test)]
599mod passthrough_tests {
600    use super::is_excluded_command;
601
602    #[test]
603    fn turbo_is_passthrough() {
604        assert!(is_excluded_command("turbo run dev", &[]));
605        assert!(is_excluded_command("turbo run build", &[]));
606        assert!(is_excluded_command("pnpm turbo run dev", &[]));
607        assert!(is_excluded_command("npx turbo run dev", &[]));
608    }
609
610    #[test]
611    fn dev_servers_are_passthrough() {
612        assert!(is_excluded_command("next dev", &[]));
613        assert!(is_excluded_command("vite dev", &[]));
614        assert!(is_excluded_command("nuxt dev", &[]));
615        assert!(is_excluded_command("astro dev", &[]));
616        assert!(is_excluded_command("nodemon server.js", &[]));
617    }
618
619    #[test]
620    fn interactive_tools_are_passthrough() {
621        assert!(is_excluded_command("vim file.rs", &[]));
622        assert!(is_excluded_command("nvim", &[]));
623        assert!(is_excluded_command("htop", &[]));
624        assert!(is_excluded_command("ssh user@host", &[]));
625        assert!(is_excluded_command("tail -f /var/log/syslog", &[]));
626    }
627
628    #[test]
629    fn docker_streaming_is_passthrough() {
630        assert!(is_excluded_command("docker logs my-container", &[]));
631        assert!(is_excluded_command("docker logs -f webapp", &[]));
632        assert!(is_excluded_command("docker attach my-container", &[]));
633        assert!(is_excluded_command("docker exec -it web bash", &[]));
634        assert!(is_excluded_command("docker exec -ti web bash", &[]));
635        assert!(is_excluded_command("docker run -it ubuntu bash", &[]));
636        assert!(is_excluded_command("docker compose exec web bash", &[]));
637        assert!(is_excluded_command("docker stats", &[]));
638        assert!(is_excluded_command("docker events", &[]));
639    }
640
641    #[test]
642    fn kubectl_is_passthrough() {
643        assert!(is_excluded_command("kubectl logs my-pod", &[]));
644        assert!(is_excluded_command("kubectl logs -f deploy/web", &[]));
645        assert!(is_excluded_command("kubectl exec -it pod -- bash", &[]));
646        assert!(is_excluded_command(
647            "kubectl port-forward svc/web 8080:80",
648            &[]
649        ));
650        assert!(is_excluded_command("kubectl attach my-pod", &[]));
651        assert!(is_excluded_command("kubectl proxy", &[]));
652    }
653
654    #[test]
655    fn database_repls_are_passthrough() {
656        assert!(is_excluded_command("psql -U user mydb", &[]));
657        assert!(is_excluded_command("mysql -u root -p", &[]));
658        assert!(is_excluded_command("sqlite3 data.db", &[]));
659        assert!(is_excluded_command("redis-cli", &[]));
660        assert!(is_excluded_command("mongosh", &[]));
661    }
662
663    #[test]
664    fn streaming_tools_are_passthrough() {
665        assert!(is_excluded_command("journalctl -f", &[]));
666        assert!(is_excluded_command("ping 8.8.8.8", &[]));
667        assert!(is_excluded_command("strace -p 1234", &[]));
668        assert!(is_excluded_command("tcpdump -i eth0", &[]));
669        assert!(is_excluded_command("tail -F /var/log/app.log", &[]));
670        assert!(is_excluded_command("tmux new -s work", &[]));
671        assert!(is_excluded_command("screen -S dev", &[]));
672    }
673
674    #[test]
675    fn additional_dev_servers_are_passthrough() {
676        assert!(is_excluded_command("gatsby develop", &[]));
677        assert!(is_excluded_command("ng serve --port 4200", &[]));
678        assert!(is_excluded_command("remix dev", &[]));
679        assert!(is_excluded_command("wrangler dev", &[]));
680        assert!(is_excluded_command("hugo server", &[]));
681        assert!(is_excluded_command("bun dev", &[]));
682        assert!(is_excluded_command("cargo watch -x test", &[]));
683    }
684
685    #[test]
686    fn normal_commands_not_excluded() {
687        assert!(!is_excluded_command("git status", &[]));
688        assert!(!is_excluded_command("cargo test", &[]));
689        assert!(!is_excluded_command("npm run build", &[]));
690        assert!(!is_excluded_command("ls -la", &[]));
691    }
692
693    #[test]
694    fn user_exclusions_work() {
695        let excl = vec!["myapp".to_string()];
696        assert!(is_excluded_command("myapp serve", &excl));
697        assert!(!is_excluded_command("git status", &excl));
698    }
699}
700
701#[cfg(test)]
702mod shell_detection_tests {
703    use super::*;
704
705    #[test]
706    fn lean_ctx_shell_env_takes_priority() {
707        std::env::set_var("LEAN_CTX_SHELL", "/custom/shell");
708        let shell = detect_shell();
709        std::env::remove_var("LEAN_CTX_SHELL");
710        assert_eq!(shell, "/custom/shell");
711    }
712
713    #[test]
714    fn shell_env_respected_when_no_override() {
715        let orig_lcs = std::env::var("LEAN_CTX_SHELL").ok();
716        std::env::remove_var("LEAN_CTX_SHELL");
717
718        let orig = std::env::var("SHELL").ok();
719        std::env::set_var("SHELL", "/bin/bash");
720        let shell = detect_shell();
721        if let Some(v) = orig {
722            std::env::set_var("SHELL", v);
723        } else {
724            std::env::remove_var("SHELL");
725        }
726        if let Some(v) = orig_lcs {
727            std::env::set_var("LEAN_CTX_SHELL", v);
728        }
729        assert_eq!(shell, "/bin/bash");
730    }
731
732    #[test]
733    fn lean_ctx_shell_self_reference_triggers_fallback() {
734        let orig_lcs = std::env::var("LEAN_CTX_SHELL").ok();
735        std::env::remove_var("LEAN_CTX_SHELL");
736
737        let orig = std::env::var("SHELL").ok();
738        std::env::set_var("SHELL", "/usr/local/bin/lean-ctx");
739        let shell = detect_shell();
740        if let Some(v) = orig {
741            std::env::set_var("SHELL", v);
742        } else {
743            std::env::remove_var("SHELL");
744        }
745        if let Some(v) = orig_lcs {
746            std::env::set_var("LEAN_CTX_SHELL", v);
747        }
748        assert_ne!(shell, "/usr/local/bin/lean-ctx");
749    }
750
751    #[test]
752    fn shell_and_flag_returns_dash_c_on_unix() {
753        let (_shell, flag) = shell_and_flag();
754        if !cfg!(windows) {
755            assert_eq!(flag, "-c");
756        }
757    }
758}
759
760#[cfg(test)]
761mod windows_shell_detection_tests {
762    #[test]
763    fn windows_flag_powershell_variants() {
764        use super::windows_shell_flag_for_exe_basename;
765        assert_eq!(windows_shell_flag_for_exe_basename("pwsh.exe"), "-Command");
766        assert_eq!(
767            windows_shell_flag_for_exe_basename("powershell.exe"),
768            "-Command"
769        );
770        assert_eq!(windows_shell_flag_for_exe_basename("pwsh"), "-Command");
771    }
772
773    #[test]
774    fn windows_flag_git_bash_variants() {
775        use super::windows_shell_flag_for_exe_basename;
776        assert_eq!(windows_shell_flag_for_exe_basename("bash.exe"), "-c");
777        assert_eq!(windows_shell_flag_for_exe_basename("sh.exe"), "-c");
778        assert_eq!(windows_shell_flag_for_exe_basename("git-bash.exe"), "-c");
779    }
780
781    #[test]
782    fn windows_flag_unknown_defaults_to_posix() {
783        use super::windows_shell_flag_for_exe_basename;
784        assert_eq!(windows_shell_flag_for_exe_basename("unknown.exe"), "-c");
785        assert_eq!(windows_shell_flag_for_exe_basename("myshell"), "-c");
786    }
787}