Skip to main content

qail_pg/driver/
connection.rs

1//! PostgreSQL Connection
2//!
3//! Low-level TCP connection with wire protocol handling.
4//! This is Layer 3 (async I/O).
5//!
6//! Methods are split across modules for easier maintenance:
7//! - `io.rs` - Core I/O (send, recv)
8//! - `query.rs` - Query execution
9//! - `transaction.rs` - Transaction control
10//! - `cursor.rs` - Streaming cursors
11//! - `copy.rs` - COPY protocol
12//! - `pipeline.rs` - High-performance pipelining
13//! - `cancel.rs` - Query cancellation
14
15use super::stream::PgStream;
16use super::{PgError, PgResult};
17use super::notification::Notification;
18use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
19use bytes::BytesMut;
20use lru::LruCache;
21use std::collections::{HashMap, VecDeque};
22use std::num::NonZeroUsize;
23use std::sync::Arc;
24use tokio::io::AsyncWriteExt;
25use tokio::net::TcpStream;
26
27/// Initial buffer capacity (64KB for pipeline performance)
28pub(crate) const BUFFER_CAPACITY: usize = 65536;
29
30/// SSLRequest message bytes (request code: 80877103)
31const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 47];
32
33/// CancelRequest protocol code: 80877102
34pub(crate) const CANCEL_REQUEST_CODE: i32 = 80877102;
35
36/// Default timeout for TCP connect + PostgreSQL handshake.
37/// Prevents Slowloris DoS where a malicious server accepts TCP but never responds.
38pub(crate) const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
39
40/// TLS configuration for mutual TLS (client certificate authentication).
41#[derive(Clone)]
42pub struct TlsConfig {
43    /// Client certificate in PEM format
44    pub client_cert_pem: Vec<u8>,
45    /// Client private key in PEM format
46    pub client_key_pem: Vec<u8>,
47    /// Optional CA certificate for server verification (uses system certs if None)
48    pub ca_cert_pem: Option<Vec<u8>>,
49}
50
51impl TlsConfig {
52    /// Create a new TLS config from file paths.
53    pub fn from_files(
54        cert_path: impl AsRef<std::path::Path>,
55        key_path: impl AsRef<std::path::Path>,
56        ca_path: Option<impl AsRef<std::path::Path>>,
57    ) -> std::io::Result<Self> {
58        Ok(Self {
59            client_cert_pem: std::fs::read(cert_path)?,
60            client_key_pem: std::fs::read(key_path)?,
61            ca_cert_pem: ca_path.map(|p| std::fs::read(p)).transpose()?,
62        })
63    }
64}
65
66/// A raw PostgreSQL connection.
67pub struct PgConnection {
68    pub(crate) stream: PgStream,
69    pub(crate) buffer: BytesMut,
70    pub(crate) write_buf: BytesMut,
71    pub(crate) sql_buf: BytesMut,
72    pub(crate) params_buf: Vec<Option<Vec<u8>>>,
73    pub(crate) prepared_statements: HashMap<String, String>,
74    pub(crate) stmt_cache: LruCache<u64, String>,
75    /// Cache of column metadata (RowDescription) per statement hash.
76    /// PostgreSQL only sends RowDescription after Parse, not on subsequent Bind+Execute.
77    /// This cache ensures by-name column access works even for cached prepared statements.
78    pub(crate) column_info_cache: HashMap<u64, Arc<super::ColumnInfo>>,
79    pub(crate) process_id: i32,
80    pub(crate) secret_key: i32,
81    /// Buffer for asynchronous LISTEN/NOTIFY notifications.
82    /// Populated by `recv()` when it encounters NotificationResponse messages.
83    pub(crate) notifications: VecDeque<Notification>,
84}
85
86impl PgConnection {
87    /// Connect to PostgreSQL server without authentication (trust mode).
88    pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
89        Self::connect_with_password(host, port, user, database, None).await
90    }
91
92    /// Connect to PostgreSQL server with optional password authentication.
93    /// Includes a default 10-second timeout covering TCP connect + handshake.
94    pub async fn connect_with_password(
95        host: &str,
96        port: u16,
97        user: &str,
98        database: &str,
99        password: Option<&str>,
100    ) -> PgResult<Self> {
101        tokio::time::timeout(
102            DEFAULT_CONNECT_TIMEOUT,
103            Self::connect_with_password_inner(host, port, user, database, password),
104        )
105        .await
106        .map_err(|_| PgError::Connection(format!(
107            "Connection timeout after {:?} (TCP connect + handshake)",
108            DEFAULT_CONNECT_TIMEOUT
109        )))?
110    }
111
112    /// Inner connection logic without timeout wrapper.
113    async fn connect_with_password_inner(
114        host: &str,
115        port: u16,
116        user: &str,
117        database: &str,
118        password: Option<&str>,
119    ) -> PgResult<Self> {
120        let addr = format!("{}:{}", host, port);
121        let tcp_stream = TcpStream::connect(&addr).await?;
122
123        // Disable Nagle's algorithm for lower latency
124        tcp_stream.set_nodelay(true)?;
125
126        let mut conn = Self {
127            stream: PgStream::Tcp(tcp_stream),
128            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
129            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), // 64KB write buffer
130            sql_buf: BytesMut::with_capacity(512),
131            params_buf: Vec::with_capacity(16), // SQL encoding buffer
132            prepared_statements: HashMap::new(),
133            stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
134            column_info_cache: HashMap::new(),
135            process_id: 0,
136            secret_key: 0,
137            notifications: VecDeque::new(),
138        };
139
140        conn.send(FrontendMessage::Startup {
141            user: user.to_string(),
142            database: database.to_string(),
143        })
144        .await?;
145
146        conn.handle_startup(user, password).await?;
147
148        Ok(conn)
149    }
150
151    /// Connect to PostgreSQL server with TLS encryption.
152    /// Includes a default 10-second timeout covering TCP connect + TLS + handshake.
153    pub async fn connect_tls(
154        host: &str,
155        port: u16,
156        user: &str,
157        database: &str,
158        password: Option<&str>,
159    ) -> PgResult<Self> {
160        tokio::time::timeout(
161            DEFAULT_CONNECT_TIMEOUT,
162            Self::connect_tls_inner(host, port, user, database, password),
163        )
164        .await
165        .map_err(|_| PgError::Connection(format!(
166            "TLS connection timeout after {:?}",
167            DEFAULT_CONNECT_TIMEOUT
168        )))?
169    }
170
171    /// Inner TLS connection logic without timeout wrapper.
172    async fn connect_tls_inner(
173        host: &str,
174        port: u16,
175        user: &str,
176        database: &str,
177        password: Option<&str>,
178    ) -> PgResult<Self> {
179        use tokio::io::AsyncReadExt;
180        use tokio_rustls::TlsConnector;
181        use tokio_rustls::rustls::ClientConfig;
182        use tokio_rustls::rustls::pki_types::ServerName;
183
184        let addr = format!("{}:{}", host, port);
185        let mut tcp_stream = TcpStream::connect(&addr).await?;
186
187        // Send SSLRequest
188        tcp_stream.write_all(&SSL_REQUEST).await?;
189
190        // Read response
191        let mut response = [0u8; 1];
192        tcp_stream.read_exact(&mut response).await?;
193
194        if response[0] != b'S' {
195            return Err(PgError::Connection(
196                "Server does not support TLS".to_string(),
197            ));
198        }
199
200        // TLS handshake
201        let certs = rustls_native_certs::load_native_certs();
202        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
203        for cert in certs.certs {
204            let _ = root_cert_store.add(cert);
205        }
206
207        let config = ClientConfig::builder()
208            .with_root_certificates(root_cert_store)
209            .with_no_client_auth();
210
211        let connector = TlsConnector::from(Arc::new(config));
212        let server_name = ServerName::try_from(host.to_string())
213            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
214
215        let tls_stream = connector
216            .connect(server_name, tcp_stream)
217            .await
218            .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
219
220        let mut conn = Self {
221            stream: PgStream::Tls(tls_stream),
222            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
223            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
224            sql_buf: BytesMut::with_capacity(512),
225            params_buf: Vec::with_capacity(16),
226            prepared_statements: HashMap::new(),
227            stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
228            column_info_cache: HashMap::new(),
229            process_id: 0,
230            secret_key: 0,
231            notifications: VecDeque::new(),
232        };
233
234        conn.send(FrontendMessage::Startup {
235            user: user.to_string(),
236            database: database.to_string(),
237        })
238        .await?;
239
240        conn.handle_startup(user, password).await?;
241
242        Ok(conn)
243    }
244
245    /// Connect with mutual TLS (client certificate authentication).
246    /// # Arguments
247    /// * `host` - PostgreSQL server hostname
248    /// * `port` - PostgreSQL server port
249    /// * `user` - Database user
250    /// * `database` - Database name
251    /// * `config` - TLS configuration with client cert/key
252    /// # Example
253    /// ```ignore
254    /// let config = TlsConfig {
255    ///     client_cert_pem: include_bytes!("client.crt").to_vec(),
256    ///     client_key_pem: include_bytes!("client.key").to_vec(),
257    ///     ca_cert_pem: Some(include_bytes!("ca.crt").to_vec()),
258    /// };
259    /// let conn = PgConnection::connect_mtls("localhost", 5432, "user", "db", config).await?;
260    /// ```
261    pub async fn connect_mtls(
262        host: &str,
263        port: u16,
264        user: &str,
265        database: &str,
266        config: TlsConfig,
267    ) -> PgResult<Self> {
268        use tokio::io::AsyncReadExt;
269        use tokio_rustls::TlsConnector;
270        use tokio_rustls::rustls::{
271            ClientConfig,
272            pki_types::{CertificateDer, ServerName},
273        };
274
275        let addr = format!("{}:{}", host, port);
276        let mut tcp_stream = TcpStream::connect(&addr).await?;
277
278        // Send SSLRequest
279        tcp_stream.write_all(&SSL_REQUEST).await?;
280
281        // Read response
282        let mut response = [0u8; 1];
283        tcp_stream.read_exact(&mut response).await?;
284
285        if response[0] != b'S' {
286            return Err(PgError::Connection(
287                "Server does not support TLS".to_string(),
288            ));
289        }
290
291        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
292
293        if let Some(ca_pem) = &config.ca_cert_pem {
294            let certs = rustls_pemfile::certs(&mut ca_pem.as_slice())
295                .filter_map(|r| r.ok())
296                .collect::<Vec<_>>();
297            for cert in certs {
298                let _ = root_cert_store.add(cert);
299            }
300        } else {
301            // Use system certs
302            let certs = rustls_native_certs::load_native_certs();
303            for cert in certs.certs {
304                let _ = root_cert_store.add(cert);
305            }
306        }
307
308        let client_certs: Vec<CertificateDer<'static>> =
309            rustls_pemfile::certs(&mut config.client_cert_pem.as_slice())
310                .filter_map(|r| r.ok())
311                .collect();
312
313        let client_key = rustls_pemfile::private_key(&mut config.client_key_pem.as_slice())
314            .map_err(|e| PgError::Connection(format!("Invalid client key: {:?}", e)))?
315            .ok_or_else(|| PgError::Connection("No private key found in PEM".to_string()))?;
316
317        let tls_config = ClientConfig::builder()
318            .with_root_certificates(root_cert_store)
319            .with_client_auth_cert(client_certs, client_key)
320            .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
321
322        let connector = TlsConnector::from(Arc::new(tls_config));
323        let server_name = ServerName::try_from(host.to_string())
324            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
325
326        let tls_stream = connector
327            .connect(server_name, tcp_stream)
328            .await
329            .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
330
331        let mut conn = Self {
332            stream: PgStream::Tls(tls_stream),
333            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
334            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
335            sql_buf: BytesMut::with_capacity(512),
336            params_buf: Vec::with_capacity(16),
337            prepared_statements: HashMap::new(),
338            stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
339            column_info_cache: HashMap::new(),
340            process_id: 0,
341            secret_key: 0,
342            notifications: VecDeque::new(),
343        };
344
345        conn.send(FrontendMessage::Startup {
346            user: user.to_string(),
347            database: database.to_string(),
348        })
349        .await?;
350
351        // mTLS typically uses cert auth, no password needed
352        conn.handle_startup(user, None).await?;
353
354        Ok(conn)
355    }
356
357    /// Connect to PostgreSQL server via Unix domain socket.
358    #[cfg(unix)]
359    pub async fn connect_unix(
360        socket_path: &str,
361        user: &str,
362        database: &str,
363        password: Option<&str>,
364    ) -> PgResult<Self> {
365        use tokio::net::UnixStream;
366
367        let unix_stream = UnixStream::connect(socket_path).await?;
368
369        let mut conn = Self {
370            stream: PgStream::Unix(unix_stream),
371            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
372            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
373            sql_buf: BytesMut::with_capacity(512),
374            params_buf: Vec::with_capacity(16),
375            prepared_statements: HashMap::new(),
376            stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
377            column_info_cache: HashMap::new(),
378            process_id: 0,
379            secret_key: 0,
380            notifications: VecDeque::new(),
381        };
382
383        conn.send(FrontendMessage::Startup {
384            user: user.to_string(),
385            database: database.to_string(),
386        })
387        .await?;
388
389        conn.handle_startup(user, password).await?;
390
391        Ok(conn)
392    }
393
394    /// Handle startup sequence (auth + params).
395    async fn handle_startup(&mut self, user: &str, password: Option<&str>) -> PgResult<()> {
396        let mut scram_client: Option<ScramClient> = None;
397
398        loop {
399            let msg = self.recv().await?;
400            match msg {
401                BackendMessage::AuthenticationOk => {}
402                BackendMessage::AuthenticationMD5Password(_salt) => {
403                    return Err(PgError::Auth(
404                        "MD5 auth not supported. Use SCRAM-SHA-256.".to_string(),
405                    ));
406                }
407                BackendMessage::AuthenticationSASL(mechanisms) => {
408                    let password = password.ok_or_else(|| {
409                        PgError::Auth("Password required for SCRAM authentication".to_string())
410                    })?;
411
412                    if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
413                        return Err(PgError::Auth(format!(
414                            "Server doesn't support SCRAM-SHA-256. Available: {:?}",
415                            mechanisms
416                        )));
417                    }
418
419                    let client = ScramClient::new(user, password);
420                    let first_message = client.client_first_message();
421
422                    self.send(FrontendMessage::SASLInitialResponse {
423                        mechanism: "SCRAM-SHA-256".to_string(),
424                        data: first_message,
425                    })
426                    .await?;
427
428                    scram_client = Some(client);
429                }
430                BackendMessage::AuthenticationSASLContinue(server_data) => {
431                    let client = scram_client.as_mut().ok_or_else(|| {
432                        PgError::Auth("Received SASL Continue without SASL init".to_string())
433                    })?;
434
435                    let final_message = client
436                        .process_server_first(&server_data)
437                        .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
438
439                    self.send(FrontendMessage::SASLResponse(final_message))
440                        .await?;
441                }
442                BackendMessage::AuthenticationSASLFinal(server_signature) => {
443                    if let Some(client) = scram_client.as_ref() {
444                        client.verify_server_final(&server_signature).map_err(|e| {
445                            PgError::Auth(format!("Server verification failed: {}", e))
446                        })?;
447                    }
448                }
449                BackendMessage::ParameterStatus { .. } => {}
450                BackendMessage::BackendKeyData {
451                    process_id,
452                    secret_key,
453                } => {
454                    self.process_id = process_id;
455                    self.secret_key = secret_key;
456                }
457                BackendMessage::ReadyForQuery(TransactionStatus::Idle)
458                | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
459                | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
460                    return Ok(());
461                }
462                BackendMessage::ErrorResponse(err) => {
463                    return Err(PgError::Connection(err.message));
464                }
465                _ => {}
466            }
467        }
468    }
469
470    /// Gracefully close the connection by sending a Terminate message.
471    /// This tells the server we're done and allows proper cleanup.
472    pub async fn close(mut self) -> PgResult<()> {
473        use crate::protocol::PgEncoder;
474        
475        // Send Terminate packet ('X')
476        let terminate = PgEncoder::encode_terminate();
477        self.stream.write_all(&terminate).await?;
478        self.stream.flush().await?;
479        
480        Ok(())
481    }
482
483    /// Maximum prepared statements per connection before LRU eviction kicks in.
484    ///
485    /// This prevents memory spikes from dynamic batch filters generating
486    /// thousands of unique SQL shapes within a single request. Using LRU
487    /// eviction instead of nuclear `.clear()` preserves hot statements.
488    pub(crate) const MAX_PREPARED_PER_CONN: usize = 128;
489
490    /// Evict the least-recently-used prepared statement if at capacity.
491    ///
492    /// Called before every new statement registration to enforce
493    /// `MAX_PREPARED_PER_CONN`. Both `stmt_cache` (LRU ordering) and
494    /// `prepared_statements` (name→SQL map) are kept in sync.
495    pub(crate) fn evict_prepared_if_full(&mut self) {
496        if self.prepared_statements.len() >= Self::MAX_PREPARED_PER_CONN {
497            // Pop the LRU entry from the cache
498            if let Some((_hash, evicted_name)) = self.stmt_cache.pop_lru() {
499                self.prepared_statements.remove(&evicted_name);
500            } else {
501                // stmt_cache is empty but prepared_statements is full —
502                // shouldn't happen in normal flow, but handle defensively
503                // by clearing the oldest entry from the HashMap.
504                if let Some(key) = self.prepared_statements.keys().next().cloned() {
505                    self.prepared_statements.remove(&key);
506                }
507            }
508        }
509    }
510}
511
512/// Drop implementation sends Terminate packet if possible.
513/// This ensures proper cleanup even without explicit close() call.
514impl Drop for PgConnection {
515    fn drop(&mut self) {
516        // Try to send Terminate packet synchronously using try_write
517        // This is best-effort - if it fails, TCP RST will handle cleanup
518        let terminate: [u8; 5] = [b'X', 0, 0, 0, 4];
519        
520        match &mut self.stream {
521            PgStream::Tcp(tcp) => {
522                // try_write is non-blocking
523                let _ = tcp.try_write(&terminate);
524            }
525            PgStream::Tls(_) => {
526                // TLS requires async write which we can't do in Drop.
527                // The TCP connection close will still notify the server.
528                // For graceful TLS shutdown, use connection.close() explicitly.
529            }
530            #[cfg(unix)]
531            PgStream::Unix(unix) => {
532                let _ = unix.try_write(&terminate);
533            }
534        }
535    }
536}
537
538pub(crate) fn parse_affected_rows(tag: &str) -> u64 {
539    tag.split_whitespace()
540        .last()
541        .and_then(|s| s.parse().ok())
542        .unwrap_or(0)
543}