use crate::{A2AError, A2AResult, AgentCard};
use protocol_transport_core::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
pub type A2AMethodHandler = Arc<dyn Fn(JsonRpcRequest) -> A2AResult<JsonRpcResponse> + Send + Sync>;
pub type A2ANotificationHandler = Arc<dyn Fn(JsonRpcNotification) -> A2AResult<()> + Send + Sync>;
#[derive(Debug, Clone)]
pub struct MethodMetadata {
pub name: String,
pub description: String,
pub parameters: Option<Value>,
pub returns: Option<Value>,
pub is_notification: bool,
}
pub struct A2AMethodRegistry {
method_handlers: HashMap<String, A2AMethodHandler>,
notification_handlers: HashMap<String, A2ANotificationHandler>,
method_metadata: HashMap<String, MethodMetadata>,
agent_card: Option<AgentCard>,
}
impl A2AMethodRegistry {
pub fn new() -> Self {
Self {
method_handlers: HashMap::new(),
notification_handlers: HashMap::new(),
method_metadata: HashMap::new(),
agent_card: None,
}
}
pub fn with_agent_card(agent_card: AgentCard) -> Self {
let mut registry = Self::new();
registry.set_agent_card(agent_card);
registry
}
pub fn set_agent_card(&mut self, agent_card: AgentCard) {
self.agent_card = Some(agent_card);
}
pub fn agent_card(&self) -> Option<&AgentCard> {
self.agent_card.as_ref()
}
pub fn register_method(
&mut self,
method: impl Into<String>,
description: impl Into<String>,
handler: A2AMethodHandler,
) {
let method = method.into();
let description = description.into();
self.method_handlers.insert(method.clone(), handler);
self.method_metadata.insert(
method.clone(),
MethodMetadata {
name: method,
description,
parameters: None,
returns: None,
is_notification: false,
},
);
}
pub fn register_method_with_metadata(
&mut self,
method: impl Into<String>,
description: impl Into<String>,
parameters: Option<Value>,
returns: Option<Value>,
handler: A2AMethodHandler,
) {
let method = method.into();
let description = description.into();
self.method_handlers.insert(method.clone(), handler);
self.method_metadata.insert(
method.clone(),
MethodMetadata {
name: method,
description,
parameters,
returns,
is_notification: false,
},
);
}
pub fn register_notification(
&mut self,
method: impl Into<String>,
description: impl Into<String>,
handler: A2ANotificationHandler,
) {
let method = method.into();
let description = description.into();
self.notification_handlers.insert(method.clone(), handler);
self.method_metadata.insert(
method.clone(),
MethodMetadata {
name: method,
description,
parameters: None,
returns: None,
is_notification: true,
},
);
}
pub fn handle_request(&self, request: JsonRpcRequest) -> A2AResult<JsonRpcResponse> {
if !request.is_valid() {
return Err(A2AError::protocol_validation_error(
"Invalid JSON-RPC request",
));
}
let handler = self
.method_handlers
.get(&request.method)
.ok_or_else(|| A2AError::method_not_found(&request.method))?;
let method = request.method.clone();
handler(request).map_err(|e| A2AError::method_execution_failed(&method, e.to_string()))
}
pub fn handle_notification(&self, notification: JsonRpcNotification) -> A2AResult<()> {
if !notification.is_valid() {
return Err(A2AError::protocol_validation_error(
"Invalid JSON-RPC notification",
));
}
let handler = self
.notification_handlers
.get(¬ification.method)
.ok_or_else(|| A2AError::method_not_found(¬ification.method))?;
let method = notification.method.clone();
handler(notification).map_err(|e| A2AError::method_execution_failed(&method, e.to_string()))
}
pub fn has_method(&self, method: &str) -> bool {
self.method_handlers.contains_key(method)
}
pub fn has_notification(&self, method: &str) -> bool {
self.notification_handlers.contains_key(method)
}
pub fn get_method_metadata(&self, method: &str) -> Option<&MethodMetadata> {
self.method_metadata.get(method)
}
pub fn list_methods(&self) -> Vec<String> {
self.method_handlers.keys().cloned().collect()
}
pub fn list_notifications(&self) -> Vec<String> {
self.notification_handlers.keys().cloned().collect()
}
pub fn get_all_metadata(&self) -> &HashMap<String, MethodMetadata> {
&self.method_metadata
}
pub fn unregister_method(&mut self, method: &str) -> bool {
let removed_handler = self.method_handlers.remove(method).is_some();
self.method_metadata.remove(method);
removed_handler
}
pub fn unregister_notification(&mut self, method: &str) -> bool {
let removed_handler = self.notification_handlers.remove(method).is_some();
self.method_metadata.remove(method);
removed_handler
}
pub fn clear(&mut self) {
self.method_handlers.clear();
self.notification_handlers.clear();
self.method_metadata.clear();
}
pub fn stats(&self) -> RegistryStats {
RegistryStats {
method_count: self.method_handlers.len(),
notification_count: self.notification_handlers.len(),
total_registrations: self.method_metadata.len(),
}
}
}
impl Default for A2AMethodRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct RegistryStats {
pub method_count: usize,
pub notification_count: usize,
pub total_registrations: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_method_registration() {
let mut registry = A2AMethodRegistry::new();
registry.register_method(
"ping",
"Simple ping method",
Arc::new(|request| Ok(JsonRpcResponse::success(request.id, json!({"pong": true})))),
);
assert!(registry.has_method("ping"));
assert!(!registry.has_method("unknown"));
assert_eq!(registry.list_methods(), vec!["ping"]);
let metadata = registry.get_method_metadata("ping").unwrap();
assert_eq!(metadata.name, "ping");
assert_eq!(metadata.description, "Simple ping method");
assert!(!metadata.is_notification);
}
#[test]
fn test_notification_registration() {
let mut registry = A2AMethodRegistry::new();
registry.register_notification(
"log.info",
"Info log notification",
Arc::new(|_notification| Ok(())),
);
assert!(registry.has_notification("log.info"));
assert!(!registry.has_notification("unknown"));
assert_eq!(registry.list_notifications(), vec!["log.info"]);
let metadata = registry.get_method_metadata("log.info").unwrap();
assert_eq!(metadata.name, "log.info");
assert_eq!(metadata.description, "Info log notification");
assert!(metadata.is_notification);
}
#[test]
fn test_request_handling() {
let mut registry = A2AMethodRegistry::new();
registry.register_method(
"ping",
"Simple ping method",
Arc::new(|request| Ok(JsonRpcResponse::success(request.id, json!({"pong": true})))),
);
let request = JsonRpcRequest::new(json!("req-123"), "ping".to_string(), json!({}));
let response = registry.handle_request(request).unwrap();
assert!(response.is_success());
assert_eq!(response.id, json!("req-123"));
assert_eq!(response.result.unwrap()["pong"], true);
let unknown_request =
JsonRpcRequest::new(json!("req-456"), "unknown".to_string(), json!({}));
let result = registry.handle_request(unknown_request);
assert!(result.is_err());
if let Err(A2AError::MethodNotFound { method }) = result {
assert_eq!(method, "unknown");
} else {
panic!("Expected MethodNotFound error");
}
}
#[test]
fn test_notification_handling() {
let mut registry = A2AMethodRegistry::new();
registry.register_notification(
"log.info",
"Info log notification",
Arc::new(|_notification| Ok(())),
);
let notification = JsonRpcNotification::new("log.info".to_string(), json!({"msg": "test"}));
let result = registry.handle_notification(notification);
assert!(result.is_ok());
let unknown_notification = JsonRpcNotification::new("unknown".to_string(), json!({}));
let result = registry.handle_notification(unknown_notification);
assert!(result.is_err());
}
#[test]
fn test_registry_management() {
let mut registry = A2AMethodRegistry::new();
registry.register_method(
"ping",
"Simple ping",
Arc::new(|request| Ok(JsonRpcResponse::success(request.id, json!({})))),
);
registry.register_notification("log", "Logger", Arc::new(|_| Ok(())));
let stats = registry.stats();
assert_eq!(stats.method_count, 1);
assert_eq!(stats.notification_count, 1);
assert_eq!(stats.total_registrations, 2);
assert!(registry.unregister_method("ping"));
assert!(!registry.unregister_method("unknown"));
let stats = registry.stats();
assert_eq!(stats.method_count, 0);
assert_eq!(stats.notification_count, 1);
assert_eq!(stats.total_registrations, 1);
registry.clear();
let stats = registry.stats();
assert_eq!(stats.total_registrations, 0);
}
#[test]
fn test_method_with_metadata() {
let mut registry = A2AMethodRegistry::new();
registry.register_method_with_metadata(
"calculate",
"Calculate something",
Some(json!({"type": "object", "properties": {"a": {"type": "number"}}})),
Some(json!({"type": "number"})),
Arc::new(|request| Ok(JsonRpcResponse::success(request.id, json!(42)))),
);
let metadata = registry.get_method_metadata("calculate").unwrap();
assert!(metadata.parameters.is_some());
assert!(metadata.returns.is_some());
assert!(!metadata.is_notification);
}
#[test]
fn test_with_agent_card_constructor() {
let agent_card =
AgentCard::new("test-agent".to_string()).with_capability("test", "Test capability");
let registry = A2AMethodRegistry::with_agent_card(agent_card.clone());
assert!(registry.agent_card().is_some());
assert_eq!(registry.agent_card().unwrap().name, "test-agent");
}
#[test]
fn test_agent_card_getter() {
let mut registry = A2AMethodRegistry::new();
assert!(registry.agent_card().is_none());
let agent_card = AgentCard::new("test-agent".to_string());
registry.set_agent_card(agent_card.clone());
assert!(registry.agent_card().is_some());
assert_eq!(registry.agent_card().unwrap().name, "test-agent");
}
#[test]
fn test_get_all_metadata() {
let mut registry = A2AMethodRegistry::new();
registry.register_method(
"method1",
"First method",
Arc::new(|request| Ok(JsonRpcResponse::success(request.id, json!({})))),
);
registry.register_notification("notif1", "First notification", Arc::new(|_| Ok(())));
let all_metadata = registry.get_all_metadata();
assert_eq!(all_metadata.len(), 2);
assert!(all_metadata.contains_key("method1"));
assert!(all_metadata.contains_key("notif1"));
assert!(!all_metadata["method1"].is_notification);
assert!(all_metadata["notif1"].is_notification);
}
#[test]
fn test_unregister_notification() {
let mut registry = A2AMethodRegistry::new();
registry.register_notification("test_notif", "Test notification", Arc::new(|_| Ok(())));
assert!(registry.has_notification("test_notif"));
assert!(registry.unregister_notification("test_notif"));
assert!(!registry.has_notification("test_notif"));
assert!(!registry.unregister_notification("unknown_notif"));
}
#[test]
fn test_default_impl() {
let registry = A2AMethodRegistry::default();
assert_eq!(registry.list_methods().len(), 0);
assert_eq!(registry.list_notifications().len(), 0);
assert!(registry.agent_card().is_none());
}
#[test]
fn test_invalid_request_validation() {
let mut registry = A2AMethodRegistry::new();
registry.register_method(
"test",
"Test method",
Arc::new(|request| Ok(JsonRpcResponse::success(request.id, json!({})))),
);
let mut invalid_request =
JsonRpcRequest::new(json!("req-1"), "test".to_string(), json!({}));
invalid_request.method = "".to_string();
let result = registry.handle_request(invalid_request);
assert!(result.is_err());
match result {
Err(A2AError::ProtocolValidationError { .. }) => {
}
_ => panic!("Expected ProtocolValidationError"),
}
}
#[test]
fn test_invalid_notification_validation() {
let mut registry = A2AMethodRegistry::new();
registry.register_notification("test", "Test notification", Arc::new(|_| Ok(())));
let mut invalid_notification = JsonRpcNotification::new("test".to_string(), json!({}));
invalid_notification.method = "".to_string();
let result = registry.handle_notification(invalid_notification);
assert!(result.is_err());
match result {
Err(A2AError::ProtocolValidationError { .. }) => {
}
_ => panic!("Expected ProtocolValidationError"),
}
}
#[test]
fn test_method_execution_error_handling() {
let mut registry = A2AMethodRegistry::new();
registry.register_method(
"failing_method",
"Method that fails",
Arc::new(|_| Err(A2AError::internal("Test method error"))),
);
let request = JsonRpcRequest::new(json!("req-1"), "failing_method".to_string(), json!({}));
let result = registry.handle_request(request);
assert!(result.is_err());
match result {
Err(A2AError::MethodExecutionFailed { method, .. }) => {
assert_eq!(method, "failing_method");
}
_ => panic!("Expected MethodExecutionFailed error"),
}
}
#[test]
fn test_notification_execution_error_handling() {
let mut registry = A2AMethodRegistry::new();
registry.register_notification(
"failing_notification",
"Notification that fails",
Arc::new(|_| Err(A2AError::internal("Test notification error"))),
);
let notification = JsonRpcNotification::new("failing_notification".to_string(), json!({}));
let result = registry.handle_notification(notification);
assert!(result.is_err());
match result {
Err(A2AError::MethodExecutionFailed { method, .. }) => {
assert_eq!(method, "failing_notification");
}
_ => panic!("Expected MethodExecutionFailed error"),
}
}
#[test]
fn test_method_handler_closure_execution() {
let mut registry = A2AMethodRegistry::new();
registry.register_method(
"echo",
"Echo method",
Arc::new(|request| Ok(JsonRpcResponse::success(request.id, request.params))),
);
let request =
JsonRpcRequest::new(json!("req-1"), "echo".to_string(), json!({"data": "test"}));
let response = registry.handle_request(request).unwrap();
assert!(response.is_success());
assert_eq!(response.result.unwrap()["data"], "test");
}
#[test]
fn test_notification_handler_closure_execution() {
let mut registry = A2AMethodRegistry::new();
registry.register_notification(
"log",
"Log notification",
Arc::new(|_| {
Ok(())
}),
);
let notification = JsonRpcNotification::new("log".to_string(), json!({"message": "test"}));
let result = registry.handle_notification(notification);
assert!(result.is_ok());
}
#[test]
fn test_comprehensive_registry_operations() {
let mut registry = A2AMethodRegistry::new();
let agent_card = AgentCard::new("comprehensive-test".to_string())
.with_capability("method1", "First method")
.with_capability("method2", "Second method");
registry.set_agent_card(agent_card);
registry.register_method(
"method1",
"First method",
Arc::new(|request| Ok(JsonRpcResponse::success(request.id, json!({"result": 1})))),
);
registry.register_method_with_metadata(
"method2",
"Second method with metadata",
Some(json!({"type": "object"})),
Some(json!({"type": "number"})),
Arc::new(|request| Ok(JsonRpcResponse::success(request.id, json!({"result": 2})))),
);
registry.register_notification("notif1", "First notification", Arc::new(|_| Ok(())));
registry.register_notification("notif2", "Second notification", Arc::new(|_| Ok(())));
let stats = registry.stats();
assert_eq!(stats.method_count, 2);
assert_eq!(stats.notification_count, 2);
assert_eq!(stats.total_registrations, 4);
let methods = registry.list_methods();
assert_eq!(methods.len(), 2);
assert!(methods.contains(&"method1".to_string()));
assert!(methods.contains(&"method2".to_string()));
let notifications = registry.list_notifications();
assert_eq!(notifications.len(), 2);
assert!(notifications.contains(&"notif1".to_string()));
assert!(notifications.contains(&"notif2".to_string()));
let metadata = registry.get_method_metadata("method2").unwrap();
assert!(metadata.parameters.is_some());
assert!(metadata.returns.is_some());
let request = JsonRpcRequest::new(json!("req-1"), "method1".to_string(), json!({}));
let response = registry.handle_request(request).unwrap();
assert_eq!(response.result.unwrap()["result"], 1);
let notification = JsonRpcNotification::new("notif1".to_string(), json!({}));
let result = registry.handle_notification(notification);
assert!(result.is_ok());
assert!(registry.unregister_method("method1"));
assert_eq!(registry.list_methods().len(), 1);
assert!(registry.unregister_notification("notif1"));
assert_eq!(registry.list_notifications().len(), 1);
let final_stats = registry.stats();
assert_eq!(final_stats.method_count, 1);
assert_eq!(final_stats.notification_count, 1);
assert_eq!(final_stats.total_registrations, 2);
}
}