1use crate::error::Result;
12use crate::shared::TransportMessage;
13use crate::types::{JSONRPCRequest, JSONRPCResponse};
14use async_trait::async_trait;
15use dashmap::DashMap;
16use parking_lot::RwLock;
17use std::fmt;
18use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21
22#[derive(Debug, Clone)]
24pub struct MiddlewareContext {
25 pub request_id: Option<String>,
27 pub metadata: Arc<DashMap<String, String>>,
29 pub metrics: Arc<PerformanceMetrics>,
31 pub start_time: Instant,
33 pub priority: Option<crate::shared::transport::MessagePriority>,
35}
36
37impl Default for MiddlewareContext {
38 fn default() -> Self {
39 Self {
40 request_id: None,
41 metadata: Arc::new(DashMap::new()),
42 metrics: Arc::new(PerformanceMetrics::new()),
43 start_time: Instant::now(),
44 priority: None,
45 }
46 }
47}
48
49impl MiddlewareContext {
50 pub fn with_request_id(request_id: String) -> Self {
52 Self {
53 request_id: Some(request_id),
54 ..Default::default()
55 }
56 }
57
58 pub fn set_metadata(&self, key: String, value: String) {
60 self.metadata.insert(key, value);
61 }
62
63 pub fn get_metadata(&self, key: &str) -> Option<String> {
65 self.metadata.get(key).map(|v| v.clone())
66 }
67
68 pub fn record_metric(&self, name: String, value: f64) {
70 self.metrics.record(name, value);
71 }
72
73 pub fn elapsed(&self) -> Duration {
75 self.start_time.elapsed()
76 }
77}
78
79#[derive(Debug, Default)]
81pub struct PerformanceMetrics {
82 metrics: DashMap<String, f64>,
84 request_count: AtomicU64,
86 error_count: AtomicU64,
88 total_time_us: AtomicU64,
90}
91
92impl PerformanceMetrics {
93 pub fn new() -> Self {
95 Self::default()
96 }
97
98 pub fn record(&self, name: String, value: f64) {
100 self.metrics.insert(name, value);
101 }
102
103 pub fn get(&self, name: &str) -> Option<f64> {
105 self.metrics.get(name).map(|v| *v)
106 }
107
108 pub fn inc_requests(&self) {
110 self.request_count.fetch_add(1, Ordering::Relaxed);
111 }
112
113 pub fn inc_errors(&self) {
115 self.error_count.fetch_add(1, Ordering::Relaxed);
116 }
117
118 pub fn add_time(&self, duration: Duration) {
120 self.total_time_us
121 .fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
122 }
123
124 pub fn request_count(&self) -> u64 {
126 self.request_count.load(Ordering::Relaxed)
127 }
128
129 pub fn error_count(&self) -> u64 {
131 self.error_count.load(Ordering::Relaxed)
132 }
133
134 pub fn average_time(&self) -> Duration {
136 let total_time = self.total_time_us.load(Ordering::Relaxed);
137 let count = self.request_count.load(Ordering::Relaxed);
138 if count > 0 {
139 Duration::from_micros(total_time / count)
140 } else {
141 Duration::ZERO
142 }
143 }
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
148pub enum MiddlewarePriority {
149 Critical = 0,
151 High = 1,
153 Normal = 2,
155 Low = 3,
157 Lowest = 4,
159}
160
161impl Default for MiddlewarePriority {
162 fn default() -> Self {
163 Self::Normal
164 }
165}
166
167#[async_trait]
169pub trait AdvancedMiddleware: Send + Sync {
170 fn priority(&self) -> MiddlewarePriority {
172 MiddlewarePriority::Normal
173 }
174
175 fn name(&self) -> &'static str {
177 "unknown"
178 }
179
180 async fn should_execute(&self, _context: &MiddlewareContext) -> bool {
182 true
183 }
184
185 async fn on_request_with_context(
187 &self,
188 request: &mut JSONRPCRequest,
189 context: &MiddlewareContext,
190 ) -> Result<()> {
191 let _ = (request, context);
192 Ok(())
193 }
194
195 async fn on_response_with_context(
197 &self,
198 response: &mut JSONRPCResponse,
199 context: &MiddlewareContext,
200 ) -> Result<()> {
201 let _ = (response, context);
202 Ok(())
203 }
204
205 async fn on_send_with_context(
207 &self,
208 message: &TransportMessage,
209 context: &MiddlewareContext,
210 ) -> Result<()> {
211 let _ = (message, context);
212 Ok(())
213 }
214
215 async fn on_receive_with_context(
217 &self,
218 message: &TransportMessage,
219 context: &MiddlewareContext,
220 ) -> Result<()> {
221 let _ = (message, context);
222 Ok(())
223 }
224
225 async fn on_chain_start(&self, _context: &MiddlewareContext) -> Result<()> {
227 Ok(())
228 }
229
230 async fn on_chain_complete(&self, _context: &MiddlewareContext) -> Result<()> {
232 Ok(())
233 }
234
235 async fn on_error(
237 &self,
238 _error: &crate::error::Error,
239 _context: &MiddlewareContext,
240 ) -> Result<()> {
241 Ok(())
242 }
243}
244
245#[async_trait]
299pub trait Middleware: Send + Sync {
300 async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
302 let _ = request;
303 Ok(())
304 }
305
306 async fn on_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
308 let _ = response;
309 Ok(())
310 }
311
312 async fn on_send(&self, message: &TransportMessage) -> Result<()> {
314 let _ = message;
315 Ok(())
316 }
317
318 async fn on_receive(&self, message: &TransportMessage) -> Result<()> {
320 let _ = message;
321 Ok(())
322 }
323}
324
325pub struct EnhancedMiddlewareChain {
356 middlewares: Vec<Arc<dyn AdvancedMiddleware>>,
357 auto_sort: bool,
358}
359
360impl fmt::Debug for EnhancedMiddlewareChain {
361 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362 f.debug_struct("EnhancedMiddlewareChain")
363 .field("count", &self.middlewares.len())
364 .field("auto_sort", &self.auto_sort)
365 .finish()
366 }
367}
368
369impl Default for EnhancedMiddlewareChain {
370 fn default() -> Self {
371 Self::new()
372 }
373}
374
375impl EnhancedMiddlewareChain {
376 pub fn new() -> Self {
378 Self {
379 middlewares: Vec::new(),
380 auto_sort: true,
381 }
382 }
383
384 pub fn new_no_sort() -> Self {
386 Self {
387 middlewares: Vec::new(),
388 auto_sort: false,
389 }
390 }
391
392 pub fn add(&mut self, middleware: Arc<dyn AdvancedMiddleware>) {
394 self.middlewares.push(middleware);
395 if self.auto_sort {
396 self.sort_by_priority();
397 }
398 }
399
400 pub fn sort_by_priority(&mut self) {
402 self.middlewares.sort_by_key(|m| m.priority());
403 }
404
405 pub fn len(&self) -> usize {
407 self.middlewares.len()
408 }
409
410 pub fn is_empty(&self) -> bool {
412 self.middlewares.is_empty()
413 }
414
415 pub async fn process_request_with_context(
417 &self,
418 request: &mut JSONRPCRequest,
419 context: &MiddlewareContext,
420 ) -> Result<()> {
421 context.metrics.inc_requests();
422 let start_time = Instant::now();
423
424 for middleware in &self.middlewares {
426 if middleware.should_execute(context).await {
427 middleware.on_chain_start(context).await?;
428 }
429 }
430
431 for middleware in &self.middlewares {
433 if middleware.should_execute(context).await {
434 if let Err(e) = middleware.on_request_with_context(request, context).await {
435 context.metrics.inc_errors();
436 for m in &self.middlewares {
438 if m.should_execute(context).await {
439 let _ = m.on_error(&e, context).await;
440 }
441 }
442 return Err(e);
443 }
444 }
445 }
446
447 for middleware in &self.middlewares {
449 if middleware.should_execute(context).await {
450 middleware.on_chain_complete(context).await?;
451 }
452 }
453
454 context.metrics.add_time(start_time.elapsed());
455 Ok(())
456 }
457
458 pub async fn process_response_with_context(
460 &self,
461 response: &mut JSONRPCResponse,
462 context: &MiddlewareContext,
463 ) -> Result<()> {
464 let start_time = Instant::now();
465
466 for middleware in self.middlewares.iter().rev() {
468 if middleware.should_execute(context).await {
469 if let Err(e) = middleware.on_response_with_context(response, context).await {
470 context.metrics.inc_errors();
471 for m in &self.middlewares {
473 if m.should_execute(context).await {
474 let _ = m.on_error(&e, context).await;
475 }
476 }
477 return Err(e);
478 }
479 }
480 }
481
482 context.metrics.add_time(start_time.elapsed());
483 Ok(())
484 }
485
486 pub async fn process_send_with_context(
488 &self,
489 message: &TransportMessage,
490 context: &MiddlewareContext,
491 ) -> Result<()> {
492 let start_time = Instant::now();
493
494 for middleware in &self.middlewares {
495 if middleware.should_execute(context).await {
496 if let Err(e) = middleware.on_send_with_context(message, context).await {
497 context.metrics.inc_errors();
498 for m in &self.middlewares {
499 if m.should_execute(context).await {
500 let _ = m.on_error(&e, context).await;
501 }
502 }
503 return Err(e);
504 }
505 }
506 }
507
508 context.metrics.add_time(start_time.elapsed());
509 Ok(())
510 }
511
512 pub async fn process_receive_with_context(
514 &self,
515 message: &TransportMessage,
516 context: &MiddlewareContext,
517 ) -> Result<()> {
518 let start_time = Instant::now();
519
520 for middleware in &self.middlewares {
521 if middleware.should_execute(context).await {
522 if let Err(e) = middleware.on_receive_with_context(message, context).await {
523 context.metrics.inc_errors();
524 for m in &self.middlewares {
525 if m.should_execute(context).await {
526 let _ = m.on_error(&e, context).await;
527 }
528 }
529 return Err(e);
530 }
531 }
532 }
533
534 context.metrics.add_time(start_time.elapsed());
535 Ok(())
536 }
537
538 pub fn get_metrics(&self) -> Vec<Arc<PerformanceMetrics>> {
540 Vec::new()
543 }
544}
545
546pub struct MiddlewareChain {
599 middlewares: Vec<Arc<dyn Middleware>>,
600}
601
602impl fmt::Debug for MiddlewareChain {
603 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
604 f.debug_struct("MiddlewareChain")
605 .field("count", &self.middlewares.len())
606 .finish()
607 }
608}
609
610impl Default for MiddlewareChain {
611 fn default() -> Self {
612 Self::new()
613 }
614}
615
616impl MiddlewareChain {
617 pub fn new() -> Self {
619 Self {
620 middlewares: Vec::new(),
621 }
622 }
623
624 pub fn add(&mut self, middleware: Arc<dyn Middleware>) {
626 self.middlewares.push(middleware);
627 }
628
629 pub async fn process_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
631 for middleware in &self.middlewares {
632 middleware.on_request(request).await?;
633 }
634 Ok(())
635 }
636
637 pub async fn process_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
639 for middleware in &self.middlewares {
640 middleware.on_response(response).await?;
641 }
642 Ok(())
643 }
644
645 pub async fn process_send(&self, message: &TransportMessage) -> Result<()> {
647 for middleware in &self.middlewares {
648 middleware.on_send(message).await?;
649 }
650 Ok(())
651 }
652
653 pub async fn process_receive(&self, message: &TransportMessage) -> Result<()> {
655 for middleware in &self.middlewares {
656 middleware.on_receive(message).await?;
657 }
658 Ok(())
659 }
660}
661
662#[derive(Debug)]
692pub struct LoggingMiddleware {
693 level: tracing::Level,
694}
695
696impl LoggingMiddleware {
697 pub fn new(level: tracing::Level) -> Self {
699 Self { level }
700 }
701}
702
703impl Default for LoggingMiddleware {
704 fn default() -> Self {
705 Self::new(tracing::Level::DEBUG)
706 }
707}
708
709#[async_trait]
710impl Middleware for LoggingMiddleware {
711 async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
712 match self.level {
713 tracing::Level::TRACE => tracing::trace!("Sending request: {:?}", request),
714 tracing::Level::DEBUG => tracing::debug!("Sending request: {}", request.method),
715 tracing::Level::INFO => tracing::info!("Sending request: {}", request.method),
716 tracing::Level::WARN => tracing::warn!("Sending request: {}", request.method),
717 tracing::Level::ERROR => tracing::error!("Sending request: {}", request.method),
718 }
719 Ok(())
720 }
721
722 async fn on_response(&self, response: &mut JSONRPCResponse) -> Result<()> {
723 match self.level {
724 tracing::Level::TRACE => tracing::trace!("Received response: {:?}", response),
725 tracing::Level::DEBUG => tracing::debug!("Received response for: {:?}", response.id),
726 tracing::Level::INFO => tracing::info!("Received response"),
727 tracing::Level::WARN => tracing::warn!("Received response"),
728 tracing::Level::ERROR => tracing::error!("Received response"),
729 }
730 Ok(())
731 }
732}
733
734#[derive(Debug)]
762pub struct AuthMiddleware {
763 #[allow(dead_code)]
764 auth_token: String,
765}
766
767impl AuthMiddleware {
768 pub fn new(auth_token: String) -> Self {
770 Self { auth_token }
771 }
772}
773
774#[async_trait]
775impl Middleware for AuthMiddleware {
776 async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
777 tracing::debug!("Adding authentication to request: {}", request.method);
780 Ok(())
781 }
782}
783
784#[derive(Debug)]
822pub struct RetryMiddleware {
823 max_retries: u32,
824 #[allow(dead_code)]
825 initial_delay_ms: u64,
826 #[allow(dead_code)]
827 max_delay_ms: u64,
828}
829
830impl RetryMiddleware {
831 pub fn new(max_retries: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
833 Self {
834 max_retries,
835 initial_delay_ms,
836 max_delay_ms,
837 }
838 }
839}
840
841impl Default for RetryMiddleware {
842 fn default() -> Self {
843 Self::new(3, 1000, 30000)
844 }
845}
846
847#[async_trait]
848impl Middleware for RetryMiddleware {
849 async fn on_request(&self, request: &mut JSONRPCRequest) -> Result<()> {
850 tracing::debug!(
853 "Request {} configured with max {} retries",
854 request.method,
855 self.max_retries
856 );
857 Ok(())
858 }
859}
860
861#[derive(Debug)]
888pub struct RateLimitMiddleware {
889 max_requests: u32,
890 bucket_size: u32,
891 refill_duration: Duration,
892 tokens: Arc<AtomicUsize>,
893 last_refill: Arc<RwLock<Instant>>,
894}
895
896impl RateLimitMiddleware {
897 pub fn new(max_requests: u32, bucket_size: u32, refill_duration: Duration) -> Self {
899 Self {
900 max_requests,
901 bucket_size,
902 refill_duration,
903 tokens: Arc::new(AtomicUsize::new(bucket_size as usize)),
904 last_refill: Arc::new(RwLock::new(Instant::now())),
905 }
906 }
907
908 fn check_rate_limit(&self) -> bool {
910 let now = Instant::now();
912 let mut last_refill = self.last_refill.write();
913 let elapsed = now.duration_since(*last_refill);
914
915 if elapsed >= self.refill_duration {
916 let refill_count = (elapsed.as_millis() / self.refill_duration.as_millis()) as u32;
917 let tokens_to_add = (refill_count * self.max_requests).min(self.bucket_size);
918
919 self.tokens.store(
920 (self.tokens.load(Ordering::Relaxed) + tokens_to_add as usize)
921 .min(self.bucket_size as usize),
922 Ordering::Relaxed,
923 );
924 *last_refill = now;
925 }
926
927 loop {
929 let current = self.tokens.load(Ordering::Relaxed);
930 if current == 0 {
931 return false;
932 }
933 if self
934 .tokens
935 .compare_exchange_weak(current, current - 1, Ordering::Relaxed, Ordering::Relaxed)
936 .is_ok()
937 {
938 return true;
939 }
940 }
941 }
942}
943
944#[async_trait]
945impl AdvancedMiddleware for RateLimitMiddleware {
946 fn name(&self) -> &'static str {
947 "rate_limit"
948 }
949
950 fn priority(&self) -> MiddlewarePriority {
951 MiddlewarePriority::High
952 }
953
954 async fn on_request_with_context(
955 &self,
956 request: &mut JSONRPCRequest,
957 context: &MiddlewareContext,
958 ) -> Result<()> {
959 if !self.check_rate_limit() {
960 tracing::warn!("Rate limit exceeded for request: {}", request.method);
961 context.record_metric("rate_limit_exceeded".to_string(), 1.0);
962 return Err(crate::error::Error::RateLimited);
963 }
964
965 tracing::debug!("Rate limit check passed for request: {}", request.method);
966 context.record_metric("rate_limit_passed".to_string(), 1.0);
967 Ok(())
968 }
969}
970
971#[derive(Debug)]
1002pub struct CircuitBreakerMiddleware {
1003 failure_threshold: u32,
1004 time_window: Duration,
1005 timeout_duration: Duration,
1006 failure_count: Arc<AtomicU64>,
1007 last_failure: Arc<RwLock<Option<Instant>>>,
1008 circuit_open_time: Arc<RwLock<Option<Instant>>>,
1009}
1010
1011impl CircuitBreakerMiddleware {
1012 pub fn new(failure_threshold: u32, time_window: Duration, timeout_duration: Duration) -> Self {
1014 Self {
1015 failure_threshold,
1016 time_window,
1017 timeout_duration,
1018 failure_count: Arc::new(AtomicU64::new(0)),
1019 last_failure: Arc::new(RwLock::new(None)),
1020 circuit_open_time: Arc::new(RwLock::new(None)),
1021 }
1022 }
1023
1024 fn should_allow_request(&self) -> bool {
1026 let now = Instant::now();
1027
1028 let open_time_value = *self.circuit_open_time.read();
1030 if let Some(open_time) = open_time_value {
1031 if now.duration_since(open_time) > self.timeout_duration {
1032 *self.circuit_open_time.write() = None;
1034 self.failure_count.store(0, Ordering::Relaxed);
1035 return true;
1036 }
1037 return false; }
1039
1040 let last_failure_value = *self.last_failure.read();
1042 if let Some(last_failure) = last_failure_value {
1043 if now.duration_since(last_failure) > self.time_window {
1044 self.failure_count.store(0, Ordering::Relaxed);
1045 }
1046 }
1047
1048 self.failure_count.load(Ordering::Relaxed) < self.failure_threshold as u64
1050 }
1051
1052 fn record_failure(&self) {
1054 let now = Instant::now();
1055 let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
1056 *self.last_failure.write() = Some(now);
1057
1058 if failures >= self.failure_threshold as u64 {
1059 *self.circuit_open_time.write() = Some(now);
1060 tracing::warn!("Circuit breaker opened due to {} failures", failures);
1061 }
1062 }
1063}
1064
1065#[async_trait]
1066impl AdvancedMiddleware for CircuitBreakerMiddleware {
1067 fn name(&self) -> &'static str {
1068 "circuit_breaker"
1069 }
1070
1071 fn priority(&self) -> MiddlewarePriority {
1072 MiddlewarePriority::High
1073 }
1074
1075 async fn on_request_with_context(
1076 &self,
1077 request: &mut JSONRPCRequest,
1078 context: &MiddlewareContext,
1079 ) -> Result<()> {
1080 if !self.should_allow_request() {
1081 tracing::warn!(
1082 "Circuit breaker open, rejecting request: {}",
1083 request.method
1084 );
1085 context.record_metric("circuit_breaker_open".to_string(), 1.0);
1086 return Err(crate::error::Error::CircuitBreakerOpen);
1087 }
1088
1089 context.record_metric("circuit_breaker_allowed".to_string(), 1.0);
1090 Ok(())
1091 }
1092
1093 async fn on_error(
1094 &self,
1095 _error: &crate::error::Error,
1096 _context: &MiddlewareContext,
1097 ) -> Result<()> {
1098 self.record_failure();
1099 Ok(())
1100 }
1101}
1102
1103#[derive(Debug)]
1128pub struct MetricsMiddleware {
1129 service_name: String,
1130 request_counts: Arc<DashMap<String, AtomicU64>>,
1131 request_durations: Arc<DashMap<String, AtomicU64>>,
1132 error_counts: Arc<DashMap<String, AtomicU64>>,
1133}
1134
1135impl MetricsMiddleware {
1136 pub fn new(service_name: String) -> Self {
1138 Self {
1139 service_name,
1140 request_counts: Arc::new(DashMap::new()),
1141 request_durations: Arc::new(DashMap::new()),
1142 error_counts: Arc::new(DashMap::new()),
1143 }
1144 }
1145
1146 pub fn get_request_count(&self, method: &str) -> u64 {
1148 self.request_counts
1149 .get(method)
1150 .map_or(0, |c| c.load(Ordering::Relaxed))
1151 }
1152
1153 pub fn get_error_count(&self, method: &str) -> u64 {
1155 self.error_counts
1156 .get(method)
1157 .map_or(0, |c| c.load(Ordering::Relaxed))
1158 }
1159
1160 pub fn get_average_duration(&self, method: &str) -> u64 {
1162 let total_duration = self
1163 .request_durations
1164 .get(method)
1165 .map_or(0, |d| d.load(Ordering::Relaxed));
1166 let count = self.get_request_count(method);
1167 if count > 0 {
1168 total_duration / count
1169 } else {
1170 0
1171 }
1172 }
1173}
1174
1175#[async_trait]
1176impl AdvancedMiddleware for MetricsMiddleware {
1177 fn name(&self) -> &'static str {
1178 "metrics"
1179 }
1180
1181 fn priority(&self) -> MiddlewarePriority {
1182 MiddlewarePriority::Low
1183 }
1184
1185 async fn on_request_with_context(
1186 &self,
1187 request: &mut JSONRPCRequest,
1188 context: &MiddlewareContext,
1189 ) -> Result<()> {
1190 self.request_counts
1192 .entry(request.method.clone())
1193 .or_insert_with(|| AtomicU64::new(0))
1194 .fetch_add(1, Ordering::Relaxed);
1195
1196 context.set_metadata(
1197 "request_start_time".to_string(),
1198 context.start_time.elapsed().as_micros().to_string(),
1199 );
1200 context.set_metadata("service_name".to_string(), self.service_name.clone());
1201
1202 tracing::debug!(
1203 "Metrics recorded for request: {} (service: {})",
1204 request.method,
1205 self.service_name
1206 );
1207 Ok(())
1208 }
1209
1210 async fn on_response_with_context(
1211 &self,
1212 response: &mut JSONRPCResponse,
1213 context: &MiddlewareContext,
1214 ) -> Result<()> {
1215 let duration_us = context.elapsed().as_micros() as u64;
1217
1218 if let Some(method) = context.get_metadata("method") {
1219 self.request_durations
1220 .entry(method)
1221 .or_insert_with(|| AtomicU64::new(0))
1222 .fetch_add(duration_us, Ordering::Relaxed);
1223 }
1224
1225 tracing::debug!(
1226 "Response metrics recorded for ID: {:?} ({}μs)",
1227 response.id,
1228 duration_us
1229 );
1230 Ok(())
1231 }
1232
1233 async fn on_error(
1234 &self,
1235 error: &crate::error::Error,
1236 context: &MiddlewareContext,
1237 ) -> Result<()> {
1238 if let Some(method) = context.get_metadata("method") {
1239 self.error_counts
1240 .entry(method)
1241 .or_insert_with(|| AtomicU64::new(0))
1242 .fetch_add(1, Ordering::Relaxed);
1243 }
1244
1245 tracing::warn!("Error recorded in metrics: {:?}", error);
1246 Ok(())
1247 }
1248}
1249
1250#[derive(Debug, Clone, Copy)]
1275pub enum CompressionType {
1276 None,
1278 Gzip,
1280 Deflate,
1282}
1283
1284#[derive(Debug)]
1286pub struct CompressionMiddleware {
1287 compression_type: CompressionType,
1288 min_size: usize,
1289}
1290
1291impl CompressionMiddleware {
1292 pub fn new(compression_type: CompressionType, min_size: usize) -> Self {
1294 Self {
1295 compression_type,
1296 min_size,
1297 }
1298 }
1299
1300 fn should_compress(&self, content_size: usize) -> bool {
1302 content_size >= self.min_size && !matches!(self.compression_type, CompressionType::None)
1303 }
1304}
1305
1306#[async_trait]
1307impl AdvancedMiddleware for CompressionMiddleware {
1308 fn name(&self) -> &'static str {
1309 "compression"
1310 }
1311
1312 fn priority(&self) -> MiddlewarePriority {
1313 MiddlewarePriority::Normal
1314 }
1315
1316 async fn on_send_with_context(
1317 &self,
1318 message: &TransportMessage,
1319 context: &MiddlewareContext,
1320 ) -> Result<()> {
1321 let serialized = serde_json::to_string(message).unwrap_or_default();
1322 let content_size = serialized.len();
1323
1324 if self.should_compress(content_size) {
1325 context.set_metadata(
1326 "compression_type".to_string(),
1327 format!("{:?}", self.compression_type),
1328 );
1329 context.record_metric("compression_original_size".to_string(), content_size as f64);
1330
1331 tracing::debug!("Compression applied to message of {} bytes", content_size);
1332 }
1334
1335 Ok(())
1336 }
1337}
1338
1339#[cfg(test)]
1340mod tests {
1341 use super::*;
1342 use crate::types::RequestId;
1343
1344 #[tokio::test]
1345 async fn test_middleware_chain() {
1346 let mut chain = MiddlewareChain::new();
1347 chain.add(Arc::new(LoggingMiddleware::default()));
1348
1349 let mut request = JSONRPCRequest {
1350 jsonrpc: "2.0".to_string(),
1351 id: RequestId::from(1i64),
1352 method: "test".to_string(),
1353 params: None,
1354 };
1355
1356 assert!(chain.process_request(&mut request).await.is_ok());
1357 }
1358
1359 #[tokio::test]
1360 async fn test_auth_middleware() {
1361 let middleware = AuthMiddleware::new("test-token".to_string());
1362
1363 let mut request = JSONRPCRequest {
1364 jsonrpc: "2.0".to_string(),
1365 id: RequestId::from(1i64),
1366 method: "test".to_string(),
1367 params: None,
1368 };
1369
1370 assert!(middleware.on_request(&mut request).await.is_ok());
1371 }
1372}