use crate::Result;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use super::{Protocol, ProtocolRequest, ProtocolResponse, SpecRegistry};
#[async_trait::async_trait]
pub trait ProtocolHandler: Send + Sync {
fn protocol(&self) -> Protocol;
fn is_enabled(&self) -> bool;
fn set_enabled(&self, enabled: bool);
fn spec_registry(&self) -> Option<&dyn SpecRegistry>;
async fn handle_request(&self, request: ProtocolRequest) -> Result<ProtocolResponse>;
fn validate_configuration(&self) -> Result<()>;
fn get_configuration(&self) -> HashMap<String, String>;
fn update_configuration(&self, config: HashMap<String, String>) -> Result<()>;
}
pub struct ProtocolRegistry {
handlers: HashMap<Protocol, Arc<dyn ProtocolHandler>>,
enabled_protocols: HashSet<Protocol>,
}
impl ProtocolRegistry {
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
enabled_protocols: HashSet::new(),
}
}
pub fn register_handler(&mut self, handler: Arc<dyn ProtocolHandler>) -> Result<()> {
let protocol = handler.protocol();
if handler.is_enabled() {
self.enabled_protocols.insert(protocol);
}
self.handlers.insert(protocol, handler);
Ok(())
}
pub fn unregister_handler(&mut self, protocol: Protocol) -> Result<()> {
if self.handlers.remove(&protocol).is_some() {
self.enabled_protocols.remove(&protocol);
Ok(())
} else {
Err(crate::Error::protocol_not_found(protocol.to_string()))
}
}
pub fn get_handler(&self, protocol: Protocol) -> Option<&Arc<dyn ProtocolHandler>> {
self.handlers.get(&protocol)
}
pub fn is_protocol_enabled(&self, protocol: Protocol) -> bool {
self.enabled_protocols.contains(&protocol)
}
pub fn enable_protocol(&mut self, protocol: Protocol) -> Result<()> {
if self.handlers.contains_key(&protocol) {
self.enabled_protocols.insert(protocol);
Ok(())
} else {
Err(crate::Error::protocol_not_found(protocol.to_string()))
}
}
pub fn disable_protocol(&mut self, protocol: Protocol) -> Result<()> {
if self.handlers.contains_key(&protocol) {
self.enabled_protocols.remove(&protocol);
Ok(())
} else {
Err(crate::Error::protocol_not_found(protocol.to_string()))
}
}
pub fn registered_protocols(&self) -> Vec<Protocol> {
self.handlers.keys().cloned().collect()
}
pub fn enabled_protocols(&self) -> Vec<Protocol> {
self.enabled_protocols.iter().cloned().collect()
}
pub async fn handle_request(&self, request: ProtocolRequest) -> Result<ProtocolResponse> {
let protocol = request.protocol;
if !self.is_protocol_enabled(protocol) {
return Err(crate::Error::protocol_disabled(protocol.to_string()));
}
if let Some(handler) = self.get_handler(protocol) {
handler.handle_request(request).await
} else {
Err(crate::Error::protocol_not_found(protocol.to_string()))
}
}
pub fn validate_all_handlers(&self) -> Result<()> {
for (protocol, handler) in &self.handlers {
if let Err(e) = handler.validate_configuration() {
return Err(crate::Error::protocol_validation_error(
protocol.to_string(),
e.to_string(),
));
}
}
Ok(())
}
pub fn get_all_configurations(&self) -> HashMap<Protocol, HashMap<String, String>> {
self.handlers
.iter()
.map(|(protocol, handler)| (*protocol, handler.get_configuration()))
.collect()
}
pub fn update_protocol_configuration(
&self,
protocol: Protocol,
config: HashMap<String, String>,
) -> Result<()> {
if let Some(handler) = self.handlers.get(&protocol) {
handler.update_configuration(config)
} else {
Err(crate::Error::protocol_not_found(protocol.to_string()))
}
}
}
impl Default for ProtocolRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::sync::Mutex;
struct MockProtocolHandler {
protocol: Protocol,
enabled: Mutex<bool>,
config: Mutex<HashMap<String, String>>,
}
impl MockProtocolHandler {
fn new(protocol: Protocol) -> Self {
Self {
protocol,
enabled: Mutex::new(true),
config: Mutex::new(HashMap::new()),
}
}
}
#[async_trait]
impl ProtocolHandler for MockProtocolHandler {
fn protocol(&self) -> Protocol {
self.protocol
}
fn is_enabled(&self) -> bool {
*self.enabled.lock().unwrap()
}
fn set_enabled(&self, enabled: bool) {
*self.enabled.lock().unwrap() = enabled;
}
fn spec_registry(&self) -> Option<&dyn SpecRegistry> {
None
}
async fn handle_request(&self, _request: ProtocolRequest) -> Result<ProtocolResponse> {
Ok(ProtocolResponse {
status: super::super::ResponseStatus::HttpStatus(200),
metadata: HashMap::new(),
body: b"mock response".to_vec(),
content_type: "text/plain".to_string(),
})
}
fn validate_configuration(&self) -> Result<()> {
Ok(())
}
fn get_configuration(&self) -> HashMap<String, String> {
self.config.lock().unwrap().clone()
}
fn update_configuration(&self, config: HashMap<String, String>) -> Result<()> {
*self.config.lock().unwrap() = config;
Ok(())
}
}
#[test]
fn test_protocol_registry_creation() {
let registry = ProtocolRegistry::new();
assert_eq!(registry.registered_protocols().len(), 0);
assert_eq!(registry.enabled_protocols().len(), 0);
}
#[test]
fn test_register_handler() {
let mut registry = ProtocolRegistry::new();
let handler = Arc::new(MockProtocolHandler::new(Protocol::Http));
assert!(registry.register_handler(handler).is_ok());
assert_eq!(registry.registered_protocols(), vec![Protocol::Http]);
assert_eq!(registry.enabled_protocols(), vec![Protocol::Http]);
}
#[test]
fn test_enable_disable_protocol() {
let mut registry = ProtocolRegistry::new();
let handler = Arc::new(MockProtocolHandler::new(Protocol::Http));
registry.register_handler(handler).unwrap();
assert!(registry.is_protocol_enabled(Protocol::Http));
registry.disable_protocol(Protocol::Http).unwrap();
assert!(!registry.is_protocol_enabled(Protocol::Http));
registry.enable_protocol(Protocol::Http).unwrap();
assert!(registry.is_protocol_enabled(Protocol::Http));
}
#[test]
fn test_handle_request() {
let mut registry = ProtocolRegistry::new();
let handler = Arc::new(MockProtocolHandler::new(Protocol::Http));
registry.register_handler(handler).unwrap();
let request = ProtocolRequest {
protocol: Protocol::Http,
operation: "GET".to_string(),
path: "/test".to_string(),
..Default::default()
};
let result = futures::executor::block_on(registry.handle_request(request));
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.body, b"mock response");
}
#[test]
fn test_handle_request_disabled_protocol() {
let mut registry = ProtocolRegistry::new();
let handler = Arc::new(MockProtocolHandler::new(Protocol::Http));
registry.register_handler(handler).unwrap();
registry.disable_protocol(Protocol::Http).unwrap();
let request = ProtocolRequest {
protocol: Protocol::Http,
operation: "GET".to_string(),
path: "/test".to_string(),
..Default::default()
};
let result = futures::executor::block_on(registry.handle_request(request));
assert!(result.is_err());
}
#[test]
fn test_update_protocol_configuration() {
let mut registry = ProtocolRegistry::new();
let handler = Arc::new(MockProtocolHandler::new(Protocol::Http));
registry.register_handler(handler).unwrap();
let mut config = HashMap::new();
config.insert("timeout".to_string(), "30".to_string());
config.insert("retries".to_string(), "3".to_string());
registry.update_protocol_configuration(Protocol::Http, config.clone()).unwrap();
let configs = registry.get_all_configurations();
let http_config = configs.get(&Protocol::Http).unwrap();
assert_eq!(http_config.get("timeout").unwrap(), "30");
assert_eq!(http_config.get("retries").unwrap(), "3");
}
#[test]
fn test_update_protocol_configuration_not_found() {
let registry = ProtocolRegistry::new();
let result = registry.update_protocol_configuration(Protocol::Http, HashMap::new());
assert!(result.is_err());
}
}