Skip to main content

lnc_client/
connection.rs

1//! Connection Management and Resilience
2//!
3//! Provides connection pooling, automatic reconnection, and resilience features
4//! for LANCE client connections.
5//!
6//! # Features
7//!
8//! - **Connection Pool**: Manage multiple connections to a LANCE server
9//! - **Auto-Reconnect**: Automatically reconnect on connection failures
10//! - **Health Checking**: Periodic health checks with ping/pong
11//! - **Backoff**: Exponential backoff for reconnection attempts
12//! - **Circuit Breaker**: Prevent cascading failures
13//!
14//! # Example
15//!
16//! ```rust,no_run
17//! use lnc_client::{ConnectionPool, ConnectionPoolConfig};
18//!
19//! #[tokio::main]
20//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
21//!     let pool = ConnectionPool::new(
22//!         "127.0.0.1:1992",
23//!         ConnectionPoolConfig::new()
24//!             .with_max_connections(10)
25//!             .with_health_check_interval(30),
26//!     ).await?;
27//!
28//!     // Get a connection from the pool
29//!     let mut conn = pool.get().await?;
30//!     
31//!     // Use the connection
32//!     conn.ping().await?;
33//!     
34//!     // Connection is returned to pool when dropped
35//!     Ok(())
36//! }
37//! ```
38
39use std::collections::VecDeque;
40use std::net::SocketAddr;
41use std::sync::Arc;
42use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
43use std::time::{Duration, Instant};
44
45use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
46
47use crate::client::{ClientConfig, LanceClient};
48use crate::error::{ClientError, Result};
49use crate::tls::TlsClientConfig;
50
51/// Configuration for connection pool
52#[derive(Debug, Clone)]
53pub struct ConnectionPoolConfig {
54    /// Maximum number of connections in the pool
55    pub max_connections: usize,
56    /// Minimum number of idle connections to maintain
57    pub min_idle: usize,
58    /// Connection timeout
59    pub connect_timeout: Duration,
60    /// Maximum time to wait for a connection from the pool
61    pub acquire_timeout: Duration,
62    /// Health check interval (0 = disabled)
63    pub health_check_interval: Duration,
64    /// Maximum connection lifetime (0 = unlimited)
65    pub max_lifetime: Duration,
66    /// Idle timeout before closing a connection
67    pub idle_timeout: Duration,
68    /// Enable automatic reconnection
69    pub auto_reconnect: bool,
70    /// Maximum reconnection attempts (0 = unlimited)
71    pub max_reconnect_attempts: u32,
72    /// Base delay for exponential backoff
73    pub reconnect_base_delay: Duration,
74    /// Maximum delay for exponential backoff
75    pub reconnect_max_delay: Duration,
76    /// TLS configuration (None = plain TCP)
77    pub tls_config: Option<TlsClientConfig>,
78}
79
80impl Default for ConnectionPoolConfig {
81    fn default() -> Self {
82        Self {
83            max_connections: 10,
84            min_idle: 1,
85            connect_timeout: Duration::from_secs(30),
86            acquire_timeout: Duration::from_secs(30),
87            health_check_interval: Duration::from_secs(30),
88            max_lifetime: Duration::from_secs(3600), // 1 hour
89            idle_timeout: Duration::from_secs(300),  // 5 minutes
90            auto_reconnect: true,
91            max_reconnect_attempts: 5,
92            reconnect_base_delay: Duration::from_millis(100),
93            reconnect_max_delay: Duration::from_secs(30),
94            tls_config: None,
95        }
96    }
97}
98
99impl ConnectionPoolConfig {
100    /// Create a new connection pool configuration with defaults
101    pub fn new() -> Self {
102        Self::default()
103    }
104
105    /// Set maximum connections
106    pub fn with_max_connections(mut self, n: usize) -> Self {
107        self.max_connections = n;
108        self
109    }
110
111    /// Set minimum idle connections
112    pub fn with_min_idle(mut self, n: usize) -> Self {
113        self.min_idle = n;
114        self
115    }
116
117    /// Set connection timeout
118    pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
119        self.connect_timeout = timeout;
120        self
121    }
122
123    /// Set acquire timeout
124    pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
125        self.acquire_timeout = timeout;
126        self
127    }
128
129    /// Set health check interval (seconds)
130    pub fn with_health_check_interval(mut self, secs: u64) -> Self {
131        self.health_check_interval = Duration::from_secs(secs);
132        self
133    }
134
135    /// Set maximum connection lifetime
136    pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
137        self.max_lifetime = lifetime;
138        self
139    }
140
141    /// Set idle timeout
142    pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
143        self.idle_timeout = timeout;
144        self
145    }
146
147    /// Enable or disable auto-reconnect
148    pub fn with_auto_reconnect(mut self, enabled: bool) -> Self {
149        self.auto_reconnect = enabled;
150        self
151    }
152
153    /// Set maximum reconnection attempts
154    pub fn with_max_reconnect_attempts(mut self, attempts: u32) -> Self {
155        self.max_reconnect_attempts = attempts;
156        self
157    }
158
159    /// Set TLS configuration for encrypted connections
160    pub fn with_tls(mut self, tls_config: TlsClientConfig) -> Self {
161        self.tls_config = Some(tls_config);
162        self
163    }
164}
165
166/// Connection pool statistics
167#[derive(Debug, Clone, Default)]
168pub struct PoolStats {
169    /// Total connections created
170    pub connections_created: u64,
171    /// Total connections closed
172    pub connections_closed: u64,
173    /// Current active connections (in use)
174    pub active_connections: u64,
175    /// Current idle connections (available)
176    pub idle_connections: u64,
177    /// Total acquire attempts
178    pub acquire_attempts: u64,
179    /// Successful acquires
180    pub acquire_successes: u64,
181    /// Failed acquires (timeout, error)
182    pub acquire_failures: u64,
183    /// Health check failures
184    pub health_check_failures: u64,
185    /// Reconnection attempts
186    pub reconnect_attempts: u64,
187}
188
189/// Internal pool metrics using atomics
190#[derive(Debug, Default)]
191struct PoolMetrics {
192    connections_created: AtomicU64,
193    connections_closed: AtomicU64,
194    active_connections: AtomicU64,
195    idle_connections: AtomicU64,
196    acquire_attempts: AtomicU64,
197    acquire_successes: AtomicU64,
198    acquire_failures: AtomicU64,
199    health_check_failures: AtomicU64,
200    reconnect_attempts: AtomicU64,
201}
202
203impl PoolMetrics {
204    fn snapshot(&self) -> PoolStats {
205        PoolStats {
206            connections_created: self.connections_created.load(Ordering::Relaxed),
207            connections_closed: self.connections_closed.load(Ordering::Relaxed),
208            active_connections: self.active_connections.load(Ordering::Relaxed),
209            idle_connections: self.idle_connections.load(Ordering::Relaxed),
210            acquire_attempts: self.acquire_attempts.load(Ordering::Relaxed),
211            acquire_successes: self.acquire_successes.load(Ordering::Relaxed),
212            acquire_failures: self.acquire_failures.load(Ordering::Relaxed),
213            health_check_failures: self.health_check_failures.load(Ordering::Relaxed),
214            reconnect_attempts: self.reconnect_attempts.load(Ordering::Relaxed),
215        }
216    }
217}
218
219/// Pooled connection wrapper
220struct PooledConnection {
221    client: LanceClient,
222    created_at: Instant,
223    last_used: Instant,
224}
225
226impl PooledConnection {
227    fn new(client: LanceClient) -> Self {
228        let now = Instant::now();
229        Self {
230            client,
231            created_at: now,
232            last_used: now,
233        }
234    }
235
236    fn is_expired(&self, max_lifetime: Duration) -> bool {
237        if max_lifetime.is_zero() {
238            return false;
239        }
240        self.created_at.elapsed() > max_lifetime
241    }
242
243    fn is_idle_too_long(&self, idle_timeout: Duration) -> bool {
244        if idle_timeout.is_zero() {
245            return false;
246        }
247        self.last_used.elapsed() > idle_timeout
248    }
249}
250
251/// Connection pool for managing LANCE client connections
252pub struct ConnectionPool {
253    addr: String,
254    config: ConnectionPoolConfig,
255    connections: Arc<Mutex<VecDeque<PooledConnection>>>,
256    semaphore: Arc<Semaphore>,
257    metrics: Arc<PoolMetrics>,
258    running: Arc<AtomicBool>,
259}
260
261impl ConnectionPool {
262    /// Create a new connection pool
263    ///
264    /// The address can be either an IP:port (e.g., "127.0.0.1:1992") or
265    /// a hostname:port (e.g., "lance.example.com:1992").
266    pub async fn new(addr: &str, config: ConnectionPoolConfig) -> Result<Self> {
267        let pool = Self {
268            addr: addr.to_string(),
269            config: config.clone(),
270            connections: Arc::new(Mutex::new(VecDeque::new())),
271            semaphore: Arc::new(Semaphore::new(config.max_connections)),
272            metrics: Arc::new(PoolMetrics::default()),
273            running: Arc::new(AtomicBool::new(true)),
274        };
275
276        // Pre-populate with minimum idle connections
277        for _ in 0..config.min_idle {
278            if let Ok(conn) = pool.create_connection().await {
279                let mut connections = pool.connections.lock().await;
280                connections.push_back(conn);
281                pool.metrics
282                    .idle_connections
283                    .fetch_add(1, Ordering::Relaxed);
284            }
285        }
286
287        // Start health check task if enabled
288        if !config.health_check_interval.is_zero() {
289            let pool_clone = ConnectionPool {
290                addr: pool.addr.clone(),
291                config: pool.config.clone(),
292                connections: pool.connections.clone(),
293                semaphore: pool.semaphore.clone(),
294                metrics: pool.metrics.clone(),
295                running: pool.running.clone(),
296            };
297            tokio::spawn(async move {
298                pool_clone.health_check_task().await;
299            });
300        }
301
302        Ok(pool)
303    }
304
305    /// Get a connection from the pool
306    pub async fn get(&self) -> Result<PooledClient> {
307        self.metrics
308            .acquire_attempts
309            .fetch_add(1, Ordering::Relaxed);
310
311        // Acquire permit with timeout
312        let permit = tokio::time::timeout(
313            self.config.acquire_timeout,
314            self.semaphore.clone().acquire_owned(),
315        )
316        .await
317        .map_err(|_| {
318            self.metrics
319                .acquire_failures
320                .fetch_add(1, Ordering::Relaxed);
321            ClientError::Timeout
322        })?
323        .map_err(|_| {
324            self.metrics
325                .acquire_failures
326                .fetch_add(1, Ordering::Relaxed);
327            ClientError::ConnectionClosed
328        })?;
329
330        // Try to get an existing connection
331        let conn = {
332            let mut connections = self.connections.lock().await;
333            loop {
334                match connections.pop_front() {
335                    Some(conn) => {
336                        self.metrics
337                            .idle_connections
338                            .fetch_sub(1, Ordering::Relaxed);
339
340                        // Check if connection is still valid
341                        if conn.is_expired(self.config.max_lifetime)
342                            || conn.is_idle_too_long(self.config.idle_timeout)
343                        {
344                            self.metrics
345                                .connections_closed
346                                .fetch_add(1, Ordering::Relaxed);
347                            continue;
348                        }
349                        break Some(conn);
350                    },
351                    None => break None,
352                }
353            }
354        };
355
356        let conn = match conn {
357            Some(mut c) => {
358                c.last_used = Instant::now();
359                c
360            },
361            None => {
362                // Create a new connection
363                self.create_connection().await?
364            },
365        };
366
367        self.metrics
368            .active_connections
369            .fetch_add(1, Ordering::Relaxed);
370        self.metrics
371            .acquire_successes
372            .fetch_add(1, Ordering::Relaxed);
373
374        Ok(PooledClient {
375            conn: Some(conn),
376            pool: self.connections.clone(),
377            metrics: self.metrics.clone(),
378            permit: Some(permit),
379            config: self.config.clone(),
380        })
381    }
382
383    /// Create a new connection
384    async fn create_connection(&self) -> Result<PooledConnection> {
385        let mut client_config = ClientConfig::new(&self.addr);
386        client_config.connect_timeout = self.config.connect_timeout;
387
388        let client = match &self.config.tls_config {
389            Some(tls_config) => LanceClient::connect_tls(client_config, tls_config.clone()).await?,
390            None => LanceClient::connect(client_config).await?,
391        };
392        self.metrics
393            .connections_created
394            .fetch_add(1, Ordering::Relaxed);
395
396        Ok(PooledConnection::new(client))
397    }
398
399    /// Get pool statistics
400    pub fn stats(&self) -> PoolStats {
401        self.metrics.snapshot()
402    }
403
404    /// Close the pool
405    pub async fn close(&self) {
406        self.running.store(false, Ordering::Relaxed);
407
408        let mut connections = self.connections.lock().await;
409        let count = connections.len() as u64;
410        connections.clear();
411        self.metrics
412            .connections_closed
413            .fetch_add(count, Ordering::Relaxed);
414        self.metrics.idle_connections.store(0, Ordering::Relaxed);
415    }
416
417    /// Health check task
418    async fn health_check_task(&self) {
419        let mut interval = tokio::time::interval(self.config.health_check_interval);
420
421        while self.running.load(Ordering::Relaxed) {
422            interval.tick().await;
423
424            // Get all connections for health check
425            let mut to_check = {
426                let mut connections = self.connections.lock().await;
427                std::mem::take(&mut *connections)
428            };
429
430            let mut healthy = VecDeque::new();
431            let _initial_count = to_check.len();
432
433            for mut conn in to_check.drain(..) {
434                // Check expiry
435                if conn.is_expired(self.config.max_lifetime) {
436                    self.metrics
437                        .connections_closed
438                        .fetch_add(1, Ordering::Relaxed);
439                    continue;
440                }
441
442                // Ping to check health
443                match conn.client.ping().await {
444                    Ok(_) => {
445                        conn.last_used = Instant::now();
446                        healthy.push_back(conn);
447                    },
448                    Err(_) => {
449                        self.metrics
450                            .health_check_failures
451                            .fetch_add(1, Ordering::Relaxed);
452                        self.metrics
453                            .connections_closed
454                            .fetch_add(1, Ordering::Relaxed);
455                    },
456                }
457            }
458
459            // Return healthy connections
460            {
461                let mut connections = self.connections.lock().await;
462                connections.extend(healthy);
463                self.metrics
464                    .idle_connections
465                    .store(connections.len() as u64, Ordering::Relaxed);
466            }
467        }
468    }
469}
470
471/// RAII wrapper for pooled connection
472pub struct PooledClient {
473    conn: Option<PooledConnection>,
474    pool: Arc<Mutex<VecDeque<PooledConnection>>>,
475    metrics: Arc<PoolMetrics>,
476    #[allow(dead_code)]
477    permit: Option<OwnedSemaphorePermit>,
478    #[allow(dead_code)]
479    config: ConnectionPoolConfig,
480}
481
482impl PooledClient {
483    /// Get a reference to the underlying client
484    pub fn client(&mut self) -> Result<&mut LanceClient> {
485        match self.conn.as_mut() {
486            Some(conn) => Ok(&mut conn.client),
487            None => Err(ClientError::ConnectionClosed),
488        }
489    }
490
491    /// Ping the server
492    pub async fn ping(&mut self) -> Result<Duration> {
493        if let Some(ref mut conn) = self.conn {
494            conn.client.ping().await
495        } else {
496            Err(ClientError::ConnectionClosed)
497        }
498    }
499
500    /// Mark the connection as unhealthy (don't return to pool)
501    pub fn mark_unhealthy(&mut self) {
502        self.conn = None;
503        self.metrics
504            .connections_closed
505            .fetch_add(1, Ordering::Relaxed);
506    }
507}
508
509impl Drop for PooledClient {
510    fn drop(&mut self) {
511        if let Some(mut conn) = self.conn.take() {
512            conn.last_used = Instant::now();
513
514            // Return to pool
515            let pool = self.pool.clone();
516            let metrics = self.metrics.clone();
517
518            tokio::spawn(async move {
519                let mut connections = pool.lock().await;
520                connections.push_back(conn);
521                metrics.active_connections.fetch_sub(1, Ordering::Relaxed);
522                metrics.idle_connections.fetch_add(1, Ordering::Relaxed);
523            });
524        } else {
525            self.metrics
526                .active_connections
527                .fetch_sub(1, Ordering::Relaxed);
528        }
529
530        // Permit is released when dropped
531    }
532}
533
534/// Reconnecting client wrapper with automatic reconnection
535/// Client with automatic reconnection and leader redirection support
536pub struct ReconnectingClient {
537    addr: String,
538    config: ClientConfig,
539    tls_config: Option<TlsClientConfig>,
540    client: Option<LanceClient>,
541    reconnect_attempts: u32,
542    max_attempts: u32,
543    base_delay: Duration,
544    max_delay: Duration,
545    /// Current leader address (for redirection)
546    leader_addr: Option<SocketAddr>,
547    /// Whether to follow leader redirects
548    follow_leader: bool,
549}
550
551impl ReconnectingClient {
552    /// Create a new reconnecting client
553    ///
554    /// The address can be either an IP:port (e.g., "127.0.0.1:1992") or
555    /// a hostname:port (e.g., "lance.example.com:1992").
556    pub async fn connect(addr: &str) -> Result<Self> {
557        let config = ClientConfig::new(addr);
558        let client = LanceClient::connect(config.clone()).await?;
559
560        Ok(Self {
561            addr: addr.to_string(),
562            config,
563            tls_config: None,
564            client: Some(client),
565            reconnect_attempts: 0,
566            max_attempts: 5,
567            base_delay: Duration::from_millis(100),
568            max_delay: Duration::from_secs(30),
569            leader_addr: None,
570            follow_leader: true,
571        })
572    }
573
574    /// Create a new reconnecting client with TLS
575    ///
576    /// The address can be either an IP:port (e.g., "127.0.0.1:1992") or
577    /// a hostname:port (e.g., "lance.example.com:1992").
578    pub async fn connect_tls(addr: &str, tls_config: TlsClientConfig) -> Result<Self> {
579        let config = ClientConfig::new(addr);
580        let client = LanceClient::connect_tls(config.clone(), tls_config.clone()).await?;
581
582        Ok(Self {
583            addr: addr.to_string(),
584            config,
585            tls_config: Some(tls_config),
586            client: Some(client),
587            reconnect_attempts: 0,
588            max_attempts: 5,
589            base_delay: Duration::from_millis(100),
590            max_delay: Duration::from_secs(30),
591            leader_addr: None,
592            follow_leader: true,
593        })
594    }
595
596    /// Set maximum reconnection attempts
597    pub fn with_max_attempts(mut self, attempts: u32) -> Self {
598        self.max_attempts = attempts;
599        self
600    }
601
602    /// Enable or disable automatic leader following
603    pub fn with_follow_leader(mut self, follow: bool) -> Self {
604        self.follow_leader = follow;
605        self
606    }
607
608    /// Get the original connection address
609    pub fn original_addr(&self) -> &str {
610        &self.addr
611    }
612
613    /// Get the current leader address if known
614    pub fn leader_addr(&self) -> Option<SocketAddr> {
615        self.leader_addr
616    }
617
618    /// Update the known leader address (called when redirect received)
619    pub fn set_leader_addr(&mut self, addr: SocketAddr) {
620        self.leader_addr = Some(addr);
621        if self.follow_leader {
622            // Update config to connect to leader on next reconnect
623            self.config.addr = addr.to_string();
624        }
625    }
626
627    /// Get total reconnection attempts made
628    pub fn reconnect_attempts(&self) -> u32 {
629        self.reconnect_attempts
630    }
631
632    /// Get a reference to the underlying client, reconnecting if needed
633    pub async fn client(&mut self) -> Result<&mut LanceClient> {
634        if self.client.is_none() {
635            self.reconnect().await?;
636        }
637        self.client.as_mut().ok_or(ClientError::ConnectionClosed)
638    }
639
640    /// Attempt to reconnect with exponential backoff
641    async fn reconnect(&mut self) -> Result<()> {
642        let mut attempts = 0;
643
644        loop {
645            attempts += 1;
646            self.reconnect_attempts += 1;
647
648            let result = match &self.tls_config {
649                Some(tls) => LanceClient::connect_tls(self.config.clone(), tls.clone()).await,
650                None => LanceClient::connect(self.config.clone()).await,
651            };
652
653            match result {
654                Ok(client) => {
655                    self.client = Some(client);
656                    return Ok(());
657                },
658                Err(e) => {
659                    if self.max_attempts > 0 && attempts >= self.max_attempts {
660                        return Err(e);
661                    }
662
663                    // Calculate backoff delay
664                    let delay = self.base_delay * 2u32.saturating_pow(attempts - 1);
665                    let delay = delay.min(self.max_delay);
666
667                    tokio::time::sleep(delay).await;
668                },
669            }
670        }
671    }
672
673    /// Execute an operation with automatic reconnection on failure
674    pub async fn execute<F, T>(&mut self, op: F) -> Result<T>
675    where
676        F: Fn(
677            &mut LanceClient,
678        )
679            -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send + '_>>,
680    {
681        loop {
682            let client = self.client().await?;
683
684            match op(client).await {
685                Ok(result) => return Ok(result),
686                Err(ClientError::ConnectionClosed) | Err(ClientError::ConnectionFailed(_)) => {
687                    self.client = None;
688                    // Will reconnect on next iteration
689                },
690                Err(e) => return Err(e),
691            }
692        }
693    }
694
695    /// Mark connection as failed
696    pub fn mark_failed(&mut self) {
697        self.client = None;
698    }
699}
700
701/// Cluster-aware client with automatic node discovery
702///
703/// Discovers cluster nodes and maintains connections for high availability.
704/// Note: Write routing to the leader is handled server-side via transparent
705/// forwarding - clients can send writes to ANY node.
706pub struct ClusterClient {
707    /// Known cluster nodes
708    nodes: Vec<SocketAddr>,
709    /// Primary node for this client (for connection affinity, not write routing)
710    primary: Option<SocketAddr>,
711    /// Client configuration
712    config: ClientConfig,
713    /// TLS configuration
714    tls_config: Option<TlsClientConfig>,
715    /// Active client connection
716    client: Option<LanceClient>,
717    /// Last successful discovery time
718    last_discovery: Option<Instant>,
719    /// Discovery refresh interval
720    discovery_interval: Duration,
721}
722
723impl ClusterClient {
724    /// Create a new cluster client with seed nodes
725    ///
726    /// Seed addresses can be either IP:port or hostname:port format.
727    pub async fn connect(seed_addrs: &[&str]) -> Result<Self> {
728        let nodes: Vec<SocketAddr> = seed_addrs.iter().filter_map(|s| s.parse().ok()).collect();
729
730        if nodes.is_empty() {
731            return Err(ClientError::ProtocolError(
732                "No valid seed addresses".to_string(),
733            ));
734        }
735
736        let config = ClientConfig::new(nodes[0].to_string());
737        let mut cluster = Self {
738            nodes,
739            primary: None,
740            config,
741            tls_config: None,
742            client: None,
743            last_discovery: None,
744            discovery_interval: Duration::from_secs(60),
745        };
746
747        cluster.discover_cluster().await?;
748        Ok(cluster)
749    }
750
751    /// Create a new cluster client with TLS
752    pub async fn connect_tls(seed_addrs: &[&str], tls_config: TlsClientConfig) -> Result<Self> {
753        let nodes: Vec<SocketAddr> = seed_addrs.iter().filter_map(|s| s.parse().ok()).collect();
754
755        if nodes.is_empty() {
756            return Err(ClientError::ProtocolError(
757                "No valid seed addresses".to_string(),
758            ));
759        }
760
761        let config = ClientConfig::new(nodes[0].to_string()).with_tls(tls_config.clone());
762        let mut cluster = Self {
763            nodes,
764            primary: None,
765            config,
766            tls_config: Some(tls_config),
767            client: None,
768            last_discovery: None,
769            discovery_interval: Duration::from_secs(60),
770        };
771
772        cluster.discover_cluster().await?;
773        Ok(cluster)
774    }
775
776    /// Set the discovery refresh interval
777    pub fn with_discovery_interval(mut self, interval: Duration) -> Self {
778        self.discovery_interval = interval;
779        self
780    }
781
782    /// Discover cluster topology from any available node
783    async fn discover_cluster(&mut self) -> Result<()> {
784        for &node in &self.nodes.clone() {
785            let mut config = self.config.clone();
786            config.addr = node.to_string();
787
788            match LanceClient::connect(config).await {
789                Ok(mut client) => {
790                    match client.get_cluster_status().await {
791                        Ok(status) => {
792                            self.primary = status.leader_id.map(|id| {
793                                // Try to find node in peer_states or use first node
794                                status
795                                    .peer_states
796                                    .get(&id)
797                                    .and_then(|s| s.parse().ok())
798                                    .unwrap_or(node)
799                            });
800                            self.last_discovery = Some(Instant::now());
801
802                            // Connect to primary if found
803                            if let Some(primary_addr) = self.primary {
804                                self.config.addr = primary_addr.to_string();
805                                self.client =
806                                    Some(LanceClient::connect(self.config.clone()).await?);
807                            } else {
808                                self.client = Some(client);
809                            }
810                            return Ok(());
811                        },
812                        Err(_) => {
813                            // Single-node mode or cluster not available
814                            self.client = Some(client);
815                            self.primary = Some(node);
816                            self.last_discovery = Some(Instant::now());
817                            return Ok(());
818                        },
819                    }
820                },
821                Err(_) => continue,
822            }
823        }
824
825        Err(ClientError::ConnectionFailed(std::io::Error::new(
826            std::io::ErrorKind::NotConnected,
827            "Could not connect to any cluster node",
828        )))
829    }
830
831    /// Get a client connection, refreshing discovery if needed
832    pub async fn client(&mut self) -> Result<&mut LanceClient> {
833        // Check if discovery refresh is needed
834        let needs_refresh = self
835            .last_discovery
836            .map(|t| t.elapsed() > self.discovery_interval)
837            .unwrap_or(true);
838
839        if needs_refresh || self.client.is_none() {
840            self.discover_cluster().await?;
841        }
842
843        self.client.as_mut().ok_or(ClientError::ConnectionClosed)
844    }
845
846    /// Get the current primary node address
847    pub fn primary(&self) -> Option<SocketAddr> {
848        self.primary
849    }
850
851    /// Get all known cluster nodes
852    pub fn nodes(&self) -> &[SocketAddr] {
853        &self.nodes
854    }
855
856    /// Get the TLS configuration if set
857    pub fn tls_config(&self) -> Option<&TlsClientConfig> {
858        self.tls_config.as_ref()
859    }
860
861    /// Check if TLS is enabled
862    pub fn is_tls_enabled(&self) -> bool {
863        self.tls_config.is_some()
864    }
865
866    /// Force a cluster discovery refresh
867    pub async fn refresh(&mut self) -> Result<()> {
868        self.discover_cluster().await
869    }
870}
871
872#[cfg(test)]
873#[allow(clippy::unwrap_used)]
874mod tests {
875    use super::*;
876
877    #[test]
878    fn test_pool_config_defaults() {
879        let config = ConnectionPoolConfig::new();
880
881        assert_eq!(config.max_connections, 10);
882        assert_eq!(config.min_idle, 1);
883        assert!(config.auto_reconnect);
884    }
885
886    #[test]
887    fn test_pool_config_builder() {
888        let config = ConnectionPoolConfig::new()
889            .with_max_connections(20)
890            .with_min_idle(5)
891            .with_health_check_interval(60)
892            .with_auto_reconnect(false);
893
894        assert_eq!(config.max_connections, 20);
895        assert_eq!(config.min_idle, 5);
896        assert_eq!(config.health_check_interval, Duration::from_secs(60));
897        assert!(!config.auto_reconnect);
898    }
899
900    #[test]
901    fn test_pool_stats_default() {
902        let stats = PoolStats::default();
903
904        assert_eq!(stats.connections_created, 0);
905        assert_eq!(stats.active_connections, 0);
906    }
907
908    #[test]
909    fn test_pooled_connection_expiry() {
910        use std::thread::sleep;
911
912        // Can't easily test without actual connection, just test the logic
913        let max_lifetime = Duration::from_millis(10);
914        let created_at = Instant::now();
915
916        sleep(Duration::from_millis(20));
917
918        assert!(created_at.elapsed() > max_lifetime);
919    }
920
921    #[test]
922    fn test_reconnecting_client_leader_addr() {
923        // Test leader address tracking (without actual connection)
924        let addr: SocketAddr = "127.0.0.1:1992".parse().unwrap();
925        let leader: SocketAddr = "127.0.0.1:1993".parse().unwrap();
926
927        // Simulate leader address update logic
928        let follow_leader = true;
929        let mut config_addr = addr;
930
931        // Set leader - simulates set_leader_addr behavior
932        let leader_addr: Option<SocketAddr> = Some(leader);
933        if follow_leader {
934            config_addr = leader;
935        }
936
937        assert_eq!(leader_addr, Some(leader));
938        assert_eq!(config_addr, leader);
939    }
940
941    #[test]
942    fn test_connection_pool_config_auto_reconnect() {
943        let config = ConnectionPoolConfig::new()
944            .with_auto_reconnect(true)
945            .with_max_reconnect_attempts(10);
946
947        assert!(config.auto_reconnect);
948        assert_eq!(config.max_reconnect_attempts, 10);
949    }
950}