Skip to main content

zeph_tools/
registry.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::fmt::Write;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum InvocationHint {
8    /// Tool invoked via ```{tag}\n...\n``` fenced block in LLM response
9    FencedBlock(&'static str),
10    /// Tool invoked via structured `ToolCall` JSON
11    ToolCall,
12}
13
14#[derive(Debug, Clone)]
15pub struct ToolDef {
16    pub id: &'static str,
17    pub description: &'static str,
18    pub schema: schemars::Schema,
19    pub invocation: InvocationHint,
20}
21
22#[derive(Debug, Default)]
23pub struct ToolRegistry {
24    tools: Vec<ToolDef>,
25}
26
27impl ToolRegistry {
28    #[must_use]
29    pub fn from_definitions(tools: Vec<ToolDef>) -> Self {
30        Self { tools }
31    }
32
33    #[must_use]
34    pub fn tools(&self) -> &[ToolDef] {
35        &self.tools
36    }
37
38    #[must_use]
39    pub fn find(&self, id: &str) -> Option<&ToolDef> {
40        self.tools.iter().find(|t| t.id == id)
41    }
42
43    /// Format tools for prompt, excluding tools fully denied by policy.
44    #[must_use]
45    pub fn format_for_prompt_filtered(
46        &self,
47        policy: &crate::permissions::PermissionPolicy,
48    ) -> String {
49        let mut out = String::from("<tools>\n");
50        for tool in &self.tools {
51            if policy.is_fully_denied(tool.id) {
52                continue;
53            }
54            format_tool(&mut out, tool);
55        }
56        out.push_str("</tools>");
57        out
58    }
59}
60
61fn format_tool(out: &mut String, tool: &ToolDef) {
62    let _ = writeln!(out, "## {}", tool.id);
63    let _ = writeln!(out, "{}", tool.description);
64    match tool.invocation {
65        InvocationHint::FencedBlock(tag) => {
66            let _ = writeln!(out, "Invocation: use ```{tag} fenced block");
67        }
68        InvocationHint::ToolCall => {
69            let _ = writeln!(
70                out,
71                "Invocation: use tool_call with {{\"tool_id\": \"{}\", \"params\": {{...}}}}",
72                tool.id
73            );
74        }
75    }
76    format_schema_params(out, &tool.schema);
77    out.push('\n');
78}
79
80/// Extract the primary type when schemars renders `Option<T>` as `"type": ["T", "null"]`
81/// or `"anyOf": [{"type": "T"}, {"type": "null"}]`.
82fn extract_non_null_type(obj: &serde_json::Map<String, serde_json::Value>) -> Option<&str> {
83    if let Some(arr) = obj.get("type").and_then(|v| v.as_array()) {
84        return arr.iter().filter_map(|v| v.as_str()).find(|t| *t != "null");
85    }
86    obj.get("anyOf")?
87        .as_array()?
88        .iter()
89        .filter_map(|v| v.as_object())
90        .filter_map(|o| o.get("type")?.as_str())
91        .find(|t| *t != "null")
92}
93
94fn format_schema_params(out: &mut String, schema: &schemars::Schema) {
95    let Some(obj) = schema.as_object() else {
96        return;
97    };
98    let Some(serde_json::Value::Object(props)) = obj.get("properties") else {
99        return;
100    };
101    if props.is_empty() {
102        return;
103    }
104
105    let required: Vec<&str> = obj
106        .get("required")
107        .and_then(|v| v.as_array())
108        .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
109        .unwrap_or_default();
110
111    let _ = writeln!(out, "Parameters:");
112    for (name, prop) in props {
113        let prop_obj = prop.as_object();
114        let ty = prop_obj
115            .and_then(|o| {
116                o.get("type")
117                    .and_then(|v| v.as_str())
118                    .or_else(|| extract_non_null_type(o))
119            })
120            .unwrap_or("string");
121        let desc = prop_obj
122            .and_then(|o| o.get("description"))
123            .and_then(|v| v.as_str())
124            .unwrap_or("");
125        let req = if required.contains(&name.as_str()) {
126            "required"
127        } else {
128            "optional"
129        };
130        let _ = writeln!(out, "  - {name}: {desc} ({ty}, {req})");
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::file::ReadParams;
138    use crate::shell::BashParams;
139
140    fn sample_tools() -> Vec<ToolDef> {
141        vec![
142            ToolDef {
143                id: "bash",
144                description: "Execute a shell command",
145                schema: schemars::schema_for!(BashParams),
146                invocation: InvocationHint::FencedBlock("bash"),
147            },
148            ToolDef {
149                id: "read",
150                description: "Read file contents",
151                schema: schemars::schema_for!(ReadParams),
152                invocation: InvocationHint::ToolCall,
153            },
154        ]
155    }
156
157    #[test]
158    fn from_definitions_stores_tools() {
159        let reg = ToolRegistry::from_definitions(sample_tools());
160        assert_eq!(reg.tools().len(), 2);
161    }
162
163    #[test]
164    fn default_registry_is_empty() {
165        let reg = ToolRegistry::default();
166        assert!(reg.tools().is_empty());
167    }
168
169    #[test]
170    fn find_existing_tool() {
171        let reg = ToolRegistry::from_definitions(sample_tools());
172        assert!(reg.find("bash").is_some());
173        assert!(reg.find("read").is_some());
174    }
175
176    #[test]
177    fn find_nonexistent_returns_none() {
178        let reg = ToolRegistry::from_definitions(sample_tools());
179        assert!(reg.find("nonexistent").is_none());
180    }
181
182    #[test]
183    fn format_for_prompt_contains_tools() {
184        let reg = ToolRegistry::from_definitions(sample_tools());
185        let prompt =
186            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
187        assert!(prompt.contains("<tools>"));
188        assert!(prompt.contains("</tools>"));
189        assert!(prompt.contains("## bash"));
190        assert!(prompt.contains("## read"));
191    }
192
193    #[test]
194    fn format_for_prompt_shows_invocation_fenced() {
195        let reg = ToolRegistry::from_definitions(sample_tools());
196        let prompt =
197            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
198        assert!(prompt.contains("Invocation: use ```bash fenced block"));
199    }
200
201    #[test]
202    fn format_for_prompt_shows_invocation_tool_call() {
203        let reg = ToolRegistry::from_definitions(sample_tools());
204        let prompt =
205            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
206        assert!(prompt.contains("Invocation: use tool_call"));
207        assert!(prompt.contains("\"tool_id\": \"read\""));
208    }
209
210    #[test]
211    fn format_for_prompt_shows_param_info() {
212        let reg = ToolRegistry::from_definitions(sample_tools());
213        let prompt =
214            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
215        assert!(prompt.contains("command:"));
216        assert!(prompt.contains("required"));
217        assert!(prompt.contains("string"));
218    }
219
220    #[test]
221    fn format_for_prompt_shows_optional_params() {
222        let reg = ToolRegistry::from_definitions(sample_tools());
223        let prompt =
224            reg.format_for_prompt_filtered(&crate::permissions::PermissionPolicy::default());
225        assert!(prompt.contains("offset:"));
226        assert!(prompt.contains("optional"));
227        assert!(
228            prompt.contains("(integer, optional)"),
229            "Option<u32> should render as integer, not string: {prompt}"
230        );
231    }
232
233    #[test]
234    fn format_filtered_excludes_fully_denied() {
235        use crate::permissions::{PermissionAction, PermissionPolicy, PermissionRule};
236        use std::collections::HashMap;
237        let mut rules = HashMap::new();
238        rules.insert(
239            "bash".to_owned(),
240            vec![PermissionRule {
241                pattern: "*".to_owned(),
242                action: PermissionAction::Deny,
243            }],
244        );
245        let policy = PermissionPolicy::new(rules);
246        let reg = ToolRegistry::from_definitions(sample_tools());
247        let prompt = reg.format_for_prompt_filtered(&policy);
248        assert!(!prompt.contains("## bash"));
249        assert!(prompt.contains("## read"));
250    }
251
252    #[test]
253    fn format_filtered_includes_mixed_rules() {
254        use crate::permissions::{PermissionAction, PermissionPolicy, PermissionRule};
255        use std::collections::HashMap;
256        let mut rules = HashMap::new();
257        rules.insert(
258            "bash".to_owned(),
259            vec![
260                PermissionRule {
261                    pattern: "echo *".to_owned(),
262                    action: PermissionAction::Allow,
263                },
264                PermissionRule {
265                    pattern: "*".to_owned(),
266                    action: PermissionAction::Deny,
267                },
268            ],
269        );
270        let policy = PermissionPolicy::new(rules);
271        let reg = ToolRegistry::from_definitions(sample_tools());
272        let prompt = reg.format_for_prompt_filtered(&policy);
273        assert!(prompt.contains("## bash"));
274    }
275
276    #[test]
277    fn format_filtered_no_rules_includes_all() {
278        let policy = crate::permissions::PermissionPolicy::default();
279        let reg = ToolRegistry::from_definitions(sample_tools());
280        let prompt = reg.format_for_prompt_filtered(&policy);
281        assert!(prompt.contains("## bash"));
282        assert!(prompt.contains("## read"));
283    }
284}