ironsbe_server/
dispatcher.rs1use crate::handler::{MessageHandler, Responder, TypedHandler};
4use ironsbe_core::header::MessageHeader;
5use std::collections::HashMap;
6use std::sync::Arc;
7
8pub struct MessageDispatcher {
10 handlers: HashMap<u16, Arc<dyn TypedHandler>>,
11 default_handler: Option<Arc<dyn MessageHandler>>,
12}
13
14impl MessageDispatcher {
15 #[must_use]
17 pub fn new() -> Self {
18 Self {
19 handlers: HashMap::new(),
20 default_handler: None,
21 }
22 }
23
24 pub fn register<H: TypedHandler + 'static>(&mut self, template_id: u16, handler: H) {
26 self.handlers.insert(template_id, Arc::new(handler));
27 }
28
29 pub fn set_default<H: MessageHandler + 'static>(&mut self, handler: H) {
31 self.default_handler = Some(Arc::new(handler));
32 }
33
34 #[must_use]
36 pub fn has_handler(&self, template_id: u16) -> bool {
37 self.handlers.contains_key(&template_id)
38 }
39}
40
41impl Default for MessageDispatcher {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl MessageHandler for MessageDispatcher {
48 fn on_message(
49 &self,
50 session_id: u64,
51 header: &MessageHeader,
52 buffer: &[u8],
53 responder: &dyn Responder,
54 ) {
55 let template_id = { header.template_id };
56 if let Some(handler) = self.handlers.get(&template_id) {
57 handler.handle(session_id, buffer, responder);
58 } else if let Some(default) = &self.default_handler {
59 default.on_message(session_id, header, buffer, responder);
60 } else {
61 tracing::warn!(
62 "No handler for template_id={} from session={}",
63 template_id,
64 session_id
65 );
66 }
67 }
68
69 fn on_session_start(&self, session_id: u64) {
70 if let Some(default) = &self.default_handler {
71 default.on_session_start(session_id);
72 }
73 }
74
75 fn on_session_end(&self, session_id: u64) {
76 if let Some(default) = &self.default_handler {
77 default.on_session_end(session_id);
78 }
79 }
80
81 fn on_error(&self, session_id: u64, error: &str) {
82 if let Some(default) = &self.default_handler {
83 default.on_error(session_id, error);
84 }
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use crate::handler::{FnHandler, SendError};
92 use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
93
94 struct MockResponder;
95
96 impl Responder for MockResponder {
97 fn send(&self, _message: &[u8]) -> Result<(), SendError> {
98 Ok(())
99 }
100
101 fn send_to(&self, _session_id: u64, _message: &[u8]) -> Result<(), SendError> {
102 Ok(())
103 }
104 }
105
106 #[test]
107 fn test_dispatcher_new() {
108 let dispatcher = MessageDispatcher::new();
109 assert!(!dispatcher.has_handler(1));
110 }
111
112 #[test]
113 fn test_dispatcher_default() {
114 let dispatcher = MessageDispatcher::default();
115 assert!(!dispatcher.has_handler(1));
116 }
117
118 #[test]
119 fn test_dispatcher_register() {
120 let mut dispatcher = MessageDispatcher::new();
121
122 let handler = FnHandler::new(|_session_id, _buffer, _responder| {});
123 dispatcher.register(1, handler);
124
125 assert!(dispatcher.has_handler(1));
126 assert!(!dispatcher.has_handler(2));
127 }
128
129 #[test]
130 fn test_dispatcher_on_message_with_handler() {
131 let mut dispatcher = MessageDispatcher::new();
132
133 let called = Arc::new(AtomicBool::new(false));
134 let called_clone = called.clone();
135
136 let handler = FnHandler::new(move |_session_id, _buffer, _responder| {
137 called_clone.store(true, Ordering::SeqCst);
138 });
139 dispatcher.register(1, handler);
140
141 let header = MessageHeader::new(16, 1, 100, 1);
142 let responder = MockResponder;
143 dispatcher.on_message(1, &header, &[0u8; 24], &responder);
144
145 assert!(called.load(Ordering::SeqCst));
146 }
147
148 #[test]
149 fn test_dispatcher_on_message_no_handler() {
150 let dispatcher = MessageDispatcher::new();
151
152 let header = MessageHeader::new(16, 99, 100, 1);
153 let responder = MockResponder;
154
155 dispatcher.on_message(1, &header, &[0u8; 24], &responder);
157 }
158
159 struct TestDefaultHandler {
160 session_started: Arc<AtomicU64>,
161 session_ended: Arc<AtomicU64>,
162 }
163
164 impl MessageHandler for TestDefaultHandler {
165 fn on_message(
166 &self,
167 _session_id: u64,
168 _header: &MessageHeader,
169 _buffer: &[u8],
170 _responder: &dyn Responder,
171 ) {
172 }
173
174 fn on_session_start(&self, session_id: u64) {
175 self.session_started.store(session_id, Ordering::SeqCst);
176 }
177
178 fn on_session_end(&self, session_id: u64) {
179 self.session_ended.store(session_id, Ordering::SeqCst);
180 }
181 }
182
183 #[test]
184 fn test_dispatcher_with_default_handler() {
185 let mut dispatcher = MessageDispatcher::new();
186
187 let session_started = Arc::new(AtomicU64::new(0));
188 let session_ended = Arc::new(AtomicU64::new(0));
189
190 let default_handler = TestDefaultHandler {
191 session_started: session_started.clone(),
192 session_ended: session_ended.clone(),
193 };
194 dispatcher.set_default(default_handler);
195
196 dispatcher.on_session_start(42);
197 assert_eq!(session_started.load(Ordering::SeqCst), 42);
198
199 dispatcher.on_session_end(43);
200 assert_eq!(session_ended.load(Ordering::SeqCst), 43);
201 }
202
203 #[test]
204 fn test_dispatcher_on_error_with_default() {
205 let mut dispatcher = MessageDispatcher::new();
206
207 struct ErrorHandler {
208 error_received: Arc<AtomicBool>,
209 }
210
211 impl MessageHandler for ErrorHandler {
212 fn on_message(
213 &self,
214 _session_id: u64,
215 _header: &MessageHeader,
216 _buffer: &[u8],
217 _responder: &dyn Responder,
218 ) {
219 }
220
221 fn on_error(&self, _session_id: u64, _error: &str) {
222 self.error_received.store(true, Ordering::SeqCst);
223 }
224 }
225
226 let error_received = Arc::new(AtomicBool::new(false));
227 let handler = ErrorHandler {
228 error_received: error_received.clone(),
229 };
230 dispatcher.set_default(handler);
231
232 dispatcher.on_error(1, "test error");
233 assert!(error_received.load(Ordering::SeqCst));
234 }
235}