Skip to main content

kaish_kernel/tools/
traits.rs

1//! Core tool traits and types.
2
3use async_trait::async_trait;
4use std::collections::{HashMap, HashSet};
5
6use crate::ast::Value;
7use crate::interpreter::{ExecResult, OutputFormat};
8use crate::validator::{IssueCode, Severity, ValidationIssue};
9
10use super::context::ExecContext;
11
12/// Schema for a tool parameter.
13#[derive(Debug, Clone)]
14pub struct ParamSchema {
15    /// Parameter name.
16    pub name: String,
17    /// Type hint (string, int, bool, array, object, any).
18    pub param_type: String,
19    /// Whether this parameter is required.
20    pub required: bool,
21    /// Default value if not required.
22    pub default: Option<Value>,
23    /// Description for help text.
24    pub description: String,
25    /// Alternative names/flags for this parameter (e.g., "-r", "-R" for "recursive").
26    pub aliases: Vec<String>,
27}
28
29impl ParamSchema {
30    /// Create a required parameter.
31    pub fn required(name: impl Into<String>, param_type: impl Into<String>, description: impl Into<String>) -> Self {
32        Self {
33            name: name.into(),
34            param_type: param_type.into(),
35            required: true,
36            default: None,
37            description: description.into(),
38            aliases: Vec::new(),
39        }
40    }
41
42    /// Create an optional parameter with a default value.
43    pub fn optional(name: impl Into<String>, param_type: impl Into<String>, default: Value, description: impl Into<String>) -> Self {
44        Self {
45            name: name.into(),
46            param_type: param_type.into(),
47            required: false,
48            default: Some(default),
49            description: description.into(),
50            aliases: Vec::new(),
51        }
52    }
53
54    /// Add alternative names/flags for this parameter.
55    ///
56    /// Aliases are used for short flags like `-r`, `-R` that map to `recursive`.
57    pub fn with_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
58        self.aliases = aliases.into_iter().map(Into::into).collect();
59        self
60    }
61
62    /// Check if a flag name matches this parameter or any of its aliases.
63    pub fn matches_flag(&self, flag: &str) -> bool {
64        if self.name == flag {
65            return true;
66        }
67        self.aliases.iter().any(|a| a == flag)
68    }
69}
70
71/// An example showing how to use a tool.
72#[derive(Debug, Clone)]
73pub struct Example {
74    /// Short description of what the example demonstrates.
75    pub description: String,
76    /// The example command/code.
77    pub code: String,
78}
79
80impl Example {
81    /// Create a new example.
82    pub fn new(description: impl Into<String>, code: impl Into<String>) -> Self {
83        Self {
84            description: description.into(),
85            code: code.into(),
86        }
87    }
88}
89
90/// Schema describing a tool's interface.
91#[derive(Debug, Clone)]
92pub struct ToolSchema {
93    /// Tool name.
94    pub name: String,
95    /// Short description.
96    pub description: String,
97    /// Parameter definitions.
98    pub params: Vec<ParamSchema>,
99    /// Usage examples.
100    pub examples: Vec<Example>,
101}
102
103impl ToolSchema {
104    /// Create a new tool schema.
105    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
106        Self {
107            name: name.into(),
108            description: description.into(),
109            params: Vec::new(),
110            examples: Vec::new(),
111        }
112    }
113
114    /// Add a parameter to the schema.
115    pub fn param(mut self, param: ParamSchema) -> Self {
116        self.params.push(param);
117        self
118    }
119
120    /// Add an example to the schema.
121    pub fn example(mut self, description: impl Into<String>, code: impl Into<String>) -> Self {
122        self.examples.push(Example::new(description, code));
123        self
124    }
125}
126
127/// Parsed arguments ready for tool execution.
128#[derive(Debug, Clone, Default)]
129pub struct ToolArgs {
130    /// Positional arguments in order.
131    pub positional: Vec<Value>,
132    /// Named arguments by key.
133    pub named: HashMap<String, Value>,
134    /// Boolean flags (e.g., -l, --force).
135    pub flags: HashSet<String>,
136}
137
138impl ToolArgs {
139    /// Create empty args.
140    pub fn new() -> Self {
141        Self::default()
142    }
143
144    /// Get a positional argument by index.
145    pub fn get_positional(&self, index: usize) -> Option<&Value> {
146        self.positional.get(index)
147    }
148
149    /// Get a named argument by key.
150    pub fn get_named(&self, key: &str) -> Option<&Value> {
151        self.named.get(key)
152    }
153
154    /// Get a named argument or positional fallback.
155    ///
156    /// Useful for tools that accept both `cat file.txt` and `cat path=file.txt`.
157    pub fn get(&self, name: &str, positional_index: usize) -> Option<&Value> {
158        self.named.get(name).or_else(|| self.positional.get(positional_index))
159    }
160
161    /// Get a string value from args.
162    pub fn get_string(&self, name: &str, positional_index: usize) -> Option<String> {
163        self.get(name, positional_index).and_then(|v| match v {
164            Value::String(s) => Some(s.clone()),
165            Value::Int(i) => Some(i.to_string()),
166            Value::Float(f) => Some(f.to_string()),
167            Value::Bool(b) => Some(b.to_string()),
168            _ => None,
169        })
170    }
171
172    /// Get a boolean value from args.
173    pub fn get_bool(&self, name: &str, positional_index: usize) -> Option<bool> {
174        self.get(name, positional_index).and_then(|v| match v {
175            Value::Bool(b) => Some(*b),
176            Value::String(s) => match s.as_str() {
177                "true" | "yes" | "1" => Some(true),
178                "false" | "no" | "0" => Some(false),
179                _ => None,
180            },
181            Value::Int(i) => Some(*i != 0),
182            _ => None,
183        })
184    }
185
186    /// Check if a flag is set (in flags set, or named bool).
187    pub fn has_flag(&self, name: &str) -> bool {
188        // Check the flags set first (from -x or --name syntax)
189        if self.flags.contains(name) {
190            return true;
191        }
192        // Fall back to checking named args (from name=true syntax)
193        self.named.get(name).is_some_and(|v| match v {
194            Value::Bool(b) => *b,
195            Value::String(s) => !s.is_empty() && s != "false" && s != "0",
196            _ => true,
197        })
198    }
199}
200
201/// A tool that can be executed.
202#[async_trait]
203pub trait Tool: Send + Sync {
204    /// The tool's name (used for lookup).
205    fn name(&self) -> &str;
206
207    /// Get the tool's schema.
208    fn schema(&self) -> ToolSchema;
209
210    /// Execute the tool with the given arguments and context.
211    async fn execute(&self, args: ToolArgs, ctx: &mut ExecContext) -> ExecResult;
212
213    /// Validate arguments without executing.
214    ///
215    /// Default implementation validates against the schema.
216    /// Override this for semantic checks (regex validity, zero increment, etc.).
217    fn validate(&self, args: &ToolArgs) -> Vec<ValidationIssue> {
218        validate_against_schema(args, &self.schema())
219    }
220}
221
222/// Validate arguments against a tool schema.
223///
224/// Checks:
225/// - Required parameters are provided
226/// - Unknown flags (warnings)
227/// - Type compatibility
228pub fn validate_against_schema(args: &ToolArgs, schema: &ToolSchema) -> Vec<ValidationIssue> {
229    let mut issues = Vec::new();
230
231    // Check required parameters
232    for (i, param) in schema.params.iter().enumerate() {
233        if !param.required {
234            continue;
235        }
236
237        // Check named args first, then positional
238        let has_named = args.named.contains_key(&param.name);
239        let has_positional = args.positional.len() > i;
240        let has_flag = param.param_type == "bool" && args.has_flag(&param.name);
241
242        if !has_named && !has_positional && !has_flag {
243            let code = IssueCode::MissingRequiredArg;
244            issues.push(ValidationIssue {
245                severity: code.default_severity(),
246                code,
247                message: format!("required parameter '{}' not provided", param.name),
248                span: None,
249                suggestion: Some(format!("add {} or {}=<value>", param.name, param.name)),
250            });
251        }
252    }
253
254    // Check for unknown flags (only warn - tools may accept dynamic flags)
255    let known_flags: HashSet<&str> = schema
256        .params
257        .iter()
258        .filter(|p| p.param_type == "bool")
259        .flat_map(|p| {
260            std::iter::once(p.name.as_str())
261                .chain(p.aliases.iter().map(|a| a.as_str()))
262        })
263        .collect();
264
265    for flag in &args.flags {
266        // Strip leading dashes for comparison
267        let flag_name = flag.trim_start_matches('-');
268        // Global output flags are handled by the kernel, not the tool
269        if is_global_output_flag(flag_name) {
270            continue;
271        }
272        if !known_flags.contains(flag_name) && !known_flags.contains(flag.as_str()) {
273            // Check if any param matches this flag via aliases
274            let matches_alias = schema.params.iter().any(|p| p.matches_flag(flag));
275            if !matches_alias {
276                issues.push(ValidationIssue {
277                    severity: Severity::Warning,
278                    code: IssueCode::UnknownFlag,
279                    message: format!("unknown flag '{}'", flag),
280                    span: None,
281                    suggestion: None,
282                });
283            }
284        }
285    }
286
287    // Check type compatibility for named args
288    for (key, value) in &args.named {
289        if let Some(param) = schema.params.iter().find(|p| &p.name == key)
290            && let Some(issue) = check_type_compatibility(key, value, &param.param_type) {
291                issues.push(issue);
292            }
293    }
294
295    // Check type compatibility for positional args
296    for (i, value) in args.positional.iter().enumerate() {
297        if let Some(param) = schema.params.get(i)
298            && let Some(issue) = check_type_compatibility(&param.name, value, &param.param_type) {
299                issues.push(issue);
300            }
301    }
302
303    issues
304}
305
306// ============================================================
307// Global Output Flags (--json)
308// ============================================================
309
310/// Registry of global output format flags.
311const GLOBAL_OUTPUT_FLAGS: &[(&str, OutputFormat)] = &[
312    ("json", OutputFormat::Json),
313];
314
315/// Check if a flag name is a global output flag.
316pub fn is_global_output_flag(name: &str) -> bool {
317    GLOBAL_OUTPUT_FLAGS.iter().any(|(n, _)| *n == name)
318}
319
320/// Extract and remove a global output format flag from ToolArgs.
321///
322/// Only applies to known tools with a schema. External commands
323/// (schema=None) must receive their flags untouched —
324/// `cargo --json` must not have --json stripped by the kernel.
325pub fn extract_output_format(
326    args: &mut ToolArgs,
327    schema: Option<&ToolSchema>,
328) -> Option<OutputFormat> {
329    // External commands keep their flags
330    let _schema = schema?;
331
332    for (flag_name, format) in GLOBAL_OUTPUT_FLAGS {
333        if args.flags.remove(*flag_name) {
334            return Some(*format);
335        }
336    }
337    None
338}
339
340/// Check if a value is compatible with a type.
341fn check_type_compatibility(name: &str, value: &Value, expected_type: &str) -> Option<ValidationIssue> {
342    let compatible = match expected_type {
343        "any" => true,
344        "string" => true, // Everything can be a string
345        "int" => matches!(value, Value::Int(_) | Value::String(_)),
346        "float" => matches!(value, Value::Float(_) | Value::Int(_) | Value::String(_)),
347        "bool" => matches!(value, Value::Bool(_) | Value::String(_)),
348        "array" => matches!(value, Value::String(_)), // Arrays are passed as strings in kaish
349        "object" => matches!(value, Value::String(_)), // Objects are JSON strings
350        _ => true, // Unknown types pass
351    };
352
353    if compatible {
354        None
355    } else {
356        let code = IssueCode::InvalidArgType;
357        Some(ValidationIssue {
358            severity: code.default_severity(),
359            code,
360            message: format!(
361                "argument '{}' has type {:?}, expected {}",
362                name, value, expected_type
363            ),
364            span: None,
365            suggestion: None,
366        })
367    }
368}