use super::shell_provider::{ShellError, ShellExecCommand};
use super::shell_tool_utils::ShellType;
use crate::constants::env::system;
use std::collections::HashMap;
pub fn build_powershell_args(cmd: &str) -> Vec<String> {
vec![
"-NoProfile".to_string(),
"-NonInteractive".to_string(),
"-Command".to_string(),
cmd.to_string(),
]
}
fn encode_powershell_command(ps_command: &str) -> String {
let utf16: Vec<u8> = ps_command
.encode_utf16()
.flat_map(|c| c.to_le_bytes())
.collect();
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &utf16)
}
pub struct PowerShellProvider {
shell_path: String,
current_sandbox_tmp_dir: Option<String>,
}
impl PowerShellProvider {
pub fn new(shell_path: &str) -> Self {
Self {
shell_path: shell_path.to_string(),
current_sandbox_tmp_dir: None,
}
}
pub fn get_type(&self) -> ShellType {
ShellType::PowerShell
}
pub fn get_shell_path(&self) -> &str {
&self.shell_path
}
pub fn is_detached(&self) -> bool {
false
}
pub async fn build_exec_command(
&mut self,
command: &str,
id: usize,
sandbox_tmp_dir: Option<&str>,
use_sandbox: bool,
) -> Result<ShellExecCommand, ShellError> {
self.current_sandbox_tmp_dir = sandbox_tmp_dir.map(|s| s.to_string());
let cwd_file_path = if use_sandbox && sandbox_tmp_dir.is_some() {
format!("{}/claude-pwd-ps-{}", sandbox_tmp_dir.unwrap(), id)
} else {
let tmpdir = std::env::temp_dir();
tmpdir
.join(format!("claude-pwd-ps-{}", id))
.to_string_lossy()
.to_string()
};
let escaped_cwd_file_path = cwd_file_path.replace('\'', "''");
let cwd_tracking = format!(
"\n; $_ec = if ($null -ne $LASTEXITCODE) {{ $LASTEXITCODE }} elseif ($?) {{ 0 }} else {{ 1 }}\n; (Get-Location).Path | Out-File -FilePath '{}' -Encoding utf8 -NoNewline\n; exit $_ec",
escaped_cwd_file_path
);
let ps_command = format!("{}{}", command, cwd_tracking);
let command_string = if use_sandbox {
let encoded = encode_powershell_command(&ps_command);
format!(
"'{}' -NoProfile -NonInteractive -EncodedCommand {}",
self.shell_path.replace('\'', "'\\''"),
encoded
)
} else {
ps_command
};
Ok(ShellExecCommand {
command_string,
cwd_file_path,
})
}
pub fn get_spawn_args(&self, command_string: &str) -> Vec<String> {
build_powershell_args(command_string)
}
pub async fn get_environment_overrides(&self) -> HashMap<String, String> {
let mut env = HashMap::new();
if let Some(ref tmpdir) = self.current_sandbox_tmp_dir {
env.insert("TMPDIR".to_string(), tmpdir.clone());
env.insert("AI_TMPDIR".to_string(), tmpdir.clone());
}
env
}
}
impl Default for PowerShellProvider {
fn default() -> Self {
#[cfg(target_os = "windows")]
let shell = "powershell.exe".to_string();
#[cfg(not(target_os = "windows"))]
let shell = std::env::var(system::PATH)
.ok()
.and_then(|p| {
p.split(':')
.find(|p| std::path::Path::new(&format!("{}/pwsh", p)).exists())
.map(|s| s.to_string())
})
.unwrap_or_else(|| "pwsh".to_string());
Self::new(&shell)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_powershell_args() {
let args = build_powershell_args("echo hello");
assert_eq!(args[0], "-NoProfile");
assert_eq!(args[1], "-NonInteractive");
assert_eq!(args[2], "-Command");
assert_eq!(args[3], "echo hello");
}
#[test]
fn test_encode_powershell_command() {
let encoded = encode_powershell_command("echo hello");
assert!(!encoded.is_empty());
}
}