pub mod analyze;
pub mod partitioned_index;
use crate::Config;
use crate::config::{Endpoint, SslMode};
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct Settings {
pub statement_cache_capacity: Option<usize>,
pub log_statements: Option<log::LevelFilter>,
pub log_slow_statements: Option<(log::LevelFilter, std::time::Duration)>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OptionsError {
UnsupportedFeature { field_name: String },
SslRootCertSystemNotSupported,
}
impl std::fmt::Display for OptionsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnsupportedFeature { field_name } => write!(
f,
"`pg_client::Config` specifies `{field_name}`, but sqlx's `PgConnectOptions` does not support that feature"
),
Self::SslRootCertSystemNotSupported => write!(
f,
"`SslRootCert::System` is not supported by sqlx, which expects a file path for `ssl_root_cert`"
),
}
}
}
impl std::error::Error for OptionsError {}
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
#[error("Failed to create SQLx connect options")]
Options(#[from] OptionsError),
#[error("Failed to connect to database")]
Connect(#[source] sqlx::Error),
#[error("Failed to close database connection")]
Close(#[source] sqlx::Error),
}
impl From<&SslMode> for sqlx::postgres::PgSslMode {
fn from(value: &SslMode) -> Self {
match value {
SslMode::Allow => Self::Allow,
SslMode::Disable => Self::Disable,
SslMode::Prefer => Self::Prefer,
SslMode::Require => Self::Require,
SslMode::VerifyCa => Self::VerifyCa,
SslMode::VerifyFull => Self::VerifyFull,
}
}
}
impl Config {
pub fn to_sqlx_connect_options(
&self,
) -> Result<sqlx::postgres::PgConnectOptions, OptionsError> {
let mut options = sqlx::postgres::PgConnectOptions::default_without_env();
options = options.database(self.session.database.as_str());
match &self.endpoint {
Endpoint::Network {
host,
channel_binding,
host_addr,
port,
} => {
options = options.host(&host.pg_env_value());
if let Some(port) = port {
options = options.port(port.into());
}
if channel_binding.is_some() {
return Err(OptionsError::UnsupportedFeature {
field_name: "channel_binding".to_string(),
});
}
if let Some(host_addr) = host_addr {
options = options.host_addr(&host_addr.to_string())
}
}
Endpoint::SocketPath(path) => {
options = options.host(path.to_str().expect("socket path contains invalid utf8"));
}
}
options = options.ssl_mode((&self.ssl_mode).into());
options = options.username(self.session.user.as_str());
if let Some(application_name) = &self.session.application_name {
options = options.application_name(application_name.as_str());
}
if let Some(password) = &self.session.password {
options = options.password(password.as_str());
}
if let Some(ssl_root_cert) = &self.ssl_root_cert {
match ssl_root_cert {
crate::config::SslRootCert::File(path) => {
options = options.ssl_root_cert(path.to_str().unwrap());
}
crate::config::SslRootCert::System => {
return Err(OptionsError::SslRootCertSystemNotSupported);
}
}
}
if let Some(capacity) = self.sqlx.statement_cache_capacity {
options = options.statement_cache_capacity(capacity);
}
if let Some(level) = self.sqlx.log_statements {
options = sqlx::ConnectOptions::log_statements(options, level);
}
if let Some((level, duration)) = self.sqlx.log_slow_statements {
options = sqlx::ConnectOptions::log_slow_statements(options, level, duration);
}
Ok(options)
}
pub async fn with_sqlx_connection<T, F: AsyncFnMut(&mut sqlx::postgres::PgConnection) -> T>(
&self,
mut action: F,
) -> Result<T, ConnectionError> {
let config = self.to_sqlx_connect_options()?;
let mut connection = sqlx::ConnectOptions::connect(&config)
.await
.map_err(ConnectionError::Connect)?;
let result = action(&mut connection).await;
sqlx::Connection::close(connection)
.await
.map_err(ConnectionError::Close)?;
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{Endpoint, Host, Port, SslMode, SslRootCert};
use crate::{Database, User};
use std::str::FromStr;
const TEST_DATABASE: Database = Database::from_static_or_panic("some-database");
const TEST_USER: User = User::from_static_or_panic("some-user");
fn test_config(sqlx: Settings) -> Config {
Config {
endpoint: Endpoint::Network {
host: Host::from_str("localhost").unwrap(),
channel_binding: None,
host_addr: None,
port: Some(Port::new(5432)),
},
session: crate::config::Session {
application_name: None,
database: TEST_DATABASE,
password: None,
user: TEST_USER,
},
ssl_mode: SslMode::Disable,
ssl_root_cert: None,
sqlx,
}
}
#[test]
fn test_statement_cache_capacity_default() {
let options = test_config(Default::default())
.to_sqlx_connect_options()
.unwrap();
let debug = format!("{options:?}");
assert!(
debug.contains("statement_cache_capacity: 100"),
"Expected default statement_cache_capacity of 100, got: {debug}"
);
}
#[test]
fn test_statement_cache_capacity_override() {
let options = test_config(Settings {
statement_cache_capacity: Some(42),
..Default::default()
})
.to_sqlx_connect_options()
.unwrap();
let debug = format!("{options:?}");
assert!(
debug.contains("statement_cache_capacity: 42"),
"Expected statement_cache_capacity of 42, got: {debug}"
);
}
#[test]
fn test_log_statements_override() {
let options = test_config(Settings {
log_statements: Some(log::LevelFilter::Off),
..Default::default()
})
.to_sqlx_connect_options()
.unwrap();
let debug = format!("{options:?}");
assert!(
debug.contains("statements_level: Off"),
"Expected statements_level: Off, got: {debug}"
);
}
#[test]
fn test_log_slow_statements_override() {
let options = test_config(Settings {
log_slow_statements: Some((log::LevelFilter::Warn, std::time::Duration::from_secs(5))),
..Default::default()
})
.to_sqlx_connect_options()
.unwrap();
let debug = format!("{options:?}");
assert!(
debug.contains("slow_statements_level: Warn"),
"Expected slow_statements_level: Warn, got: {debug}"
);
assert!(
debug.contains("slow_statements_duration: 5s"),
"Expected slow_statements_duration: 5s, got: {debug}"
);
}
#[test]
fn test_ssl_root_cert_system_not_supported() {
let config = Config {
ssl_mode: SslMode::VerifyFull,
ssl_root_cert: Some(SslRootCert::System),
..test_config(Default::default())
};
let result = config.to_sqlx_connect_options();
assert!(matches!(
result,
Err(OptionsError::SslRootCertSystemNotSupported)
));
}
}