1use std::collections::{BTreeMap, HashSet};
4
5use crate::value::Value;
6
7fn default_consumes() -> usize {
8 1
9}
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13#[non_exhaustive]
14pub struct ParamSchema {
15 pub name: String,
17 pub param_type: String,
19 pub required: bool,
21 pub default: Option<Value>,
23 pub description: String,
25 pub aliases: Vec<String>,
27 #[serde(default = "default_consumes")]
35 pub consumes: usize,
36 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
44 pub positional: bool,
45}
46
47impl ParamSchema {
48 pub fn required(name: impl Into<String>, param_type: impl Into<String>, description: impl Into<String>) -> Self {
50 Self {
51 name: name.into(),
52 param_type: param_type.into(),
53 required: true,
54 default: None,
55 description: description.into(),
56 aliases: Vec::new(),
57 consumes: 1,
58 positional: false,
59 }
60 }
61
62 pub fn optional(name: impl Into<String>, param_type: impl Into<String>, default: Value, description: impl Into<String>) -> Self {
64 Self {
65 name: name.into(),
66 param_type: param_type.into(),
67 required: false,
68 default: Some(default),
69 description: description.into(),
70 aliases: Vec::new(),
71 consumes: 1,
72 positional: false,
73 }
74 }
75
76 pub fn new(name: impl Into<String>, param_type: impl Into<String>) -> Self {
83 Self {
84 name: name.into(),
85 param_type: param_type.into(),
86 required: false,
87 default: None,
88 description: String::new(),
89 aliases: Vec::new(),
90 consumes: 1,
91 positional: false,
92 }
93 }
94
95 pub fn with_description(mut self, description: impl Into<String>) -> Self {
97 self.description = description.into();
98 self
99 }
100
101 pub fn with_required(mut self, required: bool) -> Self {
103 self.required = required;
104 self
105 }
106
107 pub fn with_default(mut self, default: Option<Value>) -> Self {
109 self.default = default;
110 self
111 }
112
113 pub fn with_positional(mut self, positional: bool) -> Self {
116 self.positional = positional;
117 self
118 }
119
120 pub fn positional(mut self) -> Self {
125 self.positional = true;
126 self
127 }
128
129 pub fn with_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
133 self.aliases = aliases.into_iter().map(Into::into).collect();
134 self
135 }
136
137 pub fn consumes(mut self, n: usize) -> Self {
141 assert!(n >= 1, "ParamSchema::consumes requires n >= 1 (use a bool param for flags that take no value)");
142 self.consumes = n;
143 self
144 }
145
146 pub fn matches_flag(&self, flag: &str) -> bool {
148 if self.name == flag {
149 return true;
150 }
151 self.aliases.iter().any(|a| a == flag)
152 }
153}
154
155#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
157pub struct Example {
158 pub description: String,
160 pub code: String,
162}
163
164impl Example {
165 pub fn new(description: impl Into<String>, code: impl Into<String>) -> Self {
167 Self {
168 description: description.into(),
169 code: code.into(),
170 }
171 }
172}
173
174#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
176#[non_exhaustive]
177pub struct ToolSchema {
178 pub name: String,
180 pub description: String,
182 pub params: Vec<ParamSchema>,
184 pub examples: Vec<Example>,
186 pub map_positionals: bool,
190 #[serde(default, skip_serializing_if = "Vec::is_empty")]
201 pub subcommands: Vec<ToolSchema>,
202 #[serde(default, skip_serializing_if = "Vec::is_empty")]
206 pub aliases: Vec<String>,
207 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
215 pub owns_output: bool,
216}
217
218impl ToolSchema {
219 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
221 Self {
222 name: name.into(),
223 description: description.into(),
224 params: Vec::new(),
225 examples: Vec::new(),
226 map_positionals: false,
227 subcommands: Vec::new(),
228 aliases: Vec::new(),
229 owns_output: false,
230 }
231 }
232
233 pub fn with_positional_mapping(mut self) -> Self {
235 self.map_positionals = true;
236 self
237 }
238
239 pub fn param(mut self, param: ParamSchema) -> Self {
241 self.params.push(param);
242 self
243 }
244
245 pub fn example(mut self, description: impl Into<String>, code: impl Into<String>) -> Self {
247 self.examples.push(Example::new(description, code));
248 self
249 }
250
251 pub fn subcommand(mut self, child: ToolSchema) -> Self {
253 self.subcommands.push(child);
254 self
255 }
256
257 pub fn with_command_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
261 self.aliases = aliases.into_iter().map(Into::into).collect();
262 self
263 }
264
265 pub fn matches_command(&self, word: &str) -> bool {
268 self.name == word || self.aliases.iter().any(|a| a == word)
269 }
270
271 pub fn with_owned_output(mut self) -> Self {
280 self.mark_owned_output();
281 self
282 }
283
284 fn mark_owned_output(&mut self) {
285 self.owns_output = true;
286 if !self.params.iter().any(|p| p.name == "json") {
287 self.params.push(
288 ParamSchema::new("json", "bool").with_description("Render output as JSON"),
289 );
290 }
291 for child in &mut self.subcommands {
292 child.mark_owned_output();
293 }
294 }
295}
296
297#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
299#[non_exhaustive]
300pub struct ToolArgs {
301 pub positional: Vec<Value>,
303 pub named: BTreeMap<String, Value>,
305 pub flags: HashSet<String>,
307}
308
309impl ToolArgs {
310 pub fn new() -> Self {
312 Self::default()
313 }
314
315 pub fn get_positional(&self, index: usize) -> Option<&Value> {
317 self.positional.get(index)
318 }
319
320 pub fn get_named(&self, key: &str) -> Option<&Value> {
322 self.named.get(key)
323 }
324
325 pub fn get(&self, name: &str, positional_index: usize) -> Option<&Value> {
329 self.named.get(name).or_else(|| self.positional.get(positional_index))
330 }
331
332 pub fn get_string(&self, name: &str, positional_index: usize) -> Option<String> {
334 self.get(name, positional_index).and_then(|v| match v {
335 Value::String(s) => Some(s.clone()),
336 Value::Int(i) => Some(i.to_string()),
337 Value::Float(f) => Some(f.to_string()),
338 Value::Bool(b) => Some(b.to_string()),
339 _ => None,
340 })
341 }
342
343 pub fn get_bool(&self, name: &str, positional_index: usize) -> Option<bool> {
345 self.get(name, positional_index).and_then(|v| match v {
346 Value::Bool(b) => Some(*b),
347 Value::String(s) => match s.as_str() {
348 "true" | "yes" | "1" => Some(true),
349 "false" | "no" | "0" => Some(false),
350 _ => None,
351 },
352 Value::Int(i) => Some(*i != 0),
353 _ => None,
354 })
355 }
356
357 pub fn has_flag(&self, name: &str) -> bool {
359 if self.flags.contains(name) {
361 return true;
362 }
363 self.named.get(name).is_some_and(|v| match v {
365 Value::Bool(b) => *b,
366 Value::String(s) => !s.is_empty() && s != "false" && s != "0",
367 _ => true,
368 })
369 }
370
371 pub fn flagify_bool_named(&mut self, schema: &ToolSchema) {
391 let value_keys: HashSet<&str> = schema
394 .params
395 .iter()
396 .filter(|p| !p.positional && !is_bool_param_type(&p.param_type))
397 .flat_map(|p| {
398 std::iter::once(p.name.as_str())
399 .chain(p.aliases.iter().map(|a| a.trim_start_matches('-')))
400 })
401 .collect();
402
403 let bool_keys: Vec<String> = self
404 .named
405 .iter()
406 .filter(|(k, v)| matches!(v, Value::Bool(_)) && !value_keys.contains(k.as_str()))
407 .map(|(k, _)| k.clone())
408 .collect();
409 for k in bool_keys {
410 if let Some(Value::Bool(true)) = self.named.remove(&k) {
414 self.flags.insert(k);
415 }
416 }
417 }
418
419 pub fn to_argv(&self) -> Vec<String> {
432 let mut argv = Vec::with_capacity(
433 self.flags.len() + self.named.len() * 2 + self.positional.len() + 1,
434 );
435
436 let mut flags: Vec<&String> = self.flags.iter().collect();
441 flags.sort();
442 for flag in flags {
443 argv.push(flag_token(flag));
444 }
445
446 for (key, value) in &self.named {
451 for rendered in render_named_value(value) {
452 argv.push(format!("{}={}", flag_token(key), rendered));
453 }
454 }
455
456 if !self.positional.is_empty() {
459 argv.push("--".to_string());
460 for value in &self.positional {
461 argv.push(value_to_argv_token(value));
462 }
463 }
464
465 argv
466 }
467}
468
469fn flag_token(name: &str) -> String {
470 if name.chars().count() == 1 {
471 format!("-{name}")
472 } else {
473 format!("--{name}")
474 }
475}
476
477fn is_bool_param_type(param_type: &str) -> bool {
479 param_type.eq_ignore_ascii_case("bool") || param_type.eq_ignore_ascii_case("boolean")
480}
481
482fn render_named_value(value: &Value) -> Vec<String> {
483 match value {
484 Value::Json(serde_json::Value::Array(outer)) if outer.iter().all(|v| v.is_array()) => {
488 outer
489 .iter()
490 .map(|inner| {
491 inner
492 .as_array()
493 .map(|a| a.iter().map(json_value_to_token).collect::<Vec<_>>().join(" "))
494 .unwrap_or_default()
495 })
496 .collect()
497 }
498 _ => vec![value_to_argv_token(value)],
499 }
500}
501
502fn value_to_argv_token(value: &Value) -> String {
503 match value {
504 Value::Null => String::new(),
505 Value::Bool(b) => b.to_string(),
506 Value::Int(i) => i.to_string(),
507 Value::Float(f) => f.to_string(),
508 Value::String(s) => s.clone(),
509 Value::Json(j) => j.to_string(),
510 Value::Bytes(data) => format!("[binary: {} bytes]", data.len()),
514 }
515}
516
517fn json_value_to_token(value: &serde_json::Value) -> String {
518 match value {
519 serde_json::Value::Null => String::new(),
520 serde_json::Value::Bool(b) => b.to_string(),
521 serde_json::Value::Number(n) => n.to_string(),
522 serde_json::Value::String(s) => s.clone(),
523 other => other.to_string(),
524 }
525}
526
527#[cfg(test)]
528mod schema_serde_tests {
529 use super::*;
530
531 #[test]
534 fn flat_schema_omits_new_fields_on_wire() {
535 let schema = ToolSchema::new("cat", "concatenate")
536 .param(ParamSchema::required("path", "string", "file to read").positional());
537 let json = serde_json::to_value(&schema).expect("serialize");
538 let obj = json.as_object().expect("object");
539 assert!(!obj.contains_key("subcommands"), "flat tool leaks subcommands: {json}");
540 assert!(!obj.contains_key("aliases"), "flat tool leaks command aliases: {json}");
541 }
542
543 #[test]
547 fn flat_wire_form_deserializes_to_empty() {
548 let flat = serde_json::json!({
549 "name": "cat",
550 "description": "concatenate",
551 "params": [],
552 "examples": [],
553 "map_positionals": false
554 });
555 let schema: ToolSchema = serde_json::from_value(flat).expect("deserialize flat form");
556 assert!(schema.subcommands.is_empty());
557 assert!(schema.aliases.is_empty());
558 }
559
560 #[test]
563 fn with_owned_output_marks_tree_and_advertises_json() {
564 let schema = ToolSchema::new("kj", "kaijutsu")
565 .subcommand(
566 ToolSchema::new("context", "ctx")
567 .subcommand(ToolSchema::new("list", "list contexts")),
568 )
569 .with_owned_output();
570
571 assert!(schema.owns_output, "root marked");
572 assert!(schema.params.iter().any(|p| p.name == "json"), "root advertises json");
573 let context = &schema.subcommands[0];
574 assert!(context.owns_output, "child marked");
575 let list = &context.subcommands[0];
576 assert!(list.owns_output, "grandchild marked");
577 assert!(list.params.iter().any(|p| p.name == "json"), "leaf advertises json");
578 }
579
580 #[test]
582 fn with_owned_output_does_not_double_add_json() {
583 let schema = ToolSchema::new("kj", "kaijutsu")
584 .param(ParamSchema::new("json", "bool"))
585 .with_owned_output();
586 let json_count = schema.params.iter().filter(|p| p.name == "json").count();
587 assert_eq!(json_count, 1, "json should appear exactly once");
588 }
589
590 #[test]
592 fn owns_output_serde() {
593 let flat = ToolSchema::new("ls", "list");
594 let json = serde_json::to_value(&flat).expect("serialize");
595 let obj = json.as_object().expect("object");
596 assert!(!obj.contains_key("owns_output"), "false omitted: {json}");
597
598 let owned = ToolSchema::new("kj", "kaijutsu").with_owned_output();
599 let wire = serde_json::to_string(&owned).expect("serialize");
600 let back: ToolSchema = serde_json::from_str(&wire).expect("deserialize");
601 assert!(back.owns_output);
602 }
603
604 #[test]
606 fn subcommand_tree_round_trips() {
607 let schema = ToolSchema::new("kj", "kaijutsu")
608 .subcommand(
609 ToolSchema::new("context", "context ops")
610 .with_command_aliases(["ctx"])
611 .subcommand(ToolSchema::new("list", "list contexts").with_command_aliases(["ls"])),
612 );
613 let json = serde_json::to_string(&schema).expect("serialize");
614 let back: ToolSchema = serde_json::from_str(&json).expect("deserialize");
615 assert_eq!(back.subcommands.len(), 1);
616 let context = &back.subcommands[0];
617 assert!(context.matches_command("context"));
618 assert!(context.matches_command("ctx"));
619 assert_eq!(context.subcommands.len(), 1);
620 assert!(context.subcommands[0].matches_command("ls"));
621 }
622}
623
624#[cfg(test)]
625mod to_argv_tests {
626 use super::*;
627
628 #[test]
629 fn empty_args_produce_empty_argv() {
630 assert!(ToolArgs::new().to_argv().is_empty());
631 }
632
633 #[test]
634 fn positionals_emitted_after_double_dash() {
635 let mut args = ToolArgs::new();
636 args.positional.push(Value::String("hello".into()));
637 args.positional.push(Value::String("world".into()));
638 assert_eq!(args.to_argv(), vec!["--", "hello", "world"]);
639 }
640
641 #[test]
642 fn single_char_flags_emit_short_form() {
643 let mut args = ToolArgs::new();
644 args.flags.insert("n".into());
645 args.flags.insert("verbose".into());
646 assert_eq!(args.to_argv(), vec!["-n", "--verbose"]);
648 }
649
650 #[test]
651 fn named_values_use_equals_form() {
652 let mut args = ToolArgs::new();
653 args.named.insert("count".into(), Value::Int(5));
654 args.named.insert("name".into(), Value::String("foo".into()));
655 assert_eq!(args.to_argv(), vec!["--count=5", "--name=foo"]);
657 }
658
659 #[test]
660 fn single_char_named_emits_short_equals() {
661 let mut args = ToolArgs::new();
662 args.named.insert("n".into(), Value::Int(5));
663 assert_eq!(args.to_argv(), vec!["-n=5"]);
664 }
665
666 #[test]
667 fn positional_with_leading_dash_survives_double_dash() {
668 let mut args = ToolArgs::new();
669 args.positional.push(Value::String("-n".into()));
670 assert_eq!(args.to_argv(), vec!["--", "-n"]);
672 }
673
674 #[test]
675 fn mixed_flags_named_positionals() {
676 let mut args = ToolArgs::new();
677 args.flags.insert("verbose".into());
678 args.named.insert("limit".into(), Value::Int(10));
679 args.positional.push(Value::String("file.txt".into()));
680 assert_eq!(
681 args.to_argv(),
682 vec!["--verbose", "--limit=10", "--", "file.txt"]
683 );
684 }
685
686 #[test]
687 fn flagify_bool_named_promotes_true_to_flag() {
688 let mut args = ToolArgs::new();
689 args.named.insert("recursive".into(), Value::Bool(true));
690 args.named.insert("limit".into(), Value::Int(5));
691
692 args.flagify_bool_named(&ToolSchema::new("t", ""));
693
694 assert!(args.flags.contains("recursive"));
695 assert!(!args.named.contains_key("recursive"));
696 assert_eq!(args.named.get("limit"), Some(&Value::Int(5)));
698 }
699
700 #[test]
701 fn flagify_bool_named_drops_false() {
702 let mut args = ToolArgs::new();
703 args.named.insert("recursive".into(), Value::Bool(false));
704
705 args.flagify_bool_named(&ToolSchema::new("t", ""));
706
707 assert!(!args.flags.contains("recursive"));
708 assert!(!args.named.contains_key("recursive"));
709 }
710
711 #[test]
712 fn flagify_bool_named_is_idempotent() {
713 let mut args = ToolArgs::new();
714 args.named.insert("recursive".into(), Value::Bool(true));
715 args.flagify_bool_named(&ToolSchema::new("t", ""));
716 args.flagify_bool_named(&ToolSchema::new("t", ""));
717 assert!(args.flags.contains("recursive"));
718 }
719
720 #[test]
723 fn flagify_bool_named_round_trips_through_to_argv() {
724 let mut args = ToolArgs::new();
725 args.named.insert("R".into(), Value::Bool(true));
726 args.flagify_bool_named(&ToolSchema::new("t", ""));
727 let argv = args.to_argv();
728 assert!(argv.contains(&"-R".to_string()), "expected -R, got {:?}", argv);
729 assert!(!argv.iter().any(|s| s.contains('=')), "no =value should appear, got {:?}", argv);
730 }
731
732 #[test]
736 fn flagify_bool_named_keeps_value_flag_value() {
737 let mut schema = ToolSchema::new("spawn", "");
738 schema.params.push(ParamSchema::new("command", "string"));
739
740 let mut args = ToolArgs::new();
741 args.named.insert("command".into(), Value::Bool(true));
742 args.flagify_bool_named(&schema);
743
744 assert!(!args.flags.contains("command"), "value flag must not collapse to a bare flag");
745 assert_eq!(args.named.get("command"), Some(&Value::Bool(true)));
746 let argv = args.to_argv();
747 assert!(
748 argv.iter().any(|s| s == "--command=true"),
749 "expected --command=true, got {:?}",
750 argv
751 );
752 }
753
754 #[test]
759 fn flagify_bool_named_distinguishes_bool_from_value_param() {
760 let mut schema = ToolSchema::new("t", "");
761 schema.params.push(ParamSchema::new("verbose", "bool"));
762 schema.params.push(ParamSchema::new("command", "string"));
763
764 let mut args = ToolArgs::new();
765 args.named.insert("verbose".into(), Value::Bool(true));
766 args.named.insert("command".into(), Value::Bool(true));
767 args.flagify_bool_named(&schema);
768
769 assert!(args.flags.contains("verbose"));
771 assert!(!args.named.contains_key("verbose"));
772 assert!(!args.flags.contains("command"));
774 assert_eq!(args.named.get("command"), Some(&Value::Bool(true)));
775 }
776}