use crate::{Result, Error, EmbeddedDatabase};
use super::handler::PgConnectionHandler;
use super::auth::{AuthManager, AuthMethod};
use super::ssl::{SslConfig, SslNegotiator, SslMode, SecureConnection};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Semaphore;
use std::sync::Arc;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
const DEFAULT_PG_ADDRESS: SocketAddr = SocketAddr::new(
IpAddr::V4(Ipv4Addr::UNSPECIFIED),
5432
);
#[derive(Debug, Clone)]
pub struct PgServerConfig {
pub address: SocketAddr,
pub auth_method: AuthMethod,
pub max_connections: usize,
pub ssl_config: Option<SslConfig>,
}
impl Default for PgServerConfig {
fn default() -> Self {
Self {
address: DEFAULT_PG_ADDRESS,
auth_method: AuthMethod::Trust,
max_connections: 100,
ssl_config: None,
}
}
}
impl PgServerConfig {
pub fn with_address(address: SocketAddr) -> Self {
Self {
address,
..Default::default()
}
}
pub fn with_auth_method(mut self, method: AuthMethod) -> Self {
self.auth_method = method;
self
}
pub fn with_max_connections(mut self, max: usize) -> Self {
self.max_connections = max;
self
}
pub fn with_ssl(mut self, ssl_config: SslConfig) -> Self {
self.ssl_config = Some(ssl_config);
self
}
pub fn with_ssl_test(mut self) -> Result<Self> {
let ssl_config = SslConfig::new(
SslMode::Allow,
"certs/server.crt",
"certs/server.key",
);
self.ssl_config = Some(ssl_config);
Ok(self)
}
}
pub struct PgServer {
config: PgServerConfig,
database: Arc<EmbeddedDatabase>,
auth_manager: Arc<AuthManager>,
ssl_negotiator: Option<Arc<SslNegotiator>>,
connection_limiter: Arc<Semaphore>,
}
impl PgServer {
fn enforce_trust_loopback_only(config: &PgServerConfig) -> Result<()> {
if matches!(config.auth_method, AuthMethod::Trust) && !config.address.ip().is_loopback() {
return Err(Error::authentication(format!(
"AuthMethod::Trust is only permitted on loopback (127.0.0.1, ::1) listeners; \
binding to {} requires a non-trust auth method (password, scram-sha-256). \
To start anyway on a non-loopback address, switch the auth method or bind to 127.0.0.1.",
config.address
)));
}
Ok(())
}
pub fn new(config: PgServerConfig, database: Arc<EmbeddedDatabase>) -> Result<Self> {
Self::enforce_trust_loopback_only(&config)?;
let auth_manager = Arc::new(
AuthManager::new(config.auth_method)
.with_default_users()
);
let ssl_negotiator = if let Some(ref ssl_config) = config.ssl_config {
Some(Arc::new(SslNegotiator::new(ssl_config.clone())?))
} else {
None
};
let connection_limiter = Arc::new(Semaphore::new(config.max_connections));
Ok(Self {
config,
database,
auth_manager,
ssl_negotiator,
connection_limiter,
})
}
pub fn with_auth_manager(
config: PgServerConfig,
database: Arc<EmbeddedDatabase>,
auth_manager: AuthManager,
) -> Result<Self> {
let effective_method = auth_manager.method();
if matches!(effective_method, AuthMethod::Trust) && !config.address.ip().is_loopback() {
return Err(Error::authentication(format!(
"AuthMethod::Trust is only permitted on loopback (127.0.0.1, ::1) listeners; \
binding to {} requires a non-trust auth method (password, scram-sha-256).",
config.address
)));
}
let ssl_negotiator = if let Some(ref ssl_config) = config.ssl_config {
Some(Arc::new(SslNegotiator::new(ssl_config.clone())?))
} else {
None
};
let connection_limiter = Arc::new(Semaphore::new(config.max_connections));
Ok(Self {
config,
database,
auth_manager: Arc::new(auth_manager),
ssl_negotiator,
connection_limiter,
})
}
pub async fn serve(&self) -> Result<()> {
let listener = TcpListener::bind(self.config.address).await
.map_err(|e| Error::network(format!("Failed to bind to {}: {}", self.config.address, e)))?;
let ssl_enabled = self.ssl_negotiator.is_some();
tracing::info!(
"PostgreSQL server listening on {} (auth: {:?}, ssl: {})",
self.config.address,
self.config.auth_method,
if ssl_enabled { "enabled" } else { "disabled" }
);
loop {
match listener.accept().await {
Ok((stream, addr)) => {
if let Err(e) = stream.set_nodelay(true) {
tracing::warn!("Failed to set TCP_NODELAY for {}: {}", addr, e);
}
let permit = match Arc::clone(&self.connection_limiter).try_acquire_owned() {
Ok(permit) => permit,
Err(_) => {
tracing::warn!("Connection limit reached ({}), rejecting {}", self.config.max_connections, addr);
drop(stream);
continue;
}
};
tracing::debug!("Accepted connection from {}", addr);
let database = Arc::clone(&self.database);
let auth_manager = Arc::clone(&self.auth_manager);
let ssl_negotiator = self.ssl_negotiator.clone();
tokio::spawn(async move {
let _permit = permit;
if let Err(e) = Self::handle_connection(stream, database, auth_manager, ssl_negotiator).await {
tracing::error!("Connection error from {}: {}", addr, e);
}
});
}
Err(e) => {
tracing::error!("Failed to accept connection: {}", e);
}
}
}
}
async fn handle_connection(
mut stream: TcpStream,
database: Arc<EmbeddedDatabase>,
auth_manager: Arc<AuthManager>,
ssl_negotiator: Option<Arc<SslNegotiator>>,
) -> Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf).await
.map_err(|e| Error::network(format!("Failed to read message length: {}", e)))?;
let mut code_buf = [0u8; 4];
stream.read_exact(&mut code_buf).await
.map_err(|e| Error::network(format!("Failed to read request code: {}", e)))?;
let code = i32::from_be_bytes(code_buf);
let is_ssl_request = code == super::ssl::SSL_REQUEST_CODE;
if let Some(negotiator) = ssl_negotiator {
if is_ssl_request {
let ssl_accepted = negotiator.negotiate(&mut stream, true).await?;
if ssl_accepted {
if let Some(acceptor) = negotiator.acceptor() {
tracing::debug!("Upgrading connection to TLS");
let tls_stream = acceptor.accept(stream).await
.map_err(|e| Error::network(format!("TLS handshake failed: {}", e)))?;
let secure_conn = SecureConnection::Tls(tls_stream);
let mut handler = PgConnectionHandler::new_with_stream(
secure_conn,
database,
auth_manager,
None );
return handler.handle().await;
}
} else if negotiator.is_required() {
return Err(Error::network("SSL is required but was rejected"));
}
} else if negotiator.is_required() {
return Err(Error::network("SSL is required but no SSL request was received"));
}
} else if is_ssl_request {
tracing::debug!("SSL request received but SSL is not configured, sending rejection");
stream.write_all(b"N").await
.map_err(|e| Error::network(format!("Failed to send SSL rejection: {}", e)))?;
stream.flush().await
.map_err(|e| Error::network(format!("Failed to flush stream: {}", e)))?;
let secure_conn = SecureConnection::Plain(stream);
let mut handler = PgConnectionHandler::new_with_stream(
secure_conn,
database,
auth_manager,
None
);
return handler.handle().await;
}
let mut initial_data = Vec::with_capacity(8);
initial_data.extend_from_slice(&len_buf);
initial_data.extend_from_slice(&code_buf);
let secure_conn = SecureConnection::Plain(stream);
let mut handler = PgConnectionHandler::new_with_stream(
secure_conn,
database,
auth_manager,
Some(&initial_data)
);
handler.handle().await
}
pub fn config(&self) -> &PgServerConfig {
&self.config
}
}
pub struct PgServerBuilder {
config: PgServerConfig,
auth_manager: Option<AuthManager>,
}
impl PgServerBuilder {
pub fn new() -> Self {
Self {
config: PgServerConfig::default(),
auth_manager: None,
}
}
pub fn address(mut self, addr: SocketAddr) -> Self {
self.config.address = addr;
self
}
pub fn auth_method(mut self, method: AuthMethod) -> Self {
self.config.auth_method = method;
self
}
pub fn max_connections(mut self, max: usize) -> Self {
self.config.max_connections = max;
self
}
pub fn auth_manager(mut self, manager: AuthManager) -> Self {
self.auth_manager = Some(manager);
self
}
pub fn ssl_config(mut self, ssl_config: SslConfig) -> Self {
self.config.ssl_config = Some(ssl_config);
self
}
pub fn ssl_test(mut self) -> Self {
self.config.ssl_config = Some(SslConfig::new(
SslMode::Allow,
"certs/server.crt",
"certs/server.key",
));
self
}
pub fn build(self, database: Arc<EmbeddedDatabase>) -> Result<PgServer> {
if let Some(auth_manager) = self.auth_manager {
PgServer::with_auth_manager(self.config, database, auth_manager)
} else {
PgServer::new(self.config, database)
}
}
}
impl Default for PgServerBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = PgServerConfig::default();
assert_eq!(config.address.port(), 5432);
assert_eq!(config.max_connections, 100);
}
#[test]
fn test_config_builder() {
let addr: SocketAddr = "127.0.0.1:15432".parse().unwrap();
let config = PgServerConfig::with_address(addr)
.with_auth_method(AuthMethod::CleartextPassword)
.with_max_connections(50);
assert_eq!(config.address, addr);
assert_eq!(config.auth_method, AuthMethod::CleartextPassword);
assert_eq!(config.max_connections, 50);
}
#[test]
fn test_server_builder() {
let db = Arc::new(EmbeddedDatabase::new_in_memory().unwrap());
let addr: SocketAddr = "127.0.0.1:15432".parse().unwrap();
let server = PgServerBuilder::new()
.address(addr)
.auth_method(AuthMethod::Trust)
.max_connections(25)
.build(db)
.unwrap();
assert_eq!(server.config().address, addr);
assert_eq!(server.config().max_connections, 25);
}
#[test]
fn test_ssl_config() {
let config = PgServerConfig::default();
assert!(config.ssl_config.is_none());
let ssl_config = SslConfig::new(
SslMode::Require,
"cert.pem",
"key.pem",
);
let config_with_ssl = PgServerConfig::default().with_ssl(ssl_config);
assert!(config_with_ssl.ssl_config.is_some());
}
}