use crate::config::PostgresConfig;
use crate::error::{Error, Result};
use std::collections::HashMap;
use std::net::IpAddr;
use tracing::warn;
pub fn parse_memory_size(size_str: &str) -> Result<u64> {
let size_str = size_str.trim();
let split_pos = size_str
.char_indices()
.find(|(_, c)| c.is_alphabetic())
.map(|(i, _)| i)
.unwrap_or(size_str.len());
if split_pos == 0 {
return Err(Error::Configuration(format!(
"Invalid memory size format: '{}' (no numeric value)",
size_str
)));
}
let (num_part, unit_part) = size_str.split_at(split_pos);
let value: u64 = num_part.trim().parse().map_err(|_| {
Error::Configuration(format!(
"Invalid memory size value: '{}' is not a valid number",
num_part
))
})?;
let unit = unit_part.trim().to_uppercase();
let multiplier = match unit.as_str() {
"" | "B" => 1,
"KB" => 1024,
"MB" => 1024 * 1024,
"GB" => 1024 * 1024 * 1024,
"TB" => 1024 * 1024 * 1024 * 1024,
_ => {
return Err(Error::Configuration(format!(
"Invalid memory unit: '{}'. Valid units are: B, kB, MB, GB, TB",
unit_part
)));
}
};
Ok(value * multiplier)
}
pub fn validate_memory_size(size_str: &str, min_bytes: u64, max_bytes: u64) -> Result<u64> {
let bytes = parse_memory_size(size_str)?;
if bytes < min_bytes {
return Err(Error::Configuration(format!(
"Memory size '{}' ({} bytes) is below minimum required ({} bytes)",
size_str, bytes, min_bytes
)));
}
if bytes > max_bytes {
return Err(Error::Configuration(format!(
"Memory size '{}' ({} bytes) exceeds maximum allowed ({} bytes)",
size_str, bytes, max_bytes
)));
}
Ok(bytes)
}
pub fn validate_cidr(address: &str) -> Result<()> {
let address = address.trim();
if address == "*" || address.is_empty() {
return Ok(());
}
if let Some((ip_part, prefix_part)) = address.split_once('/') {
let _ip: IpAddr = ip_part.parse().map_err(|_| {
Error::Configuration(format!(
"Invalid IP address in CIDR notation: '{}'",
ip_part
))
})?;
let prefix: u8 = prefix_part.parse().map_err(|_| {
Error::Configuration(format!(
"Invalid CIDR prefix length: '{}' is not a valid number",
prefix_part
))
})?;
let max_prefix = if ip_part.contains(':') { 128 } else { 32 };
if prefix > max_prefix {
return Err(Error::Configuration(format!(
"Invalid CIDR prefix length: {} exceeds maximum of {}",
prefix, max_prefix
)));
}
} else {
let _ip: IpAddr = address
.parse()
.map_err(|_| Error::Configuration(format!("Invalid IP address: '{}'", address)))?;
}
Ok(())
}
pub fn validate_listen_addresses(addresses: &str) -> Result<()> {
if addresses.trim().is_empty() {
return Err(Error::Configuration(
"listen_addresses cannot be empty".to_string(),
));
}
for addr in addresses.split(',') {
validate_cidr(addr.trim())?;
}
Ok(())
}
pub fn check_conflicting_settings(config: &PostgresConfig) -> Result<Vec<String>> {
let mut warnings = Vec::new();
let shared_buffers_bytes = config
.shared_buffers
.as_ref()
.and_then(|s| parse_memory_size(s).ok());
let work_mem_bytes = config
.work_mem
.as_ref()
.and_then(|s| parse_memory_size(s).ok());
let maintenance_work_mem_bytes = config
.maintenance_work_mem
.as_ref()
.and_then(|s| parse_memory_size(s).ok());
let effective_cache_size_bytes = config
.effective_cache_size
.as_ref()
.and_then(|s| parse_memory_size(s).ok());
if let Some(shared_buffers) = shared_buffers_bytes {
if shared_buffers < 128 * 1024 * 1024 {
warnings.push(format!(
"shared_buffers ({}) is very low. Consider at least 128MB",
config.shared_buffers.as_ref().unwrap()
));
}
if shared_buffers > 16 * 1024 * 1024 * 1024 {
warnings.push(format!(
"shared_buffers ({}) is very high. Values above 16GB rarely provide additional benefit",
config.shared_buffers.as_ref().unwrap()
));
}
}
if let (Some(work_mem), Some(max_conn)) = (work_mem_bytes, config.max_connections) {
let total_work_mem = work_mem * max_conn as u64;
if total_work_mem > 8 * 1024 * 1024 * 1024 {
warnings.push(format!(
"work_mem ({}) * max_connections ({}) = {}MB total. This may exceed available RAM",
config.work_mem.as_ref().unwrap(),
max_conn,
total_work_mem / (1024 * 1024)
));
}
}
if let Some(maint_mem) = maintenance_work_mem_bytes {
if maint_mem < 64 * 1024 * 1024 {
warnings.push(format!(
"maintenance_work_mem ({}) is low. Consider at least 64MB for better maintenance operations",
config.maintenance_work_mem.as_ref().unwrap()
));
}
if maint_mem > 2 * 1024 * 1024 * 1024 {
warnings.push(format!(
"maintenance_work_mem ({}) is very high. Values above 2GB rarely help",
config.maintenance_work_mem.as_ref().unwrap()
));
}
}
if let (Some(ecs), Some(sb)) = (effective_cache_size_bytes, shared_buffers_bytes)
&& ecs < sb
{
warnings.push(format!(
"effective_cache_size ({}) should be larger than shared_buffers ({})",
config.effective_cache_size.as_ref().unwrap(),
config.shared_buffers.as_ref().unwrap()
));
}
if let Some(target) = config.checkpoint_completion_target
&& target > 0.9
{
warnings.push(format!(
"checkpoint_completion_target ({}) is very high. Values above 0.9 may cause checkpoint spikes",
target
));
}
if config.port < 1024 {
warnings.push(format!(
"Port {} requires root privileges. Consider using ports >= 1024",
config.port
));
}
Ok(warnings)
}
pub fn validate_resource_limits(config: &PostgresConfig) -> Result<Vec<String>> {
let mut warnings = Vec::new();
if let Some(max_conn) = config.max_connections {
if max_conn < 10 {
warnings.push(format!(
"max_connections ({}) is very low. Most applications need at least 10-20 connections",
max_conn
));
}
if max_conn > 1000 {
warnings.push(format!(
"max_connections ({}) is very high. Consider using connection pooling (pgBouncer, pgPool)",
max_conn
));
}
let estimated_overhead = max_conn as u64 * 10 * 1024 * 1024;
if estimated_overhead > 10 * 1024 * 1024 * 1024 {
warnings.push(format!(
"max_connections ({}) implies {}GB connection overhead. This may exceed system limits",
max_conn,
estimated_overhead / (1024 * 1024 * 1024)
));
}
}
Ok(warnings)
}
pub fn validate_comprehensive(config: &PostgresConfig) -> Result<Vec<String>> {
config.validate()?;
let mut all_warnings = Vec::new();
if let Some(ref shared_buffers) = config.shared_buffers {
match validate_memory_size(shared_buffers, 1024 * 1024, 128 * 1024 * 1024 * 1024) {
Ok(_) => {}
Err(e) => return Err(e),
}
}
if let Some(ref work_mem) = config.work_mem {
match validate_memory_size(work_mem, 64 * 1024, 10 * 1024 * 1024 * 1024) {
Ok(_) => {}
Err(e) => return Err(e),
}
}
if let Some(ref maint_mem) = config.maintenance_work_mem {
match validate_memory_size(maint_mem, 1024 * 1024, 10 * 1024 * 1024 * 1024) {
Ok(_) => {}
Err(e) => return Err(e),
}
}
validate_listen_addresses(&config.listen_addresses)?;
let conflict_warnings = check_conflicting_settings(config)?;
all_warnings.extend(conflict_warnings);
let resource_warnings = validate_resource_limits(config)?;
all_warnings.extend(resource_warnings);
for warning in &all_warnings {
warn!("Configuration warning: {}", warning);
}
Ok(all_warnings)
}
pub fn auto_tune(
total_ram_mb: u64,
cpu_cores: u32,
workload: WorkloadType,
) -> HashMap<String, String> {
let mut config = HashMap::new();
let shared_buffers_mb = (total_ram_mb / 4).clamp(128, 16 * 1024);
config.insert(
"shared_buffers".to_string(),
format!("{}MB", shared_buffers_mb),
);
let ecs_mb = match workload {
WorkloadType::Web => total_ram_mb / 2,
WorkloadType::Mixed => (total_ram_mb * 2) / 3,
WorkloadType::DataWarehouse => (total_ram_mb * 3) / 4,
WorkloadType::Oltp => (total_ram_mb * 2) / 3,
};
config.insert("effective_cache_size".to_string(), format!("{}MB", ecs_mb));
let work_mem_mb = match workload {
WorkloadType::Web => 4,
WorkloadType::Mixed => 16,
WorkloadType::DataWarehouse => 64,
WorkloadType::Oltp => 8,
};
config.insert("work_mem".to_string(), format!("{}MB", work_mem_mb));
let maint_work_mem_mb = (total_ram_mb / 20).clamp(64, 2048);
config.insert(
"maintenance_work_mem".to_string(),
format!("{}MB", maint_work_mem_mb),
);
let max_connections = match workload {
WorkloadType::Web => 200,
WorkloadType::Mixed => 100,
WorkloadType::DataWarehouse => 20,
WorkloadType::Oltp => 300,
};
config.insert("max_connections".to_string(), max_connections.to_string());
let max_workers = cpu_cores.max(8);
config.insert("max_worker_processes".to_string(), max_workers.to_string());
let parallel_workers = (cpu_cores / 4).clamp(2, 4);
config.insert(
"max_parallel_workers_per_gather".to_string(),
parallel_workers.to_string(),
);
config.insert(
"checkpoint_completion_target".to_string(),
"0.9".to_string(),
);
config.insert("wal_buffers".to_string(), "16MB".to_string());
let random_page_cost = match workload {
WorkloadType::Web => "1.1",
WorkloadType::Mixed => "1.1",
WorkloadType::DataWarehouse => "1.1",
WorkloadType::Oltp => "1.1",
};
config.insert("random_page_cost".to_string(), random_page_cost.to_string());
config
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WorkloadType {
Web,
Mixed,
DataWarehouse,
Oltp,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_memory_size() {
assert_eq!(parse_memory_size("256MB").unwrap(), 256 * 1024 * 1024);
assert_eq!(parse_memory_size("1GB").unwrap(), 1024 * 1024 * 1024);
assert_eq!(parse_memory_size("512kB").unwrap(), 512 * 1024);
assert_eq!(
parse_memory_size("2TB").unwrap(),
2 * 1024 * 1024 * 1024 * 1024
);
assert_eq!(parse_memory_size("100").unwrap(), 100);
assert_eq!(parse_memory_size("256mb").unwrap(), 256 * 1024 * 1024);
assert_eq!(parse_memory_size("1gb").unwrap(), 1024 * 1024 * 1024);
assert!(parse_memory_size("").is_err());
assert!(parse_memory_size("MB").is_err());
assert!(parse_memory_size("256XB").is_err());
assert!(parse_memory_size("abc").is_err());
}
#[test]
fn test_validate_memory_size() {
let min = 100 * 1024 * 1024; let max = 10 * 1024 * 1024 * 1024;
assert!(validate_memory_size("256MB", min, max).is_ok());
assert!(validate_memory_size("1GB", min, max).is_ok());
assert!(validate_memory_size("50MB", min, max).is_err());
assert!(validate_memory_size("20GB", min, max).is_err());
}
#[test]
fn test_validate_cidr() {
assert!(validate_cidr("192.168.1.100").is_ok());
assert!(validate_cidr("10.0.0.1").is_ok());
assert!(validate_cidr("::1").is_ok());
assert!(validate_cidr("192.168.1.0/24").is_ok());
assert!(validate_cidr("10.0.0.0/8").is_ok());
assert!(validate_cidr("0.0.0.0/0").is_ok());
assert!(validate_cidr("::/0").is_ok());
assert!(validate_cidr("*").is_ok());
assert!(validate_cidr("999.999.999.999").is_err());
assert!(validate_cidr("192.168.1.0/33").is_err());
assert!(validate_cidr("invalid").is_err());
}
#[test]
fn test_validate_listen_addresses() {
assert!(validate_listen_addresses("192.168.1.100").is_ok());
assert!(validate_listen_addresses("192.168.1.100,10.0.0.1").is_ok());
assert!(validate_listen_addresses("*").is_ok());
assert!(validate_listen_addresses("0.0.0.0/0").is_ok());
assert!(validate_listen_addresses("").is_err());
assert!(validate_listen_addresses("invalid").is_err());
}
#[test]
fn test_auto_tune() {
let tuned = auto_tune(16384, 8, WorkloadType::Web);
assert!(tuned.contains_key("shared_buffers"));
assert!(tuned.contains_key("effective_cache_size"));
assert!(tuned.contains_key("work_mem"));
assert!(tuned.contains_key("max_connections"));
let shared_buffers = tuned.get("shared_buffers").unwrap();
assert!(shared_buffers.contains("MB") || shared_buffers.contains("GB"));
}
}