use crate::tokio_postgres::{Config, MakeTlsConnect, NoTls, Socket};
use either::Either;
#[derive(Clone)]
pub(crate) enum ConnectionParameters {
ConnectionStr(Box<str>),
TokioPostgresConfig(Box<Config>),
}
pub struct PgPubSubOptionsBuilder {
pub(crate) connection_params: ConnectionParameters,
pub(crate) channel_capacity: usize,
pub(crate) suppress_own_notifications: bool,
}
#[derive(Clone)]
pub struct PgPubSubOptions<T: MakeTlsConnect<Socket> + Clone> {
pub(crate) connection_params: ConnectionParameters,
pub(crate) channel_capacity: usize,
pub(crate) suppress_own_notifications: bool,
pub(crate) tls: T,
}
impl PgPubSubOptionsBuilder {
pub fn new(host: &str, dbname: &str, user: &str, password: &str) -> Self {
let connection_str = Self::build_connection_string(host, dbname, user, password);
Self::from_connection_str(&connection_str)
}
pub fn from_connection_str(connection_str: &str) -> Self {
let cfg = ConnectionParameters::ConnectionStr(connection_str.into());
Self::from_connection_params(cfg)
}
pub fn from_tokio_postgres_config(config: Config) -> Self {
let cfg = ConnectionParameters::TokioPostgresConfig(Box::new(config));
Self::from_connection_params(cfg)
}
fn from_connection_params(connection_params: ConnectionParameters) -> Self {
Self {
connection_params,
channel_capacity: 32,
suppress_own_notifications: false,
}
}
pub fn channel_capacity(self, channel_capacity: usize) -> Self {
Self {
channel_capacity,
..self
}
}
pub fn suppress_own_notifications(self, suppress_own_notifications: bool) -> Self {
Self {
suppress_own_notifications,
..self
}
}
fn build_connection_string(host: &str, dbname: &str, user: &str, password: &str) -> String {
format!(
"host={host} dbname={dbname} user={user} password={password}",
host = LibpqValue::from_str(host),
dbname = LibpqValue::from_str(dbname),
user = LibpqValue::from_str(user),
password = LibpqValue::from_str(password),
)
}
pub fn build(self) -> PgPubSubOptions<NoTls> {
self.build_with_tls(NoTls)
}
pub fn build_with_tls<T: MakeTlsConnect<Socket> + Clone>(self, tls: T) -> PgPubSubOptions<T> {
PgPubSubOptions {
connection_params: self.connection_params,
channel_capacity: self.channel_capacity,
tls,
suppress_own_notifications: self.suppress_own_notifications,
}
}
}
struct LibpqValue<'a> {
value: Either<String, &'a str>,
}
impl<'a> LibpqValue<'a> {
#[inline(never)]
fn from_str(input: &'a str) -> Self {
let mut contains_spaces = false;
let mut escape_count = 0;
for ch in input.chars() {
contains_spaces |= ch == ' ';
escape_count += (ch == '\\' || ch == '\'') as usize;
}
if escape_count == 0 && !contains_spaces {
return LibpqValue {
value: Either::Right(input),
};
}
let output_len = input
.len()
.checked_add(escape_count)
.and_then(|len| len.checked_add(2 * contains_spaces as usize))
.expect("Escaped String will exceed the maximum length");
let mut output = String::with_capacity(output_len);
let quote = contains_spaces || input.is_empty();
if quote {
output.push('\''); }
if escape_count == 0 {
output.push_str(input);
} else {
for ch in input.chars() {
if ch == '\\' || ch == '\'' {
output.push('\\');
output.push(ch);
} else {
output.push(ch);
}
}
}
if quote {
output.push('\''); }
debug_assert_eq!(output.len(), output_len);
LibpqValue {
value: Either::Left(output),
}
}
pub fn as_str(&'a self) -> &'a str {
match &self.value {
Either::Left(s) => s,
Either::Right(s) => s,
}
}
}
impl From<LibpqValue<'_>> for String {
fn from(v: LibpqValue<'_>) -> Self {
match v.value {
Either::Left(s) => s,
Either::Right(s) => s.to_owned(),
}
}
}
impl std::fmt::Display for LibpqValue<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{s}", s = self.as_str())
}
}
#[cfg(test)]
mod test {
use std::ptr;
use super::*;
#[test]
fn lib_pq_value_fmt_unescaped_and_unquoted_shares_memory() {
let input = "secret123";
let v = LibpqValue::from_str(input);
let output = v.as_str();
assert!(ptr::eq(input, output))
}
#[test]
fn lib_pq_value_fmt_escaped() {
let input = r#"secret\123"#;
let v = LibpqValue::from_str(input);
let output = v.as_str();
assert!(!ptr::eq(input, output));
assert_eq!(output, r#"secret\\123"#);
}
#[test]
fn lib_pq_value_fmt_quoted() {
let input = "secret 123";
let v = LibpqValue::from_str(input);
let output = v.as_str();
assert!(!ptr::eq(input, output));
assert_eq!(output, "'secret 123'");
}
#[test]
fn format_libpq_string() {
let host = r#"\\PGHOST\"#;
let dbname = "databasename";
let user = "user";
let password = r#"1j( \'9f"#;
let con_str = PgPubSubOptionsBuilder::build_connection_string(host, dbname, user, password);
let expected = r#"host=\\\\PGHOST\\ dbname=databasename user=user password='1j( \\\'9f'"#;
assert_eq!(con_str, expected);
}
}