1use async_trait::async_trait;
4use celers_protocol::Message;
5use std::collections::HashMap;
6use std::time::Duration;
7
8use uuid::Uuid;
9
10use crate::{BrokerError, MessageMiddleware, Result};
11
12pub struct ErrorClassificationMiddleware {
28 transient_patterns: Vec<String>,
29 permanent_patterns: Vec<String>,
30 max_transient_retries: u32,
31 max_permanent_retries: u32,
32}
33
34impl ErrorClassificationMiddleware {
35 pub fn new() -> Self {
37 Self {
38 transient_patterns: vec![
39 "timeout".to_string(),
40 "connection".to_string(),
41 "network".to_string(),
42 "unavailable".to_string(),
43 ],
44 permanent_patterns: vec![
45 "validation".to_string(),
46 "schema".to_string(),
47 "invalid".to_string(),
48 "forbidden".to_string(),
49 ],
50 max_transient_retries: 10,
51 max_permanent_retries: 1,
52 }
53 }
54
55 pub fn with_transient_pattern(mut self, pattern: &str) -> Self {
57 self.transient_patterns.push(pattern.to_string());
58 self
59 }
60
61 pub fn with_permanent_pattern(mut self, pattern: &str) -> Self {
63 self.permanent_patterns.push(pattern.to_string());
64 self
65 }
66
67 pub fn with_max_transient_retries(mut self, max_retries: u32) -> Self {
69 self.max_transient_retries = max_retries;
70 self
71 }
72
73 pub fn with_max_permanent_retries(mut self, max_retries: u32) -> Self {
75 self.max_permanent_retries = max_retries;
76 self
77 }
78
79 pub fn classify_error(&self, error_msg: &str) -> ErrorClass {
81 let error_lower = error_msg.to_lowercase();
82
83 for pattern in &self.permanent_patterns {
85 if error_lower.contains(&pattern.to_lowercase()) {
86 return ErrorClass::Permanent;
87 }
88 }
89
90 for pattern in &self.transient_patterns {
92 if error_lower.contains(&pattern.to_lowercase()) {
93 return ErrorClass::Transient;
94 }
95 }
96
97 ErrorClass::Unknown
99 }
100
101 pub fn should_retry(&self, error_msg: &str, current_retries: u32) -> bool {
103 match self.classify_error(error_msg) {
104 ErrorClass::Transient => current_retries < self.max_transient_retries,
105 ErrorClass::Permanent => current_retries < self.max_permanent_retries,
106 ErrorClass::Unknown => current_retries < self.max_transient_retries,
107 }
108 }
109}
110
111impl Default for ErrorClassificationMiddleware {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
119pub enum ErrorClass {
120 Transient,
122 Permanent,
124 Unknown,
126}
127
128#[async_trait]
129impl MessageMiddleware for ErrorClassificationMiddleware {
130 async fn before_publish(&self, message: &mut Message) -> Result<()> {
131 if let Some(error_value) = message.headers.extra.get("error") {
133 if let Some(error_msg) = error_value.as_str() {
134 let error_class = self.classify_error(error_msg);
135 let should_retry =
136 self.should_retry(error_msg, message.headers.retries.unwrap_or(0));
137
138 message.headers.extra.insert(
139 "x-error-class".to_string(),
140 serde_json::json!(match error_class {
141 ErrorClass::Transient => "transient",
142 ErrorClass::Permanent => "permanent",
143 ErrorClass::Unknown => "unknown",
144 }),
145 );
146
147 message.headers.extra.insert(
148 "x-should-retry".to_string(),
149 serde_json::json!(should_retry),
150 );
151
152 if !should_retry {
153 message.headers.extra.insert(
154 "x-max-retries-exceeded".to_string(),
155 serde_json::json!(true),
156 );
157 }
158 }
159 }
160 Ok(())
161 }
162
163 async fn after_consume(&self, _message: &mut Message) -> Result<()> {
164 Ok(())
166 }
167
168 fn name(&self) -> &str {
169 "error_classification"
170 }
171}
172
173pub struct CorrelationMiddleware {
188 header_name: String,
189}
190
191impl CorrelationMiddleware {
192 pub fn new() -> Self {
194 Self {
195 header_name: "x-correlation-id".to_string(),
196 }
197 }
198
199 pub fn with_header_name(header_name: &str) -> Self {
201 Self {
202 header_name: header_name.to_string(),
203 }
204 }
205
206 fn get_or_generate_correlation_id(&self, message: &Message) -> String {
207 message
208 .headers
209 .extra
210 .get(&self.header_name)
211 .and_then(|v| v.as_str())
212 .map(|s| s.to_string())
213 .unwrap_or_else(|| Uuid::new_v4().to_string())
214 }
215}
216
217impl Default for CorrelationMiddleware {
218 fn default() -> Self {
219 Self::new()
220 }
221}
222
223#[async_trait]
224impl MessageMiddleware for CorrelationMiddleware {
225 async fn before_publish(&self, message: &mut Message) -> Result<()> {
226 let correlation_id = self.get_or_generate_correlation_id(message);
227 message
228 .headers
229 .extra
230 .insert(self.header_name.clone(), serde_json::json!(correlation_id));
231 Ok(())
232 }
233
234 async fn after_consume(&self, message: &mut Message) -> Result<()> {
235 let correlation_id = self.get_or_generate_correlation_id(message);
237 message
238 .headers
239 .extra
240 .insert(self.header_name.clone(), serde_json::json!(correlation_id));
241 Ok(())
242 }
243
244 fn name(&self) -> &str {
245 "correlation"
246 }
247}
248
249pub struct ThrottlingMiddleware {
265 pub(crate) max_rate: f64,
266 pub(crate) burst_size: usize,
267 pub(crate) backpressure_threshold: f64,
268 last_refill: std::sync::Mutex<std::time::Instant>,
269 available_tokens: std::sync::Mutex<f64>,
270}
271
272impl ThrottlingMiddleware {
273 pub fn new(max_rate: f64) -> Self {
275 Self {
276 max_rate,
277 burst_size: (max_rate * 2.0) as usize,
278 backpressure_threshold: 0.8,
279 last_refill: std::sync::Mutex::new(std::time::Instant::now()),
280 available_tokens: std::sync::Mutex::new(max_rate),
281 }
282 }
283
284 pub fn with_burst_size(mut self, size: usize) -> Self {
286 self.burst_size = size;
287 self
288 }
289
290 pub fn with_backpressure_threshold(mut self, threshold: f64) -> Self {
292 self.backpressure_threshold = threshold.clamp(0.0, 1.0);
293 self
294 }
295
296 fn refill_tokens(&self) {
297 let mut last_refill = self.last_refill.lock().unwrap();
298 let mut tokens = self.available_tokens.lock().unwrap();
299
300 let now = std::time::Instant::now();
301 let elapsed = now.duration_since(*last_refill).as_secs_f64();
302
303 let new_tokens = elapsed * self.max_rate;
304 *tokens = (*tokens + new_tokens).min(self.burst_size as f64);
305 *last_refill = now;
306 }
307
308 fn calculate_delay(&self) -> Duration {
309 self.refill_tokens();
310 let tokens = self.available_tokens.lock().unwrap();
311
312 if *tokens >= 1.0 {
313 Duration::from_millis(0)
314 } else {
315 let wait_time = (1.0 - *tokens) / self.max_rate;
316 Duration::from_secs_f64(wait_time)
317 }
318 }
319
320 fn should_apply_backpressure(&self) -> bool {
321 self.refill_tokens();
322 let tokens = self.available_tokens.lock().unwrap();
323 (*tokens / self.burst_size as f64) < (1.0 - self.backpressure_threshold)
324 }
325}
326
327#[async_trait]
328impl MessageMiddleware for ThrottlingMiddleware {
329 async fn before_publish(&self, message: &mut Message) -> Result<()> {
330 let delay = self.calculate_delay();
331
332 if delay > Duration::from_millis(0) {
333 message.headers.extra.insert(
334 "x-throttle-delay-ms".to_string(),
335 serde_json::json!(delay.as_millis()),
336 );
337 }
338
339 if self.should_apply_backpressure() {
340 message
341 .headers
342 .extra
343 .insert("x-backpressure-active".to_string(), serde_json::json!(true));
344 }
345
346 let mut tokens = self.available_tokens.lock().unwrap();
348 if *tokens >= 1.0 {
349 *tokens -= 1.0;
350 }
351
352 Ok(())
353 }
354
355 async fn after_consume(&self, _message: &mut Message) -> Result<()> {
356 Ok(())
358 }
359
360 fn name(&self) -> &str {
361 "throttling"
362 }
363}
364
365pub struct CircuitBreakerMiddleware {
380 pub(crate) failure_threshold: usize,
381 window: Duration,
382 failures: std::sync::Mutex<Vec<std::time::Instant>>,
383}
384
385impl CircuitBreakerMiddleware {
386 pub fn new(failure_threshold: usize, window: Duration) -> Self {
388 Self {
389 failure_threshold,
390 window,
391 failures: std::sync::Mutex::new(Vec::new()),
392 }
393 }
394
395 fn record_failure(&self) {
396 let mut failures = self.failures.lock().unwrap();
397 let now = std::time::Instant::now();
398
399 failures.retain(|&f| now.duration_since(f) < self.window);
401
402 failures.push(now);
404 }
405
406 fn is_circuit_open(&self) -> bool {
407 let mut failures = self.failures.lock().unwrap();
408 let now = std::time::Instant::now();
409
410 failures.retain(|&f| now.duration_since(f) < self.window);
412
413 failures.len() >= self.failure_threshold
414 }
415
416 fn get_failure_count(&self) -> usize {
417 let mut failures = self.failures.lock().unwrap();
418 let now = std::time::Instant::now();
419 failures.retain(|&f| now.duration_since(f) < self.window);
420 failures.len()
421 }
422}
423
424#[async_trait]
425impl MessageMiddleware for CircuitBreakerMiddleware {
426 async fn before_publish(&self, message: &mut Message) -> Result<()> {
427 if self.is_circuit_open() {
428 message.headers.extra.insert(
429 "x-circuit-breaker-open".to_string(),
430 serde_json::json!(true),
431 );
432 message.headers.extra.insert(
433 "x-circuit-breaker-failures".to_string(),
434 serde_json::json!(self.get_failure_count()),
435 );
436 return Err(BrokerError::OperationFailed(
437 "Circuit breaker is open".to_string(),
438 ));
439 }
440 Ok(())
441 }
442
443 async fn after_consume(&self, message: &mut Message) -> Result<()> {
444 if message.headers.extra.contains_key("error") {
446 self.record_failure();
447 }
448
449 message.headers.extra.insert(
451 "x-circuit-breaker-failures".to_string(),
452 serde_json::json!(self.get_failure_count()),
453 );
454
455 Ok(())
456 }
457
458 fn name(&self) -> &str {
459 "circuit_breaker"
460 }
461}
462
463pub struct SchemaValidationMiddleware {
479 pub(crate) required_fields: Vec<String>,
480 pub(crate) max_field_count: Option<usize>,
481 min_body_size: Option<usize>,
482 pub(crate) max_body_size: Option<usize>,
483}
484
485impl SchemaValidationMiddleware {
486 pub fn new() -> Self {
488 Self {
489 required_fields: Vec::new(),
490 max_field_count: None,
491 min_body_size: None,
492 max_body_size: None,
493 }
494 }
495
496 pub fn with_required_field(mut self, field: impl Into<String>) -> Self {
498 self.required_fields.push(field.into());
499 self
500 }
501
502 pub fn with_max_field_count(mut self, count: usize) -> Self {
504 self.max_field_count = Some(count);
505 self
506 }
507
508 pub fn with_min_body_size(mut self, size: usize) -> Self {
510 self.min_body_size = Some(size);
511 self
512 }
513
514 pub fn with_max_body_size(mut self, size: usize) -> Self {
516 self.max_body_size = Some(size);
517 self
518 }
519
520 fn validate_message(&self, message: &Message) -> Result<()> {
521 for field in &self.required_fields {
523 if !message.headers.extra.contains_key(field) {
524 return Err(BrokerError::Configuration(format!(
525 "Missing required field: {}",
526 field
527 )));
528 }
529 }
530
531 if let Some(max) = self.max_field_count {
533 if message.headers.extra.len() > max {
534 return Err(BrokerError::Configuration(format!(
535 "Too many fields: {} > {}",
536 message.headers.extra.len(),
537 max
538 )));
539 }
540 }
541
542 let body_len = message.body.len();
544 if let Some(min) = self.min_body_size {
545 if body_len < min {
546 return Err(BrokerError::Configuration(format!(
547 "Body too small: {} < {}",
548 body_len, min
549 )));
550 }
551 }
552 if let Some(max) = self.max_body_size {
553 if body_len > max {
554 return Err(BrokerError::Configuration(format!(
555 "Body too large: {} > {}",
556 body_len, max
557 )));
558 }
559 }
560
561 Ok(())
562 }
563}
564
565impl Default for SchemaValidationMiddleware {
566 fn default() -> Self {
567 Self::new()
568 }
569}
570
571#[async_trait]
572impl MessageMiddleware for SchemaValidationMiddleware {
573 async fn before_publish(&self, message: &mut Message) -> Result<()> {
574 self.validate_message(message)?;
575 message
576 .headers
577 .extra
578 .insert("x-schema-validated".to_string(), serde_json::json!(true));
579 Ok(())
580 }
581
582 async fn after_consume(&self, message: &mut Message) -> Result<()> {
583 self.validate_message(message)
584 }
585
586 fn name(&self) -> &str {
587 "schema_validation"
588 }
589}
590
591pub struct MessageEnrichmentMiddleware {
608 pub(crate) hostname: Option<String>,
609 pub(crate) environment: Option<String>,
610 pub(crate) version: Option<String>,
611 pub(crate) add_timestamp: bool,
612 custom_metadata: HashMap<String, serde_json::Value>,
613}
614
615impl MessageEnrichmentMiddleware {
616 pub fn new() -> Self {
618 Self {
619 hostname: None,
620 environment: None,
621 version: None,
622 add_timestamp: false,
623 custom_metadata: HashMap::new(),
624 }
625 }
626
627 pub fn with_hostname(mut self, hostname: impl Into<String>) -> Self {
629 self.hostname = Some(hostname.into());
630 self
631 }
632
633 pub fn with_environment(mut self, environment: impl Into<String>) -> Self {
635 self.environment = Some(environment.into());
636 self
637 }
638
639 pub fn with_version(mut self, version: impl Into<String>) -> Self {
641 self.version = Some(version.into());
642 self
643 }
644
645 pub fn with_add_timestamp(mut self, add: bool) -> Self {
647 self.add_timestamp = add;
648 self
649 }
650
651 pub fn with_custom_metadata(
653 mut self,
654 key: impl Into<String>,
655 value: serde_json::Value,
656 ) -> Self {
657 self.custom_metadata.insert(key.into(), value);
658 self
659 }
660
661 fn enrich_message(&self, message: &mut Message) {
662 if let Some(ref hostname) = self.hostname {
663 message.headers.extra.insert(
664 "x-enrichment-hostname".to_string(),
665 serde_json::json!(hostname),
666 );
667 }
668
669 if let Some(ref environment) = self.environment {
670 message.headers.extra.insert(
671 "x-enrichment-environment".to_string(),
672 serde_json::json!(environment),
673 );
674 }
675
676 if let Some(ref version) = self.version {
677 message.headers.extra.insert(
678 "x-enrichment-version".to_string(),
679 serde_json::json!(version),
680 );
681 }
682
683 if self.add_timestamp {
684 let timestamp = std::time::SystemTime::now()
685 .duration_since(std::time::UNIX_EPOCH)
686 .unwrap()
687 .as_secs();
688 message.headers.extra.insert(
689 "x-enrichment-timestamp".to_string(),
690 serde_json::json!(timestamp),
691 );
692 }
693
694 for (key, value) in &self.custom_metadata {
695 message
696 .headers
697 .extra
698 .insert(format!("x-enrichment-{}", key), value.clone());
699 }
700 }
701}
702
703impl Default for MessageEnrichmentMiddleware {
704 fn default() -> Self {
705 Self::new()
706 }
707}
708
709#[async_trait]
710impl MessageMiddleware for MessageEnrichmentMiddleware {
711 async fn before_publish(&self, message: &mut Message) -> Result<()> {
712 self.enrich_message(message);
713 Ok(())
714 }
715
716 async fn after_consume(&self, _message: &mut Message) -> Result<()> {
717 Ok(())
719 }
720
721 fn name(&self) -> &str {
722 "message_enrichment"
723 }
724}
725
726#[derive(Debug, Clone, Copy, PartialEq, Eq)]
728pub enum RetryStrategy {
729 Exponential,
731 Linear,
733 Fibonacci,
735 Fixed,
737}
738
739pub struct RetryStrategyMiddleware {
756 strategy: RetryStrategy,
757 base_delay_ms: u64,
758 max_delay_ms: u64,
759 max_retries: u32,
760}
761
762impl RetryStrategyMiddleware {
763 pub fn new(strategy: RetryStrategy) -> Self {
765 Self {
766 strategy,
767 base_delay_ms: 1000, max_delay_ms: 300_000, max_retries: 5,
770 }
771 }
772
773 pub fn with_base_delay(mut self, delay: Duration) -> Self {
775 self.base_delay_ms = delay.as_millis() as u64;
776 self
777 }
778
779 pub fn with_max_delay(mut self, delay: Duration) -> Self {
781 self.max_delay_ms = delay.as_millis() as u64;
782 self
783 }
784
785 pub fn with_max_retries(mut self, retries: u32) -> Self {
787 self.max_retries = retries;
788 self
789 }
790
791 fn calculate_delay(&self, retry_count: u32) -> u64 {
792 let delay = match self.strategy {
793 RetryStrategy::Exponential => {
794 self.base_delay_ms * 2_u64.pow(retry_count)
796 }
797 RetryStrategy::Linear => {
798 self.base_delay_ms * (retry_count as u64 + 1)
800 }
801 RetryStrategy::Fibonacci => {
802 let fib = self.fibonacci(retry_count as usize);
804 self.base_delay_ms * fib
805 }
806 RetryStrategy::Fixed => {
807 self.base_delay_ms
809 }
810 };
811
812 delay.min(self.max_delay_ms)
813 }
814
815 fn fibonacci(&self, n: usize) -> u64 {
816 match n {
817 0 => 1,
818 1 => 1,
819 _ => {
820 let mut a = 1u64;
821 let mut b = 1u64;
822 for _ in 2..=n {
823 let temp = a + b;
824 a = b;
825 b = temp;
826 }
827 b
828 }
829 }
830 }
831}
832
833impl Default for RetryStrategyMiddleware {
834 fn default() -> Self {
835 Self::new(RetryStrategy::Exponential)
836 }
837}
838
839#[async_trait]
840impl MessageMiddleware for RetryStrategyMiddleware {
841 async fn before_publish(&self, message: &mut Message) -> Result<()> {
842 let retry_count = message
844 .headers
845 .extra
846 .get("x-retry-count")
847 .and_then(|v| v.as_u64())
848 .unwrap_or(0) as u32;
849
850 if retry_count >= self.max_retries {
852 return Err(BrokerError::OperationFailed(format!(
853 "Max retries ({}) exceeded",
854 self.max_retries
855 )));
856 }
857
858 let delay_ms = self.calculate_delay(retry_count);
860 message
861 .headers
862 .extra
863 .insert("x-retry-delay-ms".to_string(), serde_json::json!(delay_ms));
864 message.headers.extra.insert(
865 "x-retry-strategy".to_string(),
866 serde_json::json!(format!("{:?}", self.strategy)),
867 );
868
869 Ok(())
870 }
871
872 async fn after_consume(&self, _message: &mut Message) -> Result<()> {
873 Ok(())
875 }
876
877 fn name(&self) -> &str {
878 "retry_strategy"
879 }
880}
881
882pub struct TenantIsolationMiddleware {
897 required: bool,
898 tenant_header: String,
899 allowed_tenants: Option<Vec<String>>,
900}
901
902impl TenantIsolationMiddleware {
903 pub fn new() -> Self {
905 Self {
906 required: true,
907 tenant_header: "x-tenant-id".to_string(),
908 allowed_tenants: None,
909 }
910 }
911
912 pub fn with_required_tenant(mut self, required: bool) -> Self {
914 self.required = required;
915 self
916 }
917
918 pub fn with_tenant_header(mut self, header: impl Into<String>) -> Self {
920 self.tenant_header = header.into();
921 self
922 }
923
924 pub fn with_allowed_tenants(mut self, tenants: Vec<String>) -> Self {
926 self.allowed_tenants = Some(tenants);
927 self
928 }
929
930 fn validate_tenant(&self, tenant_id: Option<&str>) -> Result<()> {
931 if self.required && tenant_id.is_none() {
933 return Err(BrokerError::Configuration(format!(
934 "Missing required tenant header: {}",
935 self.tenant_header
936 )));
937 }
938
939 if let (Some(tenant), Some(allowed)) = (tenant_id, &self.allowed_tenants) {
941 if !allowed.contains(&tenant.to_string()) {
942 return Err(BrokerError::Configuration(format!(
943 "Tenant '{}' not in allowed list",
944 tenant
945 )));
946 }
947 }
948
949 Ok(())
950 }
951}
952
953impl Default for TenantIsolationMiddleware {
954 fn default() -> Self {
955 Self::new()
956 }
957}
958
959#[async_trait]
960impl MessageMiddleware for TenantIsolationMiddleware {
961 async fn before_publish(&self, message: &mut Message) -> Result<()> {
962 let tenant_id = message
963 .headers
964 .extra
965 .get(&self.tenant_header)
966 .and_then(|v| v.as_str());
967
968 self.validate_tenant(tenant_id)?;
969
970 message
972 .headers
973 .extra
974 .insert("x-tenant-validated".to_string(), serde_json::json!(true));
975
976 Ok(())
977 }
978
979 async fn after_consume(&self, message: &mut Message) -> Result<()> {
980 let tenant_id = message
981 .headers
982 .extra
983 .get(&self.tenant_header)
984 .and_then(|v| v.as_str());
985
986 self.validate_tenant(tenant_id)
987 }
988
989 fn name(&self) -> &str {
990 "tenant_isolation"
991 }
992}
993
994#[derive(Debug, Clone)]
1008pub struct PartitioningMiddleware {
1009 partition_count: usize,
1010 partition_header: String,
1011 partition_key_fn: Option<String>, }
1013
1014impl PartitioningMiddleware {
1015 pub fn new(partition_count: usize) -> Self {
1017 Self {
1018 partition_count: partition_count.max(1),
1019 partition_header: "x-partition-id".to_string(),
1020 partition_key_fn: None,
1021 }
1022 }
1023
1024 pub fn with_partition_header(mut self, header: impl Into<String>) -> Self {
1026 self.partition_header = header.into();
1027 self
1028 }
1029
1030 pub fn with_partition_key_field(mut self, field: impl Into<String>) -> Self {
1032 self.partition_key_fn = Some(field.into());
1033 self
1034 }
1035
1036 pub fn partition_count(&self) -> usize {
1038 self.partition_count
1039 }
1040
1041 fn calculate_partition(&self, message: &Message) -> usize {
1042 use std::collections::hash_map::DefaultHasher;
1043 use std::hash::{Hash, Hasher};
1044
1045 let task_id_str = message.headers.id.to_string();
1047 let key = if let Some(field) = &self.partition_key_fn {
1048 message
1049 .headers
1050 .extra
1051 .get(field)
1052 .and_then(|v| v.as_str())
1053 .unwrap_or(&task_id_str)
1054 } else {
1055 &task_id_str
1057 };
1058
1059 let mut hasher = DefaultHasher::new();
1061 key.hash(&mut hasher);
1062 let hash = hasher.finish();
1063
1064 (hash % self.partition_count as u64) as usize
1065 }
1066}
1067
1068impl Default for PartitioningMiddleware {
1069 fn default() -> Self {
1070 Self::new(4) }
1072}
1073
1074#[async_trait]
1075impl MessageMiddleware for PartitioningMiddleware {
1076 async fn before_publish(&self, message: &mut Message) -> Result<()> {
1077 let partition_id = self.calculate_partition(message);
1078
1079 message.headers.extra.insert(
1081 self.partition_header.clone(),
1082 serde_json::json!(partition_id),
1083 );
1084
1085 message.headers.extra.insert(
1087 "x-partition-count".to_string(),
1088 serde_json::json!(self.partition_count),
1089 );
1090
1091 Ok(())
1092 }
1093
1094 async fn after_consume(&self, _message: &mut Message) -> Result<()> {
1095 Ok(())
1097 }
1098
1099 fn name(&self) -> &str {
1100 "partitioning"
1101 }
1102}
1103
1104#[derive(Debug, Clone)]
1119pub struct AdaptiveTimeoutMiddleware {
1120 base_timeout: Duration,
1121 min_timeout: Duration,
1122 max_timeout: Duration,
1123 samples: Vec<u64>, #[allow(dead_code)]
1125 max_samples: usize,
1126 percentile: f64, }
1128
1129impl AdaptiveTimeoutMiddleware {
1130 pub fn new(base_timeout: Duration) -> Self {
1132 Self {
1133 base_timeout,
1134 min_timeout: Duration::from_secs(1),
1135 max_timeout: base_timeout.mul_f64(5.0), samples: Vec::new(),
1137 max_samples: 100,
1138 percentile: 0.95, }
1140 }
1141
1142 pub fn with_min_timeout(mut self, timeout: Duration) -> Self {
1144 self.min_timeout = timeout;
1145 self
1146 }
1147
1148 pub fn with_max_timeout(mut self, timeout: Duration) -> Self {
1150 self.max_timeout = timeout;
1151 self
1152 }
1153
1154 pub fn with_percentile(mut self, percentile: f64) -> Self {
1156 self.percentile = percentile.clamp(0.0, 1.0);
1157 self
1158 }
1159
1160 pub fn has_samples(&self) -> bool {
1162 !self.samples.is_empty()
1163 }
1164
1165 pub fn calculate_adaptive_timeout(&self) -> Duration {
1167 if self.samples.is_empty() {
1168 return self.base_timeout;
1169 }
1170
1171 let mut sorted_samples = self.samples.clone();
1172 sorted_samples.sort_unstable();
1173
1174 let index = ((sorted_samples.len() as f64 * self.percentile) as usize)
1175 .min(sorted_samples.len() - 1);
1176 let timeout_ms = sorted_samples[index];
1177
1178 let buffered_ms = (timeout_ms as f64 * 1.2) as u64;
1180
1181 let timeout = Duration::from_millis(buffered_ms);
1182
1183 timeout.clamp(self.min_timeout, self.max_timeout)
1185 }
1186}
1187
1188impl Default for AdaptiveTimeoutMiddleware {
1189 fn default() -> Self {
1190 Self::new(Duration::from_secs(30))
1191 }
1192}
1193
1194#[async_trait]
1195impl MessageMiddleware for AdaptiveTimeoutMiddleware {
1196 async fn before_publish(&self, message: &mut Message) -> Result<()> {
1197 let timeout = self.calculate_adaptive_timeout();
1198
1199 message.headers.extra.insert(
1201 "x-adaptive-timeout".to_string(),
1202 serde_json::json!(timeout.as_millis() as u64),
1203 );
1204
1205 message.headers.extra.insert(
1207 "x-timeout-percentile".to_string(),
1208 serde_json::json!(self.percentile),
1209 );
1210
1211 Ok(())
1212 }
1213
1214 async fn after_consume(&self, _message: &mut Message) -> Result<()> {
1215 Ok(())
1218 }
1219
1220 fn name(&self) -> &str {
1221 "adaptive_timeout"
1222 }
1223}