Skip to main content

agentkit_tool_shell/
lib.rs

1use std::collections::BTreeMap;
2use std::path::PathBuf;
3use std::time::Duration;
4
5use agentkit_core::{MetadataMap, ToolOutput, ToolResultPart};
6use agentkit_tools_core::{
7    PermissionRequest, ShellPermissionRequest, Tool, ToolAnnotations, ToolContext, ToolError,
8    ToolName, ToolRegistry, ToolRequest, ToolResult, ToolSpec,
9};
10use async_trait::async_trait;
11use serde::Deserialize;
12use serde_json::json;
13use tokio::process::Command;
14use tokio::time::timeout;
15
16pub fn registry() -> ToolRegistry {
17    ToolRegistry::new().with(ShellExecTool::default())
18}
19
20#[derive(Clone, Debug)]
21pub struct ShellExecTool {
22    spec: ToolSpec,
23}
24
25impl Default for ShellExecTool {
26    fn default() -> Self {
27        Self {
28            spec: ToolSpec {
29                name: ToolName::new("shell.exec"),
30                description: "Execute a shell command and capture stdout, stderr, and exit status."
31                    .into(),
32                input_schema: json!({
33                    "type": "object",
34                    "properties": {
35                        "executable": { "type": "string" },
36                        "argv": {
37                            "type": "array",
38                            "items": { "type": "string" },
39                            "default": []
40                        },
41                        "cwd": { "type": "string" },
42                        "env": {
43                            "type": "object",
44                            "additionalProperties": { "type": "string" }
45                        },
46                        "timeout_ms": { "type": "integer", "minimum": 1 }
47                    },
48                    "required": ["executable"],
49                    "additionalProperties": false
50                }),
51                annotations: ToolAnnotations {
52                    destructive_hint: true,
53                    needs_approval_hint: true,
54                    ..ToolAnnotations::default()
55                },
56                metadata: MetadataMap::new(),
57            },
58        }
59    }
60}
61
62#[derive(Debug, Deserialize)]
63struct ShellExecInput {
64    executable: String,
65    #[serde(default)]
66    argv: Vec<String>,
67    cwd: Option<PathBuf>,
68    #[serde(default)]
69    env: BTreeMap<String, String>,
70    timeout_ms: Option<u64>,
71}
72
73#[async_trait]
74impl Tool for ShellExecTool {
75    fn spec(&self) -> &ToolSpec {
76        &self.spec
77    }
78
79    fn proposed_requests(
80        &self,
81        request: &ToolRequest,
82    ) -> Result<Vec<Box<dyn PermissionRequest>>, ToolError> {
83        let input: ShellExecInput = parse_input(request)?;
84        Ok(vec![Box::new(ShellPermissionRequest {
85            executable: input.executable,
86            argv: input.argv,
87            cwd: input.cwd,
88            env_keys: input.env.keys().cloned().collect(),
89            metadata: request.metadata.clone(),
90        })])
91    }
92
93    async fn invoke(
94        &self,
95        request: ToolRequest,
96        ctx: &mut ToolContext<'_>,
97    ) -> Result<ToolResult, ToolError> {
98        let input: ShellExecInput = parse_input(&request)?;
99        let mut command = Command::new(&input.executable);
100        command.args(&input.argv);
101        command.kill_on_drop(true);
102        if let Some(cwd) = &input.cwd {
103            command.current_dir(cwd);
104        }
105        for (key, value) in &input.env {
106            command.env(key, value);
107        }
108
109        let duration_start = std::time::Instant::now();
110        let output_future = command.output();
111        tokio::pin!(output_future);
112
113        let output = if let Some(timeout_ms) = input.timeout_ms {
114            if let Some(cancellation) = ctx.cancellation.clone() {
115                tokio::select! {
116                    result = &mut output_future => result.map_err(|error| {
117                        ToolError::ExecutionFailed(format!("failed to spawn command: {error}"))
118                    })?,
119                    _ = cancellation.cancelled() => return Err(ToolError::Cancelled),
120                    _ = tokio::time::sleep(Duration::from_millis(timeout_ms)) => {
121                        return Err(ToolError::ExecutionFailed(format!("command timed out after {timeout_ms}ms")));
122                    }
123                }
124            } else {
125                timeout(Duration::from_millis(timeout_ms), &mut output_future)
126                    .await
127                    .map_err(|_| {
128                        ToolError::ExecutionFailed(format!(
129                            "command timed out after {timeout_ms}ms"
130                        ))
131                    })?
132                    .map_err(|error| {
133                        ToolError::ExecutionFailed(format!("failed to spawn command: {error}"))
134                    })?
135            }
136        } else if let Some(cancellation) = ctx.cancellation.clone() {
137            tokio::select! {
138                result = &mut output_future => result.map_err(|error| {
139                    ToolError::ExecutionFailed(format!("failed to spawn command: {error}"))
140                })?,
141                _ = cancellation.cancelled() => return Err(ToolError::Cancelled),
142            }
143        } else {
144            output_future.await.map_err(|error| {
145                ToolError::ExecutionFailed(format!("failed to spawn command: {error}"))
146            })?
147        };
148
149        let stdout = String::from_utf8_lossy(&output.stdout).to_string();
150        let stderr = String::from_utf8_lossy(&output.stderr).to_string();
151        let status = output.status.code();
152        let success = output.status.success();
153
154        Ok(ToolResult {
155            result: ToolResultPart {
156                call_id: request.call_id,
157                output: ToolOutput::Structured(json!({
158                    "stdout": stdout,
159                    "stderr": stderr,
160                    "success": success,
161                    "exit_code": status,
162                })),
163                is_error: !success,
164                metadata: MetadataMap::new(),
165            },
166            duration: Some(duration_start.elapsed()),
167            metadata: MetadataMap::new(),
168        })
169    }
170}
171
172fn parse_input(request: &ToolRequest) -> Result<ShellExecInput, ToolError> {
173    serde_json::from_value(request.input.clone())
174        .map_err(|error| ToolError::InvalidInput(format!("invalid tool input: {error}")))
175}
176
177#[cfg(test)]
178mod tests {
179    use agentkit_capabilities::CapabilityContext;
180    use agentkit_core::{SessionId, TurnId};
181    use agentkit_tools_core::{
182        BasicToolExecutor, PermissionChecker, PermissionCode, PermissionDecision, PermissionDenial,
183        ToolExecutionOutcome, ToolExecutor,
184    };
185
186    use super::*;
187
188    struct AllowAll;
189
190    impl PermissionChecker for AllowAll {
191        fn evaluate(
192            &self,
193            _request: &dyn agentkit_tools_core::PermissionRequest,
194        ) -> PermissionDecision {
195            PermissionDecision::Allow
196        }
197    }
198
199    struct DenyCommands;
200
201    impl PermissionChecker for DenyCommands {
202        fn evaluate(
203            &self,
204            _request: &dyn agentkit_tools_core::PermissionRequest,
205        ) -> PermissionDecision {
206            PermissionDecision::Deny(PermissionDenial {
207                code: PermissionCode::CommandNotAllowed,
208                message: "commands denied in test".into(),
209                metadata: MetadataMap::new(),
210            })
211        }
212    }
213
214    #[tokio::test]
215    async fn shell_tool_executes_and_captures_output() {
216        let executor = BasicToolExecutor::new(registry());
217        let metadata = MetadataMap::new();
218        let mut ctx = ToolContext {
219            capability: CapabilityContext {
220                session_id: Some(&SessionId::new("session-1")),
221                turn_id: Some(&TurnId::new("turn-1")),
222                metadata: &metadata,
223            },
224            permissions: &AllowAll,
225            resources: &(),
226            cancellation: None,
227        };
228
229        let result = executor
230            .execute(
231                ToolRequest {
232                    call_id: "call-1".into(),
233                    tool_name: ToolName::new("shell.exec"),
234                    input: json!({
235                        "executable": "sh",
236                        "argv": ["-c", "printf hello"]
237                    }),
238                    session_id: "session-1".into(),
239                    turn_id: "turn-1".into(),
240                    metadata: MetadataMap::new(),
241                },
242                &mut ctx,
243            )
244            .await;
245
246        match result {
247            ToolExecutionOutcome::Completed(result) => {
248                let value = match result.result.output {
249                    ToolOutput::Structured(value) => value,
250                    other => panic!("unexpected output: {other:?}"),
251                };
252                assert_eq!(value["stdout"], "hello");
253                assert_eq!(value["success"], true);
254            }
255            other => panic!("unexpected outcome: {other:?}"),
256        }
257    }
258
259    #[tokio::test]
260    async fn shell_tool_respects_permission_denial() {
261        let executor = BasicToolExecutor::new(registry());
262        let metadata = MetadataMap::new();
263        let mut ctx = ToolContext {
264            capability: CapabilityContext {
265                session_id: Some(&SessionId::new("session-1")),
266                turn_id: Some(&TurnId::new("turn-1")),
267                metadata: &metadata,
268            },
269            permissions: &DenyCommands,
270            resources: &(),
271            cancellation: None,
272        };
273
274        let result = executor
275            .execute(
276                ToolRequest {
277                    call_id: "call-2".into(),
278                    tool_name: ToolName::new("shell.exec"),
279                    input: json!({
280                        "executable": "sh",
281                        "argv": ["-c", "printf nope"]
282                    }),
283                    session_id: "session-1".into(),
284                    turn_id: "turn-1".into(),
285                    metadata: MetadataMap::new(),
286                },
287                &mut ctx,
288            )
289            .await;
290
291        assert!(matches!(
292            result,
293            ToolExecutionOutcome::Failed(ToolError::PermissionDenied(_))
294        ));
295    }
296}