qail-pg 1.1.1

Rust PostgreSQL driver for typed AST queries with direct wire-protocol execution
Documentation
//! PgDriver — high-level async PostgreSQL driver combining the wire-protocol
//! encoder with connection management (connect, fetch, execute, copy, pipeline, txn, RLS).

use super::auth_types::*;
use super::builder::PgDriverBuilder;
use super::connection::PgConnection;
use super::pool;
use super::rls::RlsContext;
use super::types::*;

/// Combines the pure encoder (Layer 2) with async I/O (Layer 3).
pub struct PgDriver {
    pub(super) connection: PgConnection,
    /// Current RLS context, if set. Used for multi-tenant data isolation.
    pub(super) rls_context: Option<RlsContext>,
}

impl PgDriver {
    /// Create a new driver with an existing connection.
    pub fn new(connection: PgConnection) -> Self {
        Self {
            connection,
            rls_context: None,
        }
    }

    /// Builder pattern for ergonomic connection configuration.
    /// # Example
    /// ```ignore
    /// let driver = PgDriver::builder()
    ///     .host("localhost")
    ///     .port(5432)
    ///     .user("admin")
    ///     .database("mydb")
    ///     .password("secret")  // Optional
    ///     .connect()
    ///     .await?;
    /// ```
    pub fn builder() -> PgDriverBuilder {
        PgDriverBuilder::new()
    }

