1use shlex;
2
3#[derive(Debug, PartialEq, Eq)]
4pub struct ZshShell {
5 shell_path: String,
6 zshrc_path: String,
7}
8
9#[derive(Debug, PartialEq, Eq)]
10pub enum Shell {
11 Zsh(ZshShell),
12 Unknown,
13}
14
15impl Shell {
16 pub fn format_default_shell_invocation(&self, command: Vec<String>) -> Option<Vec<String>> {
17 match self {
18 Shell::Zsh(zsh) => {
19 if !std::path::Path::new(&zsh.zshrc_path).exists() {
20 return None;
21 }
22
23 let mut result = vec![zsh.shell_path.clone()];
24 result.push("-lc".to_string());
25
26 let joined = strip_bash_lc(&command)
27 .or_else(|| shlex::try_join(command.iter().map(|s| s.as_str())).ok());
28
29 if let Some(joined) = joined {
30 result.push(format!("source {} && ({joined})", zsh.zshrc_path));
31 } else {
32 return None;
33 }
34 Some(result)
35 }
36 Shell::Unknown => None,
37 }
38 }
39}
40
41fn strip_bash_lc(command: &Vec<String>) -> Option<String> {
42 match command.as_slice() {
43 [first, second, third]
45 if first == "bash" && second == "-lc" =>
47 {
48 Some(third.clone())
49 }
50 _ => None,
51 }
52}
53
54#[cfg(target_os = "macos")]
55pub async fn default_user_shell() -> Shell {
56 use tokio::process::Command;
57 use whoami;
58
59 let user = whoami::username();
60 let home = format!("/Users/{user}");
61 let output = Command::new("dscl")
62 .args([".", "-read", &home, "UserShell"])
63 .output()
64 .await
65 .ok();
66 match output {
67 Some(o) => {
68 if !o.status.success() {
69 return Shell::Unknown;
70 }
71 let stdout = String::from_utf8_lossy(&o.stdout);
72 for line in stdout.lines() {
73 if let Some(shell_path) = line.strip_prefix("UserShell: ")
74 && shell_path.ends_with("/zsh")
75 {
76 return Shell::Zsh(ZshShell {
77 shell_path: shell_path.to_string(),
78 zshrc_path: format!("{home}/.zshrc"),
79 });
80 }
81 }
82
83 Shell::Unknown
84 }
85 _ => Shell::Unknown,
86 }
87}
88
89#[cfg(not(target_os = "macos"))]
90pub async fn default_user_shell() -> Shell {
91 Shell::Unknown
92}
93
94#[cfg(test)]
95#[cfg(target_os = "macos")]
96mod tests {
97 use super::*;
98 use std::process::Command;
99
100 #[tokio::test]
101 async fn test_current_shell_detects_zsh() {
102 let shell = Command::new("sh")
103 .arg("-c")
104 .arg("echo $SHELL")
105 .output()
106 .unwrap();
107
108 let home = std::env::var("HOME").unwrap();
109 let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string();
110 if shell_path.ends_with("/zsh") {
111 assert_eq!(
112 default_user_shell().await,
113 Shell::Zsh(ZshShell {
114 shell_path: shell_path.to_string(),
115 zshrc_path: format!("{home}/.zshrc",),
116 })
117 );
118 }
119 }
120
121 #[tokio::test]
122 async fn test_run_with_profile_zshrc_not_exists() {
123 let shell = Shell::Zsh(ZshShell {
124 shell_path: "/bin/zsh".to_string(),
125 zshrc_path: "/does/not/exist/.zshrc".to_string(),
126 });
127 let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
128 assert_eq!(actual_cmd, None);
129 }
130
131 #[tokio::test]
132 async fn test_run_with_profile_escaping_and_execution() {
133 let shell_path = "/bin/zsh";
134
135 let cases = vec![
136 (
137 vec!["myecho"],
138 vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
139 Some("It works!\n"),
140 ),
141 (
142 vec!["myecho"],
143 vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
144 Some("It works!\n"),
145 ),
146 (
147 vec!["bash", "-c", "echo 'single' \"double\""],
148 vec![
149 shell_path,
150 "-lc",
151 "source ZSHRC_PATH && (bash -c \"echo 'single' \\\"double\\\"\")",
152 ],
153 Some("single double\n"),
154 ),
155 (
156 vec!["bash", "-lc", "echo 'single' \"double\""],
157 vec![
158 shell_path,
159 "-lc",
160 "source ZSHRC_PATH && (echo 'single' \"double\")",
161 ],
162 Some("single double\n"),
163 ),
164 ];
165 for (input, expected_cmd, expected_output) in cases {
166 use std::collections::HashMap;
167 use std::path::PathBuf;
168
169 use crate::exec::ExecParams;
170 use crate::exec::SandboxType;
171 use crate::exec::process_exec_tool_call;
172 use crate::protocol::SandboxPolicy;
173
174 let temp_home = tempfile::tempdir().unwrap();
176 let zshrc_path = temp_home.path().join(".zshrc");
177 std::fs::write(
178 &zshrc_path,
179 r#"
180 set -x
181 function myecho {
182 echo 'It works!'
183 }
184 "#,
185 )
186 .unwrap();
187 let shell = Shell::Zsh(ZshShell {
188 shell_path: shell_path.to_string(),
189 zshrc_path: zshrc_path.to_str().unwrap().to_string(),
190 });
191
192 let actual_cmd = shell
193 .format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
194 let expected_cmd = expected_cmd
195 .iter()
196 .map(|s| {
197 s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap())
198 .to_string()
199 })
200 .collect();
201
202 assert_eq!(actual_cmd, Some(expected_cmd));
203 let output = process_exec_tool_call(
205 ExecParams {
206 command: actual_cmd.unwrap(),
207 cwd: PathBuf::from(temp_home.path()),
208 timeout_ms: None,
209 env: HashMap::from([(
210 "HOME".to_string(),
211 temp_home.path().to_str().unwrap().to_string(),
212 )]),
213 with_escalated_permissions: None,
214 justification: None,
215 },
216 SandboxType::None,
217 &SandboxPolicy::DangerFullAccess,
218 &None,
219 None,
220 )
221 .await
222 .unwrap();
223
224 assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}");
225 if let Some(expected) = expected_output {
226 assert_eq!(
227 output.stdout.text, expected,
228 "input: {input:?} output: {output:?}"
229 );
230 }
231 }
232 }
233}