1use std::sync::OnceLock;
8use std::time::{Duration, Instant};
9
10use atd_protocol::{
11 BindingProtocol, SafetyLevel, ToolBinding, ToolCapability, ToolDefinition, ToolResources,
12 ToolSafety, ToolTrust, ToolVisibility, TrustLevel,
13};
14
15use crate::shared::{RunError, RunRequest, run};
16use atd_runtime::context::CallContext;
17use atd_runtime::error::ToolCallError;
18use atd_runtime::registry::{CallFuture, Tool};
19
20static DEFINITION: OnceLock<ToolDefinition> = OnceLock::new();
21
22fn definition() -> &'static ToolDefinition {
23 DEFINITION.get_or_init(|| ToolDefinition {
24 id: "ref:shell.pwsh".into(),
25 name: "PowerShell Execute".into(),
26 description: "Run a command via PowerShell. Prefers `pwsh` (PS 7+ cross-platform); on Windows falls back to `powershell`. Returns exit code + separated stdout/stderr. -NoProfile is applied to skip $PROFILE scripts.".into(),
27 version: "0.1.0".into(),
28 capability: ToolCapability {
29 domain: "shell".into(),
30 actions: vec!["pwsh".into()],
31 tags: vec!["shell".into(), "powershell".into(), "subprocess".into()],
32 intent_examples: vec![
33 "list directories via PowerShell".into(),
34 "run a PS cmdlet".into(),
35 ],
36 },
37 input_schema: serde_json::json!({
38 "type": "object",
39 "properties": {
40 "command": { "type": "string", "minLength": 1 },
41 "grace_ms": { "type": "integer", "minimum": 0 }
42 },
43 "required": ["command"]
44 }),
45 output_schema: serde_json::json!({
46 "type": "object",
47 "properties": {
48 "exit_code": { "type": ["integer", "null"] },
49 "stdout": { "type": "string" },
50 "stdout_truncated": { "type": "boolean" },
51 "stderr": { "type": "string" },
52 "stderr_truncated": { "type": "boolean" },
53 "duration_ms": { "type": "integer" }
54 }
55 }),
56 bindings: vec![ToolBinding {
57 protocol: BindingProtocol::Cli,
58 config: serde_json::json!({}),
59 }],
60 safety: ToolSafety {
61 level: SafetyLevel::Destructive,
62 dry_run: true,
63 side_effects: vec!["subprocess".into(), "filesystem".into(), "network".into()],
64 data_sensitivity: Some("depends on command".into()),
65 },
66 resources: ToolResources {
67 timeout_ms: 60_000,
68 max_concurrent: 10,
69 rate_limit_per_min: None,
70 estimated_tokens: Some(500),
71 },
72 trust: ToolTrust {
73 publisher: "atd-ref-server".into(),
74 trust_level: TrustLevel::L2Tested,
75 signature: None,
76 },
77 visibility: ToolVisibility::Dangerous,
78 required_capabilities: vec![],
79 tier: None,
80 errors: vec![],
81 })
82}
83
84pub struct ShellPwshTool;
85
86impl ShellPwshTool {
87 pub fn new() -> Self {
88 Self
89 }
90}
91
92impl Default for ShellPwshTool {
93 fn default() -> Self {
94 Self::new()
95 }
96}
97
98#[derive(serde::Deserialize)]
99struct PwshArgs {
100 command: String,
101 #[serde(default)]
102 grace_ms: Option<u64>,
103}
104
105fn pwsh_programs() -> &'static [&'static str] {
107 #[cfg(windows)]
108 {
109 &["pwsh", "powershell"]
110 }
111 #[cfg(not(windows))]
112 {
113 &["pwsh"]
114 }
115}
116
117impl Tool for ShellPwshTool {
118 fn definition(&self) -> &ToolDefinition {
119 definition()
120 }
121
122 fn call<'a>(&'a self, args: serde_json::Value, ctx: &'a CallContext) -> CallFuture<'a> {
123 Box::pin(async move {
124 let args: PwshArgs = serde_json::from_value(args)
125 .map_err(|e| ToolCallError::InvalidArgs(e.to_string()))?;
126 if args.command.trim().is_empty() {
127 return Err(ToolCallError::InvalidArgs(
128 "command is empty or whitespace-only".into(),
129 ));
130 }
131
132 let deadline = ctx
133 .deadline
134 .or_else(|| Some(Instant::now() + Duration::from_secs(60)));
135 let half = ctx.max_output_bytes / 2;
136 let grace_ms = args.grace_ms.unwrap_or(1000);
137
138 for &program in pwsh_programs() {
140 let req = RunRequest {
141 program,
142 args: &["-NoProfile", "-Command", &args.command],
143 cwd: &ctx.cwd,
144 deadline,
145 grace_ms,
146 max_stdout_bytes: half,
147 max_stderr_bytes: half,
148 };
149 match run(req).await {
150 Ok(out) => {
151 return Ok(serde_json::json!({
152 "exit_code": out.exit_code,
153 "stdout": out.stdout,
154 "stdout_truncated": out.stdout_truncated,
155 "stderr": out.stderr,
156 "stderr_truncated": out.stderr_truncated,
157 "duration_ms": out.duration_ms,
158 }));
159 }
160 Err(RunError::NotFound { .. }) => continue, Err(RunError::TimedOut { after_ms }) => {
162 return Err(ToolCallError::ExecutionFailed {
163 code: "TIMEOUT".into(),
164 message: format!("command timed out after {after_ms}ms"),
165 retryable: true,
166 });
167 }
168 Err(RunError::SpawnFailed(e)) | Err(RunError::Io(e)) => {
169 return Err(ToolCallError::ExecutionFailed {
170 code: "IO".into(),
171 message: format!("io: {e}"),
172 retryable: true,
173 });
174 }
175 }
176 }
177
178 Err(ToolCallError::ExecutionFailed {
180 code: "NOT_AVAILABLE".into(),
181 message: "neither `pwsh` nor `powershell` is on PATH".into(),
182 retryable: false,
183 })
184 })
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 fn pwsh_available() -> bool {
195 let candidates = pwsh_programs();
196 for &program in candidates {
197 if std::process::Command::new(program)
198 .arg("-Version")
199 .stdin(std::process::Stdio::null())
200 .stdout(std::process::Stdio::null())
201 .stderr(std::process::Stdio::null())
202 .status()
203 .is_ok()
204 {
205 return true;
206 }
207 }
208 false
209 }
210
211 #[tokio::test]
212 async fn happy_path_when_pwsh_available() {
213 if !pwsh_available() {
214 return;
216 }
217 let t = ShellPwshTool::new();
218 let ctx = CallContext::for_test();
219 let r = t
220 .call(serde_json::json!({"command": "Write-Output 'hi'"}), &ctx)
221 .await
222 .unwrap();
223 assert_eq!(r["exit_code"], 0);
224 assert!(r["stdout"].as_str().unwrap().contains("hi"));
225 }
226
227 #[tokio::test]
228 async fn exit_code_passes_through() {
229 if !pwsh_available() {
230 return;
231 }
232 let t = ShellPwshTool::new();
233 let ctx = CallContext::for_test();
234 let r = t
235 .call(serde_json::json!({"command": "exit 5"}), &ctx)
236 .await
237 .unwrap();
238 assert_eq!(r["exit_code"], 5);
239 }
240
241 #[tokio::test]
242 async fn not_available_when_no_pwsh() {
243 if pwsh_available() {
244 return;
247 }
248 let t = ShellPwshTool::new();
249 let ctx = CallContext::for_test();
250 let err = t
251 .call(serde_json::json!({"command": "Write-Output 'hi'"}), &ctx)
252 .await
253 .unwrap_err();
254 match err {
255 ToolCallError::ExecutionFailed {
256 code, retryable, ..
257 } => {
258 assert_eq!(code, "NOT_AVAILABLE");
259 assert!(!retryable);
260 }
261 _ => panic!("expected NOT_AVAILABLE"),
262 }
263 }
264
265 #[tokio::test]
266 async fn empty_command_is_invalid_args() {
267 let t = ShellPwshTool::new();
268 let ctx = CallContext::for_test();
269 let err = t
270 .call(serde_json::json!({"command": ""}), &ctx)
271 .await
272 .unwrap_err();
273 assert!(matches!(err, ToolCallError::InvalidArgs(_)));
274 }
275
276 #[tokio::test]
277 async fn grace_ms_override_is_accepted() {
278 if !pwsh_available() {
282 return;
283 }
284 let t = ShellPwshTool::new();
285 let mut ctx = CallContext::for_test();
286 ctx.deadline = Some(Instant::now() + Duration::from_millis(150));
287 let start = Instant::now();
288 let _ = t
289 .call(
290 serde_json::json!({
291 "command": "Start-Sleep -Seconds 10",
292 "grace_ms": 100
293 }),
294 &ctx,
295 )
296 .await;
297 let elapsed = start.elapsed();
298 assert!(elapsed < Duration::from_secs(3), "too slow: {elapsed:?}");
299 }
300}