1use std::collections::HashSet;
4
5use async_trait::async_trait;
6
7use kaish_types::{ExecResult, ParamSchema, ToolArgs, ToolSchema, Value};
8
9use crate::ctx::ToolCtx;
10use crate::issue::{IssueCode, Severity, ValidationIssue};
11
12#[async_trait]
19pub trait Tool: Send + Sync {
20 fn name(&self) -> &str;
22
23 fn schema(&self) -> ToolSchema;
25
26 async fn execute(&self, args: ToolArgs, ctx: &mut dyn ToolCtx) -> ExecResult;
28
29 fn validate(&self, args: &ToolArgs) -> Vec<ValidationIssue> {
34 validate_against_schema(args, &self.schema())
35 }
36}
37
38pub fn validate_against_schema(args: &ToolArgs, schema: &ToolSchema) -> Vec<ValidationIssue> {
51 let mut issues = Vec::new();
52
53 let positional_params: Vec<&ParamSchema> = schema.params.iter().filter(|p| p.positional).collect();
54 let flag_params: Vec<&ParamSchema> = schema.params.iter().filter(|p| !p.positional).collect();
55
56 for (slot, param) in positional_params.iter().enumerate() {
58 if !param.required {
59 continue;
60 }
61 let has_positional = args.positional.len() > slot;
62 let has_named = args.named.contains_key(¶m.name);
65 if !has_positional && !has_named {
66 let code = IssueCode::MissingRequiredArg;
67 issues.push(ValidationIssue {
68 severity: code.default_severity(),
69 code,
70 message: format!("required parameter '{}' not provided", param.name),
71 span: None,
72 suggestion: Some(format!("add {} or {}=<value>", param.name, param.name)),
73 });
74 }
75 }
76
77 for param in &flag_params {
79 if !param.required {
80 continue;
81 }
82 let has_named = args.named.contains_key(¶m.name);
83 let has_flag = param.param_type == "bool" && args.has_flag(¶m.name);
84 if !has_named && !has_flag {
85 let code = IssueCode::MissingRequiredArg;
86 issues.push(ValidationIssue {
87 severity: code.default_severity(),
88 code,
89 message: format!("required parameter '{}' not provided", param.name),
90 span: None,
91 suggestion: Some(format!("add --{} <value>", param.name)),
92 });
93 }
94 }
95
96 let known_flags: HashSet<&str> = flag_params
100 .iter()
101 .filter(|p| p.param_type == "bool")
102 .flat_map(|p| {
103 std::iter::once(p.name.as_str())
104 .chain(p.aliases.iter().map(|a| a.as_str()))
105 })
106 .collect();
107
108 for flag in &args.flags {
109 let flag_name = flag.trim_start_matches('-');
111 if is_global_output_flag(flag_name) {
113 continue;
114 }
115 if !known_flags.contains(flag_name) && !known_flags.contains(flag.as_str()) {
116 let matches_alias = flag_params.iter().any(|p| p.matches_flag(flag));
118 if !matches_alias {
119 issues.push(ValidationIssue {
120 severity: Severity::Warning,
121 code: IssueCode::UnknownFlag,
122 message: format!("unknown flag '{}'", flag),
123 span: None,
124 suggestion: None,
125 });
126 }
127 }
128 }
129
130 for (key, value) in &args.named {
133 if let Some(param) = schema.params.iter().find(|p| &p.name == key)
134 && let Some(issue) = check_type_compatibility(key, value, ¶m.param_type) {
135 issues.push(issue);
136 }
137 }
138
139 for (slot, value) in args.positional.iter().enumerate() {
143 if let Some(param) = positional_params.get(slot)
144 && let Some(issue) = check_type_compatibility(¶m.name, value, ¶m.param_type) {
145 issues.push(issue);
146 }
147 }
148
149 issues
150}
151
152pub fn is_global_output_flag(name: &str) -> bool {
168 name == "json"
169}
170
171fn check_type_compatibility(name: &str, value: &Value, expected_type: &str) -> Option<ValidationIssue> {
173 let compatible = match expected_type {
174 "any" => true,
175 "string" => true, "int" => matches!(value, Value::Int(_) | Value::String(_)),
177 "float" => matches!(value, Value::Float(_) | Value::Int(_) | Value::String(_)),
178 "bool" => matches!(value, Value::Bool(_) | Value::String(_)),
179 "array" => matches!(value, Value::String(_)), "object" => matches!(value, Value::String(_)), _ => true, };
183
184 if compatible {
185 None
186 } else {
187 let code = IssueCode::InvalidArgType;
188 Some(ValidationIssue {
189 severity: code.default_severity(),
190 code,
191 message: format!(
192 "argument '{}' has type {:?}, expected {}",
193 name, value, expected_type
194 ),
195 span: None,
196 suggestion: None,
197 })
198 }
199}
200
201#[cfg(test)]
202mod validate_tests {
203 use super::*;
204 use kaish_types::{ParamSchema, ToolSchema};
205
206 fn schema_with_positionals_after_flags() -> ToolSchema {
207 ToolSchema::new("demo", "demo")
209 .param(
210 ParamSchema::new("verbose", "bool")
211 .with_default(Some(Value::Bool(false)))
212 .with_aliases(["v"]),
213 )
214 .param(ParamSchema::new("lines", "int").with_aliases(["n"]))
215 .param(
216 ParamSchema::new("path", "string")
217 .with_required(true)
218 .positional(),
219 )
220 }
221
222 #[test]
227 fn required_positional_satisfied_when_positional_sits_after_flags() {
228 let schema = schema_with_positionals_after_flags();
229 let mut args = ToolArgs::new();
230 args.positional.push(Value::String("foo.txt".into()));
231
232 let issues = validate_against_schema(&args, &schema);
233 assert!(
234 !issues.iter().any(|i| i.code == IssueCode::MissingRequiredArg),
235 "required positional should be satisfied by positional[0]; got {:?}",
236 issues
237 );
238 }
239
240 #[test]
241 fn required_positional_missing_when_no_positional_given() {
242 let schema = schema_with_positionals_after_flags();
243 let mut args = ToolArgs::new();
244 args.flags.insert("verbose".into());
245
246 let issues = validate_against_schema(&args, &schema);
247 assert!(
248 issues.iter().any(|i| i.code == IssueCode::MissingRequiredArg),
249 "missing required positional should error; got {:?}",
250 issues
251 );
252 }
253
254 #[test]
261 fn positional_type_check_targets_positional_slot_not_struct_index() {
262 let mut schema = ToolSchema::new("demo", "demo");
263 schema = schema
265 .param(ParamSchema::new("verbose", "bool").with_default(Some(Value::Bool(false))))
266 .param(
267 ParamSchema::new("count", "int")
268 .with_required(true)
269 .positional(),
270 )
271 .param(
272 ParamSchema::new("name", "string")
273 .with_required(true)
274 .positional(),
275 );
276
277 let mut args = ToolArgs::new();
278 args.positional.push(Value::Int(5));
279 args.positional.push(Value::String("widget".into()));
280
281 let issues = validate_against_schema(&args, &schema);
282 assert!(
283 !issues.iter().any(|i| matches!(i.code, IssueCode::InvalidArgType)),
284 "int->int and string->string slots should validate clean; got {:?}",
285 issues
286 );
287 }
288
289 #[test]
292 fn required_flag_still_errors_when_missing() {
293 let schema = ToolSchema::new("demo", "demo").param(
294 ParamSchema::new("output", "string")
295 .with_required(true)
296 .with_aliases(["o"]),
297 );
298
299 let args = ToolArgs::new();
300 let issues = validate_against_schema(&args, &schema);
301 assert!(
302 issues.iter().any(|i| i.code == IssueCode::MissingRequiredArg),
303 "required flag should error when missing; got {:?}",
304 issues
305 );
306 }
307}