use crate::error::{Error, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PostgresConfig {
pub version: String,
pub database_name: String,
pub username: String,
pub password: String,
pub listen_addresses: String,
pub port: u16,
pub max_connections: Option<u32>,
pub shared_buffers: Option<String>,
pub effective_cache_size: Option<String>,
pub work_mem: Option<String>,
pub maintenance_work_mem: Option<String>,
pub wal_buffers: Option<String>,
pub checkpoint_completion_target: Option<f32>,
pub ssl: bool,
pub extra_config: std::collections::HashMap<String, String>,
}
impl PostgresConfig {
pub fn builder() -> PostgresConfigBuilder {
PostgresConfigBuilder::default()
}
pub fn validate(&self) -> Result<()> {
if self.version.is_empty() {
return Err(Error::MissingConfig("version".to_string()));
}
if !self.version.chars().all(|c| c.is_ascii_digit() || c == '.') {
return Err(Error::InvalidVersion(self.version.clone()));
}
if self.database_name.is_empty() {
return Err(Error::MissingConfig("database_name".to_string()));
}
if self.username.is_empty() {
return Err(Error::MissingConfig("username".to_string()));
}
if self.password.is_empty() {
return Err(Error::MissingConfig("password".to_string()));
}
if self.port == 0 {
return Err(Error::invalid_config("port", self.port.to_string()));
}
if let Some(target) = self.checkpoint_completion_target
&& !(0.0..=1.0).contains(&target)
{
return Err(Error::invalid_config(
"checkpoint_completion_target",
target.to_string(),
));
}
if let Some(ref shared_buffers) = self.shared_buffers {
crate::validation::parse_memory_size(shared_buffers)?;
}
if let Some(ref work_mem) = self.work_mem {
crate::validation::parse_memory_size(work_mem)?;
}
if let Some(ref maintenance_work_mem) = self.maintenance_work_mem {
crate::validation::parse_memory_size(maintenance_work_mem)?;
}
if let Some(ref effective_cache_size) = self.effective_cache_size {
crate::validation::parse_memory_size(effective_cache_size)?;
}
if let Some(ref wal_buffers) = self.wal_buffers {
crate::validation::parse_memory_size(wal_buffers)?;
}
crate::validation::validate_listen_addresses(&self.listen_addresses)?;
Ok(())
}
pub fn config_dir(&self) -> String {
format!("/etc/postgresql/{}/main", self.version)
}
pub fn postgresql_conf_path(&self) -> String {
format!("{}/postgresql.conf", self.config_dir())
}
pub fn pg_hba_conf_path(&self) -> String {
format!("{}/pg_hba.conf", self.config_dir())
}
}
#[derive(Debug, Default)]
pub struct PostgresConfigBuilder {
version: Option<String>,
database_name: Option<String>,
username: Option<String>,
password: Option<String>,
listen_addresses: Option<String>,
port: Option<u16>,
max_connections: Option<u32>,
shared_buffers: Option<String>,
effective_cache_size: Option<String>,
work_mem: Option<String>,
maintenance_work_mem: Option<String>,
wal_buffers: Option<String>,
checkpoint_completion_target: Option<f32>,
ssl: Option<bool>,
extra_config: std::collections::HashMap<String, String>,
}
impl PostgresConfigBuilder {
pub fn version(mut self, version: impl Into<String>) -> Self {
self.version = Some(version.into());
self
}
pub fn database_name(mut self, name: impl Into<String>) -> Self {
self.database_name = Some(name.into());
self
}
pub fn username(mut self, username: impl Into<String>) -> Self {
self.username = Some(username.into());
self
}
pub fn password(mut self, password: impl Into<String>) -> Self {
self.password = Some(password.into());
self
}
pub fn listen_addresses(mut self, addresses: impl Into<String>) -> Self {
self.listen_addresses = Some(addresses.into());
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = Some(port);
self
}
pub fn max_connections(mut self, max: u32) -> Self {
self.max_connections = Some(max);
self
}
pub fn shared_buffers(mut self, size: impl Into<String>) -> Self {
self.shared_buffers = Some(size.into());
self
}
pub fn effective_cache_size(mut self, size: impl Into<String>) -> Self {
self.effective_cache_size = Some(size.into());
self
}
pub fn work_mem(mut self, size: impl Into<String>) -> Self {
self.work_mem = Some(size.into());
self
}
pub fn maintenance_work_mem(mut self, size: impl Into<String>) -> Self {
self.maintenance_work_mem = Some(size.into());
self
}
pub fn wal_buffers(mut self, size: impl Into<String>) -> Self {
self.wal_buffers = Some(size.into());
self
}
pub fn checkpoint_completion_target(mut self, target: f32) -> Self {
self.checkpoint_completion_target = Some(target);
self
}
pub fn ssl(mut self, enabled: bool) -> Self {
self.ssl = Some(enabled);
self
}
pub fn add_config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.extra_config.insert(key.into(), value.into());
self
}
pub fn build(self) -> Result<PostgresConfig> {
let config = PostgresConfig {
version: self.version.unwrap_or_else(|| "15".to_string()),
database_name: self
.database_name
.ok_or_else(|| Error::MissingConfig("database_name".to_string()))?,
username: self
.username
.ok_or_else(|| Error::MissingConfig("username".to_string()))?,
password: self
.password
.ok_or_else(|| Error::MissingConfig("password".to_string()))?,
listen_addresses: self
.listen_addresses
.unwrap_or_else(|| "0.0.0.0/0".to_string()),
port: self.port.unwrap_or(5432),
max_connections: self.max_connections,
shared_buffers: self.shared_buffers,
effective_cache_size: self.effective_cache_size,
work_mem: self.work_mem,
maintenance_work_mem: self.maintenance_work_mem,
wal_buffers: self.wal_buffers,
checkpoint_completion_target: self.checkpoint_completion_target,
ssl: self.ssl.unwrap_or(false),
extra_config: self.extra_config,
};
config.validate()?;
Ok(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_minimal() {
let config = PostgresConfig::builder()
.database_name("test_db")
.username("test_user")
.password("test_pass")
.build()
.unwrap();
assert_eq!(config.version, "15");
assert_eq!(config.database_name, "test_db");
assert_eq!(config.username, "test_user");
assert_eq!(config.password, "test_pass");
assert_eq!(config.port, 5432);
assert!(!config.ssl);
}
#[test]
fn test_builder_full() {
let config = PostgresConfig::builder()
.version("14")
.database_name("prod_db")
.username("prod_user")
.password("secure_pass")
.listen_addresses("10.0.0.0/8")
.port(5433)
.max_connections(200)
.shared_buffers("512MB")
.effective_cache_size("2GB")
.ssl(true)
.add_config("log_statement", "all")
.build()
.unwrap();
assert_eq!(config.version, "14");
assert_eq!(config.port, 5433);
assert_eq!(config.max_connections, Some(200));
assert!(config.ssl);
assert_eq!(
config.extra_config.get("log_statement"),
Some(&"all".to_string())
);
}
#[test]
fn test_missing_required_fields() {
let result = PostgresConfig::builder().build();
assert!(result.is_err());
let result = PostgresConfig::builder().database_name("db").build();
assert!(result.is_err());
let result = PostgresConfig::builder()
.database_name("db")
.username("user")
.build();
assert!(result.is_err());
}
#[test]
fn test_invalid_version() {
let result = PostgresConfig::builder()
.version("invalid-version")
.database_name("db")
.username("user")
.password("pass")
.build();
assert!(matches!(result, Err(Error::InvalidVersion(_))));
}
#[test]
fn test_config_paths() {
let config = PostgresConfig::builder()
.version("15")
.database_name("db")
.username("user")
.password("pass")
.build()
.unwrap();
assert_eq!(config.config_dir(), "/etc/postgresql/15/main");
assert_eq!(
config.postgresql_conf_path(),
"/etc/postgresql/15/main/postgresql.conf"
);
assert_eq!(
config.pg_hba_conf_path(),
"/etc/postgresql/15/main/pg_hba.conf"
);
}
}