use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
use thiserror::Error;
#[derive(Error, Debug, Clone)]
pub enum RegistryError {
#[error("Component {0} not found")]
NotFound(String),
#[error("Component {0} already exists")]
AlreadyExists(String),
#[error("Type mismatch for component {0}")]
TypeMismatch(String),
}
impl From<RegistryError> for crate::McpError {
fn from(err: RegistryError) -> Self {
match err {
RegistryError::NotFound(name) => {
crate::McpError::internal(format!("Component '{}' not found in registry", name))
.with_component("registry")
}
RegistryError::AlreadyExists(name) => crate::McpError::invalid_params(format!(
"Component '{}' already exists in registry",
name
))
.with_component("registry"),
RegistryError::TypeMismatch(name) => crate::McpError::internal(format!(
"Type mismatch when accessing component '{}' in registry",
name
))
.with_component("registry"),
}
}
}
#[derive(Debug)]
pub struct Registry {
components: RwLock<HashMap<String, Arc<dyn Any + Send + Sync>>>,
type_map: RwLock<HashMap<String, TypeId>>,
}
#[derive(Debug)]
pub struct RegistryBuilder {
registry: Registry,
}
impl Registry {
#[must_use]
pub fn new() -> Self {
Self {
components: RwLock::new(HashMap::new()),
type_map: RwLock::new(HashMap::new()),
}
}
#[must_use]
pub fn builder() -> RegistryBuilder {
RegistryBuilder {
registry: Self::new(),
}
}
pub fn register<T>(&self, name: impl Into<String>, component: T) -> Result<(), RegistryError>
where
T: 'static + Send + Sync,
{
let name = name.into();
let type_id = TypeId::of::<T>();
{
let mut components = self.components.write();
if components.contains_key(&name) {
return Err(RegistryError::AlreadyExists(name));
}
components.insert(name.clone(), Arc::new(component));
}
{
let mut type_map = self.type_map.write();
type_map.insert(name, type_id);
}
Ok(())
}
pub fn get<T>(&self, name: &str) -> Result<Arc<T>, RegistryError>
where
T: 'static + Send + Sync,
{
let component = {
let components = self.components.read();
components
.get(name)
.ok_or_else(|| RegistryError::NotFound(name.to_string()))?
.clone()
};
component
.downcast::<T>()
.map_err(|_| RegistryError::TypeMismatch(name.to_string()))
}
pub fn contains(&self, name: &str) -> bool {
self.components.read().contains_key(name)
}
pub fn component_names(&self) -> Vec<String> {
self.components.read().keys().cloned().collect()
}
pub fn remove(&self, name: &str) -> Option<Arc<dyn Any + Send + Sync>> {
{
let mut type_map = self.type_map.write();
type_map.remove(name);
}
let mut components = self.components.write();
components.remove(name)
}
pub fn clear(&self) {
self.components.write().clear();
self.type_map.write().clear();
}
pub fn len(&self) -> usize {
self.components.read().len()
}
pub fn is_empty(&self) -> bool {
self.components.read().is_empty()
}
}
impl RegistryBuilder {
pub fn register<T>(self, name: impl Into<String>, component: T) -> Result<Self, RegistryError>
where
T: 'static + Send + Sync,
{
self.registry.register(name, component)?;
Ok(self)
}
pub fn build(self) -> Registry {
self.registry
}
}
impl Default for Registry {
fn default() -> Self {
Self::new()
}
}
pub trait Component: 'static + Send + Sync {
fn name(&self) -> &'static str;
fn register_in(self, registry: &Registry) -> Result<(), RegistryError>
where
Self: Sized,
{
registry.register(self.name(), self)
}
}
#[macro_export]
macro_rules! register_component {
($registry:expr, $name:expr, $component:expr) => {
$registry.register($name, $component)
};
($registry:expr, $($name:expr => $component:expr),+ $(,)?) => {
{
$(
$registry.register($name, $component)?;
)+
Ok::<(), $crate::registry::RegistryError>(())
}
};
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
#[derive(Debug)]
struct TestService {
id: u32,
counter: AtomicU32,
}
impl TestService {
fn new(id: u32) -> Self {
Self {
id,
counter: AtomicU32::new(0),
}
}
fn increment(&self) -> u32 {
self.counter.fetch_add(1, Ordering::SeqCst) + 1
}
fn get_id(&self) -> u32 {
self.id
}
}
impl Component for TestService {
fn name(&self) -> &'static str {
"test_service"
}
}
#[test]
fn test_registry_basic_operations() {
let registry = Registry::new();
let service = TestService::new(42);
assert!(registry.register("test", service).is_ok());
assert!(registry.contains("test"));
assert!(!registry.contains("nonexistent"));
let retrieved: Arc<TestService> = registry.get("test").unwrap();
assert_eq!(retrieved.get_id(), 42);
assert_eq!(retrieved.increment(), 1);
assert_eq!(retrieved.increment(), 2);
assert_eq!(registry.len(), 1);
assert!(!registry.is_empty());
}
#[test]
fn test_registry_errors() {
let registry = Registry::new();
let result: Result<Arc<TestService>, _> = registry.get("nonexistent");
assert!(matches!(result, Err(RegistryError::NotFound(_))));
let service1 = TestService::new(1);
let service2 = TestService::new(2);
assert!(registry.register("duplicate", service1).is_ok());
let result = registry.register("duplicate", service2);
assert!(matches!(result, Err(RegistryError::AlreadyExists(_))));
}
#[test]
fn test_registry_builder() {
let registry = Registry::builder()
.register("service1", TestService::new(1))
.unwrap()
.register("service2", TestService::new(2))
.unwrap()
.build();
assert_eq!(registry.len(), 2);
let service1: Arc<TestService> = registry.get("service1").unwrap();
let service2: Arc<TestService> = registry.get("service2").unwrap();
assert_eq!(service1.get_id(), 1);
assert_eq!(service2.get_id(), 2);
}
#[test]
fn test_component_trait() {
let registry = Registry::new();
let service = TestService::new(123);
assert!(service.register_in(®istry).is_ok());
let retrieved: Arc<TestService> = registry.get("test_service").unwrap();
assert_eq!(retrieved.get_id(), 123);
}
#[test]
fn test_registry_removal() {
let registry = Registry::new();
let service = TestService::new(42);
registry.register("test", service).unwrap();
assert!(registry.contains("test"));
let removed = registry.remove("test");
assert!(removed.is_some());
assert!(!registry.contains("test"));
let removed = registry.remove("nonexistent");
assert!(removed.is_none());
}
#[test]
fn test_registry_clear() {
let registry = Registry::new();
registry.register("service1", TestService::new(1)).unwrap();
registry.register("service2", TestService::new(2)).unwrap();
assert_eq!(registry.len(), 2);
registry.clear();
assert_eq!(registry.len(), 0);
assert!(registry.is_empty());
}
#[test]
fn test_component_names() {
let registry = Registry::new();
registry.register("alpha", TestService::new(1)).unwrap();
registry.register("beta", TestService::new(2)).unwrap();
let mut names = registry.component_names();
names.sort();
assert_eq!(names, vec!["alpha", "beta"]);
}
}