use std::process::{Command, Stdio};
use std::time::Duration;
use wait_timeout::ChildExt;
use super::Shell;
pub fn extract_filename_from_path(path: &str) -> Option<&str> {
let filename = std::path::Path::new(path).file_name()?.to_str()?;
if filename.len() > 4 && filename[filename.len() - 4..].eq_ignore_ascii_case(".exe") {
Some(&filename[..filename.len() - 4])
} else {
Some(filename)
}
}
pub fn shell_from_name(shell_name: &str) -> Option<Shell> {
if let Ok(shell) = shell_name.parse() {
return Some(shell);
}
let name_lower = shell_name.to_lowercase();
if name_lower.starts_with("zsh") {
Some(Shell::Zsh)
} else if name_lower.starts_with("bash") {
Some(Shell::Bash)
} else if name_lower.starts_with("fish") {
Some(Shell::Fish)
} else if name_lower.starts_with("nu") {
Some(Shell::Nushell)
} else if name_lower.starts_with("pwsh") || name_lower.starts_with("powershell") {
Some(Shell::PowerShell)
} else {
None
}
}
pub fn current_shell() -> Option<Shell> {
if let Ok(shell_path) = std::env::var("SHELL")
&& let Some(name) = extract_filename_from_path(&shell_path)
{
return shell_from_name(name);
}
if std::env::var_os("PSModulePath").is_some() {
return Some(Shell::PowerShell);
}
None
}
pub fn detect_zsh_compinit() -> Option<bool> {
if std::env::var("WORKTRUNK_TEST_COMPINIT_CONFIGURED").is_ok() {
return Some(true); }
if std::env::var("WORKTRUNK_TEST_COMPINIT_MISSING").is_ok() {
return Some(false); }
let probe_cmd =
r#"(( $+functions[compdef] )) && echo __WT_COMPINIT_YES__ || echo __WT_COMPINIT_NO__"#;
log::debug!("$ zsh -ic '{}' (probe)", probe_cmd);
let mut child = Command::new("zsh")
.arg("-ic")
.arg(probe_cmd)
.stdin(Stdio::null()) .stdout(Stdio::piped())
.stderr(Stdio::null()) .env("ZSH_DISABLE_COMPFIX", "true")
.env_remove(crate::shell_exec::DIRECTIVE_FILE_ENV_VAR)
.spawn()
.ok()?;
let timeout = Duration::from_secs(2);
match child.wait_timeout(timeout) {
Ok(Some(_status)) => {
use std::io::Read;
let mut buf = Vec::new();
child.stdout.as_mut()?.read_to_end(&mut buf).ok()?;
let stdout = String::from_utf8_lossy(&buf);
Some(stdout.contains("__WT_COMPINIT_YES__"))
}
Ok(None) => {
let _ = child.kill();
let _ = child.wait();
None
}
Err(_) => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
#[case::just_name("bash", Some("bash"))]
#[case::just_name_exe("bash.exe", Some("bash"))]
#[case::mixed_case_exe_title("bash.Exe", Some("bash"))]
#[case::mixed_case_exe_upper("bash.EXE", Some("bash"))]
#[case::mixed_case_exe_camel("bash.eXe", Some("bash"))]
#[case::empty("", None)]
fn test_extract_filename_from_path_common(#[case] path: &str, #[case] expected: Option<&str>) {
assert_eq!(extract_filename_from_path(path), expected);
}
#[cfg(unix)]
#[rstest]
#[case::unix_bash("/usr/bin/bash", Some("bash"))]
#[case::unix_zsh("/bin/zsh", Some("zsh"))]
#[case::unix_fish("/usr/local/bin/fish", Some("fish"))]
#[case::nix_versioned("/nix/store/abc123/zsh-5.9", Some("zsh-5.9"))]
fn test_extract_filename_from_path_unix(#[case] path: &str, #[case] expected: Option<&str>) {
assert_eq!(extract_filename_from_path(path), expected);
}
#[cfg(windows)]
#[rstest]
#[case::windows_git_bash(r"C:\Program Files\Git\usr\bin\bash.exe", Some("bash"))]
#[case::windows_powershell(
r"C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe",
Some("powershell")
)]
#[case::windows_pwsh(r"C:\Program Files\PowerShell\7\pwsh.exe", Some("pwsh"))]
#[case::windows_zsh(r"C:\msys64\usr\bin\zsh.exe", Some("zsh"))]
#[case::uppercase_exe(r"C:\WINDOWS\SYSTEM32\BASH.EXE", Some("BASH"))]
fn test_extract_filename_from_path_windows(#[case] path: &str, #[case] expected: Option<&str>) {
assert_eq!(extract_filename_from_path(path), expected);
}
#[cfg(windows)]
#[rstest]
#[case::git_bash(r"C:\Program Files\Git\usr\bin\bash.exe", Shell::Bash)]
#[case::msys2_zsh(r"C:\msys64\usr\bin\zsh.exe", Shell::Zsh)]
#[case::powershell(
r"C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe",
Shell::PowerShell
)]
#[case::pwsh(r"C:\Program Files\PowerShell\7\pwsh.exe", Shell::PowerShell)]
fn test_issue_348_windows_shell_detection(#[case] shell_path: &str, #[case] expected: Shell) {
let shell_name = extract_filename_from_path(shell_path)
.expect("should extract filename from Windows path");
let detected =
shell_from_name(shell_name).expect("should detect shell from extracted name");
assert_eq!(detected, expected);
}
#[rstest]
#[case::bash("bash", Some(Shell::Bash))]
#[case::bash_versioned("bash5", Some(Shell::Bash))]
#[case::zsh("zsh", Some(Shell::Zsh))]
#[case::zsh_versioned("zsh-5.9", Some(Shell::Zsh))]
#[case::fish("fish", Some(Shell::Fish))]
#[case::nu("nu", Some(Shell::Nushell))]
#[case::nushell("nushell", Some(Shell::Nushell))]
#[case::powershell("powershell", Some(Shell::PowerShell))]
#[case::pwsh("pwsh", Some(Shell::PowerShell))]
#[case::pwsh_preview("pwsh-preview", Some(Shell::PowerShell))]
#[case::unknown("tcsh", None)]
#[case::unknown_csh("csh", None)]
fn test_shell_from_name(#[case] name: &str, #[case] expected: Option<Shell>) {
assert_eq!(shell_from_name(name), expected);
}
}