1use std::collections::BTreeMap;
2use std::sync::{Arc, OnceLock};
3
4use crate::llm::types::LlmToolSpec;
5use crate::{
6 PromptContribution, PromptFingerprint, ToolContract, ToolDefinition, ToolManifest,
7 prompt_tool_names_fingerprint,
8};
9
10pub type ToolContractResolver =
11 Arc<dyn Fn(&str) -> Option<Arc<ToolContract>> + Send + Sync + 'static>;
12
13#[derive(Clone)]
14pub struct ToolCatalogBuildInput {
15 pub tools: Vec<ToolManifest>,
16 pub resolve_contract: Option<ToolContractResolver>,
17 pub contributions: Vec<ToolCatalogContribution>,
18}
19
20#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
25pub struct ToolCatalogContribution {
26 pub remove: Vec<String>,
28}
29
30impl ToolCatalogContribution {
31 pub fn is_empty(&self) -> bool {
32 self.remove.is_empty()
33 }
34
35 pub fn remove_tools(tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
36 Self {
37 remove: tools.into_iter().map(Into::into).collect(),
38 }
39 }
40}
41
42#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
43pub struct ToolCatalogEntry {
44 pub manifest: ToolManifest,
45}
46
47#[derive(serde::Serialize, serde::Deserialize)]
48pub struct ToolCatalog {
49 pub tools: Vec<ToolCatalogEntry>,
50 #[serde(skip)]
51 resolve_contract: Option<ToolContractResolver>,
52 #[serde(skip)]
53 model_tool_specs: OnceLock<Arc<Vec<LlmToolSpec>>>,
54 #[serde(skip)]
55 tool_names: OnceLock<Arc<Vec<String>>>,
56 #[serde(skip)]
57 tool_names_fingerprint: OnceLock<PromptFingerprint>,
58}
59
60impl Clone for ToolCatalog {
61 fn clone(&self) -> Self {
62 let clone = Self {
63 tools: self.tools.clone(),
64 resolve_contract: self.resolve_contract.clone(),
65 model_tool_specs: OnceLock::new(),
66 tool_names: OnceLock::new(),
67 tool_names_fingerprint: OnceLock::new(),
68 };
69 if let Some(value) = self.model_tool_specs.get() {
70 let _ = clone.model_tool_specs.set(Arc::clone(value));
71 }
72 if let Some(value) = self.tool_names.get() {
73 let _ = clone.tool_names.set(Arc::clone(value));
74 }
75 if let Some(value) = self.tool_names_fingerprint.get() {
76 let _ = clone.tool_names_fingerprint.set(*value);
77 }
78 clone
79 }
80}
81
82impl std::fmt::Debug for ToolCatalog {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 f.debug_struct("ToolCatalog")
85 .field("tools", &self.tools)
86 .finish_non_exhaustive()
87 }
88}
89
90impl Default for ToolCatalog {
91 fn default() -> Self {
92 Self {
93 tools: Vec::new(),
94 resolve_contract: None,
95 model_tool_specs: OnceLock::new(),
96 tool_names: OnceLock::new(),
97 tool_names_fingerprint: OnceLock::new(),
98 }
99 }
100}
101
102impl ToolCatalog {
103 pub fn from_tool_definitions(tools: Vec<ToolDefinition>) -> Self {
104 let contracts = tools
105 .iter()
106 .map(|tool| (tool.name().to_string(), Arc::new(tool.contract())))
107 .collect();
108 Self::from_tools(
109 tools.into_iter().map(|tool| tool.manifest()).collect(),
110 contracts,
111 )
112 }
113
114 pub fn from_tools(
115 tools: Vec<ToolManifest>,
116 contracts: BTreeMap<String, Arc<ToolContract>>,
117 ) -> Self {
118 let resolver_contracts = Arc::new(contracts);
119 Self::from_tool_manifests(
120 tools,
121 Some(Arc::new(move |name| resolver_contracts.get(name).cloned())),
122 )
123 }
124
125 fn from_tool_manifests(
126 tools: Vec<ToolManifest>,
127 resolve_contract: Option<ToolContractResolver>,
128 ) -> Self {
129 Self {
130 tools: tools
131 .into_iter()
132 .map(|manifest| ToolCatalogEntry { manifest })
133 .collect(),
134 resolve_contract,
135 model_tool_specs: OnceLock::new(),
136 tool_names: OnceLock::new(),
137 tool_names_fingerprint: OnceLock::new(),
138 }
139 }
140
141 pub fn callable_tools_iter(&self) -> impl Iterator<Item = &ToolManifest> {
143 self.tools.iter().map(|tool| &tool.manifest)
144 }
145
146 pub fn callable_tools(&self) -> Vec<ToolManifest> {
147 self.callable_tools_iter().cloned().collect()
148 }
149
150 pub fn has_callable_tool(&self, tool_name: &str) -> bool {
153 self.tools
154 .iter()
155 .any(|tool| tool.manifest.name == tool_name)
156 }
157
158 pub fn tool_names(&self) -> Arc<Vec<String>> {
159 Arc::clone(self.tool_names.get_or_init(|| {
160 Arc::new(
161 self.tools
162 .iter()
163 .map(|tool| tool.manifest.name.clone())
164 .collect(),
165 )
166 }))
167 }
168
169 pub fn tool_names_fingerprint(&self) -> PromptFingerprint {
170 *self
171 .tool_names_fingerprint
172 .get_or_init(|| prompt_tool_names_fingerprint(&self.tool_names()))
173 }
174
175 pub fn model_tool_specs(&self) -> Arc<Vec<LlmToolSpec>> {
176 Arc::clone(self.model_tool_specs.get_or_init(|| {
177 Arc::new(
178 self.tools
179 .iter()
180 .filter_map(|tool| {
181 self.resolve_contract(&tool.manifest.name)
182 .map(|contract| contract.model_tool(&tool.manifest))
183 })
184 .map(|model_tool| LlmToolSpec {
185 name: model_tool.name,
186 description: model_tool.description,
187 input_schema: model_tool.input_schema,
188 output_schema: model_tool.output_schema,
189 })
190 .collect(),
191 )
192 }))
193 }
194
195 pub fn resolve_contract(&self, tool_name: &str) -> Option<Arc<ToolContract>> {
196 self.resolve_contract
197 .as_ref()
198 .and_then(|resolve| resolve(tool_name))
199 }
200
201 pub fn filter_prompt_contributions(
202 &self,
203 contributions: Vec<PromptContribution>,
204 ) -> Vec<PromptContribution> {
205 contributions
206 .into_iter()
207 .filter(|contribution| self.includes_prompt_contribution(contribution))
208 .collect()
209 }
210
211 fn includes_prompt_contribution(&self, contribution: &PromptContribution) -> bool {
212 if contribution.gate.is_empty() {
213 return true;
214 }
215 contribution
216 .gate
217 .tools
218 .iter()
219 .any(|tool_name| self.has_callable_tool(tool_name))
220 }
221}
222
223pub fn build_tool_catalog(input: ToolCatalogBuildInput) -> ToolCatalog {
224 let mut catalog = ToolCatalog::from_tool_manifests(input.tools, input.resolve_contract);
225 for contribution in input.contributions {
226 apply_contribution(&mut catalog, contribution);
227 }
228 catalog
229}
230
231fn apply_contribution(catalog: &mut ToolCatalog, contribution: ToolCatalogContribution) {
232 if contribution.remove.is_empty() {
233 return;
234 }
235 catalog
236 .tools
237 .retain(|tool| !contribution.remove.contains(&tool.manifest.name));
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use crate::{ToolActivation, ToolScheduling};
244 use std::sync::atomic::{AtomicUsize, Ordering};
245
246 fn tool(name: &str) -> ToolDefinition {
247 let mut definition = ToolDefinition::raw(
248 format!("tool:{name}"),
249 name,
250 format!("Tool {name}"),
251 serde_json::json!({
252 "type": "object",
253 "properties": { "path": { "type": "string" } },
254 "required": ["path"]
255 }),
256 serde_json::json!({ "type": "string" }),
257 );
258 definition.manifest.activation = ToolActivation::Always;
259 definition.manifest.scheduling = ToolScheduling::Parallel;
260 definition
261 }
262
263 fn build_input(
264 tools: Vec<ToolDefinition>,
265 contributions: Vec<ToolCatalogContribution>,
266 ) -> ToolCatalogBuildInput {
267 let contracts = tools
268 .iter()
269 .map(|tool| (tool.name().to_string(), Arc::new(tool.contract())))
270 .collect::<BTreeMap<_, _>>();
271 ToolCatalogBuildInput {
272 tools: tools.into_iter().map(|tool| tool.manifest()).collect(),
273 resolve_contract: Some(Arc::new(move |name| contracts.get(name).cloned())),
274 contributions,
275 }
276 }
277
278 #[test]
279 fn catalog_membership_is_flat_and_callable() {
280 let catalog = build_tool_catalog(build_input(
281 vec![tool("read_file"), tool("grep"), tool("write_file")],
282 Vec::new(),
283 ));
284
285 assert_eq!(catalog.callable_tools().len(), 3);
286 assert!(catalog.has_callable_tool("read_file"));
287 assert!(catalog.has_callable_tool("grep"));
288 assert!(!catalog.has_callable_tool("absent"));
289 }
290
291 #[test]
292 fn contributions_remove_members() {
293 let catalog = build_tool_catalog(build_input(
294 vec![tool("read_file"), tool("write_file")],
295 vec![ToolCatalogContribution::remove_tools(["write_file"])],
296 ));
297
298 assert!(catalog.has_callable_tool("read_file"));
299 assert!(!catalog.has_callable_tool("write_file"));
300 assert_eq!(catalog.callable_tools().len(), 1);
301 }
302
303 #[test]
304 fn prompt_gate_requires_member_tool() {
305 let catalog = build_tool_catalog(build_input(vec![tool("read_file")], Vec::new()));
306
307 let kept = catalog.filter_prompt_contributions(vec![
308 PromptContribution::guidance("Plain", "always"),
309 PromptContribution::guidance("WithTool", "withtool").requires_tool("read_file"),
310 PromptContribution::guidance("MissingTool", "missing").requires_tool("missing_tool"),
311 ]);
312
313 assert_eq!(kept.len(), 2);
314 assert!(
315 kept.iter()
316 .any(|contribution| contribution.title.as_deref() == Some("Plain"))
317 );
318 assert!(
319 kept.iter()
320 .any(|contribution| contribution.title.as_deref() == Some("WithTool"))
321 );
322 }
323
324 #[test]
325 fn model_specs_resolve_lazily() {
326 let contract_resolutions = Arc::new(AtomicUsize::new(0));
327 let callable = tool("read_file");
328 let resolver_count = Arc::clone(&contract_resolutions);
329 let catalog = build_tool_catalog(ToolCatalogBuildInput {
330 tools: vec![callable.manifest()],
331 resolve_contract: Some(Arc::new(move |name| {
332 resolver_count.fetch_add(1, Ordering::SeqCst);
333 (name == "read_file").then(|| Arc::new(callable.contract()))
334 })),
335 contributions: Vec::new(),
336 });
337
338 assert_eq!(contract_resolutions.load(Ordering::SeqCst), 0);
339 assert_eq!(catalog.model_tool_specs().len(), 1);
340 assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
341 assert_eq!(catalog.model_tool_specs().len(), 1);
342 assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
343 }
344
345 #[test]
346 fn tool_names_fingerprint_matches_prompt_hash() {
347 let catalog = build_tool_catalog(build_input(
348 vec![tool("read_file"), tool("grep")],
349 Vec::new(),
350 ));
351
352 assert_eq!(
353 catalog.tool_names_fingerprint(),
354 prompt_tool_names_fingerprint(&catalog.tool_names())
355 );
356 }
357}