1use crate::time::Instant;
2use alloy_json_rpc::{RequestPacket, ResponsePacket};
3use core::time::Duration;
4use derive_more::{Deref, DerefMut};
5use futures::{stream::FuturesUnordered, StreamExt};
6use parking_lot::RwLock;
7use std::{
8 collections::{HashSet, VecDeque},
9 num::NonZeroUsize,
10 sync::Arc,
11 task::{Context, Poll},
12};
13use tower::{Layer, Service};
14use tracing::trace;
15
16use crate::{TransportError, TransportErrorKind, TransportFut};
17
18const STABILITY_WEIGHT: f64 = 0.7;
20const LATENCY_WEIGHT: f64 = 0.3;
21const DEFAULT_SAMPLE_COUNT: usize = 10;
22const DEFAULT_ACTIVE_TRANSPORT_COUNT: usize = 3;
23
24#[derive(Debug, Clone)]
30pub struct FallbackService<S> {
31 transports: Arc<Vec<ScoredTransport<S>>>,
33 active_transport_count: usize,
35 sequential_methods: Arc<HashSet<String>>,
38}
39
40impl<S: Clone> FallbackService<S> {
41 pub fn new(transports: Vec<S>, active_transport_count: usize) -> Self {
49 Self::new_with_sequential_methods(
50 transports,
51 active_transport_count,
52 default_sequential_methods(),
53 )
54 }
55
56 pub fn new_with_sequential_methods(
64 transports: Vec<S>,
65 active_transport_count: usize,
66 sequential_methods: HashSet<String>,
67 ) -> Self {
68 let scored_transports = transports
69 .into_iter()
70 .enumerate()
71 .map(|(id, transport)| ScoredTransport::new(id, transport))
72 .collect::<Vec<_>>();
73
74 Self {
75 transports: Arc::new(scored_transports),
76 active_transport_count,
77 sequential_methods: Arc::new(sequential_methods),
78 }
79 }
80
81 pub fn append_sequential_method(mut self, sequential_method: impl Into<String>) -> Self {
83 let mut methods = Arc::unwrap_or_clone(self.sequential_methods);
84 methods.insert(sequential_method.into());
85 self.sequential_methods = Arc::new(methods);
86 self
87 }
88
89 pub fn with_sequential_methods(mut self, sequential_methods: HashSet<String>) -> Self {
92 self.sequential_methods = Arc::new(sequential_methods);
93 self
94 }
95
96 fn log_transport_rankings(&self) {
98 if !tracing::enabled!(tracing::Level::TRACE) {
99 return;
100 }
101
102 let mut ranked: Vec<(usize, f64, String)> =
104 self.transports.iter().map(|t| (t.id, t.score(), t.metrics_summary())).collect();
105
106 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
107
108 trace!("Current transport rankings:");
109 for (idx, (id, _score, summary)) in ranked.iter().enumerate() {
110 trace!(" #{}: Transport[{}] - {}", idx + 1, id, summary);
111 }
112 }
113
114 fn top_transports(&self) -> Vec<ScoredTransport<S>> {
117 let mut transports_clone = (*self.transports).clone();
119 transports_clone.sort_by(|a, b| b.cmp(a));
120 transports_clone.truncate(self.active_transport_count);
121 transports_clone
122 }
123}
124
125impl<S> FallbackService<S>
126where
127 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
128 + Send
129 + Clone
130 + 'static,
131{
132 async fn make_request(&self, req: RequestPacket) -> Result<ResponsePacket, TransportError> {
150 if req.method_names().any(|name| self.sequential_methods.contains(name)) {
154 return self.make_request_sequential(req).await;
155 }
156
157 let top_transports = self.top_transports();
160
161 let mut futures = FuturesUnordered::new();
163
164 for mut transport in top_transports {
166 let req_clone = req.clone();
167
168 let future = async move {
169 let start = Instant::now();
170 let result = transport.call(req_clone).await;
171 trace!(
172 "Transport[{}] completed: latency={:?}, status={}",
173 transport.id,
174 start.elapsed(),
175 if result.is_ok() { "success" } else { "fail" }
176 );
177
178 (result, transport, start.elapsed())
179 };
180
181 futures.push(future);
182 }
183
184 let mut last_error = None;
186
187 while let Some((result, transport, duration)) = futures.next().await {
188 match result {
189 Ok(response) => {
190 transport.track_success(duration);
192
193 self.log_transport_rankings();
194
195 return Ok(response);
196 }
197 Err(error) => {
198 transport.track_failure();
200
201 last_error = Some(error);
202 }
203 }
204 }
205
206 Err(last_error.unwrap_or_else(|| {
207 TransportErrorKind::custom_str("All transport futures failed to complete")
208 }))
209 }
210
211 async fn make_request_sequential(
219 &self,
220 req: RequestPacket,
221 ) -> Result<ResponsePacket, TransportError> {
222 trace!("Using sequential fallback for method with non-deterministic results");
223
224 let top_transports = self.top_transports();
226
227 let mut last_error = None;
228
229 for mut transport in top_transports {
231 let req_clone = req.clone();
232 let start = Instant::now();
233
234 trace!("Trying transport[{}] sequentially", transport.id);
235
236 match transport.call(req_clone).await {
237 Ok(response) => {
238 transport.track_success(start.elapsed());
240 trace!("Transport[{}] succeeded in {:?}", transport.id, start.elapsed());
241 self.log_transport_rankings();
242 return Ok(response);
243 }
244 Err(error) => {
245 transport.track_failure();
247 trace!("Transport[{}] failed: {:?}, trying next", transport.id, error);
248 last_error = Some(error);
249 }
250 }
251 }
252
253 Err(last_error.unwrap_or_else(|| {
255 TransportErrorKind::custom_str("All transports failed for sequential request")
256 }))
257 }
258}
259
260impl<S> Service<RequestPacket> for FallbackService<S>
261where
262 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
263 + Send
264 + Sync
265 + Clone
266 + 'static,
267{
268 type Response = ResponsePacket;
269 type Error = TransportError;
270 type Future = TransportFut<'static>;
271
272 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
273 Poll::Ready(Ok(()))
275 }
276
277 fn call(&mut self, req: RequestPacket) -> Self::Future {
278 let this = self.clone();
279 Box::pin(async move { this.make_request(req).await })
280 }
281}
282
283#[derive(Debug, Clone)]
302pub struct FallbackLayer {
303 active_transport_count: usize,
305 sequential_methods: HashSet<String>,
308}
309
310impl FallbackLayer {
311 pub const fn with_active_transport_count(mut self, count: NonZeroUsize) -> Self {
313 self.active_transport_count = count.get();
314 self
315 }
316
317 pub fn with_sequential_method(mut self, method: impl Into<String>) -> Self {
322 self.sequential_methods.insert(method.into());
323 self
324 }
325
326 pub fn with_sequential_methods(mut self, methods: HashSet<String>) -> Self {
331 self.sequential_methods = methods;
332 self
333 }
334
335 pub fn without_sequential_methods(mut self) -> Self {
340 self.sequential_methods.clear();
341 self
342 }
343}
344
345impl<S> Layer<Vec<S>> for FallbackLayer
346where
347 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
348 + Send
349 + Clone
350 + 'static,
351{
352 type Service = FallbackService<S>;
353
354 fn layer(&self, inner: Vec<S>) -> Self::Service {
355 FallbackService::new_with_sequential_methods(
356 inner,
357 self.active_transport_count,
358 self.sequential_methods.clone(),
359 )
360 }
361}
362
363impl Default for FallbackLayer {
364 fn default() -> Self {
365 Self {
366 active_transport_count: DEFAULT_ACTIVE_TRANSPORT_COUNT,
367 sequential_methods: default_sequential_methods(),
368 }
369 }
370}
371
372#[derive(Debug, Clone, Deref, DerefMut)]
385struct ScoredTransport<S> {
386 #[deref]
388 #[deref_mut]
389 transport: S,
390 id: usize,
392 metrics: Arc<RwLock<TransportMetrics>>,
394}
395
396impl<S> ScoredTransport<S> {
397 fn new(id: usize, transport: S) -> Self {
399 Self { id, transport, metrics: Arc::new(Default::default()) }
400 }
401
402 fn score(&self) -> f64 {
404 let metrics = self.metrics.read();
405 metrics.calculate_score()
406 }
407
408 fn metrics_summary(&self) -> String {
410 let metrics = self.metrics.read();
411 metrics.get_summary()
412 }
413
414 fn track_success(&self, duration: Duration) {
416 let mut metrics = self.metrics.write();
417 metrics.track_success(duration);
418 }
419
420 fn track_failure(&self) {
422 let mut metrics = self.metrics.write();
423 metrics.track_failure();
424 }
425}
426
427impl<S> PartialEq for ScoredTransport<S> {
428 fn eq(&self, other: &Self) -> bool {
429 self.score().eq(&other.score())
430 }
431}
432
433impl<S> Eq for ScoredTransport<S> {}
434
435#[expect(clippy::non_canonical_partial_ord_impl)]
436impl<S> PartialOrd for ScoredTransport<S> {
437 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
438 self.score().partial_cmp(&other.score())
439 }
440}
441
442impl<S> Ord for ScoredTransport<S> {
443 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
444 self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
445 }
446}
447
448#[derive(Debug)]
450struct TransportMetrics {
451 latencies: VecDeque<Duration>,
453 successes: VecDeque<bool>,
455 last_update: Instant,
457 total_requests: u64,
459 successful_requests: u64,
461}
462
463impl TransportMetrics {
464 fn track_success(&mut self, duration: Duration) {
466 self.total_requests += 1;
467 self.successful_requests += 1;
468 self.last_update = Instant::now();
469
470 self.latencies.push_back(duration);
472 self.successes.push_back(true);
473
474 while self.latencies.len() > DEFAULT_SAMPLE_COUNT {
476 self.latencies.pop_front();
477 }
478 while self.successes.len() > DEFAULT_SAMPLE_COUNT {
479 self.successes.pop_front();
480 }
481 }
482
483 fn track_failure(&mut self) {
485 self.total_requests += 1;
486 self.last_update = Instant::now();
487
488 self.successes.push_back(false);
490
491 while self.successes.len() > DEFAULT_SAMPLE_COUNT {
493 self.successes.pop_front();
494 }
495 }
496
497 fn calculate_score(&self) -> f64 {
499 if self.successes.is_empty() {
501 return 0.0;
502 }
503
504 let success_count = self.successes.iter().filter(|&&s| s).count();
506 let stability_score = success_count as f64 / self.successes.len() as f64;
507
508 let latency_score = if !self.latencies.is_empty() {
510 let avg_latency = self.latencies.iter().map(|d| d.as_secs_f64()).sum::<f64>()
511 / self.latencies.len() as f64;
512
513 1.0 / (1.0 + avg_latency)
515 } else {
516 0.0
517 };
518
519 (stability_score * STABILITY_WEIGHT) + (latency_score * LATENCY_WEIGHT)
521 }
522
523 fn get_summary(&self) -> String {
525 let success_rate = if !self.successes.is_empty() {
526 let success_count = self.successes.iter().filter(|&&s| s).count();
527 success_count as f64 / self.successes.len() as f64
528 } else {
529 0.0
530 };
531
532 let avg_latency = if !self.latencies.is_empty() {
533 self.latencies.iter().map(|d| d.as_secs_f64()).sum::<f64>()
534 / self.latencies.len() as f64
535 } else {
536 0.0
537 };
538
539 format!(
540 "success_rate: {:.2}%, avg_latency: {:.2}ms, samples: {}, score: {:.4}",
541 success_rate * 100.0,
542 avg_latency * 1000.0,
543 self.successes.len(),
544 self.calculate_score()
545 )
546 }
547}
548
549impl Default for TransportMetrics {
550 fn default() -> Self {
551 Self {
552 latencies: VecDeque::new(),
553 successes: VecDeque::new(),
554 last_update: Instant::now(),
555 total_requests: 0,
556 successful_requests: 0,
557 }
558 }
559}
560
561fn default_sequential_methods() -> HashSet<String> {
580 ["eth_sendRawTransactionSync".to_string(), "eth_sendTransactionSync".to_string()]
581 .into_iter()
582 .collect()
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588 use alloy_json_rpc::{Id, Request, Response, ResponsePayload};
589 use std::sync::atomic::{AtomicUsize, Ordering};
590 use tokio::time::{sleep, Duration};
591 use tower::Service;
592
593 #[derive(Clone)]
595 struct DelayedMockTransport {
596 delay: Duration,
597 response: Arc<RwLock<Option<ResponsePayload>>>,
598 call_count: Arc<AtomicUsize>,
599 }
600
601 impl DelayedMockTransport {
602 fn new(delay: Duration, response: ResponsePayload) -> Self {
603 Self {
604 delay,
605 response: Arc::new(RwLock::new(Some(response))),
606 call_count: Arc::new(AtomicUsize::new(0)),
607 }
608 }
609
610 fn call_count(&self) -> usize {
611 self.call_count.load(Ordering::SeqCst)
612 }
613 }
614
615 impl Service<RequestPacket> for DelayedMockTransport {
616 type Response = ResponsePacket;
617 type Error = TransportError;
618 type Future = TransportFut<'static>;
619
620 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
621 Poll::Ready(Ok(()))
622 }
623
624 fn call(&mut self, req: RequestPacket) -> Self::Future {
625 self.call_count.fetch_add(1, Ordering::SeqCst);
626 let delay = self.delay;
627 let response = self.response.clone();
628
629 Box::pin(async move {
630 sleep(delay).await;
631
632 match req {
633 RequestPacket::Single(single) => {
634 let resp = response.read().clone().ok_or_else(|| {
635 TransportErrorKind::custom_str("No response configured")
636 })?;
637
638 Ok(ResponsePacket::Single(Response {
639 id: single.id().clone(),
640 payload: resp,
641 }))
642 }
643 RequestPacket::Batch(batch) => {
644 let resp = response.read().clone().ok_or_else(|| {
645 TransportErrorKind::custom_str("No response configured")
646 })?;
647
648 let responses = batch
650 .iter()
651 .map(|req| Response { id: req.id().clone(), payload: resp.clone() })
652 .collect();
653
654 Ok(ResponsePacket::Batch(responses))
655 }
656 }
657 })
658 }
659 }
660
661 fn success_response(data: &str) -> ResponsePayload {
663 let raw = serde_json::value::RawValue::from_string(format!("\"{}\"", data)).unwrap();
664 ResponsePayload::Success(raw)
665 }
666
667 #[tokio::test]
668 async fn test_non_deterministic_method_uses_sequential_fallback() {
669 let transport_a = DelayedMockTransport::new(
674 Duration::from_millis(50),
675 success_response("0x1234567890abcdef"), );
677
678 let transport_b = DelayedMockTransport::new(
679 Duration::from_millis(10),
680 success_response("already_known"), );
682
683 let transports = vec![transport_a.clone(), transport_b.clone()];
684 let mut fallback_service = FallbackService::new(transports, 2);
685
686 let request = Request::new(
687 "eth_sendRawTransactionSync",
688 Id::Number(1),
689 [serde_json::Value::String("0xabcdef".to_string())],
690 );
691 let serialized = request.serialize().unwrap();
692 let request_packet = RequestPacket::Single(serialized);
693
694 let start = std::time::Instant::now();
695 let response = fallback_service.call(request_packet).await.unwrap();
696 let elapsed = start.elapsed();
697
698 let result = match response {
699 ResponsePacket::Single(resp) => match resp.payload {
700 ResponsePayload::Success(data) => data.get().to_string(),
701 ResponsePayload::Failure(err) => panic!("Unexpected error: {:?}", err),
702 },
703 ResponsePacket::Batch(_) => panic!("Unexpected batch response"),
704 };
705
706 assert_eq!(transport_a.call_count(), 1, "First transport should be called");
708 assert_eq!(transport_b.call_count(), 0, "Second transport should NOT be called");
710
711 assert_eq!(result, "\"0x1234567890abcdef\"");
713
714 assert!(
716 elapsed >= Duration::from_millis(40),
717 "Should wait for first transport: {:?}",
718 elapsed
719 );
720 }
721
722 #[tokio::test]
723 async fn test_deterministic_method_uses_parallel_execution() {
724 let tx_hash = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef";
728
729 let transport_a = DelayedMockTransport::new(
730 Duration::from_millis(100),
731 success_response(tx_hash), );
733
734 let transport_b = DelayedMockTransport::new(
735 Duration::from_millis(20),
736 success_response(tx_hash), );
738
739 let transports = vec![transport_a.clone(), transport_b.clone()];
740 let mut fallback_service = FallbackService::new(transports, 2);
741
742 let request = Request::new(
743 "eth_sendRawTransaction",
744 Id::Number(1),
745 [serde_json::Value::String("0xabcdef".to_string())],
746 );
747 let serialized = request.serialize().unwrap();
748 let request_packet = RequestPacket::Single(serialized);
749
750 let start = std::time::Instant::now();
751 let response = fallback_service.call(request_packet).await.unwrap();
752 let elapsed = start.elapsed();
753
754 let result = match response {
755 ResponsePacket::Single(resp) => match resp.payload {
756 ResponsePayload::Success(data) => data.get().to_string(),
757 ResponsePayload::Failure(err) => panic!("Unexpected error: {:?}", err),
758 },
759 ResponsePacket::Batch(_) => panic!("Unexpected batch response"),
760 };
761
762 assert_eq!(transport_a.call_count(), 1, "Transport A should be called");
764 assert_eq!(transport_b.call_count(), 1, "Transport B should be called");
765
766 assert_eq!(result, format!("\"{}\"", tx_hash));
768
769 assert!(
771 elapsed < Duration::from_millis(50),
772 "Should use parallel execution and return fast: {:?}",
773 elapsed
774 );
775 }
776
777 #[tokio::test]
778 async fn test_batch_with_any_sequential_method_uses_sequential_execution() {
779 let tx_hash = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef";
783
784 let transport_a =
786 DelayedMockTransport::new(Duration::from_millis(10), success_response(tx_hash));
787
788 let transport_b = DelayedMockTransport::new(
791 Duration::from_millis(10),
792 success_response("should_not_be_called"),
793 );
794
795 let transports = vec![transport_a.clone(), transport_b.clone()];
796 let mut fallback_service = FallbackService::new(transports, 2);
797
798 let request1 = Request::new("eth_blockNumber", Id::Number(1), ());
802 let request2 = Request::new(
803 "eth_sendRawTransactionSync",
804 Id::Number(2),
805 [serde_json::Value::String("0xabcdef".to_string())],
806 );
807
808 let batch = vec![request1.serialize().unwrap(), request2.serialize().unwrap()];
809 let request_packet = RequestPacket::Batch(batch);
810
811 let start = std::time::Instant::now();
812 let response = fallback_service.call(request_packet).await.unwrap();
813 let elapsed = start.elapsed();
814
815 assert_eq!(
818 transport_a.call_count(),
819 1,
820 "Transport A should be called once (first in sequence)"
821 );
822 assert_eq!(
823 transport_b.call_count(),
824 0,
825 "Transport B should NOT be called (transport A succeeded)"
826 );
827
828 match response {
830 ResponsePacket::Batch(responses) => {
831 assert_eq!(responses.len(), 2, "Should get 2 responses in batch");
832 for resp in responses {
834 match resp.payload {
835 ResponsePayload::Success(_) => {} ResponsePayload::Failure(err) => panic!("Unexpected error: {:?}", err),
837 }
838 }
839 }
840 ResponsePacket::Single(_) => panic!("Expected batch response"),
841 }
842
843 assert!(
845 elapsed < Duration::from_millis(50),
846 "Sequential execution with fast first transport should be quick: {:?}",
847 elapsed
848 );
849 }
850
851 #[tokio::test]
852 async fn test_custom_sequential_method() {
853 let transport_a =
857 DelayedMockTransport::new(Duration::from_millis(10), success_response("result_a"));
858
859 let transport_b =
861 DelayedMockTransport::new(Duration::from_millis(10), success_response("result_b"));
862
863 let transports = vec![transport_a.clone(), transport_b.clone()];
864
865 let custom_methods = ["my_custom_method".to_string()].into_iter().collect();
867 let mut fallback_service =
868 FallbackService::new(transports, 2).with_sequential_methods(custom_methods);
869
870 let request = Request::new("my_custom_method", Id::Number(1), ());
871 let serialized = request.serialize().unwrap();
872 let request_packet = RequestPacket::Single(serialized);
873
874 let start = std::time::Instant::now();
875 let _response = fallback_service.call(request_packet).await.unwrap();
876 let elapsed = start.elapsed();
877
878 assert_eq!(
882 transport_a.call_count(),
883 1,
884 "Transport A should be called once (sequential, first transport)"
885 );
886 assert_eq!(
887 transport_b.call_count(),
888 0,
889 "Transport B should NOT be called (sequential mode, A succeeded)"
890 );
891
892 assert!(
894 elapsed < Duration::from_millis(50),
895 "Sequential execution with fast first transport: {:?}",
896 elapsed
897 );
898 }
899}