1use async_trait::async_trait;
4use celers_protocol::Message;
5use std::collections::HashMap;
6use std::time::Duration;
7
8use crate::{BrokerError, MessageMiddleware, Priority, Result};
9
10#[derive(Debug, Clone)]
23pub struct BatchingMiddleware {
24 batch_size: usize,
25 batch_timeout_ms: u64,
26}
27
28impl BatchingMiddleware {
29 pub fn new(batch_size: usize, batch_timeout_ms: u64) -> Self {
36 Self {
37 batch_size,
38 batch_timeout_ms,
39 }
40 }
41
42 pub fn with_defaults() -> Self {
44 Self::new(100, 5000)
45 }
46}
47
48#[async_trait]
49impl MessageMiddleware for BatchingMiddleware {
50 async fn before_publish(&self, message: &mut Message) -> Result<()> {
51 message.headers.extra.insert(
53 "batch-size-hint".to_string(),
54 serde_json::json!(self.batch_size),
55 );
56 message.headers.extra.insert(
57 "batch-timeout-ms".to_string(),
58 serde_json::json!(self.batch_timeout_ms),
59 );
60
61 message
63 .headers
64 .extra
65 .insert("batching-enabled".to_string(), serde_json::json!(true));
66
67 Ok(())
68 }
69
70 async fn after_consume(&self, _message: &mut Message) -> Result<()> {
71 Ok(())
73 }
74
75 fn name(&self) -> &str {
76 "batching"
77 }
78}
79
80#[derive(Debug, Clone)]
93pub struct AuditMiddleware {
94 log_body: bool,
95}
96
97impl AuditMiddleware {
98 pub fn new(log_body: bool) -> Self {
104 Self { log_body }
105 }
106
107 pub fn with_body_logging() -> Self {
109 Self::new(true)
110 }
111
112 pub fn without_body_logging() -> Self {
114 Self::new(false)
115 }
116
117 fn create_audit_entry(&self, message: &Message, operation: &str) -> String {
118 let timestamp = std::time::SystemTime::now()
119 .duration_since(std::time::UNIX_EPOCH)
120 .unwrap()
121 .as_secs();
122
123 let body_info = if self.log_body {
124 format!("body_size={}", message.body.len())
125 } else {
126 "body=<redacted>".to_string()
127 };
128
129 format!(
130 "[AUDIT] timestamp={} operation={} task_id={} task_name={} {}",
131 timestamp,
132 operation,
133 message.task_id(),
134 message.task_name(),
135 body_info
136 )
137 }
138}
139
140#[async_trait]
141impl MessageMiddleware for AuditMiddleware {
142 async fn before_publish(&self, message: &mut Message) -> Result<()> {
143 let audit_entry = self.create_audit_entry(message, "PUBLISH");
144
145 message
147 .headers
148 .extra
149 .insert("audit-publish".to_string(), serde_json::json!(audit_entry));
150
151 let audit_id = uuid::Uuid::new_v4().to_string();
153 message
154 .headers
155 .extra
156 .insert("audit-id".to_string(), serde_json::json!(audit_id));
157
158 Ok(())
159 }
160
161 async fn after_consume(&self, message: &mut Message) -> Result<()> {
162 let audit_entry = self.create_audit_entry(message, "CONSUME");
163
164 message
166 .headers
167 .extra
168 .insert("audit-consume".to_string(), serde_json::json!(audit_entry));
169
170 Ok(())
171 }
172
173 fn name(&self) -> &str {
174 "audit"
175 }
176}
177
178#[derive(Debug, Clone)]
194pub struct DeadlineMiddleware {
195 deadline_duration: Duration,
196}
197
198impl DeadlineMiddleware {
199 pub fn new(deadline_duration: Duration) -> Self {
201 Self { deadline_duration }
202 }
203
204 pub fn deadline_duration(&self) -> Duration {
206 self.deadline_duration
207 }
208}
209
210#[async_trait]
211impl MessageMiddleware for DeadlineMiddleware {
212 async fn before_publish(&self, message: &mut Message) -> Result<()> {
213 let now = std::time::SystemTime::now()
215 .duration_since(std::time::UNIX_EPOCH)
216 .unwrap()
217 .as_secs();
218 let deadline = now + self.deadline_duration.as_secs();
219
220 message
221 .headers
222 .extra
223 .insert("x-deadline".to_string(), serde_json::json!(deadline));
224
225 Ok(())
226 }
227
228 async fn after_consume(&self, message: &mut Message) -> Result<()> {
229 if let Some(deadline_value) = message.headers.extra.get("x-deadline") {
231 if let Some(deadline) = deadline_value.as_u64() {
232 let now = std::time::SystemTime::now()
233 .duration_since(std::time::UNIX_EPOCH)
234 .unwrap()
235 .as_secs();
236
237 if now > deadline {
238 message
240 .headers
241 .extra
242 .insert("x-deadline-exceeded".to_string(), serde_json::json!(true));
243 }
244 }
245 }
246
247 Ok(())
248 }
249
250 fn name(&self) -> &str {
251 "deadline"
252 }
253}
254
255#[derive(Debug, Clone)]
270pub struct ContentTypeMiddleware {
271 allowed_content_types: Vec<String>,
272 default_content_type: String,
273}
274
275impl ContentTypeMiddleware {
276 pub fn new(allowed_content_types: Vec<String>) -> Self {
278 Self {
279 allowed_content_types,
280 default_content_type: "application/json".to_string(),
281 }
282 }
283
284 pub fn with_default(mut self, content_type: String) -> Self {
286 self.default_content_type = content_type;
287 self
288 }
289
290 pub fn is_allowed(&self, content_type: &str) -> bool {
292 self.allowed_content_types.is_empty()
293 || self
294 .allowed_content_types
295 .contains(&content_type.to_string())
296 }
297}
298
299#[async_trait]
300impl MessageMiddleware for ContentTypeMiddleware {
301 async fn before_publish(&self, message: &mut Message) -> Result<()> {
302 if message.content_type.is_empty() {
304 message.content_type = self.default_content_type.clone();
305 }
306
307 if !self.is_allowed(&message.content_type) {
309 return Err(BrokerError::Configuration(format!(
310 "Content type '{}' is not allowed. Allowed types: {:?}",
311 message.content_type, self.allowed_content_types
312 )));
313 }
314
315 Ok(())
316 }
317
318 async fn after_consume(&self, message: &mut Message) -> Result<()> {
319 if !self.is_allowed(&message.content_type) {
321 message.headers.extra.insert(
322 "x-content-type-warning".to_string(),
323 serde_json::json!(format!("Unexpected content type: {}", message.content_type)),
324 );
325 }
326
327 Ok(())
328 }
329
330 fn name(&self) -> &str {
331 "content_type"
332 }
333}
334
335pub struct RoutingKeyMiddleware {
352 key_generator: Box<dyn Fn(&Message) -> String + Send + Sync>,
353}
354
355impl RoutingKeyMiddleware {
356 pub fn new<F>(key_generator: F) -> Self
358 where
359 F: Fn(&Message) -> String + Send + Sync + 'static,
360 {
361 Self {
362 key_generator: Box::new(key_generator),
363 }
364 }
365
366 pub fn from_task_name() -> Self {
368 Self::new(|msg| format!("tasks.{}", msg.headers.task))
369 }
370
371 pub fn from_task_and_priority() -> Self {
373 Self::new(|msg| {
374 let priority = msg
375 .headers
376 .extra
377 .get("priority")
378 .and_then(|v| v.as_u64())
379 .unwrap_or(0);
380 format!("tasks.{}.priority_{}", msg.headers.task, priority)
381 })
382 }
383}
384
385#[async_trait]
386impl MessageMiddleware for RoutingKeyMiddleware {
387 async fn before_publish(&self, message: &mut Message) -> Result<()> {
388 let routing_key = (self.key_generator)(message);
389 message
390 .headers
391 .extra
392 .insert("x-routing-key".to_string(), serde_json::json!(routing_key));
393
394 Ok(())
395 }
396
397 async fn after_consume(&self, _message: &mut Message) -> Result<()> {
398 Ok(())
400 }
401
402 fn name(&self) -> &str {
403 "routing_key"
404 }
405}
406
407pub struct IdempotencyMiddleware {
423 processed_ids: std::sync::Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
424 max_cache_size: usize,
425}
426
427impl IdempotencyMiddleware {
428 pub fn new(max_cache_size: usize) -> Self {
430 Self {
431 processed_ids: std::sync::Arc::new(std::sync::Mutex::new(
432 std::collections::HashSet::new(),
433 )),
434 max_cache_size,
435 }
436 }
437
438 pub fn with_default_cache() -> Self {
440 Self::new(10000)
441 }
442
443 pub fn is_processed(&self, message_id: &str) -> bool {
445 self.processed_ids.lock().unwrap().contains(message_id)
446 }
447
448 pub fn mark_processed(&self, message_id: String) {
450 let mut cache = self.processed_ids.lock().unwrap();
451
452 if cache.len() >= self.max_cache_size {
454 let to_remove = self.max_cache_size / 5;
455 let ids_to_remove: Vec<String> = cache.iter().take(to_remove).cloned().collect();
456 for id in ids_to_remove {
457 cache.remove(&id);
458 }
459 }
460
461 cache.insert(message_id);
462 }
463
464 pub fn clear(&self) {
466 self.processed_ids.lock().unwrap().clear();
467 }
468
469 pub fn cache_size(&self) -> usize {
471 self.processed_ids.lock().unwrap().len()
472 }
473}
474
475#[async_trait]
476impl MessageMiddleware for IdempotencyMiddleware {
477 async fn before_publish(&self, message: &mut Message) -> Result<()> {
478 let idempotency_key = format!("{}:{}", message.headers.id, message.headers.task);
480 message.headers.extra.insert(
481 "x-idempotency-key".to_string(),
482 serde_json::json!(idempotency_key),
483 );
484 Ok(())
485 }
486
487 async fn after_consume(&self, message: &mut Message) -> Result<()> {
488 let idempotency_key = message
490 .headers
491 .extra
492 .get("x-idempotency-key")
493 .and_then(|v| v.as_str())
494 .map(|s| s.to_string())
495 .unwrap_or_else(|| {
496 format!("{}:{}", message.headers.id, message.headers.task)
498 });
499
500 if self.is_processed(&idempotency_key) {
501 message
503 .headers
504 .extra
505 .insert("x-already-processed".to_string(), serde_json::json!(true));
506 } else {
507 self.mark_processed(idempotency_key.clone());
509 message
510 .headers
511 .extra
512 .insert("x-already-processed".to_string(), serde_json::json!(false));
513 }
514
515 Ok(())
516 }
517
518 fn name(&self) -> &str {
519 "idempotency"
520 }
521}
522
523pub struct BackoffMiddleware {
543 initial_delay: Duration,
544 max_delay: Duration,
545 multiplier: f64,
546}
547
548impl BackoffMiddleware {
549 pub fn new(initial_delay: Duration, max_delay: Duration, multiplier: f64) -> Self {
557 Self {
558 initial_delay,
559 max_delay,
560 multiplier,
561 }
562 }
563
564 pub fn with_defaults() -> Self {
568 Self::new(Duration::from_secs(1), Duration::from_secs(300), 2.0)
569 }
570
571 fn calculate_delay(&self, retry_count: u32) -> Duration {
573 let delay_secs =
574 self.initial_delay.as_secs_f64() * self.multiplier.powi(retry_count as i32);
575 let delay = Duration::from_secs_f64(delay_secs.min(self.max_delay.as_secs_f64()));
576
577 let jitter = (delay.as_secs_f64() * 0.25 * rand::random::<f64>()).round() as u64;
579 delay + Duration::from_secs(jitter)
580 }
581}
582
583#[async_trait]
584impl MessageMiddleware for BackoffMiddleware {
585 async fn before_publish(&self, _message: &mut Message) -> Result<()> {
586 Ok(())
588 }
589
590 async fn after_consume(&self, message: &mut Message) -> Result<()> {
591 let retry_count = message
593 .headers
594 .extra
595 .get("retries")
596 .and_then(|v| v.as_u64())
597 .unwrap_or(0) as u32;
598
599 let backoff_delay = self.calculate_delay(retry_count);
600
601 message.headers.extra.insert(
602 "x-backoff-delay".to_string(),
603 serde_json::json!(backoff_delay.as_millis() as u64),
604 );
605
606 message.headers.extra.insert(
607 "x-next-retry-at".to_string(),
608 serde_json::json!((std::time::SystemTime::now() + backoff_delay)
609 .duration_since(std::time::UNIX_EPOCH)
610 .unwrap()
611 .as_secs()),
612 );
613
614 Ok(())
615 }
616
617 fn name(&self) -> &str {
618 "backoff"
619 }
620}
621
622pub struct CachingMiddleware {
638 cache: std::sync::Arc<std::sync::Mutex<CacheMap>>,
639 max_entries: usize,
640 ttl: Duration,
641}
642
643type CacheMap = std::collections::HashMap<String, (Vec<u8>, std::time::Instant)>;
644
645impl CachingMiddleware {
646 pub fn new(max_entries: usize, ttl: Duration) -> Self {
653 Self {
654 cache: std::sync::Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
655 max_entries,
656 ttl,
657 }
658 }
659
660 pub fn with_defaults() -> Self {
664 Self::new(10_000, Duration::from_secs(3600))
665 }
666
667 fn cache_key(&self, message: &Message) -> String {
669 format!("{}:{}", message.headers.id, message.headers.task)
671 }
672
673 pub fn get_cached(&self, message: &Message) -> Option<Vec<u8>> {
675 let key = self.cache_key(message);
676 let mut cache = self.cache.lock().unwrap();
677
678 if let Some((result, timestamp)) = cache.get(&key) {
679 if timestamp.elapsed() < self.ttl {
680 return Some(result.clone());
681 }
682 cache.remove(&key);
684 }
685 None
686 }
687
688 pub fn store_result(&self, message: &Message, result: Vec<u8>) {
690 let key = self.cache_key(message);
691 let mut cache = self.cache.lock().unwrap();
692
693 if cache.len() >= self.max_entries {
695 let to_remove = cache.len() / 5; let mut entries: Vec<_> = cache.iter().map(|(k, v)| (k.clone(), v.1)).collect();
697 entries.sort_by_key(|(_, timestamp)| *timestamp);
698
699 for (key, _) in entries.iter().take(to_remove) {
700 cache.remove(key);
701 }
702 }
703
704 cache.insert(key, (result, std::time::Instant::now()));
705 }
706
707 pub fn clear(&self) {
709 self.cache.lock().unwrap().clear();
710 }
711
712 pub fn cache_size(&self) -> usize {
714 self.cache.lock().unwrap().len()
715 }
716}
717
718#[async_trait]
719impl MessageMiddleware for CachingMiddleware {
720 async fn before_publish(&self, _message: &mut Message) -> Result<()> {
721 Ok(())
723 }
724
725 async fn after_consume(&self, message: &mut Message) -> Result<()> {
726 if let Some(cached_result) = self.get_cached(message) {
728 message
729 .headers
730 .extra
731 .insert("x-cache-hit".to_string(), serde_json::json!(true));
732 message.headers.extra.insert(
733 "x-cached-result-size".to_string(),
734 serde_json::json!(cached_result.len()),
735 );
736 } else {
737 message
738 .headers
739 .extra
740 .insert("x-cache-hit".to_string(), serde_json::json!(false));
741 }
742 Ok(())
743 }
744
745 fn name(&self) -> &str {
746 "caching"
747 }
748}
749
750#[derive(Clone)]
772pub struct BulkheadMiddleware {
773 max_concurrent: usize,
774 permits: std::sync::Arc<std::sync::Mutex<HashMap<String, usize>>>,
775 partition_fn: std::sync::Arc<dyn Fn(&Message) -> String + Send + Sync>,
776}
777
778impl BulkheadMiddleware {
779 pub fn new(max_concurrent: usize) -> Self {
785 Self {
786 max_concurrent,
787 permits: std::sync::Arc::new(std::sync::Mutex::new(HashMap::new())),
788 partition_fn: std::sync::Arc::new(|msg| {
789 msg.headers.task.clone()
791 }),
792 }
793 }
794
795 pub fn with_partition_fn<F>(max_concurrent: usize, partition_fn: F) -> Self
797 where
798 F: Fn(&Message) -> String + Send + Sync + 'static,
799 {
800 Self {
801 max_concurrent,
802 permits: std::sync::Arc::new(std::sync::Mutex::new(HashMap::new())),
803 partition_fn: std::sync::Arc::new(partition_fn),
804 }
805 }
806
807 pub fn try_acquire(&self, partition: &str) -> bool {
809 let mut permits = self.permits.lock().unwrap();
810 let current = permits.entry(partition.to_string()).or_insert(0);
811 if *current < self.max_concurrent {
812 *current += 1;
813 true
814 } else {
815 false
816 }
817 }
818
819 pub fn release(&self, partition: &str) {
821 let mut permits = self.permits.lock().unwrap();
822 if let Some(current) = permits.get_mut(partition) {
823 if *current > 0 {
824 *current -= 1;
825 }
826 }
827 }
828
829 pub fn current_operations(&self, partition: &str) -> usize {
831 self.permits
832 .lock()
833 .unwrap()
834 .get(partition)
835 .copied()
836 .unwrap_or(0)
837 }
838
839 pub fn total_operations(&self) -> usize {
841 self.permits.lock().unwrap().values().sum()
842 }
843}
844
845#[async_trait]
846impl MessageMiddleware for BulkheadMiddleware {
847 async fn before_publish(&self, message: &mut Message) -> Result<()> {
848 let partition = (self.partition_fn)(message);
849 if !self.try_acquire(&partition) {
850 message
851 .headers
852 .extra
853 .insert("x-bulkhead-rejected".to_string(), serde_json::json!(true));
854 message.headers.extra.insert(
855 "x-bulkhead-partition".to_string(),
856 serde_json::json!(partition),
857 );
858 message.headers.extra.insert(
859 "x-bulkhead-current".to_string(),
860 serde_json::json!(self.max_concurrent),
861 );
862 } else {
863 message.headers.extra.insert(
864 "x-bulkhead-partition".to_string(),
865 serde_json::json!(partition),
866 );
867 }
868 Ok(())
869 }
870
871 async fn after_consume(&self, message: &mut Message) -> Result<()> {
872 let partition = (self.partition_fn)(message);
873 self.release(&partition);
874 Ok(())
875 }
876
877 fn name(&self) -> &str {
878 "bulkhead"
879 }
880}
881
882pub type PriorityBoostFn = std::sync::Arc<dyn Fn(&Message, Priority) -> Priority + Send + Sync>;
907
908#[derive(Clone)]
909pub struct PriorityBoostMiddleware {
910 age_threshold: Option<Duration>,
911 age_boost_priority: Priority,
912 retry_threshold: Option<u32>,
913 retry_boost_priority: Priority,
914 custom_fn: Option<PriorityBoostFn>,
915}
916
917impl PriorityBoostMiddleware {
918 pub fn new() -> Self {
920 Self {
921 age_threshold: None,
922 age_boost_priority: Priority::High,
923 retry_threshold: None,
924 retry_boost_priority: Priority::High,
925 custom_fn: None,
926 }
927 }
928
929 pub fn with_age_boost(mut self, threshold: Duration, priority: Priority) -> Self {
931 self.age_threshold = Some(threshold);
932 self.age_boost_priority = priority;
933 self
934 }
935
936 pub fn with_retry_boost(mut self, threshold: u32, priority: Priority) -> Self {
938 self.retry_threshold = Some(threshold);
939 self.retry_boost_priority = priority;
940 self
941 }
942
943 pub fn with_custom_fn<F>(custom_fn: F) -> Self
945 where
946 F: Fn(&Message, Priority) -> Priority + Send + Sync + 'static,
947 {
948 Self {
949 age_threshold: None,
950 age_boost_priority: Priority::High,
951 retry_threshold: None,
952 retry_boost_priority: Priority::High,
953 custom_fn: Some(std::sync::Arc::new(custom_fn)),
954 }
955 }
956
957 pub fn calculate_priority(&self, message: &Message, current_priority: Priority) -> Priority {
959 let mut priority = current_priority;
960
961 if let Some(ref custom_fn) = self.custom_fn {
963 return custom_fn(message, priority);
964 }
965
966 if let Some(retry_threshold) = self.retry_threshold {
968 if message.headers.retries.unwrap_or(0) >= retry_threshold {
969 priority = priority.max(self.retry_boost_priority);
970 }
971 }
972
973 if let Some(age_threshold) = self.age_threshold {
975 if let Some(timestamp_value) = message.headers.extra.get("timestamp") {
976 if let Some(timestamp_secs) = timestamp_value.as_f64() {
977 let msg_age = std::time::SystemTime::now()
978 .duration_since(std::time::UNIX_EPOCH)
979 .unwrap()
980 .as_secs_f64()
981 - timestamp_secs;
982 if msg_age > age_threshold.as_secs_f64() {
983 priority = priority.max(self.age_boost_priority);
984 }
985 }
986 }
987 }
988
989 priority
990 }
991}
992
993impl Default for PriorityBoostMiddleware {
994 fn default() -> Self {
995 Self::new()
996 }
997}
998
999#[async_trait]
1000impl MessageMiddleware for PriorityBoostMiddleware {
1001 async fn before_publish(&self, message: &mut Message) -> Result<()> {
1002 let current_priority = message
1004 .headers
1005 .extra
1006 .get("priority")
1007 .and_then(|v| v.as_u64())
1008 .map(|p| Priority::from_u8(p as u8))
1009 .unwrap_or(Priority::Normal);
1010
1011 let boosted_priority = self.calculate_priority(message, current_priority);
1012
1013 if boosted_priority != current_priority {
1014 message.headers.extra.insert(
1015 "priority".to_string(),
1016 serde_json::json!(boosted_priority.as_u8()),
1017 );
1018 message
1019 .headers
1020 .extra
1021 .insert("x-priority-boosted".to_string(), serde_json::json!(true));
1022 message.headers.extra.insert(
1023 "x-original-priority".to_string(),
1024 serde_json::json!(current_priority.as_u8()),
1025 );
1026 }
1027 Ok(())
1028 }
1029
1030 async fn after_consume(&self, _message: &mut Message) -> Result<()> {
1031 Ok(())
1033 }
1034
1035 fn name(&self) -> &str {
1036 "priority_boost"
1037 }
1038}