Skip to main content

kaish_types/
tool.rs

1//! Tool schema and argument types.
2
3use std::collections::{BTreeMap, HashSet};
4
5use crate::value::Value;
6
7fn default_consumes() -> usize {
8    1
9}
10
11/// Schema for a tool parameter.
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct ParamSchema {
14    /// Parameter name.
15    pub name: String,
16    /// Type hint (string, int, bool, array, object, any).
17    pub param_type: String,
18    /// Whether this parameter is required.
19    pub required: bool,
20    /// Default value if not required.
21    pub default: Option<Value>,
22    /// Description for help text.
23    pub description: String,
24    /// Alternative names/flags for this parameter (e.g., "-r", "-R" for "recursive").
25    pub aliases: Vec<String>,
26    /// Number of positional tokens this non-bool flag consumes per occurrence.
27    ///
28    /// Default 1 (standard `--flag value`). Set to 2 for `--flag NAME VALUE`
29    /// patterns such as jq's `--arg` / `--argjson`. When `consumes > 1`, the
30    /// kernel collects each occurrence as an inner array and accumulates
31    /// repeated occurrences under the same `named` key — the tool sees a
32    /// `Value::Json(Array(Array(...)))` listing every (N-tuple) occurrence.
33    #[serde(default = "default_consumes")]
34    pub consumes: usize,
35}
36
37impl ParamSchema {
38    /// Create a required parameter.
39    pub fn required(name: impl Into<String>, param_type: impl Into<String>, description: impl Into<String>) -> Self {
40        Self {
41            name: name.into(),
42            param_type: param_type.into(),
43            required: true,
44            default: None,
45            description: description.into(),
46            aliases: Vec::new(),
47            consumes: 1,
48        }
49    }
50
51    /// Create an optional parameter with a default value.
52    pub fn optional(name: impl Into<String>, param_type: impl Into<String>, default: Value, description: impl Into<String>) -> Self {
53        Self {
54            name: name.into(),
55            param_type: param_type.into(),
56            required: false,
57            default: Some(default),
58            description: description.into(),
59            aliases: Vec::new(),
60            consumes: 1,
61        }
62    }
63
64    /// Add alternative names/flags for this parameter.
65    ///
66    /// Aliases are used for short flags like `-r`, `-R` that map to `recursive`.
67    pub fn with_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
68        self.aliases = aliases.into_iter().map(Into::into).collect();
69        self
70    }
71
72    /// Declare how many positional tokens this non-bool flag consumes per
73    /// occurrence (`--flag v1 v2 ...`). Default is 1. Panics on 0 — a flag
74    /// that consumes nothing is a bool flag, not a schema-typed param.
75    pub fn consumes(mut self, n: usize) -> Self {
76        assert!(n >= 1, "ParamSchema::consumes requires n >= 1 (use a bool param for flags that take no value)");
77        self.consumes = n;
78        self
79    }
80
81    /// Check if a flag name matches this parameter or any of its aliases.
82    pub fn matches_flag(&self, flag: &str) -> bool {
83        if self.name == flag {
84            return true;
85        }
86        self.aliases.iter().any(|a| a == flag)
87    }
88}
89
90/// An example showing how to use a tool.
91#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
92pub struct Example {
93    /// Short description of what the example demonstrates.
94    pub description: String,
95    /// The example command/code.
96    pub code: String,
97}
98
99impl Example {
100    /// Create a new example.
101    pub fn new(description: impl Into<String>, code: impl Into<String>) -> Self {
102        Self {
103            description: description.into(),
104            code: code.into(),
105        }
106    }
107}
108
109/// Schema describing a tool's interface.
110#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
111pub struct ToolSchema {
112    /// Tool name.
113    pub name: String,
114    /// Short description.
115    pub description: String,
116    /// Parameter definitions.
117    pub params: Vec<ParamSchema>,
118    /// Usage examples.
119    pub examples: Vec<Example>,
120    /// Map remaining positional args to named params by schema order.
121    /// Only for MCP/external tools that expect named JSON params.
122    /// Builtins handle their own positionals and should leave this false.
123    pub map_positionals: bool,
124}
125
126impl ToolSchema {
127    /// Create a new tool schema.
128    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
129        Self {
130            name: name.into(),
131            description: description.into(),
132            params: Vec::new(),
133            examples: Vec::new(),
134            map_positionals: false,
135        }
136    }
137
138    /// Enable positional->named parameter mapping for MCP/external tools.
139    pub fn with_positional_mapping(mut self) -> Self {
140        self.map_positionals = true;
141        self
142    }
143
144    /// Add a parameter to the schema.
145    pub fn param(mut self, param: ParamSchema) -> Self {
146        self.params.push(param);
147        self
148    }
149
150    /// Add an example to the schema.
151    pub fn example(mut self, description: impl Into<String>, code: impl Into<String>) -> Self {
152        self.examples.push(Example::new(description, code));
153        self
154    }
155}
156
157/// Parsed arguments ready for tool execution.
158#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
159pub struct ToolArgs {
160    /// Positional arguments in order.
161    pub positional: Vec<Value>,
162    /// Named arguments by key.
163    pub named: BTreeMap<String, Value>,
164    /// Boolean flags (e.g., -l, --force).
165    pub flags: HashSet<String>,
166}
167
168impl ToolArgs {
169    /// Create empty args.
170    pub fn new() -> Self {
171        Self::default()
172    }
173
174    /// Get a positional argument by index.
175    pub fn get_positional(&self, index: usize) -> Option<&Value> {
176        self.positional.get(index)
177    }
178
179    /// Get a named argument by key.
180    pub fn get_named(&self, key: &str) -> Option<&Value> {
181        self.named.get(key)
182    }
183
184    /// Get a named argument or positional fallback.
185    ///
186    /// Useful for tools that accept both `cat file.txt` and `cat path=file.txt`.
187    pub fn get(&self, name: &str, positional_index: usize) -> Option<&Value> {
188        self.named.get(name).or_else(|| self.positional.get(positional_index))
189    }
190
191    /// Get a string value from args.
192    pub fn get_string(&self, name: &str, positional_index: usize) -> Option<String> {
193        self.get(name, positional_index).and_then(|v| match v {
194            Value::String(s) => Some(s.clone()),
195            Value::Int(i) => Some(i.to_string()),
196            Value::Float(f) => Some(f.to_string()),
197            Value::Bool(b) => Some(b.to_string()),
198            _ => None,
199        })
200    }
201
202    /// Get a boolean value from args.
203    pub fn get_bool(&self, name: &str, positional_index: usize) -> Option<bool> {
204        self.get(name, positional_index).and_then(|v| match v {
205            Value::Bool(b) => Some(*b),
206            Value::String(s) => match s.as_str() {
207                "true" | "yes" | "1" => Some(true),
208                "false" | "no" | "0" => Some(false),
209                _ => None,
210            },
211            Value::Int(i) => Some(*i != 0),
212            _ => None,
213        })
214    }
215
216    /// Check if a flag is set (in flags set, or named bool).
217    pub fn has_flag(&self, name: &str) -> bool {
218        // Check the flags set first (from -x or --name syntax)
219        if self.flags.contains(name) {
220            return true;
221        }
222        // Fall back to checking named args (from name=true syntax)
223        self.named.get(name).is_some_and(|v| match v {
224            Value::Bool(b) => *b,
225            Value::String(s) => !s.is_empty() && s != "false" && s != "0",
226            _ => true,
227        })
228    }
229}