Skip to main content

atd_tools_shell/
pwsh.rs

1//! `ref:shell.pwsh` — PowerShell execution.
2//!
3//! Tries `pwsh` (PowerShell 7+, cross-platform) first. On NotFound, Windows
4//! falls back to `powershell` (built-in since XP). Other platforms without
5//! `pwsh` return NOT_AVAILABLE.
6
7use std::sync::OnceLock;
8use std::time::{Duration, Instant};
9
10use atd_protocol::{
11    BindingProtocol, SafetyLevel, ToolBinding, ToolCapability, ToolDefinition, ToolResources,
12    ToolSafety, ToolTrust, ToolVisibility, TrustLevel,
13};
14
15use crate::shared::{RunError, RunRequest, run};
16use atd_runtime::context::CallContext;
17use atd_runtime::error::ToolCallError;
18use atd_runtime::registry::{CallFuture, Tool};
19
20static DEFINITION: OnceLock<ToolDefinition> = OnceLock::new();
21
22fn definition() -> &'static ToolDefinition {
23    DEFINITION.get_or_init(|| ToolDefinition {
24        id: "ref:shell.pwsh".into(),
25        name: "PowerShell Execute".into(),
26        description: "Run a command via PowerShell. Prefers `pwsh` (PS 7+ cross-platform); on Windows falls back to `powershell`. Returns exit code + separated stdout/stderr. -NoProfile is applied to skip $PROFILE scripts.".into(),
27        version: "0.1.0".into(),
28        capability: ToolCapability {
29            domain: "shell".into(),
30            actions: vec!["pwsh".into()],
31            tags: vec!["shell".into(), "powershell".into(), "subprocess".into()],
32            intent_examples: vec![
33                "list directories via PowerShell".into(),
34                "run a PS cmdlet".into(),
35            ],
36        },
37        input_schema: serde_json::json!({
38            "type": "object",
39            "properties": {
40                "command":  { "type": "string", "minLength": 1 },
41                "grace_ms": { "type": "integer", "minimum": 0 }
42            },
43            "required": ["command"]
44        }),
45        output_schema: serde_json::json!({
46            "type": "object",
47            "properties": {
48                "exit_code":        { "type": ["integer", "null"] },
49                "stdout":           { "type": "string" },
50                "stdout_truncated": { "type": "boolean" },
51                "stderr":           { "type": "string" },
52                "stderr_truncated": { "type": "boolean" },
53                "duration_ms":      { "type": "integer" }
54            }
55        }),
56        bindings: vec![ToolBinding {
57            protocol: BindingProtocol::Cli,
58            config: serde_json::json!({}),
59        }],
60        safety: ToolSafety {
61            level: SafetyLevel::Destructive,
62            dry_run: true,
63            side_effects: vec!["subprocess".into(), "filesystem".into(), "network".into()],
64            data_sensitivity: Some("depends on command".into()),
65        },
66        resources: ToolResources {
67            timeout_ms: 60_000,
68            max_concurrent: 10,
69            rate_limit_per_min: None,
70            estimated_tokens: Some(500),
71        },
72        trust: ToolTrust {
73            publisher: "atd-ref-server".into(),
74            trust_level: TrustLevel::L2Tested,
75            signature: None,
76        },
77        visibility: ToolVisibility::Dangerous,
78        required_capabilities: vec![],
79        tier: None,
80        errors: vec![],
81    })
82}
83
84pub struct ShellPwshTool;
85
86impl ShellPwshTool {
87    pub fn new() -> Self {
88        Self
89    }
90}
91
92impl Default for ShellPwshTool {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98#[derive(serde::Deserialize)]
99struct PwshArgs {
100    command: String,
101    #[serde(default)]
102    grace_ms: Option<u64>,
103}
104
105/// List of program names to try in order, per-platform.
106fn pwsh_programs() -> &'static [&'static str] {
107    #[cfg(windows)]
108    {
109        &["pwsh", "powershell"]
110    }
111    #[cfg(not(windows))]
112    {
113        &["pwsh"]
114    }
115}
116
117impl Tool for ShellPwshTool {
118    fn definition(&self) -> &ToolDefinition {
119        definition()
120    }
121
122    fn call<'a>(&'a self, args: serde_json::Value, ctx: &'a CallContext) -> CallFuture<'a> {
123        Box::pin(async move {
124            let args: PwshArgs = serde_json::from_value(args)
125                .map_err(|e| ToolCallError::InvalidArgs(e.to_string()))?;
126            if args.command.trim().is_empty() {
127                return Err(ToolCallError::InvalidArgs(
128                    "command is empty or whitespace-only".into(),
129                ));
130            }
131
132            let deadline = ctx
133                .deadline
134                .or_else(|| Some(Instant::now() + Duration::from_secs(60)));
135            let half = ctx.max_output_bytes / 2;
136            let grace_ms = args.grace_ms.unwrap_or(1000);
137
138            // Try each candidate program; on NotFound, try the next.
139            for &program in pwsh_programs() {
140                let req = RunRequest {
141                    program,
142                    args: &["-NoProfile", "-Command", &args.command],
143                    cwd: &ctx.cwd,
144                    deadline,
145                    grace_ms,
146                    max_stdout_bytes: half,
147                    max_stderr_bytes: half,
148                };
149                match run(req).await {
150                    Ok(out) => {
151                        return Ok(serde_json::json!({
152                            "exit_code": out.exit_code,
153                            "stdout": out.stdout,
154                            "stdout_truncated": out.stdout_truncated,
155                            "stderr": out.stderr,
156                            "stderr_truncated": out.stderr_truncated,
157                            "duration_ms": out.duration_ms,
158                        }));
159                    }
160                    Err(RunError::NotFound { .. }) => continue, // try next candidate
161                    Err(RunError::TimedOut { after_ms }) => {
162                        return Err(ToolCallError::ExecutionFailed {
163                            code: "TIMEOUT".into(),
164                            message: format!("command timed out after {after_ms}ms"),
165                            retryable: true,
166                        });
167                    }
168                    Err(RunError::SpawnFailed(e)) | Err(RunError::Io(e)) => {
169                        return Err(ToolCallError::ExecutionFailed {
170                            code: "IO".into(),
171                            message: format!("io: {e}"),
172                            retryable: true,
173                        });
174                    }
175                }
176            }
177
178            // All candidates were NotFound.
179            Err(ToolCallError::ExecutionFailed {
180                code: "NOT_AVAILABLE".into(),
181                message: "neither `pwsh` nor `powershell` is on PATH".into(),
182                retryable: false,
183            })
184        })
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    /// Detect PowerShell availability at runtime; use it to decide the
193    /// expected test branch.
194    fn pwsh_available() -> bool {
195        let candidates = pwsh_programs();
196        for &program in candidates {
197            if std::process::Command::new(program)
198                .arg("-Version")
199                .stdin(std::process::Stdio::null())
200                .stdout(std::process::Stdio::null())
201                .stderr(std::process::Stdio::null())
202                .status()
203                .is_ok()
204            {
205                return true;
206            }
207        }
208        false
209    }
210
211    #[tokio::test]
212    async fn happy_path_when_pwsh_available() {
213        if !pwsh_available() {
214            // Skip on systems without PowerShell.
215            return;
216        }
217        let t = ShellPwshTool::new();
218        let ctx = CallContext::for_test();
219        let r = t
220            .call(serde_json::json!({"command": "Write-Output 'hi'"}), &ctx)
221            .await
222            .unwrap();
223        assert_eq!(r["exit_code"], 0);
224        assert!(r["stdout"].as_str().unwrap().contains("hi"));
225    }
226
227    #[tokio::test]
228    async fn exit_code_passes_through() {
229        if !pwsh_available() {
230            return;
231        }
232        let t = ShellPwshTool::new();
233        let ctx = CallContext::for_test();
234        let r = t
235            .call(serde_json::json!({"command": "exit 5"}), &ctx)
236            .await
237            .unwrap();
238        assert_eq!(r["exit_code"], 5);
239    }
240
241    #[tokio::test]
242    async fn not_available_when_no_pwsh() {
243        if pwsh_available() {
244            // Skip on systems with PowerShell — this test only makes sense
245            // when the shell is absent.
246            return;
247        }
248        let t = ShellPwshTool::new();
249        let ctx = CallContext::for_test();
250        let err = t
251            .call(serde_json::json!({"command": "Write-Output 'hi'"}), &ctx)
252            .await
253            .unwrap_err();
254        match err {
255            ToolCallError::ExecutionFailed {
256                code, retryable, ..
257            } => {
258                assert_eq!(code, "NOT_AVAILABLE");
259                assert!(!retryable);
260            }
261            _ => panic!("expected NOT_AVAILABLE"),
262        }
263    }
264
265    #[tokio::test]
266    async fn empty_command_is_invalid_args() {
267        let t = ShellPwshTool::new();
268        let ctx = CallContext::for_test();
269        let err = t
270            .call(serde_json::json!({"command": ""}), &ctx)
271            .await
272            .unwrap_err();
273        assert!(matches!(err, ToolCallError::InvalidArgs(_)));
274    }
275
276    #[tokio::test]
277    async fn grace_ms_override_is_accepted() {
278        // Schema accepts the optional grace_ms; behaviorally we can't easily
279        // distinguish grace values with PS, but the call should at least not
280        // reject the argument and should complete promptly on deadline.
281        if !pwsh_available() {
282            return;
283        }
284        let t = ShellPwshTool::new();
285        let mut ctx = CallContext::for_test();
286        ctx.deadline = Some(Instant::now() + Duration::from_millis(150));
287        let start = Instant::now();
288        let _ = t
289            .call(
290                serde_json::json!({
291                    "command": "Start-Sleep -Seconds 10",
292                    "grace_ms": 100
293                }),
294                &ctx,
295            )
296            .await;
297        let elapsed = start.elapsed();
298        assert!(elapsed < Duration::from_secs(3), "too slow: {elapsed:?}");
299    }
300}