    /// Connect to PostgreSQL and create a driver (trust mode, no password).
    ///
    /// # Arguments
    ///
    /// * `host` — PostgreSQL server hostname or IP.
    /// * `port` — TCP port (typically 5432).
    /// * `user` — PostgreSQL role name.
    /// * `database` — Target database name.
    pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
        let connection = PgConnection::connect(host, port, user, database).await?;
        Ok(Self::new(connection))
    }

    /// Connect to PostgreSQL with password authentication.
    /// Supports server-requested auth flow: cleartext, MD5, or SCRAM-SHA-256.
    pub async fn connect_with_password(
        host: &str,
        port: u16,
        user: &str,
        database: &str,
        password: &str,
    ) -> PgResult<Self> {
        let connection =
            PgConnection::connect_with_password(host, port, user, database, Some(password)).await?;
        Ok(Self::new(connection))
    }

    /// Connect with explicit security options.
    pub async fn connect_with_options(
        host: &str,
        port: u16,
        user: &str,
        database: &str,
        password: Option<&str>,
        options: ConnectOptions,
    ) -> PgResult<Self> {
        let connection =
            PgConnection::connect_with_options(host, port, user, database, password, options)
                .await?;
        Ok(Self::new(connection))
    }

    /// Connect in logical replication mode (`replication=database`).
    ///
    /// This enables replication commands such as `IDENTIFY_SYSTEM` and
    /// `CREATE_REPLICATION_SLOT`.
    pub async fn connect_logical_replication(
        host: &str,
        port: u16,
        user: &str,
        database: &str,
        password: Option<&str>,
    ) -> PgResult<Self> {
        let options = ConnectOptions::default().with_logical_replication();
        Self::connect_with_options(host, port, user, database, password, options).await
    }

    /// Connect with explicit options and force logical replication mode.
    pub async fn connect_logical_replication_with_options(
        host: &str,
        port: u16,
        user: &str,
        database: &str,
        password: Option<&str>,
        options: ConnectOptions,
    ) -> PgResult<Self> {
        Self::connect_with_options(
            host,
            port,
            user,
            database,
            password,
            options.with_logical_replication(),
        )
        .await
    }

    /// Connect using DATABASE_URL environment variable.
    ///
    /// Parses the URL format: `postgresql://user:password@host:port/database`
    /// or `postgres://user:password@host:port/database`
    ///
    /// # Example
    /// ```ignore
    /// // Set DATABASE_URL=postgresql://user:pass@localhost:5432/mydb
    /// let driver = PgDriver::connect_env().await?;
    /// ```
    pub async fn connect_env() -> PgResult<Self> {
        let url = std::env::var("DATABASE_URL").map_err(|_| {
            PgError::Connection("DATABASE_URL environment variable not set".to_string())
        })?;
        Self::connect_url(&url).await
    }

    /// Connect using a PostgreSQL connection URL.
    ///
    /// Parses the URL format: `postgresql://user:password@host:port/database?params`
    /// or `postgres://user:password@host:port/database?params`
    ///
    /// Supports all enterprise query params (sslmode, auth_mode, gss_provider,
    /// channel_binding, etc.) — same set as `PoolConfig::from_qail_config`.
    ///
    /// # Example
    /// ```ignore
    /// let driver = PgDriver::connect_url("postgresql://user:pass@localhost:5432/mydb?sslmode=require").await?;
    /// ```
    pub async fn connect_url(url: &str) -> PgResult<Self> {
        let (host, port, user, database, password) = Self::parse_database_url(url)?;

        // Parse enterprise query params using the shared helper from pool.rs.
        let mut pool_cfg = pool::PoolConfig::new(&host, port, &user, &database);
        if let Some(pw) = &password {
            pool_cfg = pool_cfg.password(pw);
        }
        if let Some(query) = url.split('?').nth(1) {
            pool::apply_url_query_params(&mut pool_cfg, query, &host)?;
        }

        let mut opts = ConnectOptions {
            tls_mode: pool_cfg.tls_mode,
            gss_enc_mode: pool_cfg.gss_enc_mode,
            tls_ca_cert_pem: pool_cfg.tls_ca_cert_pem,
            mtls: pool_cfg.mtls,
            gss_token_provider: pool_cfg.gss_token_provider,
            gss_token_provider_ex: pool_cfg.gss_token_provider_ex,
            auth: pool_cfg.auth_settings,
            startup_params: Vec::new(),
        };

        // Startup parameters not owned by PoolConfig parser.
        if let Some(query) = url.split('?').nth(1) {
            for pair in query.split('&') {
                let mut kv = pair.splitn(2, '=');
                let key = kv.next().unwrap_or_default().trim();
                let value = kv.next().unwrap_or_default().trim();
                if key.eq_ignore_ascii_case("replication") {
                    let replication_mode = if value.eq_ignore_ascii_case("database") {
                        "database"
                    } else if value.eq_ignore_ascii_case("true")
                        || value.eq_ignore_ascii_case("on")
                        || value == "1"
                    {
                        // Canonicalize legacy truthy values to PostgreSQL's
                        // logical-replication mode value.
                        "database"
                    } else {
                        return Err(PgError::Connection(format!(
                            "Invalid replication startup mode '{}': expected database|true|on|1",
                            value
                        )));
                    };
                    opts = opts.with_startup_param("replication", replication_mode);
                }
            }
        }

        Self::connect_with_options(&host, port, &user, &database, password.as_deref(), opts).await
    }

    /// Parse a PostgreSQL connection URL into components.
    ///
    /// Format: `postgresql://user:password@host:port/database`
    /// or `postgres://user:password@host:port/database`
    ///
    /// URL percent-encoding is automatically decoded for user and password.
    pub(crate) fn parse_database_url(
        url: &str,
    ) -> PgResult<(String, u16, String, String, Option<String>)> {
        let after_scheme = if let Some(rest) = url.strip_prefix("postgres://") {
            rest
        } else if let Some(rest) = url.strip_prefix("postgresql://") {
            rest
        } else {
            return Err(PgError::Connection(
                "Invalid DATABASE_URL: expected postgres:// or postgresql://".to_string(),
            ));
        };

        // Split into auth@host parts
        let (auth_part, host_db_part) = if let Some(at_pos) = after_scheme.rfind('@') {
            (Some(&after_scheme[..at_pos]), &after_scheme[at_pos + 1..])
        } else {
            (None, after_scheme)
        };

        // Parse auth (user:password)
        let (user, password) = if let Some(auth) = auth_part {
            let parts: Vec<&str> = auth.splitn(2, ':').collect();
            if parts.len() == 2 {
                // URL-decode both user and password
                (
                    Self::percent_decode(parts[0]),
                    Some(Self::percent_decode(parts[1])),
                )
            } else {
                (Self::percent_decode(parts[0]), None)
            }
        } else {
            return Err(PgError::Connection(
                "Invalid DATABASE_URL: missing user".to_string(),
            ));
        };

        // Parse host:port/database (strip query string if present)
        let (host_port, database) = if let Some(slash_pos) = host_db_part.find('/') {
            let raw_db = &host_db_part[slash_pos + 1..];
            // Strip ?query params — they're handled separately by connect_url
            let db = Self::percent_decode(raw_db.split('?').next().unwrap_or(raw_db));
            (&host_db_part[..slash_pos], db)
        } else {
            return Err(PgError::Connection(
                "Invalid DATABASE_URL: missing database name".to_string(),
            ));
        };

        // Parse host:port
        let (host, port) = if host_port.starts_with('[') {
            let end = host_port.find(']').ok_or_else(|| {
                PgError::Connection("Invalid DATABASE_URL: malformed IPv6 host".to_string())
            })?;
            let host = &host_port[..=end];
            if host == "[]" {
                return Err(PgError::Connection(
                    "Invalid DATABASE_URL: missing host".to_string(),
                ));
            }
            let suffix = &host_port[end + 1..];
            let port = if suffix.is_empty() {
                5432
            } else if let Some(port_str) = suffix.strip_prefix(':') {
                Self::parse_database_url_port(port_str)?
            } else {
                return Err(PgError::Connection(
                    "Invalid DATABASE_URL: malformed IPv6 host".to_string(),
                ));
            };
            (host.to_string(), port)
        } else if let Some(colon_pos) = host_port.rfind(':') {
            let port_str = &host_port[colon_pos + 1..];
            let host = &host_port[..colon_pos];
            if host.is_empty() {
                return Err(PgError::Connection(
                    "Invalid DATABASE_URL: missing host".to_string(),
                ));
            }
            let port = Self::parse_database_url_port(port_str)?;
            (host.to_string(), port)
        } else {
            if host_port.is_empty() {
                return Err(PgError::Connection(
                    "Invalid DATABASE_URL: missing host".to_string(),
                ));
            }
            (host_port.to_string(), 5432) // Default PostgreSQL port
        };

        Ok((host, port, user, database, password))
    }

    fn parse_database_url_port(port_str: &str) -> PgResult<u16> {
        if port_str.is_empty() {
            return Err(PgError::Connection(
                "Invalid DATABASE_URL: missing port after ':'".to_string(),
            ));
        }
        let port = port_str
            .parse::<u16>()
            .map_err(|_| PgError::Connection(format!("Invalid port: {}", port_str)))?;
        if port == 0 {
            return Err(PgError::Connection(
                "Invalid port: 0 (expected 1..=65535)".to_string(),
            ));
        }
        Ok(port)
    }

    /// Decode URL percent-encoded string.
    /// Handles common encodings: %20 (space), %2B (+), %3D (=), %40 (@), %2F (/), etc.
    pub(crate) fn percent_decode(s: &str) -> String {
        fn hex_value(byte: u8) -> Option<u8> {
            match byte {
                b'0'..=b'9' => Some(byte - b'0'),
                b'a'..=b'f' => Some(byte - b'a' + 10),
                b'A'..=b'F' => Some(byte - b'A' + 10),
                _ => None,
            }
        }

        let bytes = s.as_bytes();
        let mut decoded = Vec::with_capacity(bytes.len());
        let mut i = 0;

        while i < bytes.len() {
            if bytes[i] == b'%'
                && i + 2 < bytes.len()
                && let (Some(hi), Some(lo)) = (hex_value(bytes[i + 1]), hex_value(bytes[i + 2]))
            {
                decoded.push((hi << 4) | lo);
                i += 3;
            } else {
                decoded.push(bytes[i]);
                i += 1;
            }
        }

        String::from_utf8_lossy(&decoded).into_owned()
    }

    /// Connect to PostgreSQL with a connection timeout.
    /// If the connection cannot be established within the timeout, returns an error.
    /// # Example
    /// ```ignore
    /// use std::time::Duration;
    /// let driver = PgDriver::connect_with_timeout(
    ///     "localhost", 5432, "user", "db", "password",
    ///     Duration::from_secs(5)
    /// ).await?;
    /// ```
    pub async fn connect_with_timeout(
        host: &str,
        port: u16,
        user: &str,
        database: &str,
        password: &str,
        timeout: std::time::Duration,
    ) -> PgResult<Self> {
        tokio::time::timeout(
            timeout,
            Self::connect_with_password(host, port, user, database, password),
        )
        .await
        .map_err(|_| PgError::Timeout(format!("connection after {:?}", timeout)))?
    }
    /// Clear the prepared statement cache.
    /// Frees memory by removing all cached statements.
    /// Note: Statements remain on the PostgreSQL server until connection closes.
    pub fn clear_cache(&mut self) {
        self.connection.clear_prepared_statement_state();
    }

    /// Get cache statistics.
    /// Returns (current_size, max_capacity).
    pub fn cache_stats(&self) -> (usize, usize) {
        (
            self.connection.stmt_cache.len(),
            self.connection.stmt_cache.cap().get(),
        )
    }
}