use crate::config::ClientConfig;
use crate::error::{Result, SdkError};
use dashmap::DashMap;
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::Instant;
use tokio::time::timeout;
use tonic::transport::{Channel, Endpoint};
use tracing::{debug, info, warn};
#[derive(Clone)]
pub struct Connection {
channel: Channel,
created_at: Instant,
last_used: Arc<RwLock<Instant>>,
}
impl Connection {
fn new(channel: Channel) -> Self {
let now = Instant::now();
Self {
channel,
created_at: now,
last_used: Arc::new(RwLock::new(now)),
}
}
pub fn channel(&self) -> &Channel {
*self.last_used.write() = Instant::now();
&self.channel
}
fn is_idle(&self, idle_timeout: std::time::Duration) -> bool {
self.last_used.read().elapsed() > idle_timeout
}
fn age(&self) -> std::time::Duration {
self.created_at.elapsed()
}
}
pub struct ConnectionPool {
config: Arc<ClientConfig>,
connections: DashMap<usize, Connection>,
next_id: Arc<parking_lot::Mutex<usize>>,
}
impl ConnectionPool {
pub fn new(config: ClientConfig) -> Self {
Self {
config: Arc::new(config),
connections: DashMap::new(),
next_id: Arc::new(parking_lot::Mutex::new(0)),
}
}
pub async fn get(&self) -> Result<Connection> {
for entry in self.connections.iter() {
let conn = entry.value();
if !conn.is_idle(self.config.idle_timeout) {
debug!("Reusing connection {}", entry.key());
return Ok(conn.clone());
}
}
self.cleanup_idle();
if self.connections.len() >= self.config.max_connections {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
for entry in self.connections.iter() {
let conn = entry.value();
if !conn.is_idle(self.config.idle_timeout) {
return Ok(conn.clone());
}
}
return Err(SdkError::Connection(
"connection pool exhausted".to_string(),
));
}
self.create_connection().await
}
async fn create_connection(&self) -> Result<Connection> {
info!("Creating new connection to {}", self.config.server_addr);
let mut endpoint = Endpoint::from_shared(self.config.server_addr.clone())
.map_err(|e| SdkError::Configuration(format!("invalid server address: {}", e)))?;
endpoint = endpoint
.timeout(self.config.request_timeout)
.connect_timeout(self.config.connect_timeout);
if self.config.keep_alive {
endpoint = endpoint
.keep_alive_timeout(self.config.keep_alive_timeout)
.http2_keep_alive_interval(self.config.keep_alive_interval);
}
if self.config.tls_enabled {
if let Some(tls_config) = &self.config.tls_config {
let mut client_tls = tonic::transport::ClientTlsConfig::new();
if let Some(domain) = &tls_config.domain_name {
client_tls = client_tls.domain_name(domain.clone());
}
if let Some(ca_path) = &tls_config.ca_cert_path {
let ca_pem = std::fs::read(ca_path).map_err(|e| {
SdkError::Configuration(format!(
"failed to read CA certificate at {}: {}",
ca_path.display(),
e
))
})?;
let ca_cert = tonic::transport::Certificate::from_pem(ca_pem);
client_tls = client_tls.ca_certificate(ca_cert);
}
if let (Some(cert_path), Some(key_path)) =
(&tls_config.client_cert_path, &tls_config.client_key_path)
{
let cert_pem = std::fs::read(cert_path).map_err(|e| {
SdkError::Configuration(format!(
"failed to read client certificate at {}: {}",
cert_path.display(),
e
))
})?;
let key_pem = std::fs::read(key_path).map_err(|e| {
SdkError::Configuration(format!(
"failed to read client key at {}: {}",
key_path.display(),
e
))
})?;
let identity = tonic::transport::Identity::from_pem(cert_pem, key_pem);
client_tls = client_tls.identity(identity);
}
endpoint = endpoint.tls_config(client_tls).map_err(|e| {
SdkError::Configuration(format!("failed to configure TLS: {}", e))
})?;
debug!("TLS configured successfully");
}
}
let channel = timeout(self.config.connect_timeout, endpoint.connect())
.await
.map_err(|_| {
SdkError::Timeout(format!(
"connection timeout after {:?}",
self.config.connect_timeout
))
})?
.map_err(SdkError::Transport)?;
let conn = Connection::new(channel);
let id = {
let mut next = self.next_id.lock();
let id = *next;
*next += 1;
id
};
self.connections.insert(id, conn.clone());
info!("Connection {} created successfully", id);
Ok(conn)
}
fn cleanup_idle(&self) {
let mut to_remove = Vec::new();
for entry in self.connections.iter() {
if entry.value().is_idle(self.config.idle_timeout) {
to_remove.push(*entry.key());
}
}
for id in to_remove {
if let Some((_, conn)) = self.connections.remove(&id) {
warn!("Removing idle connection {} (age: {:?})", id, conn.age());
}
}
}
pub fn close_all(&self) {
info!("Closing all connections ({})", self.connections.len());
self.connections.clear();
}
pub fn stats(&self) -> PoolStats {
let total = self.connections.len();
let mut idle = 0;
for entry in self.connections.iter() {
if entry.value().is_idle(self.config.idle_timeout) {
idle += 1;
}
}
PoolStats {
total_connections: total,
active_connections: total - idle,
idle_connections: idle,
max_connections: self.config.max_connections,
}
}
}
#[derive(Debug, Clone)]
pub struct PoolStats {
pub total_connections: usize,
pub active_connections: usize,
pub idle_connections: usize,
pub max_connections: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_connection_idle() {
let now = Instant::now();
let last_used = Arc::new(RwLock::new(now));
std::thread::sleep(Duration::from_millis(10));
let elapsed = last_used.read().elapsed();
assert!(elapsed >= Duration::from_millis(10));
}
#[test]
fn test_pool_stats() {
let config = ClientConfig::default();
let pool = ConnectionPool::new(config);
let stats = pool.stats();
assert_eq!(stats.total_connections, 0);
assert_eq!(stats.active_connections, 0);
assert_eq!(stats.max_connections, 10);
}
#[test]
fn test_tls_config_construction() {
use crate::config::TlsConfig;
let tls = TlsConfig::new().with_domain_name("example.com");
let config = ClientConfig::new("https://example.com:50051").with_tls(tls);
assert!(config.tls_enabled);
assert!(config.tls_config.is_some());
let tls_cfg = config
.tls_config
.as_ref()
.expect("tls_config should be Some");
assert_eq!(tls_cfg.domain_name, Some("example.com".to_string()));
assert!(tls_cfg.ca_cert_path.is_none());
assert!(tls_cfg.client_cert_path.is_none());
assert!(tls_cfg.client_key_path.is_none());
}
#[test]
fn test_tls_config_with_mtls_paths() {
use crate::config::TlsConfig;
let tls = TlsConfig::new()
.with_ca_cert("/path/to/ca.pem")
.with_client_cert("/path/to/client.pem", "/path/to/client.key")
.with_domain_name("db.example.com");
let config = ClientConfig::new("https://db.example.com:50051").with_tls(tls);
assert!(config.tls_enabled);
let tls_cfg = config
.tls_config
.as_ref()
.expect("tls_config should be Some");
assert_eq!(
tls_cfg.ca_cert_path.as_ref().map(|p| p.to_str()),
Some(Some("/path/to/ca.pem"))
);
assert_eq!(
tls_cfg.client_cert_path.as_ref().map(|p| p.to_str()),
Some(Some("/path/to/client.pem"))
);
assert_eq!(
tls_cfg.client_key_path.as_ref().map(|p| p.to_str()),
Some(Some("/path/to/client.key"))
);
}
}