Skip to main content

qail_pg/driver/
core.rs

1//! PgDriver — high-level async PostgreSQL driver combining the wire-protocol
2//! encoder with connection management (connect, fetch, execute, copy, pipeline, txn, RLS).
3
4use super::auth_types::*;
5use super::builder::PgDriverBuilder;
6use super::connection::PgConnection;
7use super::pool;
8use super::rls::RlsContext;
9use super::types::*;
10
11/// Combines the pure encoder (Layer 2) with async I/O (Layer 3).
12pub struct PgDriver {
13    #[allow(dead_code)]
14    pub(super) connection: PgConnection,
15    /// Current RLS context, if set. Used for multi-tenant data isolation.
16    pub(super) rls_context: Option<RlsContext>,
17}
18
19impl PgDriver {
20    /// Create a new driver with an existing connection.
21    pub fn new(connection: PgConnection) -> Self {
22        Self {
23            connection,
24            rls_context: None,
25        }
26    }
27
28    /// Builder pattern for ergonomic connection configuration.
29    /// # Example
30    /// ```ignore
31    /// let driver = PgDriver::builder()
32    ///     .host("localhost")
33    ///     .port(5432)
34    ///     .user("admin")
35    ///     .database("mydb")
36    ///     .password("secret")  // Optional
37    ///     .connect()
38    ///     .await?;
39    /// ```
40    pub fn builder() -> PgDriverBuilder {
41        PgDriverBuilder::new()
42    }
43
44    /// Connect to PostgreSQL and create a driver (trust mode, no password).
45    ///
46    /// # Arguments
47    ///
48    /// * `host` — PostgreSQL server hostname or IP.
49    /// * `port` — TCP port (typically 5432).
50    /// * `user` — PostgreSQL role name.
51    /// * `database` — Target database name.
52    pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
53        let connection = PgConnection::connect(host, port, user, database).await?;
54        Ok(Self::new(connection))
55    }
56
57    /// Connect to PostgreSQL with password authentication.
58    /// Supports server-requested auth flow: cleartext, MD5, or SCRAM-SHA-256.
59    pub async fn connect_with_password(
60        host: &str,
61        port: u16,
62        user: &str,
63        database: &str,
64        password: &str,
65    ) -> PgResult<Self> {
66        let connection =
67            PgConnection::connect_with_password(host, port, user, database, Some(password)).await?;
68        Ok(Self::new(connection))
69    }
70
71    /// Connect with explicit security options.
72    pub async fn connect_with_options(
73        host: &str,
74        port: u16,
75        user: &str,
76        database: &str,
77        password: Option<&str>,
78        options: ConnectOptions,
79    ) -> PgResult<Self> {
80        let connection =
81            PgConnection::connect_with_options(host, port, user, database, password, options)
82                .await?;
83        Ok(Self::new(connection))
84    }
85
86    /// Connect in logical replication mode (`replication=database`).
87    ///
88    /// This enables replication commands such as `IDENTIFY_SYSTEM` and
89    /// `CREATE_REPLICATION_SLOT`.
90    pub async fn connect_logical_replication(
91        host: &str,
92        port: u16,
93        user: &str,
94        database: &str,
95        password: Option<&str>,
96    ) -> PgResult<Self> {
97        let options = ConnectOptions::default().with_logical_replication();
98        Self::connect_with_options(host, port, user, database, password, options).await
99    }
100
101    /// Connect with explicit options and force logical replication mode.
102    pub async fn connect_logical_replication_with_options(
103        host: &str,
104        port: u16,
105        user: &str,
106        database: &str,
107        password: Option<&str>,
108        options: ConnectOptions,
109    ) -> PgResult<Self> {
110        Self::connect_with_options(
111            host,
112            port,
113            user,
114            database,
115            password,
116            options.with_logical_replication(),
117        )
118        .await
119    }
120
121    /// Connect using DATABASE_URL environment variable.
122    ///
123    /// Parses the URL format: `postgresql://user:password@host:port/database`
124    /// or `postgres://user:password@host:port/database`
125    ///
126    /// # Example
127    /// ```ignore
128    /// // Set DATABASE_URL=postgresql://user:pass@localhost:5432/mydb
129    /// let driver = PgDriver::connect_env().await?;
130    /// ```
131    pub async fn connect_env() -> PgResult<Self> {
132        let url = std::env::var("DATABASE_URL").map_err(|_| {
133            PgError::Connection("DATABASE_URL environment variable not set".to_string())
134        })?;
135        Self::connect_url(&url).await
136    }
137
138    /// Connect using a PostgreSQL connection URL.
139    ///
140    /// Parses the URL format: `postgresql://user:password@host:port/database?params`
141    /// or `postgres://user:password@host:port/database?params`
142    ///
143    /// Supports all enterprise query params (sslmode, auth_mode, gss_provider,
144    /// channel_binding, etc.) — same set as `PoolConfig::from_qail_config`.
145    ///
146    /// # Example
147    /// ```ignore
148    /// let driver = PgDriver::connect_url("postgresql://user:pass@localhost:5432/mydb?sslmode=require").await?;
149    /// ```
150    pub async fn connect_url(url: &str) -> PgResult<Self> {
151        let (host, port, user, database, password) = Self::parse_database_url(url)?;
152
153        // Parse enterprise query params using the shared helper from pool.rs.
154        let mut pool_cfg = pool::PoolConfig::new(&host, port, &user, &database);
155        if let Some(pw) = &password {
156            pool_cfg = pool_cfg.password(pw);
157        }
158        if let Some(query) = url.split('?').nth(1) {
159            pool::apply_url_query_params(&mut pool_cfg, query, &host)?;
160        }
161
162        let mut opts = ConnectOptions {
163            tls_mode: pool_cfg.tls_mode,
164            gss_enc_mode: pool_cfg.gss_enc_mode,
165            tls_ca_cert_pem: pool_cfg.tls_ca_cert_pem,
166            mtls: pool_cfg.mtls,
167            gss_token_provider: pool_cfg.gss_token_provider,
168            gss_token_provider_ex: pool_cfg.gss_token_provider_ex,
169            auth: pool_cfg.auth_settings,
170            startup_params: Vec::new(),
171        };
172
173        // Startup parameters not owned by PoolConfig parser.
174        if let Some(query) = url.split('?').nth(1) {
175            for pair in query.split('&') {
176                let mut kv = pair.splitn(2, '=');
177                let key = kv.next().unwrap_or_default().trim();
178                let value = kv.next().unwrap_or_default().trim();
179                if key.eq_ignore_ascii_case("replication") {
180                    let replication_mode = if value.eq_ignore_ascii_case("database") {
181                        "database"
182                    } else if value.eq_ignore_ascii_case("true")
183                        || value.eq_ignore_ascii_case("on")
184                        || value == "1"
185                    {
186                        // Canonicalize legacy truthy values to PostgreSQL's
187                        // logical-replication mode value.
188                        "database"
189                    } else {
190                        return Err(PgError::Connection(format!(
191                            "Invalid replication startup mode '{}': expected database|true|on|1",
192                            value
193                        )));
194                    };
195                    opts = opts.with_startup_param("replication", replication_mode);
196                }
197            }
198        }
199
200        Self::connect_with_options(&host, port, &user, &database, password.as_deref(), opts).await
201    }
202
203    /// Parse a PostgreSQL connection URL into components.
204    ///
205    /// Format: `postgresql://user:password@host:port/database`
206    /// or `postgres://user:password@host:port/database`
207    ///
208    /// URL percent-encoding is automatically decoded for user and password.
209    pub(crate) fn parse_database_url(
210        url: &str,
211    ) -> PgResult<(String, u16, String, String, Option<String>)> {
212        // Remove scheme (postgresql:// or postgres://)
213        let after_scheme = url.split("://").nth(1).ok_or_else(|| {
214            PgError::Connection("Invalid DATABASE_URL: missing scheme".to_string())
215        })?;
216
217        // Split into auth@host parts
218        let (auth_part, host_db_part) = if let Some(at_pos) = after_scheme.rfind('@') {
219            (Some(&after_scheme[..at_pos]), &after_scheme[at_pos + 1..])
220        } else {
221            (None, after_scheme)
222        };
223
224        // Parse auth (user:password)
225        let (user, password) = if let Some(auth) = auth_part {
226            let parts: Vec<&str> = auth.splitn(2, ':').collect();
227            if parts.len() == 2 {
228                // URL-decode both user and password
229                (
230                    Self::percent_decode(parts[0]),
231                    Some(Self::percent_decode(parts[1])),
232                )
233            } else {
234                (Self::percent_decode(parts[0]), None)
235            }
236        } else {
237            return Err(PgError::Connection(
238                "Invalid DATABASE_URL: missing user".to_string(),
239            ));
240        };
241
242        // Parse host:port/database (strip query string if present)
243        let (host_port, database) = if let Some(slash_pos) = host_db_part.find('/') {
244            let raw_db = &host_db_part[slash_pos + 1..];
245            // Strip ?query params — they're handled separately by connect_url
246            let db = raw_db.split('?').next().unwrap_or(raw_db).to_string();
247            (&host_db_part[..slash_pos], db)
248        } else {
249            return Err(PgError::Connection(
250                "Invalid DATABASE_URL: missing database name".to_string(),
251            ));
252        };
253
254        // Parse host:port
255        let (host, port) = if let Some(colon_pos) = host_port.rfind(':') {
256            let port_str = &host_port[colon_pos + 1..];
257            let port = port_str
258                .parse::<u16>()
259                .map_err(|_| PgError::Connection(format!("Invalid port: {}", port_str)))?;
260            (host_port[..colon_pos].to_string(), port)
261        } else {
262            (host_port.to_string(), 5432) // Default PostgreSQL port
263        };
264
265        Ok((host, port, user, database, password))
266    }
267
268    /// Decode URL percent-encoded string.
269    /// Handles common encodings: %20 (space), %2B (+), %3D (=), %40 (@), %2F (/), etc.
270    pub(crate) fn percent_decode(s: &str) -> String {
271        let mut result = String::with_capacity(s.len());
272        let mut chars = s.chars().peekable();
273
274        while let Some(c) = chars.next() {
275            if c == '%' {
276                // Try to parse next two chars as hex
277                let hex: String = chars.by_ref().take(2).collect();
278                if hex.len() == 2
279                    && let Ok(byte) = u8::from_str_radix(&hex, 16)
280                {
281                    result.push(byte as char);
282                    continue;
283                }
284                // If parsing failed, keep original
285                result.push('%');
286                result.push_str(&hex);
287            } else if c == '+' {
288                // '+' often represents space in query strings (form encoding)
289                // But in path components, keep as-is. PostgreSQL URLs use path encoding.
290                result.push('+');
291            } else {
292                result.push(c);
293            }
294        }
295
296        result
297    }
298
299    /// Connect to PostgreSQL with a connection timeout.
300    /// If the connection cannot be established within the timeout, returns an error.
301    /// # Example
302    /// ```ignore
303    /// use std::time::Duration;
304    /// let driver = PgDriver::connect_with_timeout(
305    ///     "localhost", 5432, "user", "db", "password",
306    ///     Duration::from_secs(5)
307    /// ).await?;
308    /// ```
309    pub async fn connect_with_timeout(
310        host: &str,
311        port: u16,
312        user: &str,
313        database: &str,
314        password: &str,
315        timeout: std::time::Duration,
316    ) -> PgResult<Self> {
317        tokio::time::timeout(
318            timeout,
319            Self::connect_with_password(host, port, user, database, password),
320        )
321        .await
322        .map_err(|_| PgError::Timeout(format!("connection after {:?}", timeout)))?
323    }
324    /// Clear the prepared statement cache.
325    /// Frees memory by removing all cached statements.
326    /// Note: Statements remain on the PostgreSQL server until connection closes.
327    pub fn clear_cache(&mut self) {
328        self.connection.clear_prepared_statement_state();
329    }
330
331    /// Get cache statistics.
332    /// Returns (current_size, max_capacity).
333    pub fn cache_stats(&self) -> (usize, usize) {
334        (
335            self.connection.stmt_cache.len(),
336            self.connection.stmt_cache.cap().get(),
337        )
338    }
339}