use crate::protocol::message::{Message, MessageType};
use crate::server::connection::ConnectionId;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RoutingStrategy {
Broadcast,
Direct,
Room,
Topic,
Custom,
}
#[derive(Clone)]
pub struct RoutingRule {
pub name: String,
pub message_type: Option<MessageType>,
pub strategy: RoutingStrategy,
pub target: Option<String>,
pub priority: i32,
}
impl RoutingRule {
pub fn new(name: String, strategy: RoutingStrategy) -> Self {
Self {
name,
message_type: None,
strategy,
target: None,
priority: 0,
}
}
pub fn with_message_type(mut self, msg_type: MessageType) -> Self {
self.message_type = Some(msg_type);
self
}
pub fn with_target(mut self, target: String) -> Self {
self.target = Some(target);
self
}
pub fn with_priority(mut self, priority: i32) -> Self {
self.priority = priority;
self
}
pub fn matches(&self, message: &Message) -> bool {
if let Some(msg_type) = self.message_type {
message.msg_type == msg_type
} else {
true
}
}
}
pub struct RouteResult {
pub strategy: RoutingStrategy,
pub targets: Vec<ConnectionId>,
pub target_id: Option<String>,
}
impl RouteResult {
pub fn broadcast() -> Self {
Self {
strategy: RoutingStrategy::Broadcast,
targets: Vec::new(),
target_id: None,
}
}
pub fn direct(target: ConnectionId) -> Self {
Self {
strategy: RoutingStrategy::Direct,
targets: vec![target],
target_id: None,
}
}
pub fn room(room_name: String) -> Self {
Self {
strategy: RoutingStrategy::Room,
targets: Vec::new(),
target_id: Some(room_name),
}
}
pub fn topic(topic_name: String) -> Self {
Self {
strategy: RoutingStrategy::Topic,
targets: Vec::new(),
target_id: Some(topic_name),
}
}
}
pub struct MessageRouter {
rules: Arc<RwLock<Vec<RoutingRule>>>,
default_strategy: RoutingStrategy,
}
impl MessageRouter {
pub fn new() -> Self {
Self {
rules: Arc::new(RwLock::new(Vec::new())),
default_strategy: RoutingStrategy::Broadcast,
}
}
pub fn with_default_strategy(strategy: RoutingStrategy) -> Self {
Self {
rules: Arc::new(RwLock::new(Vec::new())),
default_strategy: strategy,
}
}
pub fn add_rule(&self, rule: RoutingRule) {
let mut rules = self.rules.write();
rules.push(rule);
rules.sort_by_key(|x| std::cmp::Reverse(x.priority));
}
pub fn remove_rule(&self, name: &str) -> bool {
let mut rules = self.rules.write();
if let Some(pos) = rules.iter().position(|r| r.name == name) {
rules.remove(pos);
true
} else {
false
}
}
pub fn route(&self, message: &Message) -> RouteResult {
let rules = self.rules.read();
for rule in rules.iter() {
if rule.matches(message) {
return match rule.strategy {
RoutingStrategy::Broadcast => RouteResult::broadcast(),
RoutingStrategy::Room => {
if let Some(target) = &rule.target {
RouteResult::room(target.clone())
} else {
RouteResult::broadcast()
}
}
RoutingStrategy::Topic => {
if let Some(target) = &rule.target {
RouteResult::topic(target.clone())
} else {
RouteResult::broadcast()
}
}
_ => RouteResult::broadcast(),
};
}
}
match self.default_strategy {
RoutingStrategy::Broadcast => RouteResult::broadcast(),
_ => RouteResult::broadcast(),
}
}
pub fn rules(&self) -> Vec<RoutingRule> {
self.rules.read().clone()
}
pub fn clear_rules(&self) {
self.rules.write().clear();
}
pub fn rule_count(&self) -> usize {
self.rules.read().len()
}
}
impl Default for MessageRouter {
fn default() -> Self {
Self::new()
}
}
pub struct RoutingTable {
routes: Arc<RwLock<HashMap<ConnectionId, Vec<ConnectionId>>>>,
}
impl RoutingTable {
pub fn new() -> Self {
Self {
routes: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn add_route(&self, source: ConnectionId, target: ConnectionId) {
let mut routes = self.routes.write();
routes.entry(source).or_default().push(target);
}
pub fn remove_route(&self, source: &ConnectionId, target: &ConnectionId) {
let mut routes = self.routes.write();
if let Some(targets) = routes.get_mut(source) {
targets.retain(|t| t != target);
}
}
pub fn get_targets(&self, source: &ConnectionId) -> Vec<ConnectionId> {
let routes = self.routes.read();
routes.get(source).cloned().unwrap_or_default()
}
pub fn remove_connection(&self, connection: &ConnectionId) {
let mut routes = self.routes.write();
routes.remove(connection);
for targets in routes.values_mut() {
targets.retain(|t| t != connection);
}
}
pub fn clear(&self) {
self.routes.write().clear();
}
pub fn route_count(&self) -> usize {
self.routes.read().len()
}
}
impl Default for RoutingTable {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_routing_rule() {
let rule = RoutingRule::new("test".to_string(), RoutingStrategy::Broadcast)
.with_message_type(MessageType::Ping)
.with_priority(10);
assert_eq!(rule.name, "test");
assert_eq!(rule.message_type, Some(MessageType::Ping));
assert_eq!(rule.priority, 10);
}
#[test]
fn test_routing_rule_matches() {
let rule = RoutingRule::new("test".to_string(), RoutingStrategy::Broadcast)
.with_message_type(MessageType::Ping);
let ping = Message::ping();
let pong = Message::pong();
assert!(rule.matches(&ping));
assert!(!rule.matches(&pong));
}
#[test]
fn test_message_router() {
let router = MessageRouter::new();
assert_eq!(router.rule_count(), 0);
let rule = RoutingRule::new("test".to_string(), RoutingStrategy::Broadcast);
router.add_rule(rule);
assert_eq!(router.rule_count(), 1);
}
#[test]
fn test_router_route() {
let router = MessageRouter::new();
let rule = RoutingRule::new("ping_room".to_string(), RoutingStrategy::Room)
.with_message_type(MessageType::Ping)
.with_target("lobby".to_string());
router.add_rule(rule);
let ping = Message::ping();
let result = router.route(&ping);
assert_eq!(result.strategy, RoutingStrategy::Room);
assert_eq!(result.target_id, Some("lobby".to_string()));
}
#[test]
fn test_router_priority() {
let router = MessageRouter::new();
let rule1 =
RoutingRule::new("low".to_string(), RoutingStrategy::Broadcast).with_priority(1);
let rule2 = RoutingRule::new("high".to_string(), RoutingStrategy::Room)
.with_priority(10)
.with_target("test".to_string());
router.add_rule(rule1);
router.add_rule(rule2);
let msg = Message::ping();
let result = router.route(&msg);
assert_eq!(result.strategy, RoutingStrategy::Room);
}
#[test]
fn test_routing_table() {
let table = RoutingTable::new();
let source = uuid::Uuid::new_v4();
let target1 = uuid::Uuid::new_v4();
let target2 = uuid::Uuid::new_v4();
table.add_route(source, target1);
table.add_route(source, target2);
let targets = table.get_targets(&source);
assert_eq!(targets.len(), 2);
}
#[test]
fn test_routing_table_remove() {
let table = RoutingTable::new();
let source = uuid::Uuid::new_v4();
let target = uuid::Uuid::new_v4();
table.add_route(source, target);
table.remove_route(&source, &target);
let targets = table.get_targets(&source);
assert_eq!(targets.len(), 0);
}
}