1use async_trait::async_trait;
4use std::collections::{HashMap, HashSet};
5
6use crate::ast::Value;
7use crate::interpreter::{ExecResult, OutputFormat};
8use crate::validator::{IssueCode, Severity, ValidationIssue};
9
10use super::context::ExecContext;
11
12#[derive(Debug, Clone)]
14pub struct ParamSchema {
15 pub name: String,
17 pub param_type: String,
19 pub required: bool,
21 pub default: Option<Value>,
23 pub description: String,
25 pub aliases: Vec<String>,
27}
28
29impl ParamSchema {
30 pub fn required(name: impl Into<String>, param_type: impl Into<String>, description: impl Into<String>) -> Self {
32 Self {
33 name: name.into(),
34 param_type: param_type.into(),
35 required: true,
36 default: None,
37 description: description.into(),
38 aliases: Vec::new(),
39 }
40 }
41
42 pub fn optional(name: impl Into<String>, param_type: impl Into<String>, default: Value, description: impl Into<String>) -> Self {
44 Self {
45 name: name.into(),
46 param_type: param_type.into(),
47 required: false,
48 default: Some(default),
49 description: description.into(),
50 aliases: Vec::new(),
51 }
52 }
53
54 pub fn with_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
58 self.aliases = aliases.into_iter().map(Into::into).collect();
59 self
60 }
61
62 pub fn matches_flag(&self, flag: &str) -> bool {
64 if self.name == flag {
65 return true;
66 }
67 self.aliases.iter().any(|a| a == flag)
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct Example {
74 pub description: String,
76 pub code: String,
78}
79
80impl Example {
81 pub fn new(description: impl Into<String>, code: impl Into<String>) -> Self {
83 Self {
84 description: description.into(),
85 code: code.into(),
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct ToolSchema {
93 pub name: String,
95 pub description: String,
97 pub params: Vec<ParamSchema>,
99 pub examples: Vec<Example>,
101 pub map_positionals: bool,
105}
106
107impl ToolSchema {
108 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
110 Self {
111 name: name.into(),
112 description: description.into(),
113 params: Vec::new(),
114 examples: Vec::new(),
115 map_positionals: false,
116 }
117 }
118
119 pub fn with_positional_mapping(mut self) -> Self {
121 self.map_positionals = true;
122 self
123 }
124
125 pub fn param(mut self, param: ParamSchema) -> Self {
127 self.params.push(param);
128 self
129 }
130
131 pub fn example(mut self, description: impl Into<String>, code: impl Into<String>) -> Self {
133 self.examples.push(Example::new(description, code));
134 self
135 }
136}
137
138#[derive(Debug, Clone, Default)]
140pub struct ToolArgs {
141 pub positional: Vec<Value>,
143 pub named: HashMap<String, Value>,
145 pub flags: HashSet<String>,
147}
148
149impl ToolArgs {
150 pub fn new() -> Self {
152 Self::default()
153 }
154
155 pub fn get_positional(&self, index: usize) -> Option<&Value> {
157 self.positional.get(index)
158 }
159
160 pub fn get_named(&self, key: &str) -> Option<&Value> {
162 self.named.get(key)
163 }
164
165 pub fn get(&self, name: &str, positional_index: usize) -> Option<&Value> {
169 self.named.get(name).or_else(|| self.positional.get(positional_index))
170 }
171
172 pub fn get_string(&self, name: &str, positional_index: usize) -> Option<String> {
174 self.get(name, positional_index).and_then(|v| match v {
175 Value::String(s) => Some(s.clone()),
176 Value::Int(i) => Some(i.to_string()),
177 Value::Float(f) => Some(f.to_string()),
178 Value::Bool(b) => Some(b.to_string()),
179 _ => None,
180 })
181 }
182
183 pub fn get_bool(&self, name: &str, positional_index: usize) -> Option<bool> {
185 self.get(name, positional_index).and_then(|v| match v {
186 Value::Bool(b) => Some(*b),
187 Value::String(s) => match s.as_str() {
188 "true" | "yes" | "1" => Some(true),
189 "false" | "no" | "0" => Some(false),
190 _ => None,
191 },
192 Value::Int(i) => Some(*i != 0),
193 _ => None,
194 })
195 }
196
197 pub fn has_flag(&self, name: &str) -> bool {
199 if self.flags.contains(name) {
201 return true;
202 }
203 self.named.get(name).is_some_and(|v| match v {
205 Value::Bool(b) => *b,
206 Value::String(s) => !s.is_empty() && s != "false" && s != "0",
207 _ => true,
208 })
209 }
210}
211
212#[async_trait]
214pub trait Tool: Send + Sync {
215 fn name(&self) -> &str;
217
218 fn schema(&self) -> ToolSchema;
220
221 async fn execute(&self, args: ToolArgs, ctx: &mut ExecContext) -> ExecResult;
223
224 fn validate(&self, args: &ToolArgs) -> Vec<ValidationIssue> {
229 validate_against_schema(args, &self.schema())
230 }
231}
232
233pub fn validate_against_schema(args: &ToolArgs, schema: &ToolSchema) -> Vec<ValidationIssue> {
240 let mut issues = Vec::new();
241
242 for (i, param) in schema.params.iter().enumerate() {
244 if !param.required {
245 continue;
246 }
247
248 let has_named = args.named.contains_key(¶m.name);
250 let has_positional = args.positional.len() > i;
251 let has_flag = param.param_type == "bool" && args.has_flag(¶m.name);
252
253 if !has_named && !has_positional && !has_flag {
254 let code = IssueCode::MissingRequiredArg;
255 issues.push(ValidationIssue {
256 severity: code.default_severity(),
257 code,
258 message: format!("required parameter '{}' not provided", param.name),
259 span: None,
260 suggestion: Some(format!("add {} or {}=<value>", param.name, param.name)),
261 });
262 }
263 }
264
265 let known_flags: HashSet<&str> = schema
267 .params
268 .iter()
269 .filter(|p| p.param_type == "bool")
270 .flat_map(|p| {
271 std::iter::once(p.name.as_str())
272 .chain(p.aliases.iter().map(|a| a.as_str()))
273 })
274 .collect();
275
276 for flag in &args.flags {
277 let flag_name = flag.trim_start_matches('-');
279 if is_global_output_flag(flag_name) {
281 continue;
282 }
283 if !known_flags.contains(flag_name) && !known_flags.contains(flag.as_str()) {
284 let matches_alias = schema.params.iter().any(|p| p.matches_flag(flag));
286 if !matches_alias {
287 issues.push(ValidationIssue {
288 severity: Severity::Warning,
289 code: IssueCode::UnknownFlag,
290 message: format!("unknown flag '{}'", flag),
291 span: None,
292 suggestion: None,
293 });
294 }
295 }
296 }
297
298 for (key, value) in &args.named {
300 if let Some(param) = schema.params.iter().find(|p| &p.name == key)
301 && let Some(issue) = check_type_compatibility(key, value, ¶m.param_type) {
302 issues.push(issue);
303 }
304 }
305
306 for (i, value) in args.positional.iter().enumerate() {
308 if let Some(param) = schema.params.get(i)
309 && let Some(issue) = check_type_compatibility(¶m.name, value, ¶m.param_type) {
310 issues.push(issue);
311 }
312 }
313
314 issues
315}
316
317const GLOBAL_OUTPUT_FLAGS: &[(&str, OutputFormat)] = &[
323 ("json", OutputFormat::Json),
324];
325
326pub fn is_global_output_flag(name: &str) -> bool {
328 GLOBAL_OUTPUT_FLAGS.iter().any(|(n, _)| *n == name)
329}
330
331pub fn extract_output_format(
337 args: &mut ToolArgs,
338 schema: Option<&ToolSchema>,
339) -> Option<OutputFormat> {
340 let _schema = schema?;
342
343 for (flag_name, format) in GLOBAL_OUTPUT_FLAGS {
344 if args.flags.remove(*flag_name) {
345 return Some(*format);
346 }
347 }
348 None
349}
350
351fn check_type_compatibility(name: &str, value: &Value, expected_type: &str) -> Option<ValidationIssue> {
353 let compatible = match expected_type {
354 "any" => true,
355 "string" => true, "int" => matches!(value, Value::Int(_) | Value::String(_)),
357 "float" => matches!(value, Value::Float(_) | Value::Int(_) | Value::String(_)),
358 "bool" => matches!(value, Value::Bool(_) | Value::String(_)),
359 "array" => matches!(value, Value::String(_)), "object" => matches!(value, Value::String(_)), _ => true, };
363
364 if compatible {
365 None
366 } else {
367 let code = IssueCode::InvalidArgType;
368 Some(ValidationIssue {
369 severity: code.default_severity(),
370 code,
371 message: format!(
372 "argument '{}' has type {:?}, expected {}",
373 name, value, expected_type
374 ),
375 span: None,
376 suggestion: None,
377 })
378 }
379}