1use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8
9use crate::error::TransportError;
10use crate::metrics::ProviderMetrics;
11use crate::policy::{CircuitBreaker, CircuitBreakerConfig};
12use crate::request::{JsonRpcRequest, JsonRpcResponse};
13use crate::transport::{HealthStatus, RpcTransport};
14
15#[derive(Debug, Clone)]
17pub struct ProviderPoolConfig {
18 pub circuit_breaker: CircuitBreakerConfig,
20 pub request_timeout: Duration,
22}
23
24impl Default for ProviderPoolConfig {
25 fn default() -> Self {
26 Self {
27 circuit_breaker: CircuitBreakerConfig::default(),
28 request_timeout: Duration::from_secs(30),
29 }
30 }
31}
32
33struct ProviderSlot {
34 transport: Arc<dyn RpcTransport>,
35 circuit: CircuitBreaker,
36 metrics: Option<Arc<ProviderMetrics>>,
37}
38
39pub struct ProviderPool {
44 slots: Vec<ProviderSlot>,
45 cursor: AtomicUsize,
46 config: ProviderPoolConfig,
47}
48
49impl ProviderPool {
50 pub fn new(transports: Vec<Arc<dyn RpcTransport>>, config: ProviderPoolConfig) -> Self {
52 let slots = transports
53 .into_iter()
54 .map(|t| ProviderSlot {
55 transport: t,
56 circuit: CircuitBreaker::new(config.circuit_breaker.clone()),
57 metrics: None,
58 })
59 .collect();
60 Self {
61 slots,
62 cursor: AtomicUsize::new(0),
63 config,
64 }
65 }
66
67 pub fn new_with_metrics(
69 transports: Vec<Arc<dyn RpcTransport>>,
70 config: ProviderPoolConfig,
71 ) -> Self {
72 let slots = transports
73 .into_iter()
74 .map(|t| {
75 let m = Arc::new(ProviderMetrics::new(t.url()));
76 ProviderSlot {
77 transport: t,
78 circuit: CircuitBreaker::new(config.circuit_breaker.clone()),
79 metrics: Some(m),
80 }
81 })
82 .collect();
83 Self {
84 slots,
85 cursor: AtomicUsize::new(0),
86 config,
87 }
88 }
89
90 pub fn len(&self) -> usize {
92 self.slots.len()
93 }
94
95 pub fn is_empty(&self) -> bool {
97 self.slots.is_empty()
98 }
99
100 pub fn health_summary(&self) -> Vec<(String, HealthStatus, String)> {
102 self.slots
103 .iter()
104 .map(|s| {
105 let url = s.transport.url().to_string();
106 let health = s.transport.health();
107 let circuit = s.circuit.state().to_string();
108 (url, health, circuit)
109 })
110 .collect()
111 }
112
113 pub fn healthy_count(&self) -> usize {
115 self.slots.iter().filter(|s| s.circuit.is_allowed()).count()
116 }
117
118 pub fn metrics(&self) -> Vec<crate::metrics::MetricsSnapshot> {
120 self.slots
121 .iter()
122 .filter_map(|s| s.metrics.as_ref().map(|m| m.snapshot()))
123 .collect()
124 }
125
126 pub fn health_report(&self) -> Vec<serde_json::Value> {
132 self.slots
133 .iter()
134 .map(|s| {
135 let mut report = serde_json::json!({
136 "url": s.transport.url(),
137 "health": s.transport.health().to_string(),
138 "circuit": s.circuit.state().to_string(),
139 });
140 if let Some(ref m) = s.metrics {
141 let snap = m.snapshot();
142 let obj = report.as_object_mut().unwrap();
143 obj.insert(
144 "total_requests".into(),
145 serde_json::json!(snap.total_requests),
146 );
147 obj.insert(
148 "successful_requests".into(),
149 serde_json::json!(snap.successful_requests),
150 );
151 obj.insert(
152 "failed_requests".into(),
153 serde_json::json!(snap.failed_requests),
154 );
155 obj.insert("success_rate".into(), serde_json::json!(snap.success_rate));
156 obj.insert(
157 "avg_latency_ms".into(),
158 serde_json::json!(snap.avg_latency_ms),
159 );
160 obj.insert(
161 "rate_limit_hits".into(),
162 serde_json::json!(snap.rate_limit_hits),
163 );
164 obj.insert(
165 "circuit_open_count".into(),
166 serde_json::json!(snap.circuit_open_count),
167 );
168 }
169 report
170 })
171 .collect()
172 }
173
174 fn next_slot(&self) -> Option<&ProviderSlot> {
177 if self.slots.is_empty() {
178 return None;
179 }
180 let start = self.cursor.fetch_add(1, Ordering::Relaxed) % self.slots.len();
181 for i in 0..self.slots.len() {
182 let idx = (start + i) % self.slots.len();
183 let slot = &self.slots[idx];
184 if slot.circuit.is_allowed() {
185 return Some(slot);
186 }
187 }
188 None
189 }
190}
191
192#[async_trait]
193impl RpcTransport for ProviderPool {
194 async fn send(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
195 let slot = self.next_slot().ok_or(TransportError::AllProvidersDown)?;
196
197 let timeout = self.config.request_timeout;
198 let start = std::time::Instant::now();
199 let result = tokio::time::timeout(timeout, slot.transport.send(req))
200 .await
201 .map_err(|_| TransportError::Timeout {
202 ms: timeout.as_millis() as u64,
203 })?;
204
205 match result {
206 Ok(resp) => {
207 slot.circuit.record_success();
208 if let Some(ref m) = slot.metrics {
209 m.record_success(start.elapsed());
210 }
211 Ok(resp)
212 }
213 Err(e) if e.is_retryable() => {
214 slot.circuit.record_failure();
215 if let Some(ref m) = slot.metrics {
216 m.record_failure();
217 }
218 Err(e)
219 }
220 Err(e) => {
221 if let Some(ref m) = slot.metrics {
222 m.record_failure();
223 }
224 Err(e)
225 }
226 }
227 }
228
229 fn health(&self) -> HealthStatus {
230 let healthy_count = self.slots.iter().filter(|s| s.circuit.is_allowed()).count();
231 match healthy_count {
232 0 => HealthStatus::Unhealthy,
233 n if n == self.slots.len() => HealthStatus::Healthy,
234 _ => HealthStatus::Degraded,
235 }
236 }
237
238 fn url(&self) -> &str {
239 "pool"
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use crate::request::RpcId;
247
248 struct MockTransport {
249 url: String,
250 should_fail: bool,
251 }
252
253 #[async_trait]
254 impl RpcTransport for MockTransport {
255 async fn send(&self, _req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
256 if self.should_fail {
257 Err(TransportError::Http("mock error".into()))
258 } else {
259 Ok(JsonRpcResponse {
260 jsonrpc: "2.0".into(),
261 id: RpcId::Number(1),
262 result: Some(serde_json::Value::String("0x1".into())),
263 error: None,
264 })
265 }
266 }
267 fn url(&self) -> &str {
268 &self.url
269 }
270 }
271
272 fn mock(url: &str, fail: bool) -> Arc<dyn RpcTransport> {
273 Arc::new(MockTransport {
274 url: url.to_string(),
275 should_fail: fail,
276 })
277 }
278
279 #[test]
280 fn pool_len() {
281 let pool = ProviderPool::new(
282 vec![mock("https://a.com", false), mock("https://b.com", false)],
283 ProviderPoolConfig::default(),
284 );
285 assert_eq!(pool.len(), 2);
286 }
287
288 #[test]
289 fn health_all_healthy() {
290 let pool = ProviderPool::new(
291 vec![mock("https://a.com", false)],
292 ProviderPoolConfig::default(),
293 );
294 assert_eq!(pool.health(), HealthStatus::Healthy);
295 }
296
297 #[test]
298 fn health_all_down() {
299 let pool = ProviderPool::new(vec![], ProviderPoolConfig::default());
300 assert!(pool.next_slot().is_none());
302 }
303}