Skip to main content

celers_kombu/
middleware_basic.rs

1//! Basic middleware implementations.
2
3use async_trait::async_trait;
4use celers_protocol::Message;
5use std::time::Duration;
6
7use uuid::Uuid;
8
9use crate::{BrokerError, BrokerMetrics, MessageMiddleware, Result};
10
11// =============================================================================
12// Built-in Middleware Implementations
13// =============================================================================
14
15/// Validation middleware - validates message structure
16///
17/// # Examples
18///
19/// ```
20/// use celers_kombu::ValidationMiddleware;
21///
22/// // Default validation (10MB max, require task name)
23/// let validator = ValidationMiddleware::new();
24///
25/// // Custom validation
26/// let validator = ValidationMiddleware::new()
27///     .with_max_body_size(5 * 1024 * 1024)  // 5MB limit
28///     .with_require_task_name(true);
29///
30/// // Disable body size limit
31/// let validator = ValidationMiddleware::new()
32///     .without_body_size_limit();
33/// ```
34pub struct ValidationMiddleware {
35    /// Maximum message body size (bytes)
36    max_body_size: Option<usize>,
37    /// Require task name to be non-empty
38    require_task_name: bool,
39}
40
41impl ValidationMiddleware {
42    /// Create a new validation middleware
43    pub fn new() -> Self {
44        Self {
45            max_body_size: Some(10 * 1024 * 1024), // 10MB default
46            require_task_name: true,
47        }
48    }
49
50    /// Set maximum body size
51    pub fn with_max_body_size(mut self, size: usize) -> Self {
52        self.max_body_size = Some(size);
53        self
54    }
55
56    /// Disable body size check
57    pub fn without_body_size_limit(mut self) -> Self {
58        self.max_body_size = None;
59        self
60    }
61
62    /// Set whether task name is required
63    pub fn with_require_task_name(mut self, require: bool) -> Self {
64        self.require_task_name = require;
65        self
66    }
67
68    fn validate_message(&self, message: &Message) -> Result<()> {
69        // Check task name
70        if self.require_task_name && message.task_name().is_empty() {
71            return Err(BrokerError::Configuration(
72                "Task name cannot be empty".to_string(),
73            ));
74        }
75
76        // Check body size
77        if let Some(max_size) = self.max_body_size {
78            if message.body.len() > max_size {
79                return Err(BrokerError::Configuration(format!(
80                    "Message body size {} exceeds maximum {}",
81                    message.body.len(),
82                    max_size
83                )));
84            }
85        }
86
87        Ok(())
88    }
89}
90
91impl Default for ValidationMiddleware {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97#[async_trait]
98impl MessageMiddleware for ValidationMiddleware {
99    async fn before_publish(&self, message: &mut Message) -> Result<()> {
100        self.validate_message(message)
101    }
102
103    async fn after_consume(&self, message: &mut Message) -> Result<()> {
104        self.validate_message(message)
105    }
106
107    fn name(&self) -> &str {
108        "validation"
109    }
110}
111
112/// Logging middleware - logs message events
113///
114/// # Examples
115///
116/// ```
117/// use celers_kombu::LoggingMiddleware;
118///
119/// // Basic logging
120/// let logger = LoggingMiddleware::new("MyApp");
121///
122/// // With detailed body logging
123/// let verbose_logger = LoggingMiddleware::new("MyApp")
124///     .with_body_logging();
125/// ```
126pub struct LoggingMiddleware {
127    prefix: String,
128    log_body: bool,
129}
130
131impl LoggingMiddleware {
132    /// Create a new logging middleware
133    pub fn new(prefix: impl Into<String>) -> Self {
134        Self {
135            prefix: prefix.into(),
136            log_body: false,
137        }
138    }
139
140    /// Enable body logging (for debugging)
141    pub fn with_body_logging(mut self) -> Self {
142        self.log_body = true;
143        self
144    }
145}
146
147#[async_trait]
148impl MessageMiddleware for LoggingMiddleware {
149    async fn before_publish(&self, message: &mut Message) -> Result<()> {
150        if self.log_body {
151            eprintln!(
152                "[{}] Publishing: task={}, id={}, body_size={}",
153                self.prefix,
154                message.task_name(),
155                message.task_id(),
156                message.body.len()
157            );
158        } else {
159            eprintln!(
160                "[{}] Publishing: task={}, id={}",
161                self.prefix,
162                message.task_name(),
163                message.task_id()
164            );
165        }
166        Ok(())
167    }
168
169    async fn after_consume(&self, message: &mut Message) -> Result<()> {
170        if self.log_body {
171            eprintln!(
172                "[{}] Consumed: task={}, id={}, body_size={}",
173                self.prefix,
174                message.task_name(),
175                message.task_id(),
176                message.body.len()
177            );
178        } else {
179            eprintln!(
180                "[{}] Consumed: task={}, id={}",
181                self.prefix,
182                message.task_name(),
183                message.task_id()
184            );
185        }
186        Ok(())
187    }
188
189    fn name(&self) -> &str {
190        "logging"
191    }
192}
193
194/// Metrics middleware - collects message statistics
195///
196/// # Examples
197///
198/// ```
199/// use celers_kombu::{MetricsMiddleware, BrokerMetrics};
200/// use std::sync::{Arc, Mutex};
201///
202/// let metrics = Arc::new(Mutex::new(BrokerMetrics::default()));
203/// let middleware = MetricsMiddleware::new(metrics.clone());
204///
205/// // Later, get metrics snapshot
206/// let snapshot = middleware.get_metrics();
207/// assert_eq!(snapshot.messages_published, 0);
208/// ```
209pub struct MetricsMiddleware {
210    metrics: std::sync::Arc<std::sync::Mutex<BrokerMetrics>>,
211}
212
213impl MetricsMiddleware {
214    /// Create a new metrics middleware
215    pub fn new(metrics: std::sync::Arc<std::sync::Mutex<BrokerMetrics>>) -> Self {
216        Self { metrics }
217    }
218
219    /// Get current metrics snapshot
220    pub fn get_metrics(&self) -> BrokerMetrics {
221        self.metrics.lock().unwrap().clone()
222    }
223}
224
225#[async_trait]
226impl MessageMiddleware for MetricsMiddleware {
227    async fn before_publish(&self, _message: &mut Message) -> Result<()> {
228        let mut metrics = self.metrics.lock().unwrap();
229        metrics.inc_published();
230        Ok(())
231    }
232
233    async fn after_consume(&self, _message: &mut Message) -> Result<()> {
234        let mut metrics = self.metrics.lock().unwrap();
235        metrics.inc_consumed();
236        Ok(())
237    }
238
239    fn name(&self) -> &str {
240        "metrics"
241    }
242}
243
244/// Retry limit middleware - enforces maximum retry count
245///
246/// # Examples
247///
248/// ```
249/// use celers_kombu::RetryLimitMiddleware;
250///
251/// // Allow up to 3 retries
252/// let middleware = RetryLimitMiddleware::new(3);
253/// ```
254pub struct RetryLimitMiddleware {
255    max_retries: u32,
256}
257
258impl RetryLimitMiddleware {
259    /// Create a new retry limit middleware
260    pub fn new(max_retries: u32) -> Self {
261        Self { max_retries }
262    }
263}
264
265#[async_trait]
266impl MessageMiddleware for RetryLimitMiddleware {
267    async fn before_publish(&self, _message: &mut Message) -> Result<()> {
268        // No validation on publish
269        Ok(())
270    }
271
272    async fn after_consume(&self, message: &mut Message) -> Result<()> {
273        // Check retry count from message headers
274        let retries = message.headers.retries.unwrap_or(0);
275        if retries > self.max_retries {
276            return Err(BrokerError::Configuration(format!(
277                "Message exceeded maximum retries: {} > {}",
278                retries, self.max_retries
279            )));
280        }
281        Ok(())
282    }
283
284    fn name(&self) -> &str {
285        "retry_limit"
286    }
287}
288
289/// Rate limiting middleware - enforces message rate limits
290///
291/// # Examples
292///
293/// ```
294/// use celers_kombu::RateLimitingMiddleware;
295///
296/// // Limit to 100 messages per second
297/// let middleware = RateLimitingMiddleware::new(100.0);
298/// ```
299pub struct RateLimitingMiddleware {
300    /// Maximum messages per second
301    max_rate: f64,
302    /// Token bucket (tracks available tokens)
303    tokens: std::sync::Arc<std::sync::Mutex<TokenBucket>>,
304}
305
306/// Token bucket for rate limiting
307struct TokenBucket {
308    /// Current token count
309    tokens: f64,
310    /// Maximum tokens
311    capacity: f64,
312    /// Tokens added per second
313    refill_rate: f64,
314    /// Last refill time
315    last_refill: std::time::Instant,
316}
317
318impl TokenBucket {
319    fn new(capacity: f64, refill_rate: f64) -> Self {
320        Self {
321            tokens: capacity,
322            capacity,
323            refill_rate,
324            last_refill: std::time::Instant::now(),
325        }
326    }
327
328    fn try_consume(&mut self, tokens: f64) -> bool {
329        // Refill tokens based on elapsed time
330        let now = std::time::Instant::now();
331        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
332        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity);
333        self.last_refill = now;
334
335        // Try to consume tokens
336        if self.tokens >= tokens {
337            self.tokens -= tokens;
338            true
339        } else {
340            false
341        }
342    }
343}
344
345impl RateLimitingMiddleware {
346    /// Create a new rate limiting middleware
347    ///
348    /// # Arguments
349    ///
350    /// * `max_rate` - Maximum messages per second
351    pub fn new(max_rate: f64) -> Self {
352        Self {
353            max_rate,
354            tokens: std::sync::Arc::new(std::sync::Mutex::new(TokenBucket::new(
355                max_rate, max_rate,
356            ))),
357        }
358    }
359}
360
361#[async_trait]
362impl MessageMiddleware for RateLimitingMiddleware {
363    async fn before_publish(&self, _message: &mut Message) -> Result<()> {
364        // Try to acquire a token
365        let mut bucket = self.tokens.lock().unwrap();
366        if !bucket.try_consume(1.0) {
367            return Err(BrokerError::OperationFailed(format!(
368                "Rate limit exceeded: {} messages/sec",
369                self.max_rate
370            )));
371        }
372        Ok(())
373    }
374
375    async fn after_consume(&self, _message: &mut Message) -> Result<()> {
376        // No rate limiting on consume
377        Ok(())
378    }
379
380    fn name(&self) -> &str {
381        "rate_limit"
382    }
383}
384
385/// Deduplication middleware - prevents duplicate message processing
386///
387/// # Examples
388///
389/// ```
390/// use celers_kombu::DeduplicationMiddleware;
391///
392/// // Track up to 5000 message IDs
393/// let middleware = DeduplicationMiddleware::new(5000);
394///
395/// // Use default cache size (10,000)
396/// let default_middleware = DeduplicationMiddleware::with_default_cache();
397/// ```
398pub struct DeduplicationMiddleware {
399    /// Recently seen message IDs
400    seen_ids: std::sync::Arc<std::sync::Mutex<std::collections::HashSet<Uuid>>>,
401    /// Maximum size of seen IDs cache
402    max_cache_size: usize,
403}
404
405impl DeduplicationMiddleware {
406    /// Create a new deduplication middleware
407    ///
408    /// # Arguments
409    ///
410    /// * `max_cache_size` - Maximum number of message IDs to track
411    pub fn new(max_cache_size: usize) -> Self {
412        Self {
413            seen_ids: std::sync::Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
414            max_cache_size,
415        }
416    }
417
418    /// Create with default cache size (10,000 message IDs)
419    pub fn with_default_cache() -> Self {
420        Self::new(10_000)
421    }
422}
423
424impl Default for DeduplicationMiddleware {
425    fn default() -> Self {
426        Self::with_default_cache()
427    }
428}
429
430#[async_trait]
431impl MessageMiddleware for DeduplicationMiddleware {
432    async fn before_publish(&self, _message: &mut Message) -> Result<()> {
433        // No deduplication on publish
434        Ok(())
435    }
436
437    async fn after_consume(&self, message: &mut Message) -> Result<()> {
438        let msg_id = message.task_id();
439        let mut seen = self.seen_ids.lock().unwrap();
440
441        // Check if we've seen this message before
442        if seen.contains(&msg_id) {
443            return Err(BrokerError::OperationFailed(format!(
444                "Duplicate message detected: {}",
445                msg_id
446            )));
447        }
448
449        // Add to seen set
450        seen.insert(msg_id);
451
452        // Evict oldest entries if cache is too large (simple FIFO eviction)
453        if seen.len() > self.max_cache_size {
454            // Remove first element (note: HashSet doesn't have ordering, so this is arbitrary)
455            if let Some(&id) = seen.iter().next() {
456                seen.remove(&id);
457            }
458        }
459
460        Ok(())
461    }
462
463    fn name(&self) -> &str {
464        "deduplication"
465    }
466}
467
468/// Compression middleware - compresses/decompresses message bodies
469///
470/// # Examples
471///
472/// ```
473/// # #[cfg(feature = "compression")]
474/// # {
475/// use celers_kombu::CompressionMiddleware;
476/// use celers_protocol::compression::CompressionType;
477///
478/// let middleware = CompressionMiddleware::new(CompressionType::Gzip)
479///     .with_min_size(2048)  // Only compress messages >= 2KB
480///     .with_level(6);       // Compression level 6
481/// # }
482/// ```
483#[cfg(feature = "compression")]
484pub struct CompressionMiddleware {
485    /// Compressor instance
486    compressor: celers_protocol::compression::Compressor,
487    /// Minimum body size to compress (bytes)
488    min_compress_size: usize,
489}
490
491#[cfg(feature = "compression")]
492impl CompressionMiddleware {
493    /// Create a new compression middleware
494    ///
495    /// # Arguments
496    ///
497    /// * `compression_type` - Type of compression to use
498    pub fn new(compression_type: celers_protocol::compression::CompressionType) -> Self {
499        Self {
500            compressor: celers_protocol::compression::Compressor::new(compression_type),
501            min_compress_size: 1024, // 1KB default
502        }
503    }
504
505    /// Set minimum body size to compress
506    pub fn with_min_size(mut self, size: usize) -> Self {
507        self.min_compress_size = size;
508        self
509    }
510
511    /// Set compression level
512    pub fn with_level(mut self, level: u32) -> Self {
513        self.compressor = self.compressor.with_level(level);
514        self
515    }
516}
517
518#[cfg(feature = "compression")]
519#[async_trait]
520impl MessageMiddleware for CompressionMiddleware {
521    async fn before_publish(&self, message: &mut Message) -> Result<()> {
522        // Only compress if body is large enough
523        if message.body.len() >= self.min_compress_size {
524            let compressed = self
525                .compressor
526                .compress(&message.body)
527                .map_err(|e| BrokerError::Serialization(e.to_string()))?;
528
529            // Only use compressed version if it's actually smaller
530            if compressed.len() < message.body.len() {
531                message.body = compressed;
532                // Note: In a real implementation, we'd set a header to indicate compression
533            }
534        }
535        Ok(())
536    }
537
538    async fn after_consume(&self, message: &mut Message) -> Result<()> {
539        // Try to decompress (would need to check header in real implementation)
540        // For now, we'll skip decompression on consume since we don't have metadata
541        // A real implementation would check message headers for compression flag
542        let _ = message;
543        Ok(())
544    }
545
546    fn name(&self) -> &str {
547        "compression"
548    }
549}
550
551/// Signing middleware - signs/verifies message bodies using HMAC
552///
553/// # Examples
554///
555/// ```
556/// # #[cfg(feature = "signing")]
557/// # {
558/// use celers_kombu::SigningMiddleware;
559///
560/// let secret_key = b"my-secret-key";
561/// let middleware = SigningMiddleware::new(secret_key);
562/// # }
563/// ```
564#[cfg(feature = "signing")]
565pub struct SigningMiddleware {
566    /// Message signer instance
567    signer: celers_protocol::auth::MessageSigner,
568}
569
570#[cfg(feature = "signing")]
571impl SigningMiddleware {
572    /// Create a new signing middleware
573    ///
574    /// # Arguments
575    ///
576    /// * `key` - Secret key for HMAC signing
577    pub fn new(key: &[u8]) -> Self {
578        Self {
579            signer: celers_protocol::auth::MessageSigner::new(key),
580        }
581    }
582}
583
584#[cfg(feature = "signing")]
585#[async_trait]
586impl MessageMiddleware for SigningMiddleware {
587    async fn before_publish(&self, message: &mut Message) -> Result<()> {
588        // Sign the message body
589        let signature = self
590            .signer
591            .sign(&message.body)
592            .map_err(|e| BrokerError::OperationFailed(format!("signing failed: {}", e)))?;
593
594        // Store signature in message headers (would need custom header field)
595        // For now, we'll just validate that signing works
596        // In a real implementation, we'd add a signature field to Message
597        let _ = signature;
598
599        Ok(())
600    }
601
602    async fn after_consume(&self, message: &mut Message) -> Result<()> {
603        // In a real implementation, we'd:
604        // 1. Extract signature from message headers
605        // 2. Verify signature against message body
606        // 3. Return error if verification fails
607        //
608        // For now, we'll just validate the message can be signed
609        let _ = self
610            .signer
611            .sign(&message.body)
612            .map_err(|e| BrokerError::OperationFailed(format!("signing failed: {}", e)))?;
613
614        Ok(())
615    }
616
617    fn name(&self) -> &str {
618        "signing"
619    }
620}
621
622/// Encryption middleware - encrypts/decrypts message bodies using AES-256-GCM
623///
624/// # Examples
625///
626/// ```
627/// # #[cfg(feature = "encryption")]
628/// # {
629/// use celers_kombu::EncryptionMiddleware;
630///
631/// // 32-byte key for AES-256
632/// let key = [0u8; 32];
633/// let middleware = EncryptionMiddleware::new(&key).expect("valid key");
634/// # }
635/// ```
636#[cfg(feature = "encryption")]
637pub struct EncryptionMiddleware {
638    /// Message encryptor instance
639    encryptor: celers_protocol::crypto::MessageEncryptor,
640}
641
642#[cfg(feature = "encryption")]
643impl EncryptionMiddleware {
644    /// Create a new encryption middleware
645    ///
646    /// # Arguments
647    ///
648    /// * `key` - 32-byte secret key for AES-256
649    ///
650    /// # Returns
651    ///
652    /// `Ok(EncryptionMiddleware)` if the key is valid, `Err(BrokerError)` otherwise
653    pub fn new(key: &[u8]) -> Result<Self> {
654        let encryptor = celers_protocol::crypto::MessageEncryptor::new(key)
655            .map_err(|e| BrokerError::Configuration(e.to_string()))?;
656
657        Ok(Self { encryptor })
658    }
659}
660
661#[cfg(feature = "encryption")]
662#[async_trait]
663impl MessageMiddleware for EncryptionMiddleware {
664    async fn before_publish(&self, message: &mut Message) -> Result<()> {
665        // Encrypt the message body
666        let (ciphertext, nonce) = self
667            .encryptor
668            .encrypt(&message.body)
669            .map_err(|e| BrokerError::Serialization(e.to_string()))?;
670
671        // In a real implementation, we'd store the nonce in message headers
672        // For now, we'll prepend the nonce to the ciphertext
673        let mut encrypted = nonce.to_vec();
674        encrypted.extend_from_slice(&ciphertext);
675        message.body = encrypted;
676
677        Ok(())
678    }
679
680    async fn after_consume(&self, message: &mut Message) -> Result<()> {
681        // Extract nonce and ciphertext
682        if message.body.len() < celers_protocol::crypto::NONCE_SIZE {
683            return Err(BrokerError::Serialization(
684                "Message too short to contain nonce".to_string(),
685            ));
686        }
687
688        let (nonce_bytes, ciphertext) = message.body.split_at(celers_protocol::crypto::NONCE_SIZE);
689
690        // Decrypt the message body
691        let plaintext = self
692            .encryptor
693            .decrypt(ciphertext, nonce_bytes)
694            .map_err(|e| BrokerError::Serialization(e.to_string()))?;
695
696        message.body = plaintext;
697        Ok(())
698    }
699
700    fn name(&self) -> &str {
701        "encryption"
702    }
703}
704
705/// Timeout middleware - enforces message processing time limits
706///
707/// # Examples
708///
709/// ```
710/// use celers_kombu::TimeoutMiddleware;
711/// use std::time::Duration;
712///
713/// // Set 30 second timeout for message processing
714/// let middleware = TimeoutMiddleware::new(Duration::from_secs(30));
715/// ```
716pub struct TimeoutMiddleware {
717    timeout: Duration,
718}
719
720impl TimeoutMiddleware {
721    /// Create a new timeout middleware
722    pub fn new(timeout: Duration) -> Self {
723        Self { timeout }
724    }
725
726    /// Get the configured timeout
727    pub fn timeout(&self) -> Duration {
728        self.timeout
729    }
730}
731
732#[async_trait]
733impl MessageMiddleware for TimeoutMiddleware {
734    async fn before_publish(&self, message: &mut Message) -> Result<()> {
735        // Store timeout in message headers for consumer
736        message.headers.extra.insert(
737            "x-timeout-ms".to_string(),
738            serde_json::Value::Number((self.timeout.as_millis() as u64).into()),
739        );
740        Ok(())
741    }
742
743    async fn after_consume(&self, _message: &mut Message) -> Result<()> {
744        // Timeout checking is implementation-specific and would be handled
745        // by the consumer/worker. This middleware just sets the metadata.
746        Ok(())
747    }
748
749    fn name(&self) -> &str {
750        "timeout"
751    }
752}
753
754/// Filter middleware - filters messages based on custom criteria
755///
756/// # Examples
757///
758/// ```
759/// use celers_kombu::FilterMiddleware;
760/// use celers_protocol::Message;
761///
762/// // Create filter that only allows high-priority tasks
763/// let filter = FilterMiddleware::new(|msg: &Message| {
764///     msg.task_name().starts_with("critical_")
765/// });
766/// ```
767pub struct FilterMiddleware {
768    predicate: Box<dyn Fn(&Message) -> bool + Send + Sync>,
769}
770
771impl FilterMiddleware {
772    /// Create a new filter middleware with a predicate function
773    pub fn new<F>(predicate: F) -> Self
774    where
775        F: Fn(&Message) -> bool + Send + Sync + 'static,
776    {
777        Self {
778            predicate: Box::new(predicate),
779        }
780    }
781
782    /// Check if a message passes the filter
783    pub fn matches(&self, message: &Message) -> bool {
784        (self.predicate)(message)
785    }
786}
787
788#[async_trait]
789impl MessageMiddleware for FilterMiddleware {
790    async fn before_publish(&self, _message: &mut Message) -> Result<()> {
791        // No filtering on publish
792        Ok(())
793    }
794
795    async fn after_consume(&self, message: &mut Message) -> Result<()> {
796        if !self.matches(message) {
797            return Err(BrokerError::Configuration(
798                "Message filtered out by predicate".to_string(),
799            ));
800        }
801        Ok(())
802    }
803
804    fn name(&self) -> &str {
805        "filter"
806    }
807}
808
809/// Sampling middleware for statistical message sampling.
810///
811/// Allows only a percentage of messages to pass through, useful for
812/// monitoring, testing, or load reduction.
813///
814/// # Examples
815///
816/// ```
817/// use celers_kombu::SamplingMiddleware;
818///
819/// // Sample 10% of messages
820/// let sampler = SamplingMiddleware::new(0.1);
821/// assert_eq!(sampler.sample_rate(), 0.1);
822/// ```
823pub struct SamplingMiddleware {
824    sample_rate: f64,
825    counter: std::sync::atomic::AtomicU64,
826}
827
828impl SamplingMiddleware {
829    /// Create a new sampling middleware with the given sample rate.
830    ///
831    /// Sample rate should be between 0.0 and 1.0, where:
832    /// - 0.0 = sample nothing
833    /// - 1.0 = sample everything
834    /// - 0.1 = sample approximately 10% of messages
835    pub fn new(sample_rate: f64) -> Self {
836        Self {
837            sample_rate: sample_rate.clamp(0.0, 1.0),
838            counter: std::sync::atomic::AtomicU64::new(0),
839        }
840    }
841
842    /// Get the configured sample rate
843    pub fn sample_rate(&self) -> f64 {
844        self.sample_rate
845    }
846
847    /// Check if a message should be sampled
848    fn should_sample(&self) -> bool {
849        let count = self
850            .counter
851            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
852        // Deterministic sampling based on counter
853        let threshold = (u64::MAX as f64 * self.sample_rate) as u64;
854        (count % u64::MAX) < threshold
855    }
856}
857
858#[async_trait]
859impl MessageMiddleware for SamplingMiddleware {
860    async fn before_publish(&self, _message: &mut Message) -> Result<()> {
861        // No sampling on publish
862        Ok(())
863    }
864
865    async fn after_consume(&self, _message: &mut Message) -> Result<()> {
866        if !self.should_sample() {
867            return Err(BrokerError::Configuration(
868                "Message filtered out by sampling".to_string(),
869            ));
870        }
871        Ok(())
872    }
873
874    fn name(&self) -> &str {
875        "sampling"
876    }
877}
878
879/// Transformation middleware for message content transformation.
880///
881/// Applies a transformation function to message bodies during processing.
882///
883/// # Examples
884///
885/// ```
886/// use celers_kombu::TransformationMiddleware;
887///
888/// // Create a transformer that uppercases text
889/// let transformer = TransformationMiddleware::new(|body: Vec<u8>| {
890///     String::from_utf8_lossy(&body).to_uppercase().into_bytes()
891/// });
892/// ```
893pub struct TransformationMiddleware {
894    transform_fn: Box<dyn Fn(Vec<u8>) -> Vec<u8> + Send + Sync>,
895}
896
897impl TransformationMiddleware {
898    /// Create a new transformation middleware with a transform function
899    pub fn new<F>(transform_fn: F) -> Self
900    where
901        F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
902    {
903        Self {
904            transform_fn: Box::new(transform_fn),
905        }
906    }
907
908    /// Apply the transformation to message body
909    fn transform(&self, body: Vec<u8>) -> Vec<u8> {
910        (self.transform_fn)(body)
911    }
912}
913
914#[async_trait]
915impl MessageMiddleware for TransformationMiddleware {
916    async fn before_publish(&self, message: &mut Message) -> Result<()> {
917        // Transform on publish
918        let transformed = self.transform(message.body.clone());
919        message.body = transformed;
920        Ok(())
921    }
922
923    async fn after_consume(&self, message: &mut Message) -> Result<()> {
924        // Transform on consume
925        let transformed = self.transform(message.body.clone());
926        message.body = transformed;
927        Ok(())
928    }
929
930    fn name(&self) -> &str {
931        "transformation"
932    }
933}
934
935/// Tracing middleware for distributed tracing
936///
937/// Injects trace IDs into message headers for distributed tracing.
938///
939/// # Examples
940///
941/// ```
942/// use celers_kombu::{TracingMiddleware, MessageMiddleware};
943///
944/// let middleware = TracingMiddleware::new("service-name");
945/// assert_eq!(middleware.name(), "tracing");
946/// ```
947#[derive(Debug, Clone)]
948pub struct TracingMiddleware {
949    service_name: String,
950}
951
952impl TracingMiddleware {
953    /// Create a new tracing middleware
954    pub fn new(service_name: impl Into<String>) -> Self {
955        Self {
956            service_name: service_name.into(),
957        }
958    }
959}
960
961#[async_trait]
962impl MessageMiddleware for TracingMiddleware {
963    async fn before_publish(&self, message: &mut Message) -> Result<()> {
964        // Inject trace ID if not present
965        if !message.headers.extra.contains_key("trace-id") {
966            let trace_id = uuid::Uuid::new_v4().to_string();
967            message
968                .headers
969                .extra
970                .insert("trace-id".to_string(), serde_json::json!(trace_id));
971        }
972
973        // Add service name
974        message.headers.extra.insert(
975            "service-name".to_string(),
976            serde_json::json!(self.service_name.clone()),
977        );
978
979        // Add span ID for this operation
980        let span_id = uuid::Uuid::new_v4().to_string();
981        message
982            .headers
983            .extra
984            .insert("span-id".to_string(), serde_json::json!(span_id));
985
986        // Add timestamp
987        message.headers.extra.insert(
988            "trace-timestamp".to_string(),
989            serde_json::json!(std::time::SystemTime::now()
990                .duration_since(std::time::UNIX_EPOCH)
991                .unwrap()
992                .as_millis()),
993        );
994
995        Ok(())
996    }
997
998    async fn after_consume(&self, message: &mut Message) -> Result<()> {
999        // Extract and log trace information
1000        if let Some(trace_id) = message.headers.extra.get("trace-id").cloned() {
1001            // In production, this would be sent to a tracing system
1002            message.headers.extra.insert(
1003                "consumer-service".to_string(),
1004                serde_json::json!(self.service_name.clone()),
1005            );
1006            message
1007                .headers
1008                .extra
1009                .insert("trace-id-consumed".to_string(), trace_id);
1010        }
1011        Ok(())
1012    }
1013
1014    fn name(&self) -> &str {
1015        "tracing"
1016    }
1017}