use std::{fmt, num::NonZeroU32, time::Duration};
use thiserror::Error;
pub const DEFAULT_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(5);
pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_mins(5);
pub const DEFAULT_MAX_LIFETIME: Duration = Duration::from_mins(30);
pub const DEFAULT_MAX_CONNECTIONS: NonZeroU32 = match NonZeroU32::new(10) {
Some(value) => value,
None => unreachable!(),
};
pub const DEFAULT_MIN_CONNECTIONS: u32 = 0;
#[derive(Clone, Copy, Debug, Eq, Error, PartialEq)]
pub enum ConfigBuilderError {
#[error(
"Postgres min connections ({min_connections}) cannot exceed max connections ({max_connections})"
)]
MinConnectionsExceedsMaxConnections {
min_connections: u32,
max_connections: NonZeroU32,
},
}
#[derive(Clone, Eq, PartialEq)]
pub struct Password(String);
impl Password {
#[must_use]
pub fn new(password: impl Into<String>) -> Self {
Self(password.into())
}
#[must_use]
pub const fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl fmt::Debug for Password {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<redacted>")
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct ConfigBuilder {
host: Option<String>,
port: Option<u16>,
username: Option<String>,
password: Option<Password>,
database: Option<String>,
application_name: Option<String>,
acquire_timeout: Option<Duration>,
idle_timeout: Option<Duration>,
max_lifetime: Option<Duration>,
max_connections: Option<NonZeroU32>,
min_connections: Option<u32>,
statement_cache_capacity: Option<usize>,
}
impl ConfigBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_host(mut self, host: impl Into<String>) -> Self {
self.host = Some(host.into());
self
}
#[must_use]
pub const fn with_port(mut self, port: u16) -> Self {
self.port = Some(port);
self
}
#[must_use]
pub fn with_username(mut self, username: impl Into<String>) -> Self {
self.username = Some(username.into());
self
}
#[must_use]
pub fn with_password(mut self, password: impl Into<String>) -> Self {
self.password = Some(Password::new(password));
self
}
#[must_use]
pub fn with_database(mut self, database: impl Into<String>) -> Self {
self.database = Some(database.into());
self
}
#[must_use]
pub fn with_application_name(mut self, application_name: impl Into<String>) -> Self {
self.application_name = Some(application_name.into());
self
}
#[must_use]
pub const fn with_acquire_timeout(mut self, acquire_timeout: Duration) -> Self {
self.acquire_timeout = Some(acquire_timeout);
self
}
#[must_use]
pub const fn with_idle_timeout(mut self, idle_timeout: Duration) -> Self {
self.idle_timeout = Some(idle_timeout);
self
}
#[must_use]
pub const fn with_max_lifetime(mut self, max_lifetime: Duration) -> Self {
self.max_lifetime = Some(max_lifetime);
self
}
#[must_use]
pub const fn with_max_connections(mut self, max_connections: NonZeroU32) -> Self {
self.max_connections = Some(max_connections);
self
}
#[must_use]
pub const fn with_min_connections(mut self, min_connections: u32) -> Self {
self.min_connections = Some(min_connections);
self
}
#[must_use]
pub const fn with_statement_cache_capacity(mut self, capacity: usize) -> Self {
self.statement_cache_capacity = Some(capacity);
self
}
pub fn build(self) -> Result<Config, ConfigBuilderError> {
let max_connections = self.max_connections.unwrap_or(DEFAULT_MAX_CONNECTIONS);
let min_connections = self.min_connections.unwrap_or(DEFAULT_MIN_CONNECTIONS);
if min_connections > max_connections.get() {
return Err(ConfigBuilderError::MinConnectionsExceedsMaxConnections {
min_connections,
max_connections,
});
}
Ok(Config {
host: self.host,
username: self.username,
password: self.password,
database: self.database,
application_name: self.application_name,
port: self.port,
statement_cache_capacity: self.statement_cache_capacity,
acquire_timeout: self.acquire_timeout.unwrap_or(DEFAULT_ACQUIRE_TIMEOUT),
idle_timeout: self.idle_timeout.unwrap_or(DEFAULT_IDLE_TIMEOUT),
max_lifetime: self.max_lifetime.unwrap_or(DEFAULT_MAX_LIFETIME),
max_connections,
min_connections,
})
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Config {
host: Option<String>,
port: Option<u16>,
username: Option<String>,
password: Option<Password>,
database: Option<String>,
application_name: Option<String>,
statement_cache_capacity: Option<usize>,
acquire_timeout: Duration,
idle_timeout: Duration,
max_lifetime: Duration,
max_connections: NonZeroU32,
min_connections: u32,
}
impl Config {
#[must_use]
pub fn builder() -> ConfigBuilder {
ConfigBuilder::new()
}
#[must_use]
pub const fn host(&self) -> Option<&str> {
match &self.host {
Some(host) => Some(host.as_str()),
None => None,
}
}
#[must_use]
pub const fn port(&self) -> Option<u16> {
self.port
}
#[must_use]
pub const fn username(&self) -> Option<&str> {
match &self.username {
Some(username) => Some(username.as_str()),
None => None,
}
}
#[must_use]
pub const fn password(&self) -> Option<&Password> {
match &self.password {
Some(password) => Some(password),
None => None,
}
}
#[must_use]
pub const fn database(&self) -> Option<&str> {
match &self.database {
Some(database) => Some(database.as_str()),
None => None,
}
}
#[must_use]
pub const fn application_name(&self) -> Option<&str> {
match &self.application_name {
Some(application_name) => Some(application_name.as_str()),
None => None,
}
}
#[must_use]
pub const fn acquire_timeout(&self) -> Duration {
self.acquire_timeout
}
#[must_use]
pub const fn idle_timeout(&self) -> Duration {
self.idle_timeout
}
#[must_use]
pub const fn max_lifetime(&self) -> Duration {
self.max_lifetime
}
#[must_use]
pub const fn max_connections(&self) -> NonZeroU32 {
self.max_connections
}
#[must_use]
pub const fn min_connections(&self) -> u32 {
self.min_connections
}
#[must_use]
pub const fn statement_cache_capacity(&self) -> Option<usize> {
self.statement_cache_capacity
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_without_connection_overrides_uses_no_fake_connection_values() {
let config = build_config(test_builder());
assert_eq!(
config.host(),
None,
"config must not invent a Postgres host"
);
assert_eq!(
config.port(),
None,
"config must not invent a Postgres port"
);
assert_eq!(
config.username(),
None,
"config must not invent a Postgres username"
);
assert_eq!(
config.password().map(Password::as_str),
None,
"config must not invent a Postgres password"
);
assert_eq!(
config.database(),
None,
"config must not invent a Postgres database"
);
assert_eq!(
config.application_name(),
None,
"config must not invent an application name"
);
assert_eq!(
config.statement_cache_capacity(),
None,
"config must not invent a statement cache capacity"
);
}
#[test]
fn build_uses_default_pool_values() {
let config = build_config(test_builder());
assert_eq!(
config.acquire_timeout(),
DEFAULT_ACQUIRE_TIMEOUT,
"config must default the acquire timeout"
);
assert_eq!(
config.idle_timeout(),
DEFAULT_IDLE_TIMEOUT,
"config must default the idle timeout"
);
assert_eq!(
config.max_lifetime(),
DEFAULT_MAX_LIFETIME,
"config must default the max lifetime"
);
assert_eq!(
config.max_connections(),
DEFAULT_MAX_CONNECTIONS,
"config must default the max connections"
);
assert_eq!(
config.min_connections(),
DEFAULT_MIN_CONNECTIONS,
"config must default the min connections"
);
}
#[test]
fn build_uses_overridden_connection_values() {
let host = "localhost";
let port = 15432;
let username = "recomp";
let password = "";
let database = "recomp_development";
let application_name = "recomp";
let config = build_config(
test_builder()
.with_host(host)
.with_port(port)
.with_username(username)
.with_password(password)
.with_database(database)
.with_application_name(application_name),
);
assert_eq!(
config.host(),
Some(host),
"config must preserve the supplied Postgres host"
);
assert_eq!(
config.port(),
Some(port),
"config must preserve the supplied Postgres port"
);
assert_eq!(
config.username(),
Some(username),
"config must preserve the supplied Postgres username"
);
assert_eq!(
config.password().map(Password::as_str),
Some(password),
"config must preserve an explicitly supplied empty Postgres password"
);
assert_eq!(
config.database(),
Some(database),
"config must preserve the supplied Postgres database"
);
assert_eq!(
config.application_name(),
Some(application_name),
"config must preserve the supplied application name"
);
}
#[test]
fn build_uses_overridden_pool_values() {
let acquire_timeout = Duration::from_secs(2);
let idle_timeout = Duration::from_secs(3);
let max_lifetime = Duration::from_secs(4);
let max_connections = non_zero_u32(20);
let min_connections = 5;
let statement_cache_capacity = 256;
let config = build_config(
test_builder()
.with_acquire_timeout(acquire_timeout)
.with_idle_timeout(idle_timeout)
.with_max_lifetime(max_lifetime)
.with_max_connections(max_connections)
.with_min_connections(min_connections)
.with_statement_cache_capacity(statement_cache_capacity),
);
assert_eq!(
config.acquire_timeout(),
acquire_timeout,
"config must preserve the supplied acquire timeout"
);
assert_eq!(
config.idle_timeout(),
idle_timeout,
"config must preserve the supplied idle timeout"
);
assert_eq!(
config.max_lifetime(),
max_lifetime,
"config must preserve the supplied max lifetime"
);
assert_eq!(
config.max_connections(),
max_connections,
"config must preserve the supplied max connections"
);
assert_eq!(
config.min_connections(),
min_connections,
"config must preserve the supplied min connections"
);
assert_eq!(
config.statement_cache_capacity(),
Some(statement_cache_capacity),
"config must preserve the supplied statement cache capacity"
);
}
#[test]
fn debug_redacts_password() {
let config = build_config(test_builder().with_password("secret-password"));
let debug = format!("{config:?}");
assert!(
debug.contains("<redacted>"),
"debug output must show that the password was redacted: {debug}"
);
assert!(
!debug.contains("secret-password"),
"debug output must not contain the Postgres password: {debug}"
);
}
#[test]
fn build_rejects_min_connections_over_max_connections() {
let max_connections = non_zero_u32(2);
let min_connections = 3;
let error = match test_builder()
.with_max_connections(max_connections)
.with_min_connections(min_connections)
.build()
{
Ok(config) => panic!(
"min connections above max connections must be rejected, got config: {config:?}"
),
Err(error) => error,
};
assert_eq!(
error,
ConfigBuilderError::MinConnectionsExceedsMaxConnections {
min_connections,
max_connections,
},
"builder must report the invalid pool size relationship"
);
}
fn test_builder() -> ConfigBuilder {
Config::builder()
}
fn build_config(builder: ConfigBuilder) -> Config {
match builder.build() {
Ok(config) => config,
Err(error) => panic!("test config must build: {error}"),
}
}
const fn non_zero_u32(value: u32) -> NonZeroU32 {
match NonZeroU32::new(value) {
Some(value) => value,
None => panic!("test value must be non-zero"),
}
}
}