Skip to main content

network_protocol/protocol/
dispatcher.rs

1use crate::error::{constants, ProtocolError, Result};
2use crate::protocol::message::Message;
3use std::borrow::Cow;
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7type HandlerFn = dyn Fn(&Message) -> Result<Message> + Send + Sync + 'static;
8
9/// Message dispatcher with zero-copy opcode routing for statics.
10/// Uses Cow<'static, str> to avoid heap allocations for known message types.
11pub struct Dispatcher {
12    handlers: Arc<RwLock<HashMap<Cow<'static, str>, Box<HandlerFn>>>>,
13}
14
15impl Default for Dispatcher {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl Dispatcher {
22    pub fn new() -> Self {
23        Self {
24            handlers: Arc::new(RwLock::new(HashMap::new())),
25        }
26    }
27
28    pub fn register<F>(&self, opcode: &str, handler: F) -> Result<()>
29    where
30        F: Fn(&Message) -> Result<Message> + Send + Sync + 'static,
31    {
32        let mut handlers = self
33            .handlers
34            .write()
35            .map_err(|_| ProtocolError::Custom(constants::ERR_DISPATCHER_WRITE_LOCK.into()))?;
36
37        handlers.insert(Cow::Owned(opcode.to_string()), Box::new(handler));
38        Ok(())
39    }
40
41    pub fn dispatch(&self, msg: &Message) -> Result<Message> {
42        let opcode = get_opcode(msg);
43
44        let handlers = self
45            .handlers
46            .read()
47            .map_err(|_| ProtocolError::Custom(constants::ERR_DISPATCHER_READ_LOCK.into()))?;
48
49        handlers
50            .get(opcode.as_ref())
51            .ok_or(ProtocolError::UnexpectedMessage)
52            .and_then(|handler| handler(msg))
53    }
54}
55
56/// Determine message type name for routing (zero-copy for known message types).
57/// Returns Cow::Borrowed for static message type opcodes, avoiding allocations in hot path.
58/// For custom commands, returns Cow::Owned since the string comes from the message.
59#[inline]
60fn get_opcode(msg: &Message) -> Cow<'static, str> {
61    match msg {
62        Message::Ping => Cow::Borrowed("PING"),
63        Message::Pong => Cow::Borrowed("PONG"),
64        Message::Echo(_) => Cow::Borrowed("ECHO"),
65        Message::SecureHandshakeInit { .. } => Cow::Borrowed("SEC_HS_INIT"),
66        Message::SecureHandshakeResponse { .. } => Cow::Borrowed("SEC_HS_RESP"),
67        Message::SecureHandshakeConfirm { .. } => Cow::Borrowed("SEC_HS_CONFIRM"),
68        Message::Custom { command, .. } => Cow::Owned(command.clone()),
69        Message::Disconnect => Cow::Borrowed("DISCONNECT"),
70        Message::Unknown => Cow::Borrowed("UNKNOWN"),
71    }
72}