1use crate::Message;
25use std::fmt;
26
27#[derive(Debug, Clone)]
29pub enum MiddlewareError {
30 Validation(String),
32 Transformation(String),
34 Processing(String),
36}
37
38impl fmt::Display for MiddlewareError {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 MiddlewareError::Validation(msg) => write!(f, "Validation error: {}", msg),
42 MiddlewareError::Transformation(msg) => write!(f, "Transformation error: {}", msg),
43 MiddlewareError::Processing(msg) => write!(f, "Processing error: {}", msg),
44 }
45 }
46}
47
48impl From<crate::ValidationError> for MiddlewareError {
49 fn from(err: crate::ValidationError) -> Self {
50 MiddlewareError::Validation(err.to_string())
51 }
52}
53
54impl std::error::Error for MiddlewareError {}
55
56pub trait Middleware: Send + Sync {
58 fn process(&self, message: Message) -> Result<Message, MiddlewareError>;
60
61 fn name(&self) -> &'static str;
63
64 fn should_skip(&self, _message: &Message) -> bool {
66 false
67 }
68}
69
70pub struct MessagePipeline {
72 middlewares: Vec<Box<dyn Middleware>>,
73}
74
75impl MessagePipeline {
76 pub fn new() -> Self {
78 Self {
79 middlewares: Vec::new(),
80 }
81 }
82
83 pub fn add(&mut self, middleware: Box<dyn Middleware>) {
85 self.middlewares.push(middleware);
86 }
87
88 pub fn process(&self, mut message: Message) -> Result<Message, MiddlewareError> {
90 for middleware in &self.middlewares {
91 if !middleware.should_skip(&message) {
92 message = middleware.process(message)?;
93 }
94 }
95 Ok(message)
96 }
97
98 #[inline]
100 pub fn len(&self) -> usize {
101 self.middlewares.len()
102 }
103
104 #[inline]
106 pub fn is_empty(&self) -> bool {
107 self.middlewares.is_empty()
108 }
109
110 pub fn clear(&mut self) {
112 self.middlewares.clear();
113 }
114}
115
116impl Default for MessagePipeline {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122pub struct ValidationMiddleware;
124
125impl Middleware for ValidationMiddleware {
126 fn process(&self, message: Message) -> Result<Message, MiddlewareError> {
127 message.validate().map_err(MiddlewareError::from)?;
128 Ok(message)
129 }
130
131 fn name(&self) -> &'static str {
132 "validation"
133 }
134}
135
136pub struct SizeLimitMiddleware {
138 max_size: usize,
139}
140
141impl SizeLimitMiddleware {
142 pub fn new(max_size: usize) -> Self {
144 Self { max_size }
145 }
146}
147
148impl Middleware for SizeLimitMiddleware {
149 fn process(&self, message: Message) -> Result<Message, MiddlewareError> {
150 if message.body.len() > self.max_size {
151 return Err(MiddlewareError::Validation(format!(
152 "Message body too large: {} bytes (max {})",
153 message.body.len(),
154 self.max_size
155 )));
156 }
157 Ok(message)
158 }
159
160 fn name(&self) -> &'static str {
161 "size_limit"
162 }
163}
164
165pub struct RetryLimitMiddleware {
167 max_retries: u32,
168}
169
170impl RetryLimitMiddleware {
171 pub fn new(max_retries: u32) -> Self {
173 Self { max_retries }
174 }
175}
176
177impl Middleware for RetryLimitMiddleware {
178 fn process(&self, message: Message) -> Result<Message, MiddlewareError> {
179 if let Some(retries) = message.headers.retries {
180 if retries > self.max_retries {
181 return Err(MiddlewareError::Validation(format!(
182 "Too many retries: {} (max {})",
183 retries, self.max_retries
184 )));
185 }
186 }
187 Ok(message)
188 }
189
190 fn name(&self) -> &'static str {
191 "retry_limit"
192 }
193}
194
195pub struct ContentTypeMiddleware {
197 allowed_types: Vec<String>,
198}
199
200impl ContentTypeMiddleware {
201 pub fn new(allowed_types: Vec<String>) -> Self {
203 Self { allowed_types }
204 }
205
206 pub fn json_only() -> Self {
208 Self {
209 allowed_types: vec!["application/json".to_string()],
210 }
211 }
212}
213
214impl Middleware for ContentTypeMiddleware {
215 fn process(&self, message: Message) -> Result<Message, MiddlewareError> {
216 if !self.allowed_types.contains(&message.content_type) {
217 return Err(MiddlewareError::Validation(format!(
218 "Content type '{}' not allowed. Allowed types: {:?}",
219 message.content_type, self.allowed_types
220 )));
221 }
222 Ok(message)
223 }
224
225 fn name(&self) -> &'static str {
226 "content_type"
227 }
228}
229
230pub struct TaskNameFilterMiddleware {
232 allowed_patterns: Vec<String>,
233}
234
235impl TaskNameFilterMiddleware {
236 pub fn new(allowed_patterns: Vec<String>) -> Self {
238 Self { allowed_patterns }
239 }
240
241 fn is_allowed(&self, task_name: &str) -> bool {
243 self.allowed_patterns.iter().any(|pattern| {
244 if pattern.ends_with('*') {
245 task_name.starts_with(&pattern[..pattern.len() - 1])
246 } else {
247 task_name == pattern
248 }
249 })
250 }
251}
252
253impl Middleware for TaskNameFilterMiddleware {
254 fn process(&self, message: Message) -> Result<Message, MiddlewareError> {
255 if !self.is_allowed(&message.headers.task) {
256 return Err(MiddlewareError::Validation(format!(
257 "Task '{}' not allowed by filter",
258 message.headers.task
259 )));
260 }
261 Ok(message)
262 }
263
264 fn name(&self) -> &'static str {
265 "task_name_filter"
266 }
267}
268
269pub struct PriorityMiddleware {
271 default_priority: u8,
272 enforce_limits: bool,
273}
274
275impl PriorityMiddleware {
276 pub fn new(default_priority: u8, enforce_limits: bool) -> Self {
278 Self {
279 default_priority,
280 enforce_limits,
281 }
282 }
283}
284
285impl Middleware for PriorityMiddleware {
286 fn process(&self, mut message: Message) -> Result<Message, MiddlewareError> {
287 if message.properties.priority.is_none() {
288 message.properties.priority = Some(self.default_priority);
289 } else if self.enforce_limits {
290 if let Some(priority) = message.properties.priority {
291 if priority > 9 {
292 return Err(MiddlewareError::Validation(format!(
293 "Priority {} exceeds maximum of 9",
294 priority
295 )));
296 }
297 }
298 }
299 Ok(message)
300 }
301
302 fn name(&self) -> &'static str {
303 "priority"
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use crate::TaskArgs;
311 use uuid::Uuid;
312
313 #[test]
314 fn test_pipeline_new() {
315 let pipeline = MessagePipeline::new();
316 assert_eq!(pipeline.len(), 0);
317 assert!(pipeline.is_empty());
318 }
319
320 #[test]
321 fn test_pipeline_add() {
322 let mut pipeline = MessagePipeline::new();
323 pipeline.add(Box::new(ValidationMiddleware));
324 assert_eq!(pipeline.len(), 1);
325 assert!(!pipeline.is_empty());
326 }
327
328 #[test]
329 fn test_pipeline_clear() {
330 let mut pipeline = MessagePipeline::new();
331 pipeline.add(Box::new(ValidationMiddleware));
332 pipeline.clear();
333 assert_eq!(pipeline.len(), 0);
334 }
335
336 #[test]
337 fn test_validation_middleware() {
338 let task_id = Uuid::new_v4();
339 let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
340 let msg = Message::new("tasks.add".to_string(), task_id, body);
341
342 let middleware = ValidationMiddleware;
343 let result = middleware.process(msg);
344 assert!(result.is_ok());
345 }
346
347 #[test]
348 fn test_size_limit_middleware_ok() {
349 let task_id = Uuid::new_v4();
350 let body = vec![1, 2, 3];
351 let msg = Message::new("tasks.test".to_string(), task_id, body);
352
353 let middleware = SizeLimitMiddleware::new(1000);
354 let result = middleware.process(msg);
355 assert!(result.is_ok());
356 }
357
358 #[test]
359 fn test_size_limit_middleware_exceeded() {
360 let task_id = Uuid::new_v4();
361 let body = vec![0u8; 1000];
362 let msg = Message::new("tasks.test".to_string(), task_id, body);
363
364 let middleware = SizeLimitMiddleware::new(100);
365 let result = middleware.process(msg);
366 assert!(result.is_err());
367 }
368
369 #[test]
370 fn test_retry_limit_middleware_ok() {
371 let task_id = Uuid::new_v4();
372 let body = vec![1, 2, 3];
373 let mut msg = Message::new("tasks.test".to_string(), task_id, body);
374 msg.headers.retries = Some(5);
375
376 let middleware = RetryLimitMiddleware::new(10);
377 let result = middleware.process(msg);
378 assert!(result.is_ok());
379 }
380
381 #[test]
382 fn test_retry_limit_middleware_exceeded() {
383 let task_id = Uuid::new_v4();
384 let body = vec![1, 2, 3];
385 let mut msg = Message::new("tasks.test".to_string(), task_id, body);
386 msg.headers.retries = Some(15);
387
388 let middleware = RetryLimitMiddleware::new(10);
389 let result = middleware.process(msg);
390 assert!(result.is_err());
391 }
392
393 #[test]
394 fn test_content_type_middleware_allowed() {
395 let task_id = Uuid::new_v4();
396 let body = vec![1, 2, 3];
397 let msg = Message::new("tasks.test".to_string(), task_id, body);
398
399 let middleware = ContentTypeMiddleware::json_only();
400 let result = middleware.process(msg);
401 assert!(result.is_ok());
402 }
403
404 #[test]
405 fn test_content_type_middleware_blocked() {
406 let task_id = Uuid::new_v4();
407 let body = vec![1, 2, 3];
408 let mut msg = Message::new("tasks.test".to_string(), task_id, body);
409 msg.content_type = "application/pickle".to_string();
410
411 let middleware = ContentTypeMiddleware::json_only();
412 let result = middleware.process(msg);
413 assert!(result.is_err());
414 }
415
416 #[test]
417 fn test_task_name_filter_exact_match() {
418 let task_id = Uuid::new_v4();
419 let body = vec![1, 2, 3];
420 let msg = Message::new("tasks.allowed".to_string(), task_id, body);
421
422 let middleware = TaskNameFilterMiddleware::new(vec!["tasks.allowed".to_string()]);
423 let result = middleware.process(msg);
424 assert!(result.is_ok());
425 }
426
427 #[test]
428 fn test_task_name_filter_wildcard() {
429 let task_id = Uuid::new_v4();
430 let body = vec![1, 2, 3];
431 let msg = Message::new("tasks.something.add".to_string(), task_id, body);
432
433 let middleware = TaskNameFilterMiddleware::new(vec!["tasks.*".to_string()]);
434 let result = middleware.process(msg);
435 assert!(result.is_ok());
436 }
437
438 #[test]
439 fn test_task_name_filter_blocked() {
440 let task_id = Uuid::new_v4();
441 let body = vec![1, 2, 3];
442 let msg = Message::new("forbidden.task".to_string(), task_id, body);
443
444 let middleware = TaskNameFilterMiddleware::new(vec!["tasks.*".to_string()]);
445 let result = middleware.process(msg);
446 assert!(result.is_err());
447 }
448
449 #[test]
450 fn test_priority_middleware_default() {
451 let task_id = Uuid::new_v4();
452 let body = vec![1, 2, 3];
453 let msg = Message::new("tasks.test".to_string(), task_id, body);
454
455 let middleware = PriorityMiddleware::new(5, false);
456 let result = middleware.process(msg).unwrap();
457 assert_eq!(result.properties.priority, Some(5));
458 }
459
460 #[test]
461 fn test_priority_middleware_enforce_limits() {
462 let task_id = Uuid::new_v4();
463 let body = vec![1, 2, 3];
464 let msg = Message::new("tasks.test".to_string(), task_id, body).with_priority(15);
465
466 let middleware = PriorityMiddleware::new(5, true);
467 let result = middleware.process(msg);
468 assert!(result.is_err());
469 }
470
471 #[test]
472 fn test_pipeline_process() {
473 let task_id = Uuid::new_v4();
474 let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
475 let msg = Message::new("tasks.add".to_string(), task_id, body);
476
477 let mut pipeline = MessagePipeline::new();
478 pipeline.add(Box::new(ValidationMiddleware));
479 pipeline.add(Box::new(SizeLimitMiddleware::new(10000)));
480
481 let result = pipeline.process(msg);
482 assert!(result.is_ok());
483 }
484
485 #[test]
486 fn test_pipeline_process_failure() {
487 let task_id = Uuid::new_v4();
488 let body = vec![0u8; 1000];
489 let msg = Message::new("tasks.test".to_string(), task_id, body);
490
491 let mut pipeline = MessagePipeline::new();
492 pipeline.add(Box::new(ValidationMiddleware));
493 pipeline.add(Box::new(SizeLimitMiddleware::new(100)));
494
495 let result = pipeline.process(msg);
496 assert!(result.is_err());
497 }
498
499 #[test]
500 fn test_middleware_error_display() {
501 let err = MiddlewareError::Validation("test error".to_string());
502 assert_eq!(err.to_string(), "Validation error: test error");
503
504 let err = MiddlewareError::Transformation("transform failed".to_string());
505 assert_eq!(err.to_string(), "Transformation error: transform failed");
506
507 let err = MiddlewareError::Processing("process error".to_string());
508 assert_eq!(err.to_string(), "Processing error: process error");
509 }
510}