Skip to main content

lash_tool_support/
static_provider.rs

1//! [`StaticToolProvider`] — a reusable [`ToolProvider`] for the common case of
2//! serving a *fixed* set of [`ToolDefinition`]s.
3//!
4//! Almost every single- (or fixed-multi-) tool provider in the workspace used
5//! to hand-roll the same idiom: `tool_manifests()` rebuilt `def.manifest()` and
6//! `resolve_contract()` rebuilt `def.contract()` on *every* call, re-running
7//! schema and doc generation each time. `StaticToolProvider` derives the
8//! manifests and contracts **once** in its constructor and serves them from a
9//! cache, delegating only `execute` (and, by default, the identity
10//! `prepare_tool_call`) to a small [`StaticToolExecute`] implementation that
11//! holds the tool's runtime state and behavior.
12
13use std::collections::HashMap;
14use std::sync::Arc;
15
16use lash_core::{
17    ToolCall, ToolContract, ToolDefinition, ToolId, ToolManifest, ToolPrepareCall,
18    ToolPrepareContext, ToolProvider, ToolResult, sansio::PendingToolCall,
19};
20
21/// Per-call execution behavior for a [`StaticToolProvider`].
22///
23/// Implement this on the struct that owns the tool's runtime state (HTTP
24/// clients, shared mutable state, configuration flags, ...). The provider's
25/// manifests and contracts come from the [`ToolDefinition`]s passed to
26/// [`StaticToolProvider::new`]; this trait supplies only the dynamic behavior.
27#[async_trait::async_trait]
28pub trait StaticToolExecute: Send + Sync + 'static {
29    /// Execute a resolved tool call. Dispatch on `call.name` when serving more
30    /// than one tool.
31    async fn execute(&self, call: ToolCall<'_>) -> ToolResult;
32
33    /// Optional argument-preparation hook, mirroring
34    /// [`ToolProvider::prepare_tool_call`]. Defaults to the identity transform.
35    async fn prepare_tool_call(
36        &self,
37        tool_id: &ToolId,
38        pending: PendingToolCall,
39        _context: &ToolPrepareContext,
40    ) -> Result<lash_core::PreparedToolCall, ToolResult> {
41        Ok(lash_core::PreparedToolCall::identity(
42            tool_id.clone(),
43            pending,
44        ))
45    }
46}
47
48/// A [`ToolProvider`] that serves a fixed set of [`ToolDefinition`]s from a
49/// cache, delegating execution to an [`StaticToolExecute`].
50pub struct StaticToolProvider<E: StaticToolExecute> {
51    manifests: Vec<ToolManifest>,
52    contracts: HashMap<String, Arc<ToolContract>>,
53    contracts_by_id: HashMap<ToolId, Arc<ToolContract>>,
54    executor: E,
55}
56
57impl<E: StaticToolExecute> StaticToolProvider<E> {
58    /// Build a provider from a fixed set of definitions and an executor.
59    ///
60    /// Manifests and contracts are derived once, here, and reused for the life
61    /// of the provider.
62    pub fn new(definitions: Vec<ToolDefinition>, executor: E) -> Self {
63        let mut manifests = Vec::with_capacity(definitions.len());
64        let mut contracts = HashMap::with_capacity(definitions.len());
65        let mut contracts_by_id = HashMap::with_capacity(definitions.len());
66        for def in &definitions {
67            let manifest = def.manifest();
68            let contract = Arc::new(def.contract());
69            contracts.insert(manifest.name.clone(), Arc::clone(&contract));
70            contracts_by_id.insert(manifest.id.clone(), contract);
71            manifests.push(manifest);
72        }
73        Self {
74            manifests,
75            contracts,
76            contracts_by_id,
77            executor,
78        }
79    }
80
81    /// Borrow the underlying executor. Useful for tests that need to inspect
82    /// the executor's internal state.
83    pub fn executor(&self) -> &E {
84        &self.executor
85    }
86}
87
88#[async_trait::async_trait]
89impl<E: StaticToolExecute> ToolProvider for StaticToolProvider<E> {
90    fn tool_manifests(&self) -> Vec<ToolManifest> {
91        self.manifests.clone()
92    }
93
94    fn resolve_manifest(&self, name: &str) -> Option<ToolManifest> {
95        self.manifests
96            .iter()
97            .find(|manifest| manifest.name == name)
98            .cloned()
99    }
100
101    fn resolve_manifest_by_id(&self, id: &ToolId) -> Option<ToolManifest> {
102        self.manifests
103            .iter()
104            .find(|manifest| manifest.id == *id)
105            .cloned()
106    }
107
108    fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
109        self.contracts.get(name).cloned()
110    }
111
112    fn resolve_contract_by_id(&self, id: &ToolId) -> Option<Arc<ToolContract>> {
113        self.contracts_by_id.get(id).cloned()
114    }
115
116    async fn prepare_tool_call(
117        &self,
118        call: ToolPrepareCall<'_>,
119    ) -> Result<lash_core::PreparedToolCall, ToolResult> {
120        self.executor
121            .prepare_tool_call(&call.tool_id, call.pending, call.context)
122            .await
123    }
124
125    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
126        self.executor.execute(call).await
127    }
128}