celers_protocol/
routing.rs

1//! Message routing helpers
2//!
3//! This module provides utilities for routing messages to different queues based on
4//! task names, priorities, and custom rules.
5
6use crate::Message;
7use std::collections::HashMap;
8
9/// Routing rule for directing messages to queues
10#[derive(Debug, Clone)]
11pub struct RoutingRule {
12    /// Task name pattern (supports prefix matching with '*')
13    pub pattern: String,
14    /// Target queue name
15    pub queue: String,
16    /// Optional routing key
17    pub routing_key: Option<String>,
18    /// Optional exchange name
19    pub exchange: Option<String>,
20}
21
22impl RoutingRule {
23    /// Create a new routing rule
24    pub fn new(pattern: impl Into<String>, queue: impl Into<String>) -> Self {
25        Self {
26            pattern: pattern.into(),
27            queue: queue.into(),
28            routing_key: None,
29            exchange: None,
30        }
31    }
32
33    /// Set the routing key
34    pub fn with_routing_key(mut self, routing_key: impl Into<String>) -> Self {
35        self.routing_key = Some(routing_key.into());
36        self
37    }
38
39    /// Set the exchange
40    pub fn with_exchange(mut self, exchange: impl Into<String>) -> Self {
41        self.exchange = Some(exchange.into());
42        self
43    }
44
45    /// Check if this rule matches a task name
46    #[inline]
47    pub fn matches(&self, task_name: &str) -> bool {
48        if self.pattern.ends_with('*') {
49            let prefix = &self.pattern[..self.pattern.len() - 1];
50            task_name.starts_with(prefix)
51        } else {
52            task_name == self.pattern
53        }
54    }
55}
56
57/// Router for directing messages to queues
58#[derive(Debug, Clone)]
59pub struct MessageRouter {
60    rules: Vec<RoutingRule>,
61    default_queue: String,
62}
63
64impl MessageRouter {
65    /// Create a new message router with a default queue
66    pub fn new(default_queue: impl Into<String>) -> Self {
67        Self {
68            rules: Vec::new(),
69            default_queue: default_queue.into(),
70        }
71    }
72
73    /// Add a routing rule
74    pub fn add_rule(&mut self, rule: RoutingRule) {
75        self.rules.push(rule);
76    }
77
78    /// Add a simple routing rule (pattern -> queue)
79    pub fn route(&mut self, pattern: impl Into<String>, queue: impl Into<String>) {
80        self.rules.push(RoutingRule::new(pattern, queue));
81    }
82
83    /// Get the queue name for a message
84    #[inline]
85    pub fn get_queue(&self, message: &Message) -> &str {
86        self.get_queue_for_task(&message.headers.task)
87    }
88
89    /// Get the queue name for a task name
90    #[inline]
91    pub fn get_queue_for_task(&self, task_name: &str) -> &str {
92        for rule in &self.rules {
93            if rule.matches(task_name) {
94                return &rule.queue;
95            }
96        }
97        &self.default_queue
98    }
99
100    /// Get the routing key for a message
101    #[inline]
102    pub fn get_routing_key(&self, message: &Message) -> Option<&str> {
103        for rule in &self.rules {
104            if rule.matches(&message.headers.task) {
105                return rule.routing_key.as_deref();
106            }
107        }
108        None
109    }
110
111    /// Get the exchange for a message
112    #[inline]
113    pub fn get_exchange(&self, message: &Message) -> Option<&str> {
114        for rule in &self.rules {
115            if rule.matches(&message.headers.task) {
116                return rule.exchange.as_deref();
117            }
118        }
119        None
120    }
121
122    /// Group messages by their target queues
123    pub fn group_by_queue(&self, messages: Vec<Message>) -> HashMap<String, Vec<Message>> {
124        let mut groups = HashMap::new();
125        for msg in messages {
126            let queue = self.get_queue(&msg).to_string();
127            groups.entry(queue).or_insert_with(Vec::new).push(msg);
128        }
129        groups
130    }
131}
132
133/// Priority-based router
134pub struct PriorityRouter {
135    high_priority_queue: String,
136    normal_priority_queue: String,
137    low_priority_queue: String,
138    threshold_high: u8,
139    threshold_low: u8,
140}
141
142impl PriorityRouter {
143    /// Create a new priority-based router
144    pub fn new(
145        high_priority_queue: impl Into<String>,
146        normal_priority_queue: impl Into<String>,
147        low_priority_queue: impl Into<String>,
148    ) -> Self {
149        Self {
150            high_priority_queue: high_priority_queue.into(),
151            normal_priority_queue: normal_priority_queue.into(),
152            low_priority_queue: low_priority_queue.into(),
153            threshold_high: 7,
154            threshold_low: 3,
155        }
156    }
157
158    /// Set priority thresholds
159    pub fn with_thresholds(mut self, high: u8, low: u8) -> Self {
160        self.threshold_high = high;
161        self.threshold_low = low;
162        self
163    }
164
165    /// Get the queue for a message based on priority
166    #[inline]
167    pub fn get_queue(&self, message: &Message) -> &str {
168        let priority = message.properties.priority.unwrap_or(5);
169
170        if priority >= self.threshold_high {
171            &self.high_priority_queue
172        } else if priority <= self.threshold_low {
173            &self.low_priority_queue
174        } else {
175            &self.normal_priority_queue
176        }
177    }
178
179    /// Group messages by priority queue
180    pub fn group_by_priority(&self, messages: Vec<Message>) -> HashMap<String, Vec<Message>> {
181        let mut groups = HashMap::new();
182        for msg in messages {
183            let queue = self.get_queue(&msg).to_string();
184            groups.entry(queue).or_insert_with(Vec::new).push(msg);
185        }
186        groups
187    }
188}
189
190/// Round-robin router for load balancing
191pub struct RoundRobinRouter {
192    queues: Vec<String>,
193    current_index: std::sync::atomic::AtomicUsize,
194}
195
196impl RoundRobinRouter {
197    /// Create a new round-robin router
198    pub fn new(queues: Vec<String>) -> Self {
199        Self {
200            queues,
201            current_index: std::sync::atomic::AtomicUsize::new(0),
202        }
203    }
204
205    /// Get the next queue in round-robin order
206    #[inline]
207    pub fn next_queue(&self) -> &str {
208        if self.queues.is_empty() {
209            return "default";
210        }
211
212        let index = self
213            .current_index
214            .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
215            % self.queues.len();
216        &self.queues[index]
217    }
218
219    /// Distribute messages across queues in round-robin fashion
220    pub fn distribute(&self, messages: Vec<Message>) -> HashMap<String, Vec<Message>> {
221        let mut groups = HashMap::new();
222        for msg in messages {
223            let queue = self.next_queue().to_string();
224            groups.entry(queue).or_insert_with(Vec::new).push(msg);
225        }
226        groups
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::builder::MessageBuilder;
234
235    fn create_test_message(task: &str) -> Message {
236        MessageBuilder::new(task).build().unwrap()
237    }
238
239    #[test]
240    fn test_routing_rule_matches() {
241        let rule = RoutingRule::new("tasks.add", "math_queue");
242        assert!(rule.matches("tasks.add"));
243        assert!(!rule.matches("tasks.subtract"));
244
245        let prefix_rule = RoutingRule::new("tasks.*", "task_queue");
246        assert!(prefix_rule.matches("tasks.add"));
247        assert!(prefix_rule.matches("tasks.subtract"));
248        assert!(!prefix_rule.matches("email.send"));
249    }
250
251    #[test]
252    fn test_message_router() {
253        let mut router = MessageRouter::new("default_queue");
254        router.route("tasks.*", "task_queue");
255        router.route("email.*", "email_queue");
256        router.route("tasks.priority", "priority_queue");
257
258        assert_eq!(router.get_queue_for_task("tasks.add"), "task_queue");
259        assert_eq!(router.get_queue_for_task("email.send"), "email_queue");
260        assert_eq!(router.get_queue_for_task("tasks.priority"), "task_queue"); // First match wins
261        assert_eq!(router.get_queue_for_task("unknown.task"), "default_queue");
262    }
263
264    #[test]
265    fn test_message_router_with_message() {
266        let mut router = MessageRouter::new("default");
267        router.route("tasks.*", "task_queue");
268
269        let msg = create_test_message("tasks.add");
270        assert_eq!(router.get_queue(&msg), "task_queue");
271    }
272
273    #[test]
274    fn test_message_router_group_by_queue() {
275        let mut router = MessageRouter::new("default");
276        router.route("tasks.*", "task_queue");
277        router.route("email.*", "email_queue");
278
279        let messages = vec![
280            create_test_message("tasks.add"),
281            create_test_message("tasks.subtract"),
282            create_test_message("email.send"),
283            create_test_message("other.task"),
284        ];
285
286        let groups = router.group_by_queue(messages);
287        assert_eq!(groups.len(), 3);
288        assert_eq!(groups.get("task_queue").unwrap().len(), 2);
289        assert_eq!(groups.get("email_queue").unwrap().len(), 1);
290        assert_eq!(groups.get("default").unwrap().len(), 1);
291    }
292
293    #[test]
294    fn test_priority_router() {
295        let router =
296            PriorityRouter::new("high_queue", "normal_queue", "low_queue").with_thresholds(7, 3);
297
298        let mut high_msg = create_test_message("task");
299        high_msg.properties.priority = Some(9);
300        assert_eq!(router.get_queue(&high_msg), "high_queue");
301
302        let mut normal_msg = create_test_message("task");
303        normal_msg.properties.priority = Some(5);
304        assert_eq!(router.get_queue(&normal_msg), "normal_queue");
305
306        let mut low_msg = create_test_message("task");
307        low_msg.properties.priority = Some(1);
308        assert_eq!(router.get_queue(&low_msg), "low_queue");
309    }
310
311    #[test]
312    fn test_priority_router_default() {
313        let router = PriorityRouter::new("high", "normal", "low");
314
315        let msg = create_test_message("task"); // No priority set
316        assert_eq!(router.get_queue(&msg), "normal");
317    }
318
319    #[test]
320    fn test_round_robin_router() {
321        let router = RoundRobinRouter::new(vec![
322            "queue1".to_string(),
323            "queue2".to_string(),
324            "queue3".to_string(),
325        ]);
326
327        assert_eq!(router.next_queue(), "queue1");
328        assert_eq!(router.next_queue(), "queue2");
329        assert_eq!(router.next_queue(), "queue3");
330        assert_eq!(router.next_queue(), "queue1"); // Wraps around
331    }
332
333    #[test]
334    fn test_round_robin_distribute() {
335        let router = RoundRobinRouter::new(vec!["queue1".to_string(), "queue2".to_string()]);
336
337        let messages = vec![
338            create_test_message("task1"),
339            create_test_message("task2"),
340            create_test_message("task3"),
341            create_test_message("task4"),
342        ];
343
344        let groups = router.distribute(messages);
345        assert_eq!(groups.len(), 2);
346        assert_eq!(groups.get("queue1").unwrap().len(), 2);
347        assert_eq!(groups.get("queue2").unwrap().len(), 2);
348    }
349
350    #[test]
351    fn test_routing_rule_with_routing_key() {
352        let rule = RoutingRule::new("tasks.*", "task_queue")
353            .with_routing_key("tasks.#")
354            .with_exchange("celery");
355
356        assert_eq!(rule.routing_key.as_deref(), Some("tasks.#"));
357        assert_eq!(rule.exchange.as_deref(), Some("celery"));
358    }
359}