1use std::collections::VecDeque;
40use std::net::SocketAddr;
41use std::sync::Arc;
42use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
43use std::time::{Duration, Instant};
44
45use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
46
47use crate::client::{ClientConfig, LanceClient};
48use crate::error::{ClientError, Result};
49use crate::tls::TlsClientConfig;
50
51#[derive(Debug, Clone)]
53pub struct ConnectionPoolConfig {
54 pub max_connections: usize,
56 pub min_idle: usize,
58 pub connect_timeout: Duration,
60 pub acquire_timeout: Duration,
62 pub health_check_interval: Duration,
64 pub max_lifetime: Duration,
66 pub idle_timeout: Duration,
68 pub auto_reconnect: bool,
70 pub max_reconnect_attempts: u32,
72 pub reconnect_base_delay: Duration,
74 pub reconnect_max_delay: Duration,
76 pub tls_config: Option<TlsClientConfig>,
78}
79
80impl Default for ConnectionPoolConfig {
81 fn default() -> Self {
82 Self {
83 max_connections: 10,
84 min_idle: 1,
85 connect_timeout: Duration::from_secs(30),
86 acquire_timeout: Duration::from_secs(30),
87 health_check_interval: Duration::from_secs(30),
88 max_lifetime: Duration::from_secs(3600), idle_timeout: Duration::from_secs(300), auto_reconnect: true,
91 max_reconnect_attempts: 5,
92 reconnect_base_delay: Duration::from_millis(100),
93 reconnect_max_delay: Duration::from_secs(30),
94 tls_config: None,
95 }
96 }
97}
98
99impl ConnectionPoolConfig {
100 pub fn new() -> Self {
102 Self::default()
103 }
104
105 pub fn with_max_connections(mut self, n: usize) -> Self {
107 self.max_connections = n;
108 self
109 }
110
111 pub fn with_min_idle(mut self, n: usize) -> Self {
113 self.min_idle = n;
114 self
115 }
116
117 pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
119 self.connect_timeout = timeout;
120 self
121 }
122
123 pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
125 self.acquire_timeout = timeout;
126 self
127 }
128
129 pub fn with_health_check_interval(mut self, secs: u64) -> Self {
131 self.health_check_interval = Duration::from_secs(secs);
132 self
133 }
134
135 pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
137 self.max_lifetime = lifetime;
138 self
139 }
140
141 pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
143 self.idle_timeout = timeout;
144 self
145 }
146
147 pub fn with_auto_reconnect(mut self, enabled: bool) -> Self {
149 self.auto_reconnect = enabled;
150 self
151 }
152
153 pub fn with_max_reconnect_attempts(mut self, attempts: u32) -> Self {
155 self.max_reconnect_attempts = attempts;
156 self
157 }
158
159 pub fn with_tls(mut self, tls_config: TlsClientConfig) -> Self {
161 self.tls_config = Some(tls_config);
162 self
163 }
164}
165
166#[derive(Debug, Clone, Default)]
168pub struct PoolStats {
169 pub connections_created: u64,
171 pub connections_closed: u64,
173 pub active_connections: u64,
175 pub idle_connections: u64,
177 pub acquire_attempts: u64,
179 pub acquire_successes: u64,
181 pub acquire_failures: u64,
183 pub health_check_failures: u64,
185 pub reconnect_attempts: u64,
187}
188
189#[derive(Debug, Default)]
191struct PoolMetrics {
192 connections_created: AtomicU64,
193 connections_closed: AtomicU64,
194 active_connections: AtomicU64,
195 idle_connections: AtomicU64,
196 acquire_attempts: AtomicU64,
197 acquire_successes: AtomicU64,
198 acquire_failures: AtomicU64,
199 health_check_failures: AtomicU64,
200 reconnect_attempts: AtomicU64,
201}
202
203impl PoolMetrics {
204 fn snapshot(&self) -> PoolStats {
205 PoolStats {
206 connections_created: self.connections_created.load(Ordering::Relaxed),
207 connections_closed: self.connections_closed.load(Ordering::Relaxed),
208 active_connections: self.active_connections.load(Ordering::Relaxed),
209 idle_connections: self.idle_connections.load(Ordering::Relaxed),
210 acquire_attempts: self.acquire_attempts.load(Ordering::Relaxed),
211 acquire_successes: self.acquire_successes.load(Ordering::Relaxed),
212 acquire_failures: self.acquire_failures.load(Ordering::Relaxed),
213 health_check_failures: self.health_check_failures.load(Ordering::Relaxed),
214 reconnect_attempts: self.reconnect_attempts.load(Ordering::Relaxed),
215 }
216 }
217}
218
219struct PooledConnection {
221 client: LanceClient,
222 created_at: Instant,
223 last_used: Instant,
224}
225
226impl PooledConnection {
227 fn new(client: LanceClient) -> Self {
228 let now = Instant::now();
229 Self {
230 client,
231 created_at: now,
232 last_used: now,
233 }
234 }
235
236 fn is_expired(&self, max_lifetime: Duration) -> bool {
237 if max_lifetime.is_zero() {
238 return false;
239 }
240 self.created_at.elapsed() > max_lifetime
241 }
242
243 fn is_idle_too_long(&self, idle_timeout: Duration) -> bool {
244 if idle_timeout.is_zero() {
245 return false;
246 }
247 self.last_used.elapsed() > idle_timeout
248 }
249}
250
251pub struct ConnectionPool {
253 addr: String,
254 config: ConnectionPoolConfig,
255 connections: Arc<Mutex<VecDeque<PooledConnection>>>,
256 semaphore: Arc<Semaphore>,
257 metrics: Arc<PoolMetrics>,
258 running: Arc<AtomicBool>,
259}
260
261impl ConnectionPool {
262 pub async fn new(addr: &str, config: ConnectionPoolConfig) -> Result<Self> {
267 let pool = Self {
268 addr: addr.to_string(),
269 config: config.clone(),
270 connections: Arc::new(Mutex::new(VecDeque::new())),
271 semaphore: Arc::new(Semaphore::new(config.max_connections)),
272 metrics: Arc::new(PoolMetrics::default()),
273 running: Arc::new(AtomicBool::new(true)),
274 };
275
276 for _ in 0..config.min_idle {
278 if let Ok(conn) = pool.create_connection().await {
279 let mut connections = pool.connections.lock().await;
280 connections.push_back(conn);
281 pool.metrics
282 .idle_connections
283 .fetch_add(1, Ordering::Relaxed);
284 }
285 }
286
287 if !config.health_check_interval.is_zero() {
289 let pool_clone = ConnectionPool {
290 addr: pool.addr.clone(),
291 config: pool.config.clone(),
292 connections: pool.connections.clone(),
293 semaphore: pool.semaphore.clone(),
294 metrics: pool.metrics.clone(),
295 running: pool.running.clone(),
296 };
297 tokio::spawn(async move {
298 pool_clone.health_check_task().await;
299 });
300 }
301
302 Ok(pool)
303 }
304
305 pub async fn get(&self) -> Result<PooledClient> {
307 self.metrics
308 .acquire_attempts
309 .fetch_add(1, Ordering::Relaxed);
310
311 let permit = tokio::time::timeout(
313 self.config.acquire_timeout,
314 self.semaphore.clone().acquire_owned(),
315 )
316 .await
317 .map_err(|_| {
318 self.metrics
319 .acquire_failures
320 .fetch_add(1, Ordering::Relaxed);
321 ClientError::Timeout
322 })?
323 .map_err(|_| {
324 self.metrics
325 .acquire_failures
326 .fetch_add(1, Ordering::Relaxed);
327 ClientError::ConnectionClosed
328 })?;
329
330 let conn = {
332 let mut connections = self.connections.lock().await;
333 loop {
334 match connections.pop_front() {
335 Some(conn) => {
336 self.metrics
337 .idle_connections
338 .fetch_sub(1, Ordering::Relaxed);
339
340 if conn.is_expired(self.config.max_lifetime)
342 || conn.is_idle_too_long(self.config.idle_timeout)
343 {
344 self.metrics
345 .connections_closed
346 .fetch_add(1, Ordering::Relaxed);
347 continue;
348 }
349 break Some(conn);
350 },
351 None => break None,
352 }
353 }
354 };
355
356 let conn = match conn {
357 Some(mut c) => {
358 c.last_used = Instant::now();
359 c
360 },
361 None => {
362 self.create_connection().await?
364 },
365 };
366
367 self.metrics
368 .active_connections
369 .fetch_add(1, Ordering::Relaxed);
370 self.metrics
371 .acquire_successes
372 .fetch_add(1, Ordering::Relaxed);
373
374 Ok(PooledClient {
375 conn: Some(conn),
376 pool: self.connections.clone(),
377 metrics: self.metrics.clone(),
378 permit: Some(permit),
379 config: self.config.clone(),
380 })
381 }
382
383 async fn create_connection(&self) -> Result<PooledConnection> {
385 let mut client_config = ClientConfig::new(&self.addr);
386 client_config.connect_timeout = self.config.connect_timeout;
387
388 let client = match &self.config.tls_config {
389 Some(tls_config) => LanceClient::connect_tls(client_config, tls_config.clone()).await?,
390 None => LanceClient::connect(client_config).await?,
391 };
392 self.metrics
393 .connections_created
394 .fetch_add(1, Ordering::Relaxed);
395
396 Ok(PooledConnection::new(client))
397 }
398
399 pub fn stats(&self) -> PoolStats {
401 self.metrics.snapshot()
402 }
403
404 pub async fn close(&self) {
406 self.running.store(false, Ordering::Relaxed);
407
408 let mut connections = self.connections.lock().await;
409 let count = connections.len() as u64;
410 connections.clear();
411 self.metrics
412 .connections_closed
413 .fetch_add(count, Ordering::Relaxed);
414 self.metrics.idle_connections.store(0, Ordering::Relaxed);
415 }
416
417 async fn health_check_task(&self) {
419 let mut interval = tokio::time::interval(self.config.health_check_interval);
420
421 while self.running.load(Ordering::Relaxed) {
422 interval.tick().await;
423
424 let mut to_check = {
426 let mut connections = self.connections.lock().await;
427 std::mem::take(&mut *connections)
428 };
429
430 let mut healthy = VecDeque::new();
431 let _initial_count = to_check.len();
432
433 for mut conn in to_check.drain(..) {
434 if conn.is_expired(self.config.max_lifetime) {
436 self.metrics
437 .connections_closed
438 .fetch_add(1, Ordering::Relaxed);
439 continue;
440 }
441
442 match conn.client.ping().await {
444 Ok(_) => {
445 conn.last_used = Instant::now();
446 healthy.push_back(conn);
447 },
448 Err(_) => {
449 self.metrics
450 .health_check_failures
451 .fetch_add(1, Ordering::Relaxed);
452 self.metrics
453 .connections_closed
454 .fetch_add(1, Ordering::Relaxed);
455 },
456 }
457 }
458
459 {
461 let mut connections = self.connections.lock().await;
462 connections.extend(healthy);
463 self.metrics
464 .idle_connections
465 .store(connections.len() as u64, Ordering::Relaxed);
466 }
467 }
468 }
469}
470
471pub struct PooledClient {
473 conn: Option<PooledConnection>,
474 pool: Arc<Mutex<VecDeque<PooledConnection>>>,
475 metrics: Arc<PoolMetrics>,
476 #[allow(dead_code)]
477 permit: Option<OwnedSemaphorePermit>,
478 #[allow(dead_code)]
479 config: ConnectionPoolConfig,
480}
481
482impl PooledClient {
483 pub fn client(&mut self) -> Result<&mut LanceClient> {
485 match self.conn.as_mut() {
486 Some(conn) => Ok(&mut conn.client),
487 None => Err(ClientError::ConnectionClosed),
488 }
489 }
490
491 pub async fn ping(&mut self) -> Result<Duration> {
493 if let Some(ref mut conn) = self.conn {
494 conn.client.ping().await
495 } else {
496 Err(ClientError::ConnectionClosed)
497 }
498 }
499
500 pub fn mark_unhealthy(&mut self) {
502 self.conn = None;
503 self.metrics
504 .connections_closed
505 .fetch_add(1, Ordering::Relaxed);
506 }
507}
508
509impl Drop for PooledClient {
510 fn drop(&mut self) {
511 if let Some(mut conn) = self.conn.take() {
512 conn.last_used = Instant::now();
513
514 let pool = self.pool.clone();
516 let metrics = self.metrics.clone();
517
518 tokio::spawn(async move {
519 let mut connections = pool.lock().await;
520 connections.push_back(conn);
521 metrics.active_connections.fetch_sub(1, Ordering::Relaxed);
522 metrics.idle_connections.fetch_add(1, Ordering::Relaxed);
523 });
524 } else {
525 self.metrics
526 .active_connections
527 .fetch_sub(1, Ordering::Relaxed);
528 }
529
530 }
532}
533
534pub struct ReconnectingClient {
537 addr: String,
538 config: ClientConfig,
539 tls_config: Option<TlsClientConfig>,
540 client: Option<LanceClient>,
541 reconnect_attempts: u32,
542 max_attempts: u32,
543 base_delay: Duration,
544 max_delay: Duration,
545 leader_addr: Option<SocketAddr>,
547 follow_leader: bool,
549}
550
551impl ReconnectingClient {
552 pub async fn connect(addr: &str) -> Result<Self> {
557 let config = ClientConfig::new(addr);
558 let client = LanceClient::connect(config.clone()).await?;
559
560 Ok(Self {
561 addr: addr.to_string(),
562 config,
563 tls_config: None,
564 client: Some(client),
565 reconnect_attempts: 0,
566 max_attempts: 5,
567 base_delay: Duration::from_millis(100),
568 max_delay: Duration::from_secs(30),
569 leader_addr: None,
570 follow_leader: true,
571 })
572 }
573
574 pub async fn connect_tls(addr: &str, tls_config: TlsClientConfig) -> Result<Self> {
579 let config = ClientConfig::new(addr);
580 let client = LanceClient::connect_tls(config.clone(), tls_config.clone()).await?;
581
582 Ok(Self {
583 addr: addr.to_string(),
584 config,
585 tls_config: Some(tls_config),
586 client: Some(client),
587 reconnect_attempts: 0,
588 max_attempts: 5,
589 base_delay: Duration::from_millis(100),
590 max_delay: Duration::from_secs(30),
591 leader_addr: None,
592 follow_leader: true,
593 })
594 }
595
596 pub fn with_max_attempts(mut self, attempts: u32) -> Self {
598 self.max_attempts = attempts;
599 self
600 }
601
602 pub fn with_follow_leader(mut self, follow: bool) -> Self {
604 self.follow_leader = follow;
605 self
606 }
607
608 pub fn original_addr(&self) -> &str {
610 &self.addr
611 }
612
613 pub fn leader_addr(&self) -> Option<SocketAddr> {
615 self.leader_addr
616 }
617
618 pub fn set_leader_addr(&mut self, addr: SocketAddr) {
620 self.leader_addr = Some(addr);
621 if self.follow_leader {
622 self.config.addr = addr.to_string();
624 }
625 }
626
627 pub fn reconnect_attempts(&self) -> u32 {
629 self.reconnect_attempts
630 }
631
632 pub async fn client(&mut self) -> Result<&mut LanceClient> {
634 if self.client.is_none() {
635 self.reconnect().await?;
636 }
637 self.client.as_mut().ok_or(ClientError::ConnectionClosed)
638 }
639
640 async fn reconnect(&mut self) -> Result<()> {
642 let mut attempts = 0;
643
644 loop {
645 attempts += 1;
646 self.reconnect_attempts += 1;
647
648 let result = match &self.tls_config {
649 Some(tls) => LanceClient::connect_tls(self.config.clone(), tls.clone()).await,
650 None => LanceClient::connect(self.config.clone()).await,
651 };
652
653 match result {
654 Ok(client) => {
655 self.client = Some(client);
656 return Ok(());
657 },
658 Err(e) => {
659 if self.max_attempts > 0 && attempts >= self.max_attempts {
660 return Err(e);
661 }
662
663 let delay = self.base_delay * 2u32.saturating_pow(attempts - 1);
665 let delay = delay.min(self.max_delay);
666
667 tokio::time::sleep(delay).await;
668 },
669 }
670 }
671 }
672
673 pub async fn execute<F, T>(&mut self, op: F) -> Result<T>
675 where
676 F: Fn(
677 &mut LanceClient,
678 )
679 -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send + '_>>,
680 {
681 loop {
682 let client = self.client().await?;
683
684 match op(client).await {
685 Ok(result) => return Ok(result),
686 Err(ClientError::ConnectionClosed) | Err(ClientError::ConnectionFailed(_)) => {
687 self.client = None;
688 },
690 Err(e) => return Err(e),
691 }
692 }
693 }
694
695 pub fn mark_failed(&mut self) {
697 self.client = None;
698 }
699}
700
701pub struct ClusterClient {
707 nodes: Vec<SocketAddr>,
709 primary: Option<SocketAddr>,
711 config: ClientConfig,
713 tls_config: Option<TlsClientConfig>,
715 client: Option<LanceClient>,
717 last_discovery: Option<Instant>,
719 discovery_interval: Duration,
721}
722
723impl ClusterClient {
724 pub async fn connect(seed_addrs: &[&str]) -> Result<Self> {
728 let nodes: Vec<SocketAddr> = seed_addrs.iter().filter_map(|s| s.parse().ok()).collect();
729
730 if nodes.is_empty() {
731 return Err(ClientError::ProtocolError(
732 "No valid seed addresses".to_string(),
733 ));
734 }
735
736 let config = ClientConfig::new(nodes[0].to_string());
737 let mut cluster = Self {
738 nodes,
739 primary: None,
740 config,
741 tls_config: None,
742 client: None,
743 last_discovery: None,
744 discovery_interval: Duration::from_secs(60),
745 };
746
747 cluster.discover_cluster().await?;
748 Ok(cluster)
749 }
750
751 pub async fn connect_tls(seed_addrs: &[&str], tls_config: TlsClientConfig) -> Result<Self> {
753 let nodes: Vec<SocketAddr> = seed_addrs.iter().filter_map(|s| s.parse().ok()).collect();
754
755 if nodes.is_empty() {
756 return Err(ClientError::ProtocolError(
757 "No valid seed addresses".to_string(),
758 ));
759 }
760
761 let config = ClientConfig::new(nodes[0].to_string()).with_tls(tls_config.clone());
762 let mut cluster = Self {
763 nodes,
764 primary: None,
765 config,
766 tls_config: Some(tls_config),
767 client: None,
768 last_discovery: None,
769 discovery_interval: Duration::from_secs(60),
770 };
771
772 cluster.discover_cluster().await?;
773 Ok(cluster)
774 }
775
776 pub fn with_discovery_interval(mut self, interval: Duration) -> Self {
778 self.discovery_interval = interval;
779 self
780 }
781
782 async fn discover_cluster(&mut self) -> Result<()> {
784 for &node in &self.nodes.clone() {
785 let mut config = self.config.clone();
786 config.addr = node.to_string();
787
788 match LanceClient::connect(config).await {
789 Ok(mut client) => {
790 match client.get_cluster_status().await {
791 Ok(status) => {
792 self.primary = status.leader_id.map(|id| {
793 status
795 .peer_states
796 .get(&id)
797 .and_then(|s| s.parse().ok())
798 .unwrap_or(node)
799 });
800 self.last_discovery = Some(Instant::now());
801
802 if let Some(primary_addr) = self.primary {
804 self.config.addr = primary_addr.to_string();
805 self.client =
806 Some(LanceClient::connect(self.config.clone()).await?);
807 } else {
808 self.client = Some(client);
809 }
810 return Ok(());
811 },
812 Err(_) => {
813 self.client = Some(client);
815 self.primary = Some(node);
816 self.last_discovery = Some(Instant::now());
817 return Ok(());
818 },
819 }
820 },
821 Err(_) => continue,
822 }
823 }
824
825 Err(ClientError::ConnectionFailed(std::io::Error::new(
826 std::io::ErrorKind::NotConnected,
827 "Could not connect to any cluster node",
828 )))
829 }
830
831 pub async fn client(&mut self) -> Result<&mut LanceClient> {
833 let needs_refresh = self
835 .last_discovery
836 .map(|t| t.elapsed() > self.discovery_interval)
837 .unwrap_or(true);
838
839 if needs_refresh || self.client.is_none() {
840 self.discover_cluster().await?;
841 }
842
843 self.client.as_mut().ok_or(ClientError::ConnectionClosed)
844 }
845
846 pub fn primary(&self) -> Option<SocketAddr> {
848 self.primary
849 }
850
851 pub fn nodes(&self) -> &[SocketAddr] {
853 &self.nodes
854 }
855
856 pub fn tls_config(&self) -> Option<&TlsClientConfig> {
858 self.tls_config.as_ref()
859 }
860
861 pub fn is_tls_enabled(&self) -> bool {
863 self.tls_config.is_some()
864 }
865
866 pub async fn refresh(&mut self) -> Result<()> {
868 self.discover_cluster().await
869 }
870}
871
872#[cfg(test)]
873#[allow(clippy::unwrap_used)]
874mod tests {
875 use super::*;
876
877 #[test]
878 fn test_pool_config_defaults() {
879 let config = ConnectionPoolConfig::new();
880
881 assert_eq!(config.max_connections, 10);
882 assert_eq!(config.min_idle, 1);
883 assert!(config.auto_reconnect);
884 }
885
886 #[test]
887 fn test_pool_config_builder() {
888 let config = ConnectionPoolConfig::new()
889 .with_max_connections(20)
890 .with_min_idle(5)
891 .with_health_check_interval(60)
892 .with_auto_reconnect(false);
893
894 assert_eq!(config.max_connections, 20);
895 assert_eq!(config.min_idle, 5);
896 assert_eq!(config.health_check_interval, Duration::from_secs(60));
897 assert!(!config.auto_reconnect);
898 }
899
900 #[test]
901 fn test_pool_stats_default() {
902 let stats = PoolStats::default();
903
904 assert_eq!(stats.connections_created, 0);
905 assert_eq!(stats.active_connections, 0);
906 }
907
908 #[test]
909 fn test_pooled_connection_expiry() {
910 use std::thread::sleep;
911
912 let max_lifetime = Duration::from_millis(10);
914 let created_at = Instant::now();
915
916 sleep(Duration::from_millis(20));
917
918 assert!(created_at.elapsed() > max_lifetime);
919 }
920
921 #[test]
922 fn test_reconnecting_client_leader_addr() {
923 let addr: SocketAddr = "127.0.0.1:1992".parse().unwrap();
925 let leader: SocketAddr = "127.0.0.1:1993".parse().unwrap();
926
927 let follow_leader = true;
929 let mut config_addr = addr;
930
931 let leader_addr: Option<SocketAddr> = Some(leader);
933 if follow_leader {
934 config_addr = leader;
935 }
936
937 assert_eq!(leader_addr, Some(leader));
938 assert_eq!(config_addr, leader);
939 }
940
941 #[test]
942 fn test_connection_pool_config_auto_reconnect() {
943 let config = ConnectionPoolConfig::new()
944 .with_auto_reconnect(true)
945 .with_max_reconnect_attempts(10);
946
947 assert!(config.auto_reconnect);
948 assert_eq!(config.max_reconnect_attempts, 10);
949 }
950}