use std::collections::HashMap;
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
use tokio::sync::{RwLock, Semaphore};
use tokio::time::timeout;
use tracing::{debug, error, info, warn};
use crate::config::Config;
use crate::database::Storage;
use crate::protocol::Connection;
use crate::protocol::shared_catalog::SharedCatalog;
#[derive(Debug, Clone)]
pub struct ConnectionStats {
pub total_connections: usize,
pub active_connections: usize,
pub failed_connections: usize,
pub timeout_connections: usize,
pub avg_connection_duration: Duration,
}
#[derive(Debug)]
struct ConnectionInfo {
pub client_addr: String,
pub started_at: Instant,
pub last_activity: Instant,
}
pub struct ConnectionManager {
config: Arc<Config>,
storage: Arc<Storage>,
connections: Arc<RwLock<HashMap<usize, ConnectionInfo>>>,
connection_counter: AtomicUsize,
active_connections: AtomicUsize,
failed_connections: AtomicUsize,
timeout_connections: AtomicUsize,
connection_semaphore: Arc<Semaphore>,
shared_catalog: Option<SharedCatalog>,
}
impl Clone for ConnectionManager {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
storage: self.storage.clone(),
connections: self.connections.clone(),
connection_counter: AtomicUsize::new(self.connection_counter.load(Ordering::SeqCst)),
active_connections: AtomicUsize::new(self.active_connections.load(Ordering::SeqCst)),
failed_connections: AtomicUsize::new(self.failed_connections.load(Ordering::SeqCst)),
timeout_connections: AtomicUsize::new(self.timeout_connections.load(Ordering::SeqCst)),
connection_semaphore: self.connection_semaphore.clone(),
shared_catalog: self.shared_catalog.clone(),
}
}
}
impl ConnectionManager {
pub async fn new(config: Arc<Config>, storage: Arc<Storage>) -> crate::Result<Self> {
let max_connections = config.max_connections.unwrap_or(1000);
let shared_catalog = if matches!(config.protocol, crate::config::Protocol::Postgres) {
Some(crate::protocol::shared_catalog::create_shared_catalog(storage.clone()).await?)
} else {
None
};
Ok(Self {
config,
storage,
connections: Arc::new(RwLock::new(HashMap::new())),
connection_counter: AtomicUsize::new(0),
active_connections: AtomicUsize::new(0),
failed_connections: AtomicUsize::new(0),
timeout_connections: AtomicUsize::new(0),
connection_semaphore: Arc::new(Semaphore::new(max_connections)),
shared_catalog,
})
}
pub async fn handle_connection(
&self,
mut stream: TcpStream,
client_addr: String,
) -> crate::Result<()> {
let permit =
match timeout(Duration::from_secs(30), self.connection_semaphore.acquire()).await {
Ok(Ok(permit)) => permit,
Ok(Err(_)) => {
error!("Failed to acquire connection permit");
return Err(crate::YamlBaseError::Database {
message: "Connection pool exhausted".to_string(),
});
}
Err(_) => {
error!("Timeout acquiring connection permit for {}", client_addr);
return Err(crate::YamlBaseError::Database {
message: "Connection pool timeout".to_string(),
});
}
};
if let Err(e) = self.configure_tcp_socket(&mut stream).await {
warn!("Failed to configure TCP socket options: {}", e);
}
let connection_id = self.connection_counter.fetch_add(1, Ordering::SeqCst);
let now = Instant::now();
{
let mut connections = self.connections.write().await;
connections.insert(
connection_id,
ConnectionInfo {
client_addr: client_addr.clone(),
started_at: now,
last_activity: now,
},
);
}
self.active_connections.fetch_add(1, Ordering::SeqCst);
info!(
"Connection {} from {} established",
connection_id, client_addr
);
let result = self
.handle_connection_with_recovery(stream, connection_id, client_addr.clone())
.await;
{
let mut connections = self.connections.write().await;
connections.remove(&connection_id);
}
self.active_connections.fetch_sub(1, Ordering::SeqCst);
drop(permit);
match &result {
Ok(_) => {
let duration = now.elapsed();
info!(
"Connection {} closed normally after {:?}",
connection_id, duration
);
}
Err(e) => {
self.failed_connections.fetch_add(1, Ordering::SeqCst);
if e.to_string().contains("timeout") {
self.timeout_connections.fetch_add(1, Ordering::SeqCst);
}
error!("Connection {} failed: {}", connection_id, e);
}
}
result
}
async fn configure_tcp_socket(&self, stream: &mut TcpStream) -> crate::Result<()> {
#[cfg(unix)]
{
use std::mem::size_of;
use std::os::unix::io::AsRawFd;
let fd = stream.as_raw_fd();
let nodelay: libc::c_int = 1;
unsafe {
if libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_NODELAY,
&nodelay as *const _ as *const libc::c_void,
size_of::<libc::c_int>() as libc::socklen_t,
) != 0
{
return Err(crate::YamlBaseError::Database {
message: "Failed to set TCP_NODELAY".to_string(),
});
}
}
let keepalive: libc::c_int = 1;
unsafe {
if libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_KEEPALIVE,
&keepalive as *const _ as *const libc::c_void,
size_of::<libc::c_int>() as libc::socklen_t,
) != 0
{
return Err(crate::YamlBaseError::Database {
message: "Failed to set SO_KEEPALIVE".to_string(),
});
}
}
}
#[cfg(windows)]
{
}
#[cfg(target_os = "linux")]
{
use std::mem::size_of;
use std::os::unix::io::AsRawFd;
let fd = stream.as_raw_fd();
let keepalive_time: libc::c_int = 60;
unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_KEEPIDLE,
&keepalive_time as *const _ as *const libc::c_void,
size_of::<libc::c_int>() as libc::socklen_t,
);
}
let keepalive_interval: libc::c_int = 10;
unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_KEEPINTVL,
&keepalive_interval as *const _ as *const libc::c_void,
size_of::<libc::c_int>() as libc::socklen_t,
);
}
let keepalive_count: libc::c_int = 6;
unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_KEEPCNT,
&keepalive_count as *const _ as *const libc::c_void,
size_of::<libc::c_int>() as libc::socklen_t,
);
}
}
debug!("TCP socket configured with stability options");
Ok(())
}
async fn handle_connection_with_recovery(
&self,
stream: TcpStream,
connection_id: usize,
client_addr: String,
) -> crate::Result<()> {
let connection_timeout = self
.config
.connection_timeout
.unwrap_or(Duration::from_secs(30));
let connection = if let Some(shared_catalog) = &self.shared_catalog {
Connection::new_with_shared_catalog_ref(
self.config.clone(),
self.storage.clone(),
shared_catalog.clone()
)
} else {
Connection::new(self.config.clone(), self.storage.clone())
};
let connection_future = async {
self.update_connection_activity(connection_id).await;
connection.handle(stream).await
};
match timeout(connection_timeout, connection_future).await {
Ok(result) => result,
Err(_) => {
warn!(
"Connection {} from {} timed out after {:?}",
connection_id, client_addr, connection_timeout
);
Err(crate::YamlBaseError::Database {
message: format!("Connection timeout after {:?}", connection_timeout),
})
}
}
}
async fn update_connection_activity(&self, connection_id: usize) {
let mut connections = self.connections.write().await;
if let Some(conn_info) = connections.get_mut(&connection_id) {
conn_info.last_activity = Instant::now();
}
}
pub async fn get_stats(&self) -> ConnectionStats {
let connections = self.connections.read().await;
let active = self.active_connections.load(Ordering::SeqCst);
let total = self.connection_counter.load(Ordering::SeqCst);
let failed = self.failed_connections.load(Ordering::SeqCst);
let timeouts = self.timeout_connections.load(Ordering::SeqCst);
let now = Instant::now();
let total_duration: Duration = connections
.values()
.map(|conn| now.duration_since(conn.started_at))
.sum();
let avg_duration = if !connections.is_empty() {
total_duration / connections.len() as u32
} else {
Duration::from_secs(0)
};
ConnectionStats {
total_connections: total,
active_connections: active,
failed_connections: failed,
timeout_connections: timeouts,
avg_connection_duration: avg_duration,
}
}
pub async fn cleanup_stale_connections(&self) {
let idle_timeout = Duration::from_secs(1800); let now = Instant::now();
let mut to_remove = Vec::new();
{
let connections = self.connections.read().await;
for (id, conn_info) in connections.iter() {
if now.duration_since(conn_info.last_activity) > idle_timeout {
warn!(
"Connection {} from {} is idle for {:?}, marking for cleanup",
id,
conn_info.client_addr,
now.duration_since(conn_info.last_activity)
);
to_remove.push(*id);
}
}
}
if !to_remove.is_empty() {
let mut connections = self.connections.write().await;
for id in to_remove {
connections.remove(&id);
info!("Cleaned up stale connection {}", id);
}
}
}
pub fn start_monitoring(&self) -> tokio::task::JoinHandle<()> {
let manager = Arc::new(self.connections.clone());
let stats_interval = Duration::from_secs(60); let cleanup_interval = Duration::from_secs(300);
tokio::spawn(async move {
let mut stats_timer = tokio::time::interval(stats_interval);
let mut cleanup_timer = tokio::time::interval(cleanup_interval);
loop {
tokio::select! {
_ = stats_timer.tick() => {
let connections = manager.read().await;
info!("Connection pool status: {} active connections", connections.len());
for (id, conn_info) in connections.iter() {
debug!("Connection {}: {} (active for {:?})",
id, conn_info.client_addr,
Instant::now().duration_since(conn_info.started_at));
}
}
_ = cleanup_timer.tick() => {
debug!("Connection cleanup cycle (would cleanup stale connections)");
}
}
}
})
}
}