use std::path::PathBuf;
use std::time::Duration;
use crate::lsn::Lsn;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SslMode {
#[default]
Disable,
Prefer,
Require,
VerifyCa,
VerifyFull,
}
impl SslMode {
#[inline]
pub fn requires_tls(&self) -> bool {
!matches!(self, SslMode::Disable | SslMode::Prefer)
}
#[inline]
pub fn verifies_certificate(&self) -> bool {
matches!(self, SslMode::VerifyCa | SslMode::VerifyFull)
}
#[inline]
pub fn verifies_hostname(&self) -> bool {
matches!(self, SslMode::VerifyFull)
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct TlsConfig {
pub mode: SslMode,
pub ca_pem_path: Option<PathBuf>,
pub sni_hostname: Option<String>,
pub client_cert_pem_path: Option<PathBuf>,
pub client_key_pem_path: Option<PathBuf>,
}
impl TlsConfig {
pub fn disabled() -> Self {
Self::default()
}
pub fn require() -> Self {
Self {
mode: SslMode::Require,
..Default::default()
}
}
pub fn verify_ca(ca_path: Option<PathBuf>) -> Self {
Self {
mode: SslMode::VerifyCa,
ca_pem_path: ca_path,
..Default::default()
}
}
pub fn verify_full(ca_path: Option<PathBuf>) -> Self {
Self {
mode: SslMode::VerifyFull,
ca_pem_path: ca_path,
..Default::default()
}
}
pub fn with_sni_hostname(mut self, hostname: impl Into<String>) -> Self {
self.sni_hostname = Some(hostname.into());
self
}
pub fn with_client_cert(
mut self,
cert_path: impl Into<PathBuf>,
key_path: impl Into<PathBuf>,
) -> Self {
self.client_cert_pem_path = Some(cert_path.into());
self.client_key_pem_path = Some(key_path.into());
self
}
#[inline]
pub fn is_mtls(&self) -> bool {
self.client_cert_pem_path.is_some() && self.client_key_pem_path.is_some()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ReplicationConfig {
pub host: String,
pub port: u16,
pub user: String,
pub password: String,
pub database: String,
pub tls: TlsConfig,
pub slot: String,
pub publication: String,
pub start_lsn: Lsn,
pub stop_at_lsn: Option<Lsn>,
pub status_interval: Duration,
pub idle_wakeup_interval: Duration,
pub buffer_events: usize,
}
impl Default for ReplicationConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".into(),
port: 5432,
user: "postgres".into(),
password: "postgres".into(),
database: "postgres".into(),
tls: TlsConfig::default(),
slot: "slot".into(),
publication: "pub".into(),
start_lsn: Lsn(0),
stop_at_lsn: None,
status_interval: Duration::from_secs(10),
idle_wakeup_interval: Duration::from_secs(10),
buffer_events: 8192,
}
}
}
impl ReplicationConfig {
pub fn new(
host: impl Into<String>,
user: impl Into<String>,
password: impl Into<String>,
database: impl Into<String>,
slot: impl Into<String>,
publication: impl Into<String>,
) -> Self {
Self {
host: host.into(),
user: user.into(),
password: password.into(),
database: database.into(),
slot: slot.into(),
publication: publication.into(),
..Default::default()
}
}
#[inline]
pub fn is_unix_socket(&self) -> bool {
self.host.starts_with('/')
}
pub fn unix_socket_path(&self) -> std::path::PathBuf {
assert!(
self.is_unix_socket(),
"unix_socket_path() called but host is not a socket directory: {:?}",
self.host
);
std::path::Path::new(&self.host).join(format!(".s.PGSQL.{}", self.port))
}
pub fn unix(
socket_dir: impl Into<String>,
port: u16,
user: impl Into<String>,
password: impl Into<String>,
database: impl Into<String>,
slot: impl Into<String>,
publication: impl Into<String>,
) -> Self {
Self {
host: socket_dir.into(),
port,
user: user.into(),
password: password.into(),
database: database.into(),
tls: TlsConfig::disabled(),
slot: slot.into(),
publication: publication.into(),
..Default::default()
}
}
pub fn with_port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn with_tls(mut self, tls: TlsConfig) -> Self {
self.tls = tls;
self
}
pub fn with_start_lsn(mut self, lsn: Lsn) -> Self {
self.start_lsn = lsn;
self
}
pub fn with_stop_lsn(mut self, lsn: Lsn) -> Self {
self.stop_at_lsn = Some(lsn);
self
}
pub fn with_status_interval(mut self, interval: Duration) -> Self {
self.status_interval = interval;
self
}
pub fn with_wakeup_interval(mut self, timeout: Duration) -> Self {
self.idle_wakeup_interval = timeout;
self
}
pub fn with_buffer_size(mut self, size: usize) -> Self {
self.buffer_events = size;
self
}
pub fn display_connection(&self) -> String {
if self.is_unix_socket() {
format!(
"postgresql://{}:***@[{}]:{}/{}",
self.user,
self.unix_socket_path().display(),
self.port,
self.database
)
} else {
format!(
"postgresql://{}:***@{}:{}/{}",
self.user, self.host, self.port, self.database
)
}
}
}