use std::collections::VecDeque;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
use crate::client::{ClientConfig, LanceClient};
use crate::error::{ClientError, Result};
use crate::tls::TlsClientConfig;
#[derive(Debug, Clone)]
pub struct ConnectionPoolConfig {
pub max_connections: usize,
pub min_idle: usize,
pub connect_timeout: Duration,
pub acquire_timeout: Duration,
pub health_check_interval: Duration,
pub max_lifetime: Duration,
pub idle_timeout: Duration,
pub auto_reconnect: bool,
pub max_reconnect_attempts: u32,
pub reconnect_base_delay: Duration,
pub reconnect_max_delay: Duration,
pub tls_config: Option<TlsClientConfig>,
}
impl Default for ConnectionPoolConfig {
fn default() -> Self {
Self {
max_connections: 10,
min_idle: 1,
connect_timeout: Duration::from_secs(30),
acquire_timeout: Duration::from_secs(30),
health_check_interval: Duration::from_secs(30),
max_lifetime: Duration::from_secs(3600), idle_timeout: Duration::from_secs(300), auto_reconnect: true,
max_reconnect_attempts: 5,
reconnect_base_delay: Duration::from_millis(100),
reconnect_max_delay: Duration::from_secs(30),
tls_config: None,
}
}
}
impl ConnectionPoolConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_connections(mut self, n: usize) -> Self {
self.max_connections = n;
self
}
pub fn with_min_idle(mut self, n: usize) -> Self {
self.min_idle = n;
self
}
pub fn with_connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
self.acquire_timeout = timeout;
self
}
pub fn with_health_check_interval(mut self, secs: u64) -> Self {
self.health_check_interval = Duration::from_secs(secs);
self
}
pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
self.max_lifetime = lifetime;
self
}
pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
self.idle_timeout = timeout;
self
}
pub fn with_auto_reconnect(mut self, enabled: bool) -> Self {
self.auto_reconnect = enabled;
self
}
pub fn with_max_reconnect_attempts(mut self, attempts: u32) -> Self {
self.max_reconnect_attempts = attempts;
self
}
pub fn with_tls(mut self, tls_config: TlsClientConfig) -> Self {
self.tls_config = Some(tls_config);
self
}
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub connections_created: u64,
pub connections_closed: u64,
pub active_connections: u64,
pub idle_connections: u64,
pub acquire_attempts: u64,
pub acquire_successes: u64,
pub acquire_failures: u64,
pub health_check_failures: u64,
pub reconnect_attempts: u64,
}
#[derive(Debug, Default)]
struct PoolMetrics {
connections_created: AtomicU64,
connections_closed: AtomicU64,
active_connections: AtomicU64,
idle_connections: AtomicU64,
acquire_attempts: AtomicU64,
acquire_successes: AtomicU64,
acquire_failures: AtomicU64,
health_check_failures: AtomicU64,
reconnect_attempts: AtomicU64,
}
impl PoolMetrics {
fn snapshot(&self) -> PoolStats {
PoolStats {
connections_created: self.connections_created.load(Ordering::Relaxed),
connections_closed: self.connections_closed.load(Ordering::Relaxed),
active_connections: self.active_connections.load(Ordering::Relaxed),
idle_connections: self.idle_connections.load(Ordering::Relaxed),
acquire_attempts: self.acquire_attempts.load(Ordering::Relaxed),
acquire_successes: self.acquire_successes.load(Ordering::Relaxed),
acquire_failures: self.acquire_failures.load(Ordering::Relaxed),
health_check_failures: self.health_check_failures.load(Ordering::Relaxed),
reconnect_attempts: self.reconnect_attempts.load(Ordering::Relaxed),
}
}
}
struct PooledConnection {
client: LanceClient,
created_at: Instant,
last_used: Instant,
}
impl PooledConnection {
fn new(client: LanceClient) -> Self {
let now = Instant::now();
Self {
client,
created_at: now,
last_used: now,
}
}
fn is_expired(&self, max_lifetime: Duration) -> bool {
if max_lifetime.is_zero() {
return false;
}
self.created_at.elapsed() > max_lifetime
}
fn is_idle_too_long(&self, idle_timeout: Duration) -> bool {
if idle_timeout.is_zero() {
return false;
}
self.last_used.elapsed() > idle_timeout
}
}
pub struct ConnectionPool {
addr: String,
config: ConnectionPoolConfig,
connections: Arc<Mutex<VecDeque<PooledConnection>>>,
semaphore: Arc<Semaphore>,
metrics: Arc<PoolMetrics>,
running: Arc<AtomicBool>,
}
impl ConnectionPool {
pub async fn new(addr: &str, config: ConnectionPoolConfig) -> Result<Self> {
let pool = Self {
addr: addr.to_string(),
config: config.clone(),
connections: Arc::new(Mutex::new(VecDeque::new())),
semaphore: Arc::new(Semaphore::new(config.max_connections)),
metrics: Arc::new(PoolMetrics::default()),
running: Arc::new(AtomicBool::new(true)),
};
for _ in 0..config.min_idle {
if let Ok(conn) = pool.create_connection().await {
let mut connections = pool.connections.lock().await;
connections.push_back(conn);
pool.metrics
.idle_connections
.fetch_add(1, Ordering::Relaxed);
}
}
if !config.health_check_interval.is_zero() {
let pool_clone = ConnectionPool {
addr: pool.addr.clone(),
config: pool.config.clone(),
connections: pool.connections.clone(),
semaphore: pool.semaphore.clone(),
metrics: pool.metrics.clone(),
running: pool.running.clone(),
};
tokio::spawn(async move {
pool_clone.health_check_task().await;
});
}
Ok(pool)
}
pub async fn get(&self) -> Result<PooledClient> {
self.metrics
.acquire_attempts
.fetch_add(1, Ordering::Relaxed);
let permit = tokio::time::timeout(
self.config.acquire_timeout,
self.semaphore.clone().acquire_owned(),
)
.await
.map_err(|_| {
self.metrics
.acquire_failures
.fetch_add(1, Ordering::Relaxed);
ClientError::Timeout
})?
.map_err(|_| {
self.metrics
.acquire_failures
.fetch_add(1, Ordering::Relaxed);
ClientError::ConnectionClosed
})?;
let conn = {
let mut connections = self.connections.lock().await;
loop {
match connections.pop_front() {
Some(conn) => {
self.metrics
.idle_connections
.fetch_sub(1, Ordering::Relaxed);
if conn.is_expired(self.config.max_lifetime)
|| conn.is_idle_too_long(self.config.idle_timeout)
{
self.metrics
.connections_closed
.fetch_add(1, Ordering::Relaxed);
continue;
}
break Some(conn);
},
None => break None,
}
}
};
let conn = match conn {
Some(mut c) => {
c.last_used = Instant::now();
c
},
None => {
self.create_connection().await?
},
};
self.metrics
.active_connections
.fetch_add(1, Ordering::Relaxed);
self.metrics
.acquire_successes
.fetch_add(1, Ordering::Relaxed);
Ok(PooledClient {
conn: Some(conn),
pool: self.connections.clone(),
metrics: self.metrics.clone(),
permit: Some(permit),
config: self.config.clone(),
})
}
async fn create_connection(&self) -> Result<PooledConnection> {
let mut client_config = ClientConfig::new(&self.addr);
client_config.connect_timeout = self.config.connect_timeout;
let client = match &self.config.tls_config {
Some(tls_config) => LanceClient::connect_tls(client_config, tls_config.clone()).await?,
None => LanceClient::connect(client_config).await?,
};
self.metrics
.connections_created
.fetch_add(1, Ordering::Relaxed);
Ok(PooledConnection::new(client))
}
pub fn stats(&self) -> PoolStats {
self.metrics.snapshot()
}
pub async fn close(&self) {
self.running.store(false, Ordering::Relaxed);
let mut connections = self.connections.lock().await;
let count = connections.len() as u64;
connections.clear();
self.metrics
.connections_closed
.fetch_add(count, Ordering::Relaxed);
self.metrics.idle_connections.store(0, Ordering::Relaxed);
}
async fn health_check_task(&self) {
let mut interval = tokio::time::interval(self.config.health_check_interval);
while self.running.load(Ordering::Relaxed) {
interval.tick().await;
let mut to_check = {
let mut connections = self.connections.lock().await;
std::mem::take(&mut *connections)
};
let mut healthy = VecDeque::new();
let _initial_count = to_check.len();
for mut conn in to_check.drain(..) {
if conn.is_expired(self.config.max_lifetime) {
self.metrics
.connections_closed
.fetch_add(1, Ordering::Relaxed);
continue;
}
match conn.client.ping().await {
Ok(_) => {
conn.last_used = Instant::now();
healthy.push_back(conn);
},
Err(_) => {
self.metrics
.health_check_failures
.fetch_add(1, Ordering::Relaxed);
self.metrics
.connections_closed
.fetch_add(1, Ordering::Relaxed);
},
}
}
{
let mut connections = self.connections.lock().await;
connections.extend(healthy);
self.metrics
.idle_connections
.store(connections.len() as u64, Ordering::Relaxed);
}
}
}
}
pub struct PooledClient {
conn: Option<PooledConnection>,
pool: Arc<Mutex<VecDeque<PooledConnection>>>,
metrics: Arc<PoolMetrics>,
#[allow(dead_code)]
permit: Option<OwnedSemaphorePermit>,
#[allow(dead_code)]
config: ConnectionPoolConfig,
}
impl PooledClient {
pub fn client(&mut self) -> Result<&mut LanceClient> {
match self.conn.as_mut() {
Some(conn) => Ok(&mut conn.client),
None => Err(ClientError::ConnectionClosed),
}
}
pub async fn ping(&mut self) -> Result<Duration> {
if let Some(ref mut conn) = self.conn {
conn.client.ping().await
} else {
Err(ClientError::ConnectionClosed)
}
}
pub fn mark_unhealthy(&mut self) {
self.conn = None;
self.metrics
.connections_closed
.fetch_add(1, Ordering::Relaxed);
}
}
impl Drop for PooledClient {
fn drop(&mut self) {
if let Some(mut conn) = self.conn.take() {
conn.last_used = Instant::now();
let pool = self.pool.clone();
let metrics = self.metrics.clone();
tokio::spawn(async move {
let mut connections = pool.lock().await;
connections.push_back(conn);
metrics.active_connections.fetch_sub(1, Ordering::Relaxed);
metrics.idle_connections.fetch_add(1, Ordering::Relaxed);
});
} else {
self.metrics
.active_connections
.fetch_sub(1, Ordering::Relaxed);
}
}
}
pub struct ReconnectingClient {
addr: String,
config: ClientConfig,
tls_config: Option<TlsClientConfig>,
client: Option<LanceClient>,
reconnect_attempts: u32,
max_attempts: u32,
base_delay: Duration,
max_delay: Duration,
leader_addr: Option<SocketAddr>,
follow_leader: bool,
}
impl ReconnectingClient {
pub async fn connect(addr: &str) -> Result<Self> {
let config = ClientConfig::new(addr);
let client = LanceClient::connect(config.clone()).await?;
Ok(Self {
addr: addr.to_string(),
config,
tls_config: None,
client: Some(client),
reconnect_attempts: 0,
max_attempts: 5,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(30),
leader_addr: None,
follow_leader: true,
})
}
pub fn from_existing(client: LanceClient, addr: &str) -> Self {
let config = client.config().clone();
Self {
addr: addr.to_string(),
config,
tls_config: None,
client: Some(client),
reconnect_attempts: 0,
max_attempts: 0, base_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(30),
leader_addr: None,
follow_leader: true,
}
}
pub async fn connect_tls(addr: &str, tls_config: TlsClientConfig) -> Result<Self> {
let config = ClientConfig::new(addr);
let client = LanceClient::connect_tls(config.clone(), tls_config.clone()).await?;
Ok(Self {
addr: addr.to_string(),
config,
tls_config: Some(tls_config),
client: Some(client),
reconnect_attempts: 0,
max_attempts: 5,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(30),
leader_addr: None,
follow_leader: true,
})
}
pub fn with_max_attempts(mut self, attempts: u32) -> Self {
self.max_attempts = attempts;
self
}
pub fn with_unlimited_retries(mut self) -> Self {
self.max_attempts = 0;
self
}
pub fn with_base_delay(mut self, delay: Duration) -> Self {
self.base_delay = delay;
self
}
pub fn with_max_delay(mut self, delay: Duration) -> Self {
self.max_delay = delay;
self
}
pub fn with_follow_leader(mut self, follow: bool) -> Self {
self.follow_leader = follow;
self
}
pub fn original_addr(&self) -> &str {
&self.addr
}
pub fn leader_addr(&self) -> Option<SocketAddr> {
self.leader_addr
}
pub fn set_leader_addr(&mut self, addr: SocketAddr) {
self.leader_addr = Some(addr);
if self.follow_leader {
self.config.addr = addr.to_string();
}
}
pub fn reconnect_attempts(&self) -> u32 {
self.reconnect_attempts
}
pub async fn client(&mut self) -> Result<&mut LanceClient> {
if self.client.is_none() {
self.reconnect().await?;
}
self.client.as_mut().ok_or(ClientError::ConnectionClosed)
}
pub async fn reconnect(&mut self) -> Result<()> {
let mut attempts = 0;
loop {
attempts += 1;
self.reconnect_attempts += 1;
let mut config = self.config.clone();
config.addr = self.addr.clone();
let result = match &self.tls_config {
Some(tls) => LanceClient::connect_tls(config, tls.clone()).await,
None => LanceClient::connect(config).await,
};
match result {
Ok(client) => {
self.client = Some(client);
return Ok(());
},
Err(e) => {
if self.max_attempts > 0 && attempts >= self.max_attempts {
return Err(e);
}
let delay = self.base_delay * 2u32.saturating_pow(attempts - 1);
let delay = delay.min(self.max_delay);
tokio::time::sleep(delay).await;
},
}
}
}
pub async fn execute<F, T>(&mut self, op: F) -> Result<T>
where
F: Fn(
&mut LanceClient,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send + '_>>,
{
let mut attempts = 0u32;
loop {
let client = self.client().await?;
match op(client).await {
Ok(result) => return Ok(result),
Err(e) if e.is_retryable() => {
attempts += 1;
if self.max_attempts > 0 && attempts >= self.max_attempts {
return Err(e);
}
self.client = None;
let delay = self.base_delay * 2u32.saturating_pow(attempts.saturating_sub(1));
let delay = delay.min(self.max_delay);
tokio::time::sleep(delay).await;
},
Err(e) => return Err(e),
}
}
}
pub fn mark_failed(&mut self) {
self.client = None;
}
}
pub struct ClusterClient {
nodes: Vec<SocketAddr>,
primary: Option<SocketAddr>,
config: ClientConfig,
tls_config: Option<TlsClientConfig>,
client: Option<LanceClient>,
last_discovery: Option<Instant>,
discovery_interval: Duration,
}
impl ClusterClient {
pub async fn connect(seed_addrs: &[&str]) -> Result<Self> {
let nodes: Vec<SocketAddr> = seed_addrs.iter().filter_map(|s| s.parse().ok()).collect();
if nodes.is_empty() {
return Err(ClientError::ProtocolError(
"No valid seed addresses".to_string(),
));
}
let config = ClientConfig::new(nodes[0].to_string());
let mut cluster = Self {
nodes,
primary: None,
config,
tls_config: None,
client: None,
last_discovery: None,
discovery_interval: Duration::from_secs(60),
};
cluster.discover_cluster().await?;
Ok(cluster)
}
pub async fn connect_tls(seed_addrs: &[&str], tls_config: TlsClientConfig) -> Result<Self> {
let nodes: Vec<SocketAddr> = seed_addrs.iter().filter_map(|s| s.parse().ok()).collect();
if nodes.is_empty() {
return Err(ClientError::ProtocolError(
"No valid seed addresses".to_string(),
));
}
let config = ClientConfig::new(nodes[0].to_string()).with_tls(tls_config.clone());
let mut cluster = Self {
nodes,
primary: None,
config,
tls_config: Some(tls_config),
client: None,
last_discovery: None,
discovery_interval: Duration::from_secs(60),
};
cluster.discover_cluster().await?;
Ok(cluster)
}
pub fn with_discovery_interval(mut self, interval: Duration) -> Self {
self.discovery_interval = interval;
self
}
async fn discover_cluster(&mut self) -> Result<()> {
for &node in &self.nodes.clone() {
let mut config = self.config.clone();
config.addr = node.to_string();
match LanceClient::connect(config).await {
Ok(mut client) => {
match client.get_cluster_status().await {
Ok(status) => {
self.primary = status.leader_id.map(|id| {
status
.peer_states
.get(&id)
.and_then(|s| s.parse().ok())
.unwrap_or(node)
});
self.last_discovery = Some(Instant::now());
if let Some(primary_addr) = self.primary {
self.config.addr = primary_addr.to_string();
self.client =
Some(LanceClient::connect(self.config.clone()).await?);
} else {
self.client = Some(client);
}
return Ok(());
},
Err(_) => {
self.client = Some(client);
self.primary = Some(node);
self.last_discovery = Some(Instant::now());
return Ok(());
},
}
},
Err(_) => continue,
}
}
Err(ClientError::ConnectionFailed(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"Could not connect to any cluster node",
)))
}
pub async fn client(&mut self) -> Result<&mut LanceClient> {
let needs_refresh = self
.last_discovery
.map(|t| t.elapsed() > self.discovery_interval)
.unwrap_or(true);
if needs_refresh || self.client.is_none() {
self.discover_cluster().await?;
}
self.client.as_mut().ok_or(ClientError::ConnectionClosed)
}
pub fn primary(&self) -> Option<SocketAddr> {
self.primary
}
pub fn nodes(&self) -> &[SocketAddr] {
&self.nodes
}
pub fn tls_config(&self) -> Option<&TlsClientConfig> {
self.tls_config.as_ref()
}
pub fn is_tls_enabled(&self) -> bool {
self.tls_config.is_some()
}
pub async fn refresh(&mut self) -> Result<()> {
self.discover_cluster().await
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_pool_config_defaults() {
let config = ConnectionPoolConfig::new();
assert_eq!(config.max_connections, 10);
assert_eq!(config.min_idle, 1);
assert!(config.auto_reconnect);
}
#[test]
fn test_pool_config_builder() {
let config = ConnectionPoolConfig::new()
.with_max_connections(20)
.with_min_idle(5)
.with_health_check_interval(60)
.with_auto_reconnect(false);
assert_eq!(config.max_connections, 20);
assert_eq!(config.min_idle, 5);
assert_eq!(config.health_check_interval, Duration::from_secs(60));
assert!(!config.auto_reconnect);
}
#[test]
fn test_pool_stats_default() {
let stats = PoolStats::default();
assert_eq!(stats.connections_created, 0);
assert_eq!(stats.active_connections, 0);
}
#[test]
fn test_pooled_connection_expiry() {
use std::thread::sleep;
let max_lifetime = Duration::from_millis(10);
let created_at = Instant::now();
sleep(Duration::from_millis(20));
assert!(created_at.elapsed() > max_lifetime);
}
#[test]
fn test_reconnecting_client_leader_addr() {
let addr: SocketAddr = "127.0.0.1:1992".parse().unwrap();
let leader: SocketAddr = "127.0.0.1:1993".parse().unwrap();
let follow_leader = true;
let mut config_addr = addr;
let leader_addr: Option<SocketAddr> = Some(leader);
if follow_leader {
config_addr = leader;
}
assert_eq!(leader_addr, Some(leader));
assert_eq!(config_addr, leader);
}
#[test]
fn test_connection_pool_config_auto_reconnect() {
let config = ConnectionPoolConfig::new()
.with_auto_reconnect(true)
.with_max_reconnect_attempts(10);
assert!(config.auto_reconnect);
assert_eq!(config.max_reconnect_attempts, 10);
}
}