bittensor_rs/connect/
state.rs

1//! Connection state management for observability and debugging
2
3use crate::config::BittensorConfig;
4use crate::error::BittensorError;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7use std::time::Duration;
8use subxt::{OnlineClient, PolkadotConfig};
9use tokio::sync::RwLock;
10use tokio::time::Instant;
11use tracing::{debug, info, warn};
12
13type ChainClient = OnlineClient<PolkadotConfig>;
14
15/// Connection state for monitoring and debugging
16#[derive(Debug, Clone)]
17pub enum ConnectionState {
18    /// Successfully connected
19    Connected { since: Instant, endpoint: String },
20    /// Currently attempting to reconnect
21    Reconnecting {
22        attempts: u32,
23        since: Instant,
24        last_error: Option<String>,
25    },
26    /// Connection failed
27    Failed {
28        error: String,
29        at: Instant,
30        consecutive_failures: u32,
31    },
32    /// Not yet initialized
33    Uninitialized,
34}
35
36impl ConnectionState {
37    /// Check if the connection is healthy
38    pub fn is_healthy(&self) -> bool {
39        matches!(self, ConnectionState::Connected { .. })
40    }
41
42    /// Get a human-readable status message
43    pub fn status_message(&self) -> String {
44        match self {
45            ConnectionState::Connected { since, endpoint } => {
46                format!("Connected to {} (uptime: {:?})", endpoint, since.elapsed())
47            }
48            ConnectionState::Reconnecting {
49                attempts,
50                since,
51                last_error,
52            } => {
53                let error_msg = last_error.as_deref().unwrap_or("unknown");
54                format!(
55                    "Reconnecting (attempt {}, elapsed: {:?}, last error: {})",
56                    attempts,
57                    since.elapsed(),
58                    error_msg
59                )
60            }
61            ConnectionState::Failed {
62                error,
63                at,
64                consecutive_failures,
65            } => {
66                format!(
67                    "Failed {} times (last: {:?} ago): {}",
68                    consecutive_failures,
69                    at.elapsed(),
70                    error
71                )
72            }
73            ConnectionState::Uninitialized => "Not initialized".to_string(),
74        }
75    }
76}
77
78/// Manages connection lifecycle and state transitions
79pub struct ConnectionManager {
80    state: Arc<RwLock<ConnectionState>>,
81    client: Arc<RwLock<Option<Arc<ChainClient>>>>,
82    config: BittensorConfig,
83    metrics: Arc<ConnectionMetrics>,
84    #[doc(hidden)]
85    pub max_consecutive_failures: u32,
86}
87
88impl ConnectionManager {
89    pub fn new(config: BittensorConfig) -> Self {
90        Self {
91            state: Arc::new(RwLock::new(ConnectionState::Uninitialized)),
92            client: Arc::new(RwLock::new(None)),
93            config,
94            metrics: Arc::new(ConnectionMetrics::new()),
95            max_consecutive_failures: 10,
96        }
97    }
98
99    /// Establish initial connection
100    pub async fn connect(&self) -> Result<(), BittensorError> {
101        self.update_state(ConnectionState::Reconnecting {
102            attempts: 1,
103            since: Instant::now(),
104            last_error: None,
105        })
106        .await;
107
108        match self.establish_connection().await {
109            Ok((client, endpoint)) => {
110                *self.client.write().await = Some(Arc::new(client));
111
112                self.update_state(ConnectionState::Connected {
113                    since: Instant::now(),
114                    endpoint: endpoint.clone(),
115                })
116                .await;
117
118                self.metrics.record_connection_success();
119                info!("Successfully connected to {}", endpoint);
120                Ok(())
121            }
122            Err(e) => {
123                let error_msg = e.to_string();
124
125                self.update_state(ConnectionState::Failed {
126                    error: error_msg.clone(),
127                    at: Instant::now(),
128                    consecutive_failures: 1,
129                })
130                .await;
131
132                self.metrics.record_connection_failure();
133                Err(e)
134            }
135        }
136    }
137
138    /// Get client with automatic reconnection
139    pub async fn get_client(&self) -> Result<Arc<ChainClient>, BittensorError> {
140        let state = self.state.read().await.clone();
141
142        match state {
143            ConnectionState::Connected { .. } => {
144                // Fast path: already connected
145                self.client.read().await.as_ref().cloned().ok_or_else(|| {
146                    BittensorError::ServiceUnavailable {
147                        message: "Client not initialized despite connected state".to_string(),
148                    }
149                })
150            }
151            ConnectionState::Reconnecting {
152                attempts, since, ..
153            } => {
154                // Wait for ongoing reconnection or trigger new one
155                if since.elapsed() > Duration::from_secs(30) {
156                    drop(state);
157                    self.reconnect_with_backoff().await
158                } else {
159                    Err(BittensorError::ServiceUnavailable {
160                        message: format!("Reconnecting (attempt {})", attempts),
161                    })
162                }
163            }
164            ConnectionState::Failed {
165                at,
166                consecutive_failures,
167                ..
168            } => {
169                // Retry after a delay
170                let retry_delay = self.calculate_retry_delay(consecutive_failures);
171
172                if at.elapsed() > retry_delay {
173                    drop(state);
174                    self.reconnect_with_backoff().await
175                } else {
176                    Err(BittensorError::ServiceUnavailable {
177                        message: format!(
178                            "Connection failed, retry in {:?}",
179                            retry_delay.saturating_sub(at.elapsed())
180                        ),
181                    })
182                }
183            }
184            ConnectionState::Uninitialized => {
185                drop(state);
186                self.connect().await?;
187                Box::pin(self.get_client()).await
188            }
189        }
190    }
191
192    /// Reconnect with exponential backoff
193    #[doc(hidden)]
194    pub async fn reconnect_with_backoff(&self) -> Result<Arc<ChainClient>, BittensorError> {
195        let mut attempts = 0u32;
196        let mut consecutive_failures = self.get_consecutive_failures().await;
197
198        loop {
199            attempts += 1;
200            consecutive_failures += 1;
201
202            if consecutive_failures > self.max_consecutive_failures {
203                return Err(BittensorError::NetworkError {
204                    message: format!(
205                        "Maximum consecutive failures ({}) exceeded",
206                        self.max_consecutive_failures
207                    ),
208                });
209            }
210
211            self.update_state(ConnectionState::Reconnecting {
212                attempts,
213                since: Instant::now(),
214                last_error: None,
215            })
216            .await;
217
218            match self.establish_connection().await {
219                Ok((client, endpoint)) => {
220                    let client_arc = Arc::new(client);
221                    *self.client.write().await = Some(Arc::clone(&client_arc));
222
223                    self.update_state(ConnectionState::Connected {
224                        since: Instant::now(),
225                        endpoint,
226                    })
227                    .await;
228
229                    self.metrics.record_connection_success();
230                    return Ok(client_arc);
231                }
232                Err(e) => {
233                    let error_msg = e.to_string();
234                    warn!("Reconnection attempt {} failed: {}", attempts, error_msg);
235
236                    self.update_state(ConnectionState::Failed {
237                        error: error_msg,
238                        at: Instant::now(),
239                        consecutive_failures,
240                    })
241                    .await;
242
243                    self.metrics.record_connection_failure();
244
245                    if attempts >= 3 {
246                        return Err(e);
247                    }
248
249                    let delay = self.calculate_retry_delay(attempts);
250                    tokio::time::sleep(delay).await;
251                }
252            }
253        }
254    }
255
256    /// Establish connection to any available endpoint
257    async fn establish_connection(&self) -> Result<(ChainClient, String), BittensorError> {
258        let endpoints = self.config.get_chain_endpoints();
259
260        for (idx, endpoint) in endpoints.iter().enumerate() {
261            debug!(
262                "Trying endpoint {}/{}: {}",
263                idx + 1,
264                endpoints.len(),
265                endpoint
266            );
267
268            let timeout_duration = Duration::from_secs(30);
269
270            let is_insecure = endpoint.starts_with("ws://") || endpoint.starts_with("http://");
271
272            let result = if is_insecure {
273                debug!("Using insecure connection for endpoint: {}", endpoint);
274                tokio::time::timeout(
275                    timeout_duration,
276                    OnlineClient::<PolkadotConfig>::from_insecure_url(endpoint),
277                )
278                .await
279            } else {
280                tokio::time::timeout(
281                    timeout_duration,
282                    OnlineClient::<PolkadotConfig>::from_url(endpoint),
283                )
284                .await
285            };
286
287            match result {
288                Ok(Ok(client)) => {
289                    info!("Successfully connected to {}", endpoint);
290                    return Ok((client, endpoint.to_string()));
291                }
292                Ok(Err(e)) => {
293                    warn!("Failed to connect to {}: {}", endpoint, e);
294                }
295                Err(_) => {
296                    warn!(
297                        "Connection to {} timed out after {:?}",
298                        endpoint, timeout_duration
299                    );
300                }
301            }
302
303            // Small delay between endpoint attempts
304            if idx < endpoints.len() - 1 {
305                tokio::time::sleep(Duration::from_millis(500)).await;
306            }
307        }
308
309        Err(BittensorError::NetworkError {
310            message: "Failed to connect to any endpoint".to_string(),
311        })
312    }
313
314    /// Update connection state
315    #[doc(hidden)]
316    pub async fn update_state(&self, new_state: ConnectionState) {
317        *self.state.write().await = new_state;
318    }
319
320    /// Get current consecutive failures count
321    async fn get_consecutive_failures(&self) -> u32 {
322        match &*self.state.read().await {
323            ConnectionState::Failed {
324                consecutive_failures,
325                ..
326            } => *consecutive_failures,
327            _ => 0,
328        }
329    }
330
331    /// Calculate retry delay based on attempt number
332    fn calculate_retry_delay(&self, attempt: u32) -> Duration {
333        let base_delay = Duration::from_secs(1);
334        let max_delay = Duration::from_secs(60);
335
336        let exponential_delay = base_delay * 2u32.pow(attempt.saturating_sub(1));
337        exponential_delay.min(max_delay)
338    }
339
340    /// Get current connection state
341    pub async fn get_state(&self) -> ConnectionState {
342        self.state.read().await.clone()
343    }
344
345    /// Get connection metrics
346    pub fn metrics(&self) -> ConnectionMetricsSnapshot {
347        self.metrics.snapshot()
348    }
349
350    /// Force reconnection
351    pub async fn force_reconnect(&self) -> Result<(), BittensorError> {
352        info!("Forcing reconnection");
353        self.update_state(ConnectionState::Uninitialized).await;
354        self.connect().await
355    }
356
357    /// Check if connected
358    pub async fn is_connected(&self) -> bool {
359        self.state.read().await.is_healthy()
360    }
361}
362
363/// Connection metrics for monitoring
364struct ConnectionMetrics {
365    success_count: AtomicU64,
366    failure_count: AtomicU64,
367    total_reconnects: AtomicU64,
368    last_success: Arc<RwLock<Option<Instant>>>,
369    last_failure: Arc<RwLock<Option<Instant>>>,
370}
371
372impl ConnectionMetrics {
373    fn new() -> Self {
374        Self {
375            success_count: AtomicU64::new(0),
376            failure_count: AtomicU64::new(0),
377            total_reconnects: AtomicU64::new(0),
378            last_success: Arc::new(RwLock::new(None)),
379            last_failure: Arc::new(RwLock::new(None)),
380        }
381    }
382
383    fn record_connection_success(&self) {
384        self.success_count.fetch_add(1, Ordering::Relaxed);
385        let last_success = Arc::clone(&self.last_success);
386        tokio::spawn(async move {
387            *last_success.write().await = Some(Instant::now());
388        });
389    }
390
391    fn record_connection_failure(&self) {
392        self.failure_count.fetch_add(1, Ordering::Relaxed);
393        self.total_reconnects.fetch_add(1, Ordering::Relaxed);
394        let last_failure = Arc::clone(&self.last_failure);
395        tokio::spawn(async move {
396            *last_failure.write().await = Some(Instant::now());
397        });
398    }
399
400    fn snapshot(&self) -> ConnectionMetricsSnapshot {
401        ConnectionMetricsSnapshot {
402            success_count: self.success_count.load(Ordering::Relaxed),
403            failure_count: self.failure_count.load(Ordering::Relaxed),
404            total_reconnects: self.total_reconnects.load(Ordering::Relaxed),
405            success_rate: self.calculate_success_rate(),
406        }
407    }
408
409    fn calculate_success_rate(&self) -> f64 {
410        let successes = self.success_count.load(Ordering::Relaxed) as f64;
411        let failures = self.failure_count.load(Ordering::Relaxed) as f64;
412        let total = successes + failures;
413
414        if total == 0.0 {
415            100.0
416        } else {
417            (successes / total) * 100.0
418        }
419    }
420}
421
422/// Snapshot of connection metrics
423#[derive(Debug, Clone)]
424pub struct ConnectionMetricsSnapshot {
425    pub success_count: u64,
426    pub failure_count: u64,
427    pub total_reconnects: u64,
428    pub success_rate: f64,
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    fn test_config() -> BittensorConfig {
436        // Use a bogus endpoint and the "local" network so no default
437        // fallbacks are appended. This ensures the connection attempt fails
438        // in environments with network access.
439        BittensorConfig {
440            network: "local".to_string(),
441            chain_endpoint: Some("wss://test.endpoint:443".to_string()),
442            wallet_name: "test_wallet".to_string(),
443            hotkey_name: "test_hotkey".to_string(),
444            netuid: 1,
445            ..Default::default()
446        }
447    }
448
449    #[tokio::test]
450    async fn test_connection_state_initialization() {
451        let manager = ConnectionManager::new(test_config());
452        let state = manager.get_state().await;
453        assert!(matches!(state, ConnectionState::Uninitialized));
454    }
455
456    #[tokio::test]
457    async fn test_connection_state_is_healthy() {
458        let state = ConnectionState::Connected {
459            since: Instant::now(),
460            endpoint: "test".to_string(),
461        };
462        assert!(state.is_healthy());
463
464        let state = ConnectionState::Failed {
465            error: "error".to_string(),
466            at: Instant::now(),
467            consecutive_failures: 1,
468        };
469        assert!(!state.is_healthy());
470
471        let state = ConnectionState::Reconnecting {
472            attempts: 1,
473            since: Instant::now(),
474            last_error: None,
475        };
476        assert!(!state.is_healthy());
477
478        let state = ConnectionState::Uninitialized;
479        assert!(!state.is_healthy());
480    }
481
482    #[tokio::test]
483    async fn test_status_message() {
484        let state = ConnectionState::Connected {
485            since: Instant::now(),
486            endpoint: "wss://test:443".to_string(),
487        };
488        let msg = state.status_message();
489        assert!(msg.contains("Connected to wss://test:443"));
490
491        let state = ConnectionState::Failed {
492            error: "connection refused".to_string(),
493            at: Instant::now(),
494            consecutive_failures: 3,
495        };
496        let msg = state.status_message();
497        assert!(msg.contains("Failed 3 times"));
498        assert!(msg.contains("connection refused"));
499
500        let state = ConnectionState::Reconnecting {
501            attempts: 2,
502            since: Instant::now(),
503            last_error: Some("timeout".to_string()),
504        };
505        let msg = state.status_message();
506        assert!(msg.contains("attempt 2"));
507        assert!(msg.contains("timeout"));
508
509        let state = ConnectionState::Uninitialized;
510        assert_eq!(state.status_message(), "Not initialized");
511    }
512
513    #[tokio::test]
514    async fn test_calculate_retry_delay() {
515        let manager = ConnectionManager::new(test_config());
516
517        let delay1 = manager.calculate_retry_delay(1);
518        assert_eq!(delay1, Duration::from_secs(1));
519
520        let delay2 = manager.calculate_retry_delay(2);
521        assert_eq!(delay2, Duration::from_secs(2));
522
523        let delay3 = manager.calculate_retry_delay(3);
524        assert_eq!(delay3, Duration::from_secs(4));
525
526        let delay4 = manager.calculate_retry_delay(4);
527        assert_eq!(delay4, Duration::from_secs(8));
528
529        // Test max delay cap
530        let delay_max = manager.calculate_retry_delay(10);
531        assert_eq!(delay_max, Duration::from_secs(60));
532    }
533
534    #[tokio::test]
535    async fn test_get_consecutive_failures() {
536        let manager = ConnectionManager::new(test_config());
537
538        // Initially should be 0
539        let failures = manager.get_consecutive_failures().await;
540        assert_eq!(failures, 0);
541
542        // Update to failed state
543        manager
544            .update_state(ConnectionState::Failed {
545                error: "test".to_string(),
546                at: Instant::now(),
547                consecutive_failures: 5,
548            })
549            .await;
550
551        let failures = manager.get_consecutive_failures().await;
552        assert_eq!(failures, 5);
553
554        // Update to connected state
555        manager
556            .update_state(ConnectionState::Connected {
557                since: Instant::now(),
558                endpoint: "test".to_string(),
559            })
560            .await;
561
562        let failures = manager.get_consecutive_failures().await;
563        assert_eq!(failures, 0);
564    }
565
566    #[tokio::test]
567    async fn test_is_connected() {
568        let manager = ConnectionManager::new(test_config());
569
570        // Initially not connected
571        assert!(!manager.is_connected().await);
572
573        // Update to connected state
574        manager
575            .update_state(ConnectionState::Connected {
576                since: Instant::now(),
577                endpoint: "test".to_string(),
578            })
579            .await;
580
581        assert!(manager.is_connected().await);
582    }
583
584    #[tokio::test]
585    async fn test_metrics_calculation() {
586        let metrics = ConnectionMetrics::new();
587
588        // Initial state
589        let snapshot = metrics.snapshot();
590        assert_eq!(snapshot.success_count, 0);
591        assert_eq!(snapshot.failure_count, 0);
592        assert_eq!(snapshot.total_reconnects, 0);
593        assert_eq!(snapshot.success_rate, 100.0);
594
595        // Record some successes and failures
596        metrics.success_count.store(7, Ordering::Relaxed);
597        metrics.failure_count.store(3, Ordering::Relaxed);
598        metrics.total_reconnects.store(3, Ordering::Relaxed);
599
600        let snapshot = metrics.snapshot();
601        assert_eq!(snapshot.success_count, 7);
602        assert_eq!(snapshot.failure_count, 3);
603        assert_eq!(snapshot.total_reconnects, 3);
604        assert!((snapshot.success_rate - 70.0).abs() < 0.01);
605    }
606
607    #[tokio::test]
608    async fn test_connection_manager_get_client_uninitialized() {
609        let manager = ConnectionManager::new(test_config());
610
611        // Getting client when uninitialized should try to connect
612        let result = manager.get_client().await;
613        assert!(result.is_err()); // Will fail as we don't have a real endpoint
614    }
615
616    #[tokio::test]
617    async fn test_max_consecutive_failures() {
618        let mut manager = ConnectionManager::new(test_config());
619        manager.max_consecutive_failures = 2;
620
621        // Set high consecutive failures
622        manager
623            .update_state(ConnectionState::Failed {
624                error: "test".to_string(),
625                at: Instant::now(),
626                consecutive_failures: 3,
627            })
628            .await;
629
630        // Should fail due to max consecutive failures
631        let result = manager.reconnect_with_backoff().await;
632        assert!(result.is_err());
633
634        if let Err(BittensorError::NetworkError { message }) = result {
635            assert!(message.contains("Maximum consecutive failures"));
636        } else {
637            panic!("Expected NetworkError with max failures message");
638        }
639    }
640
641    #[tokio::test]
642    async fn test_state_transitions() {
643        let manager = ConnectionManager::new(test_config());
644
645        // Uninitialized -> Reconnecting
646        manager
647            .update_state(ConnectionState::Reconnecting {
648                attempts: 1,
649                since: Instant::now(),
650                last_error: None,
651            })
652            .await;
653
654        let state = manager.get_state().await;
655        assert!(matches!(state, ConnectionState::Reconnecting { .. }));
656
657        // Reconnecting -> Failed
658        manager
659            .update_state(ConnectionState::Failed {
660                error: "error".to_string(),
661                at: Instant::now(),
662                consecutive_failures: 1,
663            })
664            .await;
665
666        let state = manager.get_state().await;
667        assert!(matches!(state, ConnectionState::Failed { .. }));
668
669        // Failed -> Connected
670        manager
671            .update_state(ConnectionState::Connected {
672                since: Instant::now(),
673                endpoint: "endpoint".to_string(),
674            })
675            .await;
676
677        let state = manager.get_state().await;
678        assert!(matches!(state, ConnectionState::Connected { .. }));
679    }
680}