1use std::collections::{BTreeMap, HashSet};
4
5use crate::value::Value;
6
7fn default_consumes() -> usize {
8 1
9}
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13#[non_exhaustive]
14pub struct ParamSchema {
15 pub name: String,
17 pub param_type: String,
19 pub required: bool,
21 pub default: Option<Value>,
23 pub description: String,
25 pub aliases: Vec<String>,
27 #[serde(default = "default_consumes")]
35 pub consumes: usize,
36 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
44 pub positional: bool,
45}
46
47impl ParamSchema {
48 pub fn required(name: impl Into<String>, param_type: impl Into<String>, description: impl Into<String>) -> Self {
50 Self {
51 name: name.into(),
52 param_type: param_type.into(),
53 required: true,
54 default: None,
55 description: description.into(),
56 aliases: Vec::new(),
57 consumes: 1,
58 positional: false,
59 }
60 }
61
62 pub fn optional(name: impl Into<String>, param_type: impl Into<String>, default: Value, description: impl Into<String>) -> Self {
64 Self {
65 name: name.into(),
66 param_type: param_type.into(),
67 required: false,
68 default: Some(default),
69 description: description.into(),
70 aliases: Vec::new(),
71 consumes: 1,
72 positional: false,
73 }
74 }
75
76 pub fn new(name: impl Into<String>, param_type: impl Into<String>) -> Self {
83 Self {
84 name: name.into(),
85 param_type: param_type.into(),
86 required: false,
87 default: None,
88 description: String::new(),
89 aliases: Vec::new(),
90 consumes: 1,
91 positional: false,
92 }
93 }
94
95 pub fn with_description(mut self, description: impl Into<String>) -> Self {
97 self.description = description.into();
98 self
99 }
100
101 pub fn with_required(mut self, required: bool) -> Self {
103 self.required = required;
104 self
105 }
106
107 pub fn with_default(mut self, default: Option<Value>) -> Self {
109 self.default = default;
110 self
111 }
112
113 pub fn with_positional(mut self, positional: bool) -> Self {
116 self.positional = positional;
117 self
118 }
119
120 pub fn positional(mut self) -> Self {
125 self.positional = true;
126 self
127 }
128
129 pub fn with_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
133 self.aliases = aliases.into_iter().map(Into::into).collect();
134 self
135 }
136
137 pub fn consumes(mut self, n: usize) -> Self {
141 assert!(n >= 1, "ParamSchema::consumes requires n >= 1 (use a bool param for flags that take no value)");
142 self.consumes = n;
143 self
144 }
145
146 pub fn matches_flag(&self, flag: &str) -> bool {
148 if self.name == flag {
149 return true;
150 }
151 self.aliases.iter().any(|a| a == flag)
152 }
153}
154
155#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
157pub struct Example {
158 pub description: String,
160 pub code: String,
162}
163
164impl Example {
165 pub fn new(description: impl Into<String>, code: impl Into<String>) -> Self {
167 Self {
168 description: description.into(),
169 code: code.into(),
170 }
171 }
172}
173
174#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
176#[non_exhaustive]
177pub struct ToolSchema {
178 pub name: String,
180 pub description: String,
182 pub params: Vec<ParamSchema>,
184 pub examples: Vec<Example>,
186 pub map_positionals: bool,
190 #[serde(default, skip_serializing_if = "Vec::is_empty")]
201 pub subcommands: Vec<ToolSchema>,
202 #[serde(default, skip_serializing_if = "Vec::is_empty")]
206 pub aliases: Vec<String>,
207 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
215 pub owns_output: bool,
216}
217
218impl ToolSchema {
219 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
221 Self {
222 name: name.into(),
223 description: description.into(),
224 params: Vec::new(),
225 examples: Vec::new(),
226 map_positionals: false,
227 subcommands: Vec::new(),
228 aliases: Vec::new(),
229 owns_output: false,
230 }
231 }
232
233 pub fn with_positional_mapping(mut self) -> Self {
235 self.map_positionals = true;
236 self
237 }
238
239 pub fn param(mut self, param: ParamSchema) -> Self {
241 self.params.push(param);
242 self
243 }
244
245 pub fn example(mut self, description: impl Into<String>, code: impl Into<String>) -> Self {
247 self.examples.push(Example::new(description, code));
248 self
249 }
250
251 pub fn subcommand(mut self, child: ToolSchema) -> Self {
253 self.subcommands.push(child);
254 self
255 }
256
257 pub fn with_command_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
261 self.aliases = aliases.into_iter().map(Into::into).collect();
262 self
263 }
264
265 pub fn matches_command(&self, word: &str) -> bool {
268 self.name == word || self.aliases.iter().any(|a| a == word)
269 }
270
271 pub fn with_owned_output(mut self) -> Self {
280 self.mark_owned_output();
281 self
282 }
283
284 fn mark_owned_output(&mut self) {
285 self.owns_output = true;
286 if !self.params.iter().any(|p| p.name == "json") {
287 self.params.push(
288 ParamSchema::new("json", "bool").with_description("Render output as JSON"),
289 );
290 }
291 for child in &mut self.subcommands {
292 child.mark_owned_output();
293 }
294 }
295}
296
297#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
299#[non_exhaustive]
300pub struct ToolArgs {
301 pub positional: Vec<Value>,
303 pub named: BTreeMap<String, Value>,
305 pub flags: HashSet<String>,
307}
308
309impl ToolArgs {
310 pub fn new() -> Self {
312 Self::default()
313 }
314
315 pub fn get_positional(&self, index: usize) -> Option<&Value> {
317 self.positional.get(index)
318 }
319
320 pub fn get_named(&self, key: &str) -> Option<&Value> {
322 self.named.get(key)
323 }
324
325 pub fn get(&self, name: &str, positional_index: usize) -> Option<&Value> {
329 self.named.get(name).or_else(|| self.positional.get(positional_index))
330 }
331
332 pub fn get_string(&self, name: &str, positional_index: usize) -> Option<String> {
334 self.get(name, positional_index).and_then(|v| match v {
335 Value::String(s) => Some(s.clone()),
336 Value::Int(i) => Some(i.to_string()),
337 Value::Float(f) => Some(f.to_string()),
338 Value::Bool(b) => Some(b.to_string()),
339 _ => None,
340 })
341 }
342
343 pub fn get_bool(&self, name: &str, positional_index: usize) -> Option<bool> {
345 self.get(name, positional_index).and_then(|v| match v {
346 Value::Bool(b) => Some(*b),
347 Value::String(s) => match s.as_str() {
348 "true" | "yes" | "1" => Some(true),
349 "false" | "no" | "0" => Some(false),
350 _ => None,
351 },
352 Value::Int(i) => Some(*i != 0),
353 _ => None,
354 })
355 }
356
357 pub fn has_flag(&self, name: &str) -> bool {
359 if self.flags.contains(name) {
361 return true;
362 }
363 self.named.get(name).is_some_and(|v| match v {
365 Value::Bool(b) => *b,
366 Value::String(s) => !s.is_empty() && s != "false" && s != "0",
367 _ => true,
368 })
369 }
370
371 pub fn flagify_bool_named(&mut self) {
384 let bool_keys: Vec<String> = self
385 .named
386 .iter()
387 .filter(|(_, v)| matches!(v, Value::Bool(_)))
388 .map(|(k, _)| k.clone())
389 .collect();
390 for k in bool_keys {
391 if let Some(Value::Bool(true)) = self.named.remove(&k) {
395 self.flags.insert(k);
396 }
397 }
398 }
399
400 pub fn to_argv(&self) -> Vec<String> {
413 let mut argv = Vec::with_capacity(
414 self.flags.len() + self.named.len() * 2 + self.positional.len() + 1,
415 );
416
417 let mut flags: Vec<&String> = self.flags.iter().collect();
422 flags.sort();
423 for flag in flags {
424 argv.push(flag_token(flag));
425 }
426
427 for (key, value) in &self.named {
432 for rendered in render_named_value(value) {
433 argv.push(format!("{}={}", flag_token(key), rendered));
434 }
435 }
436
437 if !self.positional.is_empty() {
440 argv.push("--".to_string());
441 for value in &self.positional {
442 argv.push(value_to_argv_token(value));
443 }
444 }
445
446 argv
447 }
448}
449
450fn flag_token(name: &str) -> String {
451 if name.chars().count() == 1 {
452 format!("-{name}")
453 } else {
454 format!("--{name}")
455 }
456}
457
458fn render_named_value(value: &Value) -> Vec<String> {
459 match value {
460 Value::Json(serde_json::Value::Array(outer)) if outer.iter().all(|v| v.is_array()) => {
464 outer
465 .iter()
466 .map(|inner| {
467 inner
468 .as_array()
469 .map(|a| a.iter().map(json_value_to_token).collect::<Vec<_>>().join(" "))
470 .unwrap_or_default()
471 })
472 .collect()
473 }
474 _ => vec![value_to_argv_token(value)],
475 }
476}
477
478fn value_to_argv_token(value: &Value) -> String {
479 match value {
480 Value::Null => String::new(),
481 Value::Bool(b) => b.to_string(),
482 Value::Int(i) => i.to_string(),
483 Value::Float(f) => f.to_string(),
484 Value::String(s) => s.clone(),
485 Value::Json(j) => j.to_string(),
486 Value::Blob(b) => format!("[blob: {} {}]", b.formatted_size(), b.content_type),
487 }
488}
489
490fn json_value_to_token(value: &serde_json::Value) -> String {
491 match value {
492 serde_json::Value::Null => String::new(),
493 serde_json::Value::Bool(b) => b.to_string(),
494 serde_json::Value::Number(n) => n.to_string(),
495 serde_json::Value::String(s) => s.clone(),
496 other => other.to_string(),
497 }
498}
499
500#[cfg(test)]
501mod schema_serde_tests {
502 use super::*;
503
504 #[test]
507 fn flat_schema_omits_new_fields_on_wire() {
508 let schema = ToolSchema::new("cat", "concatenate")
509 .param(ParamSchema::required("path", "string", "file to read").positional());
510 let json = serde_json::to_value(&schema).expect("serialize");
511 let obj = json.as_object().expect("object");
512 assert!(!obj.contains_key("subcommands"), "flat tool leaks subcommands: {json}");
513 assert!(!obj.contains_key("aliases"), "flat tool leaks command aliases: {json}");
514 }
515
516 #[test]
520 fn flat_wire_form_deserializes_to_empty() {
521 let flat = serde_json::json!({
522 "name": "cat",
523 "description": "concatenate",
524 "params": [],
525 "examples": [],
526 "map_positionals": false
527 });
528 let schema: ToolSchema = serde_json::from_value(flat).expect("deserialize flat form");
529 assert!(schema.subcommands.is_empty());
530 assert!(schema.aliases.is_empty());
531 }
532
533 #[test]
536 fn with_owned_output_marks_tree_and_advertises_json() {
537 let schema = ToolSchema::new("kj", "kaijutsu")
538 .subcommand(
539 ToolSchema::new("context", "ctx")
540 .subcommand(ToolSchema::new("list", "list contexts")),
541 )
542 .with_owned_output();
543
544 assert!(schema.owns_output, "root marked");
545 assert!(schema.params.iter().any(|p| p.name == "json"), "root advertises json");
546 let context = &schema.subcommands[0];
547 assert!(context.owns_output, "child marked");
548 let list = &context.subcommands[0];
549 assert!(list.owns_output, "grandchild marked");
550 assert!(list.params.iter().any(|p| p.name == "json"), "leaf advertises json");
551 }
552
553 #[test]
555 fn with_owned_output_does_not_double_add_json() {
556 let schema = ToolSchema::new("kj", "kaijutsu")
557 .param(ParamSchema::new("json", "bool"))
558 .with_owned_output();
559 let json_count = schema.params.iter().filter(|p| p.name == "json").count();
560 assert_eq!(json_count, 1, "json should appear exactly once");
561 }
562
563 #[test]
565 fn owns_output_serde() {
566 let flat = ToolSchema::new("ls", "list");
567 let json = serde_json::to_value(&flat).expect("serialize");
568 let obj = json.as_object().expect("object");
569 assert!(!obj.contains_key("owns_output"), "false omitted: {json}");
570
571 let owned = ToolSchema::new("kj", "kaijutsu").with_owned_output();
572 let wire = serde_json::to_string(&owned).expect("serialize");
573 let back: ToolSchema = serde_json::from_str(&wire).expect("deserialize");
574 assert!(back.owns_output);
575 }
576
577 #[test]
579 fn subcommand_tree_round_trips() {
580 let schema = ToolSchema::new("kj", "kaijutsu")
581 .subcommand(
582 ToolSchema::new("context", "context ops")
583 .with_command_aliases(["ctx"])
584 .subcommand(ToolSchema::new("list", "list contexts").with_command_aliases(["ls"])),
585 );
586 let json = serde_json::to_string(&schema).expect("serialize");
587 let back: ToolSchema = serde_json::from_str(&json).expect("deserialize");
588 assert_eq!(back.subcommands.len(), 1);
589 let context = &back.subcommands[0];
590 assert!(context.matches_command("context"));
591 assert!(context.matches_command("ctx"));
592 assert_eq!(context.subcommands.len(), 1);
593 assert!(context.subcommands[0].matches_command("ls"));
594 }
595}
596
597#[cfg(test)]
598mod to_argv_tests {
599 use super::*;
600
601 #[test]
602 fn empty_args_produce_empty_argv() {
603 assert!(ToolArgs::new().to_argv().is_empty());
604 }
605
606 #[test]
607 fn positionals_emitted_after_double_dash() {
608 let mut args = ToolArgs::new();
609 args.positional.push(Value::String("hello".into()));
610 args.positional.push(Value::String("world".into()));
611 assert_eq!(args.to_argv(), vec!["--", "hello", "world"]);
612 }
613
614 #[test]
615 fn single_char_flags_emit_short_form() {
616 let mut args = ToolArgs::new();
617 args.flags.insert("n".into());
618 args.flags.insert("verbose".into());
619 assert_eq!(args.to_argv(), vec!["-n", "--verbose"]);
621 }
622
623 #[test]
624 fn named_values_use_equals_form() {
625 let mut args = ToolArgs::new();
626 args.named.insert("count".into(), Value::Int(5));
627 args.named.insert("name".into(), Value::String("foo".into()));
628 assert_eq!(args.to_argv(), vec!["--count=5", "--name=foo"]);
630 }
631
632 #[test]
633 fn single_char_named_emits_short_equals() {
634 let mut args = ToolArgs::new();
635 args.named.insert("n".into(), Value::Int(5));
636 assert_eq!(args.to_argv(), vec!["-n=5"]);
637 }
638
639 #[test]
640 fn positional_with_leading_dash_survives_double_dash() {
641 let mut args = ToolArgs::new();
642 args.positional.push(Value::String("-n".into()));
643 assert_eq!(args.to_argv(), vec!["--", "-n"]);
645 }
646
647 #[test]
648 fn mixed_flags_named_positionals() {
649 let mut args = ToolArgs::new();
650 args.flags.insert("verbose".into());
651 args.named.insert("limit".into(), Value::Int(10));
652 args.positional.push(Value::String("file.txt".into()));
653 assert_eq!(
654 args.to_argv(),
655 vec!["--verbose", "--limit=10", "--", "file.txt"]
656 );
657 }
658
659 #[test]
660 fn flagify_bool_named_promotes_true_to_flag() {
661 let mut args = ToolArgs::new();
662 args.named.insert("recursive".into(), Value::Bool(true));
663 args.named.insert("limit".into(), Value::Int(5));
664
665 args.flagify_bool_named();
666
667 assert!(args.flags.contains("recursive"));
668 assert!(!args.named.contains_key("recursive"));
669 assert_eq!(args.named.get("limit"), Some(&Value::Int(5)));
671 }
672
673 #[test]
674 fn flagify_bool_named_drops_false() {
675 let mut args = ToolArgs::new();
676 args.named.insert("recursive".into(), Value::Bool(false));
677
678 args.flagify_bool_named();
679
680 assert!(!args.flags.contains("recursive"));
681 assert!(!args.named.contains_key("recursive"));
682 }
683
684 #[test]
685 fn flagify_bool_named_is_idempotent() {
686 let mut args = ToolArgs::new();
687 args.named.insert("recursive".into(), Value::Bool(true));
688 args.flagify_bool_named();
689 args.flagify_bool_named();
690 assert!(args.flags.contains("recursive"));
691 }
692
693 #[test]
696 fn flagify_bool_named_round_trips_through_to_argv() {
697 let mut args = ToolArgs::new();
698 args.named.insert("R".into(), Value::Bool(true));
699 args.flagify_bool_named();
700 let argv = args.to_argv();
701 assert!(argv.contains(&"-R".to_string()), "expected -R, got {:?}", argv);
702 assert!(!argv.iter().any(|s| s.contains('=')), "no =value should appear, got {:?}", argv);
703 }
704}