1use std::fmt;
2use std::str::FromStr;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum Shell {
6 Bash,
7 Pwsh,
8 Clink,
9 Nu,
10}
11
12impl FromStr for Shell {
13 type Err = ShellParseError;
14
15 fn from_str(s: &str) -> Result<Self, Self::Err> {
16 match s.to_ascii_lowercase().as_str() {
17 "bash" => Ok(Shell::Bash),
18 "pwsh" => Ok(Shell::Pwsh),
19 "clink" => Ok(Shell::Clink),
20 "nu" => Ok(Shell::Nu),
21 _ => Err(ShellParseError(s.to_string())),
22 }
23 }
24}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct ShellParseError(pub String);
28
29impl fmt::Display for ShellParseError {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 write!(
32 f,
33 "unknown shell '{}' (expected: bash, pwsh, clink, nu)",
34 self.0
35 )
36 }
37}
38
39impl std::error::Error for ShellParseError {}
40
41pub fn export_script(shell: Shell, bin: &str) -> String {
45 let template = match shell {
46 Shell::Bash => include_str!("templates/bash.sh"),
47 Shell::Pwsh => include_str!("templates/pwsh.ps1"),
48 Shell::Clink => include_str!("templates/clink.lua"),
49 Shell::Nu => include_str!("templates/nu.nu"),
50 };
51 template.replace("\r\n", "\n").replace("{BIN}", bin)
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57
58 #[test]
59 fn parse_bash() {
60 assert_eq!(Shell::from_str("bash").unwrap(), Shell::Bash);
61 }
62
63 #[test]
64 fn parse_case_insensitive() {
65 assert_eq!(Shell::from_str("PWSH").unwrap(), Shell::Pwsh);
66 assert_eq!(Shell::from_str("Clink").unwrap(), Shell::Clink);
67 assert_eq!(Shell::from_str("Nu").unwrap(), Shell::Nu);
68 }
69
70 #[test]
71 fn parse_unknown_errors() {
72 let err = Shell::from_str("fish").unwrap_err();
73 assert_eq!(err.0, "fish");
74 }
75
76 #[test]
77 fn export_script_contains_bin() {
78 for shell in [Shell::Bash, Shell::Pwsh, Shell::Clink, Shell::Nu] {
79 let script = export_script(shell, "my-runex");
80 assert!(
81 script.contains("my-runex"),
82 "{shell:?} script must contain the bin name"
83 );
84 }
85 }
86
87 #[test]
88 fn bash_script_has_bind() {
89 let s = export_script(Shell::Bash, "runex");
90 assert!(s.contains("bind"), "bash script must use bind");
91 assert!(s.contains("READLINE_LINE"), "bash script must use READLINE_LINE");
92 }
93
94 #[test]
95 fn pwsh_script_has_psreadline() {
96 let s = export_script(Shell::Pwsh, "runex");
97 assert!(s.contains("Set-PSReadLineKeyHandler"), "pwsh script must use PSReadLine");
98 }
99
100 #[test]
101 fn clink_script_has_clink() {
102 let s = export_script(Shell::Clink, "runex");
103 assert!(s.contains("clink"), "clink script must reference clink");
104 }
105
106 #[test]
107 fn nu_script_has_keybindings() {
108 let s = export_script(Shell::Nu, "runex");
109 assert!(s.contains("keybindings"), "nu script must reference keybindings");
110 }
111}