Skip to main content

lash_sansio/
tool_catalog.rs

1use std::collections::BTreeMap;
2use std::sync::{Arc, OnceLock};
3
4use crate::llm::types::LlmToolSpec;
5use crate::{
6    PromptContribution, PromptFingerprint, ToolAvailability, ToolContract, ToolDefinition,
7    ToolManifest, prompt_tool_names_fingerprint,
8};
9
10pub type ToolContractResolver =
11    Arc<dyn Fn(&str) -> Option<Arc<ToolContract>> + Send + Sync + 'static>;
12
13#[derive(Clone)]
14pub struct ToolCatalogBuildInput {
15    pub tools: Vec<ToolManifest>,
16    pub resolve_contract: Option<ToolContractResolver>,
17    pub contributions: Vec<ToolCatalogContribution>,
18}
19
20#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
21pub struct ToolCatalogContribution {
22    pub overrides: Vec<ToolCatalogOverride>,
23    pub tool_list_notes: Vec<String>,
24}
25
26impl ToolCatalogContribution {
27    pub fn is_empty(&self) -> bool {
28        self.overrides.is_empty() && self.tool_list_notes.is_empty()
29    }
30}
31
32#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
33pub struct ToolCatalogOverride {
34    pub tool_name: String,
35    pub availability: Option<ToolAvailability>,
36}
37
38#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
39pub struct ToolCatalogEntry {
40    pub manifest: ToolManifest,
41    pub availability: ToolAvailability,
42}
43
44#[derive(serde::Serialize, serde::Deserialize)]
45pub struct ToolCatalog {
46    pub tools: Vec<ToolCatalogEntry>,
47    pub tool_list_notes: Vec<String>,
48    #[serde(skip)]
49    resolve_contract: Option<ToolContractResolver>,
50    #[serde(skip)]
51    prompt_tool_docs: OnceLock<Arc<str>>,
52    #[serde(skip)]
53    model_tool_specs: OnceLock<Arc<Vec<LlmToolSpec>>>,
54    #[serde(skip)]
55    tool_names: OnceLock<Arc<Vec<String>>>,
56    #[serde(skip)]
57    tool_names_fingerprint: OnceLock<PromptFingerprint>,
58}
59
60impl Clone for ToolCatalog {
61    fn clone(&self) -> Self {
62        let clone = Self {
63            tools: self.tools.clone(),
64            tool_list_notes: self.tool_list_notes.clone(),
65            resolve_contract: self.resolve_contract.clone(),
66            prompt_tool_docs: OnceLock::new(),
67            model_tool_specs: OnceLock::new(),
68            tool_names: OnceLock::new(),
69            tool_names_fingerprint: OnceLock::new(),
70        };
71        if let Some(value) = self.prompt_tool_docs.get() {
72            let _ = clone.prompt_tool_docs.set(Arc::clone(value));
73        }
74        if let Some(value) = self.model_tool_specs.get() {
75            let _ = clone.model_tool_specs.set(Arc::clone(value));
76        }
77        if let Some(value) = self.tool_names.get() {
78            let _ = clone.tool_names.set(Arc::clone(value));
79        }
80        if let Some(value) = self.tool_names_fingerprint.get() {
81            let _ = clone.tool_names_fingerprint.set(*value);
82        }
83        clone
84    }
85}
86
87impl std::fmt::Debug for ToolCatalog {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("ToolCatalog")
90            .field("tools", &self.tools)
91            .field("tool_list_notes", &self.tool_list_notes)
92            .finish_non_exhaustive()
93    }
94}
95
96impl Default for ToolCatalog {
97    fn default() -> Self {
98        Self {
99            tools: Vec::new(),
100            tool_list_notes: Vec::new(),
101            resolve_contract: None,
102            prompt_tool_docs: OnceLock::new(),
103            model_tool_specs: OnceLock::new(),
104            tool_names: OnceLock::new(),
105            tool_names_fingerprint: OnceLock::new(),
106        }
107    }
108}
109
110impl ToolCatalog {
111    pub fn from_tool_definitions(tools: Vec<ToolDefinition>) -> Self {
112        let contracts = tools
113            .iter()
114            .map(|tool| (tool.name().to_string(), Arc::new(tool.contract())))
115            .collect();
116        Self::from_tools(
117            tools.into_iter().map(|tool| tool.manifest()).collect(),
118            contracts,
119        )
120    }
121
122    pub fn from_tools(
123        tools: Vec<ToolManifest>,
124        contracts: BTreeMap<String, Arc<ToolContract>>,
125    ) -> Self {
126        let resolver_contracts = Arc::new(contracts);
127        Self::from_tool_manifests(
128            tools,
129            Some(Arc::new(move |name| resolver_contracts.get(name).cloned())),
130        )
131    }
132
133    fn from_tool_manifests(
134        tools: Vec<ToolManifest>,
135        resolve_contract: Option<ToolContractResolver>,
136    ) -> Self {
137        Self {
138            tools: tools
139                .into_iter()
140                .map(|manifest| ToolCatalogEntry {
141                    availability: manifest.effective_availability(),
142                    manifest,
143                })
144                .collect(),
145            tool_list_notes: Vec::new(),
146            resolve_contract,
147            prompt_tool_docs: OnceLock::new(),
148            model_tool_specs: OnceLock::new(),
149            tool_names: OnceLock::new(),
150            tool_names_fingerprint: OnceLock::new(),
151        }
152    }
153
154    pub fn callable_tools_iter(&self) -> impl Iterator<Item = &ToolManifest> {
155        self.tools
156            .iter()
157            .filter(|tool| tool.availability.is_callable())
158            .map(|tool| &tool.manifest)
159    }
160
161    pub fn callable_tools(&self) -> Vec<ToolManifest> {
162        self.callable_tools_iter().cloned().collect()
163    }
164
165    pub fn showcased_tools_iter(&self) -> impl Iterator<Item = &ToolManifest> {
166        self.tools
167            .iter()
168            .filter(|tool| tool.availability.is_showcased())
169            .map(|tool| &tool.manifest)
170    }
171
172    pub fn showcased_tools(&self) -> Vec<ToolManifest> {
173        self.showcased_tools_iter().cloned().collect()
174    }
175
176    pub fn searchable_tools_iter(&self) -> impl Iterator<Item = &ToolCatalogEntry> {
177        self.tools
178            .iter()
179            .filter(|tool| tool.availability.is_searchable())
180    }
181
182    pub fn omitted_tools_iter(&self) -> impl Iterator<Item = &ToolCatalogEntry> {
183        self.searchable_tools_iter()
184            .filter(|tool| !tool.availability.is_showcased())
185    }
186
187    pub fn has_callable_tool(&self, tool_name: &str) -> bool {
188        self.tools
189            .iter()
190            .any(|tool| tool.availability.is_callable() && tool.manifest.name == tool_name)
191    }
192
193    pub fn tool_availability(&self, tool_name: &str) -> Option<ToolAvailability> {
194        self.tools
195            .iter()
196            .find(|tool| tool.manifest.name == tool_name)
197            .map(|tool| tool.availability)
198    }
199
200    pub fn tool_names(&self) -> Arc<Vec<String>> {
201        Arc::clone(self.tool_names.get_or_init(|| {
202            Arc::new(
203                self.tools
204                    .iter()
205                    .filter(|tool| tool.availability.is_callable())
206                    .map(|tool| tool.manifest.name.clone())
207                    .collect(),
208            )
209        }))
210    }
211
212    pub fn tool_names_fingerprint(&self) -> PromptFingerprint {
213        *self
214            .tool_names_fingerprint
215            .get_or_init(|| prompt_tool_names_fingerprint(&self.tool_names()))
216    }
217
218    pub fn omitted_tool_count(&self) -> usize {
219        self.omitted_tools_iter().count()
220    }
221
222    pub fn model_tool_specs(&self) -> Arc<Vec<LlmToolSpec>> {
223        Arc::clone(self.model_tool_specs.get_or_init(|| {
224            Arc::new(
225                self.tools
226                    .iter()
227                    .filter(|tool| tool.availability.is_callable())
228                    .filter_map(|tool| {
229                        self.resolve_contract(&tool.manifest.name)
230                            .map(|contract| contract.model_tool(&tool.manifest))
231                    })
232                    .map(|model_tool| LlmToolSpec {
233                        name: model_tool.name,
234                        description: model_tool.description,
235                        input_schema: model_tool.input_schema,
236                        output_schema: model_tool.output_schema,
237                        input_schema_projections: model_tool.input_schema_projections,
238                        output_schema_projections: model_tool.output_schema_projections,
239                    })
240                    .collect(),
241            )
242        }))
243    }
244
245    pub fn prompt_tool_docs(&self) -> &str {
246        self.prompt_tool_docs
247            .get_or_init(|| Arc::from(self.rendered_prompt_tool_docs()))
248            .as_ref()
249    }
250
251    fn resolve_contract(&self, tool_name: &str) -> Option<Arc<ToolContract>> {
252        self.resolve_contract
253            .as_ref()
254            .and_then(|resolve| resolve(tool_name))
255    }
256
257    fn rendered_prompt_tool_docs(&self) -> String {
258        let mut docs = self
259            .tools
260            .iter()
261            .filter(|tool| tool.availability.is_showcased())
262            .filter_map(|tool| {
263                self.resolve_contract(&tool.manifest.name)
264                    .map(|contract| contract.compact_contract(&tool.manifest).render_markdown())
265            })
266            .collect::<Vec<_>>()
267            .join("\n\n");
268        for note in &self.tool_list_notes {
269            let note = note.trim();
270            if note.is_empty() {
271                continue;
272            }
273            if !docs.is_empty() {
274                docs.push_str("\n\n");
275            }
276            docs.push_str(note);
277        }
278        docs
279    }
280
281    pub fn filter_prompt_contributions(
282        &self,
283        contributions: Vec<PromptContribution>,
284    ) -> Vec<PromptContribution> {
285        contributions
286            .into_iter()
287            .filter(|contribution| self.includes_prompt_contribution(contribution))
288            .collect()
289    }
290
291    fn includes_prompt_contribution(&self, contribution: &PromptContribution) -> bool {
292        if contribution.gate.is_empty() {
293            return true;
294        }
295        contribution.gate.tools.iter().any(|tool_name| {
296            self.tool_availability(tool_name)
297                .is_some_and(|availability| availability >= contribution.gate.minimum_availability)
298        })
299    }
300}
301
302pub fn build_tool_catalog(input: ToolCatalogBuildInput) -> ToolCatalog {
303    let mut surface = ToolCatalog::from_tool_manifests(input.tools, input.resolve_contract);
304    for contribution in input.contributions {
305        apply_contribution(&mut surface, contribution);
306    }
307    surface
308}
309
310fn apply_contribution(surface: &mut ToolCatalog, contribution: ToolCatalogContribution) {
311    for override_ in contribution.overrides {
312        if let Some(tool) = surface
313            .tools
314            .iter_mut()
315            .find(|tool| tool.manifest.name == override_.tool_name)
316            && let Some(availability) = override_.availability
317        {
318            tool.availability = availability;
319        }
320    }
321
322    surface.tool_list_notes.extend(
323        contribution
324            .tool_list_notes
325            .into_iter()
326            .map(|note| note.trim().to_string())
327            .filter(|note| !note.is_empty()),
328    );
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use crate::{ToolActivation, ToolAvailabilityConfig, ToolScheduling};
335    use std::sync::atomic::{AtomicUsize, Ordering};
336
337    fn tool(name: &str, availability: ToolAvailability) -> ToolDefinition {
338        let mut definition = ToolDefinition::raw(
339            format!("tool:{name}"),
340            name,
341            format!("Tool {name}"),
342            serde_json::json!({
343                "type": "object",
344                "properties": { "path": { "type": "string" } },
345                "required": ["path"]
346            }),
347            serde_json::json!({ "type": "string" }),
348        );
349        definition.manifest.availability = ToolAvailabilityConfig::same(availability);
350        definition.manifest.activation = ToolActivation::Always;
351        definition.manifest.scheduling = ToolScheduling::Parallel;
352        definition
353    }
354
355    fn build_input(
356        tools: Vec<ToolDefinition>,
357        contributions: Vec<ToolCatalogContribution>,
358    ) -> ToolCatalogBuildInput {
359        let contracts = tools
360            .iter()
361            .map(|tool| (tool.name().to_string(), Arc::new(tool.contract())))
362            .collect::<BTreeMap<_, _>>();
363        ToolCatalogBuildInput {
364            tools: tools.into_iter().map(|tool| tool.manifest()).collect(),
365            resolve_contract: Some(Arc::new(move |name| contracts.get(name).cloned())),
366            contributions,
367        }
368    }
369
370    #[test]
371    fn catalog_splits_callable_and_showcased_tools() {
372        let surface = build_tool_catalog(build_input(
373            vec![
374                tool("search_tools", ToolAvailability::Showcased),
375                tool("read_file", ToolAvailability::Showcased),
376                tool("grep", ToolAvailability::Callable),
377                tool("privileged_tool", ToolAvailability::Searchable),
378            ],
379            Vec::new(),
380        ));
381
382        assert_eq!(surface.callable_tools().len(), 3);
383        assert_eq!(surface.showcased_tools().len(), 2);
384        assert_eq!(surface.omitted_tool_count(), 2);
385        assert!(!surface.prompt_tool_docs().contains("Catalogued tools"));
386    }
387
388    #[test]
389    fn explicit_contributions_override_availability() {
390        let surface = build_tool_catalog(build_input(
391            vec![tool("read_file", ToolAvailability::Showcased)],
392            vec![ToolCatalogContribution {
393                overrides: vec![ToolCatalogOverride {
394                    tool_name: "read_file".to_string(),
395                    availability: Some(ToolAvailability::Off),
396                }],
397                tool_list_notes: vec!["custom note".to_string()],
398            }],
399        ));
400
401        assert_eq!(
402            surface
403                .tools
404                .iter()
405                .find(|tool| tool.manifest.name == "read_file")
406                .expect("read_file present")
407                .availability,
408            ToolAvailability::Off
409        );
410        assert!(
411            surface
412                .tool_list_notes
413                .iter()
414                .any(|note| note == "custom note")
415        );
416    }
417
418    #[test]
419    fn prompt_gate_requires_matching_tool_availability() {
420        let surface = build_tool_catalog(build_input(
421            vec![tool("search_tools", ToolAvailability::Showcased)],
422            Vec::new(),
423        ));
424
425        let kept = surface.filter_prompt_contributions(vec![
426            PromptContribution::guidance("Plain", "always"),
427            PromptContribution::guidance("Discovery", "discover")
428                .requires_tool("search_tools", ToolAvailability::Showcased),
429            PromptContribution::guidance("Off", "off")
430                .requires_tool("missing_tool", ToolAvailability::Callable),
431        ]);
432
433        assert_eq!(kept.len(), 2);
434        assert!(
435            kept.iter()
436                .any(|contribution| contribution.title.as_deref() == Some("Plain"))
437        );
438        assert!(
439            kept.iter()
440                .any(|contribution| contribution.title.as_deref() == Some("Discovery"))
441        );
442    }
443
444    #[test]
445    fn rlm_catalog_does_not_resolve_searchable_only_contracts() {
446        let contract_resolutions = Arc::new(AtomicUsize::new(0));
447        let searchable = tool("large_schema", ToolAvailability::Searchable);
448        let showcased = tool("search_tools", ToolAvailability::Showcased);
449        let resolver_count = Arc::clone(&contract_resolutions);
450        let surface = build_tool_catalog(ToolCatalogBuildInput {
451            tools: vec![searchable.manifest(), showcased.manifest()],
452            resolve_contract: Some(Arc::new(move |name| {
453                resolver_count.fetch_add(1, Ordering::SeqCst);
454                match name {
455                    "large_schema" => Some(Arc::new(searchable.contract())),
456                    "search_tools" => Some(Arc::new(showcased.contract())),
457                    _ => None,
458                }
459            })),
460            contributions: Vec::new(),
461        });
462
463        assert_eq!(
464            surface.tool_availability("large_schema"),
465            Some(ToolAvailability::Searchable)
466        );
467        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 0);
468        assert!(!surface.prompt_tool_docs().contains("large_schema"));
469        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
470    }
471
472    #[test]
473    fn callable_only_catalog_resolves_model_specs_lazily() {
474        let contract_resolutions = Arc::new(AtomicUsize::new(0));
475        let callable = tool("large_callable", ToolAvailability::Callable);
476        let resolver_count = Arc::clone(&contract_resolutions);
477        let surface = build_tool_catalog(ToolCatalogBuildInput {
478            tools: vec![callable.manifest()],
479            resolve_contract: Some(Arc::new(move |name| {
480                resolver_count.fetch_add(1, Ordering::SeqCst);
481                (name == "large_callable").then(|| Arc::new(callable.contract()))
482            })),
483            contributions: Vec::new(),
484        });
485
486        assert_eq!(
487            surface.tool_names().as_ref(),
488            &vec!["large_callable".to_string()]
489        );
490        assert_eq!(surface.model_tool_specs().len(), 1);
491        assert_eq!(surface.prompt_tool_docs(), "");
492        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
493    }
494
495    #[test]
496    fn standard_catalog_resolves_model_specs_lazily() {
497        let contract_resolutions = Arc::new(AtomicUsize::new(0));
498        let callable = tool("read_file", ToolAvailability::Callable);
499        let resolver_count = Arc::clone(&contract_resolutions);
500        let surface = build_tool_catalog(ToolCatalogBuildInput {
501            tools: vec![callable.manifest()],
502            resolve_contract: Some(Arc::new(move |name| {
503                resolver_count.fetch_add(1, Ordering::SeqCst);
504                (name == "read_file").then(|| Arc::new(callable.contract()))
505            })),
506            contributions: Vec::new(),
507        });
508
509        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 0);
510        assert_eq!(surface.model_tool_specs().len(), 1);
511        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
512        assert_eq!(surface.model_tool_specs().len(), 1);
513        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
514    }
515
516    #[test]
517    fn tool_names_fingerprint_matches_prompt_hash() {
518        let surface = build_tool_catalog(build_input(
519            vec![
520                tool("read_file", ToolAvailability::Callable),
521                tool("search_tools", ToolAvailability::Showcased),
522            ],
523            Vec::new(),
524        ));
525
526        assert_eq!(
527            surface.tool_names_fingerprint(),
528            prompt_tool_names_fingerprint(&surface.tool_names())
529        );
530    }
531}