kaish_kernel/tools/
traits.rs1use async_trait::async_trait;
4use std::collections::HashSet;
5
6use crate::interpreter::{ExecResult, OutputFormat};
7use crate::validator::{IssueCode, Severity, ValidationIssue};
8
9pub use kaish_types::{ParamSchema, ToolArgs, ToolSchema};
11
12use super::context::ExecContext;
13use crate::ast::Value;
14
15#[async_trait]
17pub trait Tool: Send + Sync {
18 fn name(&self) -> &str;
20
21 fn schema(&self) -> ToolSchema;
23
24 async fn execute(&self, args: ToolArgs, ctx: &mut ExecContext) -> ExecResult;
26
27 fn validate(&self, args: &ToolArgs) -> Vec<ValidationIssue> {
32 validate_against_schema(args, &self.schema())
33 }
34}
35
36pub fn validate_against_schema(args: &ToolArgs, schema: &ToolSchema) -> Vec<ValidationIssue> {
43 let mut issues = Vec::new();
44
45 for (i, param) in schema.params.iter().enumerate() {
47 if !param.required {
48 continue;
49 }
50
51 let has_named = args.named.contains_key(¶m.name);
53 let has_positional = args.positional.len() > i;
54 let has_flag = param.param_type == "bool" && args.has_flag(¶m.name);
55
56 if !has_named && !has_positional && !has_flag {
57 let code = IssueCode::MissingRequiredArg;
58 issues.push(ValidationIssue {
59 severity: code.default_severity(),
60 code,
61 message: format!("required parameter '{}' not provided", param.name),
62 span: None,
63 suggestion: Some(format!("add {} or {}=<value>", param.name, param.name)),
64 });
65 }
66 }
67
68 let known_flags: HashSet<&str> = schema
70 .params
71 .iter()
72 .filter(|p| p.param_type == "bool")
73 .flat_map(|p| {
74 std::iter::once(p.name.as_str())
75 .chain(p.aliases.iter().map(|a| a.as_str()))
76 })
77 .collect();
78
79 for flag in &args.flags {
80 let flag_name = flag.trim_start_matches('-');
82 if is_global_output_flag(flag_name) {
84 continue;
85 }
86 if !known_flags.contains(flag_name) && !known_flags.contains(flag.as_str()) {
87 let matches_alias = schema.params.iter().any(|p| p.matches_flag(flag));
89 if !matches_alias {
90 issues.push(ValidationIssue {
91 severity: Severity::Warning,
92 code: IssueCode::UnknownFlag,
93 message: format!("unknown flag '{}'", flag),
94 span: None,
95 suggestion: None,
96 });
97 }
98 }
99 }
100
101 for (key, value) in &args.named {
103 if let Some(param) = schema.params.iter().find(|p| &p.name == key)
104 && let Some(issue) = check_type_compatibility(key, value, ¶m.param_type) {
105 issues.push(issue);
106 }
107 }
108
109 for (i, value) in args.positional.iter().enumerate() {
111 if let Some(param) = schema.params.get(i)
112 && let Some(issue) = check_type_compatibility(¶m.name, value, ¶m.param_type) {
113 issues.push(issue);
114 }
115 }
116
117 issues
118}
119
120const GLOBAL_OUTPUT_FLAGS: &[(&str, OutputFormat)] = &[
126 ("json", OutputFormat::Json),
127];
128
129pub fn is_global_output_flag(name: &str) -> bool {
131 GLOBAL_OUTPUT_FLAGS.iter().any(|(n, _)| *n == name)
132}
133
134pub fn extract_output_format(
140 args: &mut ToolArgs,
141 schema: Option<&ToolSchema>,
142) -> Option<OutputFormat> {
143 let _schema = schema?;
145
146 for (flag_name, format) in GLOBAL_OUTPUT_FLAGS {
147 if args.flags.remove(*flag_name) {
148 return Some(*format);
149 }
150 }
151 None
152}
153
154fn check_type_compatibility(name: &str, value: &Value, expected_type: &str) -> Option<ValidationIssue> {
156 let compatible = match expected_type {
157 "any" => true,
158 "string" => true, "int" => matches!(value, Value::Int(_) | Value::String(_)),
160 "float" => matches!(value, Value::Float(_) | Value::Int(_) | Value::String(_)),
161 "bool" => matches!(value, Value::Bool(_) | Value::String(_)),
162 "array" => matches!(value, Value::String(_)), "object" => matches!(value, Value::String(_)), _ => true, };
166
167 if compatible {
168 None
169 } else {
170 let code = IssueCode::InvalidArgType;
171 Some(ValidationIssue {
172 severity: code.default_severity(),
173 code,
174 message: format!(
175 "argument '{}' has type {:?}, expected {}",
176 name, value, expected_type
177 ),
178 span: None,
179 suggestion: None,
180 })
181 }
182}