1use std::collections::BTreeMap;
2use std::sync::{Arc, OnceLock};
3
4use crate::llm::types::LlmToolSpec;
5use crate::{
6 PromptContribution, PromptFingerprint, ToolContract, ToolDefinition, ToolManifest,
7 prompt_tool_names_fingerprint,
8};
9
10pub type ToolContractResolver =
11 Arc<dyn Fn(&str) -> Option<Arc<ToolContract>> + Send + Sync + 'static>;
12
13#[derive(Clone)]
14pub struct ToolCatalogBuildInput {
15 pub tools: Vec<ToolManifest>,
16 pub resolve_contract: Option<ToolContractResolver>,
17 pub contributions: Vec<ToolCatalogContribution>,
18}
19
20#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
25pub struct ToolCatalogContribution {
26 pub remove: Vec<String>,
28}
29
30impl ToolCatalogContribution {
31 pub fn is_empty(&self) -> bool {
32 self.remove.is_empty()
33 }
34
35 pub fn remove_tools(tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
36 Self {
37 remove: tools.into_iter().map(Into::into).collect(),
38 }
39 }
40}
41
42#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
43pub struct ToolCatalogEntry {
44 pub manifest: ToolManifest,
45}
46
47#[derive(serde::Serialize, serde::Deserialize)]
48pub struct ToolCatalog {
49 pub tools: Vec<ToolCatalogEntry>,
50 #[serde(skip)]
51 resolve_contract: Option<ToolContractResolver>,
52 #[serde(skip)]
53 model_tool_specs: OnceLock<Arc<Vec<LlmToolSpec>>>,
54 #[serde(skip)]
55 tool_names: OnceLock<Arc<Vec<String>>>,
56 #[serde(skip)]
57 tool_names_fingerprint: OnceLock<PromptFingerprint>,
58}
59
60impl Clone for ToolCatalog {
61 fn clone(&self) -> Self {
62 let clone = Self {
63 tools: self.tools.clone(),
64 resolve_contract: self.resolve_contract.clone(),
65 model_tool_specs: OnceLock::new(),
66 tool_names: OnceLock::new(),
67 tool_names_fingerprint: OnceLock::new(),
68 };
69 if let Some(value) = self.model_tool_specs.get() {
70 let _ = clone.model_tool_specs.set(Arc::clone(value));
71 }
72 if let Some(value) = self.tool_names.get() {
73 let _ = clone.tool_names.set(Arc::clone(value));
74 }
75 if let Some(value) = self.tool_names_fingerprint.get() {
76 let _ = clone.tool_names_fingerprint.set(*value);
77 }
78 clone
79 }
80}
81
82impl std::fmt::Debug for ToolCatalog {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 f.debug_struct("ToolCatalog")
85 .field("tools", &self.tools)
86 .finish_non_exhaustive()
87 }
88}
89
90impl Default for ToolCatalog {
91 fn default() -> Self {
92 Self {
93 tools: Vec::new(),
94 resolve_contract: None,
95 model_tool_specs: OnceLock::new(),
96 tool_names: OnceLock::new(),
97 tool_names_fingerprint: OnceLock::new(),
98 }
99 }
100}
101
102impl ToolCatalog {
103 pub fn from_tool_definitions(tools: Vec<ToolDefinition>) -> Self {
104 let contracts = tools
105 .iter()
106 .map(|tool| (tool.name().to_string(), Arc::new(tool.contract())))
107 .collect();
108 Self::from_tools(
109 tools.into_iter().map(|tool| tool.manifest()).collect(),
110 contracts,
111 )
112 }
113
114 pub fn from_tools(
115 tools: Vec<ToolManifest>,
116 contracts: BTreeMap<String, Arc<ToolContract>>,
117 ) -> Self {
118 let resolver_contracts = Arc::new(contracts);
119 Self::from_tool_manifests(
120 tools,
121 Some(Arc::new(move |name| resolver_contracts.get(name).cloned())),
122 )
123 }
124
125 fn from_tool_manifests(
126 tools: Vec<ToolManifest>,
127 resolve_contract: Option<ToolContractResolver>,
128 ) -> Self {
129 Self {
130 tools: tools
131 .into_iter()
132 .map(|manifest| ToolCatalogEntry { manifest })
133 .collect(),
134 resolve_contract,
135 model_tool_specs: OnceLock::new(),
136 tool_names: OnceLock::new(),
137 tool_names_fingerprint: OnceLock::new(),
138 }
139 }
140
141 pub fn callable_tools_iter(&self) -> impl Iterator<Item = &ToolManifest> {
143 self.tools.iter().map(|tool| &tool.manifest)
144 }
145
146 pub fn callable_tools(&self) -> Vec<ToolManifest> {
147 self.callable_tools_iter().cloned().collect()
148 }
149
150 pub fn has_callable_tool(&self, tool_name: &str) -> bool {
153 self.tools
154 .iter()
155 .any(|tool| tool.manifest.name == tool_name)
156 }
157
158 pub fn tool_names(&self) -> Arc<Vec<String>> {
159 Arc::clone(self.tool_names.get_or_init(|| {
160 Arc::new(
161 self.tools
162 .iter()
163 .map(|tool| tool.manifest.name.clone())
164 .collect(),
165 )
166 }))
167 }
168
169 pub fn tool_names_fingerprint(&self) -> PromptFingerprint {
170 *self
171 .tool_names_fingerprint
172 .get_or_init(|| prompt_tool_names_fingerprint(&self.tool_names()))
173 }
174
175 pub fn model_tool_specs(&self) -> Arc<Vec<LlmToolSpec>> {
176 Arc::clone(self.model_tool_specs.get_or_init(|| {
177 Arc::new(
178 self.tools
179 .iter()
180 .filter_map(|tool| {
181 self.resolve_contract(&tool.manifest.name)
182 .map(|contract| contract.model_tool(&tool.manifest))
183 })
184 .map(|model_tool| LlmToolSpec {
185 name: model_tool.name,
186 description: model_tool.description,
187 input_schema: model_tool.input_schema,
188 output_schema: model_tool.output_schema,
189 input_schema_projections: model_tool.input_schema_projections,
190 output_schema_projections: model_tool.output_schema_projections,
191 })
192 .collect(),
193 )
194 }))
195 }
196
197 pub fn resolve_contract(&self, tool_name: &str) -> Option<Arc<ToolContract>> {
198 self.resolve_contract
199 .as_ref()
200 .and_then(|resolve| resolve(tool_name))
201 }
202
203 pub fn filter_prompt_contributions(
204 &self,
205 contributions: Vec<PromptContribution>,
206 ) -> Vec<PromptContribution> {
207 contributions
208 .into_iter()
209 .filter(|contribution| self.includes_prompt_contribution(contribution))
210 .collect()
211 }
212
213 fn includes_prompt_contribution(&self, contribution: &PromptContribution) -> bool {
214 if contribution.gate.is_empty() {
215 return true;
216 }
217 contribution
218 .gate
219 .tools
220 .iter()
221 .any(|tool_name| self.has_callable_tool(tool_name))
222 }
223}
224
225pub fn build_tool_catalog(input: ToolCatalogBuildInput) -> ToolCatalog {
226 let mut catalog = ToolCatalog::from_tool_manifests(input.tools, input.resolve_contract);
227 for contribution in input.contributions {
228 apply_contribution(&mut catalog, contribution);
229 }
230 catalog
231}
232
233fn apply_contribution(catalog: &mut ToolCatalog, contribution: ToolCatalogContribution) {
234 if contribution.remove.is_empty() {
235 return;
236 }
237 catalog
238 .tools
239 .retain(|tool| !contribution.remove.contains(&tool.manifest.name));
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use crate::{ToolActivation, ToolScheduling};
246 use std::sync::atomic::{AtomicUsize, Ordering};
247
248 fn tool(name: &str) -> ToolDefinition {
249 let mut definition = ToolDefinition::raw(
250 format!("tool:{name}"),
251 name,
252 format!("Tool {name}"),
253 serde_json::json!({
254 "type": "object",
255 "properties": { "path": { "type": "string" } },
256 "required": ["path"]
257 }),
258 serde_json::json!({ "type": "string" }),
259 );
260 definition.manifest.activation = ToolActivation::Always;
261 definition.manifest.scheduling = ToolScheduling::Parallel;
262 definition
263 }
264
265 fn build_input(
266 tools: Vec<ToolDefinition>,
267 contributions: Vec<ToolCatalogContribution>,
268 ) -> ToolCatalogBuildInput {
269 let contracts = tools
270 .iter()
271 .map(|tool| (tool.name().to_string(), Arc::new(tool.contract())))
272 .collect::<BTreeMap<_, _>>();
273 ToolCatalogBuildInput {
274 tools: tools.into_iter().map(|tool| tool.manifest()).collect(),
275 resolve_contract: Some(Arc::new(move |name| contracts.get(name).cloned())),
276 contributions,
277 }
278 }
279
280 #[test]
281 fn catalog_membership_is_flat_and_callable() {
282 let catalog = build_tool_catalog(build_input(
283 vec![tool("read_file"), tool("grep"), tool("write_file")],
284 Vec::new(),
285 ));
286
287 assert_eq!(catalog.callable_tools().len(), 3);
288 assert!(catalog.has_callable_tool("read_file"));
289 assert!(catalog.has_callable_tool("grep"));
290 assert!(!catalog.has_callable_tool("absent"));
291 }
292
293 #[test]
294 fn contributions_remove_members() {
295 let catalog = build_tool_catalog(build_input(
296 vec![tool("read_file"), tool("write_file")],
297 vec![ToolCatalogContribution::remove_tools(["write_file"])],
298 ));
299
300 assert!(catalog.has_callable_tool("read_file"));
301 assert!(!catalog.has_callable_tool("write_file"));
302 assert_eq!(catalog.callable_tools().len(), 1);
303 }
304
305 #[test]
306 fn prompt_gate_requires_member_tool() {
307 let catalog = build_tool_catalog(build_input(vec![tool("read_file")], Vec::new()));
308
309 let kept = catalog.filter_prompt_contributions(vec![
310 PromptContribution::guidance("Plain", "always"),
311 PromptContribution::guidance("WithTool", "withtool").requires_tool("read_file"),
312 PromptContribution::guidance("MissingTool", "missing").requires_tool("missing_tool"),
313 ]);
314
315 assert_eq!(kept.len(), 2);
316 assert!(
317 kept.iter()
318 .any(|contribution| contribution.title.as_deref() == Some("Plain"))
319 );
320 assert!(
321 kept.iter()
322 .any(|contribution| contribution.title.as_deref() == Some("WithTool"))
323 );
324 }
325
326 #[test]
327 fn model_specs_resolve_lazily() {
328 let contract_resolutions = Arc::new(AtomicUsize::new(0));
329 let callable = tool("read_file");
330 let resolver_count = Arc::clone(&contract_resolutions);
331 let catalog = build_tool_catalog(ToolCatalogBuildInput {
332 tools: vec![callable.manifest()],
333 resolve_contract: Some(Arc::new(move |name| {
334 resolver_count.fetch_add(1, Ordering::SeqCst);
335 (name == "read_file").then(|| Arc::new(callable.contract()))
336 })),
337 contributions: Vec::new(),
338 });
339
340 assert_eq!(contract_resolutions.load(Ordering::SeqCst), 0);
341 assert_eq!(catalog.model_tool_specs().len(), 1);
342 assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
343 assert_eq!(catalog.model_tool_specs().len(), 1);
344 assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
345 }
346
347 #[test]
348 fn tool_names_fingerprint_matches_prompt_hash() {
349 let catalog = build_tool_catalog(build_input(
350 vec![tool("read_file"), tool("grep")],
351 Vec::new(),
352 ));
353
354 assert_eq!(
355 catalog.tool_names_fingerprint(),
356 prompt_tool_names_fingerprint(&catalog.tool_names())
357 );
358 }
359}