1use crate::config::BittensorConfig;
4use crate::error::BittensorError;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7use std::time::Duration;
8use subxt::{OnlineClient, PolkadotConfig};
9use tokio::sync::RwLock;
10use tokio::time::Instant;
11use tracing::{debug, info, warn};
12
13type ChainClient = OnlineClient<PolkadotConfig>;
14
15#[derive(Debug, Clone)]
17pub enum ConnectionState {
18 Connected { since: Instant, endpoint: String },
20 Reconnecting {
22 attempts: u32,
23 since: Instant,
24 last_error: Option<String>,
25 },
26 Failed {
28 error: String,
29 at: Instant,
30 consecutive_failures: u32,
31 },
32 Uninitialized,
34}
35
36impl ConnectionState {
37 pub fn is_healthy(&self) -> bool {
39 matches!(self, ConnectionState::Connected { .. })
40 }
41
42 pub fn status_message(&self) -> String {
44 match self {
45 ConnectionState::Connected { since, endpoint } => {
46 format!("Connected to {} (uptime: {:?})", endpoint, since.elapsed())
47 }
48 ConnectionState::Reconnecting {
49 attempts,
50 since,
51 last_error,
52 } => {
53 let error_msg = last_error.as_deref().unwrap_or("unknown");
54 format!(
55 "Reconnecting (attempt {}, elapsed: {:?}, last error: {})",
56 attempts,
57 since.elapsed(),
58 error_msg
59 )
60 }
61 ConnectionState::Failed {
62 error,
63 at,
64 consecutive_failures,
65 } => {
66 format!(
67 "Failed {} times (last: {:?} ago): {}",
68 consecutive_failures,
69 at.elapsed(),
70 error
71 )
72 }
73 ConnectionState::Uninitialized => "Not initialized".to_string(),
74 }
75 }
76}
77
78pub struct ConnectionManager {
80 state: Arc<RwLock<ConnectionState>>,
81 client: Arc<RwLock<Option<Arc<ChainClient>>>>,
82 config: BittensorConfig,
83 metrics: Arc<ConnectionMetrics>,
84 #[doc(hidden)]
85 pub max_consecutive_failures: u32,
86}
87
88impl ConnectionManager {
89 pub fn new(config: BittensorConfig) -> Self {
90 Self {
91 state: Arc::new(RwLock::new(ConnectionState::Uninitialized)),
92 client: Arc::new(RwLock::new(None)),
93 config,
94 metrics: Arc::new(ConnectionMetrics::new()),
95 max_consecutive_failures: 10,
96 }
97 }
98
99 pub async fn connect(&self) -> Result<(), BittensorError> {
101 self.update_state(ConnectionState::Reconnecting {
102 attempts: 1,
103 since: Instant::now(),
104 last_error: None,
105 })
106 .await;
107
108 match self.establish_connection().await {
109 Ok((client, endpoint)) => {
110 *self.client.write().await = Some(Arc::new(client));
111
112 self.update_state(ConnectionState::Connected {
113 since: Instant::now(),
114 endpoint: endpoint.clone(),
115 })
116 .await;
117
118 self.metrics.record_connection_success();
119 info!("Successfully connected to {}", endpoint);
120 Ok(())
121 }
122 Err(e) => {
123 let error_msg = e.to_string();
124
125 self.update_state(ConnectionState::Failed {
126 error: error_msg.clone(),
127 at: Instant::now(),
128 consecutive_failures: 1,
129 })
130 .await;
131
132 self.metrics.record_connection_failure();
133 Err(e)
134 }
135 }
136 }
137
138 pub async fn get_client(&self) -> Result<Arc<ChainClient>, BittensorError> {
140 let state = self.state.read().await.clone();
141
142 match state {
143 ConnectionState::Connected { .. } => {
144 self.client.read().await.as_ref().cloned().ok_or_else(|| {
146 BittensorError::ServiceUnavailable {
147 message: "Client not initialized despite connected state".to_string(),
148 }
149 })
150 }
151 ConnectionState::Reconnecting {
152 attempts, since, ..
153 } => {
154 if since.elapsed() > Duration::from_secs(30) {
156 drop(state);
157 self.reconnect_with_backoff().await
158 } else {
159 Err(BittensorError::ServiceUnavailable {
160 message: format!("Reconnecting (attempt {})", attempts),
161 })
162 }
163 }
164 ConnectionState::Failed {
165 at,
166 consecutive_failures,
167 ..
168 } => {
169 let retry_delay = self.calculate_retry_delay(consecutive_failures);
171
172 if at.elapsed() > retry_delay {
173 drop(state);
174 self.reconnect_with_backoff().await
175 } else {
176 Err(BittensorError::ServiceUnavailable {
177 message: format!(
178 "Connection failed, retry in {:?}",
179 retry_delay.saturating_sub(at.elapsed())
180 ),
181 })
182 }
183 }
184 ConnectionState::Uninitialized => {
185 drop(state);
186 self.connect().await?;
187 Box::pin(self.get_client()).await
188 }
189 }
190 }
191
192 #[doc(hidden)]
194 pub async fn reconnect_with_backoff(&self) -> Result<Arc<ChainClient>, BittensorError> {
195 let mut attempts = 0u32;
196 let mut consecutive_failures = self.get_consecutive_failures().await;
197
198 loop {
199 attempts += 1;
200 consecutive_failures += 1;
201
202 if consecutive_failures > self.max_consecutive_failures {
203 return Err(BittensorError::NetworkError {
204 message: format!(
205 "Maximum consecutive failures ({}) exceeded",
206 self.max_consecutive_failures
207 ),
208 });
209 }
210
211 self.update_state(ConnectionState::Reconnecting {
212 attempts,
213 since: Instant::now(),
214 last_error: None,
215 })
216 .await;
217
218 match self.establish_connection().await {
219 Ok((client, endpoint)) => {
220 let client_arc = Arc::new(client);
221 *self.client.write().await = Some(Arc::clone(&client_arc));
222
223 self.update_state(ConnectionState::Connected {
224 since: Instant::now(),
225 endpoint,
226 })
227 .await;
228
229 self.metrics.record_connection_success();
230 return Ok(client_arc);
231 }
232 Err(e) => {
233 let error_msg = e.to_string();
234 warn!("Reconnection attempt {} failed: {}", attempts, error_msg);
235
236 self.update_state(ConnectionState::Failed {
237 error: error_msg,
238 at: Instant::now(),
239 consecutive_failures,
240 })
241 .await;
242
243 self.metrics.record_connection_failure();
244
245 if attempts >= 3 {
246 return Err(e);
247 }
248
249 let delay = self.calculate_retry_delay(attempts);
250 tokio::time::sleep(delay).await;
251 }
252 }
253 }
254 }
255
256 async fn establish_connection(&self) -> Result<(ChainClient, String), BittensorError> {
258 let endpoints = self.config.get_chain_endpoints();
259
260 for (idx, endpoint) in endpoints.iter().enumerate() {
261 debug!(
262 "Trying endpoint {}/{}: {}",
263 idx + 1,
264 endpoints.len(),
265 endpoint
266 );
267
268 let timeout_duration = Duration::from_secs(30);
269
270 let is_insecure = endpoint.starts_with("ws://") || endpoint.starts_with("http://");
271
272 let result = if is_insecure {
273 debug!("Using insecure connection for endpoint: {}", endpoint);
274 tokio::time::timeout(
275 timeout_duration,
276 OnlineClient::<PolkadotConfig>::from_insecure_url(endpoint),
277 )
278 .await
279 } else {
280 tokio::time::timeout(
281 timeout_duration,
282 OnlineClient::<PolkadotConfig>::from_url(endpoint),
283 )
284 .await
285 };
286
287 match result {
288 Ok(Ok(client)) => {
289 info!("Successfully connected to {}", endpoint);
290 return Ok((client, endpoint.to_string()));
291 }
292 Ok(Err(e)) => {
293 warn!("Failed to connect to {}: {}", endpoint, e);
294 }
295 Err(_) => {
296 warn!(
297 "Connection to {} timed out after {:?}",
298 endpoint, timeout_duration
299 );
300 }
301 }
302
303 if idx < endpoints.len() - 1 {
305 tokio::time::sleep(Duration::from_millis(500)).await;
306 }
307 }
308
309 Err(BittensorError::NetworkError {
310 message: "Failed to connect to any endpoint".to_string(),
311 })
312 }
313
314 #[doc(hidden)]
316 pub async fn update_state(&self, new_state: ConnectionState) {
317 *self.state.write().await = new_state;
318 }
319
320 async fn get_consecutive_failures(&self) -> u32 {
322 match &*self.state.read().await {
323 ConnectionState::Failed {
324 consecutive_failures,
325 ..
326 } => *consecutive_failures,
327 _ => 0,
328 }
329 }
330
331 fn calculate_retry_delay(&self, attempt: u32) -> Duration {
333 let base_delay = Duration::from_secs(1);
334 let max_delay = Duration::from_secs(60);
335
336 let exponential_delay = base_delay * 2u32.pow(attempt.saturating_sub(1));
337 exponential_delay.min(max_delay)
338 }
339
340 pub async fn get_state(&self) -> ConnectionState {
342 self.state.read().await.clone()
343 }
344
345 pub fn metrics(&self) -> ConnectionMetricsSnapshot {
347 self.metrics.snapshot()
348 }
349
350 pub async fn force_reconnect(&self) -> Result<(), BittensorError> {
352 info!("Forcing reconnection");
353 self.update_state(ConnectionState::Uninitialized).await;
354 self.connect().await
355 }
356
357 pub async fn is_connected(&self) -> bool {
359 self.state.read().await.is_healthy()
360 }
361}
362
363struct ConnectionMetrics {
365 success_count: AtomicU64,
366 failure_count: AtomicU64,
367 total_reconnects: AtomicU64,
368 last_success: Arc<RwLock<Option<Instant>>>,
369 last_failure: Arc<RwLock<Option<Instant>>>,
370}
371
372impl ConnectionMetrics {
373 fn new() -> Self {
374 Self {
375 success_count: AtomicU64::new(0),
376 failure_count: AtomicU64::new(0),
377 total_reconnects: AtomicU64::new(0),
378 last_success: Arc::new(RwLock::new(None)),
379 last_failure: Arc::new(RwLock::new(None)),
380 }
381 }
382
383 fn record_connection_success(&self) {
384 self.success_count.fetch_add(1, Ordering::Relaxed);
385 let last_success = Arc::clone(&self.last_success);
386 tokio::spawn(async move {
387 *last_success.write().await = Some(Instant::now());
388 });
389 }
390
391 fn record_connection_failure(&self) {
392 self.failure_count.fetch_add(1, Ordering::Relaxed);
393 self.total_reconnects.fetch_add(1, Ordering::Relaxed);
394 let last_failure = Arc::clone(&self.last_failure);
395 tokio::spawn(async move {
396 *last_failure.write().await = Some(Instant::now());
397 });
398 }
399
400 fn snapshot(&self) -> ConnectionMetricsSnapshot {
401 ConnectionMetricsSnapshot {
402 success_count: self.success_count.load(Ordering::Relaxed),
403 failure_count: self.failure_count.load(Ordering::Relaxed),
404 total_reconnects: self.total_reconnects.load(Ordering::Relaxed),
405 success_rate: self.calculate_success_rate(),
406 }
407 }
408
409 fn calculate_success_rate(&self) -> f64 {
410 let successes = self.success_count.load(Ordering::Relaxed) as f64;
411 let failures = self.failure_count.load(Ordering::Relaxed) as f64;
412 let total = successes + failures;
413
414 if total == 0.0 {
415 100.0
416 } else {
417 (successes / total) * 100.0
418 }
419 }
420}
421
422#[derive(Debug, Clone)]
424pub struct ConnectionMetricsSnapshot {
425 pub success_count: u64,
426 pub failure_count: u64,
427 pub total_reconnects: u64,
428 pub success_rate: f64,
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 fn test_config() -> BittensorConfig {
436 BittensorConfig {
440 network: "local".to_string(),
441 chain_endpoint: Some("wss://test.endpoint:443".to_string()),
442 wallet_name: "test_wallet".to_string(),
443 hotkey_name: "test_hotkey".to_string(),
444 netuid: 1,
445 ..Default::default()
446 }
447 }
448
449 #[tokio::test]
450 async fn test_connection_state_initialization() {
451 let manager = ConnectionManager::new(test_config());
452 let state = manager.get_state().await;
453 assert!(matches!(state, ConnectionState::Uninitialized));
454 }
455
456 #[tokio::test]
457 async fn test_connection_state_is_healthy() {
458 let state = ConnectionState::Connected {
459 since: Instant::now(),
460 endpoint: "test".to_string(),
461 };
462 assert!(state.is_healthy());
463
464 let state = ConnectionState::Failed {
465 error: "error".to_string(),
466 at: Instant::now(),
467 consecutive_failures: 1,
468 };
469 assert!(!state.is_healthy());
470
471 let state = ConnectionState::Reconnecting {
472 attempts: 1,
473 since: Instant::now(),
474 last_error: None,
475 };
476 assert!(!state.is_healthy());
477
478 let state = ConnectionState::Uninitialized;
479 assert!(!state.is_healthy());
480 }
481
482 #[tokio::test]
483 async fn test_status_message() {
484 let state = ConnectionState::Connected {
485 since: Instant::now(),
486 endpoint: "wss://test:443".to_string(),
487 };
488 let msg = state.status_message();
489 assert!(msg.contains("Connected to wss://test:443"));
490
491 let state = ConnectionState::Failed {
492 error: "connection refused".to_string(),
493 at: Instant::now(),
494 consecutive_failures: 3,
495 };
496 let msg = state.status_message();
497 assert!(msg.contains("Failed 3 times"));
498 assert!(msg.contains("connection refused"));
499
500 let state = ConnectionState::Reconnecting {
501 attempts: 2,
502 since: Instant::now(),
503 last_error: Some("timeout".to_string()),
504 };
505 let msg = state.status_message();
506 assert!(msg.contains("attempt 2"));
507 assert!(msg.contains("timeout"));
508
509 let state = ConnectionState::Uninitialized;
510 assert_eq!(state.status_message(), "Not initialized");
511 }
512
513 #[tokio::test]
514 async fn test_calculate_retry_delay() {
515 let manager = ConnectionManager::new(test_config());
516
517 let delay1 = manager.calculate_retry_delay(1);
518 assert_eq!(delay1, Duration::from_secs(1));
519
520 let delay2 = manager.calculate_retry_delay(2);
521 assert_eq!(delay2, Duration::from_secs(2));
522
523 let delay3 = manager.calculate_retry_delay(3);
524 assert_eq!(delay3, Duration::from_secs(4));
525
526 let delay4 = manager.calculate_retry_delay(4);
527 assert_eq!(delay4, Duration::from_secs(8));
528
529 let delay_max = manager.calculate_retry_delay(10);
531 assert_eq!(delay_max, Duration::from_secs(60));
532 }
533
534 #[tokio::test]
535 async fn test_get_consecutive_failures() {
536 let manager = ConnectionManager::new(test_config());
537
538 let failures = manager.get_consecutive_failures().await;
540 assert_eq!(failures, 0);
541
542 manager
544 .update_state(ConnectionState::Failed {
545 error: "test".to_string(),
546 at: Instant::now(),
547 consecutive_failures: 5,
548 })
549 .await;
550
551 let failures = manager.get_consecutive_failures().await;
552 assert_eq!(failures, 5);
553
554 manager
556 .update_state(ConnectionState::Connected {
557 since: Instant::now(),
558 endpoint: "test".to_string(),
559 })
560 .await;
561
562 let failures = manager.get_consecutive_failures().await;
563 assert_eq!(failures, 0);
564 }
565
566 #[tokio::test]
567 async fn test_is_connected() {
568 let manager = ConnectionManager::new(test_config());
569
570 assert!(!manager.is_connected().await);
572
573 manager
575 .update_state(ConnectionState::Connected {
576 since: Instant::now(),
577 endpoint: "test".to_string(),
578 })
579 .await;
580
581 assert!(manager.is_connected().await);
582 }
583
584 #[tokio::test]
585 async fn test_metrics_calculation() {
586 let metrics = ConnectionMetrics::new();
587
588 let snapshot = metrics.snapshot();
590 assert_eq!(snapshot.success_count, 0);
591 assert_eq!(snapshot.failure_count, 0);
592 assert_eq!(snapshot.total_reconnects, 0);
593 assert_eq!(snapshot.success_rate, 100.0);
594
595 metrics.success_count.store(7, Ordering::Relaxed);
597 metrics.failure_count.store(3, Ordering::Relaxed);
598 metrics.total_reconnects.store(3, Ordering::Relaxed);
599
600 let snapshot = metrics.snapshot();
601 assert_eq!(snapshot.success_count, 7);
602 assert_eq!(snapshot.failure_count, 3);
603 assert_eq!(snapshot.total_reconnects, 3);
604 assert!((snapshot.success_rate - 70.0).abs() < 0.01);
605 }
606
607 #[tokio::test]
608 async fn test_connection_manager_get_client_uninitialized() {
609 let manager = ConnectionManager::new(test_config());
610
611 let result = manager.get_client().await;
613 assert!(result.is_err()); }
615
616 #[tokio::test]
617 async fn test_max_consecutive_failures() {
618 let mut manager = ConnectionManager::new(test_config());
619 manager.max_consecutive_failures = 2;
620
621 manager
623 .update_state(ConnectionState::Failed {
624 error: "test".to_string(),
625 at: Instant::now(),
626 consecutive_failures: 3,
627 })
628 .await;
629
630 let result = manager.reconnect_with_backoff().await;
632 assert!(result.is_err());
633
634 if let Err(BittensorError::NetworkError { message }) = result {
635 assert!(message.contains("Maximum consecutive failures"));
636 } else {
637 panic!("Expected NetworkError with max failures message");
638 }
639 }
640
641 #[tokio::test]
642 async fn test_state_transitions() {
643 let manager = ConnectionManager::new(test_config());
644
645 manager
647 .update_state(ConnectionState::Reconnecting {
648 attempts: 1,
649 since: Instant::now(),
650 last_error: None,
651 })
652 .await;
653
654 let state = manager.get_state().await;
655 assert!(matches!(state, ConnectionState::Reconnecting { .. }));
656
657 manager
659 .update_state(ConnectionState::Failed {
660 error: "error".to_string(),
661 at: Instant::now(),
662 consecutive_failures: 1,
663 })
664 .await;
665
666 let state = manager.get_state().await;
667 assert!(matches!(state, ConnectionState::Failed { .. }));
668
669 manager
671 .update_state(ConnectionState::Connected {
672 since: Instant::now(),
673 endpoint: "endpoint".to_string(),
674 })
675 .await;
676
677 let state = manager.get_state().await;
678 assert!(matches!(state, ConnectionState::Connected { .. }));
679 }
680}