1use crate::connect::health::{ConnectionPoolTrait, HealthChecker};
4use crate::error::{BittensorError, RetryConfig};
5use crate::retry::ExponentialBackoff;
6use futures::future::join_all;
7use std::sync::Arc;
8use std::time::Duration;
9use subxt::{OnlineClient, PolkadotConfig};
10use tokio::sync::RwLock;
11use tracing::{debug, error, info, warn};
12
13type ChainClient = OnlineClient<PolkadotConfig>;
15
16#[derive(Debug, Clone)]
18pub struct ConnectionPool {
19 endpoints: Arc<Vec<String>>,
20 connections: Arc<RwLock<Vec<Arc<ChainClient>>>>,
21 health_checker: Arc<HealthChecker>,
22 #[doc(hidden)]
23 pub max_connections: usize,
24 #[doc(hidden)]
25 pub retry_config: RetryConfig,
26}
27
28impl ConnectionPool {
29 pub fn new(endpoints: Vec<String>, max_connections: usize) -> Self {
35 Self {
36 endpoints: Arc::new(endpoints),
37 connections: Arc::new(RwLock::new(Vec::new())),
38 health_checker: Arc::new(HealthChecker::default()),
39 max_connections,
40 retry_config: RetryConfig::network(),
41 }
42 }
43
44 pub async fn initialize(&self) -> Result<(), BittensorError> {
46 let mut connections = Vec::with_capacity(self.max_connections);
47 let endpoints_to_try = self
48 .endpoints
49 .iter()
50 .take(self.max_connections)
51 .collect::<Vec<_>>();
52
53 if endpoints_to_try.is_empty() {
54 return Err(BittensorError::ConfigError {
55 field: "endpoints".to_string(),
56 message: "No endpoints configured".to_string(),
57 });
58 }
59
60 let connection_futures = endpoints_to_try
62 .iter()
63 .map(|endpoint| self.create_connection(endpoint));
64
65 let results = join_all(connection_futures).await;
66
67 for (endpoint, result) in endpoints_to_try.into_iter().zip(results) {
68 match result {
69 Ok(client) => {
70 info!("Successfully connected to {}", endpoint);
71 connections.push(Arc::new(client));
72 }
73 Err(e) => {
74 warn!("Failed to connect to {}: {}", endpoint, e);
75 }
76 }
77 }
78
79 if connections.is_empty() {
80 error!("Failed to establish any connections to chain endpoints");
81 return Err(BittensorError::NetworkError {
82 message: "Failed to establish any connections".to_string(),
83 });
84 }
85
86 info!(
87 "Initialized connection pool with {} connections",
88 connections.len()
89 );
90 *self.connections.write().await = connections;
91 Ok(())
92 }
93
94 pub async fn get_healthy_client(&self) -> Result<Arc<ChainClient>, BittensorError> {
96 {
98 let connections = self.connections.read().await;
99 for conn in connections.iter() {
100 if self.health_checker.is_healthy(conn).await {
101 return Ok(Arc::clone(conn));
102 }
103 }
104 }
105
106 warn!("All connections unhealthy, attempting reconnection");
108 self.reconnect_with_backoff().await
109 }
110
111 pub async fn reconnect_with_backoff(&self) -> Result<Arc<ChainClient>, BittensorError> {
113 let mut backoff = ExponentialBackoff::new(self.retry_config.clone());
114 let mut last_error = None;
115
116 while let Some(delay) = backoff.next_delay() {
117 debug!("Waiting {:?} before reconnection attempt", delay);
118 tokio::time::sleep(delay).await;
119
120 match self.try_reconnect().await {
121 Ok(client) => {
122 info!("Successfully reconnected to chain");
123 return Ok(client);
124 }
125 Err(e) => {
126 warn!("Reconnection attempt {} failed: {}", backoff.attempts(), e);
127 last_error = Some(e);
128 }
129 }
130 }
131
132 Err(last_error.unwrap_or_else(|| BittensorError::NetworkError {
133 message: "Failed to reconnect after maximum attempts".to_string(),
134 }))
135 }
136
137 async fn try_reconnect(&self) -> Result<Arc<ChainClient>, BittensorError> {
139 for endpoint in self.endpoints.iter() {
141 match self.create_connection(endpoint).await {
142 Ok(client) => {
143 let client_arc = Arc::new(client);
144
145 let mut connections = self.connections.write().await;
147 connections.clear();
148 connections.push(Arc::clone(&client_arc));
149
150 return Ok(client_arc);
151 }
152 Err(e) => {
153 debug!("Failed to connect to {}: {}", endpoint, e);
154 }
155 }
156 }
157
158 Err(BittensorError::NetworkError {
159 message: "Failed to connect to any endpoint".to_string(),
160 })
161 }
162
163 async fn create_connection(&self, endpoint: &str) -> Result<ChainClient, BittensorError> {
165 let timeout_duration = Duration::from_secs(30);
166
167 let is_insecure = endpoint.starts_with("ws://") || endpoint.starts_with("http://");
168
169 let result = if is_insecure {
170 debug!("Using insecure connection for endpoint: {}", endpoint);
171 tokio::time::timeout(
172 timeout_duration,
173 OnlineClient::<PolkadotConfig>::from_insecure_url(endpoint),
174 )
175 .await
176 } else {
177 tokio::time::timeout(
178 timeout_duration,
179 OnlineClient::<PolkadotConfig>::from_url(endpoint),
180 )
181 .await
182 };
183
184 result
185 .map_err(|_| BittensorError::RpcTimeoutError {
186 message: format!("Connection to {} timed out", endpoint),
187 timeout: timeout_duration,
188 })?
189 .map_err(|e| BittensorError::RpcConnectionError {
190 message: format!("Failed to connect to {}: {}", endpoint, e),
191 })
192 }
193
194 pub async fn healthy_connection_count(&self) -> usize {
196 let connections = self.connections.read().await;
197 let mut count = 0;
198
199 for conn in connections.iter() {
200 if self.health_checker.is_healthy(conn).await {
201 count += 1;
202 }
203 }
204
205 count
206 }
207
208 pub async fn refresh_connections(&self) -> Result<(), BittensorError> {
210 info!("Refreshing all connections");
211 self.initialize().await
212 }
213
214 pub async fn total_connections(&self) -> usize {
216 self.connections.read().await.len()
217 }
218}
219
220pub struct ConnectionPoolBuilder {
222 endpoints: Vec<String>,
223 max_connections: usize,
224 retry_config: Option<RetryConfig>,
225 health_checker: Option<HealthChecker>,
226}
227
228impl ConnectionPoolBuilder {
229 pub fn new(endpoints: Vec<String>) -> Self {
230 Self {
231 endpoints,
232 max_connections: 3,
233 retry_config: None,
234 health_checker: None,
235 }
236 }
237
238 pub fn max_connections(mut self, max: usize) -> Self {
239 self.max_connections = max;
240 self
241 }
242
243 pub fn retry_config(mut self, config: RetryConfig) -> Self {
244 self.retry_config = Some(config);
245 self
246 }
247
248 pub fn health_checker(mut self, checker: HealthChecker) -> Self {
249 self.health_checker = Some(checker);
250 self
251 }
252
253 pub fn build(self) -> ConnectionPool {
254 let mut pool = ConnectionPool::new(self.endpoints, self.max_connections);
255
256 if let Some(config) = self.retry_config {
257 pool.retry_config = config;
258 }
259
260 if let Some(checker) = self.health_checker {
261 pool.health_checker = Arc::new(checker);
262 }
263
264 pool
265 }
266}
267
268#[async_trait::async_trait]
270impl ConnectionPoolTrait for ConnectionPool {
271 async fn connections(&self) -> Arc<RwLock<Vec<Arc<ChainClient>>>> {
272 Arc::clone(&self.connections)
273 }
274
275 async fn reconnect_with_backoff(&self) -> Result<Arc<ChainClient>, BittensorError> {
276 ConnectionPool::reconnect_with_backoff(self).await
277 }
278
279 async fn get_healthy_client(&self) -> Result<Arc<ChainClient>, BittensorError> {
280 ConnectionPool::get_healthy_client(self).await
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use wiremock::matchers::{method, path};
288 use wiremock::{Mock, MockServer, ResponseTemplate};
289
290 async fn setup_mock_server() -> MockServer {
291 MockServer::start().await
292 }
293
294 #[tokio::test]
295 async fn test_connection_pool_creation() {
296 let endpoints = vec!["wss://test.endpoint:443".to_string()];
297 let pool = ConnectionPool::new(endpoints.clone(), 3);
298
299 assert_eq!(pool.endpoints.len(), 1);
300 assert_eq!(pool.max_connections, 3);
301 }
302
303 #[tokio::test]
304 async fn test_connection_pool_builder() {
305 let endpoints = vec!["wss://test.endpoint:443".to_string()];
306 let pool = ConnectionPoolBuilder::new(endpoints.clone())
307 .max_connections(5)
308 .retry_config(RetryConfig::transient())
309 .build();
310
311 assert_eq!(pool.endpoints.len(), 1);
312 assert_eq!(pool.max_connections, 5);
313 }
314
315 #[tokio::test]
316 async fn test_empty_endpoints_initialization() {
317 let pool = ConnectionPool::new(vec![], 3);
318 let result = pool.initialize().await;
319
320 assert!(result.is_err());
321 if let Err(BittensorError::ConfigError { field, .. }) = result {
322 assert_eq!(field, "endpoints");
323 } else {
324 panic!("Expected ConfigError");
325 }
326 }
327
328 #[tokio::test]
329 async fn test_connection_pool_initialization_with_mock() {
330 let mock_server = setup_mock_server().await;
331
332 Mock::given(method("POST"))
333 .and(path("/"))
334 .respond_with(ResponseTemplate::new(200))
335 .mount(&mock_server)
336 .await;
337
338 let endpoints = vec![format!("ws://{}", mock_server.address())];
341 let pool = ConnectionPool::new(endpoints, 1);
342
343 let result = pool.initialize().await;
346 assert!(result.is_err()); }
348
349 #[tokio::test]
350 async fn test_healthy_connection_count() {
351 let pool = ConnectionPool::new(vec!["wss://test.endpoint:443".to_string()], 3);
352 let count = pool.healthy_connection_count().await;
353 assert_eq!(count, 0); }
355
356 #[tokio::test]
357 async fn test_total_connections() {
358 let pool = ConnectionPool::new(vec!["wss://test.endpoint:443".to_string()], 3);
359 let count = pool.total_connections().await;
360 assert_eq!(count, 0); }
362
363 #[tokio::test]
364 async fn test_get_healthy_client_no_connections() {
365 let pool = ConnectionPool::new(vec!["wss://invalid.endpoint:443".to_string()], 1);
366 let result = pool.get_healthy_client().await;
367 assert!(result.is_err());
368 }
369
370 #[tokio::test]
371 async fn test_reconnect_with_backoff() {
372 let pool = ConnectionPool::new(vec!["wss://invalid.endpoint:443".to_string()], 1);
373
374 let mut pool = pool;
376 pool.retry_config = RetryConfig {
377 max_attempts: 2,
378 initial_delay: Duration::from_millis(10),
379 max_delay: Duration::from_millis(20),
380 backoff_multiplier: 1.5,
381 jitter: false,
382 };
383
384 let result = pool.reconnect_with_backoff().await;
385 assert!(result.is_err());
386 }
387
388 #[tokio::test]
389 async fn test_multiple_endpoints_fallback() {
390 let endpoints = vec![
391 "wss://invalid1.endpoint:443".to_string(),
392 "wss://invalid2.endpoint:443".to_string(),
393 "wss://invalid3.endpoint:443".to_string(),
394 ];
395
396 let pool = ConnectionPool::new(endpoints, 3);
397 let result = pool.try_reconnect().await;
398 assert!(result.is_err()); }
400
401 #[tokio::test]
402 async fn test_create_connection_timeout() {
403 let pool = ConnectionPool::new(vec!["wss://10.255.255.1:443".to_string()], 1);
404
405 let result = pool.create_connection("wss://10.255.255.1:443").await;
407 assert!(result.is_err());
408
409 match result {
410 Err(BittensorError::RpcTimeoutError { .. })
411 | Err(BittensorError::RpcConnectionError { .. }) => {
412 }
414 Err(e) => {
415 panic!(
416 "Expected RpcTimeoutError or RpcConnectionError, got: {:?}",
417 e
418 );
419 }
420 Ok(_) => panic!("Expected error but got Ok"),
421 }
422 }
423}