1use std::collections::{BTreeMap, HashSet};
4
5use crate::value::Value;
6
7fn default_consumes() -> usize {
8 1
9}
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct ParamSchema {
14 pub name: String,
16 pub param_type: String,
18 pub required: bool,
20 pub default: Option<Value>,
22 pub description: String,
24 pub aliases: Vec<String>,
26 #[serde(default = "default_consumes")]
34 pub consumes: usize,
35}
36
37impl ParamSchema {
38 pub fn required(name: impl Into<String>, param_type: impl Into<String>, description: impl Into<String>) -> Self {
40 Self {
41 name: name.into(),
42 param_type: param_type.into(),
43 required: true,
44 default: None,
45 description: description.into(),
46 aliases: Vec::new(),
47 consumes: 1,
48 }
49 }
50
51 pub fn optional(name: impl Into<String>, param_type: impl Into<String>, default: Value, description: impl Into<String>) -> Self {
53 Self {
54 name: name.into(),
55 param_type: param_type.into(),
56 required: false,
57 default: Some(default),
58 description: description.into(),
59 aliases: Vec::new(),
60 consumes: 1,
61 }
62 }
63
64 pub fn with_aliases(mut self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
68 self.aliases = aliases.into_iter().map(Into::into).collect();
69 self
70 }
71
72 pub fn consumes(mut self, n: usize) -> Self {
76 assert!(n >= 1, "ParamSchema::consumes requires n >= 1 (use a bool param for flags that take no value)");
77 self.consumes = n;
78 self
79 }
80
81 pub fn matches_flag(&self, flag: &str) -> bool {
83 if self.name == flag {
84 return true;
85 }
86 self.aliases.iter().any(|a| a == flag)
87 }
88}
89
90#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
92pub struct Example {
93 pub description: String,
95 pub code: String,
97}
98
99impl Example {
100 pub fn new(description: impl Into<String>, code: impl Into<String>) -> Self {
102 Self {
103 description: description.into(),
104 code: code.into(),
105 }
106 }
107}
108
109#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
111pub struct ToolSchema {
112 pub name: String,
114 pub description: String,
116 pub params: Vec<ParamSchema>,
118 pub examples: Vec<Example>,
120 pub map_positionals: bool,
124}
125
126impl ToolSchema {
127 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
129 Self {
130 name: name.into(),
131 description: description.into(),
132 params: Vec::new(),
133 examples: Vec::new(),
134 map_positionals: false,
135 }
136 }
137
138 pub fn with_positional_mapping(mut self) -> Self {
140 self.map_positionals = true;
141 self
142 }
143
144 pub fn param(mut self, param: ParamSchema) -> Self {
146 self.params.push(param);
147 self
148 }
149
150 pub fn example(mut self, description: impl Into<String>, code: impl Into<String>) -> Self {
152 self.examples.push(Example::new(description, code));
153 self
154 }
155}
156
157#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
159pub struct ToolArgs {
160 pub positional: Vec<Value>,
162 pub named: BTreeMap<String, Value>,
164 pub flags: HashSet<String>,
166}
167
168impl ToolArgs {
169 pub fn new() -> Self {
171 Self::default()
172 }
173
174 pub fn get_positional(&self, index: usize) -> Option<&Value> {
176 self.positional.get(index)
177 }
178
179 pub fn get_named(&self, key: &str) -> Option<&Value> {
181 self.named.get(key)
182 }
183
184 pub fn get(&self, name: &str, positional_index: usize) -> Option<&Value> {
188 self.named.get(name).or_else(|| self.positional.get(positional_index))
189 }
190
191 pub fn get_string(&self, name: &str, positional_index: usize) -> Option<String> {
193 self.get(name, positional_index).and_then(|v| match v {
194 Value::String(s) => Some(s.clone()),
195 Value::Int(i) => Some(i.to_string()),
196 Value::Float(f) => Some(f.to_string()),
197 Value::Bool(b) => Some(b.to_string()),
198 _ => None,
199 })
200 }
201
202 pub fn get_bool(&self, name: &str, positional_index: usize) -> Option<bool> {
204 self.get(name, positional_index).and_then(|v| match v {
205 Value::Bool(b) => Some(*b),
206 Value::String(s) => match s.as_str() {
207 "true" | "yes" | "1" => Some(true),
208 "false" | "no" | "0" => Some(false),
209 _ => None,
210 },
211 Value::Int(i) => Some(*i != 0),
212 _ => None,
213 })
214 }
215
216 pub fn has_flag(&self, name: &str) -> bool {
218 if self.flags.contains(name) {
220 return true;
221 }
222 self.named.get(name).is_some_and(|v| match v {
224 Value::Bool(b) => *b,
225 Value::String(s) => !s.is_empty() && s != "false" && s != "0",
226 _ => true,
227 })
228 }
229}