Skip to main content

chainrpc_core/
pool.rs

1//! Multi-provider failover pool with round-robin selection and health tracking.
2
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8
9use crate::error::TransportError;
10use crate::metrics::ProviderMetrics;
11use crate::policy::{CircuitBreaker, CircuitBreakerConfig};
12use crate::request::{JsonRpcRequest, JsonRpcResponse};
13use crate::transport::{HealthStatus, RpcTransport};
14
15/// Configuration for the provider pool.
16#[derive(Debug, Clone)]
17pub struct ProviderPoolConfig {
18    /// Circuit breaker config shared across all providers.
19    pub circuit_breaker: CircuitBreakerConfig,
20    /// Timeout per individual request.
21    pub request_timeout: Duration,
22}
23
24impl Default for ProviderPoolConfig {
25    fn default() -> Self {
26        Self {
27            circuit_breaker: CircuitBreakerConfig::default(),
28            request_timeout: Duration::from_secs(30),
29        }
30    }
31}
32
33struct ProviderSlot {
34    transport: Arc<dyn RpcTransport>,
35    circuit: CircuitBreaker,
36    metrics: Option<Arc<ProviderMetrics>>,
37}
38
39/// Round-robin provider pool with per-provider circuit breakers.
40///
41/// Automatically skips unhealthy (circuit-open) providers and falls
42/// back to the next available one.
43pub struct ProviderPool {
44    slots: Vec<ProviderSlot>,
45    cursor: AtomicUsize,
46    config: ProviderPoolConfig,
47}
48
49impl ProviderPool {
50    /// Build a pool from a list of transports.
51    pub fn new(transports: Vec<Arc<dyn RpcTransport>>, config: ProviderPoolConfig) -> Self {
52        let slots = transports
53            .into_iter()
54            .map(|t| ProviderSlot {
55                transport: t,
56                circuit: CircuitBreaker::new(config.circuit_breaker.clone()),
57                metrics: None,
58            })
59            .collect();
60        Self {
61            slots,
62            cursor: AtomicUsize::new(0),
63            config,
64        }
65    }
66
67    /// Build a pool with per-provider metrics automatically created.
68    pub fn new_with_metrics(
69        transports: Vec<Arc<dyn RpcTransport>>,
70        config: ProviderPoolConfig,
71    ) -> Self {
72        let slots = transports
73            .into_iter()
74            .map(|t| {
75                let m = Arc::new(ProviderMetrics::new(t.url()));
76                ProviderSlot {
77                    transport: t,
78                    circuit: CircuitBreaker::new(config.circuit_breaker.clone()),
79                    metrics: Some(m),
80                }
81            })
82            .collect();
83        Self {
84            slots,
85            cursor: AtomicUsize::new(0),
86            config,
87        }
88    }
89
90    /// Number of providers in the pool.
91    pub fn len(&self) -> usize {
92        self.slots.len()
93    }
94
95    /// Returns `true` if the pool has no providers.
96    pub fn is_empty(&self) -> bool {
97        self.slots.is_empty()
98    }
99
100    /// Returns summary of each provider's health.
101    pub fn health_summary(&self) -> Vec<(String, HealthStatus, String)> {
102        self.slots
103            .iter()
104            .map(|s| {
105                let url = s.transport.url().to_string();
106                let health = s.transport.health();
107                let circuit = s.circuit.state().to_string();
108                (url, health, circuit)
109            })
110            .collect()
111    }
112
113    /// Number of providers whose circuit breaker allows requests.
114    pub fn healthy_count(&self) -> usize {
115        self.slots.iter().filter(|s| s.circuit.is_allowed()).count()
116    }
117
118    /// Return metrics snapshots for all providers that have metrics enabled.
119    pub fn metrics(&self) -> Vec<crate::metrics::MetricsSnapshot> {
120        self.slots
121            .iter()
122            .filter_map(|s| s.metrics.as_ref().map(|m| m.snapshot()))
123            .collect()
124    }
125
126    /// Detailed health report for each provider as JSON-serializable values.
127    ///
128    /// When per-provider metrics are available the report includes
129    /// additional fields such as `total_requests`, `success_rate`, and
130    /// `avg_latency_ms`.
131    pub fn health_report(&self) -> Vec<serde_json::Value> {
132        self.slots
133            .iter()
134            .map(|s| {
135                let mut report = serde_json::json!({
136                    "url": s.transport.url(),
137                    "health": s.transport.health().to_string(),
138                    "circuit": s.circuit.state().to_string(),
139                });
140                if let Some(ref m) = s.metrics {
141                    let snap = m.snapshot();
142                    let obj = report.as_object_mut().unwrap();
143                    obj.insert(
144                        "total_requests".into(),
145                        serde_json::json!(snap.total_requests),
146                    );
147                    obj.insert(
148                        "successful_requests".into(),
149                        serde_json::json!(snap.successful_requests),
150                    );
151                    obj.insert(
152                        "failed_requests".into(),
153                        serde_json::json!(snap.failed_requests),
154                    );
155                    obj.insert("success_rate".into(), serde_json::json!(snap.success_rate));
156                    obj.insert(
157                        "avg_latency_ms".into(),
158                        serde_json::json!(snap.avg_latency_ms),
159                    );
160                    obj.insert(
161                        "rate_limit_hits".into(),
162                        serde_json::json!(snap.rate_limit_hits),
163                    );
164                    obj.insert(
165                        "circuit_open_count".into(),
166                        serde_json::json!(snap.circuit_open_count),
167                    );
168                }
169                report
170            })
171            .collect()
172    }
173
174    /// Find the next available (circuit-closed/half-open) slot, starting
175    /// from the round-robin cursor.
176    fn next_slot(&self) -> Option<&ProviderSlot> {
177        if self.slots.is_empty() {
178            return None;
179        }
180        let start = self.cursor.fetch_add(1, Ordering::Relaxed) % self.slots.len();
181        for i in 0..self.slots.len() {
182            let idx = (start + i) % self.slots.len();
183            let slot = &self.slots[idx];
184            if slot.circuit.is_allowed() {
185                return Some(slot);
186            }
187        }
188        None
189    }
190}
191
192#[async_trait]
193impl RpcTransport for ProviderPool {
194    async fn send(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
195        let slot = self.next_slot().ok_or(TransportError::AllProvidersDown)?;
196
197        let timeout = self.config.request_timeout;
198        let start = std::time::Instant::now();
199        let result = tokio::time::timeout(timeout, slot.transport.send(req))
200            .await
201            .map_err(|_| TransportError::Timeout {
202                ms: timeout.as_millis() as u64,
203            })?;
204
205        match result {
206            Ok(resp) => {
207                slot.circuit.record_success();
208                if let Some(ref m) = slot.metrics {
209                    m.record_success(start.elapsed());
210                }
211                Ok(resp)
212            }
213            Err(e) if e.is_retryable() => {
214                slot.circuit.record_failure();
215                if let Some(ref m) = slot.metrics {
216                    m.record_failure();
217                }
218                Err(e)
219            }
220            Err(e) => {
221                if let Some(ref m) = slot.metrics {
222                    m.record_failure();
223                }
224                Err(e)
225            }
226        }
227    }
228
229    fn health(&self) -> HealthStatus {
230        let healthy_count = self.slots.iter().filter(|s| s.circuit.is_allowed()).count();
231        match healthy_count {
232            0 => HealthStatus::Unhealthy,
233            n if n == self.slots.len() => HealthStatus::Healthy,
234            _ => HealthStatus::Degraded,
235        }
236    }
237
238    fn url(&self) -> &str {
239        "pool"
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::request::RpcId;
247
248    struct MockTransport {
249        url: String,
250        should_fail: bool,
251    }
252
253    #[async_trait]
254    impl RpcTransport for MockTransport {
255        async fn send(&self, _req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
256            if self.should_fail {
257                Err(TransportError::Http("mock error".into()))
258            } else {
259                Ok(JsonRpcResponse {
260                    jsonrpc: "2.0".into(),
261                    id: RpcId::Number(1),
262                    result: Some(serde_json::Value::String("0x1".into())),
263                    error: None,
264                })
265            }
266        }
267        fn url(&self) -> &str {
268            &self.url
269        }
270    }
271
272    fn mock(url: &str, fail: bool) -> Arc<dyn RpcTransport> {
273        Arc::new(MockTransport {
274            url: url.to_string(),
275            should_fail: fail,
276        })
277    }
278
279    #[test]
280    fn pool_len() {
281        let pool = ProviderPool::new(
282            vec![mock("https://a.com", false), mock("https://b.com", false)],
283            ProviderPoolConfig::default(),
284        );
285        assert_eq!(pool.len(), 2);
286    }
287
288    #[test]
289    fn health_all_healthy() {
290        let pool = ProviderPool::new(
291            vec![mock("https://a.com", false)],
292            ProviderPoolConfig::default(),
293        );
294        assert_eq!(pool.health(), HealthStatus::Healthy);
295    }
296
297    #[test]
298    fn health_all_down() {
299        let pool = ProviderPool::new(vec![], ProviderPoolConfig::default());
300        // No providers → AllProvidersDown (next_slot returns None)
301        assert!(pool.next_slot().is_none());
302    }
303}