use crate::{Error, Handler, Result};
use rustc_hash::FxHashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub struct HandlerRegistry {
handlers: FxHashMap<String, Arc<dyn HandlerEntry>>,
}
trait HandlerEntry: Send + Sync {
fn dispatch(&self, params: &[u8]) -> BoxFuture<'static, Result<Vec<u8>>>;
fn input_schema(&self) -> schemars::schema::RootSchema;
fn output_schema(&self) -> schemars::schema::RootSchema;
}
struct HandlerEntryImpl<H: Handler> {
handler: Arc<H>,
}
impl<H: Handler> HandlerEntryImpl<H> {
fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
}
}
}
impl<H> HandlerEntry for HandlerEntryImpl<H>
where
H: Handler,
H::Input: 'static,
H::Output: 'static,
{
fn dispatch(&self, params: &[u8]) -> BoxFuture<'static, Result<Vec<u8>>> {
let input: H::Input = match serde_json::from_slice(params) {
Ok(input) => input,
Err(e) => return Box::pin(async move { Err(e.into()) }),
};
let handler = self.handler.clone();
Box::pin(async move {
let output = handler.handle(input).await.map_err(Into::into)?;
serde_json::to_vec(&output).map_err(Into::into)
})
}
fn input_schema(&self) -> schemars::schema::RootSchema {
H::input_schema()
}
fn output_schema(&self) -> schemars::schema::RootSchema {
H::output_schema()
}
}
impl HandlerRegistry {
pub fn new() -> Self {
Self {
handlers: FxHashMap::default(),
}
}
pub fn register<H>(&mut self, name: impl Into<String>, handler: H)
where
H: Handler,
H::Input: 'static,
H::Output: 'static,
{
let entry = HandlerEntryImpl::new(handler);
self.handlers.insert(name.into(), Arc::new(entry));
}
pub fn has_handler(&self, name: &str) -> bool {
self.handlers.contains_key(name)
}
#[inline(always)]
pub async fn dispatch(&self, tool: &str, params: &[u8]) -> Result<Vec<u8>> {
match self.handlers.get(tool) {
Some(handler) => handler.dispatch(params).await,
None => Err(Error::ToolNotFound(tool.to_string())),
}
}
pub fn len(&self) -> usize {
self.handlers.len()
}
pub fn is_empty(&self) -> bool {
self.handlers.is_empty()
}
pub fn get_input_schema(&self, tool: &str) -> Option<schemars::schema::RootSchema> {
self.handlers.get(tool).map(|h| h.input_schema())
}
pub fn get_output_schema(&self, tool: &str) -> Option<schemars::schema::RootSchema> {
self.handlers.get(tool).map(|h| h.output_schema())
}
}
impl Default for HandlerRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct TestInput {
value: i32,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct TestOutput {
result: i32,
}
struct TestHandler;
#[async_trait]
impl crate::Handler for TestHandler {
type Input = TestInput;
type Output = TestOutput;
type Error = crate::Error;
async fn handle(&self, input: Self::Input) -> Result<Self::Output> {
Ok(TestOutput {
result: input.value * 2,
})
}
}
struct ErrorHandler;
#[async_trait]
impl crate::Handler for ErrorHandler {
type Input = TestInput;
type Output = TestOutput;
type Error = crate::Error;
async fn handle(&self, _input: Self::Input) -> Result<Self::Output> {
Err(crate::Error::Handler("test error".to_string()))
}
}
#[tokio::test]
async fn test_registry_new() {
let registry = HandlerRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
}
#[tokio::test]
async fn test_registry_register() {
let mut registry = HandlerRegistry::new();
registry.register("test", TestHandler);
assert!(!registry.is_empty());
assert_eq!(registry.len(), 1);
assert!(registry.has_handler("test"));
assert!(!registry.has_handler("nonexistent"));
}
#[tokio::test]
async fn test_registry_dispatch() {
let mut registry = HandlerRegistry::new();
registry.register("test", TestHandler);
let input = TestInput { value: 21 };
let input_bytes = serde_json::to_vec(&input).unwrap();
let result = registry.dispatch("test", &input_bytes).await;
assert!(result.is_ok());
let output: TestOutput = serde_json::from_slice(&result.unwrap()).unwrap();
assert_eq!(output.result, 42);
}
#[tokio::test]
async fn test_registry_dispatch_tool_not_found() {
let registry = HandlerRegistry::new();
let input = TestInput { value: 21 };
let input_bytes = serde_json::to_vec(&input).unwrap();
let result = registry.dispatch("nonexistent", &input_bytes).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), crate::Error::ToolNotFound(_)));
}
#[tokio::test]
async fn test_registry_dispatch_invalid_input() {
let mut registry = HandlerRegistry::new();
registry.register("test", TestHandler);
let invalid_input = b"{\"invalid\": \"json\"}";
let result = registry.dispatch("test", invalid_input).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_registry_dispatch_handler_error() {
let mut registry = HandlerRegistry::new();
registry.register("error", ErrorHandler);
let input = TestInput { value: 21 };
let input_bytes = serde_json::to_vec(&input).unwrap();
let result = registry.dispatch("error", &input_bytes).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), crate::Error::Handler(_)));
}
#[tokio::test]
async fn test_registry_get_schemas() {
let mut registry = HandlerRegistry::new();
registry.register("test", TestHandler);
let input_schema = registry.get_input_schema("test");
assert!(input_schema.is_some());
let output_schema = registry.get_output_schema("test");
assert!(output_schema.is_some());
let missing_schema = registry.get_input_schema("nonexistent");
assert!(missing_schema.is_none());
}
#[tokio::test]
async fn test_registry_multiple_handlers() {
let mut registry = HandlerRegistry::new();
registry.register("handler1", TestHandler);
registry.register("handler2", TestHandler);
registry.register("handler3", TestHandler);
assert_eq!(registry.len(), 3);
assert!(registry.has_handler("handler1"));
assert!(registry.has_handler("handler2"));
assert!(registry.has_handler("handler3"));
}
#[tokio::test]
async fn test_schema_not_default() {
let mut registry = HandlerRegistry::new();
registry.register("test", TestHandler);
let input_schema = registry.get_input_schema("test").unwrap();
let default_schema = schemars::schema::RootSchema::default();
assert_ne!(
serde_json::to_string(&input_schema).unwrap(),
serde_json::to_string(&default_schema).unwrap(),
"Input schema should not be Default::default()"
);
let output_schema = registry.get_output_schema("test").unwrap();
assert_ne!(
serde_json::to_string(&output_schema).unwrap(),
serde_json::to_string(&default_schema).unwrap(),
"Output schema should not be Default::default()"
);
}
#[tokio::test]
async fn test_schema_properties() {
let mut registry = HandlerRegistry::new();
registry.register("test", TestHandler);
let input_schema = registry.get_input_schema("test").unwrap();
assert!(
input_schema.schema.object.is_some(),
"Input schema should have object"
);
let output_schema = registry.get_output_schema("test").unwrap();
assert!(
output_schema.schema.object.is_some(),
"Output schema should have object"
);
}
}