Skip to main content

kaish_kernel/tools/
traits.rs

1//! Core tool traits and types.
2
3use async_trait::async_trait;
4use std::collections::HashSet;
5
6use crate::interpreter::{ExecResult, OutputFormat};
7use crate::validator::{IssueCode, Severity, ValidationIssue};
8
9// Data types re-exported from kaish-types.
10pub use kaish_types::{ParamSchema, ToolArgs, ToolSchema};
11
12use super::context::ExecContext;
13use crate::ast::Value;
14
15/// A tool that can be executed.
16#[async_trait]
17pub trait Tool: Send + Sync {
18    /// The tool's name (used for lookup).
19    fn name(&self) -> &str;
20
21    /// Get the tool's schema.
22    fn schema(&self) -> ToolSchema;
23
24    /// Execute the tool with the given arguments and context.
25    async fn execute(&self, args: ToolArgs, ctx: &mut ExecContext) -> ExecResult;
26
27    /// Validate arguments without executing.
28    ///
29    /// Default implementation validates against the schema.
30    /// Override this for semantic checks (regex validity, zero increment, etc.).
31    fn validate(&self, args: &ToolArgs) -> Vec<ValidationIssue> {
32        validate_against_schema(args, &self.schema())
33    }
34}
35
36/// Validate arguments against a tool schema.
37///
38/// Checks:
39/// - Required parameters are provided
40/// - Unknown flags (warnings)
41/// - Type compatibility
42pub fn validate_against_schema(args: &ToolArgs, schema: &ToolSchema) -> Vec<ValidationIssue> {
43    let mut issues = Vec::new();
44
45    // Check required parameters
46    for (i, param) in schema.params.iter().enumerate() {
47        if !param.required {
48            continue;
49        }
50
51        // Check named args first, then positional
52        let has_named = args.named.contains_key(&param.name);
53        let has_positional = args.positional.len() > i;
54        let has_flag = param.param_type == "bool" && args.has_flag(&param.name);
55
56        if !has_named && !has_positional && !has_flag {
57            let code = IssueCode::MissingRequiredArg;
58            issues.push(ValidationIssue {
59                severity: code.default_severity(),
60                code,
61                message: format!("required parameter '{}' not provided", param.name),
62                span: None,
63                suggestion: Some(format!("add {} or {}=<value>", param.name, param.name)),
64            });
65        }
66    }
67
68    // Check for unknown flags (only warn - tools may accept dynamic flags)
69    let known_flags: HashSet<&str> = schema
70        .params
71        .iter()
72        .filter(|p| p.param_type == "bool")
73        .flat_map(|p| {
74            std::iter::once(p.name.as_str())
75                .chain(p.aliases.iter().map(|a| a.as_str()))
76        })
77        .collect();
78
79    for flag in &args.flags {
80        // Strip leading dashes for comparison
81        let flag_name = flag.trim_start_matches('-');
82        // Global output flags are handled by the kernel, not the tool
83        if is_global_output_flag(flag_name) {
84            continue;
85        }
86        if !known_flags.contains(flag_name) && !known_flags.contains(flag.as_str()) {
87            // Check if any param matches this flag via aliases
88            let matches_alias = schema.params.iter().any(|p| p.matches_flag(flag));
89            if !matches_alias {
90                issues.push(ValidationIssue {
91                    severity: Severity::Warning,
92                    code: IssueCode::UnknownFlag,
93                    message: format!("unknown flag '{}'", flag),
94                    span: None,
95                    suggestion: None,
96                });
97            }
98        }
99    }
100
101    // Check type compatibility for named args
102    for (key, value) in &args.named {
103        if let Some(param) = schema.params.iter().find(|p| &p.name == key)
104            && let Some(issue) = check_type_compatibility(key, value, &param.param_type) {
105                issues.push(issue);
106            }
107    }
108
109    // Check type compatibility for positional args
110    for (i, value) in args.positional.iter().enumerate() {
111        if let Some(param) = schema.params.get(i)
112            && let Some(issue) = check_type_compatibility(&param.name, value, &param.param_type) {
113                issues.push(issue);
114            }
115    }
116
117    issues
118}
119
120// ============================================================
121// Global Output Flags (--json)
122// ============================================================
123
124/// Registry of global output format flags.
125const GLOBAL_OUTPUT_FLAGS: &[(&str, OutputFormat)] = &[
126    ("json", OutputFormat::Json),
127];
128
129/// Check if a flag name is a global output flag.
130pub fn is_global_output_flag(name: &str) -> bool {
131    GLOBAL_OUTPUT_FLAGS.iter().any(|(n, _)| *n == name)
132}
133
134/// Extract and remove a global output format flag from ToolArgs.
135///
136/// Only applies to known tools with a schema. External commands
137/// (schema=None) must receive their flags untouched —
138/// `cargo --json` must not have --json stripped by the kernel.
139pub fn extract_output_format(
140    args: &mut ToolArgs,
141    schema: Option<&ToolSchema>,
142) -> Option<OutputFormat> {
143    // External commands keep their flags
144    let _schema = schema?;
145
146    for (flag_name, format) in GLOBAL_OUTPUT_FLAGS {
147        if args.flags.remove(*flag_name) {
148            return Some(*format);
149        }
150    }
151    None
152}
153
154/// Check if a value is compatible with a type.
155fn check_type_compatibility(name: &str, value: &Value, expected_type: &str) -> Option<ValidationIssue> {
156    let compatible = match expected_type {
157        "any" => true,
158        "string" => true, // Everything can be a string
159        "int" => matches!(value, Value::Int(_) | Value::String(_)),
160        "float" => matches!(value, Value::Float(_) | Value::Int(_) | Value::String(_)),
161        "bool" => matches!(value, Value::Bool(_) | Value::String(_)),
162        "array" => matches!(value, Value::String(_)), // Arrays are passed as strings in kaish
163        "object" => matches!(value, Value::String(_)), // Objects are JSON strings
164        _ => true, // Unknown types pass
165    };
166
167    if compatible {
168        None
169    } else {
170        let code = IssueCode::InvalidArgType;
171        Some(ValidationIssue {
172            severity: code.default_severity(),
173            code,
174            message: format!(
175                "argument '{}' has type {:?}, expected {}",
176                name, value, expected_type
177            ),
178            span: None,
179            suggestion: None,
180        })
181    }
182}