Skip to main content

amaters_sdk_rust/
connection_manager.rs

1//! Connection state machine, multi-endpoint failover, and auto-reconnection
2//!
3//! This module provides:
4//! - [`ConnectionState`] - A lock-free state machine for connection lifecycle
5//! - [`EndpointList`] - Priority-ordered endpoint management with failover
6//! - [`ReconnectConfig`] - Exponential backoff reconnection configuration
7//! - [`ConnectionHealth`] - Health monitoring with periodic checks
8//! - [`ConnectionManager`] - Orchestrates all of the above
9
10use crate::config::ClientConfig;
11use crate::error::{Result, SdkError};
12use parking_lot::{Mutex, RwLock};
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
15use std::time::{Duration, Instant};
16use tokio::sync::Notify;
17use tracing::{debug, error, info, warn};
18
19// ---------------------------------------------------------------------------
20// ConnectionState
21// ---------------------------------------------------------------------------
22
23/// Raw state values stored in the `AtomicU8`.
24#[repr(u8)]
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum ConnectionState {
27    /// Not connected to any endpoint.
28    Disconnected = 0,
29    /// Currently attempting to establish a connection.
30    Connecting = 1,
31    /// Successfully connected and operational.
32    Connected = 2,
33    /// Connection was lost; attempting to re-establish.
34    Reconnecting = 3,
35    /// Terminal failure – manual intervention required.
36    Failed = 4,
37}
38
39impl ConnectionState {
40    /// Convert from `u8`, returning `None` for invalid values.
41    fn from_u8(v: u8) -> Option<Self> {
42        match v {
43            0 => Some(Self::Disconnected),
44            1 => Some(Self::Connecting),
45            2 => Some(Self::Connected),
46            3 => Some(Self::Reconnecting),
47            4 => Some(Self::Failed),
48            _ => None,
49        }
50    }
51
52    /// Human-readable label.
53    pub fn as_str(&self) -> &'static str {
54        match self {
55            Self::Disconnected => "Disconnected",
56            Self::Connecting => "Connecting",
57            Self::Connected => "Connected",
58            Self::Reconnecting => "Reconnecting",
59            Self::Failed => "Failed",
60        }
61    }
62}
63
64impl std::fmt::Display for ConnectionState {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.write_str(self.as_str())
67    }
68}
69
70/// Validate that a state transition is legal.
71///
72/// Legal transitions:
73///   Disconnected  → Connecting
74///   Connecting    → Connected | Failed
75///   Connected     → Reconnecting | Disconnected | Failed
76///   Reconnecting  → Connected | Failed
77///   Failed        → Disconnected (reset)
78fn is_valid_transition(from: ConnectionState, to: ConnectionState) -> bool {
79    use ConnectionState::*;
80    matches!(
81        (from, to),
82        (Disconnected, Connecting)
83            | (Connecting, Connected)
84            | (Connecting, Failed)
85            | (Connected, Reconnecting)
86            | (Connected, Disconnected)
87            | (Connected, Failed)
88            | (Reconnecting, Connected)
89            | (Reconnecting, Failed)
90            | (Failed, Disconnected)
91    )
92}
93
94/// Type alias for the state-change callback.
95pub type StateChangeCallback =
96    Arc<dyn Fn(ConnectionState, ConnectionState) + Send + Sync + 'static>;
97
98/// Atomic state holder with optional callback.
99#[derive(Clone)]
100pub struct AtomicConnectionState {
101    raw: Arc<AtomicU8>,
102    callback: Arc<RwLock<Option<StateChangeCallback>>>,
103}
104
105impl AtomicConnectionState {
106    /// Create with `Disconnected`.
107    pub fn new() -> Self {
108        Self {
109            raw: Arc::new(AtomicU8::new(ConnectionState::Disconnected as u8)),
110            callback: Arc::new(RwLock::new(None)),
111        }
112    }
113
114    /// Lock-free read.
115    pub fn get(&self) -> ConnectionState {
116        ConnectionState::from_u8(self.raw.load(Ordering::Acquire))
117            .unwrap_or(ConnectionState::Failed)
118    }
119
120    /// Attempt a state transition. Returns `Err` if the transition is invalid.
121    pub fn transition(&self, to: ConnectionState) -> Result<ConnectionState> {
122        let from = self.get();
123        if !is_valid_transition(from, to) {
124            return Err(SdkError::Connection(format!(
125                "invalid state transition: {} -> {}",
126                from, to
127            )));
128        }
129        self.raw.store(to as u8, Ordering::Release);
130        debug!("state transition: {} -> {}", from, to);
131
132        // Fire callback outside hot path – callback is expected to be fast.
133        if let Some(cb) = self.callback.read().as_ref() {
134            cb(from, to);
135        }
136
137        Ok(from)
138    }
139
140    /// Force-set state (bypasses transition validation). Use sparingly.
141    pub fn force_set(&self, state: ConnectionState) {
142        let prev = self.get();
143        self.raw.store(state as u8, Ordering::Release);
144        if let Some(cb) = self.callback.read().as_ref() {
145            cb(prev, state);
146        }
147    }
148
149    /// Register a callback invoked on every state change.
150    pub fn on_state_change<F>(&self, f: F)
151    where
152        F: Fn(ConnectionState, ConnectionState) + Send + Sync + 'static,
153    {
154        *self.callback.write() = Some(Arc::new(f));
155    }
156
157    /// Remove the state-change callback.
158    pub fn clear_callback(&self) {
159        *self.callback.write() = None;
160    }
161}
162
163impl Default for AtomicConnectionState {
164    fn default() -> Self {
165        Self::new()
166    }
167}
168
169// ---------------------------------------------------------------------------
170// EndpointList
171// ---------------------------------------------------------------------------
172
173/// A single endpoint with a priority (lower = higher priority).
174#[derive(Debug, Clone, PartialEq, Eq)]
175pub struct EndpointEntry {
176    /// The URL of the endpoint (e.g. `http://host:50051`).
177    pub url: String,
178    /// Priority value – lower numbers are tried first.
179    pub priority: u32,
180}
181
182/// Tracks which endpoint is currently active.
183#[derive(Debug, Clone)]
184pub struct ActiveEndpoint {
185    /// Index in the sorted list.
186    pub index: usize,
187    /// URL of the active endpoint.
188    pub url: String,
189    /// When this endpoint became active.
190    pub connected_since: Instant,
191}
192
193/// Priority-ordered list of endpoints with failover support.
194#[derive(Debug, Clone)]
195pub struct EndpointList {
196    /// Endpoints sorted by priority (ascending).
197    entries: Vec<EndpointEntry>,
198    /// Currently active endpoint, if any.
199    active: Option<ActiveEndpoint>,
200}
201
202impl EndpointList {
203    /// Create a new, empty list.
204    pub fn new() -> Self {
205        Self {
206            entries: Vec::new(),
207            active: None,
208        }
209    }
210
211    /// Create with a single primary endpoint.
212    pub fn with_primary(url: impl Into<String>) -> Self {
213        let mut list = Self::new();
214        list.add_endpoint(url, 0);
215        list
216    }
217
218    /// Add an endpoint with the given priority. Re-sorts internally.
219    pub fn add_endpoint(&mut self, url: impl Into<String>, priority: u32) {
220        let url_string = url.into();
221        // Avoid duplicates.
222        if self.entries.iter().any(|e| e.url == url_string) {
223            return;
224        }
225        self.entries.push(EndpointEntry {
226            url: url_string,
227            priority,
228        });
229        self.entries.sort_by_key(|e| e.priority);
230    }
231
232    /// Number of registered endpoints.
233    pub fn len(&self) -> usize {
234        self.entries.len()
235    }
236
237    /// Check if the list is empty.
238    pub fn is_empty(&self) -> bool {
239        self.entries.is_empty()
240    }
241
242    /// Iterate endpoints in priority order.
243    pub fn iter(&self) -> impl Iterator<Item = &EndpointEntry> {
244        self.entries.iter()
245    }
246
247    /// Get the next endpoint to try after the currently active one.
248    /// Wraps around if at the end of the list.
249    pub fn next_endpoint(&self) -> Option<&EndpointEntry> {
250        if self.entries.is_empty() {
251            return None;
252        }
253        let idx = match &self.active {
254            Some(active) => (active.index + 1) % self.entries.len(),
255            None => 0,
256        };
257        self.entries.get(idx)
258    }
259
260    /// Get the first (highest-priority) endpoint.
261    pub fn primary(&self) -> Option<&EndpointEntry> {
262        self.entries.first()
263    }
264
265    /// Mark an endpoint as active by index.
266    pub fn set_active(&mut self, index: usize) -> Result<()> {
267        let entry = self.entries.get(index).ok_or_else(|| {
268            SdkError::InvalidArgument(format!(
269                "endpoint index {} out of range (len={})",
270                index,
271                self.entries.len()
272            ))
273        })?;
274        self.active = Some(ActiveEndpoint {
275            index,
276            url: entry.url.clone(),
277            connected_since: Instant::now(),
278        });
279        Ok(())
280    }
281
282    /// Mark the endpoint with the given URL as active.
283    pub fn set_active_by_url(&mut self, url: &str) -> Result<()> {
284        let index = self
285            .entries
286            .iter()
287            .position(|e| e.url == url)
288            .ok_or_else(|| SdkError::InvalidArgument(format!("endpoint not found: {}", url)))?;
289        self.set_active(index)
290    }
291
292    /// Get the currently active endpoint.
293    pub fn active(&self) -> Option<&ActiveEndpoint> {
294        self.active.as_ref()
295    }
296
297    /// Clear the active endpoint.
298    pub fn clear_active(&mut self) {
299        self.active = None;
300    }
301
302    /// Perform a failover: activate the next endpoint in priority order.
303    /// Returns the newly active endpoint's URL, or `None` if the list is empty.
304    pub fn failover(&mut self) -> Option<String> {
305        if self.entries.is_empty() {
306            return None;
307        }
308        let next_idx = match &self.active {
309            Some(active) => (active.index + 1) % self.entries.len(),
310            None => 0,
311        };
312        let url = self.entries[next_idx].url.clone();
313        self.active = Some(ActiveEndpoint {
314            index: next_idx,
315            url: url.clone(),
316            connected_since: Instant::now(),
317        });
318        info!("failover to endpoint [{}]: {}", next_idx, url);
319        Some(url)
320    }
321}
322
323impl Default for EndpointList {
324    fn default() -> Self {
325        Self::new()
326    }
327}
328
329// ---------------------------------------------------------------------------
330// ReconnectConfig
331// ---------------------------------------------------------------------------
332
333/// Configuration for automatic reconnection with exponential backoff.
334#[derive(Debug, Clone)]
335pub struct ReconnectConfig {
336    /// Maximum number of reconnection attempts before giving up.
337    pub max_attempts: u32,
338    /// Base delay between attempts.
339    pub base_delay: Duration,
340    /// Upper bound on the delay.
341    pub max_delay: Duration,
342    /// Multiplicative factor applied on each attempt.
343    pub backoff_factor: f64,
344    /// Whether to add jitter to prevent thundering herd.
345    pub jitter: bool,
346}
347
348impl Default for ReconnectConfig {
349    fn default() -> Self {
350        Self {
351            max_attempts: 5,
352            base_delay: Duration::from_secs(1),
353            max_delay: Duration::from_secs(30),
354            backoff_factor: 2.0,
355            jitter: true,
356        }
357    }
358}
359
360impl ReconnectConfig {
361    /// Create with defaults.
362    pub fn new() -> Self {
363        Self::default()
364    }
365
366    /// Set max attempts.
367    pub fn with_max_attempts(mut self, n: u32) -> Self {
368        self.max_attempts = n;
369        self
370    }
371
372    /// Set base delay.
373    pub fn with_base_delay(mut self, d: Duration) -> Self {
374        self.base_delay = d;
375        self
376    }
377
378    /// Set max delay.
379    pub fn with_max_delay(mut self, d: Duration) -> Self {
380        self.max_delay = d;
381        self
382    }
383
384    /// Set backoff factor.
385    pub fn with_backoff_factor(mut self, f: f64) -> Self {
386        self.backoff_factor = f;
387        self
388    }
389
390    /// Calculate delay for the given attempt (0-based).
391    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
392        let base_ms = self.base_delay.as_millis() as f64;
393        let raw = base_ms * self.backoff_factor.powi(attempt as i32);
394        let clamped = raw.min(self.max_delay.as_millis() as f64);
395        let ms = if self.jitter {
396            // Deterministic jitter: vary ±25 % based on attempt number.
397            let jitter_frac = 0.75 + (((attempt as usize) % 5) as f64) * 0.1;
398            clamped * jitter_frac
399        } else {
400            clamped
401        };
402        Duration::from_millis(ms as u64)
403    }
404}
405
406// ---------------------------------------------------------------------------
407// ConnectionHealth
408// ---------------------------------------------------------------------------
409
410/// Snapshot of connection health.
411#[derive(Debug, Clone, Default)]
412pub struct ConnectionHealth {
413    /// Timestamp of the most recent health check.
414    pub last_check: Option<Instant>,
415    /// Measured round-trip latency in milliseconds.
416    pub latency_ms: Option<u64>,
417    /// Number of consecutive health-check failures.
418    pub consecutive_failures: u32,
419    /// Whether the connection is currently considered healthy.
420    pub is_healthy: bool,
421}
422
423impl ConnectionHealth {
424    /// Record a successful health check.
425    pub fn record_success(&mut self, latency_ms: u64) {
426        self.last_check = Some(Instant::now());
427        self.latency_ms = Some(latency_ms);
428        self.consecutive_failures = 0;
429        self.is_healthy = true;
430    }
431
432    /// Record a failed health check.
433    pub fn record_failure(&mut self) {
434        self.last_check = Some(Instant::now());
435        self.consecutive_failures += 1;
436        self.is_healthy = false;
437    }
438
439    /// Reset to default.
440    pub fn reset(&mut self) {
441        *self = Self::default();
442    }
443}
444
445// ---------------------------------------------------------------------------
446// ConnectionManager
447// ---------------------------------------------------------------------------
448
449/// Central orchestrator for connection lifecycle management.
450///
451/// Wraps [`ClientConfig`], [`EndpointList`], [`ReconnectConfig`],
452/// [`AtomicConnectionState`] and [`ConnectionHealth`].
453pub struct ConnectionManager {
454    /// Client configuration.
455    config: ClientConfig,
456    /// Priority-ordered endpoints.
457    endpoints: Arc<RwLock<EndpointList>>,
458    /// Reconnection settings.
459    reconnect_config: ReconnectConfig,
460    /// Lock-free state machine.
461    state: AtomicConnectionState,
462    /// Health status.
463    health: Arc<RwLock<ConnectionHealth>>,
464    /// Health-check interval.
465    health_check_interval: Duration,
466    /// Toggle for the background reconnect task.
467    auto_reconnect_enabled: Arc<AtomicBool>,
468    /// Cancellation signal for background tasks.
469    cancel: Arc<Notify>,
470    /// Handles to spawned background tasks.
471    _task_handles: Arc<Mutex<Vec<tokio::task::JoinHandle<()>>>>,
472}
473
474impl ConnectionManager {
475    /// Create a new manager from config and endpoints.
476    pub fn new(
477        config: ClientConfig,
478        endpoints: EndpointList,
479        reconnect_config: ReconnectConfig,
480    ) -> Self {
481        Self {
482            config,
483            endpoints: Arc::new(RwLock::new(endpoints)),
484            reconnect_config,
485            state: AtomicConnectionState::new(),
486            health: Arc::new(RwLock::new(ConnectionHealth::default())),
487            health_check_interval: Duration::from_secs(30),
488            auto_reconnect_enabled: Arc::new(AtomicBool::new(true)),
489            cancel: Arc::new(Notify::new()),
490            _task_handles: Arc::new(Mutex::new(Vec::new())),
491        }
492    }
493
494    /// Convenience: create with a single primary endpoint.
495    pub fn with_primary(config: ClientConfig) -> Self {
496        let addr = config.server_addr.clone();
497        let endpoints = EndpointList::with_primary(addr);
498        Self::new(config, endpoints, ReconnectConfig::default())
499    }
500
501    /// Set the health-check interval.
502    pub fn with_health_check_interval(mut self, interval: Duration) -> Self {
503        self.health_check_interval = interval;
504        self
505    }
506
507    /// Set a state-change callback.
508    pub fn on_state_change<F>(&self, f: F)
509    where
510        F: Fn(ConnectionState, ConnectionState) + Send + Sync + 'static,
511    {
512        self.state.on_state_change(f);
513    }
514
515    // -- state accessors -----------------------------------------------------
516
517    /// Current connection state (lock-free).
518    pub fn state(&self) -> ConnectionState {
519        self.state.get()
520    }
521
522    /// Snapshot of connection health.
523    pub fn health(&self) -> ConnectionHealth {
524        self.health.read().clone()
525    }
526
527    /// URL of the currently active endpoint, if any.
528    pub fn active_endpoint(&self) -> Option<String> {
529        self.endpoints.read().active().map(|a| a.url.clone())
530    }
531
532    /// Reference to endpoint list (read-locked).
533    pub fn endpoints(&self) -> EndpointList {
534        self.endpoints.read().clone()
535    }
536
537    /// Client config reference.
538    pub fn config(&self) -> &ClientConfig {
539        &self.config
540    }
541
542    // -- lifecycle -----------------------------------------------------------
543
544    /// Initiate a connection attempt.
545    ///
546    /// Walks the endpoint list in priority order. On the first success the
547    /// state transitions to `Connected`. If all endpoints fail, the state
548    /// moves to `Failed`.
549    pub async fn connect(&self) -> Result<()> {
550        self.state.transition(ConnectionState::Connecting)?;
551
552        let endpoints: Vec<EndpointEntry> = {
553            let list = self.endpoints.read();
554            list.iter().cloned().collect()
555        };
556
557        if endpoints.is_empty() {
558            self.state.force_set(ConnectionState::Failed);
559            return Err(SdkError::Configuration(
560                "no endpoints configured".to_string(),
561            ));
562        }
563
564        for (idx, ep) in endpoints.iter().enumerate() {
565            info!("trying endpoint [{}] {}", idx, ep.url);
566            match self.try_connect_endpoint(&ep.url).await {
567                Ok(()) => {
568                    self.endpoints.write().set_active(idx)?;
569                    self.state.transition(ConnectionState::Connected)?;
570                    self.health.write().record_success(0);
571                    info!("connected to {}", ep.url);
572                    self.maybe_spawn_health_check();
573                    return Ok(());
574                }
575                Err(e) => {
576                    warn!("endpoint {} failed: {}", ep.url, e);
577                    continue;
578                }
579            }
580        }
581
582        self.state.force_set(ConnectionState::Failed);
583        Err(SdkError::Connection("all endpoints failed".to_string()))
584    }
585
586    /// Cleanly disconnect and reset state.
587    pub fn disconnect(&self) {
588        info!("disconnecting");
589        self.cancel.notify_waiters();
590        self.endpoints.write().clear_active();
591        self.health.write().reset();
592
593        // Transition to Disconnected if currently connected.
594        let current = self.state.get();
595        match current {
596            ConnectionState::Connected => {
597                let _ = self.state.transition(ConnectionState::Disconnected);
598            }
599            ConnectionState::Failed => {
600                let _ = self.state.transition(ConnectionState::Disconnected);
601            }
602            _ => {
603                self.state.force_set(ConnectionState::Disconnected);
604            }
605        }
606    }
607
608    /// Manually trigger a failover to the next endpoint.
609    pub async fn failover(&self) -> Result<String> {
610        let url = {
611            let mut list = self.endpoints.write();
612            list.failover().ok_or_else(|| {
613                SdkError::Connection("no endpoints available for failover".to_string())
614            })?
615        };
616
617        // If we are currently connected, mark reconnecting first.
618        let current = self.state.get();
619        if current == ConnectionState::Connected {
620            self.state.transition(ConnectionState::Reconnecting)?;
621        }
622
623        match self.try_connect_endpoint(&url).await {
624            Ok(()) => {
625                // If we were reconnecting, transition back to connected.
626                if self.state.get() == ConnectionState::Reconnecting {
627                    self.state.transition(ConnectionState::Connected)?;
628                }
629                self.health.write().record_success(0);
630                info!("failover successful to {}", url);
631                Ok(url)
632            }
633            Err(e) => {
634                self.state.force_set(ConnectionState::Failed);
635                Err(SdkError::Connection(format!(
636                    "failover to {} failed: {}",
637                    url, e
638                )))
639            }
640        }
641    }
642
643    // -- auto-reconnect ------------------------------------------------------
644
645    /// Enable automatic reconnection.
646    pub fn enable_auto_reconnect(&self) {
647        self.auto_reconnect_enabled.store(true, Ordering::Release);
648        debug!("auto-reconnect enabled");
649    }
650
651    /// Disable automatic reconnection.
652    pub fn disable_auto_reconnect(&self) {
653        self.auto_reconnect_enabled.store(false, Ordering::Release);
654        debug!("auto-reconnect disabled");
655    }
656
657    /// Whether auto-reconnect is currently enabled.
658    pub fn is_auto_reconnect_enabled(&self) -> bool {
659        self.auto_reconnect_enabled.load(Ordering::Acquire)
660    }
661
662    /// Run the reconnection loop (usually spawned as a background task).
663    /// Tries endpoints with exponential backoff until success or
664    /// `max_attempts` is exhausted.
665    pub async fn reconnect_loop(&self) -> Result<()> {
666        if !self.auto_reconnect_enabled.load(Ordering::Acquire) {
667            return Err(SdkError::Connection(
668                "auto-reconnect is disabled".to_string(),
669            ));
670        }
671
672        // Must be in a state that allows reconnecting.
673        let current = self.state.get();
674        if current == ConnectionState::Connected {
675            self.state.transition(ConnectionState::Reconnecting)?;
676        } else if current != ConnectionState::Reconnecting {
677            // Force to reconnecting if in a broken state.
678            self.state.force_set(ConnectionState::Reconnecting);
679        }
680
681        let endpoints: Vec<EndpointEntry> = {
682            let list = self.endpoints.read();
683            list.iter().cloned().collect()
684        };
685
686        for attempt in 0..self.reconnect_config.max_attempts {
687            if !self.auto_reconnect_enabled.load(Ordering::Acquire) {
688                warn!("auto-reconnect disabled during reconnect loop");
689                return Err(SdkError::Connection(
690                    "auto-reconnect disabled during loop".to_string(),
691                ));
692            }
693
694            let delay = self.reconnect_config.delay_for_attempt(attempt);
695            info!(
696                "reconnect attempt {}/{} – waiting {:?}",
697                attempt + 1,
698                self.reconnect_config.max_attempts,
699                delay
700            );
701
702            tokio::select! {
703                _ = tokio::time::sleep(delay) => {}
704                _ = self.cancel.notified() => {
705                    info!("reconnect loop cancelled");
706                    return Err(SdkError::Connection("reconnect cancelled".to_string()));
707                }
708            }
709
710            // Try each endpoint.
711            for (idx, ep) in endpoints.iter().enumerate() {
712                match self.try_connect_endpoint(&ep.url).await {
713                    Ok(()) => {
714                        if let Err(e) = self.endpoints.write().set_active(idx) {
715                            warn!("failed to set active endpoint: {}", e);
716                        }
717                        self.state.transition(ConnectionState::Connected)?;
718                        self.health.write().record_success(0);
719                        info!("reconnected to {}", ep.url);
720                        return Ok(());
721                    }
722                    Err(e) => {
723                        debug!("reconnect to {} failed: {}", ep.url, e);
724                    }
725                }
726            }
727
728            self.health.write().record_failure();
729        }
730
731        self.state.force_set(ConnectionState::Failed);
732        Err(SdkError::Connection(format!(
733            "reconnect failed after {} attempts",
734            self.reconnect_config.max_attempts
735        )))
736    }
737
738    // -- health check --------------------------------------------------------
739
740    /// Run a single health check against the active endpoint.
741    pub async fn check_health(&self) -> Result<()> {
742        let url = self.active_endpoint().ok_or_else(|| {
743            SdkError::Connection("no active endpoint to health-check".to_string())
744        })?;
745
746        let start = Instant::now();
747        match self.try_connect_endpoint(&url).await {
748            Ok(()) => {
749                let latency = start.elapsed().as_millis() as u64;
750                self.health.write().record_success(latency);
751                debug!("health check OK – {}ms", latency);
752                Ok(())
753            }
754            Err(e) => {
755                self.health.write().record_failure();
756                let failures = self.health.read().consecutive_failures;
757                warn!("health check failed ({} consecutive): {}", failures, e);
758                // Trigger reconnect after 3 consecutive failures.
759                if failures >= 3 && self.is_auto_reconnect_enabled() {
760                    error!(
761                        "triggering reconnect after {} consecutive health-check failures",
762                        failures
763                    );
764                    // Don't propagate reconnect errors from health check.
765                    let _ = self.reconnect_loop().await;
766                }
767                Err(SdkError::Connection(format!("health check failed: {}", e)))
768            }
769        }
770    }
771
772    // -- internal helpers ----------------------------------------------------
773
774    /// Attempt a tonic connection to `url`. Does not modify state.
775    async fn try_connect_endpoint(&self, url: &str) -> Result<()> {
776        use tonic::transport::Endpoint;
777
778        let mut endpoint = Endpoint::from_shared(url.to_string())
779            .map_err(|e| SdkError::Configuration(format!("invalid endpoint url: {}", e)))?;
780
781        endpoint = endpoint
782            .timeout(self.config.request_timeout)
783            .connect_timeout(self.config.connect_timeout);
784
785        if self.config.keep_alive {
786            endpoint = endpoint
787                .keep_alive_timeout(self.config.keep_alive_timeout)
788                .http2_keep_alive_interval(self.config.keep_alive_interval);
789        }
790
791        let _channel = tokio::time::timeout(self.config.connect_timeout, endpoint.connect())
792            .await
793            .map_err(|_| {
794                SdkError::Timeout(format!(
795                    "endpoint {} connect timeout after {:?}",
796                    url, self.config.connect_timeout
797                ))
798            })?
799            .map_err(SdkError::Transport)?;
800
801        Ok(())
802    }
803
804    /// Spawn a periodic health-check task if not already running.
805    fn maybe_spawn_health_check(&self) {
806        let interval = self.health_check_interval;
807        let health = Arc::clone(&self.health);
808        let state = self.state.clone();
809        let cancel = Arc::clone(&self.cancel);
810        let auto_reconnect = Arc::clone(&self.auto_reconnect_enabled);
811        let endpoints = Arc::clone(&self.endpoints);
812        let config = self.config.clone();
813
814        let handle = tokio::spawn(async move {
815            loop {
816                tokio::select! {
817                    _ = tokio::time::sleep(interval) => {}
818                    _ = cancel.notified() => {
819                        debug!("health-check task cancelled");
820                        return;
821                    }
822                }
823
824                // Only check if we are connected.
825                if state.get() != ConnectionState::Connected {
826                    continue;
827                }
828
829                let url = {
830                    let list = endpoints.read();
831                    list.active().map(|a| a.url.clone())
832                };
833
834                let url = match url {
835                    Some(u) => u,
836                    None => continue,
837                };
838
839                let start = Instant::now();
840                let result = {
841                    use tonic::transport::Endpoint;
842                    let endpoint = match Endpoint::from_shared(url.clone()) {
843                        Ok(ep) => ep
844                            .timeout(config.request_timeout)
845                            .connect_timeout(config.connect_timeout),
846                        Err(_) => continue,
847                    };
848                    tokio::time::timeout(config.connect_timeout, endpoint.connect()).await
849                };
850
851                match result {
852                    Ok(Ok(_)) => {
853                        let latency = start.elapsed().as_millis() as u64;
854                        health.write().record_success(latency);
855                    }
856                    _ => {
857                        health.write().record_failure();
858                        let failures = health.read().consecutive_failures;
859                        if failures >= 3 && auto_reconnect.load(Ordering::Acquire) {
860                            warn!(
861                                "health-check task: {} consecutive failures, signalling reconnect",
862                                failures
863                            );
864                            // Signal – the next connect attempt will handle it.
865                            state.force_set(ConnectionState::Reconnecting);
866                        }
867                    }
868                }
869            }
870        });
871
872        self._task_handles.lock().push(handle);
873    }
874}
875
876impl Drop for ConnectionManager {
877    fn drop(&mut self) {
878        // Signal cancellation to all background tasks.
879        self.cancel.notify_waiters();
880        for handle in self._task_handles.lock().iter() {
881            handle.abort();
882        }
883    }
884}
885
886// ---------------------------------------------------------------------------
887// Tests
888// ---------------------------------------------------------------------------
889
890#[cfg(test)]
891mod tests {
892    use super::*;
893
894    // -- ConnectionState transitions -----------------------------------------
895
896    #[test]
897    fn test_state_initial() {
898        let s = AtomicConnectionState::new();
899        assert_eq!(s.get(), ConnectionState::Disconnected);
900    }
901
902    #[test]
903    fn test_valid_transitions() {
904        let s = AtomicConnectionState::new();
905
906        // Disconnected -> Connecting
907        assert!(s.transition(ConnectionState::Connecting).is_ok());
908        assert_eq!(s.get(), ConnectionState::Connecting);
909
910        // Connecting -> Connected
911        assert!(s.transition(ConnectionState::Connected).is_ok());
912        assert_eq!(s.get(), ConnectionState::Connected);
913
914        // Connected -> Reconnecting
915        assert!(s.transition(ConnectionState::Reconnecting).is_ok());
916        assert_eq!(s.get(), ConnectionState::Reconnecting);
917
918        // Reconnecting -> Connected
919        assert!(s.transition(ConnectionState::Connected).is_ok());
920        assert_eq!(s.get(), ConnectionState::Connected);
921
922        // Connected -> Disconnected
923        assert!(s.transition(ConnectionState::Disconnected).is_ok());
924        assert_eq!(s.get(), ConnectionState::Disconnected);
925    }
926
927    #[test]
928    fn test_invalid_transition() {
929        let s = AtomicConnectionState::new();
930        // Disconnected -> Connected (must go via Connecting)
931        assert!(s.transition(ConnectionState::Connected).is_err());
932    }
933
934    #[test]
935    fn test_failed_to_disconnected() {
936        let s = AtomicConnectionState::new();
937        s.force_set(ConnectionState::Failed);
938        assert_eq!(s.get(), ConnectionState::Failed);
939        // Failed -> Disconnected (reset)
940        assert!(s.transition(ConnectionState::Disconnected).is_ok());
941        assert_eq!(s.get(), ConnectionState::Disconnected);
942    }
943
944    #[test]
945    fn test_state_callback() {
946        let s = AtomicConnectionState::new();
947        let transitions = Arc::new(Mutex::new(Vec::new()));
948        let t_clone = Arc::clone(&transitions);
949        s.on_state_change(move |from, to| {
950            t_clone.lock().push((from, to));
951        });
952
953        let _ = s.transition(ConnectionState::Connecting);
954        let _ = s.transition(ConnectionState::Connected);
955
956        let recorded = transitions.lock();
957        assert_eq!(recorded.len(), 2);
958        assert_eq!(
959            recorded[0],
960            (ConnectionState::Disconnected, ConnectionState::Connecting)
961        );
962        assert_eq!(
963            recorded[1],
964            (ConnectionState::Connecting, ConnectionState::Connected)
965        );
966    }
967
968    #[test]
969    fn test_state_display() {
970        assert_eq!(ConnectionState::Connected.to_string(), "Connected");
971        assert_eq!(ConnectionState::Failed.as_str(), "Failed");
972    }
973
974    // -- EndpointList --------------------------------------------------------
975
976    #[test]
977    fn test_endpoint_list_priority_ordering() {
978        let mut list = EndpointList::new();
979        list.add_endpoint("http://c:50051", 20);
980        list.add_endpoint("http://a:50051", 0);
981        list.add_endpoint("http://b:50051", 10);
982
983        let urls: Vec<&str> = list.iter().map(|e| e.url.as_str()).collect();
984        assert_eq!(
985            urls,
986            vec!["http://a:50051", "http://b:50051", "http://c:50051"]
987        );
988    }
989
990    #[test]
991    fn test_endpoint_list_no_duplicates() {
992        let mut list = EndpointList::new();
993        list.add_endpoint("http://a:50051", 0);
994        list.add_endpoint("http://a:50051", 10);
995        assert_eq!(list.len(), 1);
996    }
997
998    #[test]
999    fn test_endpoint_list_primary() {
1000        let list = EndpointList::with_primary("http://primary:50051");
1001        assert_eq!(
1002            list.primary().map(|e| e.url.as_str()),
1003            Some("http://primary:50051")
1004        );
1005    }
1006
1007    #[test]
1008    fn test_endpoint_failover() {
1009        let mut list = EndpointList::new();
1010        list.add_endpoint("http://a:50051", 0);
1011        list.add_endpoint("http://b:50051", 10);
1012        list.add_endpoint("http://c:50051", 20);
1013
1014        // No active yet – failover picks first.
1015        let url = list.failover();
1016        assert_eq!(url, Some("http://a:50051".to_string()));
1017
1018        // Now at 0, failover picks 1.
1019        let url = list.failover();
1020        assert_eq!(url, Some("http://b:50051".to_string()));
1021
1022        // At 1, failover picks 2.
1023        let url = list.failover();
1024        assert_eq!(url, Some("http://c:50051".to_string()));
1025
1026        // At 2, wraps around to 0.
1027        let url = list.failover();
1028        assert_eq!(url, Some("http://a:50051".to_string()));
1029    }
1030
1031    #[test]
1032    fn test_endpoint_set_active_by_url() {
1033        let mut list = EndpointList::new();
1034        list.add_endpoint("http://a:50051", 0);
1035        list.add_endpoint("http://b:50051", 10);
1036
1037        assert!(list.set_active_by_url("http://b:50051").is_ok());
1038        assert_eq!(
1039            list.active().map(|a| a.url.as_str()),
1040            Some("http://b:50051")
1041        );
1042
1043        // Non-existent URL.
1044        assert!(list.set_active_by_url("http://z:50051").is_err());
1045    }
1046
1047    #[test]
1048    fn test_endpoint_empty_failover() {
1049        let mut list = EndpointList::new();
1050        assert!(list.failover().is_none());
1051    }
1052
1053    #[test]
1054    fn test_endpoint_clear_active() {
1055        let mut list = EndpointList::with_primary("http://a:50051");
1056        list.set_active(0).expect("set_active should succeed");
1057        assert!(list.active().is_some());
1058        list.clear_active();
1059        assert!(list.active().is_none());
1060    }
1061
1062    // -- ReconnectConfig -----------------------------------------------------
1063
1064    #[test]
1065    fn test_reconnect_config_defaults() {
1066        let cfg = ReconnectConfig::default();
1067        assert_eq!(cfg.max_attempts, 5);
1068        assert_eq!(cfg.base_delay, Duration::from_secs(1));
1069        assert_eq!(cfg.max_delay, Duration::from_secs(30));
1070        assert!((cfg.backoff_factor - 2.0).abs() < f64::EPSILON);
1071        assert!(cfg.jitter);
1072    }
1073
1074    #[test]
1075    fn test_reconnect_backoff_no_jitter() {
1076        let cfg = ReconnectConfig {
1077            max_attempts: 5,
1078            base_delay: Duration::from_secs(1),
1079            max_delay: Duration::from_secs(30),
1080            backoff_factor: 2.0,
1081            jitter: false,
1082        };
1083
1084        assert_eq!(cfg.delay_for_attempt(0), Duration::from_secs(1)); // 1 * 2^0 = 1
1085        assert_eq!(cfg.delay_for_attempt(1), Duration::from_secs(2)); // 1 * 2^1 = 2
1086        assert_eq!(cfg.delay_for_attempt(2), Duration::from_secs(4)); // 1 * 2^2 = 4
1087        assert_eq!(cfg.delay_for_attempt(3), Duration::from_secs(8)); // 1 * 2^3 = 8
1088        assert_eq!(cfg.delay_for_attempt(4), Duration::from_secs(16)); // 1 * 2^4 = 16
1089    }
1090
1091    #[test]
1092    fn test_reconnect_backoff_clamped() {
1093        let cfg = ReconnectConfig {
1094            max_attempts: 10,
1095            base_delay: Duration::from_secs(1),
1096            max_delay: Duration::from_secs(10),
1097            backoff_factor: 2.0,
1098            jitter: false,
1099        };
1100
1101        // 2^5 = 32 > 10, so clamped to 10.
1102        assert_eq!(cfg.delay_for_attempt(5), Duration::from_secs(10));
1103        assert_eq!(cfg.delay_for_attempt(8), Duration::from_secs(10));
1104    }
1105
1106    #[test]
1107    fn test_reconnect_backoff_with_jitter() {
1108        let cfg = ReconnectConfig::default(); // jitter = true
1109
1110        let d0 = cfg.delay_for_attempt(0);
1111        let d1 = cfg.delay_for_attempt(1);
1112        // With jitter, delays should still increase overall.
1113        // d0 base = 1000ms, d1 base = 2000ms.
1114        assert!(d1 > d0, "d1={:?} should be > d0={:?}", d1, d0);
1115    }
1116
1117    #[test]
1118    fn test_reconnect_builder() {
1119        let cfg = ReconnectConfig::new()
1120            .with_max_attempts(10)
1121            .with_base_delay(Duration::from_millis(500))
1122            .with_max_delay(Duration::from_secs(60))
1123            .with_backoff_factor(3.0);
1124
1125        assert_eq!(cfg.max_attempts, 10);
1126        assert_eq!(cfg.base_delay, Duration::from_millis(500));
1127        assert_eq!(cfg.max_delay, Duration::from_secs(60));
1128        assert!((cfg.backoff_factor - 3.0).abs() < f64::EPSILON);
1129    }
1130
1131    // -- ConnectionHealth ----------------------------------------------------
1132
1133    #[test]
1134    fn test_health_default() {
1135        let h = ConnectionHealth::default();
1136        assert!(!h.is_healthy);
1137        assert_eq!(h.consecutive_failures, 0);
1138        assert!(h.last_check.is_none());
1139        assert!(h.latency_ms.is_none());
1140    }
1141
1142    #[test]
1143    fn test_health_success() {
1144        let mut h = ConnectionHealth::default();
1145        h.record_success(42);
1146        assert!(h.is_healthy);
1147        assert_eq!(h.latency_ms, Some(42));
1148        assert_eq!(h.consecutive_failures, 0);
1149        assert!(h.last_check.is_some());
1150    }
1151
1152    #[test]
1153    fn test_health_failure_counter() {
1154        let mut h = ConnectionHealth::default();
1155        h.record_failure();
1156        h.record_failure();
1157        h.record_failure();
1158        assert_eq!(h.consecutive_failures, 3);
1159        assert!(!h.is_healthy);
1160
1161        // A success resets the counter.
1162        h.record_success(10);
1163        assert_eq!(h.consecutive_failures, 0);
1164        assert!(h.is_healthy);
1165    }
1166
1167    #[test]
1168    fn test_health_reset() {
1169        let mut h = ConnectionHealth::default();
1170        h.record_success(5);
1171        h.record_failure();
1172        h.reset();
1173        assert!(h.last_check.is_none());
1174        assert!(!h.is_healthy);
1175        assert_eq!(h.consecutive_failures, 0);
1176    }
1177
1178    // -- ConnectionManager ---------------------------------------------------
1179
1180    #[test]
1181    fn test_manager_initial_state() {
1182        let mgr = ConnectionManager::with_primary(ClientConfig::default());
1183        assert_eq!(mgr.state(), ConnectionState::Disconnected);
1184    }
1185
1186    #[test]
1187    fn test_manager_disconnect_cleans_up() {
1188        let mgr = ConnectionManager::with_primary(ClientConfig::default());
1189        // Force to connected state for testing.
1190        mgr.state.force_set(ConnectionState::Connected);
1191        mgr.endpoints
1192            .write()
1193            .set_active(0)
1194            .expect("set_active should succeed");
1195
1196        mgr.disconnect();
1197
1198        assert_eq!(mgr.state(), ConnectionState::Disconnected);
1199        assert!(mgr.active_endpoint().is_none());
1200        assert!(!mgr.health().is_healthy);
1201    }
1202
1203    #[test]
1204    fn test_manager_auto_reconnect_toggle() {
1205        let mgr = ConnectionManager::with_primary(ClientConfig::default());
1206        assert!(mgr.is_auto_reconnect_enabled());
1207
1208        mgr.disable_auto_reconnect();
1209        assert!(!mgr.is_auto_reconnect_enabled());
1210
1211        mgr.enable_auto_reconnect();
1212        assert!(mgr.is_auto_reconnect_enabled());
1213    }
1214
1215    #[test]
1216    fn test_manager_health_check_interval() {
1217        let mgr = ConnectionManager::with_primary(ClientConfig::default())
1218            .with_health_check_interval(Duration::from_secs(10));
1219        assert_eq!(mgr.health_check_interval, Duration::from_secs(10));
1220    }
1221
1222    #[test]
1223    fn test_manager_endpoints_access() {
1224        let mut eps = EndpointList::new();
1225        eps.add_endpoint("http://a:50051", 0);
1226        eps.add_endpoint("http://b:50051", 10);
1227
1228        let mgr = ConnectionManager::new(ClientConfig::default(), eps, ReconnectConfig::default());
1229
1230        let list = mgr.endpoints();
1231        assert_eq!(list.len(), 2);
1232        assert_eq!(
1233            list.primary().map(|e| e.url.as_str()),
1234            Some("http://a:50051")
1235        );
1236    }
1237
1238    #[tokio::test]
1239    async fn test_manager_connect_no_endpoints() {
1240        let mgr = ConnectionManager::new(
1241            ClientConfig::default(),
1242            EndpointList::new(),
1243            ReconnectConfig::default(),
1244        );
1245
1246        let result = mgr.connect().await;
1247        assert!(result.is_err());
1248        assert_eq!(mgr.state(), ConnectionState::Failed);
1249    }
1250
1251    #[tokio::test]
1252    async fn test_manager_connect_unreachable_endpoint() {
1253        // Use a non-routable address so the connect attempt fails quickly.
1254        let config = ClientConfig::new("http://192.0.2.1:1")
1255            .with_connect_timeout(Duration::from_millis(100));
1256
1257        let eps = EndpointList::with_primary("http://192.0.2.1:1");
1258        let mgr = ConnectionManager::new(config, eps, ReconnectConfig::default());
1259
1260        let result = mgr.connect().await;
1261        assert!(result.is_err());
1262        assert_eq!(mgr.state(), ConnectionState::Failed);
1263    }
1264
1265    #[tokio::test]
1266    async fn test_manager_reconnect_disabled() {
1267        let mgr = ConnectionManager::with_primary(ClientConfig::default());
1268        mgr.disable_auto_reconnect();
1269        mgr.state.force_set(ConnectionState::Connected);
1270
1271        let result = mgr.reconnect_loop().await;
1272        assert!(result.is_err());
1273    }
1274
1275    #[test]
1276    fn test_state_from_u8_invalid() {
1277        assert!(ConnectionState::from_u8(255).is_none());
1278        assert!(ConnectionState::from_u8(5).is_none());
1279    }
1280
1281    #[test]
1282    fn test_endpoint_next_no_active() {
1283        let mut list = EndpointList::new();
1284        list.add_endpoint("http://a:50051", 0);
1285        list.add_endpoint("http://b:50051", 10);
1286
1287        // No active → returns first.
1288        let next = list.next_endpoint();
1289        assert_eq!(next.map(|e| e.url.as_str()), Some("http://a:50051"));
1290    }
1291
1292    #[test]
1293    fn test_endpoint_next_with_active() {
1294        let mut list = EndpointList::new();
1295        list.add_endpoint("http://a:50051", 0);
1296        list.add_endpoint("http://b:50051", 10);
1297        list.set_active(0).expect("set_active should succeed");
1298
1299        let next = list.next_endpoint();
1300        assert_eq!(next.map(|e| e.url.as_str()), Some("http://b:50051"));
1301    }
1302
1303    #[test]
1304    fn test_manager_state_change_callback() {
1305        let mgr = ConnectionManager::with_primary(ClientConfig::default());
1306        let states = Arc::new(Mutex::new(Vec::new()));
1307        let s_clone = Arc::clone(&states);
1308
1309        mgr.on_state_change(move |from, to| {
1310            s_clone.lock().push((from, to));
1311        });
1312
1313        mgr.state.force_set(ConnectionState::Connecting);
1314        mgr.state.force_set(ConnectionState::Connected);
1315
1316        let recorded = states.lock();
1317        assert_eq!(recorded.len(), 2);
1318    }
1319}