Skip to main content

statespace_tool_runtime/
tools.rs

1//! Tool domain models
2
3use crate::error::Error;
4use serde::{Deserialize, Serialize};
5use std::fmt;
6use std::str::FromStr;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
9#[serde(rename_all = "UPPERCASE")]
10#[non_exhaustive]
11pub enum HttpMethod {
12    #[default]
13    Get,
14    Post,
15    Put,
16    Patch,
17    Delete,
18    Head,
19    Options,
20}
21
22impl HttpMethod {
23    #[must_use]
24    pub const fn as_str(&self) -> &'static str {
25        match self {
26            Self::Get => "GET",
27            Self::Post => "POST",
28            Self::Put => "PUT",
29            Self::Patch => "PATCH",
30            Self::Delete => "DELETE",
31            Self::Head => "HEAD",
32            Self::Options => "OPTIONS",
33        }
34    }
35}
36
37impl fmt::Display for HttpMethod {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        write!(f, "{}", self.as_str())
40    }
41}
42
43impl FromStr for HttpMethod {
44    type Err = Error;
45
46    fn from_str(s: &str) -> Result<Self, Self::Err> {
47        match s.to_uppercase().as_str() {
48            "GET" => Ok(Self::Get),
49            "POST" => Ok(Self::Post),
50            "PUT" => Ok(Self::Put),
51            "PATCH" => Ok(Self::Patch),
52            "DELETE" => Ok(Self::Delete),
53            "HEAD" => Ok(Self::Head),
54            "OPTIONS" => Ok(Self::Options),
55            _ => Err(Error::InvalidCommand(format!("Unknown HTTP method: {s}"))),
56        }
57    }
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
61#[serde(tag = "type", rename_all = "lowercase")]
62#[non_exhaustive]
63pub enum BuiltinTool {
64    Curl { url: String, method: HttpMethod },
65    Exec { command: String, args: Vec<String> },
66}
67
68impl BuiltinTool {
69    /// # Errors
70    ///
71    /// Returns an error when the command is empty or malformed.
72    pub fn from_command(command: &[String]) -> Result<Self, Error> {
73        if command.is_empty() {
74            return Err(Error::InvalidCommand("Command cannot be empty".to_string()));
75        }
76
77        let cmd = &command[0];
78        Ok(Self::Exec {
79            command: cmd.clone(),
80            args: command[1..].to_vec(),
81        })
82    }
83
84    #[must_use]
85    pub const fn name(&self) -> &'static str {
86        match self {
87            Self::Curl { .. } => "curl",
88            Self::Exec { .. } => "exec",
89        }
90    }
91
92    pub const fn requires_egress(&self) -> bool {
93        matches!(self, Self::Curl { .. })
94    }
95
96    pub fn is_free_tier_allowed(&self) -> bool {
97        match self {
98            Self::Curl { .. } => false,
99            Self::Exec { command, .. } => FREE_TIER_COMMAND_ALLOWLIST.contains(&command.as_str()),
100        }
101    }
102}
103
104pub const FREE_TIER_COMMAND_ALLOWLIST: &[&str] = &[
105    "cat",
106    "head",
107    "tail",
108    "less",
109    "more",
110    "wc",
111    "sort",
112    "uniq",
113    "cut",
114    "paste",
115    "tr",
116    "tee",
117    "split",
118    "csplit",
119    "ls",
120    "stat",
121    "file",
122    "du",
123    "df",
124    "find",
125    "which",
126    "whereis",
127    "cp",
128    "mv",
129    "rm",
130    "mkdir",
131    "rmdir",
132    "touch",
133    "ln",
134    "grep",
135    "egrep",
136    "fgrep",
137    "sed",
138    "awk",
139    "diff",
140    "comm",
141    "cmp",
142    "jq",
143    "tar",
144    "gzip",
145    "gunzip",
146    "zcat",
147    "bzip2",
148    "bunzip2",
149    "xz",
150    "unxz",
151    "echo",
152    "printf",
153    "true",
154    "false",
155    "yes",
156    "date",
157    "cal",
158    "env",
159    "printenv",
160    "basename",
161    "dirname",
162    "realpath",
163    "readlink",
164    "pwd",
165    "id",
166    "whoami",
167    "uname",
168    "hostname",
169    "md5sum",
170    "sha256sum",
171    "base64",
172    "xxd",
173    "hexdump",
174    "od",
175];
176
177#[cfg(test)]
178#[allow(clippy::unwrap_used)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_builtin_tool_name() {
184        let exec = BuiltinTool::Exec {
185            command: "ls".to_string(),
186            args: vec![],
187        };
188        assert_eq!(exec.name(), "exec");
189    }
190
191    #[test]
192    fn test_from_command_ls() {
193        let tool = BuiltinTool::from_command(&["ls".to_string(), "docs/".to_string()]).unwrap();
194        assert!(matches!(
195            tool,
196            BuiltinTool::Exec { command, args } if command == "ls" && args == vec!["docs/"]
197        ));
198    }
199
200    #[test]
201    fn test_from_command_cat() {
202        let tool = BuiltinTool::from_command(&["cat".to_string(), "file.md".to_string()]).unwrap();
203        assert!(matches!(
204            tool,
205            BuiltinTool::Exec { command, args } if command == "cat" && args == vec!["file.md"]
206        ));
207
208        let tool = BuiltinTool::from_command(&["cat".to_string()]).unwrap();
209        assert!(matches!(
210            tool,
211            BuiltinTool::Exec { command, args } if command == "cat" && args.is_empty()
212        ));
213    }
214
215    #[test]
216    fn test_from_command_curl() {
217        let tool =
218            BuiltinTool::from_command(&["curl".to_string(), "https://api.github.com".to_string()])
219                .unwrap();
220        assert!(matches!(
221            tool,
222            BuiltinTool::Exec { command, args } if command == "curl" && args == vec!["https://api.github.com"]
223        ));
224
225        let tool = BuiltinTool::from_command(&[
226            "curl".to_string(),
227            "-s".to_string(),
228            "-X".to_string(),
229            "POST".to_string(),
230            "https://api.github.com".to_string(),
231        ])
232        .unwrap();
233        assert!(matches!(
234            tool,
235            BuiltinTool::Exec { command, args } if command == "curl" && args == vec!["-s", "-X", "POST", "https://api.github.com"]
236        ));
237    }
238
239    #[test]
240    fn test_from_command_custom() {
241        let tool = BuiltinTool::from_command(&["jq".to_string(), ".".to_string()]).unwrap();
242        assert!(matches!(
243            tool,
244            BuiltinTool::Exec { command, args } if command == "jq" && args == vec!["."]
245        ));
246
247        let tool =
248            BuiltinTool::from_command(&["node".to_string(), "script.js".to_string()]).unwrap();
249        assert!(matches!(
250            tool,
251            BuiltinTool::Exec { command, args } if command == "node" && args == vec!["script.js"]
252        ));
253    }
254
255    #[test]
256    fn test_http_method_parsing() {
257        assert_eq!("GET".parse::<HttpMethod>().unwrap(), HttpMethod::Get);
258        assert_eq!("post".parse::<HttpMethod>().unwrap(), HttpMethod::Post);
259        assert!("INVALID".parse::<HttpMethod>().is_err());
260    }
261
262    #[test]
263    fn test_is_free_tier_allowed_curl_blocked() {
264        let tool = BuiltinTool::Curl {
265            url: "https://example.com".to_string(),
266            method: HttpMethod::Get,
267        };
268        assert!(!tool.is_free_tier_allowed());
269    }
270
271    #[test]
272    fn test_is_free_tier_allowed_allowlisted_commands() {
273        for cmd in ["cat", "ls", "grep", "sed", "awk", "jq", "head", "tail"] {
274            let tool = BuiltinTool::Exec {
275                command: cmd.to_string(),
276                args: vec![],
277            };
278            assert!(tool.is_free_tier_allowed(), "{cmd} should be allowed");
279        }
280    }
281
282    #[test]
283    fn test_is_free_tier_blocked_dangerous_commands() {
284        for cmd in [
285            "wget", "nc", "ssh", "node", "ruby", "curl", "apt", "pip", "npm",
286        ] {
287            let tool = BuiltinTool::Exec {
288                command: cmd.to_string(),
289                args: vec![],
290            };
291            assert!(!tool.is_free_tier_allowed(), "{cmd} should be blocked");
292        }
293    }
294
295    #[test]
296    fn test_requires_egress() {
297        assert!(
298            BuiltinTool::Curl {
299                url: "https://example.com".to_string(),
300                method: HttpMethod::Get,
301            }
302            .requires_egress()
303        );
304
305        assert!(
306            !BuiltinTool::Exec {
307                command: "ls".to_string(),
308                args: vec![],
309            }
310            .requires_egress()
311        );
312    }
313}