Skip to main content

lash_sansio/
tool_surface.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::sync::{Arc, OnceLock};
3
4use crate::llm::types::LlmToolSpec;
5use crate::{
6    PromptContribution, PromptFingerprint, ToolAvailability, ToolContract, ToolDefinition,
7    ToolManifest, prompt_tool_names_fingerprint,
8};
9
10pub type ToolContractResolver =
11    Arc<dyn Fn(&str) -> Option<Arc<ToolContract>> + Send + Sync + 'static>;
12
13#[derive(Clone)]
14pub struct ToolSurfaceBuildInput {
15    pub tools: Vec<ToolManifest>,
16    pub resolve_contract: Option<ToolContractResolver>,
17    pub contributions: Vec<ToolSurfaceContribution>,
18}
19
20#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
21pub struct ToolSurfaceContribution {
22    pub overrides: Vec<ToolSurfaceOverride>,
23    pub tool_list_notes: Vec<String>,
24}
25
26impl ToolSurfaceContribution {
27    pub fn is_empty(&self) -> bool {
28        self.overrides.is_empty() && self.tool_list_notes.is_empty()
29    }
30}
31
32#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
33pub struct ToolSurfaceOverride {
34    pub tool_name: String,
35    pub availability: Option<ToolAvailability>,
36}
37
38#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
39pub struct ToolSurfaceEntry {
40    pub manifest: ToolManifest,
41    pub availability: ToolAvailability,
42}
43
44#[derive(serde::Serialize, serde::Deserialize)]
45pub struct ToolSurface {
46    pub tools: Vec<ToolSurfaceEntry>,
47    pub tool_list_notes: Vec<String>,
48    #[serde(skip)]
49    resolve_contract: Option<ToolContractResolver>,
50    #[serde(skip)]
51    prompt_tool_docs: OnceLock<Arc<str>>,
52    #[serde(skip)]
53    model_tool_specs: OnceLock<Arc<Vec<LlmToolSpec>>>,
54    #[serde(skip)]
55    tool_names: OnceLock<Arc<Vec<String>>>,
56    #[serde(skip)]
57    tool_names_fingerprint: OnceLock<PromptFingerprint>,
58}
59
60impl Clone for ToolSurface {
61    fn clone(&self) -> Self {
62        let clone = Self {
63            tools: self.tools.clone(),
64            tool_list_notes: self.tool_list_notes.clone(),
65            resolve_contract: self.resolve_contract.clone(),
66            prompt_tool_docs: OnceLock::new(),
67            model_tool_specs: OnceLock::new(),
68            tool_names: OnceLock::new(),
69            tool_names_fingerprint: OnceLock::new(),
70        };
71        if let Some(value) = self.prompt_tool_docs.get() {
72            let _ = clone.prompt_tool_docs.set(Arc::clone(value));
73        }
74        if let Some(value) = self.model_tool_specs.get() {
75            let _ = clone.model_tool_specs.set(Arc::clone(value));
76        }
77        if let Some(value) = self.tool_names.get() {
78            let _ = clone.tool_names.set(Arc::clone(value));
79        }
80        if let Some(value) = self.tool_names_fingerprint.get() {
81            let _ = clone.tool_names_fingerprint.set(*value);
82        }
83        clone
84    }
85}
86
87impl std::fmt::Debug for ToolSurface {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("ToolSurface")
90            .field("tools", &self.tools)
91            .field("tool_list_notes", &self.tool_list_notes)
92            .finish_non_exhaustive()
93    }
94}
95
96impl Default for ToolSurface {
97    fn default() -> Self {
98        Self {
99            tools: Vec::new(),
100            tool_list_notes: Vec::new(),
101            resolve_contract: None,
102            prompt_tool_docs: OnceLock::new(),
103            model_tool_specs: OnceLock::new(),
104            tool_names: OnceLock::new(),
105            tool_names_fingerprint: OnceLock::new(),
106        }
107    }
108}
109
110impl ToolSurface {
111    pub fn from_tool_definitions(tools: Vec<ToolDefinition>) -> Self {
112        let contracts = tools
113            .iter()
114            .map(|tool| (tool.name().to_string(), Arc::new(tool.contract())))
115            .collect();
116        Self::from_tools(
117            tools.into_iter().map(|tool| tool.manifest()).collect(),
118            contracts,
119        )
120    }
121
122    pub fn from_tools(
123        tools: Vec<ToolManifest>,
124        contracts: BTreeMap<String, Arc<ToolContract>>,
125    ) -> Self {
126        let resolver_contracts = Arc::new(contracts);
127        Self::from_tool_manifests(
128            tools,
129            Some(Arc::new(move |name| resolver_contracts.get(name).cloned())),
130        )
131    }
132
133    fn from_tool_manifests(
134        tools: Vec<ToolManifest>,
135        resolve_contract: Option<ToolContractResolver>,
136    ) -> Self {
137        Self {
138            tools: tools
139                .into_iter()
140                .map(|manifest| ToolSurfaceEntry {
141                    availability: manifest.effective_availability(),
142                    manifest,
143                })
144                .collect(),
145            tool_list_notes: Vec::new(),
146            resolve_contract,
147            prompt_tool_docs: OnceLock::new(),
148            model_tool_specs: OnceLock::new(),
149            tool_names: OnceLock::new(),
150            tool_names_fingerprint: OnceLock::new(),
151        }
152    }
153
154    pub fn callable_tools_iter(&self) -> impl Iterator<Item = &ToolManifest> {
155        self.tools
156            .iter()
157            .filter(|tool| tool.availability.is_callable())
158            .map(|tool| &tool.manifest)
159    }
160
161    pub fn callable_tools(&self) -> Vec<ToolManifest> {
162        self.callable_tools_iter().cloned().collect()
163    }
164
165    pub fn showcased_tools_iter(&self) -> impl Iterator<Item = &ToolManifest> {
166        self.tools
167            .iter()
168            .filter(|tool| tool.availability.is_showcased())
169            .map(|tool| &tool.manifest)
170    }
171
172    pub fn showcased_tools(&self) -> Vec<ToolManifest> {
173        self.showcased_tools_iter().cloned().collect()
174    }
175
176    pub fn searchable_tools_iter(&self) -> impl Iterator<Item = &ToolSurfaceEntry> {
177        self.tools
178            .iter()
179            .filter(|tool| tool.availability.is_searchable())
180    }
181
182    pub fn omitted_tools_iter(&self) -> impl Iterator<Item = &ToolSurfaceEntry> {
183        self.searchable_tools_iter()
184            .filter(|tool| !tool.availability.is_showcased())
185    }
186
187    pub fn has_callable_tool(&self, tool_name: &str) -> bool {
188        self.tools
189            .iter()
190            .any(|tool| tool.availability.is_callable() && tool.manifest.name == tool_name)
191    }
192
193    pub fn tool_availability(&self, tool_name: &str) -> Option<ToolAvailability> {
194        self.tools
195            .iter()
196            .find(|tool| tool.manifest.name == tool_name)
197            .map(|tool| tool.availability)
198    }
199
200    pub fn tool_names(&self) -> Arc<Vec<String>> {
201        Arc::clone(self.tool_names.get_or_init(|| {
202            Arc::new(
203                self.tools
204                    .iter()
205                    .filter(|tool| tool.availability.is_callable())
206                    .map(|tool| tool.manifest.name.clone())
207                    .collect(),
208            )
209        }))
210    }
211
212    pub fn tool_names_fingerprint(&self) -> PromptFingerprint {
213        *self
214            .tool_names_fingerprint
215            .get_or_init(|| prompt_tool_names_fingerprint(&self.tool_names()))
216    }
217
218    pub fn omitted_tool_count(&self) -> usize {
219        self.omitted_tools_iter().count()
220    }
221
222    pub fn model_tool_specs(&self) -> Arc<Vec<LlmToolSpec>> {
223        Arc::clone(self.model_tool_specs.get_or_init(|| {
224            Arc::new(
225                self.tools
226                    .iter()
227                    .filter(|tool| tool.availability.is_callable())
228                    .filter_map(|tool| {
229                        self.resolve_contract(&tool.manifest.name)
230                            .map(|contract| contract.model_tool(&tool.manifest))
231                    })
232                    .map(|model_tool| LlmToolSpec {
233                        name: model_tool.name,
234                        description: model_tool.description,
235                        input_schema: model_tool.input_schema,
236                        output_schema: model_tool.output_schema,
237                        input_schema_projections: model_tool.input_schema_projections,
238                        output_schema_projections: model_tool.output_schema_projections,
239                    })
240                    .collect(),
241            )
242        }))
243    }
244
245    pub fn prompt_tool_docs(&self) -> &str {
246        self.prompt_tool_docs
247            .get_or_init(|| Arc::from(self.rendered_prompt_tool_docs()))
248            .as_ref()
249    }
250
251    fn resolve_contract(&self, tool_name: &str) -> Option<Arc<ToolContract>> {
252        self.resolve_contract
253            .as_ref()
254            .and_then(|resolve| resolve(tool_name))
255    }
256
257    fn rendered_prompt_tool_docs(&self) -> String {
258        let mut docs = self
259            .tools
260            .iter()
261            .filter(|tool| tool.availability.is_showcased())
262            .filter_map(|tool| {
263                self.resolve_contract(&tool.manifest.name)
264                    .map(|contract| contract.compact_contract(&tool.manifest).render_markdown())
265            })
266            .collect::<Vec<_>>()
267            .join("\n\n");
268        for note in &self.tool_list_notes {
269            let note = note.trim();
270            if note.is_empty() {
271                continue;
272            }
273            if !docs.is_empty() {
274                docs.push_str("\n\n");
275            }
276            docs.push_str(note);
277        }
278        docs
279    }
280
281    pub fn filter_prompt_contributions(
282        &self,
283        contributions: Vec<PromptContribution>,
284    ) -> Vec<PromptContribution> {
285        contributions
286            .into_iter()
287            .filter(|contribution| self.includes_prompt_contribution(contribution))
288            .collect()
289    }
290
291    fn includes_prompt_contribution(&self, contribution: &PromptContribution) -> bool {
292        if contribution.gate.is_empty() {
293            return true;
294        }
295        contribution.gate.tools.iter().any(|tool_name| {
296            self.tool_availability(tool_name)
297                .is_some_and(|availability| availability >= contribution.gate.minimum_availability)
298        })
299    }
300}
301
302pub fn build_tool_surface(input: ToolSurfaceBuildInput) -> ToolSurface {
303    let mut surface = ToolSurface::from_tool_manifests(input.tools, input.resolve_contract);
304    for contribution in input.contributions {
305        apply_contribution(&mut surface, contribution);
306    }
307    validate_agent_surface(&surface.tools);
308    surface
309}
310
311fn apply_contribution(surface: &mut ToolSurface, contribution: ToolSurfaceContribution) {
312    for override_ in contribution.overrides {
313        if let Some(tool) = surface
314            .tools
315            .iter_mut()
316            .find(|tool| tool.manifest.name == override_.tool_name)
317            && let Some(availability) = override_.availability
318        {
319            tool.availability = availability;
320        }
321    }
322
323    surface.tool_list_notes.extend(
324        contribution
325            .tool_list_notes
326            .into_iter()
327            .map(|note| note.trim().to_string())
328            .filter(|note| !note.is_empty()),
329    );
330}
331
332fn validate_agent_surface(tools: &[ToolSurfaceEntry]) {
333    let mut seen = BTreeSet::new();
334    for tool in tools.iter().filter(|tool| tool.availability.is_callable()) {
335        let identity = tool
336            .manifest
337            .agent_surface
338            .executable_for(&tool.manifest.name);
339        validate_module_segments(&identity.module_path, &tool.manifest.name);
340        validate_module_segment(&identity.operation, &tool.manifest.name, "operation");
341        let key = format!("{}.{}", identity.module_path.join("."), identity.operation);
342        assert!(
343            seen.insert(key.clone()),
344            "duplicate agent module operation path `{key}`"
345        );
346    }
347}
348
349fn validate_module_segments(segments: &[String], tool_name: &str) {
350    assert!(
351        !segments.is_empty(),
352        "tool `{tool_name}` has empty agent module path"
353    );
354    for segment in segments {
355        validate_module_segment(segment, tool_name, "module path segment");
356    }
357}
358
359fn validate_module_segment(segment: &str, tool_name: &str, field: &str) {
360    assert!(
361        is_module_segment(segment),
362        "tool `{tool_name}` has invalid agent {field} `{segment}`"
363    );
364}
365
366fn is_module_segment(segment: &str) -> bool {
367    let segment = segment.trim();
368    !segment.is_empty()
369        && segment
370            .chars()
371            .all(|ch| ch.is_ascii_lowercase() || ch.is_ascii_digit() || ch == '_')
372        && segment.chars().any(|ch| ch.is_ascii_lowercase())
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use crate::{ToolActivation, ToolAvailabilityConfig, ToolScheduling};
379    use std::sync::atomic::{AtomicUsize, Ordering};
380
381    fn tool(name: &str, availability: ToolAvailability) -> ToolDefinition {
382        let mut definition = ToolDefinition::raw(
383            format!("tool:{name}"),
384            name,
385            format!("Tool {name}"),
386            serde_json::json!({
387                "type": "object",
388                "properties": { "path": { "type": "string" } },
389                "required": ["path"]
390            }),
391            serde_json::json!({ "type": "string" }),
392        );
393        definition.manifest.availability = ToolAvailabilityConfig::same(availability);
394        definition.manifest.activation = ToolActivation::Always;
395        definition.manifest.scheduling = ToolScheduling::Parallel;
396        definition
397    }
398
399    fn build_input(
400        tools: Vec<ToolDefinition>,
401        contributions: Vec<ToolSurfaceContribution>,
402    ) -> ToolSurfaceBuildInput {
403        let contracts = tools
404            .iter()
405            .map(|tool| (tool.name().to_string(), Arc::new(tool.contract())))
406            .collect::<BTreeMap<_, _>>();
407        ToolSurfaceBuildInput {
408            tools: tools.into_iter().map(|tool| tool.manifest()).collect(),
409            resolve_contract: Some(Arc::new(move |name| contracts.get(name).cloned())),
410            contributions,
411        }
412    }
413
414    #[test]
415    fn surface_splits_callable_and_showcased_tools() {
416        let surface = build_tool_surface(build_input(
417            vec![
418                tool("search_tools", ToolAvailability::Showcased),
419                tool("read_file", ToolAvailability::Showcased),
420                tool("grep", ToolAvailability::Callable),
421                tool("privileged_tool", ToolAvailability::Searchable),
422            ],
423            Vec::new(),
424        ));
425
426        assert_eq!(surface.callable_tools().len(), 3);
427        assert_eq!(surface.showcased_tools().len(), 2);
428        assert_eq!(surface.omitted_tool_count(), 2);
429        assert!(!surface.prompt_tool_docs().contains("Catalogued tools"));
430    }
431
432    #[test]
433    fn explicit_contributions_override_availability() {
434        let surface = build_tool_surface(build_input(
435            vec![tool("read_file", ToolAvailability::Showcased)],
436            vec![ToolSurfaceContribution {
437                overrides: vec![ToolSurfaceOverride {
438                    tool_name: "read_file".to_string(),
439                    availability: Some(ToolAvailability::Off),
440                }],
441                tool_list_notes: vec!["custom note".to_string()],
442            }],
443        ));
444
445        assert_eq!(
446            surface
447                .tools
448                .iter()
449                .find(|tool| tool.manifest.name == "read_file")
450                .expect("read_file present")
451                .availability,
452            ToolAvailability::Off
453        );
454        assert!(
455            surface
456                .tool_list_notes
457                .iter()
458                .any(|note| note == "custom note")
459        );
460    }
461
462    #[test]
463    fn prompt_gate_requires_matching_tool_availability() {
464        let surface = build_tool_surface(build_input(
465            vec![tool("search_tools", ToolAvailability::Showcased)],
466            Vec::new(),
467        ));
468
469        let kept = surface.filter_prompt_contributions(vec![
470            PromptContribution::guidance("Plain", "always"),
471            PromptContribution::guidance("Discovery", "discover")
472                .requires_tool("search_tools", ToolAvailability::Showcased),
473            PromptContribution::guidance("Off", "off")
474                .requires_tool("missing_tool", ToolAvailability::Callable),
475        ]);
476
477        assert_eq!(kept.len(), 2);
478        assert!(
479            kept.iter()
480                .any(|contribution| contribution.title.as_deref() == Some("Plain"))
481        );
482        assert!(
483            kept.iter()
484                .any(|contribution| contribution.title.as_deref() == Some("Discovery"))
485        );
486    }
487
488    #[test]
489    fn rlm_surface_does_not_resolve_searchable_only_contracts() {
490        let contract_resolutions = Arc::new(AtomicUsize::new(0));
491        let searchable = tool("large_schema", ToolAvailability::Searchable);
492        let showcased = tool("search_tools", ToolAvailability::Showcased);
493        let resolver_count = Arc::clone(&contract_resolutions);
494        let surface = build_tool_surface(ToolSurfaceBuildInput {
495            tools: vec![searchable.manifest(), showcased.manifest()],
496            resolve_contract: Some(Arc::new(move |name| {
497                resolver_count.fetch_add(1, Ordering::SeqCst);
498                match name {
499                    "large_schema" => Some(Arc::new(searchable.contract())),
500                    "search_tools" => Some(Arc::new(showcased.contract())),
501                    _ => None,
502                }
503            })),
504            contributions: Vec::new(),
505        });
506
507        assert_eq!(
508            surface.tool_availability("large_schema"),
509            Some(ToolAvailability::Searchable)
510        );
511        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 0);
512        assert!(!surface.prompt_tool_docs().contains("large_schema"));
513        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
514    }
515
516    #[test]
517    fn callable_only_surface_resolves_model_specs_lazily() {
518        let contract_resolutions = Arc::new(AtomicUsize::new(0));
519        let callable = tool("large_callable", ToolAvailability::Callable);
520        let resolver_count = Arc::clone(&contract_resolutions);
521        let surface = build_tool_surface(ToolSurfaceBuildInput {
522            tools: vec![callable.manifest()],
523            resolve_contract: Some(Arc::new(move |name| {
524                resolver_count.fetch_add(1, Ordering::SeqCst);
525                (name == "large_callable").then(|| Arc::new(callable.contract()))
526            })),
527            contributions: Vec::new(),
528        });
529
530        assert_eq!(
531            surface.tool_names().as_ref(),
532            &vec!["large_callable".to_string()]
533        );
534        assert_eq!(surface.model_tool_specs().len(), 1);
535        assert_eq!(surface.prompt_tool_docs(), "");
536        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
537    }
538
539    #[test]
540    fn standard_surface_resolves_model_specs_lazily() {
541        let contract_resolutions = Arc::new(AtomicUsize::new(0));
542        let callable = tool("read_file", ToolAvailability::Callable);
543        let resolver_count = Arc::clone(&contract_resolutions);
544        let surface = build_tool_surface(ToolSurfaceBuildInput {
545            tools: vec![callable.manifest()],
546            resolve_contract: Some(Arc::new(move |name| {
547                resolver_count.fetch_add(1, Ordering::SeqCst);
548                (name == "read_file").then(|| Arc::new(callable.contract()))
549            })),
550            contributions: Vec::new(),
551        });
552
553        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 0);
554        assert_eq!(surface.model_tool_specs().len(), 1);
555        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
556        assert_eq!(surface.model_tool_specs().len(), 1);
557        assert_eq!(contract_resolutions.load(Ordering::SeqCst), 1);
558    }
559
560    #[test]
561    fn tool_names_fingerprint_matches_prompt_hash() {
562        let surface = build_tool_surface(build_input(
563            vec![
564                tool("read_file", ToolAvailability::Callable),
565                tool("search_tools", ToolAvailability::Showcased),
566            ],
567            Vec::new(),
568        ));
569
570        assert_eq!(
571            surface.tool_names_fingerprint(),
572            prompt_tool_names_fingerprint(&surface.tool_names())
573        );
574    }
575}