Skip to main content

qail_pg/driver/
connection.rs

1//! PostgreSQL Connection
2//!
3//! Low-level TCP connection with wire protocol handling.
4//! This is Layer 3 (async I/O).
5//!
6//! Methods are split across modules for easier maintenance:
7//! - `io.rs` - Core I/O (send, recv)
8//! - `query.rs` - Query execution
9//! - `transaction.rs` - Transaction control
10//! - `cursor.rs` - Streaming cursors
11//! - `copy.rs` - COPY protocol
12//! - `pipeline.rs` - High-performance pipelining
13//! - `cancel.rs` - Query cancellation
14
15use super::stream::PgStream;
16use super::{PgError, PgResult};
17use super::notification::Notification;
18use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
19use bytes::BytesMut;
20use lru::LruCache;
21use std::collections::{HashMap, VecDeque};
22use std::num::NonZeroUsize;
23use std::sync::Arc;
24use tokio::io::AsyncWriteExt;
25use tokio::net::TcpStream;
26
27/// Statement cache capacity per connection.
28const STMT_CACHE_CAPACITY: NonZeroUsize = NonZeroUsize::new(100).unwrap();
29
30/// Initial buffer capacity (64KB for pipeline performance)
31pub(crate) const BUFFER_CAPACITY: usize = 65536;
32
33/// SSLRequest message bytes (request code: 80877103)
34const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 47];
35
36/// CancelRequest protocol code: 80877102
37pub(crate) const CANCEL_REQUEST_CODE: i32 = 80877102;
38
39/// Default timeout for TCP connect + PostgreSQL handshake.
40/// Prevents Slowloris DoS where a malicious server accepts TCP but never responds.
41pub(crate) const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
42
43/// TLS configuration for mutual TLS (client certificate authentication).
44#[derive(Clone)]
45pub struct TlsConfig {
46    /// Client certificate in PEM format
47    pub client_cert_pem: Vec<u8>,
48    /// Client private key in PEM format
49    pub client_key_pem: Vec<u8>,
50    /// Optional CA certificate for server verification (uses system certs if None)
51    pub ca_cert_pem: Option<Vec<u8>>,
52}
53
54impl TlsConfig {
55    /// Create a new TLS config from file paths.
56    pub fn from_files(
57        cert_path: impl AsRef<std::path::Path>,
58        key_path: impl AsRef<std::path::Path>,
59        ca_path: Option<impl AsRef<std::path::Path>>,
60    ) -> std::io::Result<Self> {
61        Ok(Self {
62            client_cert_pem: std::fs::read(cert_path)?,
63            client_key_pem: std::fs::read(key_path)?,
64            ca_cert_pem: ca_path.map(|p| std::fs::read(p)).transpose()?,
65        })
66    }
67}
68
69/// A raw PostgreSQL connection.
70pub struct PgConnection {
71    pub(crate) stream: PgStream,
72    pub(crate) buffer: BytesMut,
73    pub(crate) write_buf: BytesMut,
74    pub(crate) sql_buf: BytesMut,
75    pub(crate) params_buf: Vec<Option<Vec<u8>>>,
76    pub(crate) prepared_statements: HashMap<String, String>,
77    pub(crate) stmt_cache: LruCache<u64, String>,
78    /// Cache of column metadata (RowDescription) per statement hash.
79    /// PostgreSQL only sends RowDescription after Parse, not on subsequent Bind+Execute.
80    /// This cache ensures by-name column access works even for cached prepared statements.
81    pub(crate) column_info_cache: HashMap<u64, Arc<super::ColumnInfo>>,
82    pub(crate) process_id: i32,
83    pub(crate) secret_key: i32,
84    /// Buffer for asynchronous LISTEN/NOTIFY notifications.
85    /// Populated by `recv()` when it encounters NotificationResponse messages.
86    pub(crate) notifications: VecDeque<Notification>,
87}
88
89impl PgConnection {
90    /// Connect to PostgreSQL server without authentication (trust mode).
91    ///
92    /// # Arguments
93    ///
94    /// * `host` — PostgreSQL server hostname or IP.
95    /// * `port` — TCP port (typically 5432).
96    /// * `user` — PostgreSQL role name.
97    /// * `database` — Target database name.
98    pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
99        Self::connect_with_password(host, port, user, database, None).await
100    }
101
102    /// Connect to PostgreSQL server with optional password authentication.
103    /// Includes a default 10-second timeout covering TCP connect + handshake.
104    pub async fn connect_with_password(
105        host: &str,
106        port: u16,
107        user: &str,
108        database: &str,
109        password: Option<&str>,
110    ) -> PgResult<Self> {
111        tokio::time::timeout(
112            DEFAULT_CONNECT_TIMEOUT,
113            Self::connect_with_password_inner(host, port, user, database, password),
114        )
115        .await
116        .map_err(|_| PgError::Connection(format!(
117            "Connection timeout after {:?} (TCP connect + handshake)",
118            DEFAULT_CONNECT_TIMEOUT
119        )))?
120    }
121
122    /// Inner connection logic without timeout wrapper.
123    async fn connect_with_password_inner(
124        host: &str,
125        port: u16,
126        user: &str,
127        database: &str,
128        password: Option<&str>,
129    ) -> PgResult<Self> {
130        let addr = format!("{}:{}", host, port);
131        let tcp_stream = TcpStream::connect(&addr).await?;
132
133        // Disable Nagle's algorithm for lower latency
134        tcp_stream.set_nodelay(true)?;
135
136        let mut conn = Self {
137            stream: PgStream::Tcp(tcp_stream),
138            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
139            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), // 64KB write buffer
140            sql_buf: BytesMut::with_capacity(512),
141            params_buf: Vec::with_capacity(16), // SQL encoding buffer
142            prepared_statements: HashMap::new(),
143            stmt_cache: LruCache::new(STMT_CACHE_CAPACITY),
144            column_info_cache: HashMap::new(),
145            process_id: 0,
146            secret_key: 0,
147            notifications: VecDeque::new(),
148        };
149
150        conn.send(FrontendMessage::Startup {
151            user: user.to_string(),
152            database: database.to_string(),
153        })
154        .await?;
155
156        conn.handle_startup(user, password).await?;
157
158        Ok(conn)
159    }
160
161    /// Connect to PostgreSQL server with TLS encryption.
162    /// Includes a default 10-second timeout covering TCP connect + TLS + handshake.
163    pub async fn connect_tls(
164        host: &str,
165        port: u16,
166        user: &str,
167        database: &str,
168        password: Option<&str>,
169    ) -> PgResult<Self> {
170        tokio::time::timeout(
171            DEFAULT_CONNECT_TIMEOUT,
172            Self::connect_tls_inner(host, port, user, database, password),
173        )
174        .await
175        .map_err(|_| PgError::Connection(format!(
176            "TLS connection timeout after {:?}",
177            DEFAULT_CONNECT_TIMEOUT
178        )))?
179    }
180
181    /// Inner TLS connection logic without timeout wrapper.
182    async fn connect_tls_inner(
183        host: &str,
184        port: u16,
185        user: &str,
186        database: &str,
187        password: Option<&str>,
188    ) -> PgResult<Self> {
189        use tokio::io::AsyncReadExt;
190        use tokio_rustls::TlsConnector;
191        use tokio_rustls::rustls::ClientConfig;
192        use tokio_rustls::rustls::pki_types::ServerName;
193
194        let addr = format!("{}:{}", host, port);
195        let mut tcp_stream = TcpStream::connect(&addr).await?;
196
197        // Send SSLRequest
198        tcp_stream.write_all(&SSL_REQUEST).await?;
199
200        // Read response
201        let mut response = [0u8; 1];
202        tcp_stream.read_exact(&mut response).await?;
203
204        if response[0] != b'S' {
205            return Err(PgError::Connection(
206                "Server does not support TLS".to_string(),
207            ));
208        }
209
210        // TLS handshake
211        let certs = rustls_native_certs::load_native_certs();
212        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
213        for cert in certs.certs {
214            let _ = root_cert_store.add(cert);
215        }
216
217        let config = ClientConfig::builder()
218            .with_root_certificates(root_cert_store)
219            .with_no_client_auth();
220
221        let connector = TlsConnector::from(Arc::new(config));
222        let server_name = ServerName::try_from(host.to_string())
223            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
224
225        let tls_stream = connector
226            .connect(server_name, tcp_stream)
227            .await
228            .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
229
230        let mut conn = Self {
231            stream: PgStream::Tls(tls_stream),
232            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
233            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
234            sql_buf: BytesMut::with_capacity(512),
235            params_buf: Vec::with_capacity(16),
236            prepared_statements: HashMap::new(),
237            stmt_cache: LruCache::new(STMT_CACHE_CAPACITY),
238            column_info_cache: HashMap::new(),
239            process_id: 0,
240            secret_key: 0,
241            notifications: VecDeque::new(),
242        };
243
244        conn.send(FrontendMessage::Startup {
245            user: user.to_string(),
246            database: database.to_string(),
247        })
248        .await?;
249
250        conn.handle_startup(user, password).await?;
251
252        Ok(conn)
253    }
254
255    /// Connect with mutual TLS (client certificate authentication).
256    /// # Arguments
257    /// * `host` - PostgreSQL server hostname
258    /// * `port` - PostgreSQL server port
259    /// * `user` - Database user
260    /// * `database` - Database name
261    /// * `config` - TLS configuration with client cert/key
262    /// # Example
263    /// ```ignore
264    /// let config = TlsConfig {
265    ///     client_cert_pem: include_bytes!("client.crt").to_vec(),
266    ///     client_key_pem: include_bytes!("client.key").to_vec(),
267    ///     ca_cert_pem: Some(include_bytes!("ca.crt").to_vec()),
268    /// };
269    /// let conn = PgConnection::connect_mtls("localhost", 5432, "user", "db", config).await?;
270    /// ```
271    pub async fn connect_mtls(
272        host: &str,
273        port: u16,
274        user: &str,
275        database: &str,
276        config: TlsConfig,
277    ) -> PgResult<Self> {
278        use tokio::io::AsyncReadExt;
279        use tokio_rustls::TlsConnector;
280        use tokio_rustls::rustls::{
281            ClientConfig,
282            pki_types::{CertificateDer, ServerName},
283        };
284
285        let addr = format!("{}:{}", host, port);
286        let mut tcp_stream = TcpStream::connect(&addr).await?;
287
288        // Send SSLRequest
289        tcp_stream.write_all(&SSL_REQUEST).await?;
290
291        // Read response
292        let mut response = [0u8; 1];
293        tcp_stream.read_exact(&mut response).await?;
294
295        if response[0] != b'S' {
296            return Err(PgError::Connection(
297                "Server does not support TLS".to_string(),
298            ));
299        }
300
301        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
302
303        if let Some(ca_pem) = &config.ca_cert_pem {
304            let certs = rustls_pemfile::certs(&mut ca_pem.as_slice())
305                .filter_map(|r| r.ok())
306                .collect::<Vec<_>>();
307            for cert in certs {
308                let _ = root_cert_store.add(cert);
309            }
310        } else {
311            // Use system certs
312            let certs = rustls_native_certs::load_native_certs();
313            for cert in certs.certs {
314                let _ = root_cert_store.add(cert);
315            }
316        }
317
318        let client_certs: Vec<CertificateDer<'static>> =
319            rustls_pemfile::certs(&mut config.client_cert_pem.as_slice())
320                .filter_map(|r| r.ok())
321                .collect();
322
323        let client_key = rustls_pemfile::private_key(&mut config.client_key_pem.as_slice())
324            .map_err(|e| PgError::Connection(format!("Invalid client key: {:?}", e)))?
325            .ok_or_else(|| PgError::Connection("No private key found in PEM".to_string()))?;
326
327        let tls_config = ClientConfig::builder()
328            .with_root_certificates(root_cert_store)
329            .with_client_auth_cert(client_certs, client_key)
330            .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
331
332        let connector = TlsConnector::from(Arc::new(tls_config));
333        let server_name = ServerName::try_from(host.to_string())
334            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
335
336        let tls_stream = connector
337            .connect(server_name, tcp_stream)
338            .await
339            .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
340
341        let mut conn = Self {
342            stream: PgStream::Tls(tls_stream),
343            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
344            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
345            sql_buf: BytesMut::with_capacity(512),
346            params_buf: Vec::with_capacity(16),
347            prepared_statements: HashMap::new(),
348            stmt_cache: LruCache::new(STMT_CACHE_CAPACITY),
349            column_info_cache: HashMap::new(),
350            process_id: 0,
351            secret_key: 0,
352            notifications: VecDeque::new(),
353        };
354
355        conn.send(FrontendMessage::Startup {
356            user: user.to_string(),
357            database: database.to_string(),
358        })
359        .await?;
360
361        // mTLS typically uses cert auth, no password needed
362        conn.handle_startup(user, None).await?;
363
364        Ok(conn)
365    }
366
367    /// Connect to PostgreSQL server via Unix domain socket.
368    #[cfg(unix)]
369    pub async fn connect_unix(
370        socket_path: &str,
371        user: &str,
372        database: &str,
373        password: Option<&str>,
374    ) -> PgResult<Self> {
375        use tokio::net::UnixStream;
376
377        let unix_stream = UnixStream::connect(socket_path).await?;
378
379        let mut conn = Self {
380            stream: PgStream::Unix(unix_stream),
381            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
382            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
383            sql_buf: BytesMut::with_capacity(512),
384            params_buf: Vec::with_capacity(16),
385            prepared_statements: HashMap::new(),
386            stmt_cache: LruCache::new(STMT_CACHE_CAPACITY),
387            column_info_cache: HashMap::new(),
388            process_id: 0,
389            secret_key: 0,
390            notifications: VecDeque::new(),
391        };
392
393        conn.send(FrontendMessage::Startup {
394            user: user.to_string(),
395            database: database.to_string(),
396        })
397        .await?;
398
399        conn.handle_startup(user, password).await?;
400
401        Ok(conn)
402    }
403
404    /// Handle startup sequence (auth + params).
405    async fn handle_startup(&mut self, user: &str, password: Option<&str>) -> PgResult<()> {
406        let mut scram_client: Option<ScramClient> = None;
407
408        loop {
409            let msg = self.recv().await?;
410            match msg {
411                BackendMessage::AuthenticationOk => {}
412                BackendMessage::AuthenticationMD5Password(_salt) => {
413                    return Err(PgError::Auth(
414                        "MD5 auth not supported. Use SCRAM-SHA-256.".to_string(),
415                    ));
416                }
417                BackendMessage::AuthenticationSASL(mechanisms) => {
418                    let password = password.ok_or_else(|| {
419                        PgError::Auth("Password required for SCRAM authentication".to_string())
420                    })?;
421
422                    if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
423                        return Err(PgError::Auth(format!(
424                            "Server doesn't support SCRAM-SHA-256. Available: {:?}",
425                            mechanisms
426                        )));
427                    }
428
429                    let client = ScramClient::new(user, password);
430                    let first_message = client.client_first_message();
431
432                    self.send(FrontendMessage::SASLInitialResponse {
433                        mechanism: "SCRAM-SHA-256".to_string(),
434                        data: first_message,
435                    })
436                    .await?;
437
438                    scram_client = Some(client);
439                }
440                BackendMessage::AuthenticationSASLContinue(server_data) => {
441                    let client = scram_client.as_mut().ok_or_else(|| {
442                        PgError::Auth("Received SASL Continue without SASL init".to_string())
443                    })?;
444
445                    let final_message = client
446                        .process_server_first(&server_data)
447                        .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
448
449                    self.send(FrontendMessage::SASLResponse(final_message))
450                        .await?;
451                }
452                BackendMessage::AuthenticationSASLFinal(server_signature) => {
453                    if let Some(client) = scram_client.as_ref() {
454                        client.verify_server_final(&server_signature).map_err(|e| {
455                            PgError::Auth(format!("Server verification failed: {}", e))
456                        })?;
457                    }
458                }
459                BackendMessage::ParameterStatus { .. } => {}
460                BackendMessage::BackendKeyData {
461                    process_id,
462                    secret_key,
463                } => {
464                    self.process_id = process_id;
465                    self.secret_key = secret_key;
466                }
467                BackendMessage::ReadyForQuery(TransactionStatus::Idle)
468                | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
469                | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
470                    return Ok(());
471                }
472                BackendMessage::ErrorResponse(err) => {
473                    return Err(PgError::Connection(err.message));
474                }
475                _ => {}
476            }
477        }
478    }
479
480    /// Gracefully close the connection by sending a Terminate message.
481    /// This tells the server we're done and allows proper cleanup.
482    pub async fn close(mut self) -> PgResult<()> {
483        use crate::protocol::PgEncoder;
484        
485        // Send Terminate packet ('X')
486        let terminate = PgEncoder::encode_terminate();
487        self.stream.write_all(&terminate).await?;
488        self.stream.flush().await?;
489        
490        Ok(())
491    }
492
493    /// Maximum prepared statements per connection before LRU eviction kicks in.
494    ///
495    /// This prevents memory spikes from dynamic batch filters generating
496    /// thousands of unique SQL shapes within a single request. Using LRU
497    /// eviction instead of nuclear `.clear()` preserves hot statements.
498    pub(crate) const MAX_PREPARED_PER_CONN: usize = 128;
499
500    /// Evict the least-recently-used prepared statement if at capacity.
501    ///
502    /// Called before every new statement registration to enforce
503    /// `MAX_PREPARED_PER_CONN`. Both `stmt_cache` (LRU ordering) and
504    /// `prepared_statements` (name→SQL map) are kept in sync.
505    pub(crate) fn evict_prepared_if_full(&mut self) {
506        if self.prepared_statements.len() >= Self::MAX_PREPARED_PER_CONN {
507            // Pop the LRU entry from the cache
508            if let Some((_hash, evicted_name)) = self.stmt_cache.pop_lru() {
509                self.prepared_statements.remove(&evicted_name);
510            } else {
511                // stmt_cache is empty but prepared_statements is full —
512                // shouldn't happen in normal flow, but handle defensively
513                // by clearing the oldest entry from the HashMap.
514                if let Some(key) = self.prepared_statements.keys().next().cloned() {
515                    self.prepared_statements.remove(&key);
516                }
517            }
518        }
519    }
520}
521
522/// Drop implementation sends Terminate packet if possible.
523/// This ensures proper cleanup even without explicit close() call.
524impl Drop for PgConnection {
525    fn drop(&mut self) {
526        // Try to send Terminate packet synchronously using try_write
527        // This is best-effort - if it fails, TCP RST will handle cleanup
528        let terminate: [u8; 5] = [b'X', 0, 0, 0, 4];
529        
530        match &mut self.stream {
531            PgStream::Tcp(tcp) => {
532                // try_write is non-blocking
533                let _ = tcp.try_write(&terminate);
534            }
535            PgStream::Tls(_) => {
536                // TLS requires async write which we can't do in Drop.
537                // The TCP connection close will still notify the server.
538                // For graceful TLS shutdown, use connection.close() explicitly.
539            }
540            #[cfg(unix)]
541            PgStream::Unix(unix) => {
542                let _ = unix.try_write(&terminate);
543            }
544        }
545    }
546}
547
548pub(crate) fn parse_affected_rows(tag: &str) -> u64 {
549    tag.split_whitespace()
550        .last()
551        .and_then(|s| s.parse().ok())
552        .unwrap_or(0)
553}