1use std::collections::{BTreeMap, HashSet};
4
5use crate::value::Value;
6
7fn default_consumes() -> usize {
8 1
9}
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13#[non_exhaustive]
14pub struct ParamSchema {
15 pub name: String,
17 pub param_type: String,
19 pub required: bool,
21 pub default: Option<Value>,
23 pub description: String,
25 pub aliases: Vec<String>,
27 #[serde(default = "default_consumes")]
35 pub consumes: usize,
36 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
44 pub positional: bool,
45}
46
47impl ParamSchema {
48 pub fn required(name: impl Into<String>, param_type: impl Into<String>, description: impl Into<String>) -> Self {
50 Self {
51 name: name.into(),
52 param_type: param_type.into(),
53 required: true,
54 default: None,
55 description: description.into(),
56 aliases: Vec::new(),
57 consumes: 1,
58 positional: false,
59 }
60 }
61
62 pub fn optional(name: impl Into<String>, param_type: impl Into<String>, default: Value, description: impl Into<String>) -> Self {
64 Self {
65 name: name.into(),
66 param_type: param_type.into(),
67 required: false,
68 default: Some(default),
69 description: description.into(),
70 aliases: Vec::new(),
71 consumes: 1,
72 positional: false,
73 }
74 }
75
76 pub fn new(name: impl Into<String>, param_type: impl Into<String>) -> Self {
83 Self {
84 name: name.into(),
85 param_type: param_type.into(),
86 required: false,
87 default: None,
88 description: String::new(),
89 aliases: Vec::new(),
90 consumes: 1,
91 positional: false,
92 }
93 }
94
95 pub fn with_description(mut self, description: impl Into<String>) -> Self {
97 self.description = description.into();
98 self
99 }
100
101 pub fn with_required(mut self, required: bool) -> Self {
103 self.required = required;
104 self
105 }
106
107 pub fn with_default(mut self, default: Option<Value>) -> Self {
109 self.default = default;
110 self
111 }
112
113 pub fn with_positional(mut self, positional: bool) -> Self {
116 self.positional = positional;
117 self
118 }
119
120 pub fn positional(mut self) -> Self {
125 self.positional = true;
126 self
127 }
128
129 pub fn with_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
133 self.aliases = aliases.into_iter().map(Into::into).collect();
134 self
135 }
136
137 pub fn consumes(mut self, n: usize) -> Self {
141 assert!(n >= 1, "ParamSchema::consumes requires n >= 1 (use a bool param for flags that take no value)");
142 self.consumes = n;
143 self
144 }
145
146 pub fn matches_flag(&self, flag: &str) -> bool {
148 if self.name == flag {
149 return true;
150 }
151 self.aliases.iter().any(|a| a == flag)
152 }
153}
154
155#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
157pub struct Example {
158 pub description: String,
160 pub code: String,
162}
163
164impl Example {
165 pub fn new(description: impl Into<String>, code: impl Into<String>) -> Self {
167 Self {
168 description: description.into(),
169 code: code.into(),
170 }
171 }
172}
173
174#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
176#[non_exhaustive]
177pub struct ToolSchema {
178 pub name: String,
180 pub description: String,
182 pub params: Vec<ParamSchema>,
184 pub examples: Vec<Example>,
186 pub map_positionals: bool,
190 #[serde(default, skip_serializing_if = "Vec::is_empty")]
201 pub subcommands: Vec<ToolSchema>,
202 #[serde(default, skip_serializing_if = "Vec::is_empty")]
206 pub aliases: Vec<String>,
207 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
215 pub owns_output: bool,
216}
217
218impl ToolSchema {
219 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
221 Self {
222 name: name.into(),
223 description: description.into(),
224 params: Vec::new(),
225 examples: Vec::new(),
226 map_positionals: false,
227 subcommands: Vec::new(),
228 aliases: Vec::new(),
229 owns_output: false,
230 }
231 }
232
233 pub fn with_positional_mapping(mut self) -> Self {
235 self.map_positionals = true;
236 self
237 }
238
239 pub fn param(mut self, param: ParamSchema) -> Self {
241 self.params.push(param);
242 self
243 }
244
245 pub fn example(mut self, description: impl Into<String>, code: impl Into<String>) -> Self {
247 self.examples.push(Example::new(description, code));
248 self
249 }
250
251 pub fn subcommand(mut self, child: ToolSchema) -> Self {
253 self.subcommands.push(child);
254 self
255 }
256
257 pub fn with_command_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
261 self.aliases = aliases.into_iter().map(Into::into).collect();
262 self
263 }
264
265 pub fn matches_command(&self, word: &str) -> bool {
268 self.name == word || self.aliases.iter().any(|a| a == word)
269 }
270
271 pub fn with_owned_output(mut self) -> Self {
280 self.mark_owned_output();
281 self
282 }
283
284 fn mark_owned_output(&mut self) {
285 self.owns_output = true;
286 if !self.params.iter().any(|p| p.name == "json") {
287 self.params.push(
288 ParamSchema::new("json", "bool").with_description("Render output as JSON"),
289 );
290 }
291 for child in &mut self.subcommands {
292 child.mark_owned_output();
293 }
294 }
295}
296
297#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
299#[non_exhaustive]
300pub struct ToolArgs {
301 pub positional: Vec<Value>,
303 pub named: BTreeMap<String, Value>,
305 pub flags: HashSet<String>,
307}
308
309impl ToolArgs {
310 pub fn new() -> Self {
312 Self::default()
313 }
314
315 pub fn get_positional(&self, index: usize) -> Option<&Value> {
317 self.positional.get(index)
318 }
319
320 pub fn get_named(&self, key: &str) -> Option<&Value> {
322 self.named.get(key)
323 }
324
325 pub fn get(&self, name: &str, positional_index: usize) -> Option<&Value> {
329 self.named.get(name).or_else(|| self.positional.get(positional_index))
330 }
331
332 pub fn get_string(&self, name: &str, positional_index: usize) -> Option<String> {
334 self.get(name, positional_index).and_then(|v| match v {
335 Value::String(s) => Some(s.clone()),
336 Value::Int(i) => Some(i.to_string()),
337 Value::Float(f) => Some(f.to_string()),
338 Value::Bool(b) => Some(b.to_string()),
339 _ => None,
340 })
341 }
342
343 pub fn get_bool(&self, name: &str, positional_index: usize) -> Option<bool> {
345 self.get(name, positional_index).and_then(|v| match v {
346 Value::Bool(b) => Some(*b),
347 Value::String(s) => match s.as_str() {
348 "true" | "yes" | "1" => Some(true),
349 "false" | "no" | "0" => Some(false),
350 _ => None,
351 },
352 Value::Int(i) => Some(*i != 0),
353 _ => None,
354 })
355 }
356
357 pub fn has_flag(&self, name: &str) -> bool {
359 if self.flags.contains(name) {
361 return true;
362 }
363 self.named.get(name).is_some_and(|v| match v {
365 Value::Bool(b) => *b,
366 Value::String(s) => !s.is_empty() && s != "false" && s != "0",
367 _ => true,
368 })
369 }
370
371 pub fn flagify_bool_named(&mut self) {
384 let bool_keys: Vec<String> = self
385 .named
386 .iter()
387 .filter(|(_, v)| matches!(v, Value::Bool(_)))
388 .map(|(k, _)| k.clone())
389 .collect();
390 for k in bool_keys {
391 if let Some(Value::Bool(true)) = self.named.remove(&k) {
395 self.flags.insert(k);
396 }
397 }
398 }
399
400 pub fn to_argv(&self) -> Vec<String> {
413 let mut argv = Vec::with_capacity(
414 self.flags.len() + self.named.len() * 2 + self.positional.len() + 1,
415 );
416
417 let mut flags: Vec<&String> = self.flags.iter().collect();
422 flags.sort();
423 for flag in flags {
424 argv.push(flag_token(flag));
425 }
426
427 for (key, value) in &self.named {
432 for rendered in render_named_value(value) {
433 argv.push(format!("{}={}", flag_token(key), rendered));
434 }
435 }
436
437 if !self.positional.is_empty() {
440 argv.push("--".to_string());
441 for value in &self.positional {
442 argv.push(value_to_argv_token(value));
443 }
444 }
445
446 argv
447 }
448}
449
450fn flag_token(name: &str) -> String {
451 if name.chars().count() == 1 {
452 format!("-{name}")
453 } else {
454 format!("--{name}")
455 }
456}
457
458fn render_named_value(value: &Value) -> Vec<String> {
459 match value {
460 Value::Json(serde_json::Value::Array(outer)) if outer.iter().all(|v| v.is_array()) => {
464 outer
465 .iter()
466 .map(|inner| {
467 inner
468 .as_array()
469 .map(|a| a.iter().map(json_value_to_token).collect::<Vec<_>>().join(" "))
470 .unwrap_or_default()
471 })
472 .collect()
473 }
474 _ => vec![value_to_argv_token(value)],
475 }
476}
477
478fn value_to_argv_token(value: &Value) -> String {
479 match value {
480 Value::Null => String::new(),
481 Value::Bool(b) => b.to_string(),
482 Value::Int(i) => i.to_string(),
483 Value::Float(f) => f.to_string(),
484 Value::String(s) => s.clone(),
485 Value::Json(j) => j.to_string(),
486 Value::Bytes(data) => format!("[binary: {} bytes]", data.len()),
490 }
491}
492
493fn json_value_to_token(value: &serde_json::Value) -> String {
494 match value {
495 serde_json::Value::Null => String::new(),
496 serde_json::Value::Bool(b) => b.to_string(),
497 serde_json::Value::Number(n) => n.to_string(),
498 serde_json::Value::String(s) => s.clone(),
499 other => other.to_string(),
500 }
501}
502
503#[cfg(test)]
504mod schema_serde_tests {
505 use super::*;
506
507 #[test]
510 fn flat_schema_omits_new_fields_on_wire() {
511 let schema = ToolSchema::new("cat", "concatenate")
512 .param(ParamSchema::required("path", "string", "file to read").positional());
513 let json = serde_json::to_value(&schema).expect("serialize");
514 let obj = json.as_object().expect("object");
515 assert!(!obj.contains_key("subcommands"), "flat tool leaks subcommands: {json}");
516 assert!(!obj.contains_key("aliases"), "flat tool leaks command aliases: {json}");
517 }
518
519 #[test]
523 fn flat_wire_form_deserializes_to_empty() {
524 let flat = serde_json::json!({
525 "name": "cat",
526 "description": "concatenate",
527 "params": [],
528 "examples": [],
529 "map_positionals": false
530 });
531 let schema: ToolSchema = serde_json::from_value(flat).expect("deserialize flat form");
532 assert!(schema.subcommands.is_empty());
533 assert!(schema.aliases.is_empty());
534 }
535
536 #[test]
539 fn with_owned_output_marks_tree_and_advertises_json() {
540 let schema = ToolSchema::new("kj", "kaijutsu")
541 .subcommand(
542 ToolSchema::new("context", "ctx")
543 .subcommand(ToolSchema::new("list", "list contexts")),
544 )
545 .with_owned_output();
546
547 assert!(schema.owns_output, "root marked");
548 assert!(schema.params.iter().any(|p| p.name == "json"), "root advertises json");
549 let context = &schema.subcommands[0];
550 assert!(context.owns_output, "child marked");
551 let list = &context.subcommands[0];
552 assert!(list.owns_output, "grandchild marked");
553 assert!(list.params.iter().any(|p| p.name == "json"), "leaf advertises json");
554 }
555
556 #[test]
558 fn with_owned_output_does_not_double_add_json() {
559 let schema = ToolSchema::new("kj", "kaijutsu")
560 .param(ParamSchema::new("json", "bool"))
561 .with_owned_output();
562 let json_count = schema.params.iter().filter(|p| p.name == "json").count();
563 assert_eq!(json_count, 1, "json should appear exactly once");
564 }
565
566 #[test]
568 fn owns_output_serde() {
569 let flat = ToolSchema::new("ls", "list");
570 let json = serde_json::to_value(&flat).expect("serialize");
571 let obj = json.as_object().expect("object");
572 assert!(!obj.contains_key("owns_output"), "false omitted: {json}");
573
574 let owned = ToolSchema::new("kj", "kaijutsu").with_owned_output();
575 let wire = serde_json::to_string(&owned).expect("serialize");
576 let back: ToolSchema = serde_json::from_str(&wire).expect("deserialize");
577 assert!(back.owns_output);
578 }
579
580 #[test]
582 fn subcommand_tree_round_trips() {
583 let schema = ToolSchema::new("kj", "kaijutsu")
584 .subcommand(
585 ToolSchema::new("context", "context ops")
586 .with_command_aliases(["ctx"])
587 .subcommand(ToolSchema::new("list", "list contexts").with_command_aliases(["ls"])),
588 );
589 let json = serde_json::to_string(&schema).expect("serialize");
590 let back: ToolSchema = serde_json::from_str(&json).expect("deserialize");
591 assert_eq!(back.subcommands.len(), 1);
592 let context = &back.subcommands[0];
593 assert!(context.matches_command("context"));
594 assert!(context.matches_command("ctx"));
595 assert_eq!(context.subcommands.len(), 1);
596 assert!(context.subcommands[0].matches_command("ls"));
597 }
598}
599
600#[cfg(test)]
601mod to_argv_tests {
602 use super::*;
603
604 #[test]
605 fn empty_args_produce_empty_argv() {
606 assert!(ToolArgs::new().to_argv().is_empty());
607 }
608
609 #[test]
610 fn positionals_emitted_after_double_dash() {
611 let mut args = ToolArgs::new();
612 args.positional.push(Value::String("hello".into()));
613 args.positional.push(Value::String("world".into()));
614 assert_eq!(args.to_argv(), vec!["--", "hello", "world"]);
615 }
616
617 #[test]
618 fn single_char_flags_emit_short_form() {
619 let mut args = ToolArgs::new();
620 args.flags.insert("n".into());
621 args.flags.insert("verbose".into());
622 assert_eq!(args.to_argv(), vec!["-n", "--verbose"]);
624 }
625
626 #[test]
627 fn named_values_use_equals_form() {
628 let mut args = ToolArgs::new();
629 args.named.insert("count".into(), Value::Int(5));
630 args.named.insert("name".into(), Value::String("foo".into()));
631 assert_eq!(args.to_argv(), vec!["--count=5", "--name=foo"]);
633 }
634
635 #[test]
636 fn single_char_named_emits_short_equals() {
637 let mut args = ToolArgs::new();
638 args.named.insert("n".into(), Value::Int(5));
639 assert_eq!(args.to_argv(), vec!["-n=5"]);
640 }
641
642 #[test]
643 fn positional_with_leading_dash_survives_double_dash() {
644 let mut args = ToolArgs::new();
645 args.positional.push(Value::String("-n".into()));
646 assert_eq!(args.to_argv(), vec!["--", "-n"]);
648 }
649
650 #[test]
651 fn mixed_flags_named_positionals() {
652 let mut args = ToolArgs::new();
653 args.flags.insert("verbose".into());
654 args.named.insert("limit".into(), Value::Int(10));
655 args.positional.push(Value::String("file.txt".into()));
656 assert_eq!(
657 args.to_argv(),
658 vec!["--verbose", "--limit=10", "--", "file.txt"]
659 );
660 }
661
662 #[test]
663 fn flagify_bool_named_promotes_true_to_flag() {
664 let mut args = ToolArgs::new();
665 args.named.insert("recursive".into(), Value::Bool(true));
666 args.named.insert("limit".into(), Value::Int(5));
667
668 args.flagify_bool_named();
669
670 assert!(args.flags.contains("recursive"));
671 assert!(!args.named.contains_key("recursive"));
672 assert_eq!(args.named.get("limit"), Some(&Value::Int(5)));
674 }
675
676 #[test]
677 fn flagify_bool_named_drops_false() {
678 let mut args = ToolArgs::new();
679 args.named.insert("recursive".into(), Value::Bool(false));
680
681 args.flagify_bool_named();
682
683 assert!(!args.flags.contains("recursive"));
684 assert!(!args.named.contains_key("recursive"));
685 }
686
687 #[test]
688 fn flagify_bool_named_is_idempotent() {
689 let mut args = ToolArgs::new();
690 args.named.insert("recursive".into(), Value::Bool(true));
691 args.flagify_bool_named();
692 args.flagify_bool_named();
693 assert!(args.flags.contains("recursive"));
694 }
695
696 #[test]
699 fn flagify_bool_named_round_trips_through_to_argv() {
700 let mut args = ToolArgs::new();
701 args.named.insert("R".into(), Value::Bool(true));
702 args.flagify_bool_named();
703 let argv = args.to_argv();
704 assert!(argv.contains(&"-R".to_string()), "expected -R, got {:?}", argv);
705 assert!(!argv.iter().any(|s| s.contains('=')), "no =value should appear, got {:?}", argv);
706 }
707}