Skip to main content

lean_ctx/shell/
platform.rs

1use std::io::{self, IsTerminal};
2
3/// Sets `LC_CTYPE=C.UTF-8` when no UTF-8 locale is inherited from the parent
4/// process. Without this, commands treat bytes >127 as non-printable (C locale),
5/// mangling Cyrillic, CJK, emoji, etc.
6pub(crate) fn apply_utf8_locale(cmd: &mut std::process::Command) {
7    let has_utf8 = std::env::var("LC_ALL")
8        .or_else(|_| std::env::var("LC_CTYPE"))
9        .or_else(|_| std::env::var("LANG"))
10        .is_ok_and(|v| v.to_ascii_lowercase().contains("utf"));
11
12    if !has_utf8 {
13        cmd.env("LC_CTYPE", "C.UTF-8");
14    }
15}
16
17pub fn decode_output(bytes: &[u8]) -> String {
18    match String::from_utf8(bytes.to_vec()) {
19        Ok(s) => s,
20        Err(_) => {
21            #[cfg(windows)]
22            {
23                decode_windows_output(bytes)
24            }
25            #[cfg(not(windows))]
26            {
27                String::from_utf8_lossy(bytes).into_owned()
28            }
29        }
30    }
31}
32
33#[cfg(windows)]
34fn decode_windows_output(bytes: &[u8]) -> String {
35    use std::os::windows::ffi::OsStringExt;
36
37    let lossy = String::from_utf8_lossy(bytes);
38    let replacement_count = lossy.chars().filter(|&c| c == '\u{FFFD}').count();
39    if replacement_count == 0 {
40        return lossy.into_owned();
41    }
42
43    extern "system" {
44        fn GetACP() -> u32;
45        fn MultiByteToWideChar(
46            cp: u32,
47            flags: u32,
48            src: *const u8,
49            srclen: i32,
50            dst: *mut u16,
51            dstlen: i32,
52        ) -> i32;
53    }
54
55    // SAFETY: `GetACP` takes no arguments and only returns the active code
56    // page; it cannot fail or cause undefined behaviour.
57    let codepage = unsafe { GetACP() };
58    // SAFETY: called with a null destination and length 0 to measure the
59    // required buffer size; `bytes` is a live slice and every pointer/length
60    // argument is valid.
61    let wide_len = unsafe {
62        MultiByteToWideChar(
63            codepage,
64            0,
65            bytes.as_ptr(),
66            bytes.len() as i32,
67            std::ptr::null_mut(),
68            0,
69        )
70    };
71    if wide_len <= 0 {
72        return lossy.into_owned();
73    }
74    let mut wide: Vec<u16> = vec![0u16; wide_len as usize];
75    // SAFETY: `wide` is sized to the previously measured length and `bytes` is
76    // a live slice; the source and destination pointers/lengths are valid and
77    // do not overlap.
78    unsafe {
79        MultiByteToWideChar(
80            codepage,
81            0,
82            bytes.as_ptr(),
83            bytes.len() as i32,
84            wide.as_mut_ptr(),
85            wide_len,
86        );
87    }
88    std::ffi::OsString::from_wide(&wide)
89        .to_string_lossy()
90        .into_owned()
91}
92
93#[cfg(windows)]
94pub(super) fn set_console_utf8() {
95    extern "system" {
96        fn SetConsoleOutputCP(id: u32) -> i32;
97    }
98    // SAFETY: `SetConsoleOutputCP` takes a code-page id (65001 = UTF-8) by
99    // value; it cannot cause undefined behaviour.
100    unsafe {
101        SetConsoleOutputCP(65001);
102    }
103}
104
105/// Detects if the current process runs inside a Docker/container environment.
106pub fn is_container() -> bool {
107    #[cfg(unix)]
108    {
109        if std::path::Path::new("/.dockerenv").exists() {
110            return true;
111        }
112        if let Ok(cgroup) = std::fs::read_to_string("/proc/1/cgroup") {
113            if cgroup.contains("/docker/") || cgroup.contains("/lxc/") {
114                return true;
115            }
116        }
117        if let Ok(mounts) = std::fs::read_to_string("/proc/self/mountinfo") {
118            if mounts.contains("/docker/containers/") {
119                return true;
120            }
121        }
122        false
123    }
124    #[cfg(not(unix))]
125    {
126        false
127    }
128}
129
130/// Returns true if stdin is NOT a terminal (pipe, /dev/null, etc.)
131pub fn is_non_interactive() -> bool {
132    !io::stdin().is_terminal()
133}
134
135/// Returns `true` when `shell_path` points to a PowerShell executable.
136pub(crate) fn is_powershell(shell_path: &str) -> bool {
137    let name = std::path::Path::new(shell_path)
138        .file_name()
139        .and_then(|n| n.to_str())
140        .unwrap_or("")
141        .to_ascii_lowercase();
142    name.contains("powershell") || name.contains("pwsh")
143}
144
145/// Path to the current-user PowerShell profile (`$PROFILE.CurrentUserCurrentHost`).
146///
147/// Windows PowerShell stores it under `Documents\PowerShell\…`, but **PowerShell
148/// (pwsh) on macOS/Linux reads `~/.config/powershell/…` instead** — and stat-ing
149/// anything inside `~/Documents` on macOS pops a TCC privacy prompt ("lean-ctx
150/// would like to access files in your Documents folder", #356). Resolving the
151/// profile per-OS keeps pwsh support everywhere while never touching `~/Documents`
152/// on non-Windows hosts.
153pub(crate) fn powershell_profile_path(home: &std::path::Path) -> std::path::PathBuf {
154    const PROFILE_FILE: &str = "Microsoft.PowerShell_profile.ps1";
155    if cfg!(windows) {
156        home.join("Documents").join("PowerShell").join(PROFILE_FILE)
157    } else {
158        home.join(".config").join("powershell").join(PROFILE_FILE)
159    }
160}
161
162/// Windows only: argument that passes one command string to the shell binary.
163/// `exe_basename` must already be ASCII-lowercase (e.g. `bash.exe`, `cmd.exe`).
164fn windows_shell_flag_for_exe_basename(exe_basename: &str) -> &'static str {
165    if exe_basename.contains("powershell") || exe_basename.contains("pwsh") {
166        "-Command"
167    } else if exe_basename == "cmd.exe" || exe_basename == "cmd" {
168        "/C"
169    } else {
170        "-c"
171    }
172}
173
174pub fn shell_and_flag() -> (String, String) {
175    let shell = detect_shell();
176    let flag = if cfg!(windows) {
177        let name = std::path::Path::new(&shell)
178            .file_name()
179            .and_then(|n| n.to_str())
180            .unwrap_or("")
181            .to_ascii_lowercase();
182        windows_shell_flag_for_exe_basename(&name).to_string()
183    } else {
184        "-c".to_string()
185    };
186    (shell, flag)
187}
188
189/// Returns a short, human-readable shell name (e.g. "bash", "zsh", "powershell", "cmd").
190pub fn shell_name() -> String {
191    let shell = detect_shell();
192    let basename = std::path::Path::new(&shell)
193        .file_name()
194        .and_then(|n| n.to_str())
195        .unwrap_or("sh")
196        .to_ascii_lowercase();
197    basename
198        .strip_suffix(".exe")
199        .unwrap_or(&basename)
200        .to_string()
201}
202
203pub(super) fn detect_shell() -> String {
204    if let Ok(shell) = std::env::var("LEAN_CTX_SHELL") {
205        return shell;
206    }
207
208    if let Ok(shell) = std::env::var("SHELL") {
209        let bin = std::path::Path::new(&shell)
210            .file_name()
211            .and_then(|n| n.to_str())
212            .unwrap_or("sh");
213
214        if bin == "lean-ctx" {
215            return find_real_shell();
216        }
217        return shell;
218    }
219
220    find_real_shell()
221}
222
223#[cfg(unix)]
224fn find_real_shell() -> String {
225    for shell in &["/bin/zsh", "/bin/bash", "/bin/sh"] {
226        if std::path::Path::new(shell).exists() {
227            return shell.to_string();
228        }
229    }
230    "/bin/sh".to_string()
231}
232
233#[cfg(windows)]
234fn find_real_shell() -> String {
235    if is_running_in_msys_or_gitbash() {
236        for candidate in &["bash.exe", "sh.exe"] {
237            if let Ok(output) = std::process::Command::new("where").arg(candidate).output() {
238                if output.status.success() {
239                    if let Ok(path) = String::from_utf8(output.stdout) {
240                        if let Some(first_line) = path.lines().next() {
241                            let trimmed = first_line.trim();
242                            if !trimmed.is_empty() {
243                                return trimmed.to_string();
244                            }
245                        }
246                    }
247                }
248            }
249        }
250    }
251    if let Ok(pwsh) = which_powershell() {
252        return pwsh;
253    }
254    if let Ok(comspec) = std::env::var("COMSPEC") {
255        return comspec;
256    }
257    "cmd.exe".to_string()
258}
259
260#[cfg(windows)]
261fn is_running_in_msys_or_gitbash() -> bool {
262    std::env::var("MSYSTEM").is_ok() || std::env::var("MINGW_PREFIX").is_ok()
263}
264
265#[cfg(windows)]
266fn which_powershell() -> Result<String, ()> {
267    for candidate in &["pwsh.exe", "powershell.exe"] {
268        if let Ok(output) = std::process::Command::new("where").arg(candidate).output() {
269            if output.status.success() {
270                if let Ok(path) = String::from_utf8(output.stdout) {
271                    if let Some(first_line) = path.lines().next() {
272                        let trimmed = first_line.trim();
273                        if !trimmed.is_empty() {
274                            return Ok(trimmed.to_string());
275                        }
276                    }
277                }
278            }
279        }
280    }
281    Err(())
282}
283
284/// Join multiple CLI arguments into a single command string, using quoting
285/// conventions appropriate for the detected shell.
286///
287/// On Unix, this always produces POSIX-compatible quoting.
288/// On Windows, the quoting adapts to the actual shell (PowerShell, cmd.exe,
289/// or Git Bash / MSYS).
290pub fn join_command(args: &[String]) -> String {
291    let (_, flag) = shell_and_flag();
292    join_command_for(args, &flag)
293}
294
295pub fn join_command_for(args: &[String], shell_flag: &str) -> String {
296    match shell_flag {
297        "-Command" => join_powershell(args),
298        "/C" => join_cmd(args),
299        _ => join_posix(args),
300    }
301}
302
303fn join_posix(args: &[String]) -> String {
304    args.iter()
305        .map(|a| quote_posix(a))
306        .collect::<Vec<_>>()
307        .join(" ")
308}
309
310fn join_powershell(args: &[String]) -> String {
311    if args.len() == 1 && args[0].contains(' ') {
312        return args[0].clone();
313    }
314    let quoted: Vec<String> = args.iter().map(|a| quote_powershell(a)).collect();
315    format!("& {}", quoted.join(" "))
316}
317
318fn join_cmd(args: &[String]) -> String {
319    args.iter()
320        .map(|a| quote_cmd(a))
321        .collect::<Vec<_>>()
322        .join(" ")
323}
324
325fn quote_posix(s: &str) -> String {
326    if s.is_empty() {
327        return "''".to_string();
328    }
329    if s.bytes()
330        .all(|b| b.is_ascii_alphanumeric() || b"-_./=:@,+%^".contains(&b))
331    {
332        return s.to_string();
333    }
334    format!("'{}'", s.replace('\'', "'\\''"))
335}
336
337fn quote_powershell(s: &str) -> String {
338    if s.is_empty() {
339        return "''".to_string();
340    }
341    if s.bytes()
342        .all(|b| b.is_ascii_alphanumeric() || b"-_./=:@,+%^".contains(&b))
343    {
344        return s.to_string();
345    }
346    format!("'{}'", s.replace('\'', "''"))
347}
348
349fn quote_cmd(s: &str) -> String {
350    if s.is_empty() {
351        return "\"\"".to_string();
352    }
353    if s.bytes()
354        .all(|b| b.is_ascii_alphanumeric() || b"-_./=:@,+%^\\".contains(&b))
355    {
356        return s.to_string();
357    }
358    format!("\"{}\"", s.replace('"', "\\\""))
359}
360
361#[cfg(test)]
362mod join_command_tests {
363    use super::*;
364
365    #[test]
366    fn posix_simple_args() {
367        let args: Vec<String> = vec!["git".into(), "status".into()];
368        assert_eq!(join_command_for(&args, "-c"), "git status");
369    }
370
371    #[test]
372    fn posix_path_with_spaces() {
373        let args: Vec<String> = vec!["/usr/local/my app/bin".into(), "--help".into()];
374        assert_eq!(
375            join_command_for(&args, "-c"),
376            "'/usr/local/my app/bin' --help"
377        );
378    }
379
380    #[test]
381    fn posix_single_quotes_escaped() {
382        let args: Vec<String> = vec!["echo".into(), "it's".into()];
383        assert_eq!(join_command_for(&args, "-c"), "echo 'it'\\''s'");
384    }
385
386    #[test]
387    fn posix_empty_arg() {
388        let args: Vec<String> = vec!["cmd".into(), String::new()];
389        assert_eq!(join_command_for(&args, "-c"), "cmd ''");
390    }
391
392    #[test]
393    fn powershell_simple_args() {
394        let args: Vec<String> = vec!["npm".into(), "install".into()];
395        assert_eq!(join_command_for(&args, "-Command"), "& npm install");
396    }
397
398    #[test]
399    fn powershell_path_with_spaces() {
400        let args: Vec<String> = vec![
401            "C:\\Program Files\\nodejs\\npm.cmd".into(),
402            "install".into(),
403        ];
404        assert_eq!(
405            join_command_for(&args, "-Command"),
406            "& 'C:\\Program Files\\nodejs\\npm.cmd' install"
407        );
408    }
409
410    #[test]
411    fn powershell_single_quotes_escaped() {
412        let args: Vec<String> = vec!["echo".into(), "it's done".into()];
413        assert_eq!(join_command_for(&args, "-Command"), "& echo 'it''s done'");
414    }
415
416    #[test]
417    fn cmd_simple_args() {
418        let args: Vec<String> = vec!["npm.cmd".into(), "install".into()];
419        assert_eq!(join_command_for(&args, "/C"), "npm.cmd install");
420    }
421
422    #[test]
423    fn cmd_path_with_spaces() {
424        let args: Vec<String> = vec![
425            "C:\\Program Files\\nodejs\\npm.cmd".into(),
426            "install".into(),
427        ];
428        assert_eq!(
429            join_command_for(&args, "/C"),
430            "\"C:\\Program Files\\nodejs\\npm.cmd\" install"
431        );
432    }
433
434    #[test]
435    fn cmd_double_quotes_escaped() {
436        let args: Vec<String> = vec!["echo".into(), "say \"hello\"".into()];
437        assert_eq!(join_command_for(&args, "/C"), "echo \"say \\\"hello\\\"\"");
438    }
439
440    #[test]
441    fn unknown_flag_uses_posix() {
442        let args: Vec<String> = vec!["ls".into(), "-la".into()];
443        assert_eq!(join_command_for(&args, "--exec"), "ls -la");
444    }
445
446    #[test]
447    fn powershell_single_full_command_not_quoted() {
448        let args: Vec<String> = vec!["git commit -m \"feat: add feature\"".into()];
449        let result = join_command_for(&args, "-Command");
450        assert_eq!(result, "git commit -m \"feat: add feature\"");
451        assert!(
452            !result.starts_with("& '"),
453            "must not wrap full command in & '...'"
454        );
455    }
456
457    #[test]
458    fn powershell_single_no_spaces_still_uses_call_operator() {
459        let args: Vec<String> = vec!["git".into()];
460        assert_eq!(join_command_for(&args, "-Command"), "& git");
461    }
462}
463
464#[cfg(test)]
465mod is_powershell_tests {
466    use super::is_powershell;
467
468    #[test]
469    fn detects_pwsh_exe() {
470        assert!(is_powershell("pwsh.exe"));
471    }
472
473    #[test]
474    fn detects_powershell_exe() {
475        assert!(is_powershell("powershell.exe"));
476    }
477
478    #[test]
479    fn rejects_cmd() {
480        assert!(!is_powershell("cmd.exe"));
481    }
482
483    #[test]
484    fn rejects_bash() {
485        assert!(!is_powershell("/usr/bin/bash"));
486    }
487
488    #[test]
489    fn case_insensitive() {
490        assert!(is_powershell("PWSH.EXE"));
491        assert!(is_powershell("PowerShell.exe"));
492    }
493
494    #[test]
495    fn full_path_with_pwsh() {
496        assert!(is_powershell(
497            "C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe"
498        ));
499        assert!(is_powershell("/usr/local/bin/pwsh"));
500    }
501}
502
503#[cfg(test)]
504mod powershell_profile_tests {
505    use super::powershell_profile_path;
506    use std::path::Path;
507
508    #[test]
509    fn always_ends_with_profile_file() {
510        let p = powershell_profile_path(Path::new("/home/u"));
511        assert!(p.ends_with("Microsoft.PowerShell_profile.ps1"));
512    }
513
514    #[cfg(not(windows))]
515    #[test]
516    fn non_windows_uses_config_powershell_never_documents() {
517        // #356: stat-ing anything under ~/Documents pops a macOS TCC prompt, so the
518        // non-Windows profile path must live under ~/.config/powershell instead.
519        let p = powershell_profile_path(Path::new("/Users/jane"));
520        assert_eq!(
521            p,
522            Path::new("/Users/jane/.config/powershell/Microsoft.PowerShell_profile.ps1")
523        );
524        assert!(
525            !p.to_string_lossy().contains("Documents"),
526            "macOS/Linux PowerShell profile must never touch ~/Documents (#356)"
527        );
528    }
529
530    #[cfg(windows)]
531    #[test]
532    fn windows_uses_documents_powershell() {
533        let p = powershell_profile_path(Path::new("C:\\Users\\jane"));
534        assert!(p.ends_with("Documents\\PowerShell\\Microsoft.PowerShell_profile.ps1"));
535    }
536}
537
538#[cfg(test)]
539mod windows_shell_flag_tests {
540    use super::windows_shell_flag_for_exe_basename;
541
542    #[test]
543    fn cmd_uses_slash_c() {
544        assert_eq!(windows_shell_flag_for_exe_basename("cmd.exe"), "/C");
545        assert_eq!(windows_shell_flag_for_exe_basename("cmd"), "/C");
546    }
547
548    #[test]
549    fn powershell_uses_command() {
550        assert_eq!(
551            windows_shell_flag_for_exe_basename("powershell.exe"),
552            "-Command"
553        );
554        assert_eq!(windows_shell_flag_for_exe_basename("pwsh.exe"), "-Command");
555    }
556
557    #[test]
558    fn posix_shells_use_dash_c() {
559        assert_eq!(windows_shell_flag_for_exe_basename("bash.exe"), "-c");
560        assert_eq!(windows_shell_flag_for_exe_basename("bash"), "-c");
561        assert_eq!(windows_shell_flag_for_exe_basename("sh.exe"), "-c");
562        assert_eq!(windows_shell_flag_for_exe_basename("zsh.exe"), "-c");
563        assert_eq!(windows_shell_flag_for_exe_basename("fish.exe"), "-c");
564    }
565}
566
567#[cfg(test)]
568mod platform_tests {
569    #[test]
570    fn is_container_returns_bool() {
571        let _ = super::is_container();
572    }
573
574    #[test]
575    fn is_non_interactive_returns_bool() {
576        let _ = super::is_non_interactive();
577    }
578
579    #[test]
580    fn join_command_preserves_structure() {
581        let args = vec![
582            "git".to_string(),
583            "commit".to_string(),
584            "-m".to_string(),
585            "my message".to_string(),
586        ];
587        let joined = super::join_command(&args);
588        assert!(joined.contains("git"));
589        assert!(joined.contains("commit"));
590        assert!(joined.contains("my message") || joined.contains("'my message'"));
591    }
592
593    #[test]
594    fn quote_posix_handles_em_dash() {
595        let result = super::quote_posix("closing — see #407");
596        assert!(
597            result.starts_with('\''),
598            "em-dash args must be single-quoted: {result}"
599        );
600    }
601
602    #[test]
603    fn quote_posix_handles_nested_single_quotes() {
604        let result = super::quote_posix("it's a test");
605        assert!(
606            result.contains("\\'"),
607            "single quotes must be escaped: {result}"
608        );
609    }
610
611    #[test]
612    fn quote_posix_safe_chars_unquoted() {
613        let result = super::quote_posix("simple_word");
614        assert_eq!(result, "simple_word");
615    }
616
617    #[test]
618    fn quote_posix_empty_string() {
619        let result = super::quote_posix("");
620        assert_eq!(result, "''");
621    }
622
623    #[test]
624    fn quote_posix_dollar_expansion_protected() {
625        let result = super::quote_posix("$HOME/test");
626        assert!(
627            result.starts_with('\''),
628            "dollar signs must be single-quoted: {result}"
629        );
630    }
631
632    #[test]
633    fn quote_posix_backtick_protected() {
634        let result = super::quote_posix("echo `date`");
635        assert!(
636            result.starts_with('\''),
637            "backticks must be single-quoted: {result}"
638        );
639    }
640
641    #[test]
642    fn quote_posix_double_quotes_protected() {
643        let result = super::quote_posix(r#"he said "hello""#);
644        assert!(
645            result.starts_with('\''),
646            "double quotes must be single-quoted: {result}"
647        );
648    }
649}