oxigdal_websocket/broadcast/
router.rs1use crate::protocol::message::{Message, MessageType};
4use crate::server::connection::ConnectionId;
5use parking_lot::RwLock;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum RoutingStrategy {
12 Broadcast,
14 Direct,
16 Room,
18 Topic,
20 Custom,
22}
23
24#[derive(Clone)]
26pub struct RoutingRule {
27 pub name: String,
29 pub message_type: Option<MessageType>,
31 pub strategy: RoutingStrategy,
33 pub target: Option<String>,
35 pub priority: i32,
37}
38
39impl RoutingRule {
40 pub fn new(name: String, strategy: RoutingStrategy) -> Self {
42 Self {
43 name,
44 message_type: None,
45 strategy,
46 target: None,
47 priority: 0,
48 }
49 }
50
51 pub fn with_message_type(mut self, msg_type: MessageType) -> Self {
53 self.message_type = Some(msg_type);
54 self
55 }
56
57 pub fn with_target(mut self, target: String) -> Self {
59 self.target = Some(target);
60 self
61 }
62
63 pub fn with_priority(mut self, priority: i32) -> Self {
65 self.priority = priority;
66 self
67 }
68
69 pub fn matches(&self, message: &Message) -> bool {
71 if let Some(msg_type) = self.message_type {
72 message.msg_type == msg_type
73 } else {
74 true
75 }
76 }
77}
78
79pub struct RouteResult {
81 pub strategy: RoutingStrategy,
83 pub targets: Vec<ConnectionId>,
85 pub target_id: Option<String>,
87}
88
89impl RouteResult {
90 pub fn broadcast() -> Self {
92 Self {
93 strategy: RoutingStrategy::Broadcast,
94 targets: Vec::new(),
95 target_id: None,
96 }
97 }
98
99 pub fn direct(target: ConnectionId) -> Self {
101 Self {
102 strategy: RoutingStrategy::Direct,
103 targets: vec![target],
104 target_id: None,
105 }
106 }
107
108 pub fn room(room_name: String) -> Self {
110 Self {
111 strategy: RoutingStrategy::Room,
112 targets: Vec::new(),
113 target_id: Some(room_name),
114 }
115 }
116
117 pub fn topic(topic_name: String) -> Self {
119 Self {
120 strategy: RoutingStrategy::Topic,
121 targets: Vec::new(),
122 target_id: Some(topic_name),
123 }
124 }
125}
126
127pub struct MessageRouter {
129 rules: Arc<RwLock<Vec<RoutingRule>>>,
130 default_strategy: RoutingStrategy,
131}
132
133impl MessageRouter {
134 pub fn new() -> Self {
136 Self {
137 rules: Arc::new(RwLock::new(Vec::new())),
138 default_strategy: RoutingStrategy::Broadcast,
139 }
140 }
141
142 pub fn with_default_strategy(strategy: RoutingStrategy) -> Self {
144 Self {
145 rules: Arc::new(RwLock::new(Vec::new())),
146 default_strategy: strategy,
147 }
148 }
149
150 pub fn add_rule(&self, rule: RoutingRule) {
152 let mut rules = self.rules.write();
153 rules.push(rule);
154 rules.sort_by_key(|x| std::cmp::Reverse(x.priority));
156 }
157
158 pub fn remove_rule(&self, name: &str) -> bool {
160 let mut rules = self.rules.write();
161 if let Some(pos) = rules.iter().position(|r| r.name == name) {
162 rules.remove(pos);
163 true
164 } else {
165 false
166 }
167 }
168
169 pub fn route(&self, message: &Message) -> RouteResult {
171 let rules = self.rules.read();
172
173 for rule in rules.iter() {
175 if rule.matches(message) {
176 return match rule.strategy {
177 RoutingStrategy::Broadcast => RouteResult::broadcast(),
178 RoutingStrategy::Room => {
179 if let Some(target) = &rule.target {
180 RouteResult::room(target.clone())
181 } else {
182 RouteResult::broadcast()
183 }
184 }
185 RoutingStrategy::Topic => {
186 if let Some(target) = &rule.target {
187 RouteResult::topic(target.clone())
188 } else {
189 RouteResult::broadcast()
190 }
191 }
192 _ => RouteResult::broadcast(),
193 };
194 }
195 }
196
197 match self.default_strategy {
199 RoutingStrategy::Broadcast => RouteResult::broadcast(),
200 _ => RouteResult::broadcast(),
201 }
202 }
203
204 pub fn rules(&self) -> Vec<RoutingRule> {
206 self.rules.read().clone()
207 }
208
209 pub fn clear_rules(&self) {
211 self.rules.write().clear();
212 }
213
214 pub fn rule_count(&self) -> usize {
216 self.rules.read().len()
217 }
218}
219
220impl Default for MessageRouter {
221 fn default() -> Self {
222 Self::new()
223 }
224}
225
226pub struct RoutingTable {
228 routes: Arc<RwLock<HashMap<ConnectionId, Vec<ConnectionId>>>>,
230}
231
232impl RoutingTable {
233 pub fn new() -> Self {
235 Self {
236 routes: Arc::new(RwLock::new(HashMap::new())),
237 }
238 }
239
240 pub fn add_route(&self, source: ConnectionId, target: ConnectionId) {
242 let mut routes = self.routes.write();
243 routes.entry(source).or_default().push(target);
244 }
245
246 pub fn remove_route(&self, source: &ConnectionId, target: &ConnectionId) {
248 let mut routes = self.routes.write();
249 if let Some(targets) = routes.get_mut(source) {
250 targets.retain(|t| t != target);
251 }
252 }
253
254 pub fn get_targets(&self, source: &ConnectionId) -> Vec<ConnectionId> {
256 let routes = self.routes.read();
257 routes.get(source).cloned().unwrap_or_default()
258 }
259
260 pub fn remove_connection(&self, connection: &ConnectionId) {
262 let mut routes = self.routes.write();
263 routes.remove(connection);
264
265 for targets in routes.values_mut() {
267 targets.retain(|t| t != connection);
268 }
269 }
270
271 pub fn clear(&self) {
273 self.routes.write().clear();
274 }
275
276 pub fn route_count(&self) -> usize {
278 self.routes.read().len()
279 }
280}
281
282impl Default for RoutingTable {
283 fn default() -> Self {
284 Self::new()
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn test_routing_rule() {
294 let rule = RoutingRule::new("test".to_string(), RoutingStrategy::Broadcast)
295 .with_message_type(MessageType::Ping)
296 .with_priority(10);
297
298 assert_eq!(rule.name, "test");
299 assert_eq!(rule.message_type, Some(MessageType::Ping));
300 assert_eq!(rule.priority, 10);
301 }
302
303 #[test]
304 fn test_routing_rule_matches() {
305 let rule = RoutingRule::new("test".to_string(), RoutingStrategy::Broadcast)
306 .with_message_type(MessageType::Ping);
307
308 let ping = Message::ping();
309 let pong = Message::pong();
310
311 assert!(rule.matches(&ping));
312 assert!(!rule.matches(&pong));
313 }
314
315 #[test]
316 fn test_message_router() {
317 let router = MessageRouter::new();
318 assert_eq!(router.rule_count(), 0);
319
320 let rule = RoutingRule::new("test".to_string(), RoutingStrategy::Broadcast);
321 router.add_rule(rule);
322
323 assert_eq!(router.rule_count(), 1);
324 }
325
326 #[test]
327 fn test_router_route() {
328 let router = MessageRouter::new();
329
330 let rule = RoutingRule::new("ping_room".to_string(), RoutingStrategy::Room)
331 .with_message_type(MessageType::Ping)
332 .with_target("lobby".to_string());
333
334 router.add_rule(rule);
335
336 let ping = Message::ping();
337 let result = router.route(&ping);
338
339 assert_eq!(result.strategy, RoutingStrategy::Room);
340 assert_eq!(result.target_id, Some("lobby".to_string()));
341 }
342
343 #[test]
344 fn test_router_priority() {
345 let router = MessageRouter::new();
346
347 let rule1 =
348 RoutingRule::new("low".to_string(), RoutingStrategy::Broadcast).with_priority(1);
349
350 let rule2 = RoutingRule::new("high".to_string(), RoutingStrategy::Room)
351 .with_priority(10)
352 .with_target("test".to_string());
353
354 router.add_rule(rule1);
355 router.add_rule(rule2);
356
357 let msg = Message::ping();
358 let result = router.route(&msg);
359
360 assert_eq!(result.strategy, RoutingStrategy::Room);
362 }
363
364 #[test]
365 fn test_routing_table() {
366 let table = RoutingTable::new();
367 let source = uuid::Uuid::new_v4();
368 let target1 = uuid::Uuid::new_v4();
369 let target2 = uuid::Uuid::new_v4();
370
371 table.add_route(source, target1);
372 table.add_route(source, target2);
373
374 let targets = table.get_targets(&source);
375 assert_eq!(targets.len(), 2);
376 }
377
378 #[test]
379 fn test_routing_table_remove() {
380 let table = RoutingTable::new();
381 let source = uuid::Uuid::new_v4();
382 let target = uuid::Uuid::new_v4();
383
384 table.add_route(source, target);
385 table.remove_route(&source, &target);
386
387 let targets = table.get_targets(&source);
388 assert_eq!(targets.len(), 0);
389 }
390}