use crate::Message;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct RoutingRule {
pub pattern: String,
pub queue: String,
pub routing_key: Option<String>,
pub exchange: Option<String>,
}
impl RoutingRule {
pub fn new(pattern: impl Into<String>, queue: impl Into<String>) -> Self {
Self {
pattern: pattern.into(),
queue: queue.into(),
routing_key: None,
exchange: None,
}
}
pub fn with_routing_key(mut self, routing_key: impl Into<String>) -> Self {
self.routing_key = Some(routing_key.into());
self
}
pub fn with_exchange(mut self, exchange: impl Into<String>) -> Self {
self.exchange = Some(exchange.into());
self
}
#[inline]
pub fn matches(&self, task_name: &str) -> bool {
if self.pattern.ends_with('*') {
let prefix = &self.pattern[..self.pattern.len() - 1];
task_name.starts_with(prefix)
} else {
task_name == self.pattern
}
}
}
#[derive(Debug, Clone)]
pub struct MessageRouter {
rules: Vec<RoutingRule>,
default_queue: String,
}
impl MessageRouter {
pub fn new(default_queue: impl Into<String>) -> Self {
Self {
rules: Vec::new(),
default_queue: default_queue.into(),
}
}
pub fn add_rule(&mut self, rule: RoutingRule) {
self.rules.push(rule);
}
pub fn route(&mut self, pattern: impl Into<String>, queue: impl Into<String>) {
self.rules.push(RoutingRule::new(pattern, queue));
}
#[inline]
pub fn get_queue(&self, message: &Message) -> &str {
self.get_queue_for_task(&message.headers.task)
}
#[inline]
pub fn get_queue_for_task(&self, task_name: &str) -> &str {
for rule in &self.rules {
if rule.matches(task_name) {
return &rule.queue;
}
}
&self.default_queue
}
#[inline]
pub fn get_routing_key(&self, message: &Message) -> Option<&str> {
for rule in &self.rules {
if rule.matches(&message.headers.task) {
return rule.routing_key.as_deref();
}
}
None
}
#[inline]
pub fn get_exchange(&self, message: &Message) -> Option<&str> {
for rule in &self.rules {
if rule.matches(&message.headers.task) {
return rule.exchange.as_deref();
}
}
None
}
pub fn group_by_queue(&self, messages: Vec<Message>) -> HashMap<String, Vec<Message>> {
let mut groups = HashMap::new();
for msg in messages {
let queue = self.get_queue(&msg).to_string();
groups.entry(queue).or_insert_with(Vec::new).push(msg);
}
groups
}
}
pub struct PriorityRouter {
high_priority_queue: String,
normal_priority_queue: String,
low_priority_queue: String,
threshold_high: u8,
threshold_low: u8,
}
impl PriorityRouter {
pub fn new(
high_priority_queue: impl Into<String>,
normal_priority_queue: impl Into<String>,
low_priority_queue: impl Into<String>,
) -> Self {
Self {
high_priority_queue: high_priority_queue.into(),
normal_priority_queue: normal_priority_queue.into(),
low_priority_queue: low_priority_queue.into(),
threshold_high: 7,
threshold_low: 3,
}
}
pub fn with_thresholds(mut self, high: u8, low: u8) -> Self {
self.threshold_high = high;
self.threshold_low = low;
self
}
#[inline]
pub fn get_queue(&self, message: &Message) -> &str {
let priority = message.properties.priority.unwrap_or(5);
if priority >= self.threshold_high {
&self.high_priority_queue
} else if priority <= self.threshold_low {
&self.low_priority_queue
} else {
&self.normal_priority_queue
}
}
pub fn group_by_priority(&self, messages: Vec<Message>) -> HashMap<String, Vec<Message>> {
let mut groups = HashMap::new();
for msg in messages {
let queue = self.get_queue(&msg).to_string();
groups.entry(queue).or_insert_with(Vec::new).push(msg);
}
groups
}
}
pub struct RoundRobinRouter {
queues: Vec<String>,
current_index: std::sync::atomic::AtomicUsize,
}
impl RoundRobinRouter {
pub fn new(queues: Vec<String>) -> Self {
Self {
queues,
current_index: std::sync::atomic::AtomicUsize::new(0),
}
}
#[inline]
pub fn next_queue(&self) -> &str {
if self.queues.is_empty() {
return "default";
}
let index = self
.current_index
.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
% self.queues.len();
&self.queues[index]
}
pub fn distribute(&self, messages: Vec<Message>) -> HashMap<String, Vec<Message>> {
let mut groups = HashMap::new();
for msg in messages {
let queue = self.next_queue().to_string();
groups.entry(queue).or_insert_with(Vec::new).push(msg);
}
groups
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::MessageBuilder;
fn create_test_message(task: &str) -> Message {
MessageBuilder::new(task).build().unwrap()
}
#[test]
fn test_routing_rule_matches() {
let rule = RoutingRule::new("tasks.add", "math_queue");
assert!(rule.matches("tasks.add"));
assert!(!rule.matches("tasks.subtract"));
let prefix_rule = RoutingRule::new("tasks.*", "task_queue");
assert!(prefix_rule.matches("tasks.add"));
assert!(prefix_rule.matches("tasks.subtract"));
assert!(!prefix_rule.matches("email.send"));
}
#[test]
fn test_message_router() {
let mut router = MessageRouter::new("default_queue");
router.route("tasks.*", "task_queue");
router.route("email.*", "email_queue");
router.route("tasks.priority", "priority_queue");
assert_eq!(router.get_queue_for_task("tasks.add"), "task_queue");
assert_eq!(router.get_queue_for_task("email.send"), "email_queue");
assert_eq!(router.get_queue_for_task("tasks.priority"), "task_queue"); assert_eq!(router.get_queue_for_task("unknown.task"), "default_queue");
}
#[test]
fn test_message_router_with_message() {
let mut router = MessageRouter::new("default");
router.route("tasks.*", "task_queue");
let msg = create_test_message("tasks.add");
assert_eq!(router.get_queue(&msg), "task_queue");
}
#[test]
fn test_message_router_group_by_queue() {
let mut router = MessageRouter::new("default");
router.route("tasks.*", "task_queue");
router.route("email.*", "email_queue");
let messages = vec![
create_test_message("tasks.add"),
create_test_message("tasks.subtract"),
create_test_message("email.send"),
create_test_message("other.task"),
];
let groups = router.group_by_queue(messages);
assert_eq!(groups.len(), 3);
assert_eq!(groups.get("task_queue").unwrap().len(), 2);
assert_eq!(groups.get("email_queue").unwrap().len(), 1);
assert_eq!(groups.get("default").unwrap().len(), 1);
}
#[test]
fn test_priority_router() {
let router =
PriorityRouter::new("high_queue", "normal_queue", "low_queue").with_thresholds(7, 3);
let mut high_msg = create_test_message("task");
high_msg.properties.priority = Some(9);
assert_eq!(router.get_queue(&high_msg), "high_queue");
let mut normal_msg = create_test_message("task");
normal_msg.properties.priority = Some(5);
assert_eq!(router.get_queue(&normal_msg), "normal_queue");
let mut low_msg = create_test_message("task");
low_msg.properties.priority = Some(1);
assert_eq!(router.get_queue(&low_msg), "low_queue");
}
#[test]
fn test_priority_router_default() {
let router = PriorityRouter::new("high", "normal", "low");
let msg = create_test_message("task"); assert_eq!(router.get_queue(&msg), "normal");
}
#[test]
fn test_round_robin_router() {
let router = RoundRobinRouter::new(vec![
"queue1".to_string(),
"queue2".to_string(),
"queue3".to_string(),
]);
assert_eq!(router.next_queue(), "queue1");
assert_eq!(router.next_queue(), "queue2");
assert_eq!(router.next_queue(), "queue3");
assert_eq!(router.next_queue(), "queue1"); }
#[test]
fn test_round_robin_distribute() {
let router = RoundRobinRouter::new(vec!["queue1".to_string(), "queue2".to_string()]);
let messages = vec![
create_test_message("task1"),
create_test_message("task2"),
create_test_message("task3"),
create_test_message("task4"),
];
let groups = router.distribute(messages);
assert_eq!(groups.len(), 2);
assert_eq!(groups.get("queue1").unwrap().len(), 2);
assert_eq!(groups.get("queue2").unwrap().len(), 2);
}
#[test]
fn test_routing_rule_with_routing_key() {
let rule = RoutingRule::new("tasks.*", "task_queue")
.with_routing_key("tasks.#")
.with_exchange("celery");
assert_eq!(rule.routing_key.as_deref(), Some("tasks.#"));
assert_eq!(rule.exchange.as_deref(), Some("celery"));
}
}