1use crate::Message;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct RoutingRule {
12 pub pattern: String,
14 pub queue: String,
16 pub routing_key: Option<String>,
18 pub exchange: Option<String>,
20}
21
22impl RoutingRule {
23 pub fn new(pattern: impl Into<String>, queue: impl Into<String>) -> Self {
25 Self {
26 pattern: pattern.into(),
27 queue: queue.into(),
28 routing_key: None,
29 exchange: None,
30 }
31 }
32
33 pub fn with_routing_key(mut self, routing_key: impl Into<String>) -> Self {
35 self.routing_key = Some(routing_key.into());
36 self
37 }
38
39 pub fn with_exchange(mut self, exchange: impl Into<String>) -> Self {
41 self.exchange = Some(exchange.into());
42 self
43 }
44
45 #[inline]
47 pub fn matches(&self, task_name: &str) -> bool {
48 if self.pattern.ends_with('*') {
49 let prefix = &self.pattern[..self.pattern.len() - 1];
50 task_name.starts_with(prefix)
51 } else {
52 task_name == self.pattern
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct MessageRouter {
60 rules: Vec<RoutingRule>,
61 default_queue: String,
62}
63
64impl MessageRouter {
65 pub fn new(default_queue: impl Into<String>) -> Self {
67 Self {
68 rules: Vec::new(),
69 default_queue: default_queue.into(),
70 }
71 }
72
73 pub fn add_rule(&mut self, rule: RoutingRule) {
75 self.rules.push(rule);
76 }
77
78 pub fn route(&mut self, pattern: impl Into<String>, queue: impl Into<String>) {
80 self.rules.push(RoutingRule::new(pattern, queue));
81 }
82
83 #[inline]
85 pub fn get_queue(&self, message: &Message) -> &str {
86 self.get_queue_for_task(&message.headers.task)
87 }
88
89 #[inline]
91 pub fn get_queue_for_task(&self, task_name: &str) -> &str {
92 for rule in &self.rules {
93 if rule.matches(task_name) {
94 return &rule.queue;
95 }
96 }
97 &self.default_queue
98 }
99
100 #[inline]
102 pub fn get_routing_key(&self, message: &Message) -> Option<&str> {
103 for rule in &self.rules {
104 if rule.matches(&message.headers.task) {
105 return rule.routing_key.as_deref();
106 }
107 }
108 None
109 }
110
111 #[inline]
113 pub fn get_exchange(&self, message: &Message) -> Option<&str> {
114 for rule in &self.rules {
115 if rule.matches(&message.headers.task) {
116 return rule.exchange.as_deref();
117 }
118 }
119 None
120 }
121
122 pub fn group_by_queue(&self, messages: Vec<Message>) -> HashMap<String, Vec<Message>> {
124 let mut groups = HashMap::new();
125 for msg in messages {
126 let queue = self.get_queue(&msg).to_string();
127 groups.entry(queue).or_insert_with(Vec::new).push(msg);
128 }
129 groups
130 }
131}
132
133pub struct PriorityRouter {
135 high_priority_queue: String,
136 normal_priority_queue: String,
137 low_priority_queue: String,
138 threshold_high: u8,
139 threshold_low: u8,
140}
141
142impl PriorityRouter {
143 pub fn new(
145 high_priority_queue: impl Into<String>,
146 normal_priority_queue: impl Into<String>,
147 low_priority_queue: impl Into<String>,
148 ) -> Self {
149 Self {
150 high_priority_queue: high_priority_queue.into(),
151 normal_priority_queue: normal_priority_queue.into(),
152 low_priority_queue: low_priority_queue.into(),
153 threshold_high: 7,
154 threshold_low: 3,
155 }
156 }
157
158 pub fn with_thresholds(mut self, high: u8, low: u8) -> Self {
160 self.threshold_high = high;
161 self.threshold_low = low;
162 self
163 }
164
165 #[inline]
167 pub fn get_queue(&self, message: &Message) -> &str {
168 let priority = message.properties.priority.unwrap_or(5);
169
170 if priority >= self.threshold_high {
171 &self.high_priority_queue
172 } else if priority <= self.threshold_low {
173 &self.low_priority_queue
174 } else {
175 &self.normal_priority_queue
176 }
177 }
178
179 pub fn group_by_priority(&self, messages: Vec<Message>) -> HashMap<String, Vec<Message>> {
181 let mut groups = HashMap::new();
182 for msg in messages {
183 let queue = self.get_queue(&msg).to_string();
184 groups.entry(queue).or_insert_with(Vec::new).push(msg);
185 }
186 groups
187 }
188}
189
190pub struct RoundRobinRouter {
192 queues: Vec<String>,
193 current_index: std::sync::atomic::AtomicUsize,
194}
195
196impl RoundRobinRouter {
197 pub fn new(queues: Vec<String>) -> Self {
199 Self {
200 queues,
201 current_index: std::sync::atomic::AtomicUsize::new(0),
202 }
203 }
204
205 #[inline]
207 pub fn next_queue(&self) -> &str {
208 if self.queues.is_empty() {
209 return "default";
210 }
211
212 let index = self
213 .current_index
214 .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
215 % self.queues.len();
216 &self.queues[index]
217 }
218
219 pub fn distribute(&self, messages: Vec<Message>) -> HashMap<String, Vec<Message>> {
221 let mut groups = HashMap::new();
222 for msg in messages {
223 let queue = self.next_queue().to_string();
224 groups.entry(queue).or_insert_with(Vec::new).push(msg);
225 }
226 groups
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::builder::MessageBuilder;
234
235 fn create_test_message(task: &str) -> Message {
236 MessageBuilder::new(task).build().unwrap()
237 }
238
239 #[test]
240 fn test_routing_rule_matches() {
241 let rule = RoutingRule::new("tasks.add", "math_queue");
242 assert!(rule.matches("tasks.add"));
243 assert!(!rule.matches("tasks.subtract"));
244
245 let prefix_rule = RoutingRule::new("tasks.*", "task_queue");
246 assert!(prefix_rule.matches("tasks.add"));
247 assert!(prefix_rule.matches("tasks.subtract"));
248 assert!(!prefix_rule.matches("email.send"));
249 }
250
251 #[test]
252 fn test_message_router() {
253 let mut router = MessageRouter::new("default_queue");
254 router.route("tasks.*", "task_queue");
255 router.route("email.*", "email_queue");
256 router.route("tasks.priority", "priority_queue");
257
258 assert_eq!(router.get_queue_for_task("tasks.add"), "task_queue");
259 assert_eq!(router.get_queue_for_task("email.send"), "email_queue");
260 assert_eq!(router.get_queue_for_task("tasks.priority"), "task_queue"); assert_eq!(router.get_queue_for_task("unknown.task"), "default_queue");
262 }
263
264 #[test]
265 fn test_message_router_with_message() {
266 let mut router = MessageRouter::new("default");
267 router.route("tasks.*", "task_queue");
268
269 let msg = create_test_message("tasks.add");
270 assert_eq!(router.get_queue(&msg), "task_queue");
271 }
272
273 #[test]
274 fn test_message_router_group_by_queue() {
275 let mut router = MessageRouter::new("default");
276 router.route("tasks.*", "task_queue");
277 router.route("email.*", "email_queue");
278
279 let messages = vec![
280 create_test_message("tasks.add"),
281 create_test_message("tasks.subtract"),
282 create_test_message("email.send"),
283 create_test_message("other.task"),
284 ];
285
286 let groups = router.group_by_queue(messages);
287 assert_eq!(groups.len(), 3);
288 assert_eq!(groups.get("task_queue").unwrap().len(), 2);
289 assert_eq!(groups.get("email_queue").unwrap().len(), 1);
290 assert_eq!(groups.get("default").unwrap().len(), 1);
291 }
292
293 #[test]
294 fn test_priority_router() {
295 let router =
296 PriorityRouter::new("high_queue", "normal_queue", "low_queue").with_thresholds(7, 3);
297
298 let mut high_msg = create_test_message("task");
299 high_msg.properties.priority = Some(9);
300 assert_eq!(router.get_queue(&high_msg), "high_queue");
301
302 let mut normal_msg = create_test_message("task");
303 normal_msg.properties.priority = Some(5);
304 assert_eq!(router.get_queue(&normal_msg), "normal_queue");
305
306 let mut low_msg = create_test_message("task");
307 low_msg.properties.priority = Some(1);
308 assert_eq!(router.get_queue(&low_msg), "low_queue");
309 }
310
311 #[test]
312 fn test_priority_router_default() {
313 let router = PriorityRouter::new("high", "normal", "low");
314
315 let msg = create_test_message("task"); assert_eq!(router.get_queue(&msg), "normal");
317 }
318
319 #[test]
320 fn test_round_robin_router() {
321 let router = RoundRobinRouter::new(vec![
322 "queue1".to_string(),
323 "queue2".to_string(),
324 "queue3".to_string(),
325 ]);
326
327 assert_eq!(router.next_queue(), "queue1");
328 assert_eq!(router.next_queue(), "queue2");
329 assert_eq!(router.next_queue(), "queue3");
330 assert_eq!(router.next_queue(), "queue1"); }
332
333 #[test]
334 fn test_round_robin_distribute() {
335 let router = RoundRobinRouter::new(vec!["queue1".to_string(), "queue2".to_string()]);
336
337 let messages = vec![
338 create_test_message("task1"),
339 create_test_message("task2"),
340 create_test_message("task3"),
341 create_test_message("task4"),
342 ];
343
344 let groups = router.distribute(messages);
345 assert_eq!(groups.len(), 2);
346 assert_eq!(groups.get("queue1").unwrap().len(), 2);
347 assert_eq!(groups.get("queue2").unwrap().len(), 2);
348 }
349
350 #[test]
351 fn test_routing_rule_with_routing_key() {
352 let rule = RoutingRule::new("tasks.*", "task_queue")
353 .with_routing_key("tasks.#")
354 .with_exchange("celery");
355
356 assert_eq!(rule.routing_key.as_deref(), Some("tasks.#"));
357 assert_eq!(rule.exchange.as_deref(), Some("celery"));
358 }
359}