network_protocol/protocol/
dispatcher.rs1use crate::protocol::message::Message;
2use crate::error::{Result, ProtocolError};
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6type HandlerFn = dyn Fn(&Message) -> Result<Message> + Send + Sync + 'static;
7
8pub struct Dispatcher {
9 handlers: Arc<RwLock<HashMap<String, Box<HandlerFn>>>>,
10}
11
12impl Default for Dispatcher {
13 fn default() -> Self {
14 Self::new()
15 }
16}
17
18impl Dispatcher {
19 pub fn new() -> Self {
20 Self {
21 handlers: Arc::new(RwLock::new(HashMap::new())),
22 }
23 }
24
25 pub fn register<F>(&self, opcode: &str, handler: F) -> Result<()>
26 where
27 F: Fn(&Message) -> Result<Message> + Send + Sync + 'static,
28 {
29 match self.handlers.write() {
30 Ok(mut handlers) => {
31 handlers.insert(opcode.to_string(), Box::new(handler));
32 Ok(())
33 },
34 Err(_) => Err(ProtocolError::Custom("Failed to acquire write lock on dispatcher handlers".to_string())),
35 }
36 }
37
38 pub fn dispatch(&self, msg: &Message) -> Result<Message> {
39 let opcode = get_opcode(msg);
40
41 let handlers = match self.handlers.read() {
42 Ok(handlers) => handlers,
43 Err(_) => return Err(ProtocolError::Custom("Failed to acquire read lock on dispatcher handlers".to_string())),
44 };
45
46 match handlers.get(&opcode) {
47 Some(handler) => handler(msg),
48 None => Err(ProtocolError::UnexpectedMessage),
49 }
50 }
51}
52
53fn get_opcode(msg: &Message) -> String {
55 match msg {
56 Message::Ping => "PING",
57 Message::Pong => "PONG",
58 Message::Echo(_) => "ECHO",
59 Message::SecureHandshakeInit { .. } => "SEC_HS_INIT",
61 Message::SecureHandshakeResponse { .. } => "SEC_HS_RESP",
62 Message::SecureHandshakeConfirm { .. } => "SEC_HS_CONFIRM",
63 Message::Custom { command, .. } => command,
64 Message::Disconnect => "DISCONNECT",
65 Message::Unknown => "UNKNOWN",
66 }
67 .to_string()
68}