1use async_trait::async_trait;
4use std::collections::{HashMap, HashSet};
5
6use crate::ast::Value;
7use crate::interpreter::{ExecResult, OutputFormat};
8use crate::validator::{IssueCode, Severity, ValidationIssue};
9
10use super::context::ExecContext;
11
12#[derive(Debug, Clone)]
14pub struct ParamSchema {
15 pub name: String,
17 pub param_type: String,
19 pub required: bool,
21 pub default: Option<Value>,
23 pub description: String,
25 pub aliases: Vec<String>,
27}
28
29impl ParamSchema {
30 pub fn required(name: impl Into<String>, param_type: impl Into<String>, description: impl Into<String>) -> Self {
32 Self {
33 name: name.into(),
34 param_type: param_type.into(),
35 required: true,
36 default: None,
37 description: description.into(),
38 aliases: Vec::new(),
39 }
40 }
41
42 pub fn optional(name: impl Into<String>, param_type: impl Into<String>, default: Value, description: impl Into<String>) -> Self {
44 Self {
45 name: name.into(),
46 param_type: param_type.into(),
47 required: false,
48 default: Some(default),
49 description: description.into(),
50 aliases: Vec::new(),
51 }
52 }
53
54 pub fn with_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
58 self.aliases = aliases.into_iter().map(Into::into).collect();
59 self
60 }
61
62 pub fn matches_flag(&self, flag: &str) -> bool {
64 if self.name == flag {
65 return true;
66 }
67 self.aliases.iter().any(|a| a == flag)
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct Example {
74 pub description: String,
76 pub code: String,
78}
79
80impl Example {
81 pub fn new(description: impl Into<String>, code: impl Into<String>) -> Self {
83 Self {
84 description: description.into(),
85 code: code.into(),
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct ToolSchema {
93 pub name: String,
95 pub description: String,
97 pub params: Vec<ParamSchema>,
99 pub examples: Vec<Example>,
101}
102
103impl ToolSchema {
104 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
106 Self {
107 name: name.into(),
108 description: description.into(),
109 params: Vec::new(),
110 examples: Vec::new(),
111 }
112 }
113
114 pub fn param(mut self, param: ParamSchema) -> Self {
116 self.params.push(param);
117 self
118 }
119
120 pub fn example(mut self, description: impl Into<String>, code: impl Into<String>) -> Self {
122 self.examples.push(Example::new(description, code));
123 self
124 }
125}
126
127#[derive(Debug, Clone, Default)]
129pub struct ToolArgs {
130 pub positional: Vec<Value>,
132 pub named: HashMap<String, Value>,
134 pub flags: HashSet<String>,
136}
137
138impl ToolArgs {
139 pub fn new() -> Self {
141 Self::default()
142 }
143
144 pub fn get_positional(&self, index: usize) -> Option<&Value> {
146 self.positional.get(index)
147 }
148
149 pub fn get_named(&self, key: &str) -> Option<&Value> {
151 self.named.get(key)
152 }
153
154 pub fn get(&self, name: &str, positional_index: usize) -> Option<&Value> {
158 self.named.get(name).or_else(|| self.positional.get(positional_index))
159 }
160
161 pub fn get_string(&self, name: &str, positional_index: usize) -> Option<String> {
163 self.get(name, positional_index).and_then(|v| match v {
164 Value::String(s) => Some(s.clone()),
165 Value::Int(i) => Some(i.to_string()),
166 Value::Float(f) => Some(f.to_string()),
167 Value::Bool(b) => Some(b.to_string()),
168 _ => None,
169 })
170 }
171
172 pub fn get_bool(&self, name: &str, positional_index: usize) -> Option<bool> {
174 self.get(name, positional_index).and_then(|v| match v {
175 Value::Bool(b) => Some(*b),
176 Value::String(s) => match s.as_str() {
177 "true" | "yes" | "1" => Some(true),
178 "false" | "no" | "0" => Some(false),
179 _ => None,
180 },
181 Value::Int(i) => Some(*i != 0),
182 _ => None,
183 })
184 }
185
186 pub fn has_flag(&self, name: &str) -> bool {
188 if self.flags.contains(name) {
190 return true;
191 }
192 self.named.get(name).is_some_and(|v| match v {
194 Value::Bool(b) => *b,
195 Value::String(s) => !s.is_empty() && s != "false" && s != "0",
196 _ => true,
197 })
198 }
199}
200
201#[async_trait]
203pub trait Tool: Send + Sync {
204 fn name(&self) -> &str;
206
207 fn schema(&self) -> ToolSchema;
209
210 async fn execute(&self, args: ToolArgs, ctx: &mut ExecContext) -> ExecResult;
212
213 fn validate(&self, args: &ToolArgs) -> Vec<ValidationIssue> {
218 validate_against_schema(args, &self.schema())
219 }
220}
221
222pub fn validate_against_schema(args: &ToolArgs, schema: &ToolSchema) -> Vec<ValidationIssue> {
229 let mut issues = Vec::new();
230
231 for (i, param) in schema.params.iter().enumerate() {
233 if !param.required {
234 continue;
235 }
236
237 let has_named = args.named.contains_key(¶m.name);
239 let has_positional = args.positional.len() > i;
240 let has_flag = param.param_type == "bool" && args.has_flag(¶m.name);
241
242 if !has_named && !has_positional && !has_flag {
243 let code = IssueCode::MissingRequiredArg;
244 issues.push(ValidationIssue {
245 severity: code.default_severity(),
246 code,
247 message: format!("required parameter '{}' not provided", param.name),
248 span: None,
249 suggestion: Some(format!("add {} or {}=<value>", param.name, param.name)),
250 });
251 }
252 }
253
254 let known_flags: HashSet<&str> = schema
256 .params
257 .iter()
258 .filter(|p| p.param_type == "bool")
259 .flat_map(|p| {
260 std::iter::once(p.name.as_str())
261 .chain(p.aliases.iter().map(|a| a.as_str()))
262 })
263 .collect();
264
265 for flag in &args.flags {
266 let flag_name = flag.trim_start_matches('-');
268 if is_global_output_flag(flag_name) {
270 continue;
271 }
272 if !known_flags.contains(flag_name) && !known_flags.contains(flag.as_str()) {
273 let matches_alias = schema.params.iter().any(|p| p.matches_flag(flag));
275 if !matches_alias {
276 issues.push(ValidationIssue {
277 severity: Severity::Warning,
278 code: IssueCode::UnknownFlag,
279 message: format!("unknown flag '{}'", flag),
280 span: None,
281 suggestion: None,
282 });
283 }
284 }
285 }
286
287 for (key, value) in &args.named {
289 if let Some(param) = schema.params.iter().find(|p| &p.name == key)
290 && let Some(issue) = check_type_compatibility(key, value, ¶m.param_type) {
291 issues.push(issue);
292 }
293 }
294
295 for (i, value) in args.positional.iter().enumerate() {
297 if let Some(param) = schema.params.get(i)
298 && let Some(issue) = check_type_compatibility(¶m.name, value, ¶m.param_type) {
299 issues.push(issue);
300 }
301 }
302
303 issues
304}
305
306const GLOBAL_OUTPUT_FLAGS: &[(&str, OutputFormat)] = &[
312 ("json", OutputFormat::Json),
313];
314
315pub fn is_global_output_flag(name: &str) -> bool {
317 GLOBAL_OUTPUT_FLAGS.iter().any(|(n, _)| *n == name)
318}
319
320pub fn extract_output_format(
326 args: &mut ToolArgs,
327 schema: Option<&ToolSchema>,
328) -> Option<OutputFormat> {
329 let _schema = schema?;
331
332 for (flag_name, format) in GLOBAL_OUTPUT_FLAGS {
333 if args.flags.remove(*flag_name) {
334 return Some(*format);
335 }
336 }
337 None
338}
339
340fn check_type_compatibility(name: &str, value: &Value, expected_type: &str) -> Option<ValidationIssue> {
342 let compatible = match expected_type {
343 "any" => true,
344 "string" => true, "int" => matches!(value, Value::Int(_) | Value::String(_)),
346 "float" => matches!(value, Value::Float(_) | Value::Int(_) | Value::String(_)),
347 "bool" => matches!(value, Value::Bool(_) | Value::String(_)),
348 "array" => matches!(value, Value::String(_)), "object" => matches!(value, Value::String(_)), _ => true, };
352
353 if compatible {
354 None
355 } else {
356 let code = IssueCode::InvalidArgType;
357 Some(ValidationIssue {
358 severity: code.default_severity(),
359 code,
360 message: format!(
361 "argument '{}' has type {:?}, expected {}",
362 name, value, expected_type
363 ),
364 span: None,
365 suggestion: None,
366 })
367 }
368}