mod args;
mod executor;
pub use args::ToolArgs;
pub use executor::{
BatchExecutionResult, ParallelSafety, ToolCategory, ToolExecutor, ToolRegistration,
execute_batch_with,
};
pub(crate) type ToolFn = std::sync::Arc<
dyn Fn(
&serde_json::Value,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = lellm_core::ToolResult> + Send>>
+ Send
+ Sync,
>;
pub struct ToolSnapshot {
version: u64,
tools: std::sync::Arc<indexmap::IndexMap<String, ToolRegistration>>,
definitions: std::sync::OnceLock<Vec<lellm_core::ToolDefinition>>,
}
impl ToolSnapshot {
pub fn new(tools: indexmap::IndexMap<String, ToolRegistration>, version: u64) -> Self {
Self {
version,
tools: std::sync::Arc::new(tools),
definitions: std::sync::OnceLock::new(),
}
}
pub fn get(&self, name: &str) -> Option<&ToolRegistration> {
self.tools.get(name)
}
pub fn definitions(&self) -> &[lellm_core::ToolDefinition] {
self.definitions
.get_or_init(|| self.tools.values().map(|t| t.definition.clone()).collect())
}
pub fn has_tools(&self) -> bool {
!self.tools.is_empty()
}
pub fn version(&self) -> u64 {
self.version
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
}
#[async_trait::async_trait]
pub trait ToolCatalog: Send + Sync {
async fn snapshot(&self) -> std::sync::Arc<ToolSnapshot>;
}
pub struct StaticCatalog {
snapshot: std::sync::Arc<ToolSnapshot>,
}
impl StaticCatalog {
pub fn from_tools(tools: Vec<ToolRegistration>) -> Self {
let mut map = indexmap::IndexMap::with_capacity(tools.len());
for reg in tools {
map.insert(reg.definition.name.clone(), reg);
}
Self {
snapshot: std::sync::Arc::new(ToolSnapshot::new(map, 0)),
}
}
pub fn empty() -> Self {
Self {
snapshot: std::sync::Arc::new(ToolSnapshot::new(indexmap::IndexMap::new(), 0)),
}
}
}
#[async_trait::async_trait]
impl ToolCatalog for StaticCatalog {
async fn snapshot(&self) -> std::sync::Arc<ToolSnapshot> {
self.snapshot.clone()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ConflictPolicy {
#[default]
Shadow,
Error,
}
#[derive(Debug, Clone)]
pub struct CatalogConflict {
pub tool_name: String,
pub winner: String,
pub loser: String,
pub policy: ConflictPolicy,
}
pub struct CompositeCatalogBuilder {
sources: Vec<(String, std::sync::Arc<dyn ToolCatalog>)>,
conflict_policy: ConflictPolicy,
}
impl CompositeCatalogBuilder {
pub fn new() -> Self {
Self {
sources: Vec::new(),
conflict_policy: ConflictPolicy::default(),
}
}
pub fn conflict_policy(mut self, policy: ConflictPolicy) -> Self {
self.conflict_policy = policy;
self
}
pub fn add(
mut self,
name: impl Into<String>,
catalog: std::sync::Arc<dyn ToolCatalog>,
) -> Self {
self.sources.push((name.into(), catalog));
self
}
pub fn build(self) -> CompositeCatalog {
let sources: Vec<_> = self.sources.into_iter().map(|(_, c)| c).collect();
CompositeCatalog {
sources,
conflict_policy: self.conflict_policy,
version_counter: std::sync::atomic::AtomicU64::new(0),
conflicts: std::sync::Mutex::new(Vec::new()),
}
}
}
pub struct CompositeCatalog {
sources: Vec<std::sync::Arc<dyn ToolCatalog>>,
conflict_policy: ConflictPolicy,
version_counter: std::sync::atomic::AtomicU64,
conflicts: std::sync::Mutex<Vec<CatalogConflict>>,
}
impl CompositeCatalog {
pub fn builder() -> CompositeCatalogBuilder {
CompositeCatalogBuilder::new()
}
pub fn new(sources: Vec<std::sync::Arc<dyn ToolCatalog>>) -> Self {
Self {
sources,
conflict_policy: ConflictPolicy::default(),
version_counter: std::sync::atomic::AtomicU64::new(0),
conflicts: std::sync::Mutex::new(Vec::new()),
}
}
pub fn conflicts(&self) -> Vec<CatalogConflict> {
self.conflicts.lock().unwrap().clone()
}
}
#[async_trait::async_trait]
impl ToolCatalog for CompositeCatalog {
async fn snapshot(&self) -> std::sync::Arc<ToolSnapshot> {
let mut merged = indexmap::IndexMap::new();
let mut conflicts = Vec::new();
for (idx, source) in self.sources.iter().rev().enumerate() {
let snap = source.snapshot().await;
let snap_tools = &snap.tools;
let source_name = format!("source_{}", idx);
for (name, tool) in snap_tools.iter() {
if merged.contains_key(name) {
tracing::warn!(
tool_name = %name,
"Tool conflict detected in CompositeCatalog. Higher priority tool shadows the lower one."
);
conflicts.push(CatalogConflict {
tool_name: name.clone(),
winner: source_name.clone(),
loser: format!("source_{}", idx + 1),
policy: self.conflict_policy,
});
}
merged.insert(name.clone(), tool.clone());
}
}
if !conflicts.is_empty() {
*self.conflicts.lock().unwrap() = conflicts;
}
let version = self
.version_counter
.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
+ 1;
std::sync::Arc::new(ToolSnapshot::new(merged, version))
}
}