1use std::fmt::Write;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum InvocationHint {
8 FencedBlock(&'static str),
10 ToolCall,
12}
13
14#[derive(Debug, Clone)]
15pub struct ToolDef {
16 pub id: &'static str,
17 pub description: &'static str,
18 pub schema: schemars::Schema,
19 pub invocation: InvocationHint,
20}
21
22#[derive(Debug, Default)]
23pub struct ToolRegistry {
24 tools: Vec<ToolDef>,
25}
26
27impl ToolRegistry {
28 #[must_use]
29 pub fn from_definitions(tools: Vec<ToolDef>) -> Self {
30 Self { tools }
31 }
32
33 #[must_use]
34 pub fn tools(&self) -> &[ToolDef] {
35 &self.tools
36 }
37
38 #[must_use]
39 pub fn find(&self, id: &str) -> Option<&ToolDef> {
40 self.tools.iter().find(|t| t.id == id)
41 }
42
43 #[must_use]
45 pub fn format_for_prompt_filtered(
46 &self,
47 policy: &crate::permissions::PermissionPolicy,
48 ) -> String {
49 let mut out = String::from("<tools>\n");
50 for tool in &self.tools {
51 if policy.is_fully_denied(tool.id) {
52 continue;
53 }
54 format_tool(&mut out, tool);
55 }
56 out.push_str("</tools>");
57 out
58 }
59}
60
61fn format_tool(out: &mut String, tool: &ToolDef) {
62 let _ = writeln!(out, "## {}", tool.id);
63 let _ = writeln!(out, "{}", tool.description);
64 match tool.invocation {
65 InvocationHint::FencedBlock(tag) => {
66 let _ = writeln!(out, "Invocation: use ```{tag} fenced block");
67 }
68 InvocationHint::ToolCall => {
69 let _ = writeln!(
70 out,
71 "Invocation: use tool_call with {{\"tool_id\": \"{}\", \"params\": {{...}}}}",
72 tool.id
73 );
74 }
75 }
76 format_schema_params(out, &tool.schema);
77 out.push('\n');
78}
79
80fn extract_non_null_type(obj: &serde_json::Map<String, serde_json::Value>) -> Option<&str> {
83 if let Some(arr) = obj.get("type").and_then(|v| v.as_array()) {
84 return arr.iter().filter_map(|v| v.as_str()).find(|t| *t != "null");
85 }
86 obj.get("anyOf")?
87 .as_array()?
88 .iter()
89 .filter_map(|v| v.as_object())
90 .filter_map(|o| o.get("type")?.as_str())
91 .find(|t| *t != "null")
92}
93
94fn format_schema_params(out: &mut String, schema: &schemars::Schema) {
95 let Some(obj) = schema.as_object() else {
96 return;
97 };
98 let Some(serde_json::Value::Object(props)) = obj.get("properties") else {
99 return;
100 };
101 if props.is_empty() {
102 return;
103 }
104
105 let required: Vec<&str> = obj
106 .get("required")
107 .and_then(|v| v.as_array())
108 .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
109 .unwrap_or_default();
110
111 let _ = writeln!(out, "Parameters:");
112 for (name, prop) in props {
113 let prop_obj = prop.as_object();
114 let ty = prop_obj
115 .and_then(|o| {
116 o.get("type")
117 .and_then(|v| v.as_str())
118 .or_else(|| extract_non_null_type(o))
119 })
120 .unwrap_or("string");
121 let desc = prop_obj
122 .and_then(|o| o.get("description"))
123 .and_then(|v| v.as_str())
124 .unwrap_or("");
125 let req = if required.contains(&name.as_str()) {
126 "required"
127 } else {
128 "optional"
129 };
130 let _ = writeln!(out, " - {name}: {desc} ({ty}, {req})");
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::file::ReadParams;
138 use crate::shell::BashParams;
139
140 fn sample_tools() -> Vec<ToolDef> {
141 vec![
142 ToolDef {
143 id: "bash",
144 description: "Execute a shell command",
145 schema: schemars::schema_for!(BashParams),
146 invocation: InvocationHint::FencedBlock("bash"),
147 },
148 ToolDef {
149 id: "read",
150 description: "Read file contents",
151 schema: schemars::schema_for!(ReadParams),
152 invocation: InvocationHint::ToolCall,
153 },
154 ]
155 }
156
157 #[test]
158 fn from_definitions_stores_tools() {
159 let reg = ToolRegistry::from_definitions(sample_tools());
160 assert_eq!(reg.tools().len(), 2);
161 }
162
163 #[test]
164 fn default_registry_is_empty() {
165 let reg = ToolRegistry::default();
166 assert!(reg.tools().is_empty());
167 }
168
169 #[test]
170 fn find_existing_tool() {
171 let reg = ToolRegistry::from_definitions(sample_tools());
172 assert!(reg.find("bash").is_some());
173 assert!(reg.find("read").is_some());
174 }
175
176 #[test]
177 fn find_nonexistent_returns_none() {
178 let reg = ToolRegistry::from_definitions(sample_tools());
179 assert!(reg.find("nonexistent").is_none());
180 }
181
182 #[test]
183 fn format_for_prompt_contains_tools() {
184 let reg = ToolRegistry::from_definitions(sample_tools());
185 let prompt =
186 reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
187 assert!(prompt.contains("<tools>"));
188 assert!(prompt.contains("</tools>"));
189 assert!(prompt.contains("## bash"));
190 assert!(prompt.contains("## read"));
191 }
192
193 #[test]
194 fn format_for_prompt_shows_invocation_fenced() {
195 let reg = ToolRegistry::from_definitions(sample_tools());
196 let prompt =
197 reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
198 assert!(prompt.contains("Invocation: use ```bash fenced block"));
199 }
200
201 #[test]
202 fn format_for_prompt_shows_invocation_tool_call() {
203 let reg = ToolRegistry::from_definitions(sample_tools());
204 let prompt =
205 reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
206 assert!(prompt.contains("Invocation: use tool_call"));
207 assert!(prompt.contains("\"tool_id\": \"read\""));
208 }
209
210 #[test]
211 fn format_for_prompt_shows_param_info() {
212 let reg = ToolRegistry::from_definitions(sample_tools());
213 let prompt =
214 reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
215 assert!(prompt.contains("command:"));
216 assert!(prompt.contains("required"));
217 assert!(prompt.contains("string"));
218 }
219
220 #[test]
221 fn format_for_prompt_shows_optional_params() {
222 let reg = ToolRegistry::from_definitions(sample_tools());
223 let prompt =
224 reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
225 assert!(prompt.contains("offset:"));
226 assert!(prompt.contains("optional"));
227 assert!(
228 prompt.contains("(integer, optional)"),
229 "Option<u32> should render as integer, not string: {prompt}"
230 );
231 }
232
233 #[test]
234 fn format_filtered_excludes_fully_denied() {
235 use crate::permissions::{PermissionAction, PermissionPolicy, PermissionRule};
236 use std::collections::HashMap;
237 let mut rules = HashMap::new();
238 rules.insert(
239 "bash".to_owned(),
240 vec![PermissionRule {
241 pattern: "*".to_owned(),
242 action: PermissionAction::Deny,
243 }],
244 );
245 let policy = PermissionPolicy::new(rules);
246 let reg = ToolRegistry::from_definitions(sample_tools());
247 let prompt = reg.format_for_prompt_filtered(&policy);
248 assert!(!prompt.contains("## bash"));
249 assert!(prompt.contains("## read"));
250 }
251
252 #[test]
253 fn format_filtered_includes_mixed_rules() {
254 use crate::permissions::{PermissionAction, PermissionPolicy, PermissionRule};
255 use std::collections::HashMap;
256 let mut rules = HashMap::new();
257 rules.insert(
258 "bash".to_owned(),
259 vec![
260 PermissionRule {
261 pattern: "echo *".to_owned(),
262 action: PermissionAction::Allow,
263 },
264 PermissionRule {
265 pattern: "*".to_owned(),
266 action: PermissionAction::Deny,
267 },
268 ],
269 );
270 let policy = PermissionPolicy::new(rules);
271 let reg = ToolRegistry::from_definitions(sample_tools());
272 let prompt = reg.format_for_prompt_filtered(&policy);
273 assert!(prompt.contains("## bash"));
274 }
275
276 #[test]
277 fn format_filtered_no_rules_includes_all() {
278 let policy = crate::permissions::PermissionPolicy::default();
279 let reg = ToolRegistry::from_definitions(sample_tools());
280 let prompt = reg.format_for_prompt_filtered(&policy);
281 assert!(prompt.contains("## bash"));
282 assert!(prompt.contains("## read"));
283 }
284}