Skip to main content

kaish_kernel/tools/
traits.rs

1//! Core tool traits and types.
2
3use async_trait::async_trait;
4use std::collections::{HashMap, HashSet};
5
6use crate::ast::Value;
7use crate::interpreter::{ExecResult, OutputFormat};
8use crate::validator::{IssueCode, Severity, ValidationIssue};
9
10use super::context::ExecContext;
11
12/// Schema for a tool parameter.
13#[derive(Debug, Clone)]
14pub struct ParamSchema {
15    /// Parameter name.
16    pub name: String,
17    /// Type hint (string, int, bool, array, object, any).
18    pub param_type: String,
19    /// Whether this parameter is required.
20    pub required: bool,
21    /// Default value if not required.
22    pub default: Option<Value>,
23    /// Description for help text.
24    pub description: String,
25    /// Alternative names/flags for this parameter (e.g., "-r", "-R" for "recursive").
26    pub aliases: Vec<String>,
27}
28
29impl ParamSchema {
30    /// Create a required parameter.
31    pub fn required(name: impl Into<String>, param_type: impl Into<String>, description: impl Into<String>) -> Self {
32        Self {
33            name: name.into(),
34            param_type: param_type.into(),
35            required: true,
36            default: None,
37            description: description.into(),
38            aliases: Vec::new(),
39        }
40    }
41
42    /// Create an optional parameter with a default value.
43    pub fn optional(name: impl Into<String>, param_type: impl Into<String>, default: Value, description: impl Into<String>) -> Self {
44        Self {
45            name: name.into(),
46            param_type: param_type.into(),
47            required: false,
48            default: Some(default),
49            description: description.into(),
50            aliases: Vec::new(),
51        }
52    }
53
54    /// Add alternative names/flags for this parameter.
55    ///
56    /// Aliases are used for short flags like `-r`, `-R` that map to `recursive`.
57    pub fn with_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
58        self.aliases = aliases.into_iter().map(Into::into).collect();
59        self
60    }
61
62    /// Check if a flag name matches this parameter or any of its aliases.
63    pub fn matches_flag(&self, flag: &str) -> bool {
64        if self.name == flag {
65            return true;
66        }
67        self.aliases.iter().any(|a| a == flag)
68    }
69}
70
71/// An example showing how to use a tool.
72#[derive(Debug, Clone)]
73pub struct Example {
74    /// Short description of what the example demonstrates.
75    pub description: String,
76    /// The example command/code.
77    pub code: String,
78}
79
80impl Example {
81    /// Create a new example.
82    pub fn new(description: impl Into<String>, code: impl Into<String>) -> Self {
83        Self {
84            description: description.into(),
85            code: code.into(),
86        }
87    }
88}
89
90/// Schema describing a tool's interface.
91#[derive(Debug, Clone)]
92pub struct ToolSchema {
93    /// Tool name.
94    pub name: String,
95    /// Short description.
96    pub description: String,
97    /// Parameter definitions.
98    pub params: Vec<ParamSchema>,
99    /// Usage examples.
100    pub examples: Vec<Example>,
101    /// Map remaining positional args to named params by schema order.
102    /// Only for MCP/external tools that expect named JSON params.
103    /// Builtins handle their own positionals and should leave this false.
104    pub map_positionals: bool,
105}
106
107impl ToolSchema {
108    /// Create a new tool schema.
109    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
110        Self {
111            name: name.into(),
112            description: description.into(),
113            params: Vec::new(),
114            examples: Vec::new(),
115            map_positionals: false,
116        }
117    }
118
119    /// Enable positional→named parameter mapping for MCP/external tools.
120    pub fn with_positional_mapping(mut self) -> Self {
121        self.map_positionals = true;
122        self
123    }
124
125    /// Add a parameter to the schema.
126    pub fn param(mut self, param: ParamSchema) -> Self {
127        self.params.push(param);
128        self
129    }
130
131    /// Add an example to the schema.
132    pub fn example(mut self, description: impl Into<String>, code: impl Into<String>) -> Self {
133        self.examples.push(Example::new(description, code));
134        self
135    }
136}
137
138/// Parsed arguments ready for tool execution.
139#[derive(Debug, Clone, Default)]
140pub struct ToolArgs {
141    /// Positional arguments in order.
142    pub positional: Vec<Value>,
143    /// Named arguments by key.
144    pub named: HashMap<String, Value>,
145    /// Boolean flags (e.g., -l, --force).
146    pub flags: HashSet<String>,
147}
148
149impl ToolArgs {
150    /// Create empty args.
151    pub fn new() -> Self {
152        Self::default()
153    }
154
155    /// Get a positional argument by index.
156    pub fn get_positional(&self, index: usize) -> Option<&Value> {
157        self.positional.get(index)
158    }
159
160    /// Get a named argument by key.
161    pub fn get_named(&self, key: &str) -> Option<&Value> {
162        self.named.get(key)
163    }
164
165    /// Get a named argument or positional fallback.
166    ///
167    /// Useful for tools that accept both `cat file.txt` and `cat path=file.txt`.
168    pub fn get(&self, name: &str, positional_index: usize) -> Option<&Value> {
169        self.named.get(name).or_else(|| self.positional.get(positional_index))
170    }
171
172    /// Get a string value from args.
173    pub fn get_string(&self, name: &str, positional_index: usize) -> Option<String> {
174        self.get(name, positional_index).and_then(|v| match v {
175            Value::String(s) => Some(s.clone()),
176            Value::Int(i) => Some(i.to_string()),
177            Value::Float(f) => Some(f.to_string()),
178            Value::Bool(b) => Some(b.to_string()),
179            _ => None,
180        })
181    }
182
183    /// Get a boolean value from args.
184    pub fn get_bool(&self, name: &str, positional_index: usize) -> Option<bool> {
185        self.get(name, positional_index).and_then(|v| match v {
186            Value::Bool(b) => Some(*b),
187            Value::String(s) => match s.as_str() {
188                "true" | "yes" | "1" => Some(true),
189                "false" | "no" | "0" => Some(false),
190                _ => None,
191            },
192            Value::Int(i) => Some(*i != 0),
193            _ => None,
194        })
195    }
196
197    /// Check if a flag is set (in flags set, or named bool).
198    pub fn has_flag(&self, name: &str) -> bool {
199        // Check the flags set first (from -x or --name syntax)
200        if self.flags.contains(name) {
201            return true;
202        }
203        // Fall back to checking named args (from name=true syntax)
204        self.named.get(name).is_some_and(|v| match v {
205            Value::Bool(b) => *b,
206            Value::String(s) => !s.is_empty() && s != "false" && s != "0",
207            _ => true,
208        })
209    }
210}
211
212/// A tool that can be executed.
213#[async_trait]
214pub trait Tool: Send + Sync {
215    /// The tool's name (used for lookup).
216    fn name(&self) -> &str;
217
218    /// Get the tool's schema.
219    fn schema(&self) -> ToolSchema;
220
221    /// Execute the tool with the given arguments and context.
222    async fn execute(&self, args: ToolArgs, ctx: &mut ExecContext) -> ExecResult;
223
224    /// Validate arguments without executing.
225    ///
226    /// Default implementation validates against the schema.
227    /// Override this for semantic checks (regex validity, zero increment, etc.).
228    fn validate(&self, args: &ToolArgs) -> Vec<ValidationIssue> {
229        validate_against_schema(args, &self.schema())
230    }
231}
232
233/// Validate arguments against a tool schema.
234///
235/// Checks:
236/// - Required parameters are provided
237/// - Unknown flags (warnings)
238/// - Type compatibility
239pub fn validate_against_schema(args: &ToolArgs, schema: &ToolSchema) -> Vec<ValidationIssue> {
240    let mut issues = Vec::new();
241
242    // Check required parameters
243    for (i, param) in schema.params.iter().enumerate() {
244        if !param.required {
245            continue;
246        }
247
248        // Check named args first, then positional
249        let has_named = args.named.contains_key(&param.name);
250        let has_positional = args.positional.len() > i;
251        let has_flag = param.param_type == "bool" && args.has_flag(&param.name);
252
253        if !has_named && !has_positional && !has_flag {
254            let code = IssueCode::MissingRequiredArg;
255            issues.push(ValidationIssue {
256                severity: code.default_severity(),
257                code,
258                message: format!("required parameter '{}' not provided", param.name),
259                span: None,
260                suggestion: Some(format!("add {} or {}=<value>", param.name, param.name)),
261            });
262        }
263    }
264
265    // Check for unknown flags (only warn - tools may accept dynamic flags)
266    let known_flags: HashSet<&str> = schema
267        .params
268        .iter()
269        .filter(|p| p.param_type == "bool")
270        .flat_map(|p| {
271            std::iter::once(p.name.as_str())
272                .chain(p.aliases.iter().map(|a| a.as_str()))
273        })
274        .collect();
275
276    for flag in &args.flags {
277        // Strip leading dashes for comparison
278        let flag_name = flag.trim_start_matches('-');
279        // Global output flags are handled by the kernel, not the tool
280        if is_global_output_flag(flag_name) {
281            continue;
282        }
283        if !known_flags.contains(flag_name) && !known_flags.contains(flag.as_str()) {
284            // Check if any param matches this flag via aliases
285            let matches_alias = schema.params.iter().any(|p| p.matches_flag(flag));
286            if !matches_alias {
287                issues.push(ValidationIssue {
288                    severity: Severity::Warning,
289                    code: IssueCode::UnknownFlag,
290                    message: format!("unknown flag '{}'", flag),
291                    span: None,
292                    suggestion: None,
293                });
294            }
295        }
296    }
297
298    // Check type compatibility for named args
299    for (key, value) in &args.named {
300        if let Some(param) = schema.params.iter().find(|p| &p.name == key)
301            && let Some(issue) = check_type_compatibility(key, value, &param.param_type) {
302                issues.push(issue);
303            }
304    }
305
306    // Check type compatibility for positional args
307    for (i, value) in args.positional.iter().enumerate() {
308        if let Some(param) = schema.params.get(i)
309            && let Some(issue) = check_type_compatibility(&param.name, value, &param.param_type) {
310                issues.push(issue);
311            }
312    }
313
314    issues
315}
316
317// ============================================================
318// Global Output Flags (--json)
319// ============================================================
320
321/// Registry of global output format flags.
322const GLOBAL_OUTPUT_FLAGS: &[(&str, OutputFormat)] = &[
323    ("json", OutputFormat::Json),
324];
325
326/// Check if a flag name is a global output flag.
327pub fn is_global_output_flag(name: &str) -> bool {
328    GLOBAL_OUTPUT_FLAGS.iter().any(|(n, _)| *n == name)
329}
330
331/// Extract and remove a global output format flag from ToolArgs.
332///
333/// Only applies to known tools with a schema. External commands
334/// (schema=None) must receive their flags untouched —
335/// `cargo --json` must not have --json stripped by the kernel.
336pub fn extract_output_format(
337    args: &mut ToolArgs,
338    schema: Option<&ToolSchema>,
339) -> Option<OutputFormat> {
340    // External commands keep their flags
341    let _schema = schema?;
342
343    for (flag_name, format) in GLOBAL_OUTPUT_FLAGS {
344        if args.flags.remove(*flag_name) {
345            return Some(*format);
346        }
347    }
348    None
349}
350
351/// Check if a value is compatible with a type.
352fn check_type_compatibility(name: &str, value: &Value, expected_type: &str) -> Option<ValidationIssue> {
353    let compatible = match expected_type {
354        "any" => true,
355        "string" => true, // Everything can be a string
356        "int" => matches!(value, Value::Int(_) | Value::String(_)),
357        "float" => matches!(value, Value::Float(_) | Value::Int(_) | Value::String(_)),
358        "bool" => matches!(value, Value::Bool(_) | Value::String(_)),
359        "array" => matches!(value, Value::String(_)), // Arrays are passed as strings in kaish
360        "object" => matches!(value, Value::String(_)), // Objects are JSON strings
361        _ => true, // Unknown types pass
362    };
363
364    if compatible {
365        None
366    } else {
367        let code = IssueCode::InvalidArgType;
368        Some(ValidationIssue {
369            severity: code.default_severity(),
370            code,
371            message: format!(
372                "argument '{}' has type {:?}, expected {}",
373                name, value, expected_type
374            ),
375            span: None,
376            suggestion: None,
377        })
378    }
379}