network_protocol/protocol/
dispatcher.rs1use crate::error::{ProtocolError, Result};
2use crate::protocol::message::Message;
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6type HandlerFn = dyn Fn(&Message) -> Result<Message> + Send + Sync + 'static;
7
8pub struct Dispatcher {
9 handlers: Arc<RwLock<HashMap<String, Box<HandlerFn>>>>,
10}
11
12impl Default for Dispatcher {
13 fn default() -> Self {
14 Self::new()
15 }
16}
17
18impl Dispatcher {
19 pub fn new() -> Self {
20 Self {
21 handlers: Arc::new(RwLock::new(HashMap::new())),
22 }
23 }
24
25 pub fn register<F>(&self, opcode: &str, handler: F) -> Result<()>
26 where
27 F: Fn(&Message) -> Result<Message> + Send + Sync + 'static,
28 {
29 match self.handlers.write() {
30 Ok(mut handlers) => {
31 handlers.insert(opcode.to_string(), Box::new(handler));
32 Ok(())
33 }
34 Err(_) => Err(ProtocolError::Custom(
35 "Failed to acquire write lock on dispatcher handlers".to_string(),
36 )),
37 }
38 }
39
40 pub fn dispatch(&self, msg: &Message) -> Result<Message> {
41 let opcode = get_opcode(msg);
42
43 let handlers = match self.handlers.read() {
44 Ok(handlers) => handlers,
45 Err(_) => {
46 return Err(ProtocolError::Custom(
47 "Failed to acquire read lock on dispatcher handlers".to_string(),
48 ))
49 }
50 };
51
52 match handlers.get(&opcode) {
53 Some(handler) => handler(msg),
54 None => Err(ProtocolError::UnexpectedMessage),
55 }
56 }
57}
58
59fn get_opcode(msg: &Message) -> String {
61 match msg {
62 Message::Ping => "PING",
63 Message::Pong => "PONG",
64 Message::Echo(_) => "ECHO",
65 Message::SecureHandshakeInit { .. } => "SEC_HS_INIT",
67 Message::SecureHandshakeResponse { .. } => "SEC_HS_RESP",
68 Message::SecureHandshakeConfirm { .. } => "SEC_HS_CONFIRM",
69 Message::Custom { command, .. } => command,
70 Message::Disconnect => "DISCONNECT",
71 Message::Unknown => "UNKNOWN",
72 }
73 .to_string()
74}