pgpubsub 2.0.0

Async PostgreSQL LISTEN/NOTIFY pub/sub client built on tokio-postgres
Documentation
use crate::tokio_postgres::{Config, MakeTlsConnect, NoTls, Socket};
use either::Either;

#[derive(Clone)]
pub(crate) enum ConnectionParameters {
    ConnectionStr(Box<str>),
    TokioPostgresConfig(Box<Config>),
}

pub struct PgPubSubOptionsBuilder {
    pub(crate) connection_params: ConnectionParameters,
    pub(crate) channel_capacity: usize,
    pub(crate) suppress_own_notifications: bool,
}

#[derive(Clone)]
pub struct PgPubSubOptions<T: MakeTlsConnect<Socket> + Clone> {
    pub(crate) connection_params: ConnectionParameters,
    pub(crate) channel_capacity: usize,
    pub(crate) suppress_own_notifications: bool,
    pub(crate) tls: T,
}

impl PgPubSubOptionsBuilder {
    /// Configuration for connecting to PostgreSQL with the given parameters.
    pub fn new(host: &str, dbname: &str, user: &str, password: &str) -> Self {
        let connection_str = Self::build_connection_string(host, dbname, user, password);
        Self::from_connection_str(&connection_str)
    }

    /// Configuration for connecting to PostgreSQL with the given connection string as-is, using one
    /// of two possible formats:
    /// 1) "host=localhost dbname=name_of_database user=name_of_user password=user_password"
    /// 2) "postgresql:///mydb?user=user&host=/var/lib/postgresql"
    ///
    /// See: https://docs.rs/tokio-postgres/0.7.8/tokio_postgres/config/struct.Config.html
    pub fn from_connection_str(connection_str: &str) -> Self {
        let cfg = ConnectionParameters::ConnectionStr(connection_str.into());
        Self::from_connection_params(cfg)
    }

    /// Connects to PostgreSQL with the given tokio_postgres::Config, re-exported by this lib as
    /// `pgpubsub::tokio_postgres::Config`.
    ///
    /// See: https://docs.rs/tokio-postgres/0.7.8/tokio_postgres/config/struct.Config.html
    pub fn from_tokio_postgres_config(config: Config) -> Self {
        let cfg = ConnectionParameters::TokioPostgresConfig(Box::new(config));
        Self::from_connection_params(cfg)
    }

    fn from_connection_params(connection_params: ConnectionParameters) -> Self {
        Self {
            connection_params,
            channel_capacity: 32,
            suppress_own_notifications: false,
        }
    }

    /// Sets the per-channel capacity of the broadcast channel that delivers notifications.
    ///
    /// A capacity of 0 is treated as 1 (`tokio`'s broadcast channel panics on a zero
    /// capacity, and a subscription that can never receive anything is meaningless).
    pub fn channel_capacity(self, channel_capacity: usize) -> Self {
        Self {
            channel_capacity: channel_capacity.max(1),
            ..self
        }
    }

    /// Sets whether notifications that we send ourselves should be received by us. Defaults to no.
    pub fn suppress_own_notifications(self, suppress_own_notifications: bool) -> Self {
        Self {
            suppress_own_notifications,
            ..self
        }
    }

    fn build_connection_string(host: &str, dbname: &str, user: &str, password: &str) -> String {
        // This format consists of space-separated key-value pairs. Values which are either the
        // empty string or contain whitespace should be wrapped in '' and \ characters should be
        // backslash-escaped.
        // https://docs.rs/tokio-postgres/0.7.8/tokio_postgres/config/struct.Config.html
        format!(
            "host={host} dbname={dbname} user={user} password={password}",
            host = LibpqValue::from_str(host),
            dbname = LibpqValue::from_str(dbname),
            user = LibpqValue::from_str(user),
            password = LibpqValue::from_str(password),
        )
    }

    pub fn build(self) -> PgPubSubOptions<NoTls> {
        self.build_with_tls(NoTls)
    }

    /// Build with the given Tls option given as a `tokio_postgres::tls::MakeTlsConnect<Socket>`.
    /// All useful options here should be re-exported by this crate, for example
    /// `pgpubsub::tokio_postgres::NoTls`.
    pub fn build_with_tls<T: MakeTlsConnect<Socket> + Clone>(self, tls: T) -> PgPubSubOptions<T> {
        PgPubSubOptions {
            connection_params: self.connection_params,
            channel_capacity: self.channel_capacity,
            tls,
            suppress_own_notifications: self.suppress_own_notifications,
        }
    }
}

struct LibpqValue<'a> {
    value: Either<String, &'a str>,
}

impl<'a> LibpqValue<'a> {
    fn from_str(input: &'a str) -> Self {
        // Empty values must be quoted as `''` so libpq parses them as an explicit empty
        // value rather than treating the key as missing.
        if input.is_empty() {
            return LibpqValue {
                value: Either::Left("''".to_owned()),
            };
        }

        // Any whitespace (not just ' ') would otherwise be taken as a key-value
        // separator by the libpq-style parser, so all of it forces quoting.
        let mut needs_quoting = false;
        let mut escape_count = 0;
        for ch in input.chars() {
            needs_quoting |= ch.is_whitespace();
            escape_count += (ch == '\\' || ch == '\'') as usize;
        }
        if escape_count == 0 && !needs_quoting {
            return LibpqValue {
                value: Either::Right(input),
            };
        }

        let output_len = input
            .len()
            .checked_add(escape_count)
            .and_then(|len| len.checked_add(2 * needs_quoting as usize))
            .expect("Escaped String will exceed the maximum length");

        let mut output = String::with_capacity(output_len);

        if needs_quoting {
            output.push('\'');
        }

        if escape_count == 0 {
            output.push_str(input);
        } else {
            for ch in input.chars() {
                if ch == '\\' || ch == '\'' {
                    output.push('\\');
                }
                output.push(ch);
            }
        }

        if needs_quoting {
            output.push('\'');
        }

        debug_assert_eq!(output.len(), output_len);

        LibpqValue {
            value: Either::Left(output),
        }
    }

    fn as_str(&self) -> &str {
        match &self.value {
            Either::Left(s) => s,
            Either::Right(s) => s,
        }
    }
}

impl std::fmt::Display for LibpqValue<'_> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(self.as_str())
    }
}

#[cfg(test)]
mod test {
    use std::ptr;

    use super::*;

    #[test]
    fn lib_pq_value_fmt_unescaped_and_unquoted_shares_memory() {
        let input = "secret123";
        let v = LibpqValue::from_str(input);
        let output = v.as_str();
        // input and output should share the same data if there was no escaping or quoting.
        assert!(ptr::eq(input, output))
    }

    #[test]
    fn lib_pq_value_fmt_escaped() {
        let input = r#"secret\123"#;
        let v = LibpqValue::from_str(input);
        let output = v.as_str();
        // input and output do not share the same data since the string had to be escaped.
        assert!(!ptr::eq(input, output));
        // The backslash should have been escaped so that there are now double backslashes.
        assert_eq!(output, r#"secret\\123"#);
    }

    #[test]
    fn lib_pq_value_fmt_quoted() {
        let input = "secret 123";
        let v = LibpqValue::from_str(input);
        let output = v.as_str();
        // input and output do not share the same data since the string had to be escaped.
        assert!(!ptr::eq(input, output));
        // The string should be enclosed in single quotes.
        assert_eq!(output, "'secret 123'");
    }

    #[test]
    fn lib_pq_value_quotes_all_whitespace_not_just_spaces() {
        // Tabs and newlines are also key-value separators for the libpq-style parser;
        // leaving them unquoted would split the value.
        assert_eq!(LibpqValue::from_str("a\tb").as_str(), "'a\tb'");
        assert_eq!(LibpqValue::from_str("a\nb").as_str(), "'a\nb'");
        assert_eq!(LibpqValue::from_str("a\u{a0}b").as_str(), "'a\u{a0}b'");
    }

    #[test]
    fn lib_pq_value_fmt_empty_is_quoted() {
        let v = LibpqValue::from_str("");
        // Empty values must be wrapped in '' so libpq parses them as explicit empty rather
        // than treating the key as missing.
        assert_eq!(v.as_str(), "''");
    }

    #[test]
    fn zero_channel_capacity_is_clamped_to_one() {
        // tokio's broadcast::channel panics on capacity 0; the clamp keeps that panic
        // from surfacing later, deep inside the first listen() call.
        let builder = PgPubSubOptionsBuilder::new("h", "d", "u", "p").channel_capacity(0);
        assert_eq!(builder.channel_capacity, 1);
    }

    #[test]
    fn nonzero_channel_capacity_is_kept() {
        let builder = PgPubSubOptionsBuilder::new("h", "d", "u", "p").channel_capacity(16);
        assert_eq!(builder.channel_capacity, 16);
    }

    #[test]
    fn format_libpq_string() {
        let host = r#"\\PGHOST\"#;
        let dbname = "databasename";
        let user = "user";
        let password = r#"1j( \'9f"#;
        let con_str = PgPubSubOptionsBuilder::build_connection_string(host, dbname, user, password);

        let expected = r#"host=\\\\PGHOST\\ dbname=databasename user=user password='1j( \\\'9f'"#;

        assert_eq!(con_str, expected);
    }

    /// The authoritative check: whatever we build must survive a round-trip through the
    /// actual `tokio_postgres::Config` parser, including whitespace, quotes, and
    /// backslashes in the values.
    #[test]
    fn built_connection_string_round_trips_through_config() {
        let host = "some host";
        let dbname = "data base";
        let user = "user'name";
        let password = "we ird\tpass\\'word\nx";
        let con_str = PgPubSubOptionsBuilder::build_connection_string(host, dbname, user, password);

        let config: Config = con_str.parse().expect("config should parse");
        assert_eq!(
            config.get_hosts(),
            &[tokio_postgres::config::Host::Tcp(host.to_owned())]
        );
        assert_eq!(config.get_dbname(), Some(dbname));
        assert_eq!(config.get_user(), Some(user));
        assert_eq!(config.get_password(), Some(password.as_bytes()));
    }

    #[test]
    fn empty_password_round_trips_through_config() {
        let con_str = PgPubSubOptionsBuilder::build_connection_string("h", "db", "u", "");
        let config: Config = con_str.parse().expect("config should parse");
        assert_eq!(config.get_password(), Some(&b""[..]));
    }
}