Skip to main content

lash_sansio/
tool_catalog.rs

1use std::collections::BTreeMap;
2use std::sync::{Arc, OnceLock};
3
4use crate::llm::types::LlmToolSpec;
5use crate::{
6    PromptContribution, PromptFingerprint, ToolContract, ToolDefinition, ToolManifest,
7    prompt_tool_names_fingerprint,
8};
9
10pub type ToolContractResolver =
11    Arc<dyn Fn(&str) -> Option<Arc<ToolContract>> + Send + Sync + 'static>;
12
13#[derive(Clone)]
14pub struct ToolCatalogBuildInput {
15    pub tools: Vec<ToolManifest>,
16    pub resolve_contract: Option<ToolContractResolver>,
17    pub contributions: Vec<ToolCatalogContribution>,
18}
19
20/// A trusted plugin's contribution to catalog assembly. Membership is the
21/// execution gate, so the only override a contribution can express is *removal*
22/// of a member (authority hiding, plan-mode gating). Adding members happens by
23/// a [`crate::ToolProvider`] including them in its manifest list.
24#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
25pub struct ToolCatalogContribution {
26    /// Names of tools to remove from the catalog (non-membership).
27    pub remove: Vec<String>,
28}
29
30impl ToolCatalogContribution {
31    pub fn is_empty(&self) -> bool {
32        self.remove.is_empty()
33    }
34
35    pub fn remove_tools(tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
36        Self {
37            remove: tools.into_iter().map(Into::into).collect(),
38        }
39    }
40}
41
42#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
43pub struct ToolCatalogEntry {
44    pub manifest: ToolManifest,
45}
46
47#[derive(serde::Serialize, serde::Deserialize)]
48pub struct ToolCatalog {
49    pub tools: Vec<ToolCatalogEntry>,
50    #[serde(skip)]
51    resolve_contract: Option<ToolContractResolver>,
52    #[serde(skip)]
53    model_tool_specs: OnceLock<Arc<Vec<LlmToolSpec>>>,
54    #[serde(skip)]
55    tool_names: OnceLock<Arc<Vec<String>>>,
56    #[serde(skip)]
57    tool_names_fingerprint: OnceLock<PromptFingerprint>,
58}
59
60impl Clone for ToolCatalog {
61    fn clone(&self) -> Self {
62        let clone = Self {
63            tools: self.tools.clone(),
64            resolve_contract: self.resolve_contract.clone(),
65            model_tool_specs: OnceLock::new(),
66            tool_names: OnceLock::new(),
67            tool_names_fingerprint: OnceLock::new(),
68        };
69        if let Some(value) = self.model_tool_specs.get() {
70            let _ = clone.model_tool_specs.set(Arc::clone(value));
71        }
72        if let Some(value) = self.tool_names.get() {
73            let _ = clone.tool_names.set(Arc::clone(value));
74        }
75        if let Some(value) = self.tool_names_fingerprint.get() {
76            let _ = clone.tool_names_fingerprint.set(*value);
77        }
78        clone
79    }
80}
81
82impl std::fmt::Debug for ToolCatalog {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        f.debug_struct("ToolCatalog")
85            .field("tools", &self.tools)
86            .finish_non_exhaustive()
87    }
88}
89
90impl Default for ToolCatalog {
91    fn default() -> Self {
92        Self {
93            tools: Vec::new(),
94            resolve_contract: None,
95            model_tool_specs: OnceLock::new(),
96            tool_names: OnceLock::new(),
97            tool_names_fingerprint: OnceLock::new(),
98        }
99    }
100}
101
102impl ToolCatalog {
103    pub fn from_tool_definitions(tools: Vec<ToolDefinition>) -> Self {
104        let contracts = tools
105            .iter()
106            .map(|tool| (tool.name().to_string(), Arc::new(tool.contract())))
107            .collect();
108        Self::from_tools(
109            tools.into_iter().map(|tool| tool.manifest()).collect(),
110            contracts,
111        )
112    }
113
114    pub fn from_tools(
115        tools: Vec<ToolManifest>,
116        contracts: BTreeMap<String, Arc<ToolContract>>,
117    ) -> Self {
118        let resolver_contracts = Arc::new(contracts);
119        Self::from_tool_manifests(
120            tools,
121            Some(Arc::new(move |name| resolver_contracts.get(name).cloned())),
122        )
123    }
124
125    fn from_tool_manifests(
126        tools: Vec<ToolManifest>,
127        resolve_contract: Option<ToolContractResolver>,
128    ) -> Self {
129        Self {
130            tools: tools
131                .into_iter()
132                .map(|manifest| ToolCatalogEntry { manifest })
133                .collect(),
134            resolve_contract,
135            model_tool_specs: OnceLock::new(),
136            tool_names: OnceLock::new(),
137            tool_names_fingerprint: OnceLock::new(),
138        }
139    }
140
141    /// All catalog members. Membership is callability; there is no filtering.
142    pub fn callable_tools_iter(&self) -> impl Iterator<Item = &ToolManifest> {
143        self.tools.iter().map(|tool| &tool.manifest)
144    }
145
146    pub fn callable_tools(&self) -> Vec<ToolManifest> {
147        self.callable_tools_iter().cloned().collect()
148    }
149
150    /// Membership test: a tool is in the catalog (callable) or it does not
151    /// exist to the model.
152    pub fn has_callable_tool(&self, tool_name: &str) -> bool {
153        self.tools
154            .iter()
155            .any(|tool| tool.manifest.name == tool_name)
156    }
157
158    pub fn tool_names(&self) -> Arc<Vec<String>> {
159        Arc::clone(self.tool_names.get_or_init(|| {
160            Arc::new(
161                self.tools
162                    .iter()
163                    .map(|tool| tool.manifest.name.clone())
164                    .collect(),
165            )
166        }))
167    }
168
169    pub fn tool_names_fingerprint(&self) -> PromptFingerprint {
170        *self
171            .tool_names_fingerprint
172            .get_or_init(|| prompt_tool_names_fingerprint(&self.tool_names()))
173    }
174
175    pub fn model_tool_specs(&self) -> Arc<Vec<LlmToolSpec>> {
176        Arc::clone(self.model_tool_specs.get_or_init(|| {
177            Arc::new(
178                self.tools
179                    .iter()
180                    .filter_map(|tool| {
181                        self.resolve_contract(&tool.manifest.name)
182                            .map(|contract| contract.model_tool(&tool.manifest))
183                    })
184                    .map(|model_tool| LlmToolSpec {
185                        name: model_tool.name,
186                        description: model_tool.description,
187                        input_schema: model_tool.input_schema,
188                        output_schema: model_tool.output_schema,
189                    })
190                    .collect(),
191            )
192        }))
193    }
194
195    pub fn resolve_contract(&self, tool_name: &str) -> Option<Arc<ToolContract>> {
196        self.resolve_contract
197            .as_ref()
198            .and_then(|resolve| resolve(tool_name))
199    }
200
201    pub fn filter_prompt_contributions(
202        &self,
203        contributions: Vec<PromptContribution>,
204    ) -> Vec<PromptContribution> {
205        contributions
206            .into_iter()
207            .filter(|contribution| self.includes_prompt_contribution(contribution))
208            .collect()
209    }
210
211    fn includes_prompt_contribution(&self, contribution: &PromptContribution) -> bool {
212        if contribution.gate.is_empty() {
213            return true;
214        }
215        contribution
216            .gate
217            .tools
218            .iter()
219            .any(|tool_name| self.has_callable_tool(tool_name))
220    }
221}
222
223pub fn build_tool_catalog(input: ToolCatalogBuildInput) -> ToolCatalog {
224    let mut catalog = ToolCatalog::from_tool_manifests(input.tools, input.resolve_contract);
225    for contribution in input.contributions {
226        apply_contribution(&mut catalog, contribution);
227    }
228    catalog
229}
230
231fn apply_contribution(catalog: &mut ToolCatalog, contribution: ToolCatalogContribution) {
232    if contribution.remove.is_empty() {
233        return;
234    }
235    catalog
236        .tools
237        .retain(|tool| !contribution.remove.contains(&tool.manifest.name));
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::{ToolActivation, ToolScheduling};
244    use std::sync::atomic::{AtomicUsize, Ordering};
245
246    fn tool(name: &str) -> ToolDefinition {
247        let mut definition = ToolDefinition::raw(
248            format!("tool:{name}"),
249            name,
250            format!("Tool {name}"),
251            serde_json::json!({
252                "type": "object",
253                "properties": { "path": { "type": "string" } },
254                "required": ["path"]
255            }),
256            serde_json::json!({ "type": "string" }),
257        );
258        definition.manifest.activation = ToolActivation::Always;
259        definition.manifest.scheduling = ToolScheduling::Parallel;
260        definition
261    }
262
263    fn build_input(
264        tools: Vec<ToolDefinition>,
265        contributions: Vec<ToolCatalogContribution>,
266    ) -> ToolCatalogBuildInput {
267        let contracts = tools
268            .iter()
269            .map(|tool| (tool.name().to_string(), Arc::new(tool.contract())))
270            .collect::<BTreeMap<_, _>>();
271        ToolCatalogBuildInput {
272            tools: tools.into_iter().map(|tool| tool.manifest()).collect(),
273            resolve_contract: Some(Arc::new(move |name| contracts.get(name).cloned())),
274            contributions,
275        }
276    }
277
278    #[test]
279    fn catalog_membership_is_flat_and_callable() {
280        let catalog = build_tool_catalog(build_input(
281            vec![tool("read_file"), tool("grep"), tool("write_file")],
282            Vec::new(),
283        ));
284
285        assert_eq!(catalog.callable_tools().len(), 3);
286        assert!(catalog.has_callable_tool("read_file"));
287        assert!(catalog.has_callable_tool("grep"));
288        assert!(!catalog.has_callable_tool("absent"));
289    }
290
291    #[test]
292    fn contributions_remove_members() {
293        let catalog = build_tool_catalog(build_input(
294            vec![tool("read_file"), tool("write_file")],
295            vec![ToolCatalogContribution::remove_tools(["write_file"])],
296        ));
297
298        assert!(catalog.has_callable_tool("read_file"));
299        assert!(!catalog.has_callable_tool("write_file"));
300        assert_eq!(catalog.callable_tools().len(), 1);
301    }
302
303    #[test]
304    fn prompt_gate_requires_member_tool() {
305        let catalog = build_tool_catalog(build_input(vec![tool("read_file")], Vec::new()));
306
307        let kept = catalog.filter_prompt_contributions(vec![
308            PromptContribution::guidance("Plain", "always"),
309            PromptContribution::guidance("WithTool", "withtool").requires_tool("read_file"),
310            PromptContribution::guidance("MissingTool", "missing").requires_tool("missing_tool"),
311        ]);
312
313        assert_eq!(kept.len(), 2);
314        assert!(
315            kept.iter()
316                .any(|contribution| contribution.title.as_deref() == Some("Plain"))
317        );
318        assert!(
319            kept.iter()
320                .any(|contribution| contribution.title.as_deref() == Some("WithTool"))
321        );
322    }
323
324    #[test]
325    fn model_specs_resolve_lazily() {
326        let contract_resolutions = Arc::new(AtomicUsize::new(0));
327        let callable = tool("read_file");
328        let resolver_count = Arc::clone(&contract_resolutions);
329        let catalog = build_tool_catalog(ToolCatalogBuildInput {
330            tools: vec![callable.manifest()],
331            resolve_contract: Some(Arc::new(move |name| {
332                resolver_count.fetch_add(1, Ordering::SeqCst);
333                (name == "read_file").then(|| Arc::new(callable.contract()))
334            })),
335            contributions: Vec::new(),
336        });
337
338        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 0);
339        assert_eq!(catalog.model_tool_specs().len(), 1);
340        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
341        assert_eq!(catalog.model_tool_specs().len(), 1);
342        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
343    }
344
345    #[test]
346    fn tool_names_fingerprint_matches_prompt_hash() {
347        let catalog = build_tool_catalog(build_input(
348            vec![tool("read_file"), tool("grep")],
349            Vec::new(),
350        ));
351
352        assert_eq!(
353            catalog.tool_names_fingerprint(),
354            prompt_tool_names_fingerprint(&catalog.tool_names())
355        );
356    }
357}