agent_code_lib/tools/
powershell.rs1use async_trait::async_trait;
7use serde_json::json;
8use std::process::Stdio;
9use tokio::process::Command;
10
11use super::{Tool, ToolContext, ToolResult};
12use crate::error::ToolError;
13
14const MAX_OUTPUT_BYTES: usize = 256 * 1024;
16
17pub struct PowerShellTool;
18
19#[async_trait]
20impl Tool for PowerShellTool {
21 fn name(&self) -> &'static str {
22 "PowerShell"
23 }
24
25 fn description(&self) -> &'static str {
26 "Executes a PowerShell command on Windows. Returns stdout and stderr."
27 }
28
29 fn input_schema(&self) -> serde_json::Value {
30 json!({
31 "type": "object",
32 "required": ["command"],
33 "properties": {
34 "command": {
35 "type": "string",
36 "description": "The PowerShell command to execute"
37 },
38 "timeout": {
39 "type": "integer",
40 "description": "Timeout in milliseconds (max 600000)"
41 }
42 }
43 })
44 }
45
46 fn is_read_only(&self) -> bool {
47 false
48 }
49
50 fn is_enabled(&self) -> bool {
51 cfg!(target_os = "windows")
52 }
53
54 async fn call(
55 &self,
56 input: serde_json::Value,
57 ctx: &ToolContext,
58 ) -> Result<ToolResult, ToolError> {
59 let command = input
60 .get("command")
61 .and_then(|v| v.as_str())
62 .ok_or_else(|| ToolError::InvalidInput("'command' is required".into()))?;
63
64 let timeout_ms = input
65 .get("timeout")
66 .and_then(|v| v.as_u64())
67 .unwrap_or(120_000)
68 .min(600_000);
69
70 let shell = if which_exists("pwsh") {
72 "pwsh"
73 } else {
74 "powershell"
75 };
76
77 let mut cmd = Command::new(shell);
78 cmd.args(["-NoProfile", "-NonInteractive", "-Command", command]);
79 cmd.current_dir(&ctx.cwd);
80 cmd.stdout(Stdio::piped());
81 cmd.stderr(Stdio::piped());
82
83 let output =
84 tokio::time::timeout(std::time::Duration::from_millis(timeout_ms), cmd.output())
85 .await
86 .map_err(|_| {
87 ToolError::ExecutionFailed(format!("Command timed out after {timeout_ms}ms"))
88 })?
89 .map_err(|e| {
90 ToolError::ExecutionFailed(format!("Failed to run PowerShell: {e}"))
91 })?;
92
93 let stdout = String::from_utf8_lossy(&output.stdout);
94 let stderr = String::from_utf8_lossy(&output.stderr);
95
96 let mut result = String::new();
97 if !stdout.is_empty() {
98 result.push_str(&stdout);
99 }
100 if !stderr.is_empty() {
101 if !result.is_empty() {
102 result.push('\n');
103 }
104 result.push_str("STDERR:\n");
105 result.push_str(&stderr);
106 }
107
108 if result.len() > MAX_OUTPUT_BYTES {
110 result.truncate(MAX_OUTPUT_BYTES);
111 result.push_str("\n\n(output truncated)");
112 }
113
114 let exit_code = output.status.code().unwrap_or(-1);
115 if exit_code != 0 {
116 result.push_str(&format!("\n\nExit code: {exit_code}"));
117 }
118
119 if output.status.success() {
120 Ok(ToolResult::success(result))
121 } else {
122 Ok(ToolResult {
123 content: result,
124 is_error: true,
125 })
126 }
127 }
128}
129
130fn which_exists(cmd: &str) -> bool {
131 std::process::Command::new("which")
132 .arg(cmd)
133 .output()
134 .is_ok_and(|o| o.status.success())
135}