use crate::error::{constants, ProtocolError, Result};
use crate::protocol::message::Message;
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
type HandlerFn = dyn Fn(&Message) -> Result<Message> + Send + Sync + 'static;
pub struct Dispatcher {
handlers: Arc<RwLock<HashMap<Cow<'static, str>, Box<HandlerFn>>>>,
}
impl Default for Dispatcher {
fn default() -> Self {
Self::new()
}
}
impl Dispatcher {
pub fn new() -> Self {
Self {
handlers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register<F>(&self, opcode: &str, handler: F) -> Result<()>
where
F: Fn(&Message) -> Result<Message> + Send + Sync + 'static,
{
let mut handlers = self
.handlers
.write()
.map_err(|_| ProtocolError::Custom(constants::ERR_DISPATCHER_WRITE_LOCK.into()))?;
handlers.insert(Cow::Owned(opcode.to_string()), Box::new(handler));
Ok(())
}
pub fn dispatch(&self, msg: &Message) -> Result<Message> {
let opcode = get_opcode(msg);
let handlers = self
.handlers
.read()
.map_err(|_| ProtocolError::Custom(constants::ERR_DISPATCHER_READ_LOCK.into()))?;
handlers
.get(opcode.as_ref())
.ok_or(ProtocolError::UnexpectedMessage)
.and_then(|handler| handler(msg))
}
}
#[inline]
fn get_opcode(msg: &Message) -> Cow<'static, str> {
match msg {
Message::Ping => Cow::Borrowed("PING"),
Message::Pong => Cow::Borrowed("PONG"),
Message::Echo(_) => Cow::Borrowed("ECHO"),
Message::SecureHandshakeInit { .. } => Cow::Borrowed("SEC_HS_INIT"),
Message::SecureHandshakeResponse { .. } => Cow::Borrowed("SEC_HS_RESP"),
Message::SecureHandshakeConfirm { .. } => Cow::Borrowed("SEC_HS_CONFIRM"),
Message::Custom { command, .. } => Cow::Owned(command.clone()),
Message::Disconnect => Cow::Borrowed("DISCONNECT"),
Message::Unknown => Cow::Borrowed("UNKNOWN"),
}
}