Skip to main content

celers_kombu/
middleware_extended.rs

1//! Extended middleware implementations.
2
3use async_trait::async_trait;
4use celers_protocol::Message;
5use std::collections::HashMap;
6use std::time::Duration;
7
8use uuid::Uuid;
9
10use crate::{BrokerError, MessageMiddleware, Result};
11
12/// Error classification middleware for intelligent error routing
13///
14/// This middleware classifies errors into categories and can route messages
15/// to different queues based on error type (e.g., transient vs permanent errors).
16///
17/// # Examples
18///
19/// ```
20/// use celers_kombu::ErrorClassificationMiddleware;
21///
22/// let classifier = ErrorClassificationMiddleware::new()
23///     .with_transient_pattern("timeout|connection")
24///     .with_permanent_pattern("validation|schema")
25///     .with_max_transient_retries(5);
26/// ```
27pub struct ErrorClassificationMiddleware {
28    transient_patterns: Vec<String>,
29    permanent_patterns: Vec<String>,
30    max_transient_retries: u32,
31    max_permanent_retries: u32,
32}
33
34impl ErrorClassificationMiddleware {
35    /// Create a new error classification middleware
36    pub fn new() -> Self {
37        Self {
38            transient_patterns: vec![
39                "timeout".to_string(),
40                "connection".to_string(),
41                "network".to_string(),
42                "unavailable".to_string(),
43            ],
44            permanent_patterns: vec![
45                "validation".to_string(),
46                "schema".to_string(),
47                "invalid".to_string(),
48                "forbidden".to_string(),
49            ],
50            max_transient_retries: 10,
51            max_permanent_retries: 1,
52        }
53    }
54
55    /// Add a pattern for transient errors (can be retried)
56    pub fn with_transient_pattern(mut self, pattern: &str) -> Self {
57        self.transient_patterns.push(pattern.to_string());
58        self
59    }
60
61    /// Add a pattern for permanent errors (should not be retried)
62    pub fn with_permanent_pattern(mut self, pattern: &str) -> Self {
63        self.permanent_patterns.push(pattern.to_string());
64        self
65    }
66
67    /// Set maximum retries for transient errors
68    pub fn with_max_transient_retries(mut self, max_retries: u32) -> Self {
69        self.max_transient_retries = max_retries;
70        self
71    }
72
73    /// Set maximum retries for permanent errors
74    pub fn with_max_permanent_retries(mut self, max_retries: u32) -> Self {
75        self.max_permanent_retries = max_retries;
76        self
77    }
78
79    /// Classify an error message
80    pub fn classify_error(&self, error_msg: &str) -> ErrorClass {
81        let error_lower = error_msg.to_lowercase();
82
83        // Check for permanent errors first (more specific)
84        for pattern in &self.permanent_patterns {
85            if error_lower.contains(&pattern.to_lowercase()) {
86                return ErrorClass::Permanent;
87            }
88        }
89
90        // Check for transient errors
91        for pattern in &self.transient_patterns {
92            if error_lower.contains(&pattern.to_lowercase()) {
93                return ErrorClass::Transient;
94            }
95        }
96
97        // Unknown errors are treated as transient by default
98        ErrorClass::Unknown
99    }
100
101    /// Determine if a message should be retried based on error classification
102    pub fn should_retry(&self, error_msg: &str, current_retries: u32) -> bool {
103        match self.classify_error(error_msg) {
104            ErrorClass::Transient => current_retries < self.max_transient_retries,
105            ErrorClass::Permanent => current_retries < self.max_permanent_retries,
106            ErrorClass::Unknown => current_retries < self.max_transient_retries,
107        }
108    }
109}
110
111impl Default for ErrorClassificationMiddleware {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117/// Error classification categories
118#[derive(Debug, Clone, Copy, PartialEq, Eq)]
119pub enum ErrorClass {
120    /// Transient errors that should be retried (e.g., network timeouts)
121    Transient,
122    /// Permanent errors that should not be retried (e.g., validation errors)
123    Permanent,
124    /// Unknown error type (treated as transient by default)
125    Unknown,
126}
127
128#[async_trait]
129impl MessageMiddleware for ErrorClassificationMiddleware {
130    async fn before_publish(&self, message: &mut Message) -> Result<()> {
131        // Check if message has error information
132        if let Some(error_value) = message.headers.extra.get("error") {
133            if let Some(error_msg) = error_value.as_str() {
134                let error_class = self.classify_error(error_msg);
135                let should_retry =
136                    self.should_retry(error_msg, message.headers.retries.unwrap_or(0));
137
138                message.headers.extra.insert(
139                    "x-error-class".to_string(),
140                    serde_json::json!(match error_class {
141                        ErrorClass::Transient => "transient",
142                        ErrorClass::Permanent => "permanent",
143                        ErrorClass::Unknown => "unknown",
144                    }),
145                );
146
147                message.headers.extra.insert(
148                    "x-should-retry".to_string(),
149                    serde_json::json!(should_retry),
150                );
151
152                if !should_retry {
153                    message.headers.extra.insert(
154                        "x-max-retries-exceeded".to_string(),
155                        serde_json::json!(true),
156                    );
157                }
158            }
159        }
160        Ok(())
161    }
162
163    async fn after_consume(&self, _message: &mut Message) -> Result<()> {
164        // No action needed on consume
165        Ok(())
166    }
167
168    fn name(&self) -> &str {
169        "error_classification"
170    }
171}
172
173/// Correlation middleware for distributed tracing
174///
175/// Automatically generates and propagates correlation IDs across service boundaries
176/// for distributed tracing and request tracking.
177///
178/// # Examples
179///
180/// ```
181/// use celers_kombu::CorrelationMiddleware;
182///
183/// let correlation = CorrelationMiddleware::new();
184/// // Automatically generates correlation ID if not present
185/// // Propagates existing correlation ID from headers
186/// ```
187pub struct CorrelationMiddleware {
188    header_name: String,
189}
190
191impl CorrelationMiddleware {
192    /// Create a new correlation middleware
193    pub fn new() -> Self {
194        Self {
195            header_name: "x-correlation-id".to_string(),
196        }
197    }
198
199    /// Create with custom header name
200    pub fn with_header_name(header_name: &str) -> Self {
201        Self {
202            header_name: header_name.to_string(),
203        }
204    }
205
206    fn get_or_generate_correlation_id(&self, message: &Message) -> String {
207        message
208            .headers
209            .extra
210            .get(&self.header_name)
211            .and_then(|v| v.as_str())
212            .map(|s| s.to_string())
213            .unwrap_or_else(|| Uuid::new_v4().to_string())
214    }
215}
216
217impl Default for CorrelationMiddleware {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223#[async_trait]
224impl MessageMiddleware for CorrelationMiddleware {
225    async fn before_publish(&self, message: &mut Message) -> Result<()> {
226        let correlation_id = self.get_or_generate_correlation_id(message);
227        message
228            .headers
229            .extra
230            .insert(self.header_name.clone(), serde_json::json!(correlation_id));
231        Ok(())
232    }
233
234    async fn after_consume(&self, message: &mut Message) -> Result<()> {
235        // Ensure correlation ID is present for downstream processing
236        let correlation_id = self.get_or_generate_correlation_id(message);
237        message
238            .headers
239            .extra
240            .insert(self.header_name.clone(), serde_json::json!(correlation_id));
241        Ok(())
242    }
243
244    fn name(&self) -> &str {
245        "correlation"
246    }
247}
248
249/// Throttling middleware with backpressure support
250///
251/// Implements advanced throttling with configurable backpressure behavior.
252/// Unlike rate limiting which rejects messages, throttling delays them.
253///
254/// # Examples
255///
256/// ```
257/// use celers_kombu::ThrottlingMiddleware;
258/// use std::time::Duration;
259///
260/// let throttle = ThrottlingMiddleware::new(100.0)  // 100 msg/sec
261///     .with_burst_size(200)
262///     .with_backpressure_threshold(0.8);
263/// ```
264pub struct ThrottlingMiddleware {
265    pub(crate) max_rate: f64,
266    pub(crate) burst_size: usize,
267    pub(crate) backpressure_threshold: f64,
268    last_refill: std::sync::Mutex<std::time::Instant>,
269    available_tokens: std::sync::Mutex<f64>,
270}
271
272impl ThrottlingMiddleware {
273    /// Create a new throttling middleware
274    pub fn new(max_rate: f64) -> Self {
275        Self {
276            max_rate,
277            burst_size: (max_rate * 2.0) as usize,
278            backpressure_threshold: 0.8,
279            last_refill: std::sync::Mutex::new(std::time::Instant::now()),
280            available_tokens: std::sync::Mutex::new(max_rate),
281        }
282    }
283
284    /// Set burst size (maximum tokens that can accumulate)
285    pub fn with_burst_size(mut self, size: usize) -> Self {
286        self.burst_size = size;
287        self
288    }
289
290    /// Set backpressure threshold (0.0-1.0)
291    pub fn with_backpressure_threshold(mut self, threshold: f64) -> Self {
292        self.backpressure_threshold = threshold.clamp(0.0, 1.0);
293        self
294    }
295
296    fn refill_tokens(&self) {
297        let mut last_refill = self.last_refill.lock().unwrap();
298        let mut tokens = self.available_tokens.lock().unwrap();
299
300        let now = std::time::Instant::now();
301        let elapsed = now.duration_since(*last_refill).as_secs_f64();
302
303        let new_tokens = elapsed * self.max_rate;
304        *tokens = (*tokens + new_tokens).min(self.burst_size as f64);
305        *last_refill = now;
306    }
307
308    fn calculate_delay(&self) -> Duration {
309        self.refill_tokens();
310        let tokens = self.available_tokens.lock().unwrap();
311
312        if *tokens >= 1.0 {
313            Duration::from_millis(0)
314        } else {
315            let wait_time = (1.0 - *tokens) / self.max_rate;
316            Duration::from_secs_f64(wait_time)
317        }
318    }
319
320    fn should_apply_backpressure(&self) -> bool {
321        self.refill_tokens();
322        let tokens = self.available_tokens.lock().unwrap();
323        (*tokens / self.burst_size as f64) < (1.0 - self.backpressure_threshold)
324    }
325}
326
327#[async_trait]
328impl MessageMiddleware for ThrottlingMiddleware {
329    async fn before_publish(&self, message: &mut Message) -> Result<()> {
330        let delay = self.calculate_delay();
331
332        if delay > Duration::from_millis(0) {
333            message.headers.extra.insert(
334                "x-throttle-delay-ms".to_string(),
335                serde_json::json!(delay.as_millis()),
336            );
337        }
338
339        if self.should_apply_backpressure() {
340            message
341                .headers
342                .extra
343                .insert("x-backpressure-active".to_string(), serde_json::json!(true));
344        }
345
346        // Consume a token
347        let mut tokens = self.available_tokens.lock().unwrap();
348        if *tokens >= 1.0 {
349            *tokens -= 1.0;
350        }
351
352        Ok(())
353    }
354
355    async fn after_consume(&self, _message: &mut Message) -> Result<()> {
356        // No action needed
357        Ok(())
358    }
359
360    fn name(&self) -> &str {
361        "throttling"
362    }
363}
364
365/// Circuit breaker middleware for fault tolerance
366///
367/// Implements the circuit breaker pattern to prevent cascading failures.
368/// Tracks failures and opens the circuit after a threshold is reached.
369///
370/// # Examples
371///
372/// ```
373/// use celers_kombu::CircuitBreakerMiddleware;
374/// use std::time::Duration;
375///
376/// let breaker = CircuitBreakerMiddleware::new(5, Duration::from_secs(60));
377/// // Opens circuit after 5 failures within 60 seconds
378/// ```
379pub struct CircuitBreakerMiddleware {
380    pub(crate) failure_threshold: usize,
381    window: Duration,
382    failures: std::sync::Mutex<Vec<std::time::Instant>>,
383}
384
385impl CircuitBreakerMiddleware {
386    /// Create a new circuit breaker middleware
387    pub fn new(failure_threshold: usize, window: Duration) -> Self {
388        Self {
389            failure_threshold,
390            window,
391            failures: std::sync::Mutex::new(Vec::new()),
392        }
393    }
394
395    fn record_failure(&self) {
396        let mut failures = self.failures.lock().unwrap();
397        let now = std::time::Instant::now();
398
399        // Remove old failures outside the window
400        failures.retain(|&f| now.duration_since(f) < self.window);
401
402        // Add new failure
403        failures.push(now);
404    }
405
406    fn is_circuit_open(&self) -> bool {
407        let mut failures = self.failures.lock().unwrap();
408        let now = std::time::Instant::now();
409
410        // Clean up old failures
411        failures.retain(|&f| now.duration_since(f) < self.window);
412
413        failures.len() >= self.failure_threshold
414    }
415
416    fn get_failure_count(&self) -> usize {
417        let mut failures = self.failures.lock().unwrap();
418        let now = std::time::Instant::now();
419        failures.retain(|&f| now.duration_since(f) < self.window);
420        failures.len()
421    }
422}
423
424#[async_trait]
425impl MessageMiddleware for CircuitBreakerMiddleware {
426    async fn before_publish(&self, message: &mut Message) -> Result<()> {
427        if self.is_circuit_open() {
428            message.headers.extra.insert(
429                "x-circuit-breaker-open".to_string(),
430                serde_json::json!(true),
431            );
432            message.headers.extra.insert(
433                "x-circuit-breaker-failures".to_string(),
434                serde_json::json!(self.get_failure_count()),
435            );
436            return Err(BrokerError::OperationFailed(
437                "Circuit breaker is open".to_string(),
438            ));
439        }
440        Ok(())
441    }
442
443    async fn after_consume(&self, message: &mut Message) -> Result<()> {
444        // Check if message has error indicator
445        if message.headers.extra.contains_key("error") {
446            self.record_failure();
447        }
448
449        // Add circuit state to message
450        message.headers.extra.insert(
451            "x-circuit-breaker-failures".to_string(),
452            serde_json::json!(self.get_failure_count()),
453        );
454
455        Ok(())
456    }
457
458    fn name(&self) -> &str {
459        "circuit_breaker"
460    }
461}
462
463/// Schema validation middleware - validates message structure and content
464///
465/// This middleware validates messages against configured rules before publishing
466/// and after consuming.
467///
468/// # Examples
469///
470/// ```
471/// use celers_kombu::SchemaValidationMiddleware;
472///
473/// let validator = SchemaValidationMiddleware::new()
474///     .with_required_field("user_id")
475///     .with_required_field("action")
476///     .with_max_field_count(20);
477/// ```
478pub struct SchemaValidationMiddleware {
479    pub(crate) required_fields: Vec<String>,
480    pub(crate) max_field_count: Option<usize>,
481    min_body_size: Option<usize>,
482    pub(crate) max_body_size: Option<usize>,
483}
484
485impl SchemaValidationMiddleware {
486    /// Create a new schema validation middleware
487    pub fn new() -> Self {
488        Self {
489            required_fields: Vec::new(),
490            max_field_count: None,
491            min_body_size: None,
492            max_body_size: None,
493        }
494    }
495
496    /// Add a required field to validation
497    pub fn with_required_field(mut self, field: impl Into<String>) -> Self {
498        self.required_fields.push(field.into());
499        self
500    }
501
502    /// Set maximum field count
503    pub fn with_max_field_count(mut self, count: usize) -> Self {
504        self.max_field_count = Some(count);
505        self
506    }
507
508    /// Set minimum body size
509    pub fn with_min_body_size(mut self, size: usize) -> Self {
510        self.min_body_size = Some(size);
511        self
512    }
513
514    /// Set maximum body size
515    pub fn with_max_body_size(mut self, size: usize) -> Self {
516        self.max_body_size = Some(size);
517        self
518    }
519
520    fn validate_message(&self, message: &Message) -> Result<()> {
521        // Check required fields
522        for field in &self.required_fields {
523            if !message.headers.extra.contains_key(field) {
524                return Err(BrokerError::Configuration(format!(
525                    "Missing required field: {}",
526                    field
527                )));
528            }
529        }
530
531        // Check field count
532        if let Some(max) = self.max_field_count {
533            if message.headers.extra.len() > max {
534                return Err(BrokerError::Configuration(format!(
535                    "Too many fields: {} > {}",
536                    message.headers.extra.len(),
537                    max
538                )));
539            }
540        }
541
542        // Check body size
543        let body_len = message.body.len();
544        if let Some(min) = self.min_body_size {
545            if body_len < min {
546                return Err(BrokerError::Configuration(format!(
547                    "Body too small: {} < {}",
548                    body_len, min
549                )));
550            }
551        }
552        if let Some(max) = self.max_body_size {
553            if body_len > max {
554                return Err(BrokerError::Configuration(format!(
555                    "Body too large: {} > {}",
556                    body_len, max
557                )));
558            }
559        }
560
561        Ok(())
562    }
563}
564
565impl Default for SchemaValidationMiddleware {
566    fn default() -> Self {
567        Self::new()
568    }
569}
570
571#[async_trait]
572impl MessageMiddleware for SchemaValidationMiddleware {
573    async fn before_publish(&self, message: &mut Message) -> Result<()> {
574        self.validate_message(message)?;
575        message
576            .headers
577            .extra
578            .insert("x-schema-validated".to_string(), serde_json::json!(true));
579        Ok(())
580    }
581
582    async fn after_consume(&self, message: &mut Message) -> Result<()> {
583        self.validate_message(message)
584    }
585
586    fn name(&self) -> &str {
587        "schema_validation"
588    }
589}
590
591/// Message enrichment middleware - automatically adds metadata
592///
593/// This middleware enriches messages with contextual metadata such as
594/// hostname, environment, version, and timestamps.
595///
596/// # Examples
597///
598/// ```
599/// use celers_kombu::MessageEnrichmentMiddleware;
600///
601/// let enricher = MessageEnrichmentMiddleware::new()
602///     .with_hostname("worker-01")
603///     .with_environment("production")
604///     .with_version("1.0.0")
605///     .with_add_timestamp(true);
606/// ```
607pub struct MessageEnrichmentMiddleware {
608    pub(crate) hostname: Option<String>,
609    pub(crate) environment: Option<String>,
610    pub(crate) version: Option<String>,
611    pub(crate) add_timestamp: bool,
612    custom_metadata: HashMap<String, serde_json::Value>,
613}
614
615impl MessageEnrichmentMiddleware {
616    /// Create a new message enrichment middleware
617    pub fn new() -> Self {
618        Self {
619            hostname: None,
620            environment: None,
621            version: None,
622            add_timestamp: false,
623            custom_metadata: HashMap::new(),
624        }
625    }
626
627    /// Set hostname metadata
628    pub fn with_hostname(mut self, hostname: impl Into<String>) -> Self {
629        self.hostname = Some(hostname.into());
630        self
631    }
632
633    /// Set environment metadata
634    pub fn with_environment(mut self, environment: impl Into<String>) -> Self {
635        self.environment = Some(environment.into());
636        self
637    }
638
639    /// Set version metadata
640    pub fn with_version(mut self, version: impl Into<String>) -> Self {
641        self.version = Some(version.into());
642        self
643    }
644
645    /// Enable timestamp enrichment
646    pub fn with_add_timestamp(mut self, add: bool) -> Self {
647        self.add_timestamp = add;
648        self
649    }
650
651    /// Add custom metadata
652    pub fn with_custom_metadata(
653        mut self,
654        key: impl Into<String>,
655        value: serde_json::Value,
656    ) -> Self {
657        self.custom_metadata.insert(key.into(), value);
658        self
659    }
660
661    fn enrich_message(&self, message: &mut Message) {
662        if let Some(ref hostname) = self.hostname {
663            message.headers.extra.insert(
664                "x-enrichment-hostname".to_string(),
665                serde_json::json!(hostname),
666            );
667        }
668
669        if let Some(ref environment) = self.environment {
670            message.headers.extra.insert(
671                "x-enrichment-environment".to_string(),
672                serde_json::json!(environment),
673            );
674        }
675
676        if let Some(ref version) = self.version {
677            message.headers.extra.insert(
678                "x-enrichment-version".to_string(),
679                serde_json::json!(version),
680            );
681        }
682
683        if self.add_timestamp {
684            let timestamp = std::time::SystemTime::now()
685                .duration_since(std::time::UNIX_EPOCH)
686                .unwrap()
687                .as_secs();
688            message.headers.extra.insert(
689                "x-enrichment-timestamp".to_string(),
690                serde_json::json!(timestamp),
691            );
692        }
693
694        for (key, value) in &self.custom_metadata {
695            message
696                .headers
697                .extra
698                .insert(format!("x-enrichment-{}", key), value.clone());
699        }
700    }
701}
702
703impl Default for MessageEnrichmentMiddleware {
704    fn default() -> Self {
705        Self::new()
706    }
707}
708
709#[async_trait]
710impl MessageMiddleware for MessageEnrichmentMiddleware {
711    async fn before_publish(&self, message: &mut Message) -> Result<()> {
712        self.enrich_message(message);
713        Ok(())
714    }
715
716    async fn after_consume(&self, _message: &mut Message) -> Result<()> {
717        // No action needed on consume
718        Ok(())
719    }
720
721    fn name(&self) -> &str {
722        "message_enrichment"
723    }
724}
725
726/// Retry strategy for intelligent retry handling
727#[derive(Debug, Clone, Copy, PartialEq, Eq)]
728pub enum RetryStrategy {
729    /// Exponential backoff (delay doubles each time)
730    Exponential,
731    /// Linear backoff (delay increases by fixed amount)
732    Linear,
733    /// Fibonacci backoff (follows fibonacci sequence)
734    Fibonacci,
735    /// Fixed delay (constant delay between retries)
736    Fixed,
737}
738
739/// Retry strategy middleware - implements different retry strategies
740///
741/// This middleware applies various retry strategies to failed messages,
742/// calculating appropriate delays based on the retry count.
743///
744/// # Examples
745///
746/// ```
747/// use celers_kombu::{RetryStrategyMiddleware, RetryStrategy};
748/// use std::time::Duration;
749///
750/// let middleware = RetryStrategyMiddleware::new(RetryStrategy::Exponential)
751///     .with_base_delay(Duration::from_secs(1))
752///     .with_max_delay(Duration::from_secs(300))
753///     .with_max_retries(5);
754/// ```
755pub struct RetryStrategyMiddleware {
756    strategy: RetryStrategy,
757    base_delay_ms: u64,
758    max_delay_ms: u64,
759    max_retries: u32,
760}
761
762impl RetryStrategyMiddleware {
763    /// Create a new retry strategy middleware
764    pub fn new(strategy: RetryStrategy) -> Self {
765        Self {
766            strategy,
767            base_delay_ms: 1000,   // 1 second default
768            max_delay_ms: 300_000, // 5 minutes default
769            max_retries: 5,
770        }
771    }
772
773    /// Set the base delay
774    pub fn with_base_delay(mut self, delay: Duration) -> Self {
775        self.base_delay_ms = delay.as_millis() as u64;
776        self
777    }
778
779    /// Set the maximum delay
780    pub fn with_max_delay(mut self, delay: Duration) -> Self {
781        self.max_delay_ms = delay.as_millis() as u64;
782        self
783    }
784
785    /// Set the maximum number of retries
786    pub fn with_max_retries(mut self, retries: u32) -> Self {
787        self.max_retries = retries;
788        self
789    }
790
791    fn calculate_delay(&self, retry_count: u32) -> u64 {
792        let delay = match self.strategy {
793            RetryStrategy::Exponential => {
794                // delay = base * 2^retry_count
795                self.base_delay_ms * 2_u64.pow(retry_count)
796            }
797            RetryStrategy::Linear => {
798                // delay = base * retry_count
799                self.base_delay_ms * (retry_count as u64 + 1)
800            }
801            RetryStrategy::Fibonacci => {
802                // Calculate fibonacci number for retry_count
803                let fib = self.fibonacci(retry_count as usize);
804                self.base_delay_ms * fib
805            }
806            RetryStrategy::Fixed => {
807                // Always use base delay
808                self.base_delay_ms
809            }
810        };
811
812        delay.min(self.max_delay_ms)
813    }
814
815    fn fibonacci(&self, n: usize) -> u64 {
816        match n {
817            0 => 1,
818            1 => 1,
819            _ => {
820                let mut a = 1u64;
821                let mut b = 1u64;
822                for _ in 2..=n {
823                    let temp = a + b;
824                    a = b;
825                    b = temp;
826                }
827                b
828            }
829        }
830    }
831}
832
833impl Default for RetryStrategyMiddleware {
834    fn default() -> Self {
835        Self::new(RetryStrategy::Exponential)
836    }
837}
838
839#[async_trait]
840impl MessageMiddleware for RetryStrategyMiddleware {
841    async fn before_publish(&self, message: &mut Message) -> Result<()> {
842        // Get retry count from headers
843        let retry_count = message
844            .headers
845            .extra
846            .get("x-retry-count")
847            .and_then(|v| v.as_u64())
848            .unwrap_or(0) as u32;
849
850        // Check if max retries exceeded
851        if retry_count >= self.max_retries {
852            return Err(BrokerError::OperationFailed(format!(
853                "Max retries ({}) exceeded",
854                self.max_retries
855            )));
856        }
857
858        // Calculate and set delay
859        let delay_ms = self.calculate_delay(retry_count);
860        message
861            .headers
862            .extra
863            .insert("x-retry-delay-ms".to_string(), serde_json::json!(delay_ms));
864        message.headers.extra.insert(
865            "x-retry-strategy".to_string(),
866            serde_json::json!(format!("{:?}", self.strategy)),
867        );
868
869        Ok(())
870    }
871
872    async fn after_consume(&self, _message: &mut Message) -> Result<()> {
873        // No action needed on consume
874        Ok(())
875    }
876
877    fn name(&self) -> &str {
878        "retry_strategy"
879    }
880}
881
882/// Tenant isolation middleware - provides multi-tenancy support
883///
884/// This middleware enforces tenant isolation by validating and routing
885/// messages based on tenant identifiers.
886///
887/// # Examples
888///
889/// ```
890/// use celers_kombu::TenantIsolationMiddleware;
891///
892/// let middleware = TenantIsolationMiddleware::new()
893///     .with_required_tenant(true)
894///     .with_tenant_header("x-tenant-id");
895/// ```
896pub struct TenantIsolationMiddleware {
897    required: bool,
898    tenant_header: String,
899    allowed_tenants: Option<Vec<String>>,
900}
901
902impl TenantIsolationMiddleware {
903    /// Create a new tenant isolation middleware
904    pub fn new() -> Self {
905        Self {
906            required: true,
907            tenant_header: "x-tenant-id".to_string(),
908            allowed_tenants: None,
909        }
910    }
911
912    /// Set whether tenant ID is required
913    pub fn with_required_tenant(mut self, required: bool) -> Self {
914        self.required = required;
915        self
916    }
917
918    /// Set the tenant header name
919    pub fn with_tenant_header(mut self, header: impl Into<String>) -> Self {
920        self.tenant_header = header.into();
921        self
922    }
923
924    /// Set allowed tenants (whitelist)
925    pub fn with_allowed_tenants(mut self, tenants: Vec<String>) -> Self {
926        self.allowed_tenants = Some(tenants);
927        self
928    }
929
930    fn validate_tenant(&self, tenant_id: Option<&str>) -> Result<()> {
931        // Check if tenant is required but missing
932        if self.required && tenant_id.is_none() {
933            return Err(BrokerError::Configuration(format!(
934                "Missing required tenant header: {}",
935                self.tenant_header
936            )));
937        }
938
939        // Check if tenant is in whitelist (if whitelist exists)
940        if let (Some(tenant), Some(allowed)) = (tenant_id, &self.allowed_tenants) {
941            if !allowed.contains(&tenant.to_string()) {
942                return Err(BrokerError::Configuration(format!(
943                    "Tenant '{}' not in allowed list",
944                    tenant
945                )));
946            }
947        }
948
949        Ok(())
950    }
951}
952
953impl Default for TenantIsolationMiddleware {
954    fn default() -> Self {
955        Self::new()
956    }
957}
958
959#[async_trait]
960impl MessageMiddleware for TenantIsolationMiddleware {
961    async fn before_publish(&self, message: &mut Message) -> Result<()> {
962        let tenant_id = message
963            .headers
964            .extra
965            .get(&self.tenant_header)
966            .and_then(|v| v.as_str());
967
968        self.validate_tenant(tenant_id)?;
969
970        // Add tenant validation marker
971        message
972            .headers
973            .extra
974            .insert("x-tenant-validated".to_string(), serde_json::json!(true));
975
976        Ok(())
977    }
978
979    async fn after_consume(&self, message: &mut Message) -> Result<()> {
980        let tenant_id = message
981            .headers
982            .extra
983            .get(&self.tenant_header)
984            .and_then(|v| v.as_str());
985
986        self.validate_tenant(tenant_id)
987    }
988
989    fn name(&self) -> &str {
990        "tenant_isolation"
991    }
992}
993
994/// Partitioning middleware for distributed load balancing
995///
996/// Automatically assigns partition keys to messages for distributed processing.
997/// Useful for ensuring related messages are processed by the same worker.
998///
999/// # Examples
1000///
1001/// ```
1002/// use celers_kombu::PartitioningMiddleware;
1003///
1004/// let partitioner = PartitioningMiddleware::new(8); // 8 partitions
1005/// assert_eq!(partitioner.partition_count(), 8);
1006/// ```
1007#[derive(Debug, Clone)]
1008pub struct PartitioningMiddleware {
1009    partition_count: usize,
1010    partition_header: String,
1011    partition_key_fn: Option<String>, // Field name to use for partitioning
1012}
1013
1014impl PartitioningMiddleware {
1015    /// Create a new partitioning middleware with specified partition count
1016    pub fn new(partition_count: usize) -> Self {
1017        Self {
1018            partition_count: partition_count.max(1),
1019            partition_header: "x-partition-id".to_string(),
1020            partition_key_fn: None,
1021        }
1022    }
1023
1024    /// Set the partition header name
1025    pub fn with_partition_header(mut self, header: impl Into<String>) -> Self {
1026        self.partition_header = header.into();
1027        self
1028    }
1029
1030    /// Set the field name to use for partition key extraction
1031    pub fn with_partition_key_field(mut self, field: impl Into<String>) -> Self {
1032        self.partition_key_fn = Some(field.into());
1033        self
1034    }
1035
1036    /// Get the partition count
1037    pub fn partition_count(&self) -> usize {
1038        self.partition_count
1039    }
1040
1041    fn calculate_partition(&self, message: &Message) -> usize {
1042        use std::collections::hash_map::DefaultHasher;
1043        use std::hash::{Hash, Hasher};
1044
1045        // Try to extract partition key from specified field or use task ID
1046        let task_id_str = message.headers.id.to_string();
1047        let key = if let Some(field) = &self.partition_key_fn {
1048            message
1049                .headers
1050                .extra
1051                .get(field)
1052                .and_then(|v| v.as_str())
1053                .unwrap_or(&task_id_str)
1054        } else {
1055            // Default to task ID for partitioning
1056            &task_id_str
1057        };
1058
1059        // Hash the key to determine partition
1060        let mut hasher = DefaultHasher::new();
1061        key.hash(&mut hasher);
1062        let hash = hasher.finish();
1063
1064        (hash % self.partition_count as u64) as usize
1065    }
1066}
1067
1068impl Default for PartitioningMiddleware {
1069    fn default() -> Self {
1070        Self::new(4) // Default to 4 partitions
1071    }
1072}
1073
1074#[async_trait]
1075impl MessageMiddleware for PartitioningMiddleware {
1076    async fn before_publish(&self, message: &mut Message) -> Result<()> {
1077        let partition_id = self.calculate_partition(message);
1078
1079        // Inject partition ID into message headers
1080        message.headers.extra.insert(
1081            self.partition_header.clone(),
1082            serde_json::json!(partition_id),
1083        );
1084
1085        // Also add total partition count for consumer reference
1086        message.headers.extra.insert(
1087            "x-partition-count".to_string(),
1088            serde_json::json!(self.partition_count),
1089        );
1090
1091        Ok(())
1092    }
1093
1094    async fn after_consume(&self, _message: &mut Message) -> Result<()> {
1095        // No action needed on consume
1096        Ok(())
1097    }
1098
1099    fn name(&self) -> &str {
1100        "partitioning"
1101    }
1102}
1103
1104/// Adaptive timeout middleware with dynamic timeout adjustment
1105///
1106/// Adjusts message timeouts based on historical processing times.
1107/// Helps optimize resource usage and prevents premature timeouts.
1108///
1109/// # Examples
1110///
1111/// ```
1112/// use celers_kombu::AdaptiveTimeoutMiddleware;
1113/// use std::time::Duration;
1114///
1115/// let adaptive = AdaptiveTimeoutMiddleware::new(Duration::from_secs(30));
1116/// assert!(adaptive.has_samples() == false);
1117/// ```
1118#[derive(Debug, Clone)]
1119pub struct AdaptiveTimeoutMiddleware {
1120    base_timeout: Duration,
1121    min_timeout: Duration,
1122    max_timeout: Duration,
1123    samples: Vec<u64>, // Processing times in milliseconds
1124    #[allow(dead_code)]
1125    max_samples: usize,
1126    percentile: f64, // Use this percentile for timeout calculation (e.g., 0.95 for p95)
1127}
1128
1129impl AdaptiveTimeoutMiddleware {
1130    /// Create a new adaptive timeout middleware
1131    pub fn new(base_timeout: Duration) -> Self {
1132        Self {
1133            base_timeout,
1134            min_timeout: Duration::from_secs(1),
1135            max_timeout: base_timeout.mul_f64(5.0), // Max 5x base timeout
1136            samples: Vec::new(),
1137            max_samples: 100,
1138            percentile: 0.95, // Default to p95
1139        }
1140    }
1141
1142    /// Set minimum timeout
1143    pub fn with_min_timeout(mut self, timeout: Duration) -> Self {
1144        self.min_timeout = timeout;
1145        self
1146    }
1147
1148    /// Set maximum timeout
1149    pub fn with_max_timeout(mut self, timeout: Duration) -> Self {
1150        self.max_timeout = timeout;
1151        self
1152    }
1153
1154    /// Set the percentile to use for timeout calculation
1155    pub fn with_percentile(mut self, percentile: f64) -> Self {
1156        self.percentile = percentile.clamp(0.0, 1.0);
1157        self
1158    }
1159
1160    /// Check if we have collected samples
1161    pub fn has_samples(&self) -> bool {
1162        !self.samples.is_empty()
1163    }
1164
1165    /// Calculate adaptive timeout based on collected samples
1166    pub fn calculate_adaptive_timeout(&self) -> Duration {
1167        if self.samples.is_empty() {
1168            return self.base_timeout;
1169        }
1170
1171        let mut sorted_samples = self.samples.clone();
1172        sorted_samples.sort_unstable();
1173
1174        let index = ((sorted_samples.len() as f64 * self.percentile) as usize)
1175            .min(sorted_samples.len() - 1);
1176        let timeout_ms = sorted_samples[index];
1177
1178        // Add 20% buffer to the percentile value
1179        let buffered_ms = (timeout_ms as f64 * 1.2) as u64;
1180
1181        let timeout = Duration::from_millis(buffered_ms);
1182
1183        // Clamp to min/max bounds
1184        timeout.clamp(self.min_timeout, self.max_timeout)
1185    }
1186}
1187
1188impl Default for AdaptiveTimeoutMiddleware {
1189    fn default() -> Self {
1190        Self::new(Duration::from_secs(30))
1191    }
1192}
1193
1194#[async_trait]
1195impl MessageMiddleware for AdaptiveTimeoutMiddleware {
1196    async fn before_publish(&self, message: &mut Message) -> Result<()> {
1197        let timeout = self.calculate_adaptive_timeout();
1198
1199        // Inject adaptive timeout into message headers
1200        message.headers.extra.insert(
1201            "x-adaptive-timeout".to_string(),
1202            serde_json::json!(timeout.as_millis() as u64),
1203        );
1204
1205        // Also add the percentile used for transparency
1206        message.headers.extra.insert(
1207            "x-timeout-percentile".to_string(),
1208            serde_json::json!(self.percentile),
1209        );
1210
1211        Ok(())
1212    }
1213
1214    async fn after_consume(&self, _message: &mut Message) -> Result<()> {
1215        // In a real implementation, we would record the actual processing time here
1216        // For now, this is a placeholder
1217        Ok(())
1218    }
1219
1220    fn name(&self) -> &str {
1221        "adaptive_timeout"
1222    }
1223}