1use std::borrow::Cow;
5use std::fmt::Write;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum InvocationHint {
9 FencedBlock(&'static str),
11 ToolCall,
13}
14
15#[derive(Debug, Clone)]
16pub struct ToolDef {
17 pub id: Cow<'static, str>,
18 pub description: Cow<'static, str>,
19 pub schema: schemars::Schema,
20 pub invocation: InvocationHint,
21 pub output_schema: Option<serde_json::Value>,
25}
26
27#[derive(Debug, Default)]
28pub struct ToolRegistry {
29 tools: Vec<ToolDef>,
30}
31
32impl ToolRegistry {
33 #[must_use]
34 pub fn from_definitions(tools: Vec<ToolDef>) -> Self {
35 Self { tools }
36 }
37
38 #[must_use]
39 pub fn tools(&self) -> &[ToolDef] {
40 &self.tools
41 }
42
43 #[must_use]
44 pub fn find(&self, id: &str) -> Option<&ToolDef> {
45 self.tools.iter().find(|t| t.id.as_ref() == id)
46 }
47
48 #[must_use]
50 pub fn format_for_prompt_filtered(
51 &self,
52 policy: &crate::permissions::PermissionPolicy,
53 ) -> String {
54 let mut out = String::from("<tools>\n");
55 for tool in &self.tools {
56 if policy.is_fully_denied(&tool.id) {
57 continue;
58 }
59 format_tool(&mut out, tool);
60 }
61 out.push_str("</tools>");
62 out
63 }
64}
65
66fn format_tool(out: &mut String, tool: &ToolDef) {
67 let _ = writeln!(out, "## {}", tool.id);
68 let _ = writeln!(out, "{}", tool.description);
69 match tool.invocation {
70 InvocationHint::FencedBlock(tag) => {
71 let _ = writeln!(out, "Invocation: use ```{tag} fenced block");
72 }
73 InvocationHint::ToolCall => {
74 let _ = writeln!(
75 out,
76 "Invocation: use tool_call with {{\"tool_id\": \"{}\", \"params\": {{...}}}}",
77 tool.id
78 );
79 }
80 }
81 format_schema_params(out, &tool.schema);
82 out.push('\n');
83}
84
85fn extract_non_null_type(obj: &serde_json::Map<String, serde_json::Value>) -> Option<&str> {
88 if let Some(arr) = obj.get("type").and_then(|v| v.as_array()) {
89 return arr.iter().filter_map(|v| v.as_str()).find(|t| *t != "null");
90 }
91 obj.get("anyOf")?
92 .as_array()?
93 .iter()
94 .filter_map(|v| v.as_object())
95 .filter_map(|o| o.get("type")?.as_str())
96 .find(|t| *t != "null")
97}
98
99fn format_schema_params(out: &mut String, schema: &schemars::Schema) {
100 let Some(obj) = schema.as_object() else {
101 return;
102 };
103 let Some(serde_json::Value::Object(props)) = obj.get("properties") else {
104 return;
105 };
106 if props.is_empty() {
107 return;
108 }
109
110 let required: Vec<&str> = obj
111 .get("required")
112 .and_then(|v| v.as_array())
113 .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
114 .unwrap_or_default();
115
116 let _ = writeln!(out, "Parameters:");
117 for (name, prop) in props {
118 let prop_obj = prop.as_object();
119 let ty = prop_obj
120 .and_then(|o| {
121 o.get("type")
122 .and_then(|v| v.as_str())
123 .or_else(|| extract_non_null_type(o))
124 })
125 .unwrap_or("string");
126 let desc = prop_obj
127 .and_then(|o| o.get("description"))
128 .and_then(|v| v.as_str())
129 .unwrap_or("");
130 let req = if required.contains(&name.as_str()) {
131 "required"
132 } else {
133 "optional"
134 };
135 let _ = writeln!(out, " - {name}: {desc} ({ty}, {req})");
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::file::ReadParams;
143 use crate::shell::BashParams;
144
145 fn sample_tools() -> Vec<ToolDef> {
146 vec![
147 ToolDef {
148 id: "bash".into(),
149 description: "Execute a shell command".into(),
150 schema: schemars::schema_for!(BashParams),
151 invocation: InvocationHint::FencedBlock("bash"),
152 output_schema: None,
153 },
154 ToolDef {
155 id: "read".into(),
156 description: "Read file contents".into(),
157 schema: schemars::schema_for!(ReadParams),
158 invocation: InvocationHint::ToolCall,
159 output_schema: None,
160 },
161 ]
162 }
163
164 #[test]
165 fn from_definitions_stores_tools() {
166 let reg = ToolRegistry::from_definitions(sample_tools());
167 assert_eq!(reg.tools().len(), 2);
168 }
169
170 #[test]
171 fn default_registry_is_empty() {
172 let reg = ToolRegistry::default();
173 assert!(reg.tools().is_empty());
174 }
175
176 #[test]
177 fn find_existing_tool() {
178 let reg = ToolRegistry::from_definitions(sample_tools());
179 assert!(reg.find("bash").is_some());
180 assert!(reg.find("read").is_some());
181 }
182
183 #[test]
184 fn find_nonexistent_returns_none() {
185 let reg = ToolRegistry::from_definitions(sample_tools());
186 assert!(reg.find("nonexistent").is_none());
187 }
188
189 #[test]
190 fn format_for_prompt_contains_tools() {
191 let reg = ToolRegistry::from_definitions(sample_tools());
192 let prompt =
193 reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
194 assert!(prompt.contains("<tools>"));
195 assert!(prompt.contains("</tools>"));
196 assert!(prompt.contains("## bash"));
197 assert!(prompt.contains("## read"));
198 }
199
200 #[test]
201 fn format_for_prompt_shows_invocation_fenced() {
202 let reg = ToolRegistry::from_definitions(sample_tools());
203 let prompt =
204 reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
205 assert!(prompt.contains("Invocation: use ```bash fenced block"));
206 }
207
208 #[test]
209 fn format_for_prompt_shows_invocation_tool_call() {
210 let reg = ToolRegistry::from_definitions(sample_tools());
211 let prompt =
212 reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
213 assert!(prompt.contains("Invocation: use tool_call"));
214 assert!(prompt.contains("\"tool_id\": \"read\""));
215 }
216
217 #[test]
218 fn format_for_prompt_shows_param_info() {
219 let reg = ToolRegistry::from_definitions(sample_tools());
220 let prompt =
221 reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
222 assert!(prompt.contains("command:"));
223 assert!(prompt.contains("required"));
224 assert!(prompt.contains("string"));
225 }
226
227 #[test]
228 fn format_for_prompt_shows_optional_params() {
229 let reg = ToolRegistry::from_definitions(sample_tools());
230 let prompt =
231 reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
232 assert!(prompt.contains("offset:"));
233 assert!(prompt.contains("optional"));
234 assert!(
235 prompt.contains("(integer, optional)"),
236 "Option<u32> should render as integer, not string: {prompt}"
237 );
238 }
239
240 #[test]
241 fn format_filtered_excludes_fully_denied() {
242 use crate::permissions::{PermissionAction, PermissionPolicy, PermissionRule};
243 use std::collections::HashMap;
244 let mut rules = HashMap::new();
245 rules.insert(
246 "bash".to_owned(),
247 vec![PermissionRule {
248 pattern: "*".to_owned(),
249 action: PermissionAction::Deny,
250 }],
251 );
252 let policy = PermissionPolicy::new(rules);
253 let reg = ToolRegistry::from_definitions(sample_tools());
254 let prompt = reg.format_for_prompt_filtered(&policy);
255 assert!(!prompt.contains("## bash"));
256 assert!(prompt.contains("## read"));
257 }
258
259 #[test]
260 fn format_filtered_includes_mixed_rules() {
261 use crate::permissions::{PermissionAction, PermissionPolicy, PermissionRule};
262 use std::collections::HashMap;
263 let mut rules = HashMap::new();
264 rules.insert(
265 "bash".to_owned(),
266 vec![
267 PermissionRule {
268 pattern: "echo *".to_owned(),
269 action: PermissionAction::Allow,
270 },
271 PermissionRule {
272 pattern: "*".to_owned(),
273 action: PermissionAction::Deny,
274 },
275 ],
276 );
277 let policy = PermissionPolicy::new(rules);
278 let reg = ToolRegistry::from_definitions(sample_tools());
279 let prompt = reg.format_for_prompt_filtered(&policy);
280 assert!(prompt.contains("## bash"));
281 }
282
283 #[test]
284 fn format_filtered_no_rules_includes_all() {
285 let policy = crate::permissions::PermissionPolicy::default();
286 let reg = ToolRegistry::from_definitions(sample_tools());
287 let prompt = reg.format_for_prompt_filtered(&policy);
288 assert!(prompt.contains("## bash"));
289 assert!(prompt.contains("## read"));
290 }
291}