1use async_trait::async_trait;
4use celers_protocol::Message;
5use std::time::Duration;
6
7use uuid::Uuid;
8
9use crate::{BrokerError, BrokerMetrics, MessageMiddleware, Result};
10
11pub struct ValidationMiddleware {
35 max_body_size: Option<usize>,
37 require_task_name: bool,
39}
40
41impl ValidationMiddleware {
42 pub fn new() -> Self {
44 Self {
45 max_body_size: Some(10 * 1024 * 1024), require_task_name: true,
47 }
48 }
49
50 pub fn with_max_body_size(mut self, size: usize) -> Self {
52 self.max_body_size = Some(size);
53 self
54 }
55
56 pub fn without_body_size_limit(mut self) -> Self {
58 self.max_body_size = None;
59 self
60 }
61
62 pub fn with_require_task_name(mut self, require: bool) -> Self {
64 self.require_task_name = require;
65 self
66 }
67
68 fn validate_message(&self, message: &Message) -> Result<()> {
69 if self.require_task_name && message.task_name().is_empty() {
71 return Err(BrokerError::Configuration(
72 "Task name cannot be empty".to_string(),
73 ));
74 }
75
76 if let Some(max_size) = self.max_body_size {
78 if message.body.len() > max_size {
79 return Err(BrokerError::Configuration(format!(
80 "Message body size {} exceeds maximum {}",
81 message.body.len(),
82 max_size
83 )));
84 }
85 }
86
87 Ok(())
88 }
89}
90
91impl Default for ValidationMiddleware {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97#[async_trait]
98impl MessageMiddleware for ValidationMiddleware {
99 async fn before_publish(&self, message: &mut Message) -> Result<()> {
100 self.validate_message(message)
101 }
102
103 async fn after_consume(&self, message: &mut Message) -> Result<()> {
104 self.validate_message(message)
105 }
106
107 fn name(&self) -> &str {
108 "validation"
109 }
110}
111
112pub struct LoggingMiddleware {
127 prefix: String,
128 log_body: bool,
129}
130
131impl LoggingMiddleware {
132 pub fn new(prefix: impl Into<String>) -> Self {
134 Self {
135 prefix: prefix.into(),
136 log_body: false,
137 }
138 }
139
140 pub fn with_body_logging(mut self) -> Self {
142 self.log_body = true;
143 self
144 }
145}
146
147#[async_trait]
148impl MessageMiddleware for LoggingMiddleware {
149 async fn before_publish(&self, message: &mut Message) -> Result<()> {
150 if self.log_body {
151 eprintln!(
152 "[{}] Publishing: task={}, id={}, body_size={}",
153 self.prefix,
154 message.task_name(),
155 message.task_id(),
156 message.body.len()
157 );
158 } else {
159 eprintln!(
160 "[{}] Publishing: task={}, id={}",
161 self.prefix,
162 message.task_name(),
163 message.task_id()
164 );
165 }
166 Ok(())
167 }
168
169 async fn after_consume(&self, message: &mut Message) -> Result<()> {
170 if self.log_body {
171 eprintln!(
172 "[{}] Consumed: task={}, id={}, body_size={}",
173 self.prefix,
174 message.task_name(),
175 message.task_id(),
176 message.body.len()
177 );
178 } else {
179 eprintln!(
180 "[{}] Consumed: task={}, id={}",
181 self.prefix,
182 message.task_name(),
183 message.task_id()
184 );
185 }
186 Ok(())
187 }
188
189 fn name(&self) -> &str {
190 "logging"
191 }
192}
193
194pub struct MetricsMiddleware {
210 metrics: std::sync::Arc<std::sync::Mutex<BrokerMetrics>>,
211}
212
213impl MetricsMiddleware {
214 pub fn new(metrics: std::sync::Arc<std::sync::Mutex<BrokerMetrics>>) -> Self {
216 Self { metrics }
217 }
218
219 pub fn get_metrics(&self) -> BrokerMetrics {
221 self.metrics.lock().unwrap().clone()
222 }
223}
224
225#[async_trait]
226impl MessageMiddleware for MetricsMiddleware {
227 async fn before_publish(&self, _message: &mut Message) -> Result<()> {
228 let mut metrics = self.metrics.lock().unwrap();
229 metrics.inc_published();
230 Ok(())
231 }
232
233 async fn after_consume(&self, _message: &mut Message) -> Result<()> {
234 let mut metrics = self.metrics.lock().unwrap();
235 metrics.inc_consumed();
236 Ok(())
237 }
238
239 fn name(&self) -> &str {
240 "metrics"
241 }
242}
243
244pub struct RetryLimitMiddleware {
255 max_retries: u32,
256}
257
258impl RetryLimitMiddleware {
259 pub fn new(max_retries: u32) -> Self {
261 Self { max_retries }
262 }
263}
264
265#[async_trait]
266impl MessageMiddleware for RetryLimitMiddleware {
267 async fn before_publish(&self, _message: &mut Message) -> Result<()> {
268 Ok(())
270 }
271
272 async fn after_consume(&self, message: &mut Message) -> Result<()> {
273 let retries = message.headers.retries.unwrap_or(0);
275 if retries > self.max_retries {
276 return Err(BrokerError::Configuration(format!(
277 "Message exceeded maximum retries: {} > {}",
278 retries, self.max_retries
279 )));
280 }
281 Ok(())
282 }
283
284 fn name(&self) -> &str {
285 "retry_limit"
286 }
287}
288
289pub struct RateLimitingMiddleware {
300 max_rate: f64,
302 tokens: std::sync::Arc<std::sync::Mutex<TokenBucket>>,
304}
305
306struct TokenBucket {
308 tokens: f64,
310 capacity: f64,
312 refill_rate: f64,
314 last_refill: std::time::Instant,
316}
317
318impl TokenBucket {
319 fn new(capacity: f64, refill_rate: f64) -> Self {
320 Self {
321 tokens: capacity,
322 capacity,
323 refill_rate,
324 last_refill: std::time::Instant::now(),
325 }
326 }
327
328 fn try_consume(&mut self, tokens: f64) -> bool {
329 let now = std::time::Instant::now();
331 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
332 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity);
333 self.last_refill = now;
334
335 if self.tokens >= tokens {
337 self.tokens -= tokens;
338 true
339 } else {
340 false
341 }
342 }
343}
344
345impl RateLimitingMiddleware {
346 pub fn new(max_rate: f64) -> Self {
352 Self {
353 max_rate,
354 tokens: std::sync::Arc::new(std::sync::Mutex::new(TokenBucket::new(
355 max_rate, max_rate,
356 ))),
357 }
358 }
359}
360
361#[async_trait]
362impl MessageMiddleware for RateLimitingMiddleware {
363 async fn before_publish(&self, _message: &mut Message) -> Result<()> {
364 let mut bucket = self.tokens.lock().unwrap();
366 if !bucket.try_consume(1.0) {
367 return Err(BrokerError::OperationFailed(format!(
368 "Rate limit exceeded: {} messages/sec",
369 self.max_rate
370 )));
371 }
372 Ok(())
373 }
374
375 async fn after_consume(&self, _message: &mut Message) -> Result<()> {
376 Ok(())
378 }
379
380 fn name(&self) -> &str {
381 "rate_limit"
382 }
383}
384
385pub struct DeduplicationMiddleware {
399 seen_ids: std::sync::Arc<std::sync::Mutex<std::collections::HashSet<Uuid>>>,
401 max_cache_size: usize,
403}
404
405impl DeduplicationMiddleware {
406 pub fn new(max_cache_size: usize) -> Self {
412 Self {
413 seen_ids: std::sync::Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
414 max_cache_size,
415 }
416 }
417
418 pub fn with_default_cache() -> Self {
420 Self::new(10_000)
421 }
422}
423
424impl Default for DeduplicationMiddleware {
425 fn default() -> Self {
426 Self::with_default_cache()
427 }
428}
429
430#[async_trait]
431impl MessageMiddleware for DeduplicationMiddleware {
432 async fn before_publish(&self, _message: &mut Message) -> Result<()> {
433 Ok(())
435 }
436
437 async fn after_consume(&self, message: &mut Message) -> Result<()> {
438 let msg_id = message.task_id();
439 let mut seen = self.seen_ids.lock().unwrap();
440
441 if seen.contains(&msg_id) {
443 return Err(BrokerError::OperationFailed(format!(
444 "Duplicate message detected: {}",
445 msg_id
446 )));
447 }
448
449 seen.insert(msg_id);
451
452 if seen.len() > self.max_cache_size {
454 if let Some(&id) = seen.iter().next() {
456 seen.remove(&id);
457 }
458 }
459
460 Ok(())
461 }
462
463 fn name(&self) -> &str {
464 "deduplication"
465 }
466}
467
468#[cfg(feature = "compression")]
484pub struct CompressionMiddleware {
485 compressor: celers_protocol::compression::Compressor,
487 min_compress_size: usize,
489}
490
491#[cfg(feature = "compression")]
492impl CompressionMiddleware {
493 pub fn new(compression_type: celers_protocol::compression::CompressionType) -> Self {
499 Self {
500 compressor: celers_protocol::compression::Compressor::new(compression_type),
501 min_compress_size: 1024, }
503 }
504
505 pub fn with_min_size(mut self, size: usize) -> Self {
507 self.min_compress_size = size;
508 self
509 }
510
511 pub fn with_level(mut self, level: u32) -> Self {
513 self.compressor = self.compressor.with_level(level);
514 self
515 }
516}
517
518#[cfg(feature = "compression")]
519#[async_trait]
520impl MessageMiddleware for CompressionMiddleware {
521 async fn before_publish(&self, message: &mut Message) -> Result<()> {
522 if message.body.len() >= self.min_compress_size {
524 let compressed = self
525 .compressor
526 .compress(&message.body)
527 .map_err(|e| BrokerError::Serialization(e.to_string()))?;
528
529 if compressed.len() < message.body.len() {
531 message.body = compressed;
532 }
534 }
535 Ok(())
536 }
537
538 async fn after_consume(&self, message: &mut Message) -> Result<()> {
539 let _ = message;
543 Ok(())
544 }
545
546 fn name(&self) -> &str {
547 "compression"
548 }
549}
550
551#[cfg(feature = "signing")]
565pub struct SigningMiddleware {
566 signer: celers_protocol::auth::MessageSigner,
568}
569
570#[cfg(feature = "signing")]
571impl SigningMiddleware {
572 pub fn new(key: &[u8]) -> Self {
578 Self {
579 signer: celers_protocol::auth::MessageSigner::new(key),
580 }
581 }
582}
583
584#[cfg(feature = "signing")]
585#[async_trait]
586impl MessageMiddleware for SigningMiddleware {
587 async fn before_publish(&self, message: &mut Message) -> Result<()> {
588 let signature = self
590 .signer
591 .sign(&message.body)
592 .map_err(|e| BrokerError::OperationFailed(format!("signing failed: {}", e)))?;
593
594 let _ = signature;
598
599 Ok(())
600 }
601
602 async fn after_consume(&self, message: &mut Message) -> Result<()> {
603 let _ = self
610 .signer
611 .sign(&message.body)
612 .map_err(|e| BrokerError::OperationFailed(format!("signing failed: {}", e)))?;
613
614 Ok(())
615 }
616
617 fn name(&self) -> &str {
618 "signing"
619 }
620}
621
622#[cfg(feature = "encryption")]
637pub struct EncryptionMiddleware {
638 encryptor: celers_protocol::crypto::MessageEncryptor,
640}
641
642#[cfg(feature = "encryption")]
643impl EncryptionMiddleware {
644 pub fn new(key: &[u8]) -> Result<Self> {
654 let encryptor = celers_protocol::crypto::MessageEncryptor::new(key)
655 .map_err(|e| BrokerError::Configuration(e.to_string()))?;
656
657 Ok(Self { encryptor })
658 }
659}
660
661#[cfg(feature = "encryption")]
662#[async_trait]
663impl MessageMiddleware for EncryptionMiddleware {
664 async fn before_publish(&self, message: &mut Message) -> Result<()> {
665 let (ciphertext, nonce) = self
667 .encryptor
668 .encrypt(&message.body)
669 .map_err(|e| BrokerError::Serialization(e.to_string()))?;
670
671 let mut encrypted = nonce.to_vec();
674 encrypted.extend_from_slice(&ciphertext);
675 message.body = encrypted;
676
677 Ok(())
678 }
679
680 async fn after_consume(&self, message: &mut Message) -> Result<()> {
681 if message.body.len() < celers_protocol::crypto::NONCE_SIZE {
683 return Err(BrokerError::Serialization(
684 "Message too short to contain nonce".to_string(),
685 ));
686 }
687
688 let (nonce_bytes, ciphertext) = message.body.split_at(celers_protocol::crypto::NONCE_SIZE);
689
690 let plaintext = self
692 .encryptor
693 .decrypt(ciphertext, nonce_bytes)
694 .map_err(|e| BrokerError::Serialization(e.to_string()))?;
695
696 message.body = plaintext;
697 Ok(())
698 }
699
700 fn name(&self) -> &str {
701 "encryption"
702 }
703}
704
705pub struct TimeoutMiddleware {
717 timeout: Duration,
718}
719
720impl TimeoutMiddleware {
721 pub fn new(timeout: Duration) -> Self {
723 Self { timeout }
724 }
725
726 pub fn timeout(&self) -> Duration {
728 self.timeout
729 }
730}
731
732#[async_trait]
733impl MessageMiddleware for TimeoutMiddleware {
734 async fn before_publish(&self, message: &mut Message) -> Result<()> {
735 message.headers.extra.insert(
737 "x-timeout-ms".to_string(),
738 serde_json::Value::Number((self.timeout.as_millis() as u64).into()),
739 );
740 Ok(())
741 }
742
743 async fn after_consume(&self, _message: &mut Message) -> Result<()> {
744 Ok(())
747 }
748
749 fn name(&self) -> &str {
750 "timeout"
751 }
752}
753
754pub struct FilterMiddleware {
768 predicate: Box<dyn Fn(&Message) -> bool + Send + Sync>,
769}
770
771impl FilterMiddleware {
772 pub fn new<F>(predicate: F) -> Self
774 where
775 F: Fn(&Message) -> bool + Send + Sync + 'static,
776 {
777 Self {
778 predicate: Box::new(predicate),
779 }
780 }
781
782 pub fn matches(&self, message: &Message) -> bool {
784 (self.predicate)(message)
785 }
786}
787
788#[async_trait]
789impl MessageMiddleware for FilterMiddleware {
790 async fn before_publish(&self, _message: &mut Message) -> Result<()> {
791 Ok(())
793 }
794
795 async fn after_consume(&self, message: &mut Message) -> Result<()> {
796 if !self.matches(message) {
797 return Err(BrokerError::Configuration(
798 "Message filtered out by predicate".to_string(),
799 ));
800 }
801 Ok(())
802 }
803
804 fn name(&self) -> &str {
805 "filter"
806 }
807}
808
809pub struct SamplingMiddleware {
824 sample_rate: f64,
825 counter: std::sync::atomic::AtomicU64,
826}
827
828impl SamplingMiddleware {
829 pub fn new(sample_rate: f64) -> Self {
836 Self {
837 sample_rate: sample_rate.clamp(0.0, 1.0),
838 counter: std::sync::atomic::AtomicU64::new(0),
839 }
840 }
841
842 pub fn sample_rate(&self) -> f64 {
844 self.sample_rate
845 }
846
847 fn should_sample(&self) -> bool {
849 let count = self
850 .counter
851 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
852 let threshold = (u64::MAX as f64 * self.sample_rate) as u64;
854 (count % u64::MAX) < threshold
855 }
856}
857
858#[async_trait]
859impl MessageMiddleware for SamplingMiddleware {
860 async fn before_publish(&self, _message: &mut Message) -> Result<()> {
861 Ok(())
863 }
864
865 async fn after_consume(&self, _message: &mut Message) -> Result<()> {
866 if !self.should_sample() {
867 return Err(BrokerError::Configuration(
868 "Message filtered out by sampling".to_string(),
869 ));
870 }
871 Ok(())
872 }
873
874 fn name(&self) -> &str {
875 "sampling"
876 }
877}
878
879pub struct TransformationMiddleware {
894 transform_fn: Box<dyn Fn(Vec<u8>) -> Vec<u8> + Send + Sync>,
895}
896
897impl TransformationMiddleware {
898 pub fn new<F>(transform_fn: F) -> Self
900 where
901 F: Fn(Vec<u8>) -> Vec<u8> + Send + Sync + 'static,
902 {
903 Self {
904 transform_fn: Box::new(transform_fn),
905 }
906 }
907
908 fn transform(&self, body: Vec<u8>) -> Vec<u8> {
910 (self.transform_fn)(body)
911 }
912}
913
914#[async_trait]
915impl MessageMiddleware for TransformationMiddleware {
916 async fn before_publish(&self, message: &mut Message) -> Result<()> {
917 let transformed = self.transform(message.body.clone());
919 message.body = transformed;
920 Ok(())
921 }
922
923 async fn after_consume(&self, message: &mut Message) -> Result<()> {
924 let transformed = self.transform(message.body.clone());
926 message.body = transformed;
927 Ok(())
928 }
929
930 fn name(&self) -> &str {
931 "transformation"
932 }
933}
934
935#[derive(Debug, Clone)]
948pub struct TracingMiddleware {
949 service_name: String,
950}
951
952impl TracingMiddleware {
953 pub fn new(service_name: impl Into<String>) -> Self {
955 Self {
956 service_name: service_name.into(),
957 }
958 }
959}
960
961#[async_trait]
962impl MessageMiddleware for TracingMiddleware {
963 async fn before_publish(&self, message: &mut Message) -> Result<()> {
964 if !message.headers.extra.contains_key("trace-id") {
966 let trace_id = uuid::Uuid::new_v4().to_string();
967 message
968 .headers
969 .extra
970 .insert("trace-id".to_string(), serde_json::json!(trace_id));
971 }
972
973 message.headers.extra.insert(
975 "service-name".to_string(),
976 serde_json::json!(self.service_name.clone()),
977 );
978
979 let span_id = uuid::Uuid::new_v4().to_string();
981 message
982 .headers
983 .extra
984 .insert("span-id".to_string(), serde_json::json!(span_id));
985
986 message.headers.extra.insert(
988 "trace-timestamp".to_string(),
989 serde_json::json!(std::time::SystemTime::now()
990 .duration_since(std::time::UNIX_EPOCH)
991 .unwrap()
992 .as_millis()),
993 );
994
995 Ok(())
996 }
997
998 async fn after_consume(&self, message: &mut Message) -> Result<()> {
999 if let Some(trace_id) = message.headers.extra.get("trace-id").cloned() {
1001 message.headers.extra.insert(
1003 "consumer-service".to_string(),
1004 serde_json::json!(self.service_name.clone()),
1005 );
1006 message
1007 .headers
1008 .extra
1009 .insert("trace-id-consumed".to_string(), trace_id);
1010 }
1011 Ok(())
1012 }
1013
1014 fn name(&self) -> &str {
1015 "tracing"
1016 }
1017}