use core::fmt;
use std::{
num::{NonZeroU32, NonZeroUsize},
str::FromStr,
time::Duration,
};
use backon::{BackoffBuilder as _, ConstantBuilder, Retryable as _};
use secrecy::{ExposeSecret as _, SecretString};
use serde::{Deserialize, Deserializer, de};
use sqlx::{Executor as _, PgPool, postgres::PgPoolOptions};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct SanitizedSchema(String);
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub struct SanitizedSchemaParserError;
impl core::error::Error for SanitizedSchemaParserError {}
impl TryFrom<String> for SanitizedSchema {
type Error = SanitizedSchemaParserError;
fn try_from(value: String) -> Result<Self, Self::Error> {
value.parse()
}
}
impl FromStr for SanitizedSchema {
type Err = SanitizedSchemaParserError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.is_empty() {
return Err(SanitizedSchemaParserError);
}
if s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
Ok(SanitizedSchema(s.to_owned()))
} else {
Err(SanitizedSchemaParserError)
}
}
}
impl fmt::Display for SanitizedSchemaParserError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("schema must contain only ASCII alphanumeric and '_' and must not be empty")
}
}
impl fmt::Display for SanitizedSchema {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl<'de> Deserialize<'de> for SanitizedSchema {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
String::deserialize(deserializer)?
.parse()
.map_err(de::Error::custom)
}
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct PostgresConfig {
pub connection_string: SecretString,
pub schema: SanitizedSchema,
#[serde(default = "PostgresConfig::default_max_connections")]
pub max_connections: NonZeroU32,
#[serde(default = "PostgresConfig::default_acquire_timeout")]
#[serde(with = "humantime_serde")]
pub acquire_timeout: Duration,
#[serde(default = "PostgresConfig::default_slow_acquire_threshold")]
#[serde(with = "humantime_serde")]
pub slow_acquire_threshold: Duration,
#[serde(default = "PostgresConfig::default_max_retries")]
pub max_retries: NonZeroUsize,
#[serde(default = "PostgresConfig::default_retry_delay")]
#[serde(with = "humantime_serde")]
pub retry_delay: Duration,
}
impl PostgresConfig {
fn default_max_connections() -> NonZeroU32 {
NonZeroU32::try_from(4).expect("Is non-zero")
}
fn default_acquire_timeout() -> Duration {
Duration::from_secs(120)
}
fn default_slow_acquire_threshold() -> Duration {
Duration::from_secs(90)
}
fn default_max_retries() -> NonZeroUsize {
NonZeroUsize::try_from(20).expect("Is non-zero")
}
fn default_retry_delay() -> Duration {
Duration::from_secs(5)
}
#[must_use]
pub fn with_default_values(connection_string: SecretString, schema: SanitizedSchema) -> Self {
Self {
connection_string,
schema,
max_connections: Self::default_max_connections(),
acquire_timeout: Self::default_acquire_timeout(),
slow_acquire_threshold: Self::default_slow_acquire_threshold(),
max_retries: Self::default_max_retries(),
retry_delay: Self::default_retry_delay(),
}
}
}
#[must_use]
#[inline]
fn schema_connect_with_create(schema: &SanitizedSchema) -> String {
format!(
r#"
CREATE SCHEMA IF NOT EXISTS "{schema}";
SET search_path TO "{schema}";
"#
)
}
fn schema_connect(schema: &SanitizedSchema) -> String {
format!(
r#"
SET search_path TO "{schema}";
"#
)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[allow(clippy::exhaustive_enums, reason = "Is a boolean switch")]
pub enum CreateSchema {
Yes,
No,
}
pub async fn pg_pool_with_schema(
config: &PostgresConfig,
create_schema: CreateSchema,
) -> Result<PgPool, sqlx::Error> {
let schema_connect = match create_schema {
CreateSchema::Yes => schema_connect_with_create(&config.schema),
CreateSchema::No => schema_connect(&config.schema),
};
let backoff_strategy = ConstantBuilder::new()
.with_delay(config.retry_delay)
.with_max_times(config.max_retries.get())
.build();
let pg_pool_options = PgPoolOptions::new()
.max_connections(config.max_connections.get())
.acquire_timeout(config.acquire_timeout)
.acquire_slow_threshold(config.slow_acquire_threshold)
.after_connect(move |conn, _| {
let schema_connect = schema_connect.clone();
Box::pin(async move {
if let Err(e) = conn.execute(schema_connect.as_ref()).await {
tracing::error!("error in after_connect: {:?}", e);
return Err(e);
}
Ok(())
})
});
(|| {
pg_pool_options
.clone()
.connect(config.connection_string.expose_secret())
})
.retry(backoff_strategy)
.sleep(tokio::time::sleep)
.when(is_retryable_error)
.notify(|e, duration| {
tracing::warn!("Failed to create pool: {e:?}. Retry after {duration:?}");
})
.await
}
#[inline]
fn is_retryable_error(e: &sqlx::Error) -> bool {
matches!(
e,
sqlx::Error::PoolTimedOut
| sqlx::Error::Io(_)
| sqlx::Error::Tls(_)
| sqlx::Error::Protocol(_)
| sqlx::Error::AnyDriverError(_)
| sqlx::Error::WorkerCrashed
)
}