Skip to main content

lash_sansio/
tool_catalog.rs

1use std::collections::BTreeMap;
2use std::sync::{Arc, OnceLock};
3
4use crate::llm::types::LlmToolSpec;
5use crate::{
6    PromptContribution, PromptFingerprint, ToolContract, ToolDefinition, ToolManifest,
7    prompt_tool_names_fingerprint,
8};
9
10pub type ToolContractResolver =
11    Arc<dyn Fn(&str) -> Option<Arc<ToolContract>> + Send + Sync + 'static>;
12
13#[derive(Clone)]
14pub struct ToolCatalogBuildInput {
15    pub tools: Vec<ToolManifest>,
16    pub resolve_contract: Option<ToolContractResolver>,
17    pub contributions: Vec<ToolCatalogContribution>,
18}
19
20/// A trusted plugin's contribution to catalog assembly. Membership is the
21/// execution gate, so the only override a contribution can express is *removal*
22/// of a member (authority hiding, plan-mode gating). Adding members happens by
23/// a [`crate::ToolProvider`] including them in its manifest list.
24#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
25pub struct ToolCatalogContribution {
26    /// Names of tools to remove from the catalog (non-membership).
27    pub remove: Vec<String>,
28}
29
30impl ToolCatalogContribution {
31    pub fn is_empty(&self) -> bool {
32        self.remove.is_empty()
33    }
34
35    pub fn remove_tools(tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
36        Self {
37            remove: tools.into_iter().map(Into::into).collect(),
38        }
39    }
40}
41
42#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
43pub struct ToolCatalogEntry {
44    pub manifest: ToolManifest,
45}
46
47#[derive(serde::Serialize, serde::Deserialize)]
48pub struct ToolCatalog {
49    pub tools: Vec<ToolCatalogEntry>,
50    #[serde(skip)]
51    resolve_contract: Option<ToolContractResolver>,
52    #[serde(skip)]
53    model_tool_specs: OnceLock<Arc<Vec<LlmToolSpec>>>,
54    #[serde(skip)]
55    tool_names: OnceLock<Arc<Vec<String>>>,
56    #[serde(skip)]
57    tool_names_fingerprint: OnceLock<PromptFingerprint>,
58}
59
60impl Clone for ToolCatalog {
61    fn clone(&self) -> Self {
62        let clone = Self {
63            tools: self.tools.clone(),
64            resolve_contract: self.resolve_contract.clone(),
65            model_tool_specs: OnceLock::new(),
66            tool_names: OnceLock::new(),
67            tool_names_fingerprint: OnceLock::new(),
68        };
69        if let Some(value) = self.model_tool_specs.get() {
70            let _ = clone.model_tool_specs.set(Arc::clone(value));
71        }
72        if let Some(value) = self.tool_names.get() {
73            let _ = clone.tool_names.set(Arc::clone(value));
74        }
75        if let Some(value) = self.tool_names_fingerprint.get() {
76            let _ = clone.tool_names_fingerprint.set(*value);
77        }
78        clone
79    }
80}
81
82impl std::fmt::Debug for ToolCatalog {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        f.debug_struct("ToolCatalog")
85            .field("tools", &self.tools)
86            .finish_non_exhaustive()
87    }
88}
89
90impl Default for ToolCatalog {
91    fn default() -> Self {
92        Self {
93            tools: Vec::new(),
94            resolve_contract: None,
95            model_tool_specs: OnceLock::new(),
96            tool_names: OnceLock::new(),
97            tool_names_fingerprint: OnceLock::new(),
98        }
99    }
100}
101
102impl ToolCatalog {
103    pub fn from_tool_definitions(tools: Vec<ToolDefinition>) -> Self {
104        let contracts = tools
105            .iter()
106            .map(|tool| (tool.name().to_string(), Arc::new(tool.contract())))
107            .collect();
108        Self::from_tools(
109            tools.into_iter().map(|tool| tool.manifest()).collect(),
110            contracts,
111        )
112    }
113
114    pub fn from_tools(
115        tools: Vec<ToolManifest>,
116        contracts: BTreeMap<String, Arc<ToolContract>>,
117    ) -> Self {
118        let resolver_contracts = Arc::new(contracts);
119        Self::from_tool_manifests(
120            tools,
121            Some(Arc::new(move |name| resolver_contracts.get(name).cloned())),
122        )
123    }
124
125    fn from_tool_manifests(
126        tools: Vec<ToolManifest>,
127        resolve_contract: Option<ToolContractResolver>,
128    ) -> Self {
129        Self {
130            tools: tools
131                .into_iter()
132                .map(|manifest| ToolCatalogEntry { manifest })
133                .collect(),
134            resolve_contract,
135            model_tool_specs: OnceLock::new(),
136            tool_names: OnceLock::new(),
137            tool_names_fingerprint: OnceLock::new(),
138        }
139    }
140
141    /// All catalog members. Membership is callability; there is no filtering.
142    pub fn callable_tools_iter(&self) -> impl Iterator<Item = &ToolManifest> {
143        self.tools.iter().map(|tool| &tool.manifest)
144    }
145
146    pub fn callable_tools(&self) -> Vec<ToolManifest> {
147        self.callable_tools_iter().cloned().collect()
148    }
149
150    /// Membership test: a tool is in the catalog (callable) or it does not
151    /// exist to the model.
152    pub fn has_callable_tool(&self, tool_name: &str) -> bool {
153        self.tools
154            .iter()
155            .any(|tool| tool.manifest.name == tool_name)
156    }
157
158    pub fn tool_names(&self) -> Arc<Vec<String>> {
159        Arc::clone(self.tool_names.get_or_init(|| {
160            Arc::new(
161                self.tools
162                    .iter()
163                    .map(|tool| tool.manifest.name.clone())
164                    .collect(),
165            )
166        }))
167    }
168
169    pub fn tool_names_fingerprint(&self) -> PromptFingerprint {
170        *self
171            .tool_names_fingerprint
172            .get_or_init(|| prompt_tool_names_fingerprint(&self.tool_names()))
173    }
174
175    pub fn model_tool_specs(&self) -> Arc<Vec<LlmToolSpec>> {
176        Arc::clone(self.model_tool_specs.get_or_init(|| {
177            Arc::new(
178                self.tools
179                    .iter()
180                    .filter_map(|tool| {
181                        self.resolve_contract(&tool.manifest.name)
182                            .map(|contract| contract.model_tool(&tool.manifest))
183                    })
184                    .map(|model_tool| LlmToolSpec {
185                        name: model_tool.name,
186                        description: model_tool.description,
187                        input_schema: model_tool.input_schema,
188                        output_schema: model_tool.output_schema,
189                        input_schema_projections: model_tool.input_schema_projections,
190                        output_schema_projections: model_tool.output_schema_projections,
191                    })
192                    .collect(),
193            )
194        }))
195    }
196
197    pub fn resolve_contract(&self, tool_name: &str) -> Option<Arc<ToolContract>> {
198        self.resolve_contract
199            .as_ref()
200            .and_then(|resolve| resolve(tool_name))
201    }
202
203    pub fn filter_prompt_contributions(
204        &self,
205        contributions: Vec<PromptContribution>,
206    ) -> Vec<PromptContribution> {
207        contributions
208            .into_iter()
209            .filter(|contribution| self.includes_prompt_contribution(contribution))
210            .collect()
211    }
212
213    fn includes_prompt_contribution(&self, contribution: &PromptContribution) -> bool {
214        if contribution.gate.is_empty() {
215            return true;
216        }
217        contribution
218            .gate
219            .tools
220            .iter()
221            .any(|tool_name| self.has_callable_tool(tool_name))
222    }
223}
224
225pub fn build_tool_catalog(input: ToolCatalogBuildInput) -> ToolCatalog {
226    let mut catalog = ToolCatalog::from_tool_manifests(input.tools, input.resolve_contract);
227    for contribution in input.contributions {
228        apply_contribution(&mut catalog, contribution);
229    }
230    catalog
231}
232
233fn apply_contribution(catalog: &mut ToolCatalog, contribution: ToolCatalogContribution) {
234    if contribution.remove.is_empty() {
235        return;
236    }
237    catalog
238        .tools
239        .retain(|tool| !contribution.remove.contains(&tool.manifest.name));
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use crate::{ToolActivation, ToolScheduling};
246    use std::sync::atomic::{AtomicUsize, Ordering};
247
248    fn tool(name: &str) -> ToolDefinition {
249        let mut definition = ToolDefinition::raw(
250            format!("tool:{name}"),
251            name,
252            format!("Tool {name}"),
253            serde_json::json!({
254                "type": "object",
255                "properties": { "path": { "type": "string" } },
256                "required": ["path"]
257            }),
258            serde_json::json!({ "type": "string" }),
259        );
260        definition.manifest.activation = ToolActivation::Always;
261        definition.manifest.scheduling = ToolScheduling::Parallel;
262        definition
263    }
264
265    fn build_input(
266        tools: Vec<ToolDefinition>,
267        contributions: Vec<ToolCatalogContribution>,
268    ) -> ToolCatalogBuildInput {
269        let contracts = tools
270            .iter()
271            .map(|tool| (tool.name().to_string(), Arc::new(tool.contract())))
272            .collect::<BTreeMap<_, _>>();
273        ToolCatalogBuildInput {
274            tools: tools.into_iter().map(|tool| tool.manifest()).collect(),
275            resolve_contract: Some(Arc::new(move |name| contracts.get(name).cloned())),
276            contributions,
277        }
278    }
279
280    #[test]
281    fn catalog_membership_is_flat_and_callable() {
282        let catalog = build_tool_catalog(build_input(
283            vec![tool("read_file"), tool("grep"), tool("write_file")],
284            Vec::new(),
285        ));
286
287        assert_eq!(catalog.callable_tools().len(), 3);
288        assert!(catalog.has_callable_tool("read_file"));
289        assert!(catalog.has_callable_tool("grep"));
290        assert!(!catalog.has_callable_tool("absent"));
291    }
292
293    #[test]
294    fn contributions_remove_members() {
295        let catalog = build_tool_catalog(build_input(
296            vec![tool("read_file"), tool("write_file")],
297            vec![ToolCatalogContribution::remove_tools(["write_file"])],
298        ));
299
300        assert!(catalog.has_callable_tool("read_file"));
301        assert!(!catalog.has_callable_tool("write_file"));
302        assert_eq!(catalog.callable_tools().len(), 1);
303    }
304
305    #[test]
306    fn prompt_gate_requires_member_tool() {
307        let catalog = build_tool_catalog(build_input(vec![tool("read_file")], Vec::new()));
308
309        let kept = catalog.filter_prompt_contributions(vec![
310            PromptContribution::guidance("Plain", "always"),
311            PromptContribution::guidance("WithTool", "withtool").requires_tool("read_file"),
312            PromptContribution::guidance("MissingTool", "missing").requires_tool("missing_tool"),
313        ]);
314
315        assert_eq!(kept.len(), 2);
316        assert!(
317            kept.iter()
318                .any(|contribution| contribution.title.as_deref() == Some("Plain"))
319        );
320        assert!(
321            kept.iter()
322                .any(|contribution| contribution.title.as_deref() == Some("WithTool"))
323        );
324    }
325
326    #[test]
327    fn model_specs_resolve_lazily() {
328        let contract_resolutions = Arc::new(AtomicUsize::new(0));
329        let callable = tool("read_file");
330        let resolver_count = Arc::clone(&contract_resolutions);
331        let catalog = build_tool_catalog(ToolCatalogBuildInput {
332            tools: vec![callable.manifest()],
333            resolve_contract: Some(Arc::new(move |name| {
334                resolver_count.fetch_add(1, Ordering::SeqCst);
335                (name == "read_file").then(|| Arc::new(callable.contract()))
336            })),
337            contributions: Vec::new(),
338        });
339
340        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 0);
341        assert_eq!(catalog.model_tool_specs().len(), 1);
342        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
343        assert_eq!(catalog.model_tool_specs().len(), 1);
344        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
345    }
346
347    #[test]
348    fn tool_names_fingerprint_matches_prompt_hash() {
349        let catalog = build_tool_catalog(build_input(
350            vec![tool("read_file"), tool("grep")],
351            Vec::new(),
352        ));
353
354        assert_eq!(
355            catalog.tool_names_fingerprint(),
356            prompt_tool_names_fingerprint(&catalog.tool_names())
357        );
358    }
359}