Skip to main content

kaish_tool_api/
tool.rs

1//! The `Tool` trait and argument validation.
2
3use std::collections::HashSet;
4
5use async_trait::async_trait;
6
7use kaish_types::{ExecResult, ParamSchema, ToolArgs, ToolSchema, Value};
8
9use crate::ctx::ToolCtx;
10use crate::issue::{IssueCode, Severity, ValidationIssue};
11
12/// A tool that can be executed.
13///
14/// Every kaish command — builtin or third-party — implements this trait. The
15/// `execute` method receives a `&mut dyn ToolCtx`, the trimmed portable
16/// context; tools needing deeper kernel state downcast via
17/// [`ToolCtx::as_any_mut`](crate::ToolCtx::as_any_mut).
18#[async_trait]
19pub trait Tool: Send + Sync {
20    /// The tool's name (used for lookup).
21    fn name(&self) -> &str;
22
23    /// Get the tool's schema.
24    fn schema(&self) -> ToolSchema;
25
26    /// Execute the tool with the given arguments and context.
27    async fn execute(&self, args: ToolArgs, ctx: &mut dyn ToolCtx) -> ExecResult;
28
29    /// Validate arguments without executing.
30    ///
31    /// Default implementation validates against the schema.
32    /// Override this for semantic checks (regex validity, zero increment, etc.).
33    fn validate(&self, args: &ToolArgs) -> Vec<ValidationIssue> {
34        validate_against_schema(args, &self.schema())
35    }
36}
37
38/// Validate arguments against a tool schema.
39///
40/// Splits `schema.params` into positional and named/flag groups so the
41/// positional slot index never conflates with the struct-field index. With
42/// clap-derived schemas, positionals sit *after* the flags in struct order;
43/// the old single-index walk would have falsely failed `mkdir foo` because
44/// the path slot lives at struct index 1+.
45///
46/// Checks:
47/// - Required parameters are provided (positionals by slot, flags by name).
48/// - Unknown flags (warning).
49/// - Type compatibility for both positional and named args.
50pub fn validate_against_schema(args: &ToolArgs, schema: &ToolSchema) -> Vec<ValidationIssue> {
51    let mut issues = Vec::new();
52
53    let positional_params: Vec<&ParamSchema> = schema.params.iter().filter(|p| p.positional).collect();
54    let flag_params: Vec<&ParamSchema> = schema.params.iter().filter(|p| !p.positional).collect();
55
56    // Required positionals: matched by slot among positional params only.
57    for (slot, param) in positional_params.iter().enumerate() {
58        if !param.required {
59            continue;
60        }
61        let has_positional = args.positional.len() > slot;
62        // A required positional can also be supplied as a named arg if the
63        // caller knows the param name (e.g. `mkdir paths=foo`).
64        let has_named = args.named.contains_key(&param.name);
65        if !has_positional && !has_named {
66            let code = IssueCode::MissingRequiredArg;
67            issues.push(ValidationIssue {
68                severity: code.default_severity(),
69                code,
70                message: format!("required parameter '{}' not provided", param.name),
71                span: None,
72                suggestion: Some(format!("add {} or {}=<value>", param.name, param.name)),
73            });
74        }
75    }
76
77    // Required flags: matched by name (or alias) against args.named / args.flags.
78    for param in &flag_params {
79        if !param.required {
80            continue;
81        }
82        let has_named = args.named.contains_key(&param.name);
83        let has_flag = param.param_type == "bool" && args.has_flag(&param.name);
84        if !has_named && !has_flag {
85            let code = IssueCode::MissingRequiredArg;
86            issues.push(ValidationIssue {
87                severity: code.default_severity(),
88                code,
89                message: format!("required parameter '{}' not provided", param.name),
90                span: None,
91                suggestion: Some(format!("add --{} <value>", param.name)),
92            });
93        }
94    }
95
96    // Check for unknown flags (only warn - tools may accept dynamic flags).
97    // Only bool flags are gathered for the strict known-flag set; the
98    // alias-fallback below catches value-taking flags via `matches_flag`.
99    let known_flags: HashSet<&str> = flag_params
100        .iter()
101        .filter(|p| p.param_type == "bool")
102        .flat_map(|p| {
103            std::iter::once(p.name.as_str())
104                .chain(p.aliases.iter().map(|a| a.as_str()))
105        })
106        .collect();
107
108    for flag in &args.flags {
109        // Strip leading dashes for comparison
110        let flag_name = flag.trim_start_matches('-');
111        // Global output flags are handled by the kernel, not the tool
112        if is_global_output_flag(flag_name) {
113            continue;
114        }
115        if !known_flags.contains(flag_name) && !known_flags.contains(flag.as_str()) {
116            // Check if any flag param matches this flag via aliases
117            let matches_alias = flag_params.iter().any(|p| p.matches_flag(flag));
118            if !matches_alias {
119                issues.push(ValidationIssue {
120                    severity: Severity::Warning,
121                    code: IssueCode::UnknownFlag,
122                    message: format!("unknown flag '{}'", flag),
123                    span: None,
124                    suggestion: None,
125                });
126            }
127        }
128    }
129
130    // Type compatibility for named args (search the full schema — callers
131    // may name either a positional or a flag param).
132    for (key, value) in &args.named {
133        if let Some(param) = schema.params.iter().find(|p| &p.name == key)
134            && let Some(issue) = check_type_compatibility(key, value, &param.param_type) {
135                issues.push(issue);
136            }
137    }
138
139    // Type compatibility for positional args (matched by slot among
140    // positional params). Extra positionals past the schema are ignored —
141    // many builtins (cat, cp, mkdir) accept variadic positionals.
142    for (slot, value) in args.positional.iter().enumerate() {
143        if let Some(param) = positional_params.get(slot)
144            && let Some(issue) = check_type_compatibility(&param.name, value, &param.param_type) {
145                issues.push(issue);
146            }
147    }
148
149    issues
150}
151
152// ============================================================
153// Global Output Flags (--json)
154// ============================================================
155//
156// `--json` is declared per-builtin via `GlobalFlags` flatten
157// (`crate::global_flags`). Builtins parse it inside execute() and write the
158// output format via `ToolCtx::set_output_format`; the kernel applies the
159// format after execute() returns. See `docs/clap-migration.md`.
160
161/// Check if a flag name is the kernel-owned `--json` flag.
162///
163/// External commands (no schema) bypass clap entirely and the kernel
164/// doesn't touch their argv — `cargo --json` and similar work as
165/// expected. `is_global_output_flag` is retained for the validator's
166/// unknown-flag check.
167pub fn is_global_output_flag(name: &str) -> bool {
168    name == "json"
169}
170
171/// Check if a value is compatible with a type.
172fn check_type_compatibility(name: &str, value: &Value, expected_type: &str) -> Option<ValidationIssue> {
173    let compatible = match expected_type {
174        "any" => true,
175        "string" => true, // Everything can be a string
176        "int" => matches!(value, Value::Int(_) | Value::String(_)),
177        "float" => matches!(value, Value::Float(_) | Value::Int(_) | Value::String(_)),
178        "bool" => matches!(value, Value::Bool(_) | Value::String(_)),
179        "array" => matches!(value, Value::String(_)), // Arrays are passed as strings in kaish
180        "object" => matches!(value, Value::String(_)), // Objects are JSON strings
181        _ => true, // Unknown types pass
182    };
183
184    if compatible {
185        None
186    } else {
187        let code = IssueCode::InvalidArgType;
188        Some(ValidationIssue {
189            severity: code.default_severity(),
190            code,
191            message: format!(
192                "argument '{}' has type {:?}, expected {}",
193                name, value, expected_type
194            ),
195            span: None,
196            suggestion: None,
197        })
198    }
199}
200
201#[cfg(test)]
202mod validate_tests {
203    use super::*;
204    use kaish_types::{ParamSchema, ToolSchema};
205
206    fn schema_with_positionals_after_flags() -> ToolSchema {
207        // Mirrors clap-derived order: flag fields first, positionals last.
208        ToolSchema::new("demo", "demo")
209            .param(
210                ParamSchema::new("verbose", "bool")
211                    .with_default(Some(Value::Bool(false)))
212                    .with_aliases(["v"]),
213            )
214            .param(ParamSchema::new("lines", "int").with_aliases(["n"]))
215            .param(
216                ParamSchema::new("path", "string")
217                    .with_required(true)
218                    .positional(),
219            )
220    }
221
222    /// Regression for the clap-migration index-mismatch: `cat foo.txt` should
223    /// satisfy the required positional `path` even though `path` is at struct
224    /// index 2 (after `verbose`/`lines`). The old code matched positional[0]
225    /// against `verbose` and required positional[2] to exist.
226    #[test]
227    fn required_positional_satisfied_when_positional_sits_after_flags() {
228        let schema = schema_with_positionals_after_flags();
229        let mut args = ToolArgs::new();
230        args.positional.push(Value::String("foo.txt".into()));
231
232        let issues = validate_against_schema(&args, &schema);
233        assert!(
234            !issues.iter().any(|i| i.code == IssueCode::MissingRequiredArg),
235            "required positional should be satisfied by positional[0]; got {:?}",
236            issues
237        );
238    }
239
240    #[test]
241    fn required_positional_missing_when_no_positional_given() {
242        let schema = schema_with_positionals_after_flags();
243        let mut args = ToolArgs::new();
244        args.flags.insert("verbose".into());
245
246        let issues = validate_against_schema(&args, &schema);
247        assert!(
248            issues.iter().any(|i| i.code == IssueCode::MissingRequiredArg),
249            "missing required positional should error; got {:?}",
250            issues
251        );
252    }
253
254    /// Positional type check must look up the positional slot, not the
255    /// struct-field index. Here we have a string positional at slot 0; the
256    /// old code would have type-checked positional[0] against the int param
257    /// `lines` (struct index 1) and emit nothing — but now an int positional
258    /// against the string slot must be accepted, and a string positional
259    /// against an int positional slot would error.
260    #[test]
261    fn positional_type_check_targets_positional_slot_not_struct_index() {
262        let mut schema = ToolSchema::new("demo", "demo");
263        // Two positionals: count (int) then name (string).
264        schema = schema
265            .param(ParamSchema::new("verbose", "bool").with_default(Some(Value::Bool(false))))
266            .param(
267                ParamSchema::new("count", "int")
268                    .with_required(true)
269                    .positional(),
270            )
271            .param(
272                ParamSchema::new("name", "string")
273                    .with_required(true)
274                    .positional(),
275            );
276
277        let mut args = ToolArgs::new();
278        args.positional.push(Value::Int(5));
279        args.positional.push(Value::String("widget".into()));
280
281        let issues = validate_against_schema(&args, &schema);
282        assert!(
283            !issues.iter().any(|i| matches!(i.code, IssueCode::InvalidArgType)),
284            "int->int and string->string slots should validate clean; got {:?}",
285            issues
286        );
287    }
288
289    /// Required *flag* (non-positional) must still fire MissingRequiredArg
290    /// when absent — separating the loops shouldn't silently drop the check.
291    #[test]
292    fn required_flag_still_errors_when_missing() {
293        let schema = ToolSchema::new("demo", "demo").param(
294            ParamSchema::new("output", "string")
295                .with_required(true)
296                .with_aliases(["o"]),
297        );
298
299        let args = ToolArgs::new();
300        let issues = validate_against_schema(&args, &schema);
301        assert!(
302            issues.iter().any(|i| i.code == IssueCode::MissingRequiredArg),
303            "required flag should error when missing; got {:?}",
304            issues
305        );
306    }
307}