Skip to main content

lash_sansio/
tool_surface.rs

1use std::collections::BTreeMap;
2use std::sync::{Arc, OnceLock};
3
4use crate::llm::types::LlmToolSpec;
5use crate::{
6    ExecutionMode, PromptContribution, PromptFingerprint, ToolAvailability, ToolContract,
7    ToolDefinition, ToolManifest, prompt_tool_names_fingerprint,
8};
9
10pub type ToolContractResolver =
11    Arc<dyn Fn(&str) -> Option<Arc<ToolContract>> + Send + Sync + 'static>;
12
13#[derive(Clone)]
14pub struct ToolSurfaceBuildInput {
15    pub tools: Vec<ToolManifest>,
16    pub mode: ExecutionMode,
17    pub resolve_contract: Option<ToolContractResolver>,
18    pub contributions: Vec<ToolSurfaceContribution>,
19}
20
21#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
22pub struct ToolSurfaceContribution {
23    pub overrides: Vec<ToolSurfaceOverride>,
24    pub tool_list_notes: Vec<String>,
25}
26
27impl ToolSurfaceContribution {
28    pub fn is_empty(&self) -> bool {
29        self.overrides.is_empty() && self.tool_list_notes.is_empty()
30    }
31}
32
33#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
34pub struct ToolSurfaceOverride {
35    pub tool_name: String,
36    pub availability: Option<ToolAvailability>,
37}
38
39#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
40pub struct ToolSurfaceEntry {
41    pub manifest: ToolManifest,
42    pub availability: ToolAvailability,
43}
44
45#[derive(serde::Serialize, serde::Deserialize)]
46pub struct ToolSurface {
47    pub tools: Vec<ToolSurfaceEntry>,
48    pub tool_list_notes: Vec<String>,
49    #[serde(skip)]
50    resolve_contract: Option<ToolContractResolver>,
51    #[serde(skip)]
52    build_model_tool_specs: bool,
53    #[serde(skip)]
54    prompt_tool_docs: OnceLock<Arc<str>>,
55    #[serde(skip)]
56    model_tool_specs: OnceLock<Arc<Vec<LlmToolSpec>>>,
57    #[serde(skip)]
58    tool_names: OnceLock<Arc<Vec<String>>>,
59    #[serde(skip)]
60    tool_names_fingerprint: OnceLock<PromptFingerprint>,
61}
62
63impl Clone for ToolSurface {
64    fn clone(&self) -> Self {
65        let clone = Self {
66            tools: self.tools.clone(),
67            tool_list_notes: self.tool_list_notes.clone(),
68            resolve_contract: self.resolve_contract.clone(),
69            build_model_tool_specs: self.build_model_tool_specs,
70            prompt_tool_docs: OnceLock::new(),
71            model_tool_specs: OnceLock::new(),
72            tool_names: OnceLock::new(),
73            tool_names_fingerprint: OnceLock::new(),
74        };
75        if let Some(value) = self.prompt_tool_docs.get() {
76            let _ = clone.prompt_tool_docs.set(Arc::clone(value));
77        }
78        if let Some(value) = self.model_tool_specs.get() {
79            let _ = clone.model_tool_specs.set(Arc::clone(value));
80        }
81        if let Some(value) = self.tool_names.get() {
82            let _ = clone.tool_names.set(Arc::clone(value));
83        }
84        if let Some(value) = self.tool_names_fingerprint.get() {
85            let _ = clone.tool_names_fingerprint.set(*value);
86        }
87        clone
88    }
89}
90
91impl std::fmt::Debug for ToolSurface {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("ToolSurface")
94            .field("tools", &self.tools)
95            .field("tool_list_notes", &self.tool_list_notes)
96            .field("build_model_tool_specs", &self.build_model_tool_specs)
97            .finish_non_exhaustive()
98    }
99}
100
101impl Default for ToolSurface {
102    fn default() -> Self {
103        Self {
104            tools: Vec::new(),
105            tool_list_notes: Vec::new(),
106            resolve_contract: None,
107            build_model_tool_specs: false,
108            prompt_tool_docs: OnceLock::new(),
109            model_tool_specs: OnceLock::new(),
110            tool_names: OnceLock::new(),
111            tool_names_fingerprint: OnceLock::new(),
112        }
113    }
114}
115
116impl ToolSurface {
117    pub fn from_tool_definitions(tools: Vec<ToolDefinition>, mode: ExecutionMode) -> Self {
118        let contracts = tools
119            .iter()
120            .map(|tool| (tool.name.clone(), Arc::new(tool.contract())))
121            .collect();
122        Self::from_tools(
123            tools.into_iter().map(|tool| tool.manifest()).collect(),
124            mode,
125            contracts,
126        )
127    }
128
129    pub fn from_tools(
130        tools: Vec<ToolManifest>,
131        mode: ExecutionMode,
132        contracts: BTreeMap<String, Arc<ToolContract>>,
133    ) -> Self {
134        let resolver_contracts = Arc::new(contracts);
135        Self::from_tool_manifests(
136            tools,
137            &mode,
138            Some(Arc::new(move |name| resolver_contracts.get(name).cloned())),
139        )
140    }
141
142    fn from_tool_manifests(
143        tools: Vec<ToolManifest>,
144        mode: &ExecutionMode,
145        resolve_contract: Option<ToolContractResolver>,
146    ) -> Self {
147        Self {
148            tools: tools
149                .into_iter()
150                .map(|manifest| ToolSurfaceEntry {
151                    availability: manifest.effective_availability(mode),
152                    manifest,
153                })
154                .collect(),
155            tool_list_notes: Vec::new(),
156            resolve_contract,
157            build_model_tool_specs: mode.plugin_id() != "rlm",
158            prompt_tool_docs: OnceLock::new(),
159            model_tool_specs: OnceLock::new(),
160            tool_names: OnceLock::new(),
161            tool_names_fingerprint: OnceLock::new(),
162        }
163    }
164
165    pub fn callable_tools_iter(&self) -> impl Iterator<Item = &ToolManifest> {
166        self.tools
167            .iter()
168            .filter(|tool| tool.availability.is_callable())
169            .map(|tool| &tool.manifest)
170    }
171
172    pub fn callable_tools(&self) -> Vec<ToolManifest> {
173        self.callable_tools_iter().cloned().collect()
174    }
175
176    pub fn showcased_tools_iter(&self) -> impl Iterator<Item = &ToolManifest> {
177        self.tools
178            .iter()
179            .filter(|tool| tool.availability.is_showcased())
180            .map(|tool| &tool.manifest)
181    }
182
183    pub fn showcased_tools(&self) -> Vec<ToolManifest> {
184        self.showcased_tools_iter().cloned().collect()
185    }
186
187    pub fn searchable_tools_iter(&self) -> impl Iterator<Item = &ToolSurfaceEntry> {
188        self.tools
189            .iter()
190            .filter(|tool| tool.availability.is_searchable())
191    }
192
193    pub fn omitted_tools_iter(&self) -> impl Iterator<Item = &ToolSurfaceEntry> {
194        self.searchable_tools_iter()
195            .filter(|tool| !tool.availability.is_showcased())
196    }
197
198    pub fn has_callable_tool(&self, tool_name: &str) -> bool {
199        self.tools
200            .iter()
201            .any(|tool| tool.availability.is_callable() && tool.manifest.name == tool_name)
202    }
203
204    pub fn tool_availability(&self, tool_name: &str) -> Option<ToolAvailability> {
205        self.tools
206            .iter()
207            .find(|tool| tool.manifest.name == tool_name)
208            .map(|tool| tool.availability)
209    }
210
211    pub fn tool_names(&self) -> Arc<Vec<String>> {
212        Arc::clone(self.tool_names.get_or_init(|| {
213            Arc::new(
214                self.tools
215                    .iter()
216                    .filter(|tool| tool.availability.is_callable())
217                    .map(|tool| tool.manifest.name.clone())
218                    .collect(),
219            )
220        }))
221    }
222
223    pub fn tool_names_fingerprint(&self) -> PromptFingerprint {
224        *self
225            .tool_names_fingerprint
226            .get_or_init(|| prompt_tool_names_fingerprint(&self.tool_names()))
227    }
228
229    pub fn omitted_tool_count(&self) -> usize {
230        self.omitted_tools_iter().count()
231    }
232
233    pub fn model_tool_specs(&self) -> Arc<Vec<LlmToolSpec>> {
234        if !self.build_model_tool_specs {
235            return Arc::clone(self.model_tool_specs.get_or_init(|| Arc::new(Vec::new())));
236        }
237        Arc::clone(self.model_tool_specs.get_or_init(|| {
238            Arc::new(
239                self.tools
240                    .iter()
241                    .filter(|tool| tool.availability.is_callable())
242                    .filter_map(|tool| {
243                        self.resolve_contract(&tool.manifest.name)
244                            .map(|contract| contract.model_tool(&tool.manifest))
245                    })
246                    .map(|model_tool| LlmToolSpec {
247                        name: model_tool.name,
248                        description: model_tool.description,
249                        input_schema: model_tool.input_schema,
250                        output_schema: model_tool.output_schema,
251                        input_schema_projections: model_tool.input_schema_projections,
252                        output_schema_projections: model_tool.output_schema_projections,
253                    })
254                    .collect(),
255            )
256        }))
257    }
258
259    pub fn prompt_tool_docs(&self) -> &str {
260        self.prompt_tool_docs
261            .get_or_init(|| Arc::from(self.rendered_prompt_tool_docs()))
262            .as_ref()
263    }
264
265    fn resolve_contract(&self, tool_name: &str) -> Option<Arc<ToolContract>> {
266        self.resolve_contract
267            .as_ref()
268            .and_then(|resolve| resolve(tool_name))
269    }
270
271    fn rendered_prompt_tool_docs(&self) -> String {
272        let mut docs = self
273            .tools
274            .iter()
275            .filter(|tool| tool.availability.is_showcased())
276            .filter_map(|tool| {
277                self.resolve_contract(&tool.manifest.name)
278                    .map(|contract| contract.compact_contract(&tool.manifest).render_markdown())
279            })
280            .collect::<Vec<_>>()
281            .join("\n\n");
282        for note in &self.tool_list_notes {
283            let note = note.trim();
284            if note.is_empty() {
285                continue;
286            }
287            if !docs.is_empty() {
288                docs.push_str("\n\n");
289            }
290            docs.push_str(note);
291        }
292        docs
293    }
294
295    pub fn filter_prompt_contributions(
296        &self,
297        contributions: Vec<PromptContribution>,
298    ) -> Vec<PromptContribution> {
299        contributions
300            .into_iter()
301            .filter(|contribution| self.includes_prompt_contribution(contribution))
302            .collect()
303    }
304
305    fn includes_prompt_contribution(&self, contribution: &PromptContribution) -> bool {
306        if contribution.gate.is_empty() {
307            return true;
308        }
309        contribution.gate.tools.iter().any(|tool_name| {
310            self.tool_availability(tool_name)
311                .is_some_and(|availability| availability >= contribution.gate.minimum_availability)
312        })
313    }
314}
315
316pub fn build_tool_surface(input: ToolSurfaceBuildInput) -> ToolSurface {
317    let mode = input.mode;
318    let mut surface = ToolSurface::from_tool_manifests(input.tools, &mode, input.resolve_contract);
319    for contribution in input.contributions {
320        apply_contribution(&mut surface, contribution);
321    }
322    surface
323}
324
325fn apply_contribution(surface: &mut ToolSurface, contribution: ToolSurfaceContribution) {
326    for override_ in contribution.overrides {
327        if let Some(tool) = surface
328            .tools
329            .iter_mut()
330            .find(|tool| tool.manifest.name == override_.tool_name)
331            && let Some(availability) = override_.availability
332        {
333            tool.availability = availability;
334        }
335    }
336
337    surface.tool_list_notes.extend(
338        contribution
339            .tool_list_notes
340            .into_iter()
341            .map(|note| note.trim().to_string())
342            .filter(|note| !note.is_empty()),
343    );
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use crate::{ToolActivation, ToolAvailabilityConfig, ToolExecutionMode};
350    use std::sync::atomic::{AtomicUsize, Ordering};
351
352    fn tool(name: &str, availability: ToolAvailability) -> ToolDefinition {
353        let mut definition = ToolDefinition::raw(
354            name,
355            format!("Tool {name}"),
356            serde_json::json!({
357                "type": "object",
358                "properties": { "path": { "type": "string" } },
359                "required": ["path"]
360            }),
361            serde_json::json!({ "type": "string" }),
362        );
363        definition.availability = ToolAvailabilityConfig::same(availability);
364        definition.activation = ToolActivation::Always;
365        definition.execution_mode = ToolExecutionMode::Parallel;
366        definition
367    }
368
369    fn build_input(
370        tools: Vec<ToolDefinition>,
371        mode: crate::ExecutionMode,
372        contributions: Vec<ToolSurfaceContribution>,
373    ) -> ToolSurfaceBuildInput {
374        let contracts = tools
375            .iter()
376            .map(|tool| (tool.name.clone(), Arc::new(tool.contract())))
377            .collect::<BTreeMap<_, _>>();
378        ToolSurfaceBuildInput {
379            tools: tools.into_iter().map(|tool| tool.manifest()).collect(),
380            mode,
381            resolve_contract: Some(Arc::new(move |name| contracts.get(name).cloned())),
382            contributions,
383        }
384    }
385
386    #[test]
387    fn surface_splits_callable_and_showcased_tools() {
388        let surface = build_tool_surface(build_input(
389            vec![
390                tool("search_tools", ToolAvailability::Showcased),
391                tool("read_file", ToolAvailability::Showcased),
392                tool("grep", ToolAvailability::Callable),
393                tool("privileged_tool", ToolAvailability::Searchable),
394            ],
395            crate::ExecutionMode::new("test_mode"),
396            Vec::new(),
397        ));
398
399        assert_eq!(surface.callable_tools().len(), 3);
400        assert_eq!(surface.showcased_tools().len(), 2);
401        assert_eq!(surface.omitted_tool_count(), 2);
402        assert!(!surface.prompt_tool_docs().contains("Catalogued tools"));
403    }
404
405    #[test]
406    fn explicit_contributions_override_availability() {
407        let surface = build_tool_surface(build_input(
408            vec![tool("read_file", ToolAvailability::Showcased)],
409            crate::ExecutionMode::new("test_mode"),
410            vec![ToolSurfaceContribution {
411                overrides: vec![ToolSurfaceOverride {
412                    tool_name: "read_file".to_string(),
413                    availability: Some(ToolAvailability::Off),
414                }],
415                tool_list_notes: vec!["custom note".to_string()],
416            }],
417        ));
418
419        assert_eq!(
420            surface
421                .tools
422                .iter()
423                .find(|tool| tool.manifest.name == "read_file")
424                .expect("read_file present")
425                .availability,
426            ToolAvailability::Off
427        );
428        assert!(
429            surface
430                .tool_list_notes
431                .iter()
432                .any(|note| note == "custom note")
433        );
434    }
435
436    #[test]
437    fn prompt_gate_requires_matching_tool_availability() {
438        let surface = build_tool_surface(build_input(
439            vec![tool("search_tools", ToolAvailability::Showcased)],
440            crate::ExecutionMode::standard(),
441            Vec::new(),
442        ));
443
444        let kept = surface.filter_prompt_contributions(vec![
445            PromptContribution::guidance("Plain", "always"),
446            PromptContribution::guidance("Discovery", "discover")
447                .requires_tool("search_tools", ToolAvailability::Showcased),
448            PromptContribution::guidance("Off", "off")
449                .requires_tool("missing_tool", ToolAvailability::Callable),
450        ]);
451
452        assert_eq!(kept.len(), 2);
453        assert!(
454            kept.iter()
455                .any(|contribution| contribution.title.as_deref() == Some("Plain"))
456        );
457        assert!(
458            kept.iter()
459                .any(|contribution| contribution.title.as_deref() == Some("Discovery"))
460        );
461    }
462
463    #[test]
464    fn rlm_surface_does_not_resolve_searchable_only_contracts() {
465        let contract_resolutions = Arc::new(AtomicUsize::new(0));
466        let searchable = tool("large_schema", ToolAvailability::Searchable);
467        let showcased = tool("search_tools", ToolAvailability::Showcased);
468        let resolver_count = Arc::clone(&contract_resolutions);
469        let surface = build_tool_surface(ToolSurfaceBuildInput {
470            tools: vec![searchable.manifest(), showcased.manifest()],
471            mode: crate::ExecutionMode::new("rlm"),
472            resolve_contract: Some(Arc::new(move |name| {
473                resolver_count.fetch_add(1, Ordering::SeqCst);
474                match name {
475                    "large_schema" => Some(Arc::new(searchable.contract())),
476                    "search_tools" => Some(Arc::new(showcased.contract())),
477                    _ => None,
478                }
479            })),
480            contributions: Vec::new(),
481        });
482
483        assert_eq!(
484            surface.tool_availability("large_schema"),
485            Some(ToolAvailability::Searchable)
486        );
487        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 0);
488        assert!(!surface.prompt_tool_docs().contains("large_schema"));
489        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
490    }
491
492    #[test]
493    fn rlm_callable_only_surface_does_not_resolve_contracts_for_specs() {
494        let contract_resolutions = Arc::new(AtomicUsize::new(0));
495        let callable = tool("large_callable", ToolAvailability::Callable);
496        let resolver_count = Arc::clone(&contract_resolutions);
497        let surface = build_tool_surface(ToolSurfaceBuildInput {
498            tools: vec![callable.manifest()],
499            mode: crate::ExecutionMode::new("rlm"),
500            resolve_contract: Some(Arc::new(move |name| {
501                resolver_count.fetch_add(1, Ordering::SeqCst);
502                (name == "large_callable").then(|| Arc::new(callable.contract()))
503            })),
504            contributions: Vec::new(),
505        });
506
507        assert_eq!(
508            surface.tool_names().as_ref(),
509            &vec!["large_callable".to_string()]
510        );
511        assert_eq!(surface.model_tool_specs().len(), 0);
512        assert_eq!(surface.prompt_tool_docs(), "");
513        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 0);
514    }
515
516    #[test]
517    fn standard_surface_resolves_model_specs_lazily() {
518        let contract_resolutions = Arc::new(AtomicUsize::new(0));
519        let callable = tool("read_file", ToolAvailability::Callable);
520        let resolver_count = Arc::clone(&contract_resolutions);
521        let surface = build_tool_surface(ToolSurfaceBuildInput {
522            tools: vec![callable.manifest()],
523            mode: crate::ExecutionMode::standard(),
524            resolve_contract: Some(Arc::new(move |name| {
525                resolver_count.fetch_add(1, Ordering::SeqCst);
526                (name == "read_file").then(|| Arc::new(callable.contract()))
527            })),
528            contributions: Vec::new(),
529        });
530
531        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 0);
532        assert_eq!(surface.model_tool_specs().len(), 1);
533        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
534        assert_eq!(surface.model_tool_specs().len(), 1);
535        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
536    }
537
538    #[test]
539    fn tool_names_fingerprint_matches_prompt_hash() {
540        let surface = build_tool_surface(build_input(
541            vec![
542                tool("read_file", ToolAvailability::Callable),
543                tool("search_tools", ToolAvailability::Showcased),
544            ],
545            crate::ExecutionMode::standard(),
546            Vec::new(),
547        ));
548
549        assert_eq!(
550            surface.tool_names_fingerprint(),
551            prompt_tool_names_fingerprint(&surface.tool_names())
552        );
553    }
554}