Skip to main content

qail_pg/driver/
connection.rs

1//! PostgreSQL Connection
2//!
3//! Low-level TCP connection with wire protocol handling.
4//! This is Layer 3 (async I/O).
5//!
6//! Methods are split across modules for easier maintenance:
7//! - `io.rs` - Core I/O (send, recv)
8//! - `query.rs` - Query execution
9//! - `transaction.rs` - Transaction control
10//! - `cursor.rs` - Streaming cursors
11//! - `copy.rs` - COPY protocol
12//! - `pipeline.rs` - High-performance pipelining
13//! - `cancel.rs` - Query cancellation
14
15use super::stream::PgStream;
16use super::{PgError, PgResult};
17use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
18use bytes::BytesMut;
19use lru::LruCache;
20use std::collections::HashMap;
21use std::num::NonZeroUsize;
22use std::sync::Arc;
23use tokio::io::AsyncWriteExt;
24use tokio::net::TcpStream;
25
26/// Initial buffer capacity (64KB for pipeline performance)
27pub(crate) const BUFFER_CAPACITY: usize = 65536;
28
29/// SSLRequest message bytes (request code: 80877103)
30const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 47];
31
32/// CancelRequest protocol code: 80877102
33pub(crate) const CANCEL_REQUEST_CODE: i32 = 80877102;
34
35/// TLS configuration for mutual TLS (client certificate authentication).
36#[derive(Clone)]
37pub struct TlsConfig {
38    /// Client certificate in PEM format
39    pub client_cert_pem: Vec<u8>,
40    /// Client private key in PEM format
41    pub client_key_pem: Vec<u8>,
42    /// Optional CA certificate for server verification (uses system certs if None)
43    pub ca_cert_pem: Option<Vec<u8>>,
44}
45
46impl TlsConfig {
47    /// Create a new TLS config from file paths.
48    pub fn from_files(
49        cert_path: impl AsRef<std::path::Path>,
50        key_path: impl AsRef<std::path::Path>,
51        ca_path: Option<impl AsRef<std::path::Path>>,
52    ) -> std::io::Result<Self> {
53        Ok(Self {
54            client_cert_pem: std::fs::read(cert_path)?,
55            client_key_pem: std::fs::read(key_path)?,
56            ca_cert_pem: ca_path.map(|p| std::fs::read(p)).transpose()?,
57        })
58    }
59}
60
61/// A raw PostgreSQL connection.
62pub struct PgConnection {
63    pub(crate) stream: PgStream,
64    pub(crate) buffer: BytesMut,
65    pub(crate) write_buf: BytesMut,
66    pub(crate) sql_buf: BytesMut,
67    pub(crate) params_buf: Vec<Option<Vec<u8>>>,
68    pub(crate) prepared_statements: HashMap<String, String>,
69    pub(crate) stmt_cache: LruCache<u64, String>,
70    /// Cache of column metadata (RowDescription) per statement hash.
71    /// PostgreSQL only sends RowDescription after Parse, not on subsequent Bind+Execute.
72    /// This cache ensures by-name column access works even for cached prepared statements.
73    pub(crate) column_info_cache: HashMap<u64, Arc<super::ColumnInfo>>,
74    pub(crate) process_id: i32,
75    pub(crate) secret_key: i32,
76}
77
78impl PgConnection {
79    /// Connect to PostgreSQL server without authentication (trust mode).
80    pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
81        Self::connect_with_password(host, port, user, database, None).await
82    }
83
84    /// Connect to PostgreSQL server with optional password authentication.
85    pub async fn connect_with_password(
86        host: &str,
87        port: u16,
88        user: &str,
89        database: &str,
90        password: Option<&str>,
91    ) -> PgResult<Self> {
92        let addr = format!("{}:{}", host, port);
93        let tcp_stream = TcpStream::connect(&addr).await?;
94
95        // Disable Nagle's algorithm for lower latency
96        tcp_stream.set_nodelay(true)?;
97
98        let mut conn = Self {
99            stream: PgStream::Tcp(tcp_stream),
100            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
101            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), // 64KB write buffer
102            sql_buf: BytesMut::with_capacity(512),
103            params_buf: Vec::with_capacity(16), // SQL encoding buffer
104            prepared_statements: HashMap::new(),
105            stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
106            column_info_cache: HashMap::new(),
107            process_id: 0,
108            secret_key: 0,
109        };
110
111        conn.send(FrontendMessage::Startup {
112            user: user.to_string(),
113            database: database.to_string(),
114        })
115        .await?;
116
117        conn.handle_startup(user, password).await?;
118
119        Ok(conn)
120    }
121
122    /// Connect to PostgreSQL server with TLS encryption.
123    pub async fn connect_tls(
124        host: &str,
125        port: u16,
126        user: &str,
127        database: &str,
128        password: Option<&str>,
129    ) -> PgResult<Self> {
130        use tokio::io::AsyncReadExt;
131        use tokio_rustls::TlsConnector;
132        use tokio_rustls::rustls::ClientConfig;
133        use tokio_rustls::rustls::pki_types::ServerName;
134
135        let addr = format!("{}:{}", host, port);
136        let mut tcp_stream = TcpStream::connect(&addr).await?;
137
138        // Send SSLRequest
139        tcp_stream.write_all(&SSL_REQUEST).await?;
140
141        // Read response
142        let mut response = [0u8; 1];
143        tcp_stream.read_exact(&mut response).await?;
144
145        if response[0] != b'S' {
146            return Err(PgError::Connection(
147                "Server does not support TLS".to_string(),
148            ));
149        }
150
151        // TLS handshake
152        let certs = rustls_native_certs::load_native_certs();
153        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
154        for cert in certs.certs {
155            let _ = root_cert_store.add(cert);
156        }
157
158        let config = ClientConfig::builder()
159            .with_root_certificates(root_cert_store)
160            .with_no_client_auth();
161
162        let connector = TlsConnector::from(Arc::new(config));
163        let server_name = ServerName::try_from(host.to_string())
164            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
165
166        let tls_stream = connector
167            .connect(server_name, tcp_stream)
168            .await
169            .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
170
171        let mut conn = Self {
172            stream: PgStream::Tls(tls_stream),
173            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
174            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
175            sql_buf: BytesMut::with_capacity(512),
176            params_buf: Vec::with_capacity(16),
177            prepared_statements: HashMap::new(),
178            stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
179            column_info_cache: HashMap::new(),
180            process_id: 0,
181            secret_key: 0,
182        };
183
184        conn.send(FrontendMessage::Startup {
185            user: user.to_string(),
186            database: database.to_string(),
187        })
188        .await?;
189
190        conn.handle_startup(user, password).await?;
191
192        Ok(conn)
193    }
194
195    /// Connect with mutual TLS (client certificate authentication).
196    /// # Arguments
197    /// * `host` - PostgreSQL server hostname
198    /// * `port` - PostgreSQL server port
199    /// * `user` - Database user
200    /// * `database` - Database name
201    /// * `config` - TLS configuration with client cert/key
202    /// # Example
203    /// ```ignore
204    /// let config = TlsConfig {
205    ///     client_cert_pem: include_bytes!("client.crt").to_vec(),
206    ///     client_key_pem: include_bytes!("client.key").to_vec(),
207    ///     ca_cert_pem: Some(include_bytes!("ca.crt").to_vec()),
208    /// };
209    /// let conn = PgConnection::connect_mtls("localhost", 5432, "user", "db", config).await?;
210    /// ```
211    pub async fn connect_mtls(
212        host: &str,
213        port: u16,
214        user: &str,
215        database: &str,
216        config: TlsConfig,
217    ) -> PgResult<Self> {
218        use tokio::io::AsyncReadExt;
219        use tokio_rustls::TlsConnector;
220        use tokio_rustls::rustls::{
221            ClientConfig,
222            pki_types::{CertificateDer, ServerName},
223        };
224
225        let addr = format!("{}:{}", host, port);
226        let mut tcp_stream = TcpStream::connect(&addr).await?;
227
228        // Send SSLRequest
229        tcp_stream.write_all(&SSL_REQUEST).await?;
230
231        // Read response
232        let mut response = [0u8; 1];
233        tcp_stream.read_exact(&mut response).await?;
234
235        if response[0] != b'S' {
236            return Err(PgError::Connection(
237                "Server does not support TLS".to_string(),
238            ));
239        }
240
241        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
242
243        if let Some(ca_pem) = &config.ca_cert_pem {
244            let certs = rustls_pemfile::certs(&mut ca_pem.as_slice())
245                .filter_map(|r| r.ok())
246                .collect::<Vec<_>>();
247            for cert in certs {
248                let _ = root_cert_store.add(cert);
249            }
250        } else {
251            // Use system certs
252            let certs = rustls_native_certs::load_native_certs();
253            for cert in certs.certs {
254                let _ = root_cert_store.add(cert);
255            }
256        }
257
258        let client_certs: Vec<CertificateDer<'static>> =
259            rustls_pemfile::certs(&mut config.client_cert_pem.as_slice())
260                .filter_map(|r| r.ok())
261                .collect();
262
263        let client_key = rustls_pemfile::private_key(&mut config.client_key_pem.as_slice())
264            .map_err(|e| PgError::Connection(format!("Invalid client key: {:?}", e)))?
265            .ok_or_else(|| PgError::Connection("No private key found in PEM".to_string()))?;
266
267        let tls_config = ClientConfig::builder()
268            .with_root_certificates(root_cert_store)
269            .with_client_auth_cert(client_certs, client_key)
270            .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
271
272        let connector = TlsConnector::from(Arc::new(tls_config));
273        let server_name = ServerName::try_from(host.to_string())
274            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
275
276        let tls_stream = connector
277            .connect(server_name, tcp_stream)
278            .await
279            .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
280
281        let mut conn = Self {
282            stream: PgStream::Tls(tls_stream),
283            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
284            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
285            sql_buf: BytesMut::with_capacity(512),
286            params_buf: Vec::with_capacity(16),
287            prepared_statements: HashMap::new(),
288            stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
289            column_info_cache: HashMap::new(),
290            process_id: 0,
291            secret_key: 0,
292        };
293
294        conn.send(FrontendMessage::Startup {
295            user: user.to_string(),
296            database: database.to_string(),
297        })
298        .await?;
299
300        // mTLS typically uses cert auth, no password needed
301        conn.handle_startup(user, None).await?;
302
303        Ok(conn)
304    }
305
306    /// Connect to PostgreSQL server via Unix domain socket.
307    #[cfg(unix)]
308    pub async fn connect_unix(
309        socket_path: &str,
310        user: &str,
311        database: &str,
312        password: Option<&str>,
313    ) -> PgResult<Self> {
314        use tokio::net::UnixStream;
315
316        let unix_stream = UnixStream::connect(socket_path).await?;
317
318        let mut conn = Self {
319            stream: PgStream::Unix(unix_stream),
320            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
321            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
322            sql_buf: BytesMut::with_capacity(512),
323            params_buf: Vec::with_capacity(16),
324            prepared_statements: HashMap::new(),
325            stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
326            column_info_cache: HashMap::new(),
327            process_id: 0,
328            secret_key: 0,
329        };
330
331        conn.send(FrontendMessage::Startup {
332            user: user.to_string(),
333            database: database.to_string(),
334        })
335        .await?;
336
337        conn.handle_startup(user, password).await?;
338
339        Ok(conn)
340    }
341
342    /// Handle startup sequence (auth + params).
343    async fn handle_startup(&mut self, user: &str, password: Option<&str>) -> PgResult<()> {
344        let mut scram_client: Option<ScramClient> = None;
345
346        loop {
347            let msg = self.recv().await?;
348            match msg {
349                BackendMessage::AuthenticationOk => {}
350                BackendMessage::AuthenticationMD5Password(_salt) => {
351                    return Err(PgError::Auth(
352                        "MD5 auth not supported. Use SCRAM-SHA-256.".to_string(),
353                    ));
354                }
355                BackendMessage::AuthenticationSASL(mechanisms) => {
356                    let password = password.ok_or_else(|| {
357                        PgError::Auth("Password required for SCRAM authentication".to_string())
358                    })?;
359
360                    if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
361                        return Err(PgError::Auth(format!(
362                            "Server doesn't support SCRAM-SHA-256. Available: {:?}",
363                            mechanisms
364                        )));
365                    }
366
367                    let client = ScramClient::new(user, password);
368                    let first_message = client.client_first_message();
369
370                    self.send(FrontendMessage::SASLInitialResponse {
371                        mechanism: "SCRAM-SHA-256".to_string(),
372                        data: first_message,
373                    })
374                    .await?;
375
376                    scram_client = Some(client);
377                }
378                BackendMessage::AuthenticationSASLContinue(server_data) => {
379                    let client = scram_client.as_mut().ok_or_else(|| {
380                        PgError::Auth("Received SASL Continue without SASL init".to_string())
381                    })?;
382
383                    let final_message = client
384                        .process_server_first(&server_data)
385                        .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
386
387                    self.send(FrontendMessage::SASLResponse(final_message))
388                        .await?;
389                }
390                BackendMessage::AuthenticationSASLFinal(server_signature) => {
391                    if let Some(client) = scram_client.as_ref() {
392                        client.verify_server_final(&server_signature).map_err(|e| {
393                            PgError::Auth(format!("Server verification failed: {}", e))
394                        })?;
395                    }
396                }
397                BackendMessage::ParameterStatus { .. } => {}
398                BackendMessage::BackendKeyData {
399                    process_id,
400                    secret_key,
401                } => {
402                    self.process_id = process_id;
403                    self.secret_key = secret_key;
404                }
405                BackendMessage::ReadyForQuery(TransactionStatus::Idle)
406                | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
407                | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
408                    return Ok(());
409                }
410                BackendMessage::ErrorResponse(err) => {
411                    return Err(PgError::Connection(err.message));
412                }
413                _ => {}
414            }
415        }
416    }
417
418    /// Gracefully close the connection by sending a Terminate message.
419    /// This tells the server we're done and allows proper cleanup.
420    pub async fn close(mut self) -> PgResult<()> {
421        use crate::protocol::PgEncoder;
422        
423        // Send Terminate packet ('X')
424        let terminate = PgEncoder::encode_terminate();
425        self.stream.write_all(&terminate).await?;
426        self.stream.flush().await?;
427        
428        Ok(())
429    }
430}
431
432/// Drop implementation sends Terminate packet if possible.
433/// This ensures proper cleanup even without explicit close() call.
434impl Drop for PgConnection {
435    fn drop(&mut self) {
436        // Try to send Terminate packet synchronously using try_write
437        // This is best-effort - if it fails, TCP RST will handle cleanup
438        let terminate: [u8; 5] = [b'X', 0, 0, 0, 4];
439        
440        match &mut self.stream {
441            PgStream::Tcp(tcp) => {
442                // try_write is non-blocking
443                let _ = tcp.try_write(&terminate);
444            }
445            PgStream::Tls(_) => {
446                // TLS requires async write which we can't do in Drop.
447                // The TCP connection close will still notify the server.
448                // For graceful TLS shutdown, use connection.close() explicitly.
449            }
450            #[cfg(unix)]
451            PgStream::Unix(unix) => {
452                let _ = unix.try_write(&terminate);
453            }
454        }
455    }
456}
457
458pub(crate) fn parse_affected_rows(tag: &str) -> u64 {
459    tag.split_whitespace()
460        .last()
461        .and_then(|s| s.parse().ok())
462        .unwrap_or(0)
463}