celers_protocol/
middleware.rs

1//! Message transformation middleware
2//!
3//! This module provides a middleware pattern for transforming messages through
4//! a pipeline of transformations (validation, signing, encryption, etc.).
5//!
6//! # Example
7//!
8//! ```
9//! use celers_protocol::middleware::{MessagePipeline, ValidationMiddleware};
10//! use celers_protocol::{Message, TaskArgs};
11//! use uuid::Uuid;
12//!
13//! let task_id = Uuid::new_v4();
14//! let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
15//! let msg = Message::new("tasks.add".to_string(), task_id, body);
16//!
17//! let mut pipeline = MessagePipeline::new();
18//! pipeline.add(Box::new(ValidationMiddleware));
19//!
20//! let result = pipeline.process(msg);
21//! assert!(result.is_ok());
22//! ```
23
24use crate::Message;
25use std::fmt;
26
27/// Middleware error
28#[derive(Debug, Clone)]
29pub enum MiddlewareError {
30    /// Validation failed
31    Validation(String),
32    /// Transformation failed
33    Transformation(String),
34    /// Processing error
35    Processing(String),
36}
37
38impl fmt::Display for MiddlewareError {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            MiddlewareError::Validation(msg) => write!(f, "Validation error: {}", msg),
42            MiddlewareError::Transformation(msg) => write!(f, "Transformation error: {}", msg),
43            MiddlewareError::Processing(msg) => write!(f, "Processing error: {}", msg),
44        }
45    }
46}
47
48impl From<crate::ValidationError> for MiddlewareError {
49    fn from(err: crate::ValidationError) -> Self {
50        MiddlewareError::Validation(err.to_string())
51    }
52}
53
54impl std::error::Error for MiddlewareError {}
55
56/// Middleware trait for message processing
57pub trait Middleware: Send + Sync {
58    /// Process a message
59    fn process(&self, message: Message) -> Result<Message, MiddlewareError>;
60
61    /// Get the name of this middleware
62    fn name(&self) -> &'static str;
63
64    /// Check if this middleware should be skipped for a message
65    fn should_skip(&self, _message: &Message) -> bool {
66        false
67    }
68}
69
70/// Message processing pipeline
71pub struct MessagePipeline {
72    middlewares: Vec<Box<dyn Middleware>>,
73}
74
75impl MessagePipeline {
76    /// Create a new empty pipeline
77    pub fn new() -> Self {
78        Self {
79            middlewares: Vec::new(),
80        }
81    }
82
83    /// Add a middleware to the pipeline
84    pub fn add(&mut self, middleware: Box<dyn Middleware>) {
85        self.middlewares.push(middleware);
86    }
87
88    /// Process a message through the pipeline
89    pub fn process(&self, mut message: Message) -> Result<Message, MiddlewareError> {
90        for middleware in &self.middlewares {
91            if !middleware.should_skip(&message) {
92                message = middleware.process(message)?;
93            }
94        }
95        Ok(message)
96    }
97
98    /// Get the number of middlewares in the pipeline
99    #[inline]
100    pub fn len(&self) -> usize {
101        self.middlewares.len()
102    }
103
104    /// Check if the pipeline is empty
105    #[inline]
106    pub fn is_empty(&self) -> bool {
107        self.middlewares.is_empty()
108    }
109
110    /// Clear all middlewares
111    pub fn clear(&mut self) {
112        self.middlewares.clear();
113    }
114}
115
116impl Default for MessagePipeline {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122/// Validation middleware
123pub struct ValidationMiddleware;
124
125impl Middleware for ValidationMiddleware {
126    fn process(&self, message: Message) -> Result<Message, MiddlewareError> {
127        message.validate().map_err(MiddlewareError::from)?;
128        Ok(message)
129    }
130
131    fn name(&self) -> &'static str {
132        "validation"
133    }
134}
135
136/// Size limit middleware
137pub struct SizeLimitMiddleware {
138    max_size: usize,
139}
140
141impl SizeLimitMiddleware {
142    /// Create a new size limit middleware
143    pub fn new(max_size: usize) -> Self {
144        Self { max_size }
145    }
146}
147
148impl Middleware for SizeLimitMiddleware {
149    fn process(&self, message: Message) -> Result<Message, MiddlewareError> {
150        if message.body.len() > self.max_size {
151            return Err(MiddlewareError::Validation(format!(
152                "Message body too large: {} bytes (max {})",
153                message.body.len(),
154                self.max_size
155            )));
156        }
157        Ok(message)
158    }
159
160    fn name(&self) -> &'static str {
161        "size_limit"
162    }
163}
164
165/// Retry count middleware
166pub struct RetryLimitMiddleware {
167    max_retries: u32,
168}
169
170impl RetryLimitMiddleware {
171    /// Create a new retry limit middleware
172    pub fn new(max_retries: u32) -> Self {
173        Self { max_retries }
174    }
175}
176
177impl Middleware for RetryLimitMiddleware {
178    fn process(&self, message: Message) -> Result<Message, MiddlewareError> {
179        if let Some(retries) = message.headers.retries {
180            if retries > self.max_retries {
181                return Err(MiddlewareError::Validation(format!(
182                    "Too many retries: {} (max {})",
183                    retries, self.max_retries
184                )));
185            }
186        }
187        Ok(message)
188    }
189
190    fn name(&self) -> &'static str {
191        "retry_limit"
192    }
193}
194
195/// Content type validation middleware
196pub struct ContentTypeMiddleware {
197    allowed_types: Vec<String>,
198}
199
200impl ContentTypeMiddleware {
201    /// Create a new content type middleware
202    pub fn new(allowed_types: Vec<String>) -> Self {
203        Self { allowed_types }
204    }
205
206    /// Create middleware that only allows JSON
207    pub fn json_only() -> Self {
208        Self {
209            allowed_types: vec!["application/json".to_string()],
210        }
211    }
212}
213
214impl Middleware for ContentTypeMiddleware {
215    fn process(&self, message: Message) -> Result<Message, MiddlewareError> {
216        if !self.allowed_types.contains(&message.content_type) {
217            return Err(MiddlewareError::Validation(format!(
218                "Content type '{}' not allowed. Allowed types: {:?}",
219                message.content_type, self.allowed_types
220            )));
221        }
222        Ok(message)
223    }
224
225    fn name(&self) -> &'static str {
226        "content_type"
227    }
228}
229
230/// Task name filter middleware
231pub struct TaskNameFilterMiddleware {
232    allowed_patterns: Vec<String>,
233}
234
235impl TaskNameFilterMiddleware {
236    /// Create a new task name filter middleware
237    pub fn new(allowed_patterns: Vec<String>) -> Self {
238        Self { allowed_patterns }
239    }
240
241    /// Check if a task name matches any allowed pattern
242    fn is_allowed(&self, task_name: &str) -> bool {
243        self.allowed_patterns.iter().any(|pattern| {
244            if pattern.ends_with('*') {
245                task_name.starts_with(&pattern[..pattern.len() - 1])
246            } else {
247                task_name == pattern
248            }
249        })
250    }
251}
252
253impl Middleware for TaskNameFilterMiddleware {
254    fn process(&self, message: Message) -> Result<Message, MiddlewareError> {
255        if !self.is_allowed(&message.headers.task) {
256            return Err(MiddlewareError::Validation(format!(
257                "Task '{}' not allowed by filter",
258                message.headers.task
259            )));
260        }
261        Ok(message)
262    }
263
264    fn name(&self) -> &'static str {
265        "task_name_filter"
266    }
267}
268
269/// Priority enforcement middleware
270pub struct PriorityMiddleware {
271    default_priority: u8,
272    enforce_limits: bool,
273}
274
275impl PriorityMiddleware {
276    /// Create a new priority middleware
277    pub fn new(default_priority: u8, enforce_limits: bool) -> Self {
278        Self {
279            default_priority,
280            enforce_limits,
281        }
282    }
283}
284
285impl Middleware for PriorityMiddleware {
286    fn process(&self, mut message: Message) -> Result<Message, MiddlewareError> {
287        if message.properties.priority.is_none() {
288            message.properties.priority = Some(self.default_priority);
289        } else if self.enforce_limits {
290            if let Some(priority) = message.properties.priority {
291                if priority > 9 {
292                    return Err(MiddlewareError::Validation(format!(
293                        "Priority {} exceeds maximum of 9",
294                        priority
295                    )));
296                }
297            }
298        }
299        Ok(message)
300    }
301
302    fn name(&self) -> &'static str {
303        "priority"
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use crate::TaskArgs;
311    use uuid::Uuid;
312
313    #[test]
314    fn test_pipeline_new() {
315        let pipeline = MessagePipeline::new();
316        assert_eq!(pipeline.len(), 0);
317        assert!(pipeline.is_empty());
318    }
319
320    #[test]
321    fn test_pipeline_add() {
322        let mut pipeline = MessagePipeline::new();
323        pipeline.add(Box::new(ValidationMiddleware));
324        assert_eq!(pipeline.len(), 1);
325        assert!(!pipeline.is_empty());
326    }
327
328    #[test]
329    fn test_pipeline_clear() {
330        let mut pipeline = MessagePipeline::new();
331        pipeline.add(Box::new(ValidationMiddleware));
332        pipeline.clear();
333        assert_eq!(pipeline.len(), 0);
334    }
335
336    #[test]
337    fn test_validation_middleware() {
338        let task_id = Uuid::new_v4();
339        let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
340        let msg = Message::new("tasks.add".to_string(), task_id, body);
341
342        let middleware = ValidationMiddleware;
343        let result = middleware.process(msg);
344        assert!(result.is_ok());
345    }
346
347    #[test]
348    fn test_size_limit_middleware_ok() {
349        let task_id = Uuid::new_v4();
350        let body = vec![1, 2, 3];
351        let msg = Message::new("tasks.test".to_string(), task_id, body);
352
353        let middleware = SizeLimitMiddleware::new(1000);
354        let result = middleware.process(msg);
355        assert!(result.is_ok());
356    }
357
358    #[test]
359    fn test_size_limit_middleware_exceeded() {
360        let task_id = Uuid::new_v4();
361        let body = vec![0u8; 1000];
362        let msg = Message::new("tasks.test".to_string(), task_id, body);
363
364        let middleware = SizeLimitMiddleware::new(100);
365        let result = middleware.process(msg);
366        assert!(result.is_err());
367    }
368
369    #[test]
370    fn test_retry_limit_middleware_ok() {
371        let task_id = Uuid::new_v4();
372        let body = vec![1, 2, 3];
373        let mut msg = Message::new("tasks.test".to_string(), task_id, body);
374        msg.headers.retries = Some(5);
375
376        let middleware = RetryLimitMiddleware::new(10);
377        let result = middleware.process(msg);
378        assert!(result.is_ok());
379    }
380
381    #[test]
382    fn test_retry_limit_middleware_exceeded() {
383        let task_id = Uuid::new_v4();
384        let body = vec![1, 2, 3];
385        let mut msg = Message::new("tasks.test".to_string(), task_id, body);
386        msg.headers.retries = Some(15);
387
388        let middleware = RetryLimitMiddleware::new(10);
389        let result = middleware.process(msg);
390        assert!(result.is_err());
391    }
392
393    #[test]
394    fn test_content_type_middleware_allowed() {
395        let task_id = Uuid::new_v4();
396        let body = vec![1, 2, 3];
397        let msg = Message::new("tasks.test".to_string(), task_id, body);
398
399        let middleware = ContentTypeMiddleware::json_only();
400        let result = middleware.process(msg);
401        assert!(result.is_ok());
402    }
403
404    #[test]
405    fn test_content_type_middleware_blocked() {
406        let task_id = Uuid::new_v4();
407        let body = vec![1, 2, 3];
408        let mut msg = Message::new("tasks.test".to_string(), task_id, body);
409        msg.content_type = "application/pickle".to_string();
410
411        let middleware = ContentTypeMiddleware::json_only();
412        let result = middleware.process(msg);
413        assert!(result.is_err());
414    }
415
416    #[test]
417    fn test_task_name_filter_exact_match() {
418        let task_id = Uuid::new_v4();
419        let body = vec![1, 2, 3];
420        let msg = Message::new("tasks.allowed".to_string(), task_id, body);
421
422        let middleware = TaskNameFilterMiddleware::new(vec!["tasks.allowed".to_string()]);
423        let result = middleware.process(msg);
424        assert!(result.is_ok());
425    }
426
427    #[test]
428    fn test_task_name_filter_wildcard() {
429        let task_id = Uuid::new_v4();
430        let body = vec![1, 2, 3];
431        let msg = Message::new("tasks.something.add".to_string(), task_id, body);
432
433        let middleware = TaskNameFilterMiddleware::new(vec!["tasks.*".to_string()]);
434        let result = middleware.process(msg);
435        assert!(result.is_ok());
436    }
437
438    #[test]
439    fn test_task_name_filter_blocked() {
440        let task_id = Uuid::new_v4();
441        let body = vec![1, 2, 3];
442        let msg = Message::new("forbidden.task".to_string(), task_id, body);
443
444        let middleware = TaskNameFilterMiddleware::new(vec!["tasks.*".to_string()]);
445        let result = middleware.process(msg);
446        assert!(result.is_err());
447    }
448
449    #[test]
450    fn test_priority_middleware_default() {
451        let task_id = Uuid::new_v4();
452        let body = vec![1, 2, 3];
453        let msg = Message::new("tasks.test".to_string(), task_id, body);
454
455        let middleware = PriorityMiddleware::new(5, false);
456        let result = middleware.process(msg).unwrap();
457        assert_eq!(result.properties.priority, Some(5));
458    }
459
460    #[test]
461    fn test_priority_middleware_enforce_limits() {
462        let task_id = Uuid::new_v4();
463        let body = vec![1, 2, 3];
464        let msg = Message::new("tasks.test".to_string(), task_id, body).with_priority(15);
465
466        let middleware = PriorityMiddleware::new(5, true);
467        let result = middleware.process(msg);
468        assert!(result.is_err());
469    }
470
471    #[test]
472    fn test_pipeline_process() {
473        let task_id = Uuid::new_v4();
474        let body = serde_json::to_vec(&TaskArgs::new()).unwrap();
475        let msg = Message::new("tasks.add".to_string(), task_id, body);
476
477        let mut pipeline = MessagePipeline::new();
478        pipeline.add(Box::new(ValidationMiddleware));
479        pipeline.add(Box::new(SizeLimitMiddleware::new(10000)));
480
481        let result = pipeline.process(msg);
482        assert!(result.is_ok());
483    }
484
485    #[test]
486    fn test_pipeline_process_failure() {
487        let task_id = Uuid::new_v4();
488        let body = vec![0u8; 1000];
489        let msg = Message::new("tasks.test".to_string(), task_id, body);
490
491        let mut pipeline = MessagePipeline::new();
492        pipeline.add(Box::new(ValidationMiddleware));
493        pipeline.add(Box::new(SizeLimitMiddleware::new(100)));
494
495        let result = pipeline.process(msg);
496        assert!(result.is_err());
497    }
498
499    #[test]
500    fn test_middleware_error_display() {
501        let err = MiddlewareError::Validation("test error".to_string());
502        assert_eq!(err.to_string(), "Validation error: test error");
503
504        let err = MiddlewareError::Transformation("transform failed".to_string());
505        assert_eq!(err.to_string(), "Transformation error: transform failed");
506
507        let err = MiddlewareError::Processing("process error".to_string());
508        assert_eq!(err.to_string(), "Processing error: process error");
509    }
510}