1use crate::time::Instant;
2use alloy_json_rpc::{RequestPacket, ResponsePacket};
3use core::time::Duration;
4use derive_more::{Deref, DerefMut};
5use futures::{stream::FuturesUnordered, StreamExt};
6use parking_lot::RwLock;
7use std::{
8 collections::{HashSet, VecDeque},
9 num::NonZeroUsize,
10 sync::Arc,
11 task::{Context, Poll},
12};
13use tower::{Layer, Service};
14use tracing::trace;
15
16use crate::{TransportError, TransportErrorKind, TransportFut};
17
18const STABILITY_WEIGHT: f64 = 0.7;
20const LATENCY_WEIGHT: f64 = 0.3;
21const DEFAULT_SAMPLE_COUNT: usize = 10;
22const DEFAULT_ACTIVE_TRANSPORT_COUNT: usize = 3;
23
24#[derive(Debug, Clone)]
30pub struct FallbackService<S> {
31 transports: Arc<Vec<ScoredTransport<S>>>,
33 active_transport_count: usize,
35 sequential_methods: Arc<HashSet<String>>,
38}
39
40impl<S: Clone> FallbackService<S> {
41 pub fn new(transports: Vec<S>, active_transport_count: usize) -> Self {
49 Self::new_with_sequential_methods(
50 transports,
51 active_transport_count,
52 default_sequential_methods(),
53 )
54 }
55
56 pub fn new_with_sequential_methods(
64 transports: Vec<S>,
65 active_transport_count: usize,
66 sequential_methods: HashSet<String>,
67 ) -> Self {
68 let scored_transports = transports
69 .into_iter()
70 .enumerate()
71 .map(|(id, transport)| ScoredTransport::new(id, transport))
72 .collect::<Vec<_>>();
73
74 Self {
75 transports: Arc::new(scored_transports),
76 active_transport_count,
77 sequential_methods: Arc::new(sequential_methods),
78 }
79 }
80
81 pub fn append_sequential_method(mut self, sequential_method: impl Into<String>) -> Self {
83 let mut methods = Arc::unwrap_or_clone(self.sequential_methods);
84 methods.insert(sequential_method.into());
85 self.sequential_methods = Arc::new(methods);
86 self
87 }
88
89 pub fn with_sequential_methods(mut self, sequential_methods: HashSet<String>) -> Self {
92 self.sequential_methods = Arc::new(sequential_methods);
93 self
94 }
95
96 fn log_transport_rankings(&self) {
98 if !tracing::enabled!(tracing::Level::TRACE) {
99 return;
100 }
101
102 let mut ranked: Vec<(usize, f64, String)> =
104 self.transports.iter().map(|t| (t.id, t.score(), t.metrics_summary())).collect();
105
106 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
107
108 trace!("Current transport rankings:");
109 for (idx, (id, _score, summary)) in ranked.iter().enumerate() {
110 trace!(" #{}: Transport[{}] - {}", idx + 1, id, summary);
111 }
112 }
113}
114
115impl<S> FallbackService<S>
116where
117 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
118 + Send
119 + Clone
120 + 'static,
121{
122 async fn make_request(&self, req: RequestPacket) -> Result<ResponsePacket, TransportError> {
140 if req.method_names().any(|name| self.sequential_methods.contains(name)) {
144 return self.make_request_sequential(req).await;
145 }
146
147 let top_transports = {
150 let mut transports_clone = (*self.transports).clone();
152 transports_clone.sort_by(|a, b| b.cmp(a));
153 transports_clone.truncate(self.active_transport_count);
154 transports_clone
155 };
156
157 let mut futures = FuturesUnordered::new();
159
160 for mut transport in top_transports {
162 let req_clone = req.clone();
163
164 let future = async move {
165 let start = Instant::now();
166 let result = transport.call(req_clone).await;
167 trace!(
168 "Transport[{}] completed: latency={:?}, status={}",
169 transport.id,
170 start.elapsed(),
171 if result.is_ok() { "success" } else { "fail" }
172 );
173
174 (result, transport, start.elapsed())
175 };
176
177 futures.push(future);
178 }
179
180 let mut last_error = None;
182
183 while let Some((result, transport, duration)) = futures.next().await {
184 match result {
185 Ok(response) => {
186 transport.track_success(duration);
188
189 self.log_transport_rankings();
190
191 return Ok(response);
192 }
193 Err(error) => {
194 transport.track_failure();
196
197 last_error = Some(error);
198 }
199 }
200 }
201
202 Err(last_error.unwrap_or_else(|| {
203 TransportErrorKind::custom_str("All transport futures failed to complete")
204 }))
205 }
206
207 async fn make_request_sequential(
215 &self,
216 req: RequestPacket,
217 ) -> Result<ResponsePacket, TransportError> {
218 trace!("Using sequential fallback for method with non-deterministic results");
219
220 let top_transports = {
222 let mut transports_clone = (*self.transports).clone();
223 transports_clone.sort_by(|a, b| b.cmp(a));
224 transports_clone.truncate(self.active_transport_count);
225 transports_clone
226 };
227
228 let mut last_error = None;
229
230 for mut transport in top_transports {
232 let req_clone = req.clone();
233 let start = Instant::now();
234
235 trace!("Trying transport[{}] sequentially", transport.id);
236
237 match transport.call(req_clone).await {
238 Ok(response) => {
239 transport.track_success(start.elapsed());
241 trace!("Transport[{}] succeeded in {:?}", transport.id, start.elapsed());
242 self.log_transport_rankings();
243 return Ok(response);
244 }
245 Err(error) => {
246 transport.track_failure();
248 trace!("Transport[{}] failed: {:?}, trying next", transport.id, error);
249 last_error = Some(error);
250 }
251 }
252 }
253
254 Err(last_error.unwrap_or_else(|| {
256 TransportErrorKind::custom_str("All transports failed for sequential request")
257 }))
258 }
259}
260
261impl<S> Service<RequestPacket> for FallbackService<S>
262where
263 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
264 + Send
265 + Sync
266 + Clone
267 + 'static,
268{
269 type Response = ResponsePacket;
270 type Error = TransportError;
271 type Future = TransportFut<'static>;
272
273 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
274 Poll::Ready(Ok(()))
276 }
277
278 fn call(&mut self, req: RequestPacket) -> Self::Future {
279 let this = self.clone();
280 Box::pin(async move { this.make_request(req).await })
281 }
282}
283
284#[derive(Debug, Clone)]
303pub struct FallbackLayer {
304 active_transport_count: usize,
306 sequential_methods: HashSet<String>,
309}
310
311impl FallbackLayer {
312 pub const fn with_active_transport_count(mut self, count: NonZeroUsize) -> Self {
314 self.active_transport_count = count.get();
315 self
316 }
317
318 pub fn with_sequential_method(mut self, method: impl Into<String>) -> Self {
323 self.sequential_methods.insert(method.into());
324 self
325 }
326
327 pub fn with_sequential_methods(mut self, methods: HashSet<String>) -> Self {
332 self.sequential_methods = methods;
333 self
334 }
335
336 pub fn without_sequential_methods(mut self) -> Self {
341 self.sequential_methods.clear();
342 self
343 }
344}
345
346impl<S> Layer<Vec<S>> for FallbackLayer
347where
348 S: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
349 + Send
350 + Clone
351 + 'static,
352{
353 type Service = FallbackService<S>;
354
355 fn layer(&self, inner: Vec<S>) -> Self::Service {
356 FallbackService::new_with_sequential_methods(
357 inner,
358 self.active_transport_count,
359 self.sequential_methods.clone(),
360 )
361 }
362}
363
364impl Default for FallbackLayer {
365 fn default() -> Self {
366 Self {
367 active_transport_count: DEFAULT_ACTIVE_TRANSPORT_COUNT,
368 sequential_methods: default_sequential_methods(),
369 }
370 }
371}
372
373#[derive(Debug, Clone, Deref, DerefMut)]
386struct ScoredTransport<S> {
387 #[deref]
389 #[deref_mut]
390 transport: S,
391 id: usize,
393 metrics: Arc<RwLock<TransportMetrics>>,
395}
396
397impl<S> ScoredTransport<S> {
398 fn new(id: usize, transport: S) -> Self {
400 Self { id, transport, metrics: Arc::new(Default::default()) }
401 }
402
403 fn score(&self) -> f64 {
405 let metrics = self.metrics.read();
406 metrics.calculate_score()
407 }
408
409 fn metrics_summary(&self) -> String {
411 let metrics = self.metrics.read();
412 metrics.get_summary()
413 }
414
415 fn track_success(&self, duration: Duration) {
417 let mut metrics = self.metrics.write();
418 metrics.track_success(duration);
419 }
420
421 fn track_failure(&self) {
423 let mut metrics = self.metrics.write();
424 metrics.track_failure();
425 }
426}
427
428impl<S> PartialEq for ScoredTransport<S> {
429 fn eq(&self, other: &Self) -> bool {
430 self.score().eq(&other.score())
431 }
432}
433
434impl<S> Eq for ScoredTransport<S> {}
435
436#[expect(clippy::non_canonical_partial_ord_impl)]
437impl<S> PartialOrd for ScoredTransport<S> {
438 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
439 self.score().partial_cmp(&other.score())
440 }
441}
442
443impl<S> Ord for ScoredTransport<S> {
444 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
445 self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
446 }
447}
448
449#[derive(Debug)]
451struct TransportMetrics {
452 latencies: VecDeque<Duration>,
454 successes: VecDeque<bool>,
456 last_update: Instant,
458 total_requests: u64,
460 successful_requests: u64,
462}
463
464impl TransportMetrics {
465 fn track_success(&mut self, duration: Duration) {
467 self.total_requests += 1;
468 self.successful_requests += 1;
469 self.last_update = Instant::now();
470
471 self.latencies.push_back(duration);
473 self.successes.push_back(true);
474
475 while self.latencies.len() > DEFAULT_SAMPLE_COUNT {
477 self.latencies.pop_front();
478 }
479 while self.successes.len() > DEFAULT_SAMPLE_COUNT {
480 self.successes.pop_front();
481 }
482 }
483
484 fn track_failure(&mut self) {
486 self.total_requests += 1;
487 self.last_update = Instant::now();
488
489 self.successes.push_back(false);
491
492 while self.successes.len() > DEFAULT_SAMPLE_COUNT {
494 self.successes.pop_front();
495 }
496 }
497
498 fn calculate_score(&self) -> f64 {
500 if self.successes.is_empty() {
502 return 0.0;
503 }
504
505 let success_count = self.successes.iter().filter(|&&s| s).count();
507 let stability_score = success_count as f64 / self.successes.len() as f64;
508
509 let latency_score = if !self.latencies.is_empty() {
511 let avg_latency = self.latencies.iter().map(|d| d.as_secs_f64()).sum::<f64>()
512 / self.latencies.len() as f64;
513
514 1.0 / (1.0 + avg_latency)
516 } else {
517 0.0
518 };
519
520 (stability_score * STABILITY_WEIGHT) + (latency_score * LATENCY_WEIGHT)
522 }
523
524 fn get_summary(&self) -> String {
526 let success_rate = if !self.successes.is_empty() {
527 let success_count = self.successes.iter().filter(|&&s| s).count();
528 success_count as f64 / self.successes.len() as f64
529 } else {
530 0.0
531 };
532
533 let avg_latency = if !self.latencies.is_empty() {
534 self.latencies.iter().map(|d| d.as_secs_f64()).sum::<f64>()
535 / self.latencies.len() as f64
536 } else {
537 0.0
538 };
539
540 format!(
541 "success_rate: {:.2}%, avg_latency: {:.2}ms, samples: {}, score: {:.4}",
542 success_rate * 100.0,
543 avg_latency * 1000.0,
544 self.successes.len(),
545 self.calculate_score()
546 )
547 }
548}
549
550impl Default for TransportMetrics {
551 fn default() -> Self {
552 Self {
553 latencies: VecDeque::new(),
554 successes: VecDeque::new(),
555 last_update: Instant::now(),
556 total_requests: 0,
557 successful_requests: 0,
558 }
559 }
560}
561
562fn default_sequential_methods() -> HashSet<String> {
581 ["eth_sendRawTransactionSync".to_string(), "eth_sendTransactionSync".to_string()]
582 .into_iter()
583 .collect()
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589 use alloy_json_rpc::{Id, Request, Response, ResponsePayload};
590 use std::sync::atomic::{AtomicUsize, Ordering};
591 use tokio::time::{sleep, Duration};
592 use tower::Service;
593
594 #[derive(Clone)]
596 struct DelayedMockTransport {
597 delay: Duration,
598 response: Arc<RwLock<Option<ResponsePayload>>>,
599 call_count: Arc<AtomicUsize>,
600 }
601
602 impl DelayedMockTransport {
603 fn new(delay: Duration, response: ResponsePayload) -> Self {
604 Self {
605 delay,
606 response: Arc::new(RwLock::new(Some(response))),
607 call_count: Arc::new(AtomicUsize::new(0)),
608 }
609 }
610
611 fn call_count(&self) -> usize {
612 self.call_count.load(Ordering::SeqCst)
613 }
614 }
615
616 impl Service<RequestPacket> for DelayedMockTransport {
617 type Response = ResponsePacket;
618 type Error = TransportError;
619 type Future = TransportFut<'static>;
620
621 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
622 Poll::Ready(Ok(()))
623 }
624
625 fn call(&mut self, req: RequestPacket) -> Self::Future {
626 self.call_count.fetch_add(1, Ordering::SeqCst);
627 let delay = self.delay;
628 let response = self.response.clone();
629
630 Box::pin(async move {
631 sleep(delay).await;
632
633 match req {
634 RequestPacket::Single(single) => {
635 let resp = response.read().clone().ok_or_else(|| {
636 TransportErrorKind::custom_str("No response configured")
637 })?;
638
639 Ok(ResponsePacket::Single(Response {
640 id: single.id().clone(),
641 payload: resp,
642 }))
643 }
644 RequestPacket::Batch(batch) => {
645 let resp = response.read().clone().ok_or_else(|| {
646 TransportErrorKind::custom_str("No response configured")
647 })?;
648
649 let responses = batch
651 .iter()
652 .map(|req| Response { id: req.id().clone(), payload: resp.clone() })
653 .collect();
654
655 Ok(ResponsePacket::Batch(responses))
656 }
657 }
658 })
659 }
660 }
661
662 fn success_response(data: &str) -> ResponsePayload {
664 let raw = serde_json::value::RawValue::from_string(format!("\"{}\"", data)).unwrap();
665 ResponsePayload::Success(raw)
666 }
667
668 #[tokio::test]
669 async fn test_non_deterministic_method_uses_sequential_fallback() {
670 let transport_a = DelayedMockTransport::new(
675 Duration::from_millis(50),
676 success_response("0x1234567890abcdef"), );
678
679 let transport_b = DelayedMockTransport::new(
680 Duration::from_millis(10),
681 success_response("already_known"), );
683
684 let transports = vec![transport_a.clone(), transport_b.clone()];
685 let mut fallback_service = FallbackService::new(transports, 2);
686
687 let request = Request::new(
688 "eth_sendRawTransactionSync",
689 Id::Number(1),
690 [serde_json::Value::String("0xabcdef".to_string())],
691 );
692 let serialized = request.serialize().unwrap();
693 let request_packet = RequestPacket::Single(serialized);
694
695 let start = std::time::Instant::now();
696 let response = fallback_service.call(request_packet).await.unwrap();
697 let elapsed = start.elapsed();
698
699 let result = match response {
700 ResponsePacket::Single(resp) => match resp.payload {
701 ResponsePayload::Success(data) => data.get().to_string(),
702 ResponsePayload::Failure(err) => panic!("Unexpected error: {:?}", err),
703 },
704 ResponsePacket::Batch(_) => panic!("Unexpected batch response"),
705 };
706
707 assert_eq!(transport_a.call_count(), 1, "First transport should be called");
709 assert_eq!(transport_b.call_count(), 0, "Second transport should NOT be called");
711
712 assert_eq!(result, "\"0x1234567890abcdef\"");
714
715 assert!(
717 elapsed >= Duration::from_millis(40),
718 "Should wait for first transport: {:?}",
719 elapsed
720 );
721 }
722
723 #[tokio::test]
724 async fn test_deterministic_method_uses_parallel_execution() {
725 let tx_hash = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef";
729
730 let transport_a = DelayedMockTransport::new(
731 Duration::from_millis(100),
732 success_response(tx_hash), );
734
735 let transport_b = DelayedMockTransport::new(
736 Duration::from_millis(20),
737 success_response(tx_hash), );
739
740 let transports = vec![transport_a.clone(), transport_b.clone()];
741 let mut fallback_service = FallbackService::new(transports, 2);
742
743 let request = Request::new(
744 "eth_sendRawTransaction",
745 Id::Number(1),
746 [serde_json::Value::String("0xabcdef".to_string())],
747 );
748 let serialized = request.serialize().unwrap();
749 let request_packet = RequestPacket::Single(serialized);
750
751 let start = std::time::Instant::now();
752 let response = fallback_service.call(request_packet).await.unwrap();
753 let elapsed = start.elapsed();
754
755 let result = match response {
756 ResponsePacket::Single(resp) => match resp.payload {
757 ResponsePayload::Success(data) => data.get().to_string(),
758 ResponsePayload::Failure(err) => panic!("Unexpected error: {:?}", err),
759 },
760 ResponsePacket::Batch(_) => panic!("Unexpected batch response"),
761 };
762
763 assert_eq!(transport_a.call_count(), 1, "Transport A should be called");
765 assert_eq!(transport_b.call_count(), 1, "Transport B should be called");
766
767 assert_eq!(result, format!("\"{}\"", tx_hash));
769
770 assert!(
772 elapsed < Duration::from_millis(50),
773 "Should use parallel execution and return fast: {:?}",
774 elapsed
775 );
776 }
777
778 #[tokio::test]
779 async fn test_batch_with_any_sequential_method_uses_sequential_execution() {
780 let tx_hash = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef";
784
785 let transport_a =
787 DelayedMockTransport::new(Duration::from_millis(10), success_response(tx_hash));
788
789 let transport_b = DelayedMockTransport::new(
792 Duration::from_millis(10),
793 success_response("should_not_be_called"),
794 );
795
796 let transports = vec![transport_a.clone(), transport_b.clone()];
797 let mut fallback_service = FallbackService::new(transports, 2);
798
799 let request1 = Request::new("eth_blockNumber", Id::Number(1), ());
803 let request2 = Request::new(
804 "eth_sendRawTransactionSync",
805 Id::Number(2),
806 [serde_json::Value::String("0xabcdef".to_string())],
807 );
808
809 let batch = vec![request1.serialize().unwrap(), request2.serialize().unwrap()];
810 let request_packet = RequestPacket::Batch(batch);
811
812 let start = std::time::Instant::now();
813 let response = fallback_service.call(request_packet).await.unwrap();
814 let elapsed = start.elapsed();
815
816 assert_eq!(
819 transport_a.call_count(),
820 1,
821 "Transport A should be called once (first in sequence)"
822 );
823 assert_eq!(
824 transport_b.call_count(),
825 0,
826 "Transport B should NOT be called (transport A succeeded)"
827 );
828
829 match response {
831 ResponsePacket::Batch(responses) => {
832 assert_eq!(responses.len(), 2, "Should get 2 responses in batch");
833 for resp in responses {
835 match resp.payload {
836 ResponsePayload::Success(_) => {} ResponsePayload::Failure(err) => panic!("Unexpected error: {:?}", err),
838 }
839 }
840 }
841 ResponsePacket::Single(_) => panic!("Expected batch response"),
842 }
843
844 assert!(
846 elapsed < Duration::from_millis(50),
847 "Sequential execution with fast first transport should be quick: {:?}",
848 elapsed
849 );
850 }
851
852 #[tokio::test]
853 async fn test_custom_sequential_method() {
854 let transport_a =
858 DelayedMockTransport::new(Duration::from_millis(10), success_response("result_a"));
859
860 let transport_b =
862 DelayedMockTransport::new(Duration::from_millis(10), success_response("result_b"));
863
864 let transports = vec![transport_a.clone(), transport_b.clone()];
865
866 let custom_methods = ["my_custom_method".to_string()].into_iter().collect();
868 let mut fallback_service =
869 FallbackService::new(transports, 2).with_sequential_methods(custom_methods);
870
871 let request = Request::new("my_custom_method", Id::Number(1), ());
872 let serialized = request.serialize().unwrap();
873 let request_packet = RequestPacket::Single(serialized);
874
875 let start = std::time::Instant::now();
876 let _response = fallback_service.call(request_packet).await.unwrap();
877 let elapsed = start.elapsed();
878
879 assert_eq!(
883 transport_a.call_count(),
884 1,
885 "Transport A should be called once (sequential, first transport)"
886 );
887 assert_eq!(
888 transport_b.call_count(),
889 0,
890 "Transport B should NOT be called (sequential mode, A succeeded)"
891 );
892
893 assert!(
895 elapsed < Duration::from_millis(50),
896 "Sequential execution with fast first transport: {:?}",
897 elapsed
898 );
899 }
900}