Skip to main content

lean_ctx/
shell.rs

1use std::io::{self, BufRead, IsTerminal, Write};
2use std::process::{Command, Stdio};
3
4use crate::core::config;
5use crate::core::patterns;
6use crate::core::slow_log;
7use crate::core::stats;
8use crate::core::tokens::count_tokens;
9
10pub fn decode_output(bytes: &[u8]) -> String {
11    match String::from_utf8(bytes.to_vec()) {
12        Ok(s) => s,
13        Err(_) => {
14            #[cfg(windows)]
15            {
16                decode_windows_output(bytes)
17            }
18            #[cfg(not(windows))]
19            {
20                String::from_utf8_lossy(bytes).into_owned()
21            }
22        }
23    }
24}
25
26#[cfg(windows)]
27fn decode_windows_output(bytes: &[u8]) -> String {
28    use std::os::windows::ffi::OsStringExt;
29
30    extern "system" {
31        fn GetACP() -> u32;
32        fn MultiByteToWideChar(
33            cp: u32,
34            flags: u32,
35            src: *const u8,
36            srclen: i32,
37            dst: *mut u16,
38            dstlen: i32,
39        ) -> i32;
40    }
41
42    let codepage = unsafe { GetACP() };
43    let wide_len = unsafe {
44        MultiByteToWideChar(
45            codepage,
46            0,
47            bytes.as_ptr(),
48            bytes.len() as i32,
49            std::ptr::null_mut(),
50            0,
51        )
52    };
53    if wide_len <= 0 {
54        return String::from_utf8_lossy(bytes).into_owned();
55    }
56    let mut wide: Vec<u16> = vec![0u16; wide_len as usize];
57    unsafe {
58        MultiByteToWideChar(
59            codepage,
60            0,
61            bytes.as_ptr(),
62            bytes.len() as i32,
63            wide.as_mut_ptr(),
64            wide_len,
65        );
66    }
67    std::ffi::OsString::from_wide(&wide)
68        .to_string_lossy()
69        .into_owned()
70}
71
72#[cfg(windows)]
73fn set_console_utf8() {
74    extern "system" {
75        fn SetConsoleOutputCP(id: u32) -> i32;
76    }
77    unsafe {
78        SetConsoleOutputCP(65001);
79    }
80}
81
82/// Detects if the current process runs inside a Docker/container environment.
83pub fn is_container() -> bool {
84    #[cfg(unix)]
85    {
86        if std::path::Path::new("/.dockerenv").exists() {
87            return true;
88        }
89        if let Ok(cgroup) = std::fs::read_to_string("/proc/1/cgroup") {
90            if cgroup.contains("/docker/") || cgroup.contains("/lxc/") {
91                return true;
92            }
93        }
94        if let Ok(mounts) = std::fs::read_to_string("/proc/self/mountinfo") {
95            if mounts.contains("/docker/containers/") {
96                return true;
97            }
98        }
99        false
100    }
101    #[cfg(not(unix))]
102    {
103        false
104    }
105}
106
107/// Returns true if stdin is NOT a terminal (pipe, /dev/null, etc.)
108pub fn is_non_interactive() -> bool {
109    !io::stdin().is_terminal()
110}
111
112pub fn exec(command: &str) -> i32 {
113    let (shell, shell_flag) = shell_and_flag();
114    let command = crate::tools::ctx_shell::normalize_command_for_shell(command);
115    let command = command.as_str();
116
117    if std::env::var("LEAN_CTX_DISABLED").is_ok() || std::env::var("LEAN_CTX_ACTIVE").is_ok() {
118        return exec_inherit(command, &shell, &shell_flag);
119    }
120
121    let cfg = config::Config::load();
122    let force_compress = std::env::var("LEAN_CTX_COMPRESS").is_ok();
123    let raw_mode = std::env::var("LEAN_CTX_RAW").is_ok();
124
125    if raw_mode || (!force_compress && is_excluded_command(command, &cfg.excluded_commands)) {
126        return exec_inherit(command, &shell, &shell_flag);
127    }
128
129    if !force_compress {
130        if io::stdout().is_terminal() {
131            return exec_inherit_tracked(command, &shell, &shell_flag);
132        }
133        return exec_inherit(command, &shell, &shell_flag);
134    }
135
136    exec_buffered(command, &shell, &shell_flag, &cfg)
137}
138
139fn exec_inherit(command: &str, shell: &str, shell_flag: &str) -> i32 {
140    let status = Command::new(shell)
141        .arg(shell_flag)
142        .arg(command)
143        .env("LEAN_CTX_ACTIVE", "1")
144        .stdin(Stdio::inherit())
145        .stdout(Stdio::inherit())
146        .stderr(Stdio::inherit())
147        .status();
148
149    match status {
150        Ok(s) => s.code().unwrap_or(1),
151        Err(e) => {
152            eprintln!("lean-ctx: failed to execute: {e}");
153            127
154        }
155    }
156}
157
158fn exec_inherit_tracked(command: &str, shell: &str, shell_flag: &str) -> i32 {
159    let code = exec_inherit(command, shell, shell_flag);
160    stats::record(command, 0, 0);
161    code
162}
163
164fn combine_output(stdout: &str, stderr: &str) -> String {
165    if stderr.is_empty() {
166        stdout.to_string()
167    } else if stdout.is_empty() {
168        stderr.to_string()
169    } else {
170        format!("{stdout}\n{stderr}")
171    }
172}
173
174fn exec_buffered(command: &str, shell: &str, shell_flag: &str, cfg: &config::Config) -> i32 {
175    #[cfg(windows)]
176    set_console_utf8();
177
178    let start = std::time::Instant::now();
179
180    let mut cmd = Command::new(shell);
181    cmd.arg(shell_flag);
182
183    #[cfg(windows)]
184    {
185        let is_powershell =
186            shell.to_lowercase().contains("powershell") || shell.to_lowercase().contains("pwsh");
187        if is_powershell {
188            cmd.arg(format!(
189                "[Console]::OutputEncoding = [System.Text.Encoding]::UTF8; {command}"
190            ));
191        } else {
192            cmd.arg(command);
193        }
194    }
195    #[cfg(not(windows))]
196    cmd.arg(command);
197
198    let child = cmd
199        .env("LEAN_CTX_ACTIVE", "1")
200        .env_remove("DISPLAY")
201        .env_remove("XAUTHORITY")
202        .env_remove("WAYLAND_DISPLAY")
203        .stdout(Stdio::piped())
204        .stderr(Stdio::piped())
205        .spawn();
206
207    let child = match child {
208        Ok(c) => c,
209        Err(e) => {
210            eprintln!("lean-ctx: failed to execute: {e}");
211            return 127;
212        }
213    };
214
215    let output = match child.wait_with_output() {
216        Ok(o) => o,
217        Err(e) => {
218            eprintln!("lean-ctx: failed to wait: {e}");
219            return 127;
220        }
221    };
222
223    let duration_ms = start.elapsed().as_millis();
224    let exit_code = output.status.code().unwrap_or(1);
225    let stdout = decode_output(&output.stdout);
226    let stderr = decode_output(&output.stderr);
227
228    let full_output = combine_output(&stdout, &stderr);
229    let input_tokens = count_tokens(&full_output);
230
231    let (compressed, output_tokens) = compress_and_measure(command, &stdout, &stderr);
232
233    stats::record(command, input_tokens, output_tokens);
234
235    if !compressed.is_empty() {
236        let _ = io::stdout().write_all(compressed.as_bytes());
237        if !compressed.ends_with('\n') {
238            let _ = io::stdout().write_all(b"\n");
239        }
240    }
241    let should_tee = match cfg.tee_mode {
242        config::TeeMode::Always => !full_output.trim().is_empty(),
243        config::TeeMode::Failures => exit_code != 0 && !full_output.trim().is_empty(),
244        config::TeeMode::Never => false,
245    };
246    if should_tee {
247        if let Some(path) = save_tee(command, &full_output) {
248            eprintln!("[lean-ctx: full output -> {path} (redacted, 24h TTL)]");
249        }
250    }
251
252    let threshold = cfg.slow_command_threshold_ms;
253    if threshold > 0 && duration_ms >= threshold as u128 {
254        slow_log::record(command, duration_ms, exit_code);
255    }
256
257    exit_code
258}
259
260const BUILTIN_PASSTHROUGH: &[&str] = &[
261    // JS/TS dev servers & watchers
262    "turbo",
263    "nx serve",
264    "nx dev",
265    "next dev",
266    "vite dev",
267    "vite preview",
268    "vitest",
269    "nuxt dev",
270    "astro dev",
271    "webpack serve",
272    "webpack-dev-server",
273    "nodemon",
274    "concurrently",
275    "pm2",
276    "pm2 logs",
277    "gatsby develop",
278    "expo start",
279    "react-scripts start",
280    "ng serve",
281    "remix dev",
282    "wrangler dev",
283    "hugo server",
284    "hugo serve",
285    "jekyll serve",
286    "bun dev",
287    "ember serve",
288    // Docker
289    "docker compose up",
290    "docker-compose up",
291    "docker compose logs",
292    "docker-compose logs",
293    "docker compose exec",
294    "docker-compose exec",
295    "docker compose run",
296    "docker-compose run",
297    "docker logs",
298    "docker attach",
299    "docker exec -it",
300    "docker exec -ti",
301    "docker run -it",
302    "docker run -ti",
303    "docker stats",
304    "docker events",
305    // Kubernetes
306    "kubectl logs",
307    "kubectl exec -it",
308    "kubectl exec -ti",
309    "kubectl attach",
310    "kubectl port-forward",
311    "kubectl proxy",
312    // System monitors & streaming
313    "top",
314    "htop",
315    "btop",
316    "watch ",
317    "tail -f",
318    "tail -F",
319    "journalctl -f",
320    "journalctl --follow",
321    "dmesg -w",
322    "dmesg --follow",
323    "strace",
324    "tcpdump",
325    "ping ",
326    "ping6 ",
327    "traceroute",
328    // Editors & pagers
329    "less",
330    "more",
331    "vim",
332    "nvim",
333    "vi ",
334    "nano",
335    "micro ",
336    "helix ",
337    "hx ",
338    "emacs",
339    // Terminal multiplexers
340    "tmux",
341    "screen",
342    // Interactive shells & REPLs
343    "ssh ",
344    "telnet ",
345    "nc ",
346    "ncat ",
347    "psql",
348    "mysql",
349    "sqlite3",
350    "redis-cli",
351    "mongosh",
352    "mongo ",
353    "python3 -i",
354    "python -i",
355    "irb",
356    "rails console",
357    "rails c ",
358    "iex",
359    // Rust watchers
360    "cargo watch",
361    // Authentication flows (device code, OAuth, SSO — output contains codes users must see)
362    "az login",
363    "az account",
364    "gh auth",
365    "gcloud auth",
366    "gcloud init",
367    "aws sso",
368    "aws configure sso",
369    "firebase login",
370    "netlify login",
371    "vercel login",
372    "heroku login",
373    "flyctl auth",
374    "fly auth",
375    "railway login",
376    "supabase login",
377    "wrangler login",
378    "doppler login",
379    "vault login",
380    "oc login",
381    "kubelogin",
382    "--use-device-code",
383];
384
385fn is_excluded_command(command: &str, excluded: &[String]) -> bool {
386    let cmd = command.trim().to_lowercase();
387    for pattern in BUILTIN_PASSTHROUGH {
388        if pattern.starts_with("--") {
389            if cmd.contains(pattern) {
390                return true;
391            }
392        } else if pattern.ends_with(' ') || pattern.ends_with('\t') {
393            if cmd == pattern.trim() || cmd.starts_with(pattern) {
394                return true;
395            }
396        } else if cmd == *pattern
397            || cmd.starts_with(&format!("{pattern} "))
398            || cmd.starts_with(&format!("{pattern}\t"))
399            || cmd.contains(&format!(" {pattern} "))
400            || cmd.contains(&format!(" {pattern}\t"))
401            || cmd.contains(&format!("|{pattern} "))
402            || cmd.contains(&format!("|{pattern}\t"))
403            || cmd.ends_with(&format!(" {pattern}"))
404            || cmd.ends_with(&format!("|{pattern}"))
405        {
406            return true;
407        }
408    }
409    if excluded.is_empty() {
410        return false;
411    }
412    excluded.iter().any(|excl| {
413        let excl_lower = excl.trim().to_lowercase();
414        cmd == excl_lower || cmd.starts_with(&format!("{excl_lower} "))
415    })
416}
417
418pub fn interactive() {
419    let real_shell = detect_shell();
420
421    eprintln!(
422        "lean-ctx shell v{} (wrapping {real_shell})",
423        env!("CARGO_PKG_VERSION")
424    );
425    eprintln!("All command output is automatically compressed.");
426    eprintln!("Type 'exit' to quit.\n");
427
428    let stdin = io::stdin();
429    let mut stdout = io::stdout();
430
431    loop {
432        let _ = write!(stdout, "lean-ctx> ");
433        let _ = stdout.flush();
434
435        let mut line = String::new();
436        match stdin.lock().read_line(&mut line) {
437            Ok(0) => break,
438            Ok(_) => {}
439            Err(_) => break,
440        }
441
442        let cmd = line.trim();
443        if cmd.is_empty() {
444            continue;
445        }
446        if cmd == "exit" || cmd == "quit" {
447            break;
448        }
449        if cmd == "gain" {
450            println!("{}", stats::format_gain());
451            continue;
452        }
453
454        let exit_code = exec(cmd);
455
456        if exit_code != 0 {
457            let _ = writeln!(stdout, "[exit: {exit_code}]");
458        }
459    }
460}
461
462fn compress_and_measure(command: &str, stdout: &str, stderr: &str) -> (String, usize) {
463    let compressed_stdout = compress_if_beneficial(command, stdout);
464    let compressed_stderr = compress_if_beneficial(command, stderr);
465
466    let mut result = String::new();
467    if !compressed_stdout.is_empty() {
468        result.push_str(&compressed_stdout);
469    }
470    if !compressed_stderr.is_empty() {
471        if !result.is_empty() {
472            result.push('\n');
473        }
474        result.push_str(&compressed_stderr);
475    }
476
477    // Count tokens on content BEFORE the [lean-ctx: ...] footer to avoid
478    // counting the annotation overhead against savings.
479    let content_for_counting = if let Some(pos) = result.rfind("\n[lean-ctx: ") {
480        &result[..pos]
481    } else {
482        &result
483    };
484    let output_tokens = count_tokens(content_for_counting);
485    (result, output_tokens)
486}
487
488fn compress_if_beneficial(command: &str, output: &str) -> String {
489    if output.trim().is_empty() {
490        return String::new();
491    }
492
493    if crate::tools::ctx_shell::contains_auth_flow(output) {
494        return output.to_string();
495    }
496
497    let original_tokens = count_tokens(output);
498
499    if original_tokens < 50 {
500        return output.to_string();
501    }
502
503    let min_output_tokens = 5;
504
505    if let Some(compressed) = patterns::compress_output(command, output) {
506        if !compressed.trim().is_empty() {
507            let compressed_tokens = count_tokens(&compressed);
508            if compressed_tokens >= min_output_tokens && compressed_tokens < original_tokens {
509                let saved = original_tokens - compressed_tokens;
510                let pct = (saved as f64 / original_tokens as f64 * 100.0).round() as usize;
511                if pct >= 5 {
512                    return format!(
513                        "{compressed}\n[lean-ctx: {original_tokens}→{compressed_tokens} tok, -{pct}%]"
514                    );
515                }
516                return compressed;
517            }
518            if compressed_tokens < min_output_tokens {
519                return output.to_string();
520            }
521        }
522    }
523
524    // Apply lightweight cleanup to remove whitespace-only lines and collapse braces
525    let cleaned = crate::core::compressor::lightweight_cleanup(output);
526    let cleaned_tokens = count_tokens(&cleaned);
527    if cleaned_tokens < original_tokens {
528        let lines: Vec<&str> = cleaned.lines().collect();
529        if lines.len() > 30 {
530            let first = &lines[..5];
531            let last = &lines[lines.len() - 5..];
532            let omitted = lines.len() - 10;
533            let total = lines.len();
534            let compressed = format!(
535                "{}\n[truncated: showing 10/{total} lines, {omitted} omitted]\n{}",
536                first.join("\n"),
537                last.join("\n")
538            );
539            let ct = count_tokens(&compressed);
540            if ct < original_tokens {
541                let saved = original_tokens - ct;
542                let pct = (saved as f64 / original_tokens as f64 * 100.0).round() as usize;
543                if pct >= 5 {
544                    return format!(
545                        "{compressed}\n[lean-ctx: {original_tokens}→{ct} tok, -{pct}%]"
546                    );
547                }
548                return compressed;
549            }
550        }
551        if cleaned_tokens < original_tokens {
552            let saved = original_tokens - cleaned_tokens;
553            let pct = (saved as f64 / original_tokens as f64 * 100.0).round() as usize;
554            if pct >= 5 {
555                return format!(
556                    "{cleaned}\n[lean-ctx: {original_tokens}→{cleaned_tokens} tok, -{pct}%]"
557                );
558            }
559            return cleaned;
560        }
561    }
562
563    let lines: Vec<&str> = output.lines().collect();
564    if lines.len() > 30 {
565        let first = &lines[..5];
566        let last = &lines[lines.len() - 5..];
567        let omitted = lines.len() - 10;
568        let compressed = format!(
569            "{}\n... ({omitted} lines omitted) ...\n{}",
570            first.join("\n"),
571            last.join("\n")
572        );
573        let compressed_tokens = count_tokens(&compressed);
574        if compressed_tokens < original_tokens {
575            let saved = original_tokens - compressed_tokens;
576            let pct = (saved as f64 / original_tokens as f64 * 100.0).round() as usize;
577            if pct >= 5 {
578                return format!(
579                    "{compressed}\n[lean-ctx: {original_tokens}→{compressed_tokens} tok, -{pct}%]"
580                );
581            }
582            return compressed;
583        }
584    }
585
586    output.to_string()
587}
588
589/// Windows only: argument that passes one command string to the shell binary.
590/// `exe_basename` must already be ASCII-lowercase (e.g. `bash.exe`, `cmd.exe`).
591fn windows_shell_flag_for_exe_basename(exe_basename: &str) -> &'static str {
592    if exe_basename.contains("powershell") || exe_basename.contains("pwsh") {
593        "-Command"
594    } else if exe_basename == "cmd.exe" || exe_basename == "cmd" {
595        "/C"
596    } else {
597        // POSIX-style shells: Git Bash / MSYS (`bash`, `sh`, `zsh`, `fish`, …).
598        // `/C` is only valid for `cmd.exe`; using it with bash produced
599        // `/C: Is a directory` and exit 126 (see github.com/yvgude/lean-ctx/issues/7).
600        "-c"
601    }
602}
603
604pub fn shell_and_flag() -> (String, String) {
605    let shell = detect_shell();
606    let flag = if cfg!(windows) {
607        let name = std::path::Path::new(&shell)
608            .file_name()
609            .and_then(|n| n.to_str())
610            .unwrap_or("")
611            .to_ascii_lowercase();
612        windows_shell_flag_for_exe_basename(&name).to_string()
613    } else {
614        "-c".to_string()
615    };
616    (shell, flag)
617}
618
619fn detect_shell() -> String {
620    if let Ok(shell) = std::env::var("LEAN_CTX_SHELL") {
621        return shell;
622    }
623
624    if let Ok(shell) = std::env::var("SHELL") {
625        let bin = std::path::Path::new(&shell)
626            .file_name()
627            .and_then(|n| n.to_str())
628            .unwrap_or("sh");
629
630        if bin == "lean-ctx" {
631            return find_real_shell();
632        }
633        return shell;
634    }
635
636    find_real_shell()
637}
638
639#[cfg(unix)]
640fn find_real_shell() -> String {
641    for shell in &["/bin/zsh", "/bin/bash", "/bin/sh"] {
642        if std::path::Path::new(shell).exists() {
643            return shell.to_string();
644        }
645    }
646    "/bin/sh".to_string()
647}
648
649#[cfg(windows)]
650fn find_real_shell() -> String {
651    if is_running_in_powershell() {
652        if let Ok(pwsh) = which_powershell() {
653            return pwsh;
654        }
655    }
656    if let Ok(comspec) = std::env::var("COMSPEC") {
657        return comspec;
658    }
659    "cmd.exe".to_string()
660}
661
662#[cfg(windows)]
663fn is_running_in_powershell() -> bool {
664    std::env::var("PSModulePath").is_ok()
665}
666
667#[cfg(windows)]
668fn which_powershell() -> Result<String, ()> {
669    for candidate in &["pwsh.exe", "powershell.exe"] {
670        if let Ok(output) = std::process::Command::new("where").arg(candidate).output() {
671            if output.status.success() {
672                if let Ok(path) = String::from_utf8(output.stdout) {
673                    if let Some(first_line) = path.lines().next() {
674                        let trimmed = first_line.trim();
675                        if !trimmed.is_empty() {
676                            return Ok(trimmed.to_string());
677                        }
678                    }
679                }
680            }
681        }
682    }
683    Err(())
684}
685
686pub fn save_tee(command: &str, output: &str) -> Option<String> {
687    let tee_dir = dirs::home_dir()?.join(".lean-ctx").join("tee");
688    std::fs::create_dir_all(&tee_dir).ok()?;
689
690    cleanup_old_tee_logs(&tee_dir);
691
692    let cmd_slug: String = command
693        .chars()
694        .take(40)
695        .map(|c| {
696            if c.is_alphanumeric() || c == '-' {
697                c
698            } else {
699                '_'
700            }
701        })
702        .collect();
703    let ts = chrono::Local::now().format("%Y-%m-%d_%H%M%S");
704    let filename = format!("{ts}_{cmd_slug}.log");
705    let path = tee_dir.join(&filename);
706
707    let masked = mask_sensitive_data(output);
708    std::fs::write(&path, masked).ok()?;
709    Some(path.to_string_lossy().to_string())
710}
711
712fn mask_sensitive_data(input: &str) -> String {
713    use regex::Regex;
714
715    let patterns: Vec<(&str, Regex)> = vec![
716        ("Bearer token", Regex::new(r"(?i)(bearer\s+)[a-zA-Z0-9\-_\.]{8,}").unwrap()),
717        ("Authorization header", Regex::new(r"(?i)(authorization:\s*(?:basic|bearer|token)\s+)[^\s\r\n]+").unwrap()),
718        ("API key param", Regex::new(r#"(?i)((?:api[_-]?key|apikey|access[_-]?key|secret[_-]?key|token|password|passwd|pwd|secret)\s*[=:]\s*)[^\s\r\n,;&"']+"#).unwrap()),
719        ("AWS key", Regex::new(r"(AKIA[0-9A-Z]{12,})").unwrap()),
720        ("Private key block", Regex::new(r"(?s)(-----BEGIN\s+(?:RSA\s+)?PRIVATE\s+KEY-----).+?(-----END\s+(?:RSA\s+)?PRIVATE\s+KEY-----)").unwrap()),
721        ("GitHub token", Regex::new(r"(gh[pousr]_)[a-zA-Z0-9]{20,}").unwrap()),
722        ("Generic long hex/base64 secret", Regex::new(r#"(?i)(?:key|token|secret|password|credential|auth)\s*[=:]\s*['"]?([a-zA-Z0-9+/=\-_]{32,})['"]?"#).unwrap()),
723    ];
724
725    let mut result = input.to_string();
726    for (label, re) in &patterns {
727        result = re
728            .replace_all(&result, |caps: &regex::Captures| {
729                if let Some(prefix) = caps.get(1) {
730                    format!("{}[REDACTED:{}]", prefix.as_str(), label)
731                } else {
732                    format!("[REDACTED:{}]", label)
733                }
734            })
735            .to_string();
736    }
737    result
738}
739
740fn cleanup_old_tee_logs(tee_dir: &std::path::Path) {
741    let cutoff =
742        std::time::SystemTime::now().checked_sub(std::time::Duration::from_secs(24 * 60 * 60));
743    let cutoff = match cutoff {
744        Some(t) => t,
745        None => return,
746    };
747
748    if let Ok(entries) = std::fs::read_dir(tee_dir) {
749        for entry in entries.flatten() {
750            if let Ok(meta) = entry.metadata() {
751                if let Ok(modified) = meta.modified() {
752                    if modified < cutoff {
753                        let _ = std::fs::remove_file(entry.path());
754                    }
755                }
756            }
757        }
758    }
759}
760
761/// Join multiple CLI arguments into a single command string, using quoting
762/// conventions appropriate for the detected shell.
763///
764/// On Unix, this always produces POSIX-compatible quoting.
765/// On Windows, the quoting adapts to the actual shell (PowerShell, cmd.exe,
766/// or Git Bash / MSYS).
767pub fn join_command(args: &[String]) -> String {
768    let (_, flag) = shell_and_flag();
769    join_command_for(args, &flag)
770}
771
772fn join_command_for(args: &[String], shell_flag: &str) -> String {
773    match shell_flag {
774        "-Command" => join_powershell(args),
775        "/C" => join_cmd(args),
776        _ => join_posix(args),
777    }
778}
779
780fn join_posix(args: &[String]) -> String {
781    args.iter()
782        .map(|a| quote_posix(a))
783        .collect::<Vec<_>>()
784        .join(" ")
785}
786
787fn join_powershell(args: &[String]) -> String {
788    let quoted: Vec<String> = args.iter().map(|a| quote_powershell(a)).collect();
789    format!("& {}", quoted.join(" "))
790}
791
792fn join_cmd(args: &[String]) -> String {
793    args.iter()
794        .map(|a| quote_cmd(a))
795        .collect::<Vec<_>>()
796        .join(" ")
797}
798
799fn quote_posix(s: &str) -> String {
800    if s.is_empty() {
801        return "''".to_string();
802    }
803    if s.bytes()
804        .all(|b| b.is_ascii_alphanumeric() || b"-_./=:@,+%^".contains(&b))
805    {
806        return s.to_string();
807    }
808    format!("'{}'", s.replace('\'', "'\\''"))
809}
810
811fn quote_powershell(s: &str) -> String {
812    if s.is_empty() {
813        return "''".to_string();
814    }
815    if s.bytes()
816        .all(|b| b.is_ascii_alphanumeric() || b"-_./=:@,+%^".contains(&b))
817    {
818        return s.to_string();
819    }
820    format!("'{}'", s.replace('\'', "''"))
821}
822
823fn quote_cmd(s: &str) -> String {
824    if s.is_empty() {
825        return "\"\"".to_string();
826    }
827    if s.bytes()
828        .all(|b| b.is_ascii_alphanumeric() || b"-_./=:@,+%^\\".contains(&b))
829    {
830        return s.to_string();
831    }
832    format!("\"{}\"", s.replace('"', "\\\""))
833}
834
835#[cfg(test)]
836mod join_command_tests {
837    use super::*;
838
839    #[test]
840    fn posix_simple_args() {
841        let args: Vec<String> = vec!["git".into(), "status".into()];
842        assert_eq!(join_command_for(&args, "-c"), "git status");
843    }
844
845    #[test]
846    fn posix_path_with_spaces() {
847        let args: Vec<String> = vec!["/usr/local/my app/bin".into(), "--help".into()];
848        assert_eq!(
849            join_command_for(&args, "-c"),
850            "'/usr/local/my app/bin' --help"
851        );
852    }
853
854    #[test]
855    fn posix_single_quotes_escaped() {
856        let args: Vec<String> = vec!["echo".into(), "it's".into()];
857        assert_eq!(join_command_for(&args, "-c"), "echo 'it'\\''s'");
858    }
859
860    #[test]
861    fn posix_empty_arg() {
862        let args: Vec<String> = vec!["cmd".into(), "".into()];
863        assert_eq!(join_command_for(&args, "-c"), "cmd ''");
864    }
865
866    #[test]
867    fn powershell_simple_args() {
868        let args: Vec<String> = vec!["npm".into(), "install".into()];
869        assert_eq!(join_command_for(&args, "-Command"), "& npm install");
870    }
871
872    #[test]
873    fn powershell_path_with_spaces() {
874        let args: Vec<String> = vec![
875            "C:\\Program Files\\nodejs\\npm.cmd".into(),
876            "install".into(),
877        ];
878        assert_eq!(
879            join_command_for(&args, "-Command"),
880            "& 'C:\\Program Files\\nodejs\\npm.cmd' install"
881        );
882    }
883
884    #[test]
885    fn powershell_single_quotes_escaped() {
886        let args: Vec<String> = vec!["echo".into(), "it's done".into()];
887        assert_eq!(join_command_for(&args, "-Command"), "& echo 'it''s done'");
888    }
889
890    #[test]
891    fn cmd_simple_args() {
892        let args: Vec<String> = vec!["npm.cmd".into(), "install".into()];
893        assert_eq!(join_command_for(&args, "/C"), "npm.cmd install");
894    }
895
896    #[test]
897    fn cmd_path_with_spaces() {
898        let args: Vec<String> = vec![
899            "C:\\Program Files\\nodejs\\npm.cmd".into(),
900            "install".into(),
901        ];
902        assert_eq!(
903            join_command_for(&args, "/C"),
904            "\"C:\\Program Files\\nodejs\\npm.cmd\" install"
905        );
906    }
907
908    #[test]
909    fn cmd_double_quotes_escaped() {
910        let args: Vec<String> = vec!["echo".into(), "say \"hello\"".into()];
911        assert_eq!(join_command_for(&args, "/C"), "echo \"say \\\"hello\\\"\"");
912    }
913
914    #[test]
915    fn unknown_flag_uses_posix() {
916        let args: Vec<String> = vec!["ls".into(), "-la".into()];
917        assert_eq!(join_command_for(&args, "--exec"), "ls -la");
918    }
919}
920
921#[cfg(test)]
922mod windows_shell_flag_tests {
923    use super::windows_shell_flag_for_exe_basename;
924
925    #[test]
926    fn cmd_uses_slash_c() {
927        assert_eq!(windows_shell_flag_for_exe_basename("cmd.exe"), "/C");
928        assert_eq!(windows_shell_flag_for_exe_basename("cmd"), "/C");
929    }
930
931    #[test]
932    fn powershell_uses_command() {
933        assert_eq!(
934            windows_shell_flag_for_exe_basename("powershell.exe"),
935            "-Command"
936        );
937        assert_eq!(windows_shell_flag_for_exe_basename("pwsh.exe"), "-Command");
938    }
939
940    #[test]
941    fn posix_shells_use_dash_c() {
942        assert_eq!(windows_shell_flag_for_exe_basename("bash.exe"), "-c");
943        assert_eq!(windows_shell_flag_for_exe_basename("bash"), "-c");
944        assert_eq!(windows_shell_flag_for_exe_basename("sh.exe"), "-c");
945        assert_eq!(windows_shell_flag_for_exe_basename("zsh.exe"), "-c");
946        assert_eq!(windows_shell_flag_for_exe_basename("fish.exe"), "-c");
947    }
948}
949
950#[cfg(test)]
951mod passthrough_tests {
952    use super::is_excluded_command;
953
954    #[test]
955    fn turbo_is_passthrough() {
956        assert!(is_excluded_command("turbo run dev", &[]));
957        assert!(is_excluded_command("turbo run build", &[]));
958        assert!(is_excluded_command("pnpm turbo run dev", &[]));
959        assert!(is_excluded_command("npx turbo run dev", &[]));
960    }
961
962    #[test]
963    fn dev_servers_are_passthrough() {
964        assert!(is_excluded_command("next dev", &[]));
965        assert!(is_excluded_command("vite dev", &[]));
966        assert!(is_excluded_command("nuxt dev", &[]));
967        assert!(is_excluded_command("astro dev", &[]));
968        assert!(is_excluded_command("nodemon server.js", &[]));
969    }
970
971    #[test]
972    fn interactive_tools_are_passthrough() {
973        assert!(is_excluded_command("vim file.rs", &[]));
974        assert!(is_excluded_command("nvim", &[]));
975        assert!(is_excluded_command("htop", &[]));
976        assert!(is_excluded_command("ssh user@host", &[]));
977        assert!(is_excluded_command("tail -f /var/log/syslog", &[]));
978    }
979
980    #[test]
981    fn docker_streaming_is_passthrough() {
982        assert!(is_excluded_command("docker logs my-container", &[]));
983        assert!(is_excluded_command("docker logs -f webapp", &[]));
984        assert!(is_excluded_command("docker attach my-container", &[]));
985        assert!(is_excluded_command("docker exec -it web bash", &[]));
986        assert!(is_excluded_command("docker exec -ti web bash", &[]));
987        assert!(is_excluded_command("docker run -it ubuntu bash", &[]));
988        assert!(is_excluded_command("docker compose exec web bash", &[]));
989        assert!(is_excluded_command("docker stats", &[]));
990        assert!(is_excluded_command("docker events", &[]));
991    }
992
993    #[test]
994    fn kubectl_is_passthrough() {
995        assert!(is_excluded_command("kubectl logs my-pod", &[]));
996        assert!(is_excluded_command("kubectl logs -f deploy/web", &[]));
997        assert!(is_excluded_command("kubectl exec -it pod -- bash", &[]));
998        assert!(is_excluded_command(
999            "kubectl port-forward svc/web 8080:80",
1000            &[]
1001        ));
1002        assert!(is_excluded_command("kubectl attach my-pod", &[]));
1003        assert!(is_excluded_command("kubectl proxy", &[]));
1004    }
1005
1006    #[test]
1007    fn database_repls_are_passthrough() {
1008        assert!(is_excluded_command("psql -U user mydb", &[]));
1009        assert!(is_excluded_command("mysql -u root -p", &[]));
1010        assert!(is_excluded_command("sqlite3 data.db", &[]));
1011        assert!(is_excluded_command("redis-cli", &[]));
1012        assert!(is_excluded_command("mongosh", &[]));
1013    }
1014
1015    #[test]
1016    fn streaming_tools_are_passthrough() {
1017        assert!(is_excluded_command("journalctl -f", &[]));
1018        assert!(is_excluded_command("ping 8.8.8.8", &[]));
1019        assert!(is_excluded_command("strace -p 1234", &[]));
1020        assert!(is_excluded_command("tcpdump -i eth0", &[]));
1021        assert!(is_excluded_command("tail -F /var/log/app.log", &[]));
1022        assert!(is_excluded_command("tmux new -s work", &[]));
1023        assert!(is_excluded_command("screen -S dev", &[]));
1024    }
1025
1026    #[test]
1027    fn additional_dev_servers_are_passthrough() {
1028        assert!(is_excluded_command("gatsby develop", &[]));
1029        assert!(is_excluded_command("ng serve --port 4200", &[]));
1030        assert!(is_excluded_command("remix dev", &[]));
1031        assert!(is_excluded_command("wrangler dev", &[]));
1032        assert!(is_excluded_command("hugo server", &[]));
1033        assert!(is_excluded_command("bun dev", &[]));
1034        assert!(is_excluded_command("cargo watch -x test", &[]));
1035    }
1036
1037    #[test]
1038    fn normal_commands_not_excluded() {
1039        assert!(!is_excluded_command("git status", &[]));
1040        assert!(!is_excluded_command("cargo test", &[]));
1041        assert!(!is_excluded_command("npm run build", &[]));
1042        assert!(!is_excluded_command("ls -la", &[]));
1043    }
1044
1045    #[test]
1046    fn user_exclusions_work() {
1047        let excl = vec!["myapp".to_string()];
1048        assert!(is_excluded_command("myapp serve", &excl));
1049        assert!(!is_excluded_command("git status", &excl));
1050    }
1051
1052    #[test]
1053    fn is_container_returns_bool() {
1054        let _ = super::is_container();
1055    }
1056
1057    #[test]
1058    fn is_non_interactive_returns_bool() {
1059        let _ = super::is_non_interactive();
1060    }
1061
1062    #[test]
1063    fn auth_commands_excluded() {
1064        assert!(is_excluded_command("az login --use-device-code", &[]));
1065        assert!(is_excluded_command("gh auth login", &[]));
1066        assert!(is_excluded_command("gcloud auth login", &[]));
1067        assert!(is_excluded_command("aws sso login", &[]));
1068        assert!(is_excluded_command("firebase login", &[]));
1069        assert!(is_excluded_command("vercel login", &[]));
1070        assert!(is_excluded_command("heroku login", &[]));
1071        assert!(is_excluded_command("az login", &[]));
1072        assert!(is_excluded_command("kubelogin convert-kubeconfig", &[]));
1073        assert!(is_excluded_command("vault login -method=oidc", &[]));
1074        assert!(is_excluded_command("flyctl auth login", &[]));
1075    }
1076
1077    #[test]
1078    fn auth_exclusion_does_not_affect_normal_commands() {
1079        assert!(!is_excluded_command("git log", &[]));
1080        assert!(!is_excluded_command("npm run build", &[]));
1081        assert!(!is_excluded_command("cargo test", &[]));
1082        assert!(!is_excluded_command("aws s3 ls", &[]));
1083        assert!(!is_excluded_command("gcloud compute instances list", &[]));
1084        assert!(!is_excluded_command("az vm list", &[]));
1085    }
1086}