Skip to main content

statespace_tool_runtime/
spec.rs

1//! Tool specification parsing and validation.
2//!
3//! ```yaml
4//! tools:
5//!   - [ls]                                 # Simple command, extra args allowed
6//!   - [cat, { }]                           # Placeholder accepts any value
7//!   - [cat, { regex: ".*\\.md$" }]         # Regex-constrained placeholder
8//!   - [psql, -c, { regex: "^SELECT" }, ;]  # Trailing ; disables extra args
9//! ```
10
11use regex::Regex;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ToolPart {
15    Literal(String),
16    Placeholder { regex: Option<CompiledRegex> },
17}
18
19#[derive(Debug, Clone)]
20pub struct CompiledRegex {
21    pub pattern: String,
22    pub regex: Regex,
23}
24
25impl PartialEq for CompiledRegex {
26    fn eq(&self, other: &Self) -> bool {
27        self.pattern == other.pattern
28    }
29}
30
31impl Eq for CompiledRegex {}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct ToolSpec {
35    pub parts: Vec<ToolPart>,
36    pub options_disabled: bool,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
40#[non_exhaustive]
41pub enum SpecError {
42    #[error("invalid regex pattern '{pattern}': {message}")]
43    InvalidRegex { pattern: String, message: String },
44    #[error("empty tool specification")]
45    EmptySpec,
46    #[error("invalid tool part: {0}")]
47    InvalidPart(String),
48}
49
50pub type SpecResult<T> = Result<T, SpecError>;
51
52impl ToolSpec {
53    /// # Errors
54    ///
55    /// Returns `SpecError` when the tool specification is empty or invalid.
56    pub fn parse(raw: &[serde_json::Value]) -> SpecResult<Self> {
57        if raw.is_empty() {
58            return Err(SpecError::EmptySpec);
59        }
60
61        let options_disabled = raw.last().is_some_and(|v| v.as_str() == Some(";"));
62
63        let parts = raw
64            .iter()
65            .filter(|v| v.as_str() != Some(";"))
66            .map(Self::parse_part)
67            .collect::<SpecResult<Vec<_>>>()?;
68
69        if parts.is_empty() {
70            return Err(SpecError::EmptySpec);
71        }
72
73        Ok(Self {
74            parts,
75            options_disabled,
76        })
77    }
78
79    fn parse_part(value: &serde_json::Value) -> SpecResult<ToolPart> {
80        match value {
81            serde_json::Value::String(s) => Ok(ToolPart::Literal(s.clone())),
82
83            serde_json::Value::Object(obj) => {
84                if obj.is_empty() {
85                    return Ok(ToolPart::Placeholder { regex: None });
86                }
87
88                if let Some(pattern) = obj.get("regex").and_then(|v| v.as_str()) {
89                    let regex = Regex::new(pattern).map_err(|e| SpecError::InvalidRegex {
90                        pattern: pattern.to_string(),
91                        message: e.to_string(),
92                    })?;
93                    return Ok(ToolPart::Placeholder {
94                        regex: Some(CompiledRegex {
95                            pattern: pattern.to_string(),
96                            regex,
97                        }),
98                    });
99                }
100
101                Err(SpecError::InvalidPart(format!(
102                    "unknown object keys: {:?}",
103                    obj.keys().collect::<Vec<_>>()
104                )))
105            }
106
107            _ => Err(SpecError::InvalidPart(format!(
108                "expected string or object, got: {value}"
109            ))),
110        }
111    }
112}
113
114#[must_use]
115pub fn is_valid_tool_call(command: &[String], specs: &[ToolSpec]) -> bool {
116    if command.is_empty() {
117        return false;
118    }
119    specs.iter().any(|spec| matches_spec(command, spec))
120}
121
122fn matches_spec(command: &[String], spec: &ToolSpec) -> bool {
123    if command.len() < spec.parts.len() {
124        return false;
125    }
126
127    if command.len() > spec.parts.len() && spec.options_disabled {
128        return false;
129    }
130
131    for (i, part) in spec.parts.iter().enumerate() {
132        let cmd_part = &command[i];
133
134        match part {
135            ToolPart::Literal(lit) => {
136                if cmd_part != lit {
137                    return false;
138                }
139            }
140            ToolPart::Placeholder { regex: None } => {}
141            ToolPart::Placeholder {
142                regex: Some(compiled),
143            } => {
144                if !compiled.regex.is_match(cmd_part) {
145                    return false;
146                }
147            }
148        }
149    }
150
151    true
152}
153
154#[cfg(test)]
155#[allow(clippy::unwrap_used)]
156mod tests {
157    use super::*;
158
159    fn make_spec(parts: Vec<ToolPart>, options_disabled: bool) -> ToolSpec {
160        ToolSpec {
161            parts,
162            options_disabled,
163        }
164    }
165
166    fn lit(s: &str) -> ToolPart {
167        ToolPart::Literal(s.to_string())
168    }
169
170    fn placeholder() -> ToolPart {
171        ToolPart::Placeholder { regex: None }
172    }
173
174    fn regex_placeholder(pattern: &str) -> ToolPart {
175        ToolPart::Placeholder {
176            regex: Some(CompiledRegex {
177                pattern: pattern.to_string(),
178                regex: Regex::new(pattern).unwrap(),
179            }),
180        }
181    }
182
183    #[test]
184    fn validate_simple_match() {
185        let specs = vec![make_spec(vec![lit("ls")], false)];
186        assert!(is_valid_tool_call(&["ls".to_string()], &specs));
187    }
188
189    #[test]
190    fn validate_with_extra_args_allowed() {
191        let specs = vec![make_spec(vec![lit("ls")], false)];
192        assert!(is_valid_tool_call(
193            &["ls".to_string(), "-la".to_string()],
194            &specs
195        ));
196    }
197
198    #[test]
199    fn validate_with_extra_args_disabled() {
200        let specs = vec![make_spec(vec![lit("ls")], true)];
201        assert!(!is_valid_tool_call(
202            &["ls".to_string(), "-la".to_string()],
203            &specs
204        ));
205    }
206
207    #[test]
208    fn validate_placeholder_matches_any() {
209        let specs = vec![make_spec(vec![lit("cat"), placeholder()], false)];
210
211        assert!(is_valid_tool_call(
212            &["cat".to_string(), "file.txt".to_string()],
213            &specs
214        ));
215        assert!(is_valid_tool_call(
216            &["cat".to_string(), "anything".to_string()],
217            &specs
218        ));
219    }
220
221    #[test]
222    fn validate_regex_placeholder() {
223        let specs = vec![make_spec(
224            vec![lit("cat"), regex_placeholder(r".*\.md$")],
225            false,
226        )];
227
228        assert!(is_valid_tool_call(
229            &["cat".to_string(), "README.md".to_string()],
230            &specs
231        ));
232        assert!(!is_valid_tool_call(
233            &["cat".to_string(), "README.txt".to_string()],
234            &specs
235        ));
236    }
237
238    #[test]
239    fn validate_regex_with_options_disabled() {
240        let specs = vec![make_spec(
241            vec![lit("cat"), regex_placeholder(r".*\.md$")],
242            true,
243        )];
244
245        assert!(is_valid_tool_call(
246            &["cat".to_string(), "file.md".to_string()],
247            &specs
248        ));
249
250        assert!(!is_valid_tool_call(
251            &["cat".to_string(), "file.md".to_string(), "-n".to_string()],
252            &specs
253        ));
254
255        assert!(!is_valid_tool_call(
256            &["cat".to_string(), "file.txt".to_string()],
257            &specs
258        ));
259    }
260
261    #[test]
262    fn validate_complex_psql_spec() {
263        let specs = vec![make_spec(
264            vec![lit("psql"), lit("-c"), regex_placeholder("^SELECT")],
265            true,
266        )];
267
268        assert!(is_valid_tool_call(
269            &[
270                "psql".to_string(),
271                "-c".to_string(),
272                "SELECT * FROM users".to_string()
273            ],
274            &specs
275        ));
276
277        assert!(!is_valid_tool_call(
278            &[
279                "psql".to_string(),
280                "-c".to_string(),
281                "INSERT INTO users VALUES (1)".to_string()
282            ],
283            &specs
284        ));
285
286        assert!(!is_valid_tool_call(
287            &[
288                "psql".to_string(),
289                "-c".to_string(),
290                "SELECT 1".to_string(),
291                "--extra".to_string()
292            ],
293            &specs
294        ));
295    }
296
297    #[test]
298    fn validate_empty_command() {
299        let specs = vec![make_spec(vec![lit("ls")], false)];
300        assert!(!is_valid_tool_call(&[], &specs));
301    }
302
303    #[test]
304    fn validate_placeholder_is_required() {
305        let specs = vec![make_spec(vec![lit("ls"), placeholder()], false)];
306
307        assert!(!is_valid_tool_call(&["ls".into()], &specs));
308        assert!(is_valid_tool_call(&["ls".into(), "dir".into()], &specs));
309        assert!(is_valid_tool_call(
310            &["ls".into(), "dir".into(), "-la".into()],
311            &specs
312        ));
313    }
314
315    #[test]
316    fn validate_multiple_specs() {
317        let specs = vec![
318            make_spec(vec![lit("ls")], false),
319            make_spec(vec![lit("cat"), placeholder()], false),
320        ];
321
322        assert!(is_valid_tool_call(&["ls".to_string()], &specs));
323        assert!(is_valid_tool_call(
324            &["cat".to_string(), "file.txt".to_string()],
325            &specs
326        ));
327        assert!(!is_valid_tool_call(&["rm".to_string()], &specs));
328    }
329}