use dashmap::DashMap;
use std::sync::Arc;
use crate::handlers::{
CompletionProvider, ElicitationHandler, HandlerCapabilities, PingHandler,
ResourceTemplateHandler,
};
use crate::registry::{Registry, RegistryError};
macro_rules! register_handler {
($map:expr, $caps:expr, $name:expr, $handler:expr, $cap_field:ident) => {{
let name = $name.into();
if $map.contains_key(&name) {
return Err(RegistryError::AlreadyExists(name));
}
$map.insert(name.clone(), $handler);
$caps.entry(name.clone()).or_default().$cap_field = true;
Ok(())
}};
}
pub struct EnhancedRegistry {
base: Registry,
elicitation_handlers: Arc<DashMap<String, Arc<dyn ElicitationHandler>>>,
completion_providers: Arc<DashMap<String, Arc<dyn CompletionProvider>>>,
template_handlers: Arc<DashMap<String, Arc<dyn ResourceTemplateHandler>>>,
ping_handlers: Arc<DashMap<String, Arc<dyn PingHandler>>>,
capabilities: Arc<DashMap<String, HandlerCapabilities>>,
}
impl EnhancedRegistry {
pub fn new() -> Self {
Self {
base: Registry::new(),
elicitation_handlers: Arc::new(DashMap::new()),
completion_providers: Arc::new(DashMap::new()),
template_handlers: Arc::new(DashMap::new()),
ping_handlers: Arc::new(DashMap::new()),
capabilities: Arc::new(DashMap::new()),
}
}
pub fn register_elicitation_handler(
&self,
name: impl Into<String>,
handler: Arc<dyn ElicitationHandler>,
) -> Result<(), RegistryError> {
register_handler!(
self.elicitation_handlers,
self.capabilities,
name,
handler,
elicitation
)
}
pub fn get_elicitation_handler(&self, name: &str) -> Option<Arc<dyn ElicitationHandler>> {
self.elicitation_handlers.get(name).map(|h| h.clone())
}
pub fn list_elicitation_handlers(&self) -> Vec<String> {
self.elicitation_handlers
.iter()
.map(|entry| entry.key().clone())
.collect()
}
pub fn register_completion_provider(
&self,
name: impl Into<String>,
provider: Arc<dyn CompletionProvider>,
) -> Result<(), RegistryError> {
register_handler!(
self.completion_providers,
self.capabilities,
name,
provider,
completion
)
}
pub fn get_completion_provider(&self, name: &str) -> Option<Arc<dyn CompletionProvider>> {
self.completion_providers.get(name).map(|p| p.clone())
}
pub fn get_matching_completion_providers(
&self,
context: &crate::context::CompletionContext,
) -> Vec<Arc<dyn CompletionProvider>> {
let mut providers: Vec<_> = self
.completion_providers
.iter()
.filter_map(|entry| {
let provider = entry.value();
if provider.can_provide(context) {
Some(provider.clone())
} else {
None
}
})
.collect();
providers.sort_by_key(|p| -p.priority());
providers
}
pub fn register_template_handler(
&self,
name: impl Into<String>,
handler: Arc<dyn ResourceTemplateHandler>,
) -> Result<(), RegistryError> {
register_handler!(
self.template_handlers,
self.capabilities,
name,
handler,
templates
)
}
pub fn get_template_handler(&self, name: &str) -> Option<Arc<dyn ResourceTemplateHandler>> {
self.template_handlers.get(name).map(|h| h.clone())
}
pub fn register_ping_handler(
&self,
name: impl Into<String>,
handler: Arc<dyn PingHandler>,
) -> Result<(), RegistryError> {
register_handler!(self.ping_handlers, self.capabilities, name, handler, ping)
}
pub fn get_ping_handler(&self, name: &str) -> Option<Arc<dyn PingHandler>> {
self.ping_handlers.get(name).map(|h| h.clone())
}
pub fn get_capabilities(&self, name: &str) -> Option<HandlerCapabilities> {
self.capabilities.get(name).map(|c| c.clone())
}
pub fn find_by_capabilities(
&self,
filter: impl Fn(&HandlerCapabilities) -> bool,
) -> Vec<String> {
self.capabilities
.iter()
.filter(|entry| filter(entry.value()))
.map(|entry| entry.key().clone())
.collect()
}
pub fn clear_handlers(&self) {
self.elicitation_handlers.clear();
self.completion_providers.clear();
self.template_handlers.clear();
self.ping_handlers.clear();
self.capabilities.clear();
}
pub fn handler_stats(&self) -> HandlerStats {
HandlerStats {
elicitation_handlers: self.elicitation_handlers.len(),
completion_providers: self.completion_providers.len(),
template_handlers: self.template_handlers.len(),
ping_handlers: self.ping_handlers.len(),
total_components: self.capabilities.len(),
}
}
pub fn base(&self) -> &Registry {
&self.base
}
}
impl Default for EnhancedRegistry {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for EnhancedRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EnhancedRegistry")
.field("base", &self.base)
.field(
"elicitation_handlers_count",
&self.elicitation_handlers.len(),
)
.field(
"completion_providers_count",
&self.completion_providers.len(),
)
.field("template_handlers_count", &self.template_handlers.len())
.field("ping_handlers_count", &self.ping_handlers.len())
.field("capabilities_count", &self.capabilities.len())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct HandlerStats {
pub elicitation_handlers: usize,
pub completion_providers: usize,
pub template_handlers: usize,
pub ping_handlers: usize,
pub total_components: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Result;
use crate::context::{CompletionContext, ElicitationContext};
use crate::handlers::{CompletionItem, ElicitationResponse};
use std::future::Future;
use std::pin::Pin;
struct TestElicitationHandler;
impl ElicitationHandler for TestElicitationHandler {
fn handle_elicitation(
&self,
_context: &ElicitationContext,
) -> Pin<Box<dyn Future<Output = Result<ElicitationResponse>> + Send + '_>> {
Box::pin(async move {
Ok(ElicitationResponse {
accepted: true,
content: None,
decline_reason: None,
})
})
}
fn can_handle(&self, _context: &ElicitationContext) -> bool {
true
}
}
struct TestCompletionProvider;
impl CompletionProvider for TestCompletionProvider {
fn provide_completions(
&self,
_context: &CompletionContext,
) -> Pin<Box<dyn Future<Output = Result<Vec<CompletionItem>>> + Send + '_>> {
Box::pin(async move { Ok(vec![]) })
}
fn can_provide(&self, _context: &CompletionContext) -> bool {
true
}
fn priority(&self) -> i32 {
10
}
}
#[test]
fn test_enhanced_registry() {
let registry = EnhancedRegistry::new();
let handler = Arc::new(TestElicitationHandler);
registry
.register_elicitation_handler("test_handler", handler)
.unwrap();
assert!(registry.get_elicitation_handler("test_handler").is_some());
assert_eq!(registry.list_elicitation_handlers(), vec!["test_handler"]);
let caps = registry.get_capabilities("test_handler").unwrap();
assert!(caps.elicitation);
assert!(!caps.completion);
}
#[test]
fn test_completion_provider_priority() {
let registry = EnhancedRegistry::new();
let provider = Arc::new(TestCompletionProvider);
registry
.register_completion_provider("test_provider", provider)
.unwrap();
use crate::context::CompletionReference;
let context = CompletionContext::new(CompletionReference::Tool {
name: "test".to_string(),
argument: "arg".to_string(),
});
let providers = registry.get_matching_completion_providers(&context);
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].priority(), 10);
}
#[test]
fn test_handler_stats() {
let registry = EnhancedRegistry::new();
registry
.register_elicitation_handler("elicit1", Arc::new(TestElicitationHandler))
.unwrap();
registry
.register_completion_provider("comp1", Arc::new(TestCompletionProvider))
.unwrap();
let stats = registry.handler_stats();
assert_eq!(stats.elicitation_handlers, 1);
assert_eq!(stats.completion_providers, 1);
assert_eq!(stats.template_handlers, 0);
assert_eq!(stats.ping_handlers, 0);
assert_eq!(stats.total_components, 2);
}
}