use std::collections::HashMap;
use std::sync::Arc;
use lash_core::{
ToolCall, ToolContract, ToolDefinition, ToolManifest, ToolPrepareCall, ToolPrepareContext,
ToolProvider, ToolResult, sansio::PendingToolCall,
};
#[async_trait::async_trait]
pub trait StaticToolExecute: Send + Sync + 'static {
async fn execute(&self, call: ToolCall<'_>) -> ToolResult;
async fn prepare_tool_call(
&self,
pending: PendingToolCall,
_context: &ToolPrepareContext,
) -> Result<lash_core::PreparedToolCall, ToolResult> {
Ok(lash_core::PreparedToolCall::identity(pending))
}
}
pub struct StaticToolProvider<E: StaticToolExecute> {
manifests: Vec<ToolManifest>,
contracts: HashMap<String, Arc<ToolContract>>,
executor: E,
}
impl<E: StaticToolExecute> StaticToolProvider<E> {
pub fn new(definitions: Vec<ToolDefinition>, executor: E) -> Self {
let mut manifests = Vec::with_capacity(definitions.len());
let mut contracts = HashMap::with_capacity(definitions.len());
for def in &definitions {
let manifest = def.manifest();
contracts.insert(manifest.name.clone(), Arc::new(def.contract()));
manifests.push(manifest);
}
Self {
manifests,
contracts,
executor,
}
}
pub fn executor(&self) -> &E {
&self.executor
}
}
#[async_trait::async_trait]
impl<E: StaticToolExecute> ToolProvider for StaticToolProvider<E> {
fn tool_manifests(&self) -> Vec<ToolManifest> {
self.manifests.clone()
}
fn resolve_manifest(&self, name: &str) -> Option<ToolManifest> {
self.manifests
.iter()
.find(|manifest| manifest.name == name)
.cloned()
}
fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
self.contracts.get(name).cloned()
}
async fn prepare_tool_call(
&self,
call: ToolPrepareCall<'_>,
) -> Result<lash_core::PreparedToolCall, ToolResult> {
self.executor
.prepare_tool_call(call.pending, call.context)
.await
}
async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
self.executor.execute(call).await
}
}