use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::net::TcpListener;
use tokio::sync::Semaphore;
use tracing::{error, info, warn};
use super::session::Session;
use crate::{EmbeddedDatabase, Error};
const DEFAULT_MAX_CONNECTIONS: usize = 256;
pub struct PgServer {
address: String,
db: Arc<EmbeddedDatabase>,
next_session_id: Arc<AtomicU32>,
connection_limiter: Arc<Semaphore>,
max_connections: usize,
idle_timeout_secs: u64,
}
impl PgServer {
pub fn new(address: impl Into<String>, db: Arc<EmbeddedDatabase>) -> Self {
Self {
address: address.into(),
db,
next_session_id: Arc::new(AtomicU32::new(1)),
connection_limiter: Arc::new(Semaphore::new(DEFAULT_MAX_CONNECTIONS)),
max_connections: DEFAULT_MAX_CONNECTIONS,
idle_timeout_secs: 300, }
}
pub fn with_max_connections(address: impl Into<String>, db: Arc<EmbeddedDatabase>, max_connections: usize) -> Self {
Self {
address: address.into(),
db,
next_session_id: Arc::new(AtomicU32::new(1)),
connection_limiter: Arc::new(Semaphore::new(max_connections)),
max_connections,
idle_timeout_secs: 300,
}
}
pub fn with_idle_timeout(mut self, secs: u64) -> Self {
self.idle_timeout_secs = secs;
self
}
pub async fn run(self) -> Result<(), Error> {
let listener = TcpListener::bind(&self.address)
.await
.map_err(|e| Error::protocol(format!("Failed to bind to {}: {}", self.address, e)))?;
info!("HeliosDB PostgreSQL server listening on {} (max_connections: {})", self.address, self.max_connections);
let parts: Vec<&str> = self.address.split(':').collect();
let host = parts.first().unwrap_or(&"localhost");
let port = parts.get(1).unwrap_or(&"5432");
info!("Connect with: psql -h {} -p {}", host, port);
loop {
let (stream, addr) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
error!("Failed to accept connection: {}", e);
continue;
}
};
let permit = match Arc::clone(&self.connection_limiter).try_acquire_owned() {
Ok(permit) => permit,
Err(_) => {
warn!("Connection limit reached ({}), rejecting {}", self.max_connections, addr);
drop(stream);
continue;
}
};
let session_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
info!("New connection from {} (session {})", addr, session_id);
let session = Session::new(Arc::clone(&self.db), session_id)
.with_idle_timeout(self.idle_timeout_secs);
tokio::spawn(async move {
let _permit = permit;
if let Err(e) = session.handle_connection(stream).await {
error!("Session {} error: {}", session_id, e);
}
info!("Session {} ended", session_id);
});
}
}
pub async fn run_with_shutdown<F>(self, shutdown: F) -> Result<(), Error>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
let listener = TcpListener::bind(&self.address)
.await
.map_err(|e| Error::protocol(format!("Failed to bind to {}: {}", self.address, e)))?;
info!("HeliosDB PostgreSQL server listening on {} (max_connections: {})", self.address, self.max_connections);
let parts: Vec<&str> = self.address.split(':').collect();
let host = parts.first().unwrap_or(&"localhost");
let port = parts.get(1).unwrap_or(&"5432");
info!("Connect with: psql -h {} -p {}", host, port);
tokio::pin!(shutdown);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, addr) = match result {
Ok(conn) => conn,
Err(e) => {
error!("Failed to accept connection: {}", e);
continue;
}
};
let permit = match Arc::clone(&self.connection_limiter).try_acquire_owned() {
Ok(permit) => permit,
Err(_) => {
warn!("Connection limit reached ({}), rejecting {}", self.max_connections, addr);
drop(stream);
continue;
}
};
let session_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
info!("New connection from {} (session {})", addr, session_id);
let session = Session::new(Arc::clone(&self.db), session_id)
.with_idle_timeout(self.idle_timeout_secs);
tokio::spawn(async move {
let _permit = permit;
if let Err(e) = session.handle_connection(stream).await {
error!("Session {} error: {}", session_id, e);
}
info!("Session {} ended", session_id);
});
}
() = &mut shutdown => {
info!("Shutdown signal received, stopping server");
break;
}
}
}
info!("Server stopped");
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[tokio::test]
async fn test_server_creation() {
let db = Arc::new(EmbeddedDatabase::new_in_memory().unwrap());
let _server = PgServer::new("127.0.0.1:0", db);
}
}