Skip to main content

qail_pg/driver/
core.rs

1//! PgDriver — high-level async PostgreSQL driver combining the wire-protocol
2//! encoder with connection management (connect, fetch, execute, copy, pipeline, txn, RLS).
3
4use super::auth_types::*;
5use super::builder::PgDriverBuilder;
6use super::connection::PgConnection;
7use super::pool;
8use super::rls::RlsContext;
9use super::types::*;
10
11/// Combines the pure encoder (Layer 2) with async I/O (Layer 3).
12pub struct PgDriver {
13    pub(super) connection: PgConnection,
14    /// Current RLS context, if set. Used for multi-tenant data isolation.
15    pub(super) rls_context: Option<RlsContext>,
16}
17
18impl PgDriver {
19    /// Create a new driver with an existing connection.
20    pub fn new(connection: PgConnection) -> Self {
21        Self {
22            connection,
23            rls_context: None,
24        }
25    }
26
27    /// Builder pattern for ergonomic connection configuration.
28    /// # Example
29    /// ```ignore
30    /// let driver = PgDriver::builder()
31    ///     .host("localhost")
32    ///     .port(5432)
33    ///     .user("admin")
34    ///     .database("mydb")
35    ///     .password("secret")  // Optional
36    ///     .connect()
37    ///     .await?;
38    /// ```
39    pub fn builder() -> PgDriverBuilder {
40        PgDriverBuilder::new()
41    }
42
43    /// Connect to PostgreSQL and create a driver (trust mode, no password).
44    ///
45    /// # Arguments
46    ///
47    /// * `host` — PostgreSQL server hostname or IP.
48    /// * `port` — TCP port (typically 5432).
49    /// * `user` — PostgreSQL role name.
50    /// * `database` — Target database name.
51    pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
52        let connection = PgConnection::connect(host, port, user, database).await?;
53        Ok(Self::new(connection))
54    }
55
56    /// Connect to PostgreSQL with password authentication.
57    /// Supports server-requested auth flow: cleartext, MD5, or SCRAM-SHA-256.
58    pub async fn connect_with_password(
59        host: &str,
60        port: u16,
61        user: &str,
62        database: &str,
63        password: &str,
64    ) -> PgResult<Self> {
65        let connection =
66            PgConnection::connect_with_password(host, port, user, database, Some(password)).await?;
67        Ok(Self::new(connection))
68    }
69
70    /// Connect with explicit security options.
71    pub async fn connect_with_options(
72        host: &str,
73        port: u16,
74        user: &str,
75        database: &str,
76        password: Option<&str>,
77        options: ConnectOptions,
78    ) -> PgResult<Self> {
79        let connection =
80            PgConnection::connect_with_options(host, port, user, database, password, options)
81                .await?;
82        Ok(Self::new(connection))
83    }
84
85    /// Connect in logical replication mode (`replication=database`).
86    ///
87    /// This enables replication commands such as `IDENTIFY_SYSTEM` and
88    /// `CREATE_REPLICATION_SLOT`.
89    pub async fn connect_logical_replication(
90        host: &str,
91        port: u16,
92        user: &str,
93        database: &str,
94        password: Option<&str>,
95    ) -> PgResult<Self> {
96        let options = ConnectOptions::default().with_logical_replication();
97        Self::connect_with_options(host, port, user, database, password, options).await
98    }
99
100    /// Connect with explicit options and force logical replication mode.
101    pub async fn connect_logical_replication_with_options(
102        host: &str,
103        port: u16,
104        user: &str,
105        database: &str,
106        password: Option<&str>,
107        options: ConnectOptions,
108    ) -> PgResult<Self> {
109        Self::connect_with_options(
110            host,
111            port,
112            user,
113            database,
114            password,
115            options.with_logical_replication(),
116        )
117        .await
118    }
119
120    /// Connect using DATABASE_URL environment variable.
121    ///
122    /// Parses the URL format: `postgresql://user:password@host:port/database`
123    /// or `postgres://user:password@host:port/database`
124    ///
125    /// # Example
126    /// ```ignore
127    /// // Set DATABASE_URL=postgresql://user:pass@localhost:5432/mydb
128    /// let driver = PgDriver::connect_env().await?;
129    /// ```
130    pub async fn connect_env() -> PgResult<Self> {
131        let url = std::env::var("DATABASE_URL").map_err(|_| {
132            PgError::Connection("DATABASE_URL environment variable not set".to_string())
133        })?;
134        Self::connect_url(&url).await
135    }
136
137    /// Connect using a PostgreSQL connection URL.
138    ///
139    /// Parses the URL format: `postgresql://user:password@host:port/database?params`
140    /// or `postgres://user:password@host:port/database?params`
141    ///
142    /// Supports all enterprise query params (sslmode, auth_mode, gss_provider,
143    /// channel_binding, etc.) — same set as `PoolConfig::from_qail_config`.
144    ///
145    /// # Example
146    /// ```ignore
147    /// let driver = PgDriver::connect_url("postgresql://user:pass@localhost:5432/mydb?sslmode=require").await?;
148    /// ```
149    pub async fn connect_url(url: &str) -> PgResult<Self> {
150        let (host, port, user, database, password) = Self::parse_database_url(url)?;
151
152        // Parse enterprise query params using the shared helper from pool.rs.
153        let mut pool_cfg = pool::PoolConfig::new(&host, port, &user, &database);
154        if let Some(pw) = &password {
155            pool_cfg = pool_cfg.password(pw);
156        }
157        if let Some((_, query)) = url.split_once('?') {
158            pool::apply_url_query_params(&mut pool_cfg, query, &host)?;
159        }
160
161        let mut opts = ConnectOptions {
162            tls_mode: pool_cfg.tls_mode,
163            gss_enc_mode: pool_cfg.gss_enc_mode,
164            tls_ca_cert_pem: pool_cfg.tls_ca_cert_pem,
165            mtls: pool_cfg.mtls,
166            gss_token_provider: pool_cfg.gss_token_provider,
167            gss_token_provider_ex: pool_cfg.gss_token_provider_ex,
168            auth: pool_cfg.auth_settings,
169            startup_params: Vec::new(),
170        };
171
172        // Startup parameters not owned by PoolConfig parser.
173        if let Some((_, query)) = url.split_once('?') {
174            for pair in query.split('&') {
175                let mut kv = pair.splitn(2, '=');
176                let key = kv.next().unwrap_or_default().trim();
177                let value = kv.next().unwrap_or_default().trim();
178                if key.eq_ignore_ascii_case("replication") {
179                    let replication_mode = if value.eq_ignore_ascii_case("database") {
180                        "database"
181                    } else if value.eq_ignore_ascii_case("true")
182                        || value.eq_ignore_ascii_case("on")
183                        || value == "1"
184                    {
185                        // Canonicalize legacy truthy values to PostgreSQL's
186                        // logical-replication mode value.
187                        "database"
188                    } else {
189                        return Err(PgError::Connection(format!(
190                            "Invalid replication startup mode '{}': expected database|true|on|1",
191                            value
192                        )));
193                    };
194                    opts = opts.with_startup_param("replication", replication_mode);
195                }
196            }
197        }
198
199        Self::connect_with_options(&host, port, &user, &database, password.as_deref(), opts).await
200    }
201
202    /// Parse a PostgreSQL connection URL into components.
203    ///
204    /// Format: `postgresql://user:password@host:port/database`
205    /// or `postgres://user:password@host:port/database`
206    ///
207    /// URL percent-encoding is automatically decoded for user and password.
208    pub(crate) fn parse_database_url(
209        url: &str,
210    ) -> PgResult<(String, u16, String, String, Option<String>)> {
211        let after_scheme = if let Some(rest) = url.strip_prefix("postgres://") {
212            rest
213        } else if let Some(rest) = url.strip_prefix("postgresql://") {
214            rest
215        } else {
216            return Err(PgError::Connection(
217                "Invalid DATABASE_URL: expected postgres:// or postgresql://".to_string(),
218            ));
219        };
220
221        // Split into auth@host parts
222        let (auth_part, host_db_part) = if let Some(at_pos) = after_scheme.rfind('@') {
223            (Some(&after_scheme[..at_pos]), &after_scheme[at_pos + 1..])
224        } else {
225            (None, after_scheme)
226        };
227
228        // Parse auth (user:password)
229        let (user, password) = if let Some(auth) = auth_part {
230            if auth.is_empty() {
231                return Err(PgError::Connection(
232                    "Invalid DATABASE_URL: missing user".to_string(),
233                ));
234            }
235            let parts: Vec<&str> = auth.splitn(2, ':').collect();
236            if parts.len() == 2 {
237                // URL-decode both user and password
238                let user = Self::percent_decode(parts[0])?;
239                if user.is_empty() {
240                    return Err(PgError::Connection(
241                        "Invalid DATABASE_URL: missing user".to_string(),
242                    ));
243                }
244                (user, Some(Self::percent_decode(parts[1])?))
245            } else {
246                let user = Self::percent_decode(parts[0])?;
247                if user.is_empty() {
248                    return Err(PgError::Connection(
249                        "Invalid DATABASE_URL: missing user".to_string(),
250                    ));
251                }
252                (user, None)
253            }
254        } else {
255            ("postgres".to_string(), None)
256        };
257
258        // Parse host:port/database (strip query string if present)
259        let (host_port, database) = if let Some(slash_pos) = host_db_part.find('/') {
260            let raw_db = &host_db_part[slash_pos + 1..];
261            // Strip ?query params — they're handled separately by connect_url
262            let db = Self::percent_decode(raw_db.split('?').next().unwrap_or(raw_db))?;
263            (&host_db_part[..slash_pos], db)
264        } else {
265            return Err(PgError::Connection(
266                "Invalid DATABASE_URL: missing database name".to_string(),
267            ));
268        };
269
270        // Parse host:port
271        let (host, port) = if host_port.starts_with('[') {
272            let end = host_port.find(']').ok_or_else(|| {
273                PgError::Connection("Invalid DATABASE_URL: malformed IPv6 host".to_string())
274            })?;
275            let host = &host_port[..=end];
276            if host == "[]" {
277                return Err(PgError::Connection(
278                    "Invalid DATABASE_URL: missing host".to_string(),
279                ));
280            }
281            let suffix = &host_port[end + 1..];
282            let port = if suffix.is_empty() {
283                5432
284            } else if let Some(port_str) = suffix.strip_prefix(':') {
285                Self::parse_database_url_port(port_str)?
286            } else {
287                return Err(PgError::Connection(
288                    "Invalid DATABASE_URL: malformed IPv6 host".to_string(),
289                ));
290            };
291            (host.to_string(), port)
292        } else if let Some(colon_pos) = host_port.rfind(':') {
293            let port_str = &host_port[colon_pos + 1..];
294            let host = &host_port[..colon_pos];
295            if host.is_empty() {
296                return Err(PgError::Connection(
297                    "Invalid DATABASE_URL: missing host".to_string(),
298                ));
299            }
300            let port = Self::parse_database_url_port(port_str)?;
301            (host.to_string(), port)
302        } else {
303            if host_port.is_empty() {
304                return Err(PgError::Connection(
305                    "Invalid DATABASE_URL: missing host".to_string(),
306                ));
307            }
308            (host_port.to_string(), 5432) // Default PostgreSQL port
309        };
310
311        Ok((host, port, user, database, password))
312    }
313
314    fn parse_database_url_port(port_str: &str) -> PgResult<u16> {
315        if port_str.is_empty() {
316            return Err(PgError::Connection(
317                "Invalid DATABASE_URL: missing port after ':'".to_string(),
318            ));
319        }
320        let port = port_str
321            .parse::<u16>()
322            .map_err(|_| PgError::Connection(format!("Invalid port: {}", port_str)))?;
323        if port == 0 {
324            return Err(PgError::Connection(
325                "Invalid port: 0 (expected 1..=65535)".to_string(),
326            ));
327        }
328        Ok(port)
329    }
330
331    /// Decode URL percent-encoded string.
332    /// Handles common encodings: %20 (space), %2B (+), %3D (=), %40 (@), %2F (/), etc.
333    pub(crate) fn percent_decode(s: &str) -> PgResult<String> {
334        fn hex_value(byte: u8) -> Option<u8> {
335            match byte {
336                b'0'..=b'9' => Some(byte - b'0'),
337                b'a'..=b'f' => Some(byte - b'a' + 10),
338                b'A'..=b'F' => Some(byte - b'A' + 10),
339                _ => None,
340            }
341        }
342
343        let bytes = s.as_bytes();
344        let mut decoded = Vec::with_capacity(bytes.len());
345        let mut i = 0;
346
347        while i < bytes.len() {
348            if bytes[i] == b'%'
349                && i + 2 < bytes.len()
350                && let (Some(hi), Some(lo)) = (hex_value(bytes[i + 1]), hex_value(bytes[i + 2]))
351            {
352                decoded.push((hi << 4) | lo);
353                i += 3;
354            } else {
355                decoded.push(bytes[i]);
356                i += 1;
357            }
358        }
359
360        String::from_utf8(decoded).map_err(|_| {
361            PgError::Connection(
362                "Invalid DATABASE_URL percent-encoding: decoded value is not UTF-8".to_string(),
363            )
364        })
365    }
366
367    /// Connect to PostgreSQL with a connection timeout.
368    /// If the connection cannot be established within the timeout, returns an error.
369    /// # Example
370    /// ```ignore
371    /// use std::time::Duration;
372    /// let driver = PgDriver::connect_with_timeout(
373    ///     "localhost", 5432, "user", "db", "password",
374    ///     Duration::from_secs(5)
375    /// ).await?;
376    /// ```
377    pub async fn connect_with_timeout(
378        host: &str,
379        port: u16,
380        user: &str,
381        database: &str,
382        password: &str,
383        timeout: std::time::Duration,
384    ) -> PgResult<Self> {
385        tokio::time::timeout(
386            timeout,
387            Self::connect_with_password(host, port, user, database, password),
388        )
389        .await
390        .map_err(|_| PgError::Timeout(format!("connection after {:?}", timeout)))?
391    }
392    /// Clear the prepared statement cache.
393    /// Frees memory by removing all cached statements.
394    /// Note: Statements remain on the PostgreSQL server until connection closes.
395    pub fn clear_cache(&mut self) {
396        self.connection.clear_prepared_statement_state();
397    }
398
399    /// Get cache statistics.
400    /// Returns (current_size, max_capacity).
401    pub fn cache_stats(&self) -> (usize, usize) {
402        (
403            self.connection.stmt_cache.len(),
404            self.connection.stmt_cache.cap().get(),
405        )
406    }
407}