Skip to main content

celers_kombu/
middleware_advanced.rs

1//! Advanced middleware implementations.
2
3use async_trait::async_trait;
4use celers_protocol::Message;
5use std::collections::HashMap;
6use std::time::Duration;
7
8use crate::{BrokerError, MessageMiddleware, Priority, Result};
9
10/// Batching middleware for automatic message batching
11///
12/// Automatically batches messages based on size or time thresholds.
13///
14/// # Examples
15///
16/// ```
17/// use celers_kombu::{BatchingMiddleware, MessageMiddleware};
18///
19/// let middleware = BatchingMiddleware::new(100, 5000);
20/// assert_eq!(middleware.name(), "batching");
21/// ```
22#[derive(Debug, Clone)]
23pub struct BatchingMiddleware {
24    batch_size: usize,
25    batch_timeout_ms: u64,
26}
27
28impl BatchingMiddleware {
29    /// Create a new batching middleware
30    ///
31    /// # Arguments
32    ///
33    /// * `batch_size` - Maximum messages per batch
34    /// * `batch_timeout_ms` - Maximum wait time in milliseconds
35    pub fn new(batch_size: usize, batch_timeout_ms: u64) -> Self {
36        Self {
37            batch_size,
38            batch_timeout_ms,
39        }
40    }
41
42    /// Create with default settings (100 messages, 5 second timeout)
43    pub fn with_defaults() -> Self {
44        Self::new(100, 5000)
45    }
46}
47
48#[async_trait]
49impl MessageMiddleware for BatchingMiddleware {
50    async fn before_publish(&self, message: &mut Message) -> Result<()> {
51        // Add batching metadata
52        message.headers.extra.insert(
53            "batch-size-hint".to_string(),
54            serde_json::json!(self.batch_size),
55        );
56        message.headers.extra.insert(
57            "batch-timeout-ms".to_string(),
58            serde_json::json!(self.batch_timeout_ms),
59        );
60
61        // Mark message as batch-enabled
62        message
63            .headers
64            .extra
65            .insert("batching-enabled".to_string(), serde_json::json!(true));
66
67        Ok(())
68    }
69
70    async fn after_consume(&self, _message: &mut Message) -> Result<()> {
71        // No-op for consume side
72        Ok(())
73    }
74
75    fn name(&self) -> &str {
76        "batching"
77    }
78}
79
80/// Audit middleware for comprehensive audit logging
81///
82/// Logs all message operations for audit trails and compliance.
83///
84/// # Examples
85///
86/// ```
87/// use celers_kombu::{AuditMiddleware, MessageMiddleware};
88///
89/// let middleware = AuditMiddleware::new(true);
90/// assert_eq!(middleware.name(), "audit");
91/// ```
92#[derive(Debug, Clone)]
93pub struct AuditMiddleware {
94    log_body: bool,
95}
96
97impl AuditMiddleware {
98    /// Create a new audit middleware
99    ///
100    /// # Arguments
101    ///
102    /// * `log_body` - Whether to include message body in audit logs
103    pub fn new(log_body: bool) -> Self {
104        Self { log_body }
105    }
106
107    /// Create audit middleware with body logging enabled
108    pub fn with_body_logging() -> Self {
109        Self::new(true)
110    }
111
112    /// Create audit middleware without body logging
113    pub fn without_body_logging() -> Self {
114        Self::new(false)
115    }
116
117    fn create_audit_entry(&self, message: &Message, operation: &str) -> String {
118        let timestamp = std::time::SystemTime::now()
119            .duration_since(std::time::UNIX_EPOCH)
120            .unwrap()
121            .as_secs();
122
123        let body_info = if self.log_body {
124            format!("body_size={}", message.body.len())
125        } else {
126            "body=<redacted>".to_string()
127        };
128
129        format!(
130            "[AUDIT] timestamp={} operation={} task_id={} task_name={} {}",
131            timestamp,
132            operation,
133            message.task_id(),
134            message.task_name(),
135            body_info
136        )
137    }
138}
139
140#[async_trait]
141impl MessageMiddleware for AuditMiddleware {
142    async fn before_publish(&self, message: &mut Message) -> Result<()> {
143        let audit_entry = self.create_audit_entry(message, "PUBLISH");
144
145        // In production, this would be sent to an audit logging system
146        message
147            .headers
148            .extra
149            .insert("audit-publish".to_string(), serde_json::json!(audit_entry));
150
151        // Add audit ID
152        let audit_id = uuid::Uuid::new_v4().to_string();
153        message
154            .headers
155            .extra
156            .insert("audit-id".to_string(), serde_json::json!(audit_id));
157
158        Ok(())
159    }
160
161    async fn after_consume(&self, message: &mut Message) -> Result<()> {
162        let audit_entry = self.create_audit_entry(message, "CONSUME");
163
164        // In production, this would be sent to an audit logging system
165        message
166            .headers
167            .extra
168            .insert("audit-consume".to_string(), serde_json::json!(audit_entry));
169
170        Ok(())
171    }
172
173    fn name(&self) -> &str {
174        "audit"
175    }
176}
177
178/// Middleware for enforcing hard deadlines on message processing.
179///
180/// Unlike TimeoutMiddleware which sets a timeout hint, DeadlineMiddleware
181/// enforces a hard deadline (absolute time) by which a message must be processed.
182///
183/// # Examples
184///
185/// ```
186/// use celers_kombu::{DeadlineMiddleware, MessageMiddleware};
187/// use std::time::Duration;
188///
189/// // Enforce 5-minute deadline from now
190/// let middleware = DeadlineMiddleware::new(Duration::from_secs(300));
191/// assert_eq!(middleware.name(), "deadline");
192/// ```
193#[derive(Debug, Clone)]
194pub struct DeadlineMiddleware {
195    deadline_duration: Duration,
196}
197
198impl DeadlineMiddleware {
199    /// Create a new deadline middleware with the specified duration from now
200    pub fn new(deadline_duration: Duration) -> Self {
201        Self { deadline_duration }
202    }
203
204    /// Get the deadline duration
205    pub fn deadline_duration(&self) -> Duration {
206        self.deadline_duration
207    }
208}
209
210#[async_trait]
211impl MessageMiddleware for DeadlineMiddleware {
212    async fn before_publish(&self, message: &mut Message) -> Result<()> {
213        // Calculate absolute deadline timestamp
214        let now = std::time::SystemTime::now()
215            .duration_since(std::time::UNIX_EPOCH)
216            .unwrap()
217            .as_secs();
218        let deadline = now + self.deadline_duration.as_secs();
219
220        message
221            .headers
222            .extra
223            .insert("x-deadline".to_string(), serde_json::json!(deadline));
224
225        Ok(())
226    }
227
228    async fn after_consume(&self, message: &mut Message) -> Result<()> {
229        // Check if deadline has passed
230        if let Some(deadline_value) = message.headers.extra.get("x-deadline") {
231            if let Some(deadline) = deadline_value.as_u64() {
232                let now = std::time::SystemTime::now()
233                    .duration_since(std::time::UNIX_EPOCH)
234                    .unwrap()
235                    .as_secs();
236
237                if now > deadline {
238                    // Mark message as deadline-exceeded
239                    message
240                        .headers
241                        .extra
242                        .insert("x-deadline-exceeded".to_string(), serde_json::json!(true));
243                }
244            }
245        }
246
247        Ok(())
248    }
249
250    fn name(&self) -> &str {
251        "deadline"
252    }
253}
254
255/// Middleware for content type validation and conversion hints.
256///
257/// Validates that messages have acceptable content types and can inject
258/// conversion hints for consumers.
259///
260/// # Examples
261///
262/// ```
263/// use celers_kombu::{ContentTypeMiddleware, MessageMiddleware};
264///
265/// // Only allow JSON messages
266/// let middleware = ContentTypeMiddleware::new(vec!["application/json".to_string()]);
267/// assert_eq!(middleware.name(), "content_type");
268/// ```
269#[derive(Debug, Clone)]
270pub struct ContentTypeMiddleware {
271    allowed_content_types: Vec<String>,
272    default_content_type: String,
273}
274
275impl ContentTypeMiddleware {
276    /// Create a new content type middleware
277    pub fn new(allowed_content_types: Vec<String>) -> Self {
278        Self {
279            allowed_content_types,
280            default_content_type: "application/json".to_string(),
281        }
282    }
283
284    /// Set the default content type for messages without one
285    pub fn with_default(mut self, content_type: String) -> Self {
286        self.default_content_type = content_type;
287        self
288    }
289
290    /// Check if a content type is allowed
291    pub fn is_allowed(&self, content_type: &str) -> bool {
292        self.allowed_content_types.is_empty()
293            || self
294                .allowed_content_types
295                .contains(&content_type.to_string())
296    }
297}
298
299#[async_trait]
300impl MessageMiddleware for ContentTypeMiddleware {
301    async fn before_publish(&self, message: &mut Message) -> Result<()> {
302        // Set default content type if not present
303        if message.content_type.is_empty() {
304            message.content_type = self.default_content_type.clone();
305        }
306
307        // Validate content type
308        if !self.is_allowed(&message.content_type) {
309            return Err(BrokerError::Configuration(format!(
310                "Content type '{}' is not allowed. Allowed types: {:?}",
311                message.content_type, self.allowed_content_types
312            )));
313        }
314
315        Ok(())
316    }
317
318    async fn after_consume(&self, message: &mut Message) -> Result<()> {
319        // Validate content type on consume
320        if !self.is_allowed(&message.content_type) {
321            message.headers.extra.insert(
322                "x-content-type-warning".to_string(),
323                serde_json::json!(format!("Unexpected content type: {}", message.content_type)),
324            );
325        }
326
327        Ok(())
328    }
329
330    fn name(&self) -> &str {
331        "content_type"
332    }
333}
334
335/// Middleware for dynamic routing key assignment.
336///
337/// Assigns routing keys to messages based on custom logic or message content.
338/// Useful for implementing dynamic routing strategies.
339///
340/// # Examples
341///
342/// ```
343/// use celers_kombu::{RoutingKeyMiddleware, MessageMiddleware};
344///
345/// // Use task name as routing key
346/// let middleware = RoutingKeyMiddleware::new(|msg| {
347///     format!("tasks.{}", msg.headers.task)
348/// });
349/// assert_eq!(middleware.name(), "routing_key");
350/// ```
351pub struct RoutingKeyMiddleware {
352    key_generator: Box<dyn Fn(&Message) -> String + Send + Sync>,
353}
354
355impl RoutingKeyMiddleware {
356    /// Create a new routing key middleware with a custom key generator
357    pub fn new<F>(key_generator: F) -> Self
358    where
359        F: Fn(&Message) -> String + Send + Sync + 'static,
360    {
361        Self {
362            key_generator: Box::new(key_generator),
363        }
364    }
365
366    /// Create a routing key from task name
367    pub fn from_task_name() -> Self {
368        Self::new(|msg| format!("tasks.{}", msg.headers.task))
369    }
370
371    /// Create a routing key from task name with priority
372    pub fn from_task_and_priority() -> Self {
373        Self::new(|msg| {
374            let priority = msg
375                .headers
376                .extra
377                .get("priority")
378                .and_then(|v| v.as_u64())
379                .unwrap_or(0);
380            format!("tasks.{}.priority_{}", msg.headers.task, priority)
381        })
382    }
383}
384
385#[async_trait]
386impl MessageMiddleware for RoutingKeyMiddleware {
387    async fn before_publish(&self, message: &mut Message) -> Result<()> {
388        let routing_key = (self.key_generator)(message);
389        message
390            .headers
391            .extra
392            .insert("x-routing-key".to_string(), serde_json::json!(routing_key));
393
394        Ok(())
395    }
396
397    async fn after_consume(&self, _message: &mut Message) -> Result<()> {
398        // No action needed on consume
399        Ok(())
400    }
401
402    fn name(&self) -> &str {
403        "routing_key"
404    }
405}
406
407/// Idempotency middleware for ensuring exactly-once message processing
408///
409/// This middleware tracks processed message IDs to prevent duplicate processing.
410/// Unlike DeduplicationMiddleware which only prevents duplicate publishing,
411/// IdempotencyMiddleware ensures that a message is processed only once even if
412/// it's delivered multiple times (e.g., due to network issues or retries).
413///
414/// # Examples
415///
416/// ```
417/// use celers_kombu::{IdempotencyMiddleware, MessageMiddleware};
418///
419/// let middleware = IdempotencyMiddleware::new(10000);
420/// assert_eq!(middleware.name(), "idempotency");
421/// ```
422pub struct IdempotencyMiddleware {
423    processed_ids: std::sync::Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
424    max_cache_size: usize,
425}
426
427impl IdempotencyMiddleware {
428    /// Create a new idempotency middleware with a custom cache size
429    pub fn new(max_cache_size: usize) -> Self {
430        Self {
431            processed_ids: std::sync::Arc::new(std::sync::Mutex::new(
432                std::collections::HashSet::new(),
433            )),
434            max_cache_size,
435        }
436    }
437
438    /// Create a new idempotency middleware with default cache size (10,000)
439    pub fn with_default_cache() -> Self {
440        Self::new(10000)
441    }
442
443    /// Check if a message ID has been processed
444    pub fn is_processed(&self, message_id: &str) -> bool {
445        self.processed_ids.lock().unwrap().contains(message_id)
446    }
447
448    /// Mark a message ID as processed
449    pub fn mark_processed(&self, message_id: String) {
450        let mut cache = self.processed_ids.lock().unwrap();
451
452        // Simple cache eviction: if we exceed max size, clear oldest 20%
453        if cache.len() >= self.max_cache_size {
454            let to_remove = self.max_cache_size / 5;
455            let ids_to_remove: Vec<String> = cache.iter().take(to_remove).cloned().collect();
456            for id in ids_to_remove {
457                cache.remove(&id);
458            }
459        }
460
461        cache.insert(message_id);
462    }
463
464    /// Clear all processed message IDs
465    pub fn clear(&self) {
466        self.processed_ids.lock().unwrap().clear();
467    }
468
469    /// Get the number of tracked message IDs
470    pub fn cache_size(&self) -> usize {
471        self.processed_ids.lock().unwrap().len()
472    }
473}
474
475#[async_trait]
476impl MessageMiddleware for IdempotencyMiddleware {
477    async fn before_publish(&self, message: &mut Message) -> Result<()> {
478        // Add idempotency key to message headers for tracking
479        let idempotency_key = format!("{}:{}", message.headers.id, message.headers.task);
480        message.headers.extra.insert(
481            "x-idempotency-key".to_string(),
482            serde_json::json!(idempotency_key),
483        );
484        Ok(())
485    }
486
487    async fn after_consume(&self, message: &mut Message) -> Result<()> {
488        // Check if message has already been processed
489        let idempotency_key = message
490            .headers
491            .extra
492            .get("x-idempotency-key")
493            .and_then(|v| v.as_str())
494            .map(|s| s.to_string())
495            .unwrap_or_else(|| {
496                // Fallback to generating key if not present
497                format!("{}:{}", message.headers.id, message.headers.task)
498            });
499
500        if self.is_processed(&idempotency_key) {
501            // Message already processed, mark it in headers
502            message
503                .headers
504                .extra
505                .insert("x-already-processed".to_string(), serde_json::json!(true));
506        } else {
507            // Mark as being processed
508            self.mark_processed(idempotency_key.clone());
509            message
510                .headers
511                .extra
512                .insert("x-already-processed".to_string(), serde_json::json!(false));
513        }
514
515        Ok(())
516    }
517
518    fn name(&self) -> &str {
519        "idempotency"
520    }
521}
522
523/// Backoff middleware for automatic retry backoff calculation
524///
525/// This middleware automatically calculates and injects retry backoff delays
526/// based on the number of retries, using exponential backoff with jitter.
527/// This helps prevent thundering herd problems when retrying failed messages.
528///
529/// # Examples
530///
531/// ```
532/// use celers_kombu::{BackoffMiddleware, MessageMiddleware};
533/// use std::time::Duration;
534///
535/// let middleware = BackoffMiddleware::new(
536///     Duration::from_secs(1),
537///     Duration::from_secs(300),
538///     2.0
539/// );
540/// assert_eq!(middleware.name(), "backoff");
541/// ```
542pub struct BackoffMiddleware {
543    initial_delay: Duration,
544    max_delay: Duration,
545    multiplier: f64,
546}
547
548impl BackoffMiddleware {
549    /// Create a new backoff middleware with custom settings
550    ///
551    /// # Arguments
552    ///
553    /// * `initial_delay` - Initial retry delay
554    /// * `max_delay` - Maximum retry delay
555    /// * `multiplier` - Backoff multiplier (typically 2.0)
556    pub fn new(initial_delay: Duration, max_delay: Duration, multiplier: f64) -> Self {
557        Self {
558            initial_delay,
559            max_delay,
560            multiplier,
561        }
562    }
563
564    /// Create a new backoff middleware with default settings
565    ///
566    /// Defaults: 1s initial, 5min max, 2.0 multiplier
567    pub fn with_defaults() -> Self {
568        Self::new(Duration::from_secs(1), Duration::from_secs(300), 2.0)
569    }
570
571    /// Calculate backoff delay for a given retry attempt
572    fn calculate_delay(&self, retry_count: u32) -> Duration {
573        let delay_secs =
574            self.initial_delay.as_secs_f64() * self.multiplier.powi(retry_count as i32);
575        let delay = Duration::from_secs_f64(delay_secs.min(self.max_delay.as_secs_f64()));
576
577        // Add jitter (0-25% of the delay)
578        let jitter = (delay.as_secs_f64() * 0.25 * rand::random::<f64>()).round() as u64;
579        delay + Duration::from_secs(jitter)
580    }
581}
582
583#[async_trait]
584impl MessageMiddleware for BackoffMiddleware {
585    async fn before_publish(&self, _message: &mut Message) -> Result<()> {
586        // No action needed on publish
587        Ok(())
588    }
589
590    async fn after_consume(&self, message: &mut Message) -> Result<()> {
591        // Calculate and inject backoff delay based on retry count
592        let retry_count = message
593            .headers
594            .extra
595            .get("retries")
596            .and_then(|v| v.as_u64())
597            .unwrap_or(0) as u32;
598
599        let backoff_delay = self.calculate_delay(retry_count);
600
601        message.headers.extra.insert(
602            "x-backoff-delay".to_string(),
603            serde_json::json!(backoff_delay.as_millis() as u64),
604        );
605
606        message.headers.extra.insert(
607            "x-next-retry-at".to_string(),
608            serde_json::json!((std::time::SystemTime::now() + backoff_delay)
609                .duration_since(std::time::UNIX_EPOCH)
610                .unwrap()
611                .as_secs()),
612        );
613
614        Ok(())
615    }
616
617    fn name(&self) -> &str {
618        "backoff"
619    }
620}
621
622/// Caching middleware for storing message processing results
623///
624/// This middleware caches the results of message processing to avoid
625/// reprocessing identical messages. Useful for expensive operations that
626/// are idempotent (e.g., external API calls, database queries).
627///
628/// # Examples
629///
630/// ```
631/// use celers_kombu::{CachingMiddleware, MessageMiddleware};
632/// use std::time::Duration;
633///
634/// let middleware = CachingMiddleware::new(1000, Duration::from_secs(3600));
635/// assert_eq!(middleware.name(), "caching");
636/// ```
637pub struct CachingMiddleware {
638    cache: std::sync::Arc<std::sync::Mutex<CacheMap>>,
639    max_entries: usize,
640    ttl: Duration,
641}
642
643type CacheMap = std::collections::HashMap<String, (Vec<u8>, std::time::Instant)>;
644
645impl CachingMiddleware {
646    /// Create a new caching middleware with custom settings
647    ///
648    /// # Arguments
649    ///
650    /// * `max_entries` - Maximum number of cache entries
651    /// * `ttl` - Time-to-live for cache entries
652    pub fn new(max_entries: usize, ttl: Duration) -> Self {
653        Self {
654            cache: std::sync::Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
655            max_entries,
656            ttl,
657        }
658    }
659
660    /// Create a new caching middleware with default settings
661    ///
662    /// Defaults: 10,000 entries, 1 hour TTL
663    pub fn with_defaults() -> Self {
664        Self::new(10_000, Duration::from_secs(3600))
665    }
666
667    /// Generate cache key from message
668    fn cache_key(&self, message: &Message) -> String {
669        // Use message ID and task name as cache key
670        format!("{}:{}", message.headers.id, message.headers.task)
671    }
672
673    /// Check if a cached result exists and is still valid
674    pub fn get_cached(&self, message: &Message) -> Option<Vec<u8>> {
675        let key = self.cache_key(message);
676        let mut cache = self.cache.lock().unwrap();
677
678        if let Some((result, timestamp)) = cache.get(&key) {
679            if timestamp.elapsed() < self.ttl {
680                return Some(result.clone());
681            }
682            // Remove expired entry
683            cache.remove(&key);
684        }
685        None
686    }
687
688    /// Store a result in the cache
689    pub fn store_result(&self, message: &Message, result: Vec<u8>) {
690        let key = self.cache_key(message);
691        let mut cache = self.cache.lock().unwrap();
692
693        // Evict oldest entries if cache is full
694        if cache.len() >= self.max_entries {
695            let to_remove = cache.len() / 5; // Remove oldest 20%
696            let mut entries: Vec<_> = cache.iter().map(|(k, v)| (k.clone(), v.1)).collect();
697            entries.sort_by_key(|(_, timestamp)| *timestamp);
698
699            for (key, _) in entries.iter().take(to_remove) {
700                cache.remove(key);
701            }
702        }
703
704        cache.insert(key, (result, std::time::Instant::now()));
705    }
706
707    /// Clear all cached results
708    pub fn clear(&self) {
709        self.cache.lock().unwrap().clear();
710    }
711
712    /// Get the number of cached entries
713    pub fn cache_size(&self) -> usize {
714        self.cache.lock().unwrap().len()
715    }
716}
717
718#[async_trait]
719impl MessageMiddleware for CachingMiddleware {
720    async fn before_publish(&self, _message: &mut Message) -> Result<()> {
721        // No action needed on publish
722        Ok(())
723    }
724
725    async fn after_consume(&self, message: &mut Message) -> Result<()> {
726        // Check if result is cached
727        if let Some(cached_result) = self.get_cached(message) {
728            message
729                .headers
730                .extra
731                .insert("x-cache-hit".to_string(), serde_json::json!(true));
732            message.headers.extra.insert(
733                "x-cached-result-size".to_string(),
734                serde_json::json!(cached_result.len()),
735            );
736        } else {
737            message
738                .headers
739                .extra
740                .insert("x-cache-hit".to_string(), serde_json::json!(false));
741        }
742        Ok(())
743    }
744
745    fn name(&self) -> &str {
746        "caching"
747    }
748}
749
750/// Bulkhead middleware for limiting concurrent operations per partition
751///
752/// This middleware implements the bulkhead pattern to prevent resource exhaustion
753/// by limiting the number of concurrent operations per partition/queue.
754///
755/// # Examples
756///
757/// ```
758/// use celers_kombu::BulkheadMiddleware;
759///
760/// // Create bulkhead with max 50 concurrent operations per partition
761/// let bulkhead = BulkheadMiddleware::new(50);
762///
763/// // Create with custom partition key extractor
764/// let bulkhead = BulkheadMiddleware::with_partition_fn(50, |msg| {
765///     msg.headers.extra.get("partition_key")
766///         .and_then(|v| v.as_str())
767///         .map(|s| s.to_string())
768///         .unwrap_or_else(|| "default".to_string())
769/// });
770/// ```
771#[derive(Clone)]
772pub struct BulkheadMiddleware {
773    max_concurrent: usize,
774    permits: std::sync::Arc<std::sync::Mutex<HashMap<String, usize>>>,
775    partition_fn: std::sync::Arc<dyn Fn(&Message) -> String + Send + Sync>,
776}
777
778impl BulkheadMiddleware {
779    /// Create a new bulkhead middleware with max concurrent operations
780    ///
781    /// # Arguments
782    ///
783    /// * `max_concurrent` - Maximum number of concurrent operations per partition
784    pub fn new(max_concurrent: usize) -> Self {
785        Self {
786            max_concurrent,
787            permits: std::sync::Arc::new(std::sync::Mutex::new(HashMap::new())),
788            partition_fn: std::sync::Arc::new(|msg| {
789                // Default: partition by task name
790                msg.headers.task.clone()
791            }),
792        }
793    }
794
795    /// Create with custom partition key extraction function
796    pub fn with_partition_fn<F>(max_concurrent: usize, partition_fn: F) -> Self
797    where
798        F: Fn(&Message) -> String + Send + Sync + 'static,
799    {
800        Self {
801            max_concurrent,
802            permits: std::sync::Arc::new(std::sync::Mutex::new(HashMap::new())),
803            partition_fn: std::sync::Arc::new(partition_fn),
804        }
805    }
806
807    /// Try to acquire a permit for the given partition
808    pub fn try_acquire(&self, partition: &str) -> bool {
809        let mut permits = self.permits.lock().unwrap();
810        let current = permits.entry(partition.to_string()).or_insert(0);
811        if *current < self.max_concurrent {
812            *current += 1;
813            true
814        } else {
815            false
816        }
817    }
818
819    /// Release a permit for the given partition
820    pub fn release(&self, partition: &str) {
821        let mut permits = self.permits.lock().unwrap();
822        if let Some(current) = permits.get_mut(partition) {
823            if *current > 0 {
824                *current -= 1;
825            }
826        }
827    }
828
829    /// Get current concurrent operations for a partition
830    pub fn current_operations(&self, partition: &str) -> usize {
831        self.permits
832            .lock()
833            .unwrap()
834            .get(partition)
835            .copied()
836            .unwrap_or(0)
837    }
838
839    /// Get total concurrent operations across all partitions
840    pub fn total_operations(&self) -> usize {
841        self.permits.lock().unwrap().values().sum()
842    }
843}
844
845#[async_trait]
846impl MessageMiddleware for BulkheadMiddleware {
847    async fn before_publish(&self, message: &mut Message) -> Result<()> {
848        let partition = (self.partition_fn)(message);
849        if !self.try_acquire(&partition) {
850            message
851                .headers
852                .extra
853                .insert("x-bulkhead-rejected".to_string(), serde_json::json!(true));
854            message.headers.extra.insert(
855                "x-bulkhead-partition".to_string(),
856                serde_json::json!(partition),
857            );
858            message.headers.extra.insert(
859                "x-bulkhead-current".to_string(),
860                serde_json::json!(self.max_concurrent),
861            );
862        } else {
863            message.headers.extra.insert(
864                "x-bulkhead-partition".to_string(),
865                serde_json::json!(partition),
866            );
867        }
868        Ok(())
869    }
870
871    async fn after_consume(&self, message: &mut Message) -> Result<()> {
872        let partition = (self.partition_fn)(message);
873        self.release(&partition);
874        Ok(())
875    }
876
877    fn name(&self) -> &str {
878        "bulkhead"
879    }
880}
881
882/// Priority boost middleware for dynamic priority adjustment
883///
884/// This middleware dynamically adjusts message priority based on configurable rules
885/// such as message age, retry count, or custom criteria.
886///
887/// # Examples
888///
889/// ```
890/// use celers_kombu::{PriorityBoostMiddleware, Priority};
891/// use std::time::Duration;
892///
893/// // Boost priority for messages older than 5 minutes
894/// let boost = PriorityBoostMiddleware::new()
895///     .with_age_boost(Duration::from_secs(300), Priority::High);
896///
897/// // Custom boost function
898/// let boost = PriorityBoostMiddleware::with_custom_fn(|msg, current_priority| {
899///     if msg.headers.retries.unwrap_or(0) > 3 {
900///         Priority::Highest
901///     } else {
902///         current_priority
903///     }
904/// });
905/// ```
906pub type PriorityBoostFn = std::sync::Arc<dyn Fn(&Message, Priority) -> Priority + Send + Sync>;
907
908#[derive(Clone)]
909pub struct PriorityBoostMiddleware {
910    age_threshold: Option<Duration>,
911    age_boost_priority: Priority,
912    retry_threshold: Option<u32>,
913    retry_boost_priority: Priority,
914    custom_fn: Option<PriorityBoostFn>,
915}
916
917impl PriorityBoostMiddleware {
918    /// Create a new priority boost middleware with defaults
919    pub fn new() -> Self {
920        Self {
921            age_threshold: None,
922            age_boost_priority: Priority::High,
923            retry_threshold: None,
924            retry_boost_priority: Priority::High,
925            custom_fn: None,
926        }
927    }
928
929    /// Boost priority for messages older than the specified duration
930    pub fn with_age_boost(mut self, threshold: Duration, priority: Priority) -> Self {
931        self.age_threshold = Some(threshold);
932        self.age_boost_priority = priority;
933        self
934    }
935
936    /// Boost priority for messages with retry count exceeding threshold
937    pub fn with_retry_boost(mut self, threshold: u32, priority: Priority) -> Self {
938        self.retry_threshold = Some(threshold);
939        self.retry_boost_priority = priority;
940        self
941    }
942
943    /// Create with custom priority boost function
944    pub fn with_custom_fn<F>(custom_fn: F) -> Self
945    where
946        F: Fn(&Message, Priority) -> Priority + Send + Sync + 'static,
947    {
948        Self {
949            age_threshold: None,
950            age_boost_priority: Priority::High,
951            retry_threshold: None,
952            retry_boost_priority: Priority::High,
953            custom_fn: Some(std::sync::Arc::new(custom_fn)),
954        }
955    }
956
957    /// Calculate boosted priority for a message
958    pub fn calculate_priority(&self, message: &Message, current_priority: Priority) -> Priority {
959        let mut priority = current_priority;
960
961        // Apply custom function if provided
962        if let Some(ref custom_fn) = self.custom_fn {
963            return custom_fn(message, priority);
964        }
965
966        // Check retry count
967        if let Some(retry_threshold) = self.retry_threshold {
968            if message.headers.retries.unwrap_or(0) >= retry_threshold {
969                priority = priority.max(self.retry_boost_priority);
970            }
971        }
972
973        // Check message age (using timestamp if available)
974        if let Some(age_threshold) = self.age_threshold {
975            if let Some(timestamp_value) = message.headers.extra.get("timestamp") {
976                if let Some(timestamp_secs) = timestamp_value.as_f64() {
977                    let msg_age = std::time::SystemTime::now()
978                        .duration_since(std::time::UNIX_EPOCH)
979                        .unwrap()
980                        .as_secs_f64()
981                        - timestamp_secs;
982                    if msg_age > age_threshold.as_secs_f64() {
983                        priority = priority.max(self.age_boost_priority);
984                    }
985                }
986            }
987        }
988
989        priority
990    }
991}
992
993impl Default for PriorityBoostMiddleware {
994    fn default() -> Self {
995        Self::new()
996    }
997}
998
999#[async_trait]
1000impl MessageMiddleware for PriorityBoostMiddleware {
1001    async fn before_publish(&self, message: &mut Message) -> Result<()> {
1002        // Get current priority from message headers
1003        let current_priority = message
1004            .headers
1005            .extra
1006            .get("priority")
1007            .and_then(|v| v.as_u64())
1008            .map(|p| Priority::from_u8(p as u8))
1009            .unwrap_or(Priority::Normal);
1010
1011        let boosted_priority = self.calculate_priority(message, current_priority);
1012
1013        if boosted_priority != current_priority {
1014            message.headers.extra.insert(
1015                "priority".to_string(),
1016                serde_json::json!(boosted_priority.as_u8()),
1017            );
1018            message
1019                .headers
1020                .extra
1021                .insert("x-priority-boosted".to_string(), serde_json::json!(true));
1022            message.headers.extra.insert(
1023                "x-original-priority".to_string(),
1024                serde_json::json!(current_priority.as_u8()),
1025            );
1026        }
1027        Ok(())
1028    }
1029
1030    async fn after_consume(&self, _message: &mut Message) -> Result<()> {
1031        // No action needed on consume
1032        Ok(())
1033    }
1034
1035    fn name(&self) -> &str {
1036        "priority_boost"
1037    }
1038}