bittensor_rs/connect/
pool.rs

1//! Connection pool for managing multiple blockchain connections with automatic failover
2
3use crate::connect::health::{ConnectionPoolTrait, HealthChecker};
4use crate::error::{BittensorError, RetryConfig};
5use crate::retry::ExponentialBackoff;
6use futures::future::join_all;
7use std::sync::Arc;
8use std::time::Duration;
9use subxt::{OnlineClient, PolkadotConfig};
10use tokio::sync::RwLock;
11use tracing::{debug, error, info, warn};
12
13/// Type alias for chain client
14type ChainClient = OnlineClient<PolkadotConfig>;
15
16/// Connection pool managing multiple blockchain connections with automatic failover
17#[derive(Debug, Clone)]
18pub struct ConnectionPool {
19    endpoints: Arc<Vec<String>>,
20    connections: Arc<RwLock<Vec<Arc<ChainClient>>>>,
21    health_checker: Arc<HealthChecker>,
22    #[doc(hidden)]
23    pub max_connections: usize,
24    #[doc(hidden)]
25    pub retry_config: RetryConfig,
26}
27
28impl ConnectionPool {
29    /// Creates a new connection pool
30    ///
31    /// # Arguments
32    /// * `endpoints` - List of WebSocket endpoints to connect to
33    /// * `max_connections` - Maximum number of concurrent connections to maintain
34    pub fn new(endpoints: Vec<String>, max_connections: usize) -> Self {
35        Self {
36            endpoints: Arc::new(endpoints),
37            connections: Arc::new(RwLock::new(Vec::new())),
38            health_checker: Arc::new(HealthChecker::default()),
39            max_connections,
40            retry_config: RetryConfig::network(),
41        }
42    }
43
44    /// Initialize the connection pool with at least one working connection
45    pub async fn initialize(&self) -> Result<(), BittensorError> {
46        let mut connections = Vec::with_capacity(self.max_connections);
47        let endpoints_to_try = self
48            .endpoints
49            .iter()
50            .take(self.max_connections)
51            .collect::<Vec<_>>();
52
53        if endpoints_to_try.is_empty() {
54            return Err(BittensorError::ConfigError {
55                field: "endpoints".to_string(),
56                message: "No endpoints configured".to_string(),
57            });
58        }
59
60        // Try to establish connections in parallel
61        let connection_futures = endpoints_to_try
62            .iter()
63            .map(|endpoint| self.create_connection(endpoint));
64
65        let results = join_all(connection_futures).await;
66
67        for (endpoint, result) in endpoints_to_try.into_iter().zip(results) {
68            match result {
69                Ok(client) => {
70                    info!("Successfully connected to {}", endpoint);
71                    connections.push(Arc::new(client));
72                }
73                Err(e) => {
74                    warn!("Failed to connect to {}: {}", endpoint, e);
75                }
76            }
77        }
78
79        if connections.is_empty() {
80            error!("Failed to establish any connections to chain endpoints");
81            return Err(BittensorError::NetworkError {
82                message: "Failed to establish any connections".to_string(),
83            });
84        }
85
86        info!(
87            "Initialized connection pool with {} connections",
88            connections.len()
89        );
90        *self.connections.write().await = connections;
91        Ok(())
92    }
93
94    /// Get a healthy client from the pool, reconnecting if necessary
95    pub async fn get_healthy_client(&self) -> Result<Arc<ChainClient>, BittensorError> {
96        // Fast path: check existing connections
97        {
98            let connections = self.connections.read().await;
99            for conn in connections.iter() {
100                if self.health_checker.is_healthy(conn).await {
101                    return Ok(Arc::clone(conn));
102                }
103            }
104        }
105
106        // Slow path: all connections unhealthy, trigger reconnection
107        warn!("All connections unhealthy, attempting reconnection");
108        self.reconnect_with_backoff().await
109    }
110
111    /// Reconnect to endpoints with exponential backoff
112    pub async fn reconnect_with_backoff(&self) -> Result<Arc<ChainClient>, BittensorError> {
113        let mut backoff = ExponentialBackoff::new(self.retry_config.clone());
114        let mut last_error = None;
115
116        while let Some(delay) = backoff.next_delay() {
117            debug!("Waiting {:?} before reconnection attempt", delay);
118            tokio::time::sleep(delay).await;
119
120            match self.try_reconnect().await {
121                Ok(client) => {
122                    info!("Successfully reconnected to chain");
123                    return Ok(client);
124                }
125                Err(e) => {
126                    warn!("Reconnection attempt {} failed: {}", backoff.attempts(), e);
127                    last_error = Some(e);
128                }
129            }
130        }
131
132        Err(last_error.unwrap_or_else(|| BittensorError::NetworkError {
133            message: "Failed to reconnect after maximum attempts".to_string(),
134        }))
135    }
136
137    /// Attempt to reconnect to any available endpoint
138    async fn try_reconnect(&self) -> Result<Arc<ChainClient>, BittensorError> {
139        // Try endpoints in order of priority
140        for endpoint in self.endpoints.iter() {
141            match self.create_connection(endpoint).await {
142                Ok(client) => {
143                    let client_arc = Arc::new(client);
144
145                    // Update connection pool atomically
146                    let mut connections = self.connections.write().await;
147                    connections.clear();
148                    connections.push(Arc::clone(&client_arc));
149
150                    return Ok(client_arc);
151                }
152                Err(e) => {
153                    debug!("Failed to connect to {}: {}", endpoint, e);
154                }
155            }
156        }
157
158        Err(BittensorError::NetworkError {
159            message: "Failed to connect to any endpoint".to_string(),
160        })
161    }
162
163    /// Create a new connection to the specified endpoint
164    async fn create_connection(&self, endpoint: &str) -> Result<ChainClient, BittensorError> {
165        let timeout_duration = Duration::from_secs(30);
166
167        let is_insecure = endpoint.starts_with("ws://") || endpoint.starts_with("http://");
168
169        let result = if is_insecure {
170            debug!("Using insecure connection for endpoint: {}", endpoint);
171            tokio::time::timeout(
172                timeout_duration,
173                OnlineClient::<PolkadotConfig>::from_insecure_url(endpoint),
174            )
175            .await
176        } else {
177            tokio::time::timeout(
178                timeout_duration,
179                OnlineClient::<PolkadotConfig>::from_url(endpoint),
180            )
181            .await
182        };
183
184        result
185            .map_err(|_| BittensorError::RpcTimeoutError {
186                message: format!("Connection to {} timed out", endpoint),
187                timeout: timeout_duration,
188            })?
189            .map_err(|e| BittensorError::RpcConnectionError {
190                message: format!("Failed to connect to {}: {}", endpoint, e),
191            })
192    }
193
194    /// Get the current number of healthy connections
195    pub async fn healthy_connection_count(&self) -> usize {
196        let connections = self.connections.read().await;
197        let mut count = 0;
198
199        for conn in connections.iter() {
200            if self.health_checker.is_healthy(conn).await {
201                count += 1;
202            }
203        }
204
205        count
206    }
207
208    /// Force refresh all connections
209    pub async fn refresh_connections(&self) -> Result<(), BittensorError> {
210        info!("Refreshing all connections");
211        self.initialize().await
212    }
213
214    /// Get total number of connections (healthy and unhealthy)
215    pub async fn total_connections(&self) -> usize {
216        self.connections.read().await.len()
217    }
218}
219
220/// Builder pattern for better ergonomics
221pub struct ConnectionPoolBuilder {
222    endpoints: Vec<String>,
223    max_connections: usize,
224    retry_config: Option<RetryConfig>,
225    health_checker: Option<HealthChecker>,
226}
227
228impl ConnectionPoolBuilder {
229    pub fn new(endpoints: Vec<String>) -> Self {
230        Self {
231            endpoints,
232            max_connections: 3,
233            retry_config: None,
234            health_checker: None,
235        }
236    }
237
238    pub fn max_connections(mut self, max: usize) -> Self {
239        self.max_connections = max;
240        self
241    }
242
243    pub fn retry_config(mut self, config: RetryConfig) -> Self {
244        self.retry_config = Some(config);
245        self
246    }
247
248    pub fn health_checker(mut self, checker: HealthChecker) -> Self {
249        self.health_checker = Some(checker);
250        self
251    }
252
253    pub fn build(self) -> ConnectionPool {
254        let mut pool = ConnectionPool::new(self.endpoints, self.max_connections);
255
256        if let Some(config) = self.retry_config {
257            pool.retry_config = config;
258        }
259
260        if let Some(checker) = self.health_checker {
261            pool.health_checker = Arc::new(checker);
262        }
263
264        pool
265    }
266}
267
268// Implement the trait for health checking
269#[async_trait::async_trait]
270impl ConnectionPoolTrait for ConnectionPool {
271    async fn connections(&self) -> Arc<RwLock<Vec<Arc<ChainClient>>>> {
272        Arc::clone(&self.connections)
273    }
274
275    async fn reconnect_with_backoff(&self) -> Result<Arc<ChainClient>, BittensorError> {
276        ConnectionPool::reconnect_with_backoff(self).await
277    }
278
279    async fn get_healthy_client(&self) -> Result<Arc<ChainClient>, BittensorError> {
280        ConnectionPool::get_healthy_client(self).await
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use wiremock::matchers::{method, path};
288    use wiremock::{Mock, MockServer, ResponseTemplate};
289
290    async fn setup_mock_server() -> MockServer {
291        MockServer::start().await
292    }
293
294    #[tokio::test]
295    async fn test_connection_pool_creation() {
296        let endpoints = vec!["wss://test.endpoint:443".to_string()];
297        let pool = ConnectionPool::new(endpoints.clone(), 3);
298
299        assert_eq!(pool.endpoints.len(), 1);
300        assert_eq!(pool.max_connections, 3);
301    }
302
303    #[tokio::test]
304    async fn test_connection_pool_builder() {
305        let endpoints = vec!["wss://test.endpoint:443".to_string()];
306        let pool = ConnectionPoolBuilder::new(endpoints.clone())
307            .max_connections(5)
308            .retry_config(RetryConfig::transient())
309            .build();
310
311        assert_eq!(pool.endpoints.len(), 1);
312        assert_eq!(pool.max_connections, 5);
313    }
314
315    #[tokio::test]
316    async fn test_empty_endpoints_initialization() {
317        let pool = ConnectionPool::new(vec![], 3);
318        let result = pool.initialize().await;
319
320        assert!(result.is_err());
321        if let Err(BittensorError::ConfigError { field, .. }) = result {
322            assert_eq!(field, "endpoints");
323        } else {
324            panic!("Expected ConfigError");
325        }
326    }
327
328    #[tokio::test]
329    async fn test_connection_pool_initialization_with_mock() {
330        let mock_server = setup_mock_server().await;
331
332        Mock::given(method("POST"))
333            .and(path("/"))
334            .respond_with(ResponseTemplate::new(200))
335            .mount(&mock_server)
336            .await;
337
338        // Note: This test would need actual WebSocket mocking which is complex
339        // For real testing, we'd need to mock the subxt client properly
340        let endpoints = vec![format!("ws://{}", mock_server.address())];
341        let pool = ConnectionPool::new(endpoints, 1);
342
343        // This will fail as we can't easily mock WebSocket connections
344        // In production, you'd use integration tests or more sophisticated mocking
345        let result = pool.initialize().await;
346        assert!(result.is_err()); // Expected as we can't mock WS properly
347    }
348
349    #[tokio::test]
350    async fn test_healthy_connection_count() {
351        let pool = ConnectionPool::new(vec!["wss://test.endpoint:443".to_string()], 3);
352        let count = pool.healthy_connection_count().await;
353        assert_eq!(count, 0); // No connections established yet
354    }
355
356    #[tokio::test]
357    async fn test_total_connections() {
358        let pool = ConnectionPool::new(vec!["wss://test.endpoint:443".to_string()], 3);
359        let count = pool.total_connections().await;
360        assert_eq!(count, 0); // No connections established yet
361    }
362
363    #[tokio::test]
364    async fn test_get_healthy_client_no_connections() {
365        let pool = ConnectionPool::new(vec!["wss://invalid.endpoint:443".to_string()], 1);
366        let result = pool.get_healthy_client().await;
367        assert!(result.is_err());
368    }
369
370    #[tokio::test]
371    async fn test_reconnect_with_backoff() {
372        let pool = ConnectionPool::new(vec!["wss://invalid.endpoint:443".to_string()], 1);
373
374        // Override retry config to make test faster
375        let mut pool = pool;
376        pool.retry_config = RetryConfig {
377            max_attempts: 2,
378            initial_delay: Duration::from_millis(10),
379            max_delay: Duration::from_millis(20),
380            backoff_multiplier: 1.5,
381            jitter: false,
382        };
383
384        let result = pool.reconnect_with_backoff().await;
385        assert!(result.is_err());
386    }
387
388    #[tokio::test]
389    async fn test_multiple_endpoints_fallback() {
390        let endpoints = vec![
391            "wss://invalid1.endpoint:443".to_string(),
392            "wss://invalid2.endpoint:443".to_string(),
393            "wss://invalid3.endpoint:443".to_string(),
394        ];
395
396        let pool = ConnectionPool::new(endpoints, 3);
397        let result = pool.try_reconnect().await;
398        assert!(result.is_err()); // All endpoints are invalid
399    }
400
401    #[tokio::test]
402    async fn test_create_connection_timeout() {
403        let pool = ConnectionPool::new(vec!["wss://10.255.255.1:443".to_string()], 1);
404
405        // This IP should not be routable, causing a timeout or connection error
406        let result = pool.create_connection("wss://10.255.255.1:443").await;
407        assert!(result.is_err());
408
409        match result {
410            Err(BittensorError::RpcTimeoutError { .. })
411            | Err(BittensorError::RpcConnectionError { .. }) => {
412                // Expected - either timeout or connection error is acceptable in CI environments
413            }
414            Err(e) => {
415                panic!(
416                    "Expected RpcTimeoutError or RpcConnectionError, got: {:?}",
417                    e
418                );
419            }
420            Ok(_) => panic!("Expected error but got Ok"),
421        }
422    }
423}