1use std::collections::VecDeque;
40use std::net::SocketAddr;
41use std::sync::Arc;
42use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
43use std::time::{Duration, Instant};
44
45use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
46
47use crate::client::{ClientConfig, LanceClient};
48use crate::error::{ClientError, Result};
49use crate::tls::TlsClientConfig;
50
51#[derive(Debug, Clone)]
53pub struct ConnectionPoolConfig {
54 pub max_connections: usize,
56 pub min_idle: usize,
58 pub connect_timeout: Duration,
60 pub acquire_timeout: Duration,
62 pub health_check_interval: Duration,
64 pub max_lifetime: Duration,
66 pub idle_timeout: Duration,
68 pub auto_reconnect: bool,
70 pub max_reconnect_attempts: u32,
72 pub reconnect_base_delay: Duration,
74 pub reconnect_max_delay: Duration,
76 pub tls_config: Option<TlsClientConfig>,
78}
79
80impl Default for ConnectionPoolConfig {
81 fn default() -> Self {
82 Self {
83 max_connections: 10,
84 min_idle: 1,
85 connect_timeout: Duration::from_secs(30),
86 acquire_timeout: Duration::from_secs(30),
87 health_check_interval: Duration::from_secs(30),
88 max_lifetime: Duration::from_secs(3600), idle_timeout: Duration::from_secs(300), auto_reconnect: true,
91 max_reconnect_attempts: 5,
92 reconnect_base_delay: Duration::from_millis(100),
93 reconnect_max_delay: Duration::from_secs(30),
94 tls_config: None,
95 }
96 }
97}
98
99impl ConnectionPoolConfig {
100 pub fn new() -> Self {
102 Self::default()
103 }
104
105 pub fn with_max_connections(mut self, n: usize) -> Self {
107 self.max_connections = n;
108 self
109 }
110
111 pub fn with_min_idle(mut self, n: usize) -> Self {
113 self.min_idle = n;
114 self
115 }
116
117 pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
119 self.connect_timeout = timeout;
120 self
121 }
122
123 pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
125 self.acquire_timeout = timeout;
126 self
127 }
128
129 pub fn with_health_check_interval(mut self, secs: u64) -> Self {
131 self.health_check_interval = Duration::from_secs(secs);
132 self
133 }
134
135 pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
137 self.max_lifetime = lifetime;
138 self
139 }
140
141 pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
143 self.idle_timeout = timeout;
144 self
145 }
146
147 pub fn with_auto_reconnect(mut self, enabled: bool) -> Self {
149 self.auto_reconnect = enabled;
150 self
151 }
152
153 pub fn with_max_reconnect_attempts(mut self, attempts: u32) -> Self {
155 self.max_reconnect_attempts = attempts;
156 self
157 }
158
159 pub fn with_tls(mut self, tls_config: TlsClientConfig) -> Self {
161 self.tls_config = Some(tls_config);
162 self
163 }
164}
165
166#[derive(Debug, Clone, Default)]
168pub struct PoolStats {
169 pub connections_created: u64,
171 pub connections_closed: u64,
173 pub active_connections: u64,
175 pub idle_connections: u64,
177 pub acquire_attempts: u64,
179 pub acquire_successes: u64,
181 pub acquire_failures: u64,
183 pub health_check_failures: u64,
185 pub reconnect_attempts: u64,
187}
188
189#[derive(Debug, Default)]
191struct PoolMetrics {
192 connections_created: AtomicU64,
193 connections_closed: AtomicU64,
194 active_connections: AtomicU64,
195 idle_connections: AtomicU64,
196 acquire_attempts: AtomicU64,
197 acquire_successes: AtomicU64,
198 acquire_failures: AtomicU64,
199 health_check_failures: AtomicU64,
200 reconnect_attempts: AtomicU64,
201}
202
203impl PoolMetrics {
204 fn snapshot(&self) -> PoolStats {
205 PoolStats {
206 connections_created: self.connections_created.load(Ordering::Relaxed),
207 connections_closed: self.connections_closed.load(Ordering::Relaxed),
208 active_connections: self.active_connections.load(Ordering::Relaxed),
209 idle_connections: self.idle_connections.load(Ordering::Relaxed),
210 acquire_attempts: self.acquire_attempts.load(Ordering::Relaxed),
211 acquire_successes: self.acquire_successes.load(Ordering::Relaxed),
212 acquire_failures: self.acquire_failures.load(Ordering::Relaxed),
213 health_check_failures: self.health_check_failures.load(Ordering::Relaxed),
214 reconnect_attempts: self.reconnect_attempts.load(Ordering::Relaxed),
215 }
216 }
217}
218
219struct PooledConnection {
221 client: LanceClient,
222 created_at: Instant,
223 last_used: Instant,
224}
225
226impl PooledConnection {
227 fn new(client: LanceClient) -> Self {
228 let now = Instant::now();
229 Self {
230 client,
231 created_at: now,
232 last_used: now,
233 }
234 }
235
236 fn is_expired(&self, max_lifetime: Duration) -> bool {
237 if max_lifetime.is_zero() {
238 return false;
239 }
240 self.created_at.elapsed() > max_lifetime
241 }
242
243 fn is_idle_too_long(&self, idle_timeout: Duration) -> bool {
244 if idle_timeout.is_zero() {
245 return false;
246 }
247 self.last_used.elapsed() > idle_timeout
248 }
249}
250
251pub struct ConnectionPool {
253 addr: SocketAddr,
254 config: ConnectionPoolConfig,
255 connections: Arc<Mutex<VecDeque<PooledConnection>>>,
256 semaphore: Arc<Semaphore>,
257 metrics: Arc<PoolMetrics>,
258 running: Arc<AtomicBool>,
259}
260
261impl ConnectionPool {
262 pub async fn new(addr: &str, config: ConnectionPoolConfig) -> Result<Self> {
264 let socket_addr: SocketAddr = addr
265 .parse()
266 .map_err(|e| ClientError::ProtocolError(format!("Invalid address: {}", e)))?;
267
268 let pool = Self {
269 addr: socket_addr,
270 config: config.clone(),
271 connections: Arc::new(Mutex::new(VecDeque::new())),
272 semaphore: Arc::new(Semaphore::new(config.max_connections)),
273 metrics: Arc::new(PoolMetrics::default()),
274 running: Arc::new(AtomicBool::new(true)),
275 };
276
277 for _ in 0..config.min_idle {
279 if let Ok(conn) = pool.create_connection().await {
280 let mut connections = pool.connections.lock().await;
281 connections.push_back(conn);
282 pool.metrics
283 .idle_connections
284 .fetch_add(1, Ordering::Relaxed);
285 }
286 }
287
288 if !config.health_check_interval.is_zero() {
290 let pool_clone = ConnectionPool {
291 addr: pool.addr,
292 config: pool.config.clone(),
293 connections: pool.connections.clone(),
294 semaphore: pool.semaphore.clone(),
295 metrics: pool.metrics.clone(),
296 running: pool.running.clone(),
297 };
298 tokio::spawn(async move {
299 pool_clone.health_check_task().await;
300 });
301 }
302
303 Ok(pool)
304 }
305
306 pub async fn get(&self) -> Result<PooledClient> {
308 self.metrics
309 .acquire_attempts
310 .fetch_add(1, Ordering::Relaxed);
311
312 let permit = tokio::time::timeout(
314 self.config.acquire_timeout,
315 self.semaphore.clone().acquire_owned(),
316 )
317 .await
318 .map_err(|_| {
319 self.metrics
320 .acquire_failures
321 .fetch_add(1, Ordering::Relaxed);
322 ClientError::Timeout
323 })?
324 .map_err(|_| {
325 self.metrics
326 .acquire_failures
327 .fetch_add(1, Ordering::Relaxed);
328 ClientError::ConnectionClosed
329 })?;
330
331 let conn = {
333 let mut connections = self.connections.lock().await;
334 loop {
335 match connections.pop_front() {
336 Some(conn) => {
337 self.metrics
338 .idle_connections
339 .fetch_sub(1, Ordering::Relaxed);
340
341 if conn.is_expired(self.config.max_lifetime)
343 || conn.is_idle_too_long(self.config.idle_timeout)
344 {
345 self.metrics
346 .connections_closed
347 .fetch_add(1, Ordering::Relaxed);
348 continue;
349 }
350 break Some(conn);
351 },
352 None => break None,
353 }
354 }
355 };
356
357 let conn = match conn {
358 Some(mut c) => {
359 c.last_used = Instant::now();
360 c
361 },
362 None => {
363 self.create_connection().await?
365 },
366 };
367
368 self.metrics
369 .active_connections
370 .fetch_add(1, Ordering::Relaxed);
371 self.metrics
372 .acquire_successes
373 .fetch_add(1, Ordering::Relaxed);
374
375 Ok(PooledClient {
376 conn: Some(conn),
377 pool: self.connections.clone(),
378 metrics: self.metrics.clone(),
379 permit: Some(permit),
380 config: self.config.clone(),
381 })
382 }
383
384 async fn create_connection(&self) -> Result<PooledConnection> {
386 let mut client_config = ClientConfig::new(self.addr);
387 client_config.connect_timeout = self.config.connect_timeout;
388
389 let client = match &self.config.tls_config {
390 Some(tls_config) => LanceClient::connect_tls(client_config, tls_config.clone()).await?,
391 None => LanceClient::connect(client_config).await?,
392 };
393 self.metrics
394 .connections_created
395 .fetch_add(1, Ordering::Relaxed);
396
397 Ok(PooledConnection::new(client))
398 }
399
400 pub fn stats(&self) -> PoolStats {
402 self.metrics.snapshot()
403 }
404
405 pub async fn close(&self) {
407 self.running.store(false, Ordering::Relaxed);
408
409 let mut connections = self.connections.lock().await;
410 let count = connections.len() as u64;
411 connections.clear();
412 self.metrics
413 .connections_closed
414 .fetch_add(count, Ordering::Relaxed);
415 self.metrics.idle_connections.store(0, Ordering::Relaxed);
416 }
417
418 async fn health_check_task(&self) {
420 let mut interval = tokio::time::interval(self.config.health_check_interval);
421
422 while self.running.load(Ordering::Relaxed) {
423 interval.tick().await;
424
425 let mut to_check = {
427 let mut connections = self.connections.lock().await;
428 std::mem::take(&mut *connections)
429 };
430
431 let mut healthy = VecDeque::new();
432 let _initial_count = to_check.len();
433
434 for mut conn in to_check.drain(..) {
435 if conn.is_expired(self.config.max_lifetime) {
437 self.metrics
438 .connections_closed
439 .fetch_add(1, Ordering::Relaxed);
440 continue;
441 }
442
443 match conn.client.ping().await {
445 Ok(_) => {
446 conn.last_used = Instant::now();
447 healthy.push_back(conn);
448 },
449 Err(_) => {
450 self.metrics
451 .health_check_failures
452 .fetch_add(1, Ordering::Relaxed);
453 self.metrics
454 .connections_closed
455 .fetch_add(1, Ordering::Relaxed);
456 },
457 }
458 }
459
460 {
462 let mut connections = self.connections.lock().await;
463 connections.extend(healthy);
464 self.metrics
465 .idle_connections
466 .store(connections.len() as u64, Ordering::Relaxed);
467 }
468 }
469 }
470}
471
472pub struct PooledClient {
474 conn: Option<PooledConnection>,
475 pool: Arc<Mutex<VecDeque<PooledConnection>>>,
476 metrics: Arc<PoolMetrics>,
477 #[allow(dead_code)]
478 permit: Option<OwnedSemaphorePermit>,
479 #[allow(dead_code)]
480 config: ConnectionPoolConfig,
481}
482
483impl PooledClient {
484 pub fn client(&mut self) -> Result<&mut LanceClient> {
486 match self.conn.as_mut() {
487 Some(conn) => Ok(&mut conn.client),
488 None => Err(ClientError::ConnectionClosed),
489 }
490 }
491
492 pub async fn ping(&mut self) -> Result<Duration> {
494 if let Some(ref mut conn) = self.conn {
495 conn.client.ping().await
496 } else {
497 Err(ClientError::ConnectionClosed)
498 }
499 }
500
501 pub fn mark_unhealthy(&mut self) {
503 self.conn = None;
504 self.metrics
505 .connections_closed
506 .fetch_add(1, Ordering::Relaxed);
507 }
508}
509
510impl Drop for PooledClient {
511 fn drop(&mut self) {
512 if let Some(mut conn) = self.conn.take() {
513 conn.last_used = Instant::now();
514
515 let pool = self.pool.clone();
517 let metrics = self.metrics.clone();
518
519 tokio::spawn(async move {
520 let mut connections = pool.lock().await;
521 connections.push_back(conn);
522 metrics.active_connections.fetch_sub(1, Ordering::Relaxed);
523 metrics.idle_connections.fetch_add(1, Ordering::Relaxed);
524 });
525 } else {
526 self.metrics
527 .active_connections
528 .fetch_sub(1, Ordering::Relaxed);
529 }
530
531 }
533}
534
535pub struct ReconnectingClient {
538 addr: String,
539 config: ClientConfig,
540 tls_config: Option<TlsClientConfig>,
541 client: Option<LanceClient>,
542 reconnect_attempts: u32,
543 max_attempts: u32,
544 base_delay: Duration,
545 max_delay: Duration,
546 leader_addr: Option<SocketAddr>,
548 follow_leader: bool,
550}
551
552impl ReconnectingClient {
553 pub async fn connect(addr: &str) -> Result<Self> {
555 let socket_addr: SocketAddr = addr
556 .parse()
557 .map_err(|e| ClientError::ProtocolError(format!("Invalid address: {}", e)))?;
558
559 let config = ClientConfig::new(socket_addr);
560 let client = LanceClient::connect(config.clone()).await?;
561
562 Ok(Self {
563 addr: addr.to_string(),
564 config,
565 tls_config: None,
566 client: Some(client),
567 reconnect_attempts: 0,
568 max_attempts: 5,
569 base_delay: Duration::from_millis(100),
570 max_delay: Duration::from_secs(30),
571 leader_addr: None,
572 follow_leader: true,
573 })
574 }
575
576 pub async fn connect_tls(addr: &str, tls_config: TlsClientConfig) -> Result<Self> {
578 let socket_addr: SocketAddr = addr
579 .parse()
580 .map_err(|e| ClientError::ProtocolError(format!("Invalid address: {}", e)))?;
581
582 let config = ClientConfig::new(socket_addr);
583 let client = LanceClient::connect_tls(config.clone(), tls_config.clone()).await?;
584
585 Ok(Self {
586 addr: addr.to_string(),
587 config,
588 tls_config: Some(tls_config),
589 client: Some(client),
590 reconnect_attempts: 0,
591 max_attempts: 5,
592 base_delay: Duration::from_millis(100),
593 max_delay: Duration::from_secs(30),
594 leader_addr: None,
595 follow_leader: true,
596 })
597 }
598
599 pub fn with_max_attempts(mut self, attempts: u32) -> Self {
601 self.max_attempts = attempts;
602 self
603 }
604
605 pub fn with_follow_leader(mut self, follow: bool) -> Self {
607 self.follow_leader = follow;
608 self
609 }
610
611 pub fn original_addr(&self) -> &str {
613 &self.addr
614 }
615
616 pub fn leader_addr(&self) -> Option<SocketAddr> {
618 self.leader_addr
619 }
620
621 pub fn set_leader_addr(&mut self, addr: SocketAddr) {
623 self.leader_addr = Some(addr);
624 if self.follow_leader {
625 self.config.addr = addr;
627 }
628 }
629
630 pub fn reconnect_attempts(&self) -> u32 {
632 self.reconnect_attempts
633 }
634
635 pub async fn client(&mut self) -> Result<&mut LanceClient> {
637 if self.client.is_none() {
638 self.reconnect().await?;
639 }
640 self.client.as_mut().ok_or(ClientError::ConnectionClosed)
641 }
642
643 async fn reconnect(&mut self) -> Result<()> {
645 let mut attempts = 0;
646
647 loop {
648 attempts += 1;
649 self.reconnect_attempts += 1;
650
651 let result = match &self.tls_config {
652 Some(tls) => LanceClient::connect_tls(self.config.clone(), tls.clone()).await,
653 None => LanceClient::connect(self.config.clone()).await,
654 };
655
656 match result {
657 Ok(client) => {
658 self.client = Some(client);
659 return Ok(());
660 },
661 Err(e) => {
662 if self.max_attempts > 0 && attempts >= self.max_attempts {
663 return Err(e);
664 }
665
666 let delay = self.base_delay * 2u32.saturating_pow(attempts - 1);
668 let delay = delay.min(self.max_delay);
669
670 tokio::time::sleep(delay).await;
671 },
672 }
673 }
674 }
675
676 pub async fn execute<F, T>(&mut self, op: F) -> Result<T>
678 where
679 F: Fn(
680 &mut LanceClient,
681 )
682 -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send + '_>>,
683 {
684 loop {
685 let client = self.client().await?;
686
687 match op(client).await {
688 Ok(result) => return Ok(result),
689 Err(ClientError::ConnectionClosed) | Err(ClientError::ConnectionFailed(_)) => {
690 self.client = None;
691 },
693 Err(e) => return Err(e),
694 }
695 }
696 }
697
698 pub fn mark_failed(&mut self) {
700 self.client = None;
701 }
702}
703
704pub struct ClusterClient {
710 nodes: Vec<SocketAddr>,
712 primary: Option<SocketAddr>,
714 config: ClientConfig,
716 tls_config: Option<TlsClientConfig>,
718 client: Option<LanceClient>,
720 last_discovery: Option<Instant>,
722 discovery_interval: Duration,
724}
725
726impl ClusterClient {
727 pub async fn connect(seed_addrs: &[&str]) -> Result<Self> {
729 let nodes: Vec<SocketAddr> = seed_addrs.iter().filter_map(|s| s.parse().ok()).collect();
730
731 if nodes.is_empty() {
732 return Err(ClientError::ProtocolError(
733 "No valid seed addresses".to_string(),
734 ));
735 }
736
737 let config = ClientConfig::new(nodes[0]);
738 let mut cluster = Self {
739 nodes,
740 primary: None,
741 config,
742 tls_config: None,
743 client: None,
744 last_discovery: None,
745 discovery_interval: Duration::from_secs(60),
746 };
747
748 cluster.discover_cluster().await?;
749 Ok(cluster)
750 }
751
752 pub async fn connect_tls(seed_addrs: &[&str], tls_config: TlsClientConfig) -> Result<Self> {
754 let nodes: Vec<SocketAddr> = seed_addrs.iter().filter_map(|s| s.parse().ok()).collect();
755
756 if nodes.is_empty() {
757 return Err(ClientError::ProtocolError(
758 "No valid seed addresses".to_string(),
759 ));
760 }
761
762 let config = ClientConfig::new(nodes[0]).with_tls(tls_config.clone());
763 let mut cluster = Self {
764 nodes,
765 primary: None,
766 config,
767 tls_config: Some(tls_config),
768 client: None,
769 last_discovery: None,
770 discovery_interval: Duration::from_secs(60),
771 };
772
773 cluster.discover_cluster().await?;
774 Ok(cluster)
775 }
776
777 pub fn with_discovery_interval(mut self, interval: Duration) -> Self {
779 self.discovery_interval = interval;
780 self
781 }
782
783 async fn discover_cluster(&mut self) -> Result<()> {
785 for &node in &self.nodes.clone() {
786 let mut config = self.config.clone();
787 config.addr = node;
788
789 match LanceClient::connect(config).await {
790 Ok(mut client) => {
791 match client.get_cluster_status().await {
792 Ok(status) => {
793 self.primary = status.leader_id.map(|id| {
794 status
796 .peer_states
797 .get(&id)
798 .and_then(|s| s.parse().ok())
799 .unwrap_or(node)
800 });
801 self.last_discovery = Some(Instant::now());
802
803 if let Some(primary_addr) = self.primary {
805 self.config.addr = primary_addr;
806 self.client =
807 Some(LanceClient::connect(self.config.clone()).await?);
808 } else {
809 self.client = Some(client);
810 }
811 return Ok(());
812 },
813 Err(_) => {
814 self.client = Some(client);
816 self.primary = Some(node);
817 self.last_discovery = Some(Instant::now());
818 return Ok(());
819 },
820 }
821 },
822 Err(_) => continue,
823 }
824 }
825
826 Err(ClientError::ConnectionFailed(std::io::Error::new(
827 std::io::ErrorKind::NotConnected,
828 "Could not connect to any cluster node",
829 )))
830 }
831
832 pub async fn client(&mut self) -> Result<&mut LanceClient> {
834 let needs_refresh = self
836 .last_discovery
837 .map(|t| t.elapsed() > self.discovery_interval)
838 .unwrap_or(true);
839
840 if needs_refresh || self.client.is_none() {
841 self.discover_cluster().await?;
842 }
843
844 self.client.as_mut().ok_or(ClientError::ConnectionClosed)
845 }
846
847 pub fn primary(&self) -> Option<SocketAddr> {
849 self.primary
850 }
851
852 pub fn nodes(&self) -> &[SocketAddr] {
854 &self.nodes
855 }
856
857 pub fn tls_config(&self) -> Option<&TlsClientConfig> {
859 self.tls_config.as_ref()
860 }
861
862 pub fn is_tls_enabled(&self) -> bool {
864 self.tls_config.is_some()
865 }
866
867 pub async fn refresh(&mut self) -> Result<()> {
869 self.discover_cluster().await
870 }
871}
872
873#[cfg(test)]
874#[allow(clippy::unwrap_used)]
875mod tests {
876 use super::*;
877
878 #[test]
879 fn test_pool_config_defaults() {
880 let config = ConnectionPoolConfig::new();
881
882 assert_eq!(config.max_connections, 10);
883 assert_eq!(config.min_idle, 1);
884 assert!(config.auto_reconnect);
885 }
886
887 #[test]
888 fn test_pool_config_builder() {
889 let config = ConnectionPoolConfig::new()
890 .with_max_connections(20)
891 .with_min_idle(5)
892 .with_health_check_interval(60)
893 .with_auto_reconnect(false);
894
895 assert_eq!(config.max_connections, 20);
896 assert_eq!(config.min_idle, 5);
897 assert_eq!(config.health_check_interval, Duration::from_secs(60));
898 assert!(!config.auto_reconnect);
899 }
900
901 #[test]
902 fn test_pool_stats_default() {
903 let stats = PoolStats::default();
904
905 assert_eq!(stats.connections_created, 0);
906 assert_eq!(stats.active_connections, 0);
907 }
908
909 #[test]
910 fn test_pooled_connection_expiry() {
911 use std::thread::sleep;
912
913 let max_lifetime = Duration::from_millis(10);
915 let created_at = Instant::now();
916
917 sleep(Duration::from_millis(20));
918
919 assert!(created_at.elapsed() > max_lifetime);
920 }
921
922 #[test]
923 fn test_reconnecting_client_leader_addr() {
924 let addr: SocketAddr = "127.0.0.1:1992".parse().unwrap();
926 let leader: SocketAddr = "127.0.0.1:1993".parse().unwrap();
927
928 let follow_leader = true;
930 let mut config_addr = addr;
931
932 let leader_addr: Option<SocketAddr> = Some(leader);
934 if follow_leader {
935 config_addr = leader;
936 }
937
938 assert_eq!(leader_addr, Some(leader));
939 assert_eq!(config_addr, leader);
940 }
941
942 #[test]
943 fn test_connection_pool_config_auto_reconnect() {
944 let config = ConnectionPoolConfig::new()
945 .with_auto_reconnect(true)
946 .with_max_reconnect_attempts(10);
947
948 assert!(config.auto_reconnect);
949 assert_eq!(config.max_reconnect_attempts, 10);
950 }
951}