Skip to main content

ai_agent/utils/shell/
powershell_provider.rs

1//! PowerShell provider implementation.
2
3use super::shell_provider::{ShellError, ShellExecCommand};
4use super::shell_tool_utils::ShellType;
5use crate::constants::env::system;
6use std::collections::HashMap;
7
8/// PowerShell invocation flags + command.
9/// Shared by the provider's get_spawn_args and other paths.
10pub fn build_powershell_args(cmd: &str) -> Vec<String> {
11    vec![
12        "-NoProfile".to_string(),
13        "-NonInteractive".to_string(),
14        "-Command".to_string(),
15        cmd.to_string(),
16    ]
17}
18
19/// Base64-encode a string as UTF-16LE for PowerShell's -EncodedCommand.
20/// This encoding survives ANY shell-quoting layer.
21fn encode_powershell_command(ps_command: &str) -> String {
22    // Convert to UTF-16LE bytes
23    let utf16: Vec<u8> = ps_command
24        .encode_utf16()
25        .flat_map(|c| c.to_le_bytes())
26        .collect();
27
28    // Base64 encode
29    base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &utf16)
30}
31
32/// PowerShell shell provider implementation.
33pub struct PowerShellProvider {
34    shell_path: String,
35    current_sandbox_tmp_dir: Option<String>,
36}
37
38impl PowerShellProvider {
39    /// Create a new PowerShellProvider
40    pub fn new(shell_path: &str) -> Self {
41        Self {
42            shell_path: shell_path.to_string(),
43            current_sandbox_tmp_dir: None,
44        }
45    }
46
47    /// Get shell type
48    pub fn get_type(&self) -> ShellType {
49        ShellType::PowerShell
50    }
51
52    /// Get shell path
53    pub fn get_shell_path(&self) -> &str {
54        &self.shell_path
55    }
56
57    /// Whether the shell is detached
58    pub fn is_detached(&self) -> bool {
59        false
60    }
61
62    /// Build the full command string including all PowerShell-specific setup.
63    pub async fn build_exec_command(
64        &mut self,
65        command: &str,
66        id: usize,
67        sandbox_tmp_dir: Option<&str>,
68        use_sandbox: bool,
69    ) -> Result<ShellExecCommand, ShellError> {
70        // Stash sandbox_tmp_dir for get_environment_overrides
71        self.current_sandbox_tmp_dir = sandbox_tmp_dir.map(|s| s.to_string());
72
73        let cwd_file_path = if use_sandbox && sandbox_tmp_dir.is_some() {
74            format!("{}/claude-pwd-ps-{}", sandbox_tmp_dir.unwrap(), id)
75        } else {
76            let tmpdir = std::env::temp_dir();
77            tmpdir
78                .join(format!("claude-pwd-ps-{}", id))
79                .to_string_lossy()
80                .to_string()
81        };
82
83        let escaped_cwd_file_path = cwd_file_path.replace('\'', "''");
84
85        // Exit-code capture: prefer $LASTEXITCODE when a native exe ran.
86        // Fall back to $? for cmdlet-only pipelines.
87        let cwd_tracking = format!(
88            "\n; $_ec = if ($null -ne $LASTEXITCODE) {{ $LASTEXITCODE }} elseif ($?) {{ 0 }} else {{ 1 }}\n; (Get-Location).Path | Out-File -FilePath '{}' -Encoding utf8 -NoNewline\n; exit $_ec",
89            escaped_cwd_file_path
90        );
91
92        let ps_command = format!("{}{}", command, cwd_tracking);
93
94        // For sandbox path, build a command that invokes pwsh with the full flag set
95        // For non-sandbox path, return the bare PS command
96        let command_string = if use_sandbox {
97            // Shell path is single-quoted, base64 encoded command
98            let encoded = encode_powershell_command(&ps_command);
99            format!(
100                "'{}' -NoProfile -NonInteractive -EncodedCommand {}",
101                self.shell_path.replace('\'', "'\\''"),
102                encoded
103            )
104        } else {
105            ps_command
106        };
107
108        Ok(ShellExecCommand {
109            command_string,
110            cwd_file_path,
111        })
112    }
113
114    /// Get shell args for spawn
115    pub fn get_spawn_args(&self, command_string: &str) -> Vec<String> {
116        build_powershell_args(command_string)
117    }
118
119    /// Get extra env vars for this shell type
120    pub async fn get_environment_overrides(&self) -> HashMap<String, String> {
121        let mut env = HashMap::new();
122
123        // Apply session env vars set via /env
124        // This would be implemented with session env vars integration
125        // for (key, value) in get_session_env_vars() {
126        //     env.insert(key, value);
127        // }
128
129        if let Some(ref tmpdir) = self.current_sandbox_tmp_dir {
130            // PowerShell on Linux/macOS honors TMPDIR
131            env.insert("TMPDIR".to_string(), tmpdir.clone());
132            env.insert("AI_TMPDIR".to_string(), tmpdir.clone());
133        }
134
135        env
136    }
137}
138
139impl Default for PowerShellProvider {
140    fn default() -> Self {
141        #[cfg(target_os = "windows")]
142        let shell = "powershell.exe".to_string();
143        #[cfg(not(target_os = "windows"))]
144        let shell = std::env::var(system::PATH)
145            .ok()
146            .and_then(|p| {
147                p.split(':')
148                    .find(|p| std::path::Path::new(&format!("{}/pwsh", p)).exists())
149                    .map(|s| s.to_string())
150            })
151            .unwrap_or_else(|| "pwsh".to_string());
152
153        Self::new(&shell)
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_build_powershell_args() {
163        let args = build_powershell_args("echo hello");
164        assert_eq!(args[0], "-NoProfile");
165        assert_eq!(args[1], "-NonInteractive");
166        assert_eq!(args[2], "-Command");
167        assert_eq!(args[3], "echo hello");
168    }
169
170    #[test]
171    fn test_encode_powershell_command() {
172        let encoded = encode_powershell_command("echo hello");
173        // Base64 of "echo hello" in UTF-16LE
174        assert!(!encoded.is_empty());
175    }
176}