agcodex_core/
shell.rs

1use shlex;
2
3#[derive(Debug, PartialEq, Eq)]
4pub struct ZshShell {
5    shell_path: String,
6    zshrc_path: String,
7}
8
9#[derive(Debug, PartialEq, Eq)]
10pub enum Shell {
11    Zsh(ZshShell),
12    Unknown,
13}
14
15impl Shell {
16    pub fn format_default_shell_invocation(&self, command: Vec<String>) -> Option<Vec<String>> {
17        match self {
18            Shell::Zsh(zsh) => {
19                if !std::path::Path::new(&zsh.zshrc_path).exists() {
20                    return None;
21                }
22
23                let mut result = vec![zsh.shell_path.clone()];
24                result.push("-lc".to_string());
25
26                let joined = strip_bash_lc(&command)
27                    .or_else(|| shlex::try_join(command.iter().map(|s| s.as_str())).ok());
28
29                if let Some(joined) = joined {
30                    result.push(format!("source {} && ({joined})", zsh.zshrc_path));
31                } else {
32                    return None;
33                }
34                Some(result)
35            }
36            Shell::Unknown => None,
37        }
38    }
39}
40
41fn strip_bash_lc(command: &Vec<String>) -> Option<String> {
42    match command.as_slice() {
43        // exactly three items
44        [first, second, third]
45            // first two must be "bash", "-lc"
46            if first == "bash" && second == "-lc" =>
47        {
48            Some(third.clone())
49        }
50        _ => None,
51    }
52}
53
54#[cfg(target_os = "macos")]
55pub async fn default_user_shell() -> Shell {
56    use tokio::process::Command;
57    use whoami;
58
59    let user = whoami::username();
60    let home = format!("/Users/{user}");
61    let output = Command::new("dscl")
62        .args([".", "-read", &home, "UserShell"])
63        .output()
64        .await
65        .ok();
66    match output {
67        Some(o) => {
68            if !o.status.success() {
69                return Shell::Unknown;
70            }
71            let stdout = String::from_utf8_lossy(&o.stdout);
72            for line in stdout.lines() {
73                if let Some(shell_path) = line.strip_prefix("UserShell: ")
74                    && shell_path.ends_with("/zsh")
75                {
76                    return Shell::Zsh(ZshShell {
77                        shell_path: shell_path.to_string(),
78                        zshrc_path: format!("{home}/.zshrc"),
79                    });
80                }
81            }
82
83            Shell::Unknown
84        }
85        _ => Shell::Unknown,
86    }
87}
88
89#[cfg(not(target_os = "macos"))]
90pub async fn default_user_shell() -> Shell {
91    Shell::Unknown
92}
93
94#[cfg(test)]
95#[cfg(target_os = "macos")]
96mod tests {
97    use super::*;
98    use std::process::Command;
99
100    #[tokio::test]
101    async fn test_current_shell_detects_zsh() {
102        let shell = Command::new("sh")
103            .arg("-c")
104            .arg("echo $SHELL")
105            .output()
106            .unwrap();
107
108        let home = std::env::var("HOME").unwrap();
109        let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string();
110        if shell_path.ends_with("/zsh") {
111            assert_eq!(
112                default_user_shell().await,
113                Shell::Zsh(ZshShell {
114                    shell_path: shell_path.to_string(),
115                    zshrc_path: format!("{home}/.zshrc",),
116                })
117            );
118        }
119    }
120
121    #[tokio::test]
122    async fn test_run_with_profile_zshrc_not_exists() {
123        let shell = Shell::Zsh(ZshShell {
124            shell_path: "/bin/zsh".to_string(),
125            zshrc_path: "/does/not/exist/.zshrc".to_string(),
126        });
127        let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
128        assert_eq!(actual_cmd, None);
129    }
130
131    #[tokio::test]
132    async fn test_run_with_profile_escaping_and_execution() {
133        let shell_path = "/bin/zsh";
134
135        let cases = vec![
136            (
137                vec!["myecho"],
138                vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
139                Some("It works!\n"),
140            ),
141            (
142                vec!["myecho"],
143                vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
144                Some("It works!\n"),
145            ),
146            (
147                vec!["bash", "-c", "echo 'single' \"double\""],
148                vec![
149                    shell_path,
150                    "-lc",
151                    "source ZSHRC_PATH && (bash -c \"echo 'single' \\\"double\\\"\")",
152                ],
153                Some("single double\n"),
154            ),
155            (
156                vec!["bash", "-lc", "echo 'single' \"double\""],
157                vec![
158                    shell_path,
159                    "-lc",
160                    "source ZSHRC_PATH && (echo 'single' \"double\")",
161                ],
162                Some("single double\n"),
163            ),
164        ];
165        for (input, expected_cmd, expected_output) in cases {
166            use std::collections::HashMap;
167            use std::path::PathBuf;
168
169            use crate::exec::ExecParams;
170            use crate::exec::SandboxType;
171            use crate::exec::process_exec_tool_call;
172            use crate::protocol::SandboxPolicy;
173
174            // create a temp directory with a zshrc file in it
175            let temp_home = tempfile::tempdir().unwrap();
176            let zshrc_path = temp_home.path().join(".zshrc");
177            std::fs::write(
178                &zshrc_path,
179                r#"
180                    set -x
181                    function myecho {
182                        echo 'It works!'
183                    }
184                    "#,
185            )
186            .unwrap();
187            let shell = Shell::Zsh(ZshShell {
188                shell_path: shell_path.to_string(),
189                zshrc_path: zshrc_path.to_str().unwrap().to_string(),
190            });
191
192            let actual_cmd = shell
193                .format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
194            let expected_cmd = expected_cmd
195                .iter()
196                .map(|s| {
197                    s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap())
198                        .to_string()
199                })
200                .collect();
201
202            assert_eq!(actual_cmd, Some(expected_cmd));
203            // Actually run the command and check output/exit code
204            let output = process_exec_tool_call(
205                ExecParams {
206                    command: actual_cmd.unwrap(),
207                    cwd: PathBuf::from(temp_home.path()),
208                    timeout_ms: None,
209                    env: HashMap::from([(
210                        "HOME".to_string(),
211                        temp_home.path().to_str().unwrap().to_string(),
212                    )]),
213                    with_escalated_permissions: None,
214                    justification: None,
215                },
216                SandboxType::None,
217                &SandboxPolicy::DangerFullAccess,
218                &None,
219                None,
220            )
221            .await
222            .unwrap();
223
224            assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}");
225            if let Some(expected) = expected_output {
226                assert_eq!(
227                    output.stdout.text, expected,
228                    "input: {input:?} output: {output:?}"
229                );
230            }
231        }
232    }
233}