Skip to main content

oxigdal_websocket/broadcast/
router.rs

1//! Message routing and distribution
2
3use crate::protocol::message::{Message, MessageType};
4use crate::server::connection::ConnectionId;
5use parking_lot::RwLock;
6use std::collections::HashMap;
7use std::sync::Arc;
8
9/// Routing strategy
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum RoutingStrategy {
12    /// Broadcast to all connections
13    Broadcast,
14    /// Send to specific connection
15    Direct,
16    /// Send to connections in a room
17    Room,
18    /// Send to topic subscribers
19    Topic,
20    /// Custom routing logic
21    Custom,
22}
23
24/// Routing rule
25#[derive(Clone)]
26pub struct RoutingRule {
27    /// Rule name
28    pub name: String,
29    /// Message type to match
30    pub message_type: Option<MessageType>,
31    /// Routing strategy
32    pub strategy: RoutingStrategy,
33    /// Target (room name, topic, etc.)
34    pub target: Option<String>,
35    /// Priority (higher priority rules are evaluated first)
36    pub priority: i32,
37}
38
39impl RoutingRule {
40    /// Create a new routing rule
41    pub fn new(name: String, strategy: RoutingStrategy) -> Self {
42        Self {
43            name,
44            message_type: None,
45            strategy,
46            target: None,
47            priority: 0,
48        }
49    }
50
51    /// Set message type filter
52    pub fn with_message_type(mut self, msg_type: MessageType) -> Self {
53        self.message_type = Some(msg_type);
54        self
55    }
56
57    /// Set target
58    pub fn with_target(mut self, target: String) -> Self {
59        self.target = Some(target);
60        self
61    }
62
63    /// Set priority
64    pub fn with_priority(mut self, priority: i32) -> Self {
65        self.priority = priority;
66        self
67    }
68
69    /// Check if this rule matches a message
70    pub fn matches(&self, message: &Message) -> bool {
71        if let Some(msg_type) = self.message_type {
72            message.msg_type == msg_type
73        } else {
74            true
75        }
76    }
77}
78
79/// Route result
80pub struct RouteResult {
81    /// Routing strategy to use
82    pub strategy: RoutingStrategy,
83    /// Target connections
84    pub targets: Vec<ConnectionId>,
85    /// Target identifier (room name, topic, etc.)
86    pub target_id: Option<String>,
87}
88
89impl RouteResult {
90    /// Create a broadcast route
91    pub fn broadcast() -> Self {
92        Self {
93            strategy: RoutingStrategy::Broadcast,
94            targets: Vec::new(),
95            target_id: None,
96        }
97    }
98
99    /// Create a direct route
100    pub fn direct(target: ConnectionId) -> Self {
101        Self {
102            strategy: RoutingStrategy::Direct,
103            targets: vec![target],
104            target_id: None,
105        }
106    }
107
108    /// Create a room route
109    pub fn room(room_name: String) -> Self {
110        Self {
111            strategy: RoutingStrategy::Room,
112            targets: Vec::new(),
113            target_id: Some(room_name),
114        }
115    }
116
117    /// Create a topic route
118    pub fn topic(topic_name: String) -> Self {
119        Self {
120            strategy: RoutingStrategy::Topic,
121            targets: Vec::new(),
122            target_id: Some(topic_name),
123        }
124    }
125}
126
127/// Message router
128pub struct MessageRouter {
129    rules: Arc<RwLock<Vec<RoutingRule>>>,
130    default_strategy: RoutingStrategy,
131}
132
133impl MessageRouter {
134    /// Create a new message router
135    pub fn new() -> Self {
136        Self {
137            rules: Arc::new(RwLock::new(Vec::new())),
138            default_strategy: RoutingStrategy::Broadcast,
139        }
140    }
141
142    /// Create a router with a default strategy
143    pub fn with_default_strategy(strategy: RoutingStrategy) -> Self {
144        Self {
145            rules: Arc::new(RwLock::new(Vec::new())),
146            default_strategy: strategy,
147        }
148    }
149
150    /// Add a routing rule
151    pub fn add_rule(&self, rule: RoutingRule) {
152        let mut rules = self.rules.write();
153        rules.push(rule);
154        // Sort by priority (highest first)
155        rules.sort_by_key(|x| std::cmp::Reverse(x.priority));
156    }
157
158    /// Remove a routing rule by name
159    pub fn remove_rule(&self, name: &str) -> bool {
160        let mut rules = self.rules.write();
161        if let Some(pos) = rules.iter().position(|r| r.name == name) {
162            rules.remove(pos);
163            true
164        } else {
165            false
166        }
167    }
168
169    /// Route a message
170    pub fn route(&self, message: &Message) -> RouteResult {
171        let rules = self.rules.read();
172
173        // Find first matching rule
174        for rule in rules.iter() {
175            if rule.matches(message) {
176                return match rule.strategy {
177                    RoutingStrategy::Broadcast => RouteResult::broadcast(),
178                    RoutingStrategy::Room => {
179                        if let Some(target) = &rule.target {
180                            RouteResult::room(target.clone())
181                        } else {
182                            RouteResult::broadcast()
183                        }
184                    }
185                    RoutingStrategy::Topic => {
186                        if let Some(target) = &rule.target {
187                            RouteResult::topic(target.clone())
188                        } else {
189                            RouteResult::broadcast()
190                        }
191                    }
192                    _ => RouteResult::broadcast(),
193                };
194            }
195        }
196
197        // Use default strategy if no rule matches
198        match self.default_strategy {
199            RoutingStrategy::Broadcast => RouteResult::broadcast(),
200            _ => RouteResult::broadcast(),
201        }
202    }
203
204    /// Get all rules
205    pub fn rules(&self) -> Vec<RoutingRule> {
206        self.rules.read().clone()
207    }
208
209    /// Clear all rules
210    pub fn clear_rules(&self) {
211        self.rules.write().clear();
212    }
213
214    /// Get rule count
215    pub fn rule_count(&self) -> usize {
216        self.rules.read().len()
217    }
218}
219
220impl Default for MessageRouter {
221    fn default() -> Self {
222        Self::new()
223    }
224}
225
226/// Routing table for connection-to-connection routing
227pub struct RoutingTable {
228    /// Mapping of source connections to target connections
229    routes: Arc<RwLock<HashMap<ConnectionId, Vec<ConnectionId>>>>,
230}
231
232impl RoutingTable {
233    /// Create a new routing table
234    pub fn new() -> Self {
235        Self {
236            routes: Arc::new(RwLock::new(HashMap::new())),
237        }
238    }
239
240    /// Add a route
241    pub fn add_route(&self, source: ConnectionId, target: ConnectionId) {
242        let mut routes = self.routes.write();
243        routes.entry(source).or_default().push(target);
244    }
245
246    /// Remove a route
247    pub fn remove_route(&self, source: &ConnectionId, target: &ConnectionId) {
248        let mut routes = self.routes.write();
249        if let Some(targets) = routes.get_mut(source) {
250            targets.retain(|t| t != target);
251        }
252    }
253
254    /// Get targets for a source
255    pub fn get_targets(&self, source: &ConnectionId) -> Vec<ConnectionId> {
256        let routes = self.routes.read();
257        routes.get(source).cloned().unwrap_or_default()
258    }
259
260    /// Remove all routes for a connection
261    pub fn remove_connection(&self, connection: &ConnectionId) {
262        let mut routes = self.routes.write();
263        routes.remove(connection);
264
265        // Also remove from targets
266        for targets in routes.values_mut() {
267            targets.retain(|t| t != connection);
268        }
269    }
270
271    /// Clear all routes
272    pub fn clear(&self) {
273        self.routes.write().clear();
274    }
275
276    /// Get route count
277    pub fn route_count(&self) -> usize {
278        self.routes.read().len()
279    }
280}
281
282impl Default for RoutingTable {
283    fn default() -> Self {
284        Self::new()
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn test_routing_rule() {
294        let rule = RoutingRule::new("test".to_string(), RoutingStrategy::Broadcast)
295            .with_message_type(MessageType::Ping)
296            .with_priority(10);
297
298        assert_eq!(rule.name, "test");
299        assert_eq!(rule.message_type, Some(MessageType::Ping));
300        assert_eq!(rule.priority, 10);
301    }
302
303    #[test]
304    fn test_routing_rule_matches() {
305        let rule = RoutingRule::new("test".to_string(), RoutingStrategy::Broadcast)
306            .with_message_type(MessageType::Ping);
307
308        let ping = Message::ping();
309        let pong = Message::pong();
310
311        assert!(rule.matches(&ping));
312        assert!(!rule.matches(&pong));
313    }
314
315    #[test]
316    fn test_message_router() {
317        let router = MessageRouter::new();
318        assert_eq!(router.rule_count(), 0);
319
320        let rule = RoutingRule::new("test".to_string(), RoutingStrategy::Broadcast);
321        router.add_rule(rule);
322
323        assert_eq!(router.rule_count(), 1);
324    }
325
326    #[test]
327    fn test_router_route() {
328        let router = MessageRouter::new();
329
330        let rule = RoutingRule::new("ping_room".to_string(), RoutingStrategy::Room)
331            .with_message_type(MessageType::Ping)
332            .with_target("lobby".to_string());
333
334        router.add_rule(rule);
335
336        let ping = Message::ping();
337        let result = router.route(&ping);
338
339        assert_eq!(result.strategy, RoutingStrategy::Room);
340        assert_eq!(result.target_id, Some("lobby".to_string()));
341    }
342
343    #[test]
344    fn test_router_priority() {
345        let router = MessageRouter::new();
346
347        let rule1 =
348            RoutingRule::new("low".to_string(), RoutingStrategy::Broadcast).with_priority(1);
349
350        let rule2 = RoutingRule::new("high".to_string(), RoutingStrategy::Room)
351            .with_priority(10)
352            .with_target("test".to_string());
353
354        router.add_rule(rule1);
355        router.add_rule(rule2);
356
357        let msg = Message::ping();
358        let result = router.route(&msg);
359
360        // High priority rule should match first
361        assert_eq!(result.strategy, RoutingStrategy::Room);
362    }
363
364    #[test]
365    fn test_routing_table() {
366        let table = RoutingTable::new();
367        let source = uuid::Uuid::new_v4();
368        let target1 = uuid::Uuid::new_v4();
369        let target2 = uuid::Uuid::new_v4();
370
371        table.add_route(source, target1);
372        table.add_route(source, target2);
373
374        let targets = table.get_targets(&source);
375        assert_eq!(targets.len(), 2);
376    }
377
378    #[test]
379    fn test_routing_table_remove() {
380        let table = RoutingTable::new();
381        let source = uuid::Uuid::new_v4();
382        let target = uuid::Uuid::new_v4();
383
384        table.add_route(source, target);
385        table.remove_route(&source, &target);
386
387        let targets = table.get_targets(&source);
388        assert_eq!(targets.len(), 0);
389    }
390}