1use crate::config::ClientConfig;
11use crate::error::{Result, SdkError};
12use parking_lot::{Mutex, RwLock};
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
15use std::time::{Duration, Instant};
16use tokio::sync::Notify;
17use tracing::{debug, error, info, warn};
18
19#[repr(u8)]
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub enum ConnectionState {
27 Disconnected = 0,
29 Connecting = 1,
31 Connected = 2,
33 Reconnecting = 3,
35 Failed = 4,
37}
38
39impl ConnectionState {
40 fn from_u8(v: u8) -> Option<Self> {
42 match v {
43 0 => Some(Self::Disconnected),
44 1 => Some(Self::Connecting),
45 2 => Some(Self::Connected),
46 3 => Some(Self::Reconnecting),
47 4 => Some(Self::Failed),
48 _ => None,
49 }
50 }
51
52 pub fn as_str(&self) -> &'static str {
54 match self {
55 Self::Disconnected => "Disconnected",
56 Self::Connecting => "Connecting",
57 Self::Connected => "Connected",
58 Self::Reconnecting => "Reconnecting",
59 Self::Failed => "Failed",
60 }
61 }
62}
63
64impl std::fmt::Display for ConnectionState {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.write_str(self.as_str())
67 }
68}
69
70fn is_valid_transition(from: ConnectionState, to: ConnectionState) -> bool {
79 use ConnectionState::*;
80 matches!(
81 (from, to),
82 (Disconnected, Connecting)
83 | (Connecting, Connected)
84 | (Connecting, Failed)
85 | (Connected, Reconnecting)
86 | (Connected, Disconnected)
87 | (Connected, Failed)
88 | (Reconnecting, Connected)
89 | (Reconnecting, Failed)
90 | (Failed, Disconnected)
91 )
92}
93
94pub type StateChangeCallback =
96 Arc<dyn Fn(ConnectionState, ConnectionState) + Send + Sync + 'static>;
97
98#[derive(Clone)]
100pub struct AtomicConnectionState {
101 raw: Arc<AtomicU8>,
102 callback: Arc<RwLock<Option<StateChangeCallback>>>,
103}
104
105impl AtomicConnectionState {
106 pub fn new() -> Self {
108 Self {
109 raw: Arc::new(AtomicU8::new(ConnectionState::Disconnected as u8)),
110 callback: Arc::new(RwLock::new(None)),
111 }
112 }
113
114 pub fn get(&self) -> ConnectionState {
116 ConnectionState::from_u8(self.raw.load(Ordering::Acquire))
117 .unwrap_or(ConnectionState::Failed)
118 }
119
120 pub fn transition(&self, to: ConnectionState) -> Result<ConnectionState> {
122 let from = self.get();
123 if !is_valid_transition(from, to) {
124 return Err(SdkError::Connection(format!(
125 "invalid state transition: {} -> {}",
126 from, to
127 )));
128 }
129 self.raw.store(to as u8, Ordering::Release);
130 debug!("state transition: {} -> {}", from, to);
131
132 if let Some(cb) = self.callback.read().as_ref() {
134 cb(from, to);
135 }
136
137 Ok(from)
138 }
139
140 pub fn force_set(&self, state: ConnectionState) {
142 let prev = self.get();
143 self.raw.store(state as u8, Ordering::Release);
144 if let Some(cb) = self.callback.read().as_ref() {
145 cb(prev, state);
146 }
147 }
148
149 pub fn on_state_change<F>(&self, f: F)
151 where
152 F: Fn(ConnectionState, ConnectionState) + Send + Sync + 'static,
153 {
154 *self.callback.write() = Some(Arc::new(f));
155 }
156
157 pub fn clear_callback(&self) {
159 *self.callback.write() = None;
160 }
161}
162
163impl Default for AtomicConnectionState {
164 fn default() -> Self {
165 Self::new()
166 }
167}
168
169#[derive(Debug, Clone, PartialEq, Eq)]
175pub struct EndpointEntry {
176 pub url: String,
178 pub priority: u32,
180}
181
182#[derive(Debug, Clone)]
184pub struct ActiveEndpoint {
185 pub index: usize,
187 pub url: String,
189 pub connected_since: Instant,
191}
192
193#[derive(Debug, Clone)]
195pub struct EndpointList {
196 entries: Vec<EndpointEntry>,
198 active: Option<ActiveEndpoint>,
200}
201
202impl EndpointList {
203 pub fn new() -> Self {
205 Self {
206 entries: Vec::new(),
207 active: None,
208 }
209 }
210
211 pub fn with_primary(url: impl Into<String>) -> Self {
213 let mut list = Self::new();
214 list.add_endpoint(url, 0);
215 list
216 }
217
218 pub fn add_endpoint(&mut self, url: impl Into<String>, priority: u32) {
220 let url_string = url.into();
221 if self.entries.iter().any(|e| e.url == url_string) {
223 return;
224 }
225 self.entries.push(EndpointEntry {
226 url: url_string,
227 priority,
228 });
229 self.entries.sort_by_key(|e| e.priority);
230 }
231
232 pub fn len(&self) -> usize {
234 self.entries.len()
235 }
236
237 pub fn is_empty(&self) -> bool {
239 self.entries.is_empty()
240 }
241
242 pub fn iter(&self) -> impl Iterator<Item = &EndpointEntry> {
244 self.entries.iter()
245 }
246
247 pub fn next_endpoint(&self) -> Option<&EndpointEntry> {
250 if self.entries.is_empty() {
251 return None;
252 }
253 let idx = match &self.active {
254 Some(active) => (active.index + 1) % self.entries.len(),
255 None => 0,
256 };
257 self.entries.get(idx)
258 }
259
260 pub fn primary(&self) -> Option<&EndpointEntry> {
262 self.entries.first()
263 }
264
265 pub fn set_active(&mut self, index: usize) -> Result<()> {
267 let entry = self.entries.get(index).ok_or_else(|| {
268 SdkError::InvalidArgument(format!(
269 "endpoint index {} out of range (len={})",
270 index,
271 self.entries.len()
272 ))
273 })?;
274 self.active = Some(ActiveEndpoint {
275 index,
276 url: entry.url.clone(),
277 connected_since: Instant::now(),
278 });
279 Ok(())
280 }
281
282 pub fn set_active_by_url(&mut self, url: &str) -> Result<()> {
284 let index = self
285 .entries
286 .iter()
287 .position(|e| e.url == url)
288 .ok_or_else(|| SdkError::InvalidArgument(format!("endpoint not found: {}", url)))?;
289 self.set_active(index)
290 }
291
292 pub fn active(&self) -> Option<&ActiveEndpoint> {
294 self.active.as_ref()
295 }
296
297 pub fn clear_active(&mut self) {
299 self.active = None;
300 }
301
302 pub fn failover(&mut self) -> Option<String> {
305 if self.entries.is_empty() {
306 return None;
307 }
308 let next_idx = match &self.active {
309 Some(active) => (active.index + 1) % self.entries.len(),
310 None => 0,
311 };
312 let url = self.entries[next_idx].url.clone();
313 self.active = Some(ActiveEndpoint {
314 index: next_idx,
315 url: url.clone(),
316 connected_since: Instant::now(),
317 });
318 info!("failover to endpoint [{}]: {}", next_idx, url);
319 Some(url)
320 }
321}
322
323impl Default for EndpointList {
324 fn default() -> Self {
325 Self::new()
326 }
327}
328
329#[derive(Debug, Clone)]
335pub struct ReconnectConfig {
336 pub max_attempts: u32,
338 pub base_delay: Duration,
340 pub max_delay: Duration,
342 pub backoff_factor: f64,
344 pub jitter: bool,
346}
347
348impl Default for ReconnectConfig {
349 fn default() -> Self {
350 Self {
351 max_attempts: 5,
352 base_delay: Duration::from_secs(1),
353 max_delay: Duration::from_secs(30),
354 backoff_factor: 2.0,
355 jitter: true,
356 }
357 }
358}
359
360impl ReconnectConfig {
361 pub fn new() -> Self {
363 Self::default()
364 }
365
366 pub fn with_max_attempts(mut self, n: u32) -> Self {
368 self.max_attempts = n;
369 self
370 }
371
372 pub fn with_base_delay(mut self, d: Duration) -> Self {
374 self.base_delay = d;
375 self
376 }
377
378 pub fn with_max_delay(mut self, d: Duration) -> Self {
380 self.max_delay = d;
381 self
382 }
383
384 pub fn with_backoff_factor(mut self, f: f64) -> Self {
386 self.backoff_factor = f;
387 self
388 }
389
390 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
392 let base_ms = self.base_delay.as_millis() as f64;
393 let raw = base_ms * self.backoff_factor.powi(attempt as i32);
394 let clamped = raw.min(self.max_delay.as_millis() as f64);
395 let ms = if self.jitter {
396 let jitter_frac = 0.75 + (((attempt as usize) % 5) as f64) * 0.1;
398 clamped * jitter_frac
399 } else {
400 clamped
401 };
402 Duration::from_millis(ms as u64)
403 }
404}
405
406#[derive(Debug, Clone, Default)]
412pub struct ConnectionHealth {
413 pub last_check: Option<Instant>,
415 pub latency_ms: Option<u64>,
417 pub consecutive_failures: u32,
419 pub is_healthy: bool,
421}
422
423impl ConnectionHealth {
424 pub fn record_success(&mut self, latency_ms: u64) {
426 self.last_check = Some(Instant::now());
427 self.latency_ms = Some(latency_ms);
428 self.consecutive_failures = 0;
429 self.is_healthy = true;
430 }
431
432 pub fn record_failure(&mut self) {
434 self.last_check = Some(Instant::now());
435 self.consecutive_failures += 1;
436 self.is_healthy = false;
437 }
438
439 pub fn reset(&mut self) {
441 *self = Self::default();
442 }
443}
444
445pub struct ConnectionManager {
454 config: ClientConfig,
456 endpoints: Arc<RwLock<EndpointList>>,
458 reconnect_config: ReconnectConfig,
460 state: AtomicConnectionState,
462 health: Arc<RwLock<ConnectionHealth>>,
464 health_check_interval: Duration,
466 auto_reconnect_enabled: Arc<AtomicBool>,
468 cancel: Arc<Notify>,
470 _task_handles: Arc<Mutex<Vec<tokio::task::JoinHandle<()>>>>,
472}
473
474impl ConnectionManager {
475 pub fn new(
477 config: ClientConfig,
478 endpoints: EndpointList,
479 reconnect_config: ReconnectConfig,
480 ) -> Self {
481 Self {
482 config,
483 endpoints: Arc::new(RwLock::new(endpoints)),
484 reconnect_config,
485 state: AtomicConnectionState::new(),
486 health: Arc::new(RwLock::new(ConnectionHealth::default())),
487 health_check_interval: Duration::from_secs(30),
488 auto_reconnect_enabled: Arc::new(AtomicBool::new(true)),
489 cancel: Arc::new(Notify::new()),
490 _task_handles: Arc::new(Mutex::new(Vec::new())),
491 }
492 }
493
494 pub fn with_primary(config: ClientConfig) -> Self {
496 let addr = config.server_addr.clone();
497 let endpoints = EndpointList::with_primary(addr);
498 Self::new(config, endpoints, ReconnectConfig::default())
499 }
500
501 pub fn with_health_check_interval(mut self, interval: Duration) -> Self {
503 self.health_check_interval = interval;
504 self
505 }
506
507 pub fn on_state_change<F>(&self, f: F)
509 where
510 F: Fn(ConnectionState, ConnectionState) + Send + Sync + 'static,
511 {
512 self.state.on_state_change(f);
513 }
514
515 pub fn state(&self) -> ConnectionState {
519 self.state.get()
520 }
521
522 pub fn health(&self) -> ConnectionHealth {
524 self.health.read().clone()
525 }
526
527 pub fn active_endpoint(&self) -> Option<String> {
529 self.endpoints.read().active().map(|a| a.url.clone())
530 }
531
532 pub fn endpoints(&self) -> EndpointList {
534 self.endpoints.read().clone()
535 }
536
537 pub fn config(&self) -> &ClientConfig {
539 &self.config
540 }
541
542 pub async fn connect(&self) -> Result<()> {
550 self.state.transition(ConnectionState::Connecting)?;
551
552 let endpoints: Vec<EndpointEntry> = {
553 let list = self.endpoints.read();
554 list.iter().cloned().collect()
555 };
556
557 if endpoints.is_empty() {
558 self.state.force_set(ConnectionState::Failed);
559 return Err(SdkError::Configuration(
560 "no endpoints configured".to_string(),
561 ));
562 }
563
564 for (idx, ep) in endpoints.iter().enumerate() {
565 info!("trying endpoint [{}] {}", idx, ep.url);
566 match self.try_connect_endpoint(&ep.url).await {
567 Ok(()) => {
568 self.endpoints.write().set_active(idx)?;
569 self.state.transition(ConnectionState::Connected)?;
570 self.health.write().record_success(0);
571 info!("connected to {}", ep.url);
572 self.maybe_spawn_health_check();
573 return Ok(());
574 }
575 Err(e) => {
576 warn!("endpoint {} failed: {}", ep.url, e);
577 continue;
578 }
579 }
580 }
581
582 self.state.force_set(ConnectionState::Failed);
583 Err(SdkError::Connection("all endpoints failed".to_string()))
584 }
585
586 pub fn disconnect(&self) {
588 info!("disconnecting");
589 self.cancel.notify_waiters();
590 self.endpoints.write().clear_active();
591 self.health.write().reset();
592
593 let current = self.state.get();
595 match current {
596 ConnectionState::Connected => {
597 let _ = self.state.transition(ConnectionState::Disconnected);
598 }
599 ConnectionState::Failed => {
600 let _ = self.state.transition(ConnectionState::Disconnected);
601 }
602 _ => {
603 self.state.force_set(ConnectionState::Disconnected);
604 }
605 }
606 }
607
608 pub async fn failover(&self) -> Result<String> {
610 let url = {
611 let mut list = self.endpoints.write();
612 list.failover().ok_or_else(|| {
613 SdkError::Connection("no endpoints available for failover".to_string())
614 })?
615 };
616
617 let current = self.state.get();
619 if current == ConnectionState::Connected {
620 self.state.transition(ConnectionState::Reconnecting)?;
621 }
622
623 match self.try_connect_endpoint(&url).await {
624 Ok(()) => {
625 if self.state.get() == ConnectionState::Reconnecting {
627 self.state.transition(ConnectionState::Connected)?;
628 }
629 self.health.write().record_success(0);
630 info!("failover successful to {}", url);
631 Ok(url)
632 }
633 Err(e) => {
634 self.state.force_set(ConnectionState::Failed);
635 Err(SdkError::Connection(format!(
636 "failover to {} failed: {}",
637 url, e
638 )))
639 }
640 }
641 }
642
643 pub fn enable_auto_reconnect(&self) {
647 self.auto_reconnect_enabled.store(true, Ordering::Release);
648 debug!("auto-reconnect enabled");
649 }
650
651 pub fn disable_auto_reconnect(&self) {
653 self.auto_reconnect_enabled.store(false, Ordering::Release);
654 debug!("auto-reconnect disabled");
655 }
656
657 pub fn is_auto_reconnect_enabled(&self) -> bool {
659 self.auto_reconnect_enabled.load(Ordering::Acquire)
660 }
661
662 pub async fn reconnect_loop(&self) -> Result<()> {
666 if !self.auto_reconnect_enabled.load(Ordering::Acquire) {
667 return Err(SdkError::Connection(
668 "auto-reconnect is disabled".to_string(),
669 ));
670 }
671
672 let current = self.state.get();
674 if current == ConnectionState::Connected {
675 self.state.transition(ConnectionState::Reconnecting)?;
676 } else if current != ConnectionState::Reconnecting {
677 self.state.force_set(ConnectionState::Reconnecting);
679 }
680
681 let endpoints: Vec<EndpointEntry> = {
682 let list = self.endpoints.read();
683 list.iter().cloned().collect()
684 };
685
686 for attempt in 0..self.reconnect_config.max_attempts {
687 if !self.auto_reconnect_enabled.load(Ordering::Acquire) {
688 warn!("auto-reconnect disabled during reconnect loop");
689 return Err(SdkError::Connection(
690 "auto-reconnect disabled during loop".to_string(),
691 ));
692 }
693
694 let delay = self.reconnect_config.delay_for_attempt(attempt);
695 info!(
696 "reconnect attempt {}/{} – waiting {:?}",
697 attempt + 1,
698 self.reconnect_config.max_attempts,
699 delay
700 );
701
702 tokio::select! {
703 _ = tokio::time::sleep(delay) => {}
704 _ = self.cancel.notified() => {
705 info!("reconnect loop cancelled");
706 return Err(SdkError::Connection("reconnect cancelled".to_string()));
707 }
708 }
709
710 for (idx, ep) in endpoints.iter().enumerate() {
712 match self.try_connect_endpoint(&ep.url).await {
713 Ok(()) => {
714 if let Err(e) = self.endpoints.write().set_active(idx) {
715 warn!("failed to set active endpoint: {}", e);
716 }
717 self.state.transition(ConnectionState::Connected)?;
718 self.health.write().record_success(0);
719 info!("reconnected to {}", ep.url);
720 return Ok(());
721 }
722 Err(e) => {
723 debug!("reconnect to {} failed: {}", ep.url, e);
724 }
725 }
726 }
727
728 self.health.write().record_failure();
729 }
730
731 self.state.force_set(ConnectionState::Failed);
732 Err(SdkError::Connection(format!(
733 "reconnect failed after {} attempts",
734 self.reconnect_config.max_attempts
735 )))
736 }
737
738 pub async fn check_health(&self) -> Result<()> {
742 let url = self.active_endpoint().ok_or_else(|| {
743 SdkError::Connection("no active endpoint to health-check".to_string())
744 })?;
745
746 let start = Instant::now();
747 match self.try_connect_endpoint(&url).await {
748 Ok(()) => {
749 let latency = start.elapsed().as_millis() as u64;
750 self.health.write().record_success(latency);
751 debug!("health check OK – {}ms", latency);
752 Ok(())
753 }
754 Err(e) => {
755 self.health.write().record_failure();
756 let failures = self.health.read().consecutive_failures;
757 warn!("health check failed ({} consecutive): {}", failures, e);
758 if failures >= 3 && self.is_auto_reconnect_enabled() {
760 error!(
761 "triggering reconnect after {} consecutive health-check failures",
762 failures
763 );
764 let _ = self.reconnect_loop().await;
766 }
767 Err(SdkError::Connection(format!("health check failed: {}", e)))
768 }
769 }
770 }
771
772 async fn try_connect_endpoint(&self, url: &str) -> Result<()> {
776 use tonic::transport::Endpoint;
777
778 let mut endpoint = Endpoint::from_shared(url.to_string())
779 .map_err(|e| SdkError::Configuration(format!("invalid endpoint url: {}", e)))?;
780
781 endpoint = endpoint
782 .timeout(self.config.request_timeout)
783 .connect_timeout(self.config.connect_timeout);
784
785 if self.config.keep_alive {
786 endpoint = endpoint
787 .keep_alive_timeout(self.config.keep_alive_timeout)
788 .http2_keep_alive_interval(self.config.keep_alive_interval);
789 }
790
791 let _channel = tokio::time::timeout(self.config.connect_timeout, endpoint.connect())
792 .await
793 .map_err(|_| {
794 SdkError::Timeout(format!(
795 "endpoint {} connect timeout after {:?}",
796 url, self.config.connect_timeout
797 ))
798 })?
799 .map_err(SdkError::Transport)?;
800
801 Ok(())
802 }
803
804 fn maybe_spawn_health_check(&self) {
806 let interval = self.health_check_interval;
807 let health = Arc::clone(&self.health);
808 let state = self.state.clone();
809 let cancel = Arc::clone(&self.cancel);
810 let auto_reconnect = Arc::clone(&self.auto_reconnect_enabled);
811 let endpoints = Arc::clone(&self.endpoints);
812 let config = self.config.clone();
813
814 let handle = tokio::spawn(async move {
815 loop {
816 tokio::select! {
817 _ = tokio::time::sleep(interval) => {}
818 _ = cancel.notified() => {
819 debug!("health-check task cancelled");
820 return;
821 }
822 }
823
824 if state.get() != ConnectionState::Connected {
826 continue;
827 }
828
829 let url = {
830 let list = endpoints.read();
831 list.active().map(|a| a.url.clone())
832 };
833
834 let url = match url {
835 Some(u) => u,
836 None => continue,
837 };
838
839 let start = Instant::now();
840 let result = {
841 use tonic::transport::Endpoint;
842 let endpoint = match Endpoint::from_shared(url.clone()) {
843 Ok(ep) => ep
844 .timeout(config.request_timeout)
845 .connect_timeout(config.connect_timeout),
846 Err(_) => continue,
847 };
848 tokio::time::timeout(config.connect_timeout, endpoint.connect()).await
849 };
850
851 match result {
852 Ok(Ok(_)) => {
853 let latency = start.elapsed().as_millis() as u64;
854 health.write().record_success(latency);
855 }
856 _ => {
857 health.write().record_failure();
858 let failures = health.read().consecutive_failures;
859 if failures >= 3 && auto_reconnect.load(Ordering::Acquire) {
860 warn!(
861 "health-check task: {} consecutive failures, signalling reconnect",
862 failures
863 );
864 state.force_set(ConnectionState::Reconnecting);
866 }
867 }
868 }
869 }
870 });
871
872 self._task_handles.lock().push(handle);
873 }
874}
875
876impl Drop for ConnectionManager {
877 fn drop(&mut self) {
878 self.cancel.notify_waiters();
880 for handle in self._task_handles.lock().iter() {
881 handle.abort();
882 }
883 }
884}
885
886#[cfg(test)]
891mod tests {
892 use super::*;
893
894 #[test]
897 fn test_state_initial() {
898 let s = AtomicConnectionState::new();
899 assert_eq!(s.get(), ConnectionState::Disconnected);
900 }
901
902 #[test]
903 fn test_valid_transitions() {
904 let s = AtomicConnectionState::new();
905
906 assert!(s.transition(ConnectionState::Connecting).is_ok());
908 assert_eq!(s.get(), ConnectionState::Connecting);
909
910 assert!(s.transition(ConnectionState::Connected).is_ok());
912 assert_eq!(s.get(), ConnectionState::Connected);
913
914 assert!(s.transition(ConnectionState::Reconnecting).is_ok());
916 assert_eq!(s.get(), ConnectionState::Reconnecting);
917
918 assert!(s.transition(ConnectionState::Connected).is_ok());
920 assert_eq!(s.get(), ConnectionState::Connected);
921
922 assert!(s.transition(ConnectionState::Disconnected).is_ok());
924 assert_eq!(s.get(), ConnectionState::Disconnected);
925 }
926
927 #[test]
928 fn test_invalid_transition() {
929 let s = AtomicConnectionState::new();
930 assert!(s.transition(ConnectionState::Connected).is_err());
932 }
933
934 #[test]
935 fn test_failed_to_disconnected() {
936 let s = AtomicConnectionState::new();
937 s.force_set(ConnectionState::Failed);
938 assert_eq!(s.get(), ConnectionState::Failed);
939 assert!(s.transition(ConnectionState::Disconnected).is_ok());
941 assert_eq!(s.get(), ConnectionState::Disconnected);
942 }
943
944 #[test]
945 fn test_state_callback() {
946 let s = AtomicConnectionState::new();
947 let transitions = Arc::new(Mutex::new(Vec::new()));
948 let t_clone = Arc::clone(&transitions);
949 s.on_state_change(move |from, to| {
950 t_clone.lock().push((from, to));
951 });
952
953 let _ = s.transition(ConnectionState::Connecting);
954 let _ = s.transition(ConnectionState::Connected);
955
956 let recorded = transitions.lock();
957 assert_eq!(recorded.len(), 2);
958 assert_eq!(
959 recorded[0],
960 (ConnectionState::Disconnected, ConnectionState::Connecting)
961 );
962 assert_eq!(
963 recorded[1],
964 (ConnectionState::Connecting, ConnectionState::Connected)
965 );
966 }
967
968 #[test]
969 fn test_state_display() {
970 assert_eq!(ConnectionState::Connected.to_string(), "Connected");
971 assert_eq!(ConnectionState::Failed.as_str(), "Failed");
972 }
973
974 #[test]
977 fn test_endpoint_list_priority_ordering() {
978 let mut list = EndpointList::new();
979 list.add_endpoint("http://c:50051", 20);
980 list.add_endpoint("http://a:50051", 0);
981 list.add_endpoint("http://b:50051", 10);
982
983 let urls: Vec<&str> = list.iter().map(|e| e.url.as_str()).collect();
984 assert_eq!(
985 urls,
986 vec!["http://a:50051", "http://b:50051", "http://c:50051"]
987 );
988 }
989
990 #[test]
991 fn test_endpoint_list_no_duplicates() {
992 let mut list = EndpointList::new();
993 list.add_endpoint("http://a:50051", 0);
994 list.add_endpoint("http://a:50051", 10);
995 assert_eq!(list.len(), 1);
996 }
997
998 #[test]
999 fn test_endpoint_list_primary() {
1000 let list = EndpointList::with_primary("http://primary:50051");
1001 assert_eq!(
1002 list.primary().map(|e| e.url.as_str()),
1003 Some("http://primary:50051")
1004 );
1005 }
1006
1007 #[test]
1008 fn test_endpoint_failover() {
1009 let mut list = EndpointList::new();
1010 list.add_endpoint("http://a:50051", 0);
1011 list.add_endpoint("http://b:50051", 10);
1012 list.add_endpoint("http://c:50051", 20);
1013
1014 let url = list.failover();
1016 assert_eq!(url, Some("http://a:50051".to_string()));
1017
1018 let url = list.failover();
1020 assert_eq!(url, Some("http://b:50051".to_string()));
1021
1022 let url = list.failover();
1024 assert_eq!(url, Some("http://c:50051".to_string()));
1025
1026 let url = list.failover();
1028 assert_eq!(url, Some("http://a:50051".to_string()));
1029 }
1030
1031 #[test]
1032 fn test_endpoint_set_active_by_url() {
1033 let mut list = EndpointList::new();
1034 list.add_endpoint("http://a:50051", 0);
1035 list.add_endpoint("http://b:50051", 10);
1036
1037 assert!(list.set_active_by_url("http://b:50051").is_ok());
1038 assert_eq!(
1039 list.active().map(|a| a.url.as_str()),
1040 Some("http://b:50051")
1041 );
1042
1043 assert!(list.set_active_by_url("http://z:50051").is_err());
1045 }
1046
1047 #[test]
1048 fn test_endpoint_empty_failover() {
1049 let mut list = EndpointList::new();
1050 assert!(list.failover().is_none());
1051 }
1052
1053 #[test]
1054 fn test_endpoint_clear_active() {
1055 let mut list = EndpointList::with_primary("http://a:50051");
1056 list.set_active(0).expect("set_active should succeed");
1057 assert!(list.active().is_some());
1058 list.clear_active();
1059 assert!(list.active().is_none());
1060 }
1061
1062 #[test]
1065 fn test_reconnect_config_defaults() {
1066 let cfg = ReconnectConfig::default();
1067 assert_eq!(cfg.max_attempts, 5);
1068 assert_eq!(cfg.base_delay, Duration::from_secs(1));
1069 assert_eq!(cfg.max_delay, Duration::from_secs(30));
1070 assert!((cfg.backoff_factor - 2.0).abs() < f64::EPSILON);
1071 assert!(cfg.jitter);
1072 }
1073
1074 #[test]
1075 fn test_reconnect_backoff_no_jitter() {
1076 let cfg = ReconnectConfig {
1077 max_attempts: 5,
1078 base_delay: Duration::from_secs(1),
1079 max_delay: Duration::from_secs(30),
1080 backoff_factor: 2.0,
1081 jitter: false,
1082 };
1083
1084 assert_eq!(cfg.delay_for_attempt(0), Duration::from_secs(1)); assert_eq!(cfg.delay_for_attempt(1), Duration::from_secs(2)); assert_eq!(cfg.delay_for_attempt(2), Duration::from_secs(4)); assert_eq!(cfg.delay_for_attempt(3), Duration::from_secs(8)); assert_eq!(cfg.delay_for_attempt(4), Duration::from_secs(16)); }
1090
1091 #[test]
1092 fn test_reconnect_backoff_clamped() {
1093 let cfg = ReconnectConfig {
1094 max_attempts: 10,
1095 base_delay: Duration::from_secs(1),
1096 max_delay: Duration::from_secs(10),
1097 backoff_factor: 2.0,
1098 jitter: false,
1099 };
1100
1101 assert_eq!(cfg.delay_for_attempt(5), Duration::from_secs(10));
1103 assert_eq!(cfg.delay_for_attempt(8), Duration::from_secs(10));
1104 }
1105
1106 #[test]
1107 fn test_reconnect_backoff_with_jitter() {
1108 let cfg = ReconnectConfig::default(); let d0 = cfg.delay_for_attempt(0);
1111 let d1 = cfg.delay_for_attempt(1);
1112 assert!(d1 > d0, "d1={:?} should be > d0={:?}", d1, d0);
1115 }
1116
1117 #[test]
1118 fn test_reconnect_builder() {
1119 let cfg = ReconnectConfig::new()
1120 .with_max_attempts(10)
1121 .with_base_delay(Duration::from_millis(500))
1122 .with_max_delay(Duration::from_secs(60))
1123 .with_backoff_factor(3.0);
1124
1125 assert_eq!(cfg.max_attempts, 10);
1126 assert_eq!(cfg.base_delay, Duration::from_millis(500));
1127 assert_eq!(cfg.max_delay, Duration::from_secs(60));
1128 assert!((cfg.backoff_factor - 3.0).abs() < f64::EPSILON);
1129 }
1130
1131 #[test]
1134 fn test_health_default() {
1135 let h = ConnectionHealth::default();
1136 assert!(!h.is_healthy);
1137 assert_eq!(h.consecutive_failures, 0);
1138 assert!(h.last_check.is_none());
1139 assert!(h.latency_ms.is_none());
1140 }
1141
1142 #[test]
1143 fn test_health_success() {
1144 let mut h = ConnectionHealth::default();
1145 h.record_success(42);
1146 assert!(h.is_healthy);
1147 assert_eq!(h.latency_ms, Some(42));
1148 assert_eq!(h.consecutive_failures, 0);
1149 assert!(h.last_check.is_some());
1150 }
1151
1152 #[test]
1153 fn test_health_failure_counter() {
1154 let mut h = ConnectionHealth::default();
1155 h.record_failure();
1156 h.record_failure();
1157 h.record_failure();
1158 assert_eq!(h.consecutive_failures, 3);
1159 assert!(!h.is_healthy);
1160
1161 h.record_success(10);
1163 assert_eq!(h.consecutive_failures, 0);
1164 assert!(h.is_healthy);
1165 }
1166
1167 #[test]
1168 fn test_health_reset() {
1169 let mut h = ConnectionHealth::default();
1170 h.record_success(5);
1171 h.record_failure();
1172 h.reset();
1173 assert!(h.last_check.is_none());
1174 assert!(!h.is_healthy);
1175 assert_eq!(h.consecutive_failures, 0);
1176 }
1177
1178 #[test]
1181 fn test_manager_initial_state() {
1182 let mgr = ConnectionManager::with_primary(ClientConfig::default());
1183 assert_eq!(mgr.state(), ConnectionState::Disconnected);
1184 }
1185
1186 #[test]
1187 fn test_manager_disconnect_cleans_up() {
1188 let mgr = ConnectionManager::with_primary(ClientConfig::default());
1189 mgr.state.force_set(ConnectionState::Connected);
1191 mgr.endpoints
1192 .write()
1193 .set_active(0)
1194 .expect("set_active should succeed");
1195
1196 mgr.disconnect();
1197
1198 assert_eq!(mgr.state(), ConnectionState::Disconnected);
1199 assert!(mgr.active_endpoint().is_none());
1200 assert!(!mgr.health().is_healthy);
1201 }
1202
1203 #[test]
1204 fn test_manager_auto_reconnect_toggle() {
1205 let mgr = ConnectionManager::with_primary(ClientConfig::default());
1206 assert!(mgr.is_auto_reconnect_enabled());
1207
1208 mgr.disable_auto_reconnect();
1209 assert!(!mgr.is_auto_reconnect_enabled());
1210
1211 mgr.enable_auto_reconnect();
1212 assert!(mgr.is_auto_reconnect_enabled());
1213 }
1214
1215 #[test]
1216 fn test_manager_health_check_interval() {
1217 let mgr = ConnectionManager::with_primary(ClientConfig::default())
1218 .with_health_check_interval(Duration::from_secs(10));
1219 assert_eq!(mgr.health_check_interval, Duration::from_secs(10));
1220 }
1221
1222 #[test]
1223 fn test_manager_endpoints_access() {
1224 let mut eps = EndpointList::new();
1225 eps.add_endpoint("http://a:50051", 0);
1226 eps.add_endpoint("http://b:50051", 10);
1227
1228 let mgr = ConnectionManager::new(ClientConfig::default(), eps, ReconnectConfig::default());
1229
1230 let list = mgr.endpoints();
1231 assert_eq!(list.len(), 2);
1232 assert_eq!(
1233 list.primary().map(|e| e.url.as_str()),
1234 Some("http://a:50051")
1235 );
1236 }
1237
1238 #[tokio::test]
1239 async fn test_manager_connect_no_endpoints() {
1240 let mgr = ConnectionManager::new(
1241 ClientConfig::default(),
1242 EndpointList::new(),
1243 ReconnectConfig::default(),
1244 );
1245
1246 let result = mgr.connect().await;
1247 assert!(result.is_err());
1248 assert_eq!(mgr.state(), ConnectionState::Failed);
1249 }
1250
1251 #[tokio::test]
1252 async fn test_manager_connect_unreachable_endpoint() {
1253 let config = ClientConfig::new("http://192.0.2.1:1")
1255 .with_connect_timeout(Duration::from_millis(100));
1256
1257 let eps = EndpointList::with_primary("http://192.0.2.1:1");
1258 let mgr = ConnectionManager::new(config, eps, ReconnectConfig::default());
1259
1260 let result = mgr.connect().await;
1261 assert!(result.is_err());
1262 assert_eq!(mgr.state(), ConnectionState::Failed);
1263 }
1264
1265 #[tokio::test]
1266 async fn test_manager_reconnect_disabled() {
1267 let mgr = ConnectionManager::with_primary(ClientConfig::default());
1268 mgr.disable_auto_reconnect();
1269 mgr.state.force_set(ConnectionState::Connected);
1270
1271 let result = mgr.reconnect_loop().await;
1272 assert!(result.is_err());
1273 }
1274
1275 #[test]
1276 fn test_state_from_u8_invalid() {
1277 assert!(ConnectionState::from_u8(255).is_none());
1278 assert!(ConnectionState::from_u8(5).is_none());
1279 }
1280
1281 #[test]
1282 fn test_endpoint_next_no_active() {
1283 let mut list = EndpointList::new();
1284 list.add_endpoint("http://a:50051", 0);
1285 list.add_endpoint("http://b:50051", 10);
1286
1287 let next = list.next_endpoint();
1289 assert_eq!(next.map(|e| e.url.as_str()), Some("http://a:50051"));
1290 }
1291
1292 #[test]
1293 fn test_endpoint_next_with_active() {
1294 let mut list = EndpointList::new();
1295 list.add_endpoint("http://a:50051", 0);
1296 list.add_endpoint("http://b:50051", 10);
1297 list.set_active(0).expect("set_active should succeed");
1298
1299 let next = list.next_endpoint();
1300 assert_eq!(next.map(|e| e.url.as_str()), Some("http://b:50051"));
1301 }
1302
1303 #[test]
1304 fn test_manager_state_change_callback() {
1305 let mgr = ConnectionManager::with_primary(ClientConfig::default());
1306 let states = Arc::new(Mutex::new(Vec::new()));
1307 let s_clone = Arc::clone(&states);
1308
1309 mgr.on_state_change(move |from, to| {
1310 s_clone.lock().push((from, to));
1311 });
1312
1313 mgr.state.force_set(ConnectionState::Connecting);
1314 mgr.state.force_set(ConnectionState::Connected);
1315
1316 let recorded = states.lock();
1317 assert_eq!(recorded.len(), 2);
1318 }
1319}