use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::session::ToolRegistry;
use crate::tool::{Tool, ToolSchema};
pub struct StaticToolRegistry {
schemas: Vec<ToolSchema>,
by_name: HashMap<String, Arc<dyn Tool>>,
}
impl StaticToolRegistry {
pub fn builder() -> StaticToolRegistryBuilder {
StaticToolRegistryBuilder::default()
}
pub fn empty() -> Self {
Self {
schemas: Vec::new(),
by_name: HashMap::new(),
}
}
}
impl ToolRegistry for StaticToolRegistry {
fn schemas(&self) -> Vec<ToolSchema> {
self.schemas.clone()
}
fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.by_name.get(name).cloned()
}
}
#[derive(Default)]
pub struct StaticToolRegistryBuilder {
schemas: Vec<ToolSchema>,
by_name: HashMap<String, Arc<dyn Tool>>,
}
impl StaticToolRegistryBuilder {
pub fn insert(mut self, tool: Arc<dyn Tool>) -> Self {
let schema = tool.schema().clone();
if let Some(pos) = self.schemas.iter().position(|s| s.name == schema.name) {
if let Some(slot) = self.schemas.get_mut(pos) {
*slot = schema.clone();
}
} else {
self.schemas.push(schema.clone());
}
self.by_name.insert(schema.name, tool);
self
}
pub fn build(self) -> StaticToolRegistry {
StaticToolRegistry {
schemas: self.schemas,
by_name: self.by_name,
}
}
}
pub struct CompositeRegistry {
session: Arc<dyn ToolRegistry>,
process: Arc<dyn ToolRegistry>,
}
impl CompositeRegistry {
pub fn new(session: Arc<dyn ToolRegistry>, process: Arc<dyn ToolRegistry>) -> Self {
Self { session, process }
}
}
impl ToolRegistry for CompositeRegistry {
fn schemas(&self) -> Vec<ToolSchema> {
let mut session_schemas = self.session.schemas();
let mut process_schemas = self.process.schemas();
let session_names: HashSet<&str> =
session_schemas.iter().map(|s| s.name.as_str()).collect();
process_schemas.retain(|s| !session_names.contains(s.name.as_str()));
session_schemas.append(&mut process_schemas);
session_schemas
}
fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.session.get(name).or_else(|| self.process.get(name))
}
}
#[cfg(test)]
mod tests;