1use regex::Regex;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ToolPart {
15 Literal(String),
16 Placeholder { regex: Option<CompiledRegex> },
17}
18
19#[derive(Debug, Clone)]
20pub struct CompiledRegex {
21 pub pattern: String,
22 pub regex: Regex,
23}
24
25impl PartialEq for CompiledRegex {
26 fn eq(&self, other: &Self) -> bool {
27 self.pattern == other.pattern
28 }
29}
30
31impl Eq for CompiledRegex {}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct ToolSpec {
35 pub parts: Vec<ToolPart>,
36 pub options_disabled: bool,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
40#[non_exhaustive]
41pub enum SpecError {
42 #[error("invalid regex pattern '{pattern}': {message}")]
43 InvalidRegex { pattern: String, message: String },
44 #[error("empty tool specification")]
45 EmptySpec,
46 #[error("invalid tool part: {0}")]
47 InvalidPart(String),
48}
49
50pub type SpecResult<T> = Result<T, SpecError>;
51
52impl ToolSpec {
53 pub fn parse(raw: &[serde_json::Value]) -> SpecResult<Self> {
57 if raw.is_empty() {
58 return Err(SpecError::EmptySpec);
59 }
60
61 let options_disabled = raw.last().is_some_and(|v| v.as_str() == Some(";"));
62
63 let parts = raw
64 .iter()
65 .filter(|v| v.as_str() != Some(";"))
66 .map(Self::parse_part)
67 .collect::<SpecResult<Vec<_>>>()?;
68
69 if parts.is_empty() {
70 return Err(SpecError::EmptySpec);
71 }
72
73 Ok(Self {
74 parts,
75 options_disabled,
76 })
77 }
78
79 fn parse_part(value: &serde_json::Value) -> SpecResult<ToolPart> {
80 match value {
81 serde_json::Value::String(s) => Ok(ToolPart::Literal(s.clone())),
82
83 serde_json::Value::Object(obj) => {
84 if obj.is_empty() {
85 return Ok(ToolPart::Placeholder { regex: None });
86 }
87
88 if let Some(pattern) = obj.get("regex").and_then(|v| v.as_str()) {
89 let regex = Regex::new(pattern).map_err(|e| SpecError::InvalidRegex {
90 pattern: pattern.to_string(),
91 message: e.to_string(),
92 })?;
93 return Ok(ToolPart::Placeholder {
94 regex: Some(CompiledRegex {
95 pattern: pattern.to_string(),
96 regex,
97 }),
98 });
99 }
100
101 Err(SpecError::InvalidPart(format!(
102 "unknown object keys: {:?}",
103 obj.keys().collect::<Vec<_>>()
104 )))
105 }
106
107 _ => Err(SpecError::InvalidPart(format!(
108 "expected string or object, got: {value}"
109 ))),
110 }
111 }
112}
113
114#[must_use]
115pub fn is_valid_tool_call(command: &[String], specs: &[ToolSpec]) -> bool {
116 if command.is_empty() {
117 return false;
118 }
119 specs.iter().any(|spec| matches_spec(command, spec))
120}
121
122fn matches_spec(command: &[String], spec: &ToolSpec) -> bool {
123 if command.len() < spec.parts.len() {
124 return false;
125 }
126
127 if command.len() > spec.parts.len() && spec.options_disabled {
128 return false;
129 }
130
131 for (i, part) in spec.parts.iter().enumerate() {
132 let cmd_part = &command[i];
133
134 match part {
135 ToolPart::Literal(lit) => {
136 if cmd_part != lit {
137 return false;
138 }
139 }
140 ToolPart::Placeholder { regex: None } => {}
141 ToolPart::Placeholder {
142 regex: Some(compiled),
143 } => {
144 if !compiled.regex.is_match(cmd_part) {
145 return false;
146 }
147 }
148 }
149 }
150
151 true
152}
153
154#[cfg(test)]
155#[allow(clippy::unwrap_used)]
156mod tests {
157 use super::*;
158
159 fn make_spec(parts: Vec<ToolPart>, options_disabled: bool) -> ToolSpec {
160 ToolSpec {
161 parts,
162 options_disabled,
163 }
164 }
165
166 fn lit(s: &str) -> ToolPart {
167 ToolPart::Literal(s.to_string())
168 }
169
170 fn placeholder() -> ToolPart {
171 ToolPart::Placeholder { regex: None }
172 }
173
174 fn regex_placeholder(pattern: &str) -> ToolPart {
175 ToolPart::Placeholder {
176 regex: Some(CompiledRegex {
177 pattern: pattern.to_string(),
178 regex: Regex::new(pattern).unwrap(),
179 }),
180 }
181 }
182
183 #[test]
184 fn validate_simple_match() {
185 let specs = vec![make_spec(vec![lit("ls")], false)];
186 assert!(is_valid_tool_call(&["ls".to_string()], &specs));
187 }
188
189 #[test]
190 fn validate_with_extra_args_allowed() {
191 let specs = vec![make_spec(vec![lit("ls")], false)];
192 assert!(is_valid_tool_call(
193 &["ls".to_string(), "-la".to_string()],
194 &specs
195 ));
196 }
197
198 #[test]
199 fn validate_with_extra_args_disabled() {
200 let specs = vec![make_spec(vec![lit("ls")], true)];
201 assert!(!is_valid_tool_call(
202 &["ls".to_string(), "-la".to_string()],
203 &specs
204 ));
205 }
206
207 #[test]
208 fn validate_placeholder_matches_any() {
209 let specs = vec![make_spec(vec![lit("cat"), placeholder()], false)];
210
211 assert!(is_valid_tool_call(
212 &["cat".to_string(), "file.txt".to_string()],
213 &specs
214 ));
215 assert!(is_valid_tool_call(
216 &["cat".to_string(), "anything".to_string()],
217 &specs
218 ));
219 }
220
221 #[test]
222 fn validate_regex_placeholder() {
223 let specs = vec![make_spec(
224 vec![lit("cat"), regex_placeholder(r".*\.md$")],
225 false,
226 )];
227
228 assert!(is_valid_tool_call(
229 &["cat".to_string(), "README.md".to_string()],
230 &specs
231 ));
232 assert!(!is_valid_tool_call(
233 &["cat".to_string(), "README.txt".to_string()],
234 &specs
235 ));
236 }
237
238 #[test]
239 fn validate_regex_with_options_disabled() {
240 let specs = vec![make_spec(
241 vec![lit("cat"), regex_placeholder(r".*\.md$")],
242 true,
243 )];
244
245 assert!(is_valid_tool_call(
246 &["cat".to_string(), "file.md".to_string()],
247 &specs
248 ));
249
250 assert!(!is_valid_tool_call(
251 &["cat".to_string(), "file.md".to_string(), "-n".to_string()],
252 &specs
253 ));
254
255 assert!(!is_valid_tool_call(
256 &["cat".to_string(), "file.txt".to_string()],
257 &specs
258 ));
259 }
260
261 #[test]
262 fn validate_complex_psql_spec() {
263 let specs = vec![make_spec(
264 vec![lit("psql"), lit("-c"), regex_placeholder("^SELECT")],
265 true,
266 )];
267
268 assert!(is_valid_tool_call(
269 &[
270 "psql".to_string(),
271 "-c".to_string(),
272 "SELECT * FROM users".to_string()
273 ],
274 &specs
275 ));
276
277 assert!(!is_valid_tool_call(
278 &[
279 "psql".to_string(),
280 "-c".to_string(),
281 "INSERT INTO users VALUES (1)".to_string()
282 ],
283 &specs
284 ));
285
286 assert!(!is_valid_tool_call(
287 &[
288 "psql".to_string(),
289 "-c".to_string(),
290 "SELECT 1".to_string(),
291 "--extra".to_string()
292 ],
293 &specs
294 ));
295 }
296
297 #[test]
298 fn validate_empty_command() {
299 let specs = vec![make_spec(vec![lit("ls")], false)];
300 assert!(!is_valid_tool_call(&[], &specs));
301 }
302
303 #[test]
304 fn validate_placeholder_is_required() {
305 let specs = vec![make_spec(vec![lit("ls"), placeholder()], false)];
306
307 assert!(!is_valid_tool_call(&["ls".into()], &specs));
308 assert!(is_valid_tool_call(&["ls".into(), "dir".into()], &specs));
309 assert!(is_valid_tool_call(
310 &["ls".into(), "dir".into(), "-la".into()],
311 &specs
312 ));
313 }
314
315 #[test]
316 fn validate_multiple_specs() {
317 let specs = vec![
318 make_spec(vec![lit("ls")], false),
319 make_spec(vec![lit("cat"), placeholder()], false),
320 ];
321
322 assert!(is_valid_tool_call(&["ls".to_string()], &specs));
323 assert!(is_valid_tool_call(
324 &["cat".to_string(), "file.txt".to_string()],
325 &specs
326 ));
327 assert!(!is_valid_tool_call(&["rm".to_string()], &specs));
328 }
329}