use anyhow::Result;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use tokio::sync::Notify;
use tokio_postgres::{connect, Client, NoTls};
use tracing::{debug, error, warn};
use crate::config::PoolConfig;
use crate::errors::{MCPError, Result as MCPResult};
pub struct ConnectionPool {
config: PoolConfig,
connection_string: String,
idle_connections: crossbeam::queue::SegQueue<Arc<Client>>,
active_connections: AtomicU32,
notify: Notify,
}
impl ConnectionPool {
pub async fn new(connection_string: &str, config: PoolConfig) -> Result<Self> {
debug!("Creating connection pool with config: {:?}", config);
let idle_queue = crossbeam::queue::SegQueue::new();
let mut created = 0u32;
for _ in 0..config.min_size {
match connect(connection_string, NoTls).await {
Ok((client, connection)) => {
tokio::spawn(async move {
if let Err(e) = connection.await {
error!("Connection error: {}", e);
}
});
idle_queue.push(Arc::new(client));
created += 1;
}
Err(e) => {
warn!("Failed to create initial connection: {}", e);
}
}
}
if created == 0 {
return Err(anyhow::anyhow!(
"Failed to establish any database connection. Check DATABASE_URL and ensure PostgreSQL is running."
));
}
Ok(Self {
config,
connection_string: connection_string.to_string(),
idle_connections: idle_queue,
active_connections: AtomicU32::new(created),
notify: Notify::new(),
})
}
pub async fn acquire(&self) -> MCPResult<Arc<Client>> {
loop {
if let Some(conn) = self.idle_connections.pop() {
if is_connection_alive(&conn) {
return Ok(conn);
}
self.active_connections.fetch_sub(1, Ordering::Relaxed);
continue;
}
let prev = self.active_connections.fetch_add(1, Ordering::Relaxed);
if prev < self.config.max_size {
match connect(&self.connection_string, NoTls).await {
Ok((client, connection)) => {
tokio::spawn(async move {
if let Err(e) = connection.await {
error!("Lazy connection error: {}", e);
}
});
return Ok(Arc::new(client));
}
Err(e) => {
error!("Failed to create lazy connection: {}", e);
self.active_connections.fetch_sub(1, Ordering::Relaxed);
continue;
}
}
} else {
self.active_connections.fetch_sub(1, Ordering::Relaxed);
tokio::time::timeout(self.config.queue_timeout, self.notify.notified())
.await
.map_err(|_| MCPError::PoolError("Connection pool exhausted".into()))?;
}
}
}
pub fn release(&self, conn: Arc<Client>) {
if is_connection_alive(&conn) {
self.idle_connections.push(conn);
} else {
self.active_connections.fetch_sub(1, Ordering::Relaxed);
}
self.notify.notify_one();
debug!("Connection released back to pool");
}
pub fn active_count(&self) -> u32 {
self.active_connections.load(Ordering::Relaxed)
}
pub fn max_size(&self) -> u32 {
self.config.max_size
}
}
fn is_connection_alive(conn: &Client) -> bool {
!conn.is_closed()
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_config() {
let cfg = PoolConfig {
min_size: 2,
max_size: 10,
queue_timeout: Duration::from_secs(10),
};
assert!(cfg.max_size >= cfg.min_size);
}
}