qail_pg/driver/
connection.rs

1//! PostgreSQL Connection
2//!
3//! Low-level TCP connection with wire protocol handling.
4//! This is Layer 3 (async I/O).
5//!
6//! Methods are split across modules for easier maintenance:
7//! - `io.rs` - Core I/O (send, recv)
8//! - `query.rs` - Query execution
9//! - `transaction.rs` - Transaction control
10//! - `cursor.rs` - Streaming cursors
11//! - `copy.rs` - COPY protocol
12//! - `pipeline.rs` - High-performance pipelining
13//! - `cancel.rs` - Query cancellation
14
15use super::stream::PgStream;
16use super::{PgError, PgResult};
17use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
18use bytes::BytesMut;
19use lru::LruCache;
20use std::collections::HashMap;
21use std::num::NonZeroUsize;
22use std::sync::Arc;
23use tokio::io::AsyncWriteExt;
24use tokio::net::TcpStream;
25
26/// Initial buffer capacity (64KB for pipeline performance)
27pub(crate) const BUFFER_CAPACITY: usize = 65536;
28
29/// SSLRequest message bytes (request code: 80877103)
30const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 4, 210, 22, 47];
31
32/// CancelRequest protocol code: 80877102
33pub(crate) const CANCEL_REQUEST_CODE: i32 = 80877102;
34
35/// TLS configuration for mutual TLS (client certificate authentication).
36#[derive(Clone)]
37pub struct TlsConfig {
38    /// Client certificate in PEM format
39    pub client_cert_pem: Vec<u8>,
40    /// Client private key in PEM format
41    pub client_key_pem: Vec<u8>,
42    /// Optional CA certificate for server verification (uses system certs if None)
43    pub ca_cert_pem: Option<Vec<u8>>,
44}
45
46impl TlsConfig {
47    /// Create a new TLS config from file paths.
48    pub fn from_files(
49        cert_path: impl AsRef<std::path::Path>,
50        key_path: impl AsRef<std::path::Path>,
51        ca_path: Option<impl AsRef<std::path::Path>>,
52    ) -> std::io::Result<Self> {
53        Ok(Self {
54            client_cert_pem: std::fs::read(cert_path)?,
55            client_key_pem: std::fs::read(key_path)?,
56            ca_cert_pem: ca_path.map(|p| std::fs::read(p)).transpose()?,
57        })
58    }
59}
60
61/// A raw PostgreSQL connection.
62pub struct PgConnection {
63    pub(crate) stream: PgStream,
64    pub(crate) buffer: BytesMut,
65    pub(crate) write_buf: BytesMut,
66    pub(crate) sql_buf: BytesMut,
67    pub(crate) params_buf: Vec<Option<Vec<u8>>>,
68    pub(crate) prepared_statements: HashMap<String, String>,
69    pub(crate) stmt_cache: LruCache<u64, String>,
70    pub(crate) process_id: i32,
71    pub(crate) secret_key: i32,
72}
73
74impl PgConnection {
75    /// Connect to PostgreSQL server without authentication (trust mode).
76    pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
77        Self::connect_with_password(host, port, user, database, None).await
78    }
79
80    /// Connect to PostgreSQL server with optional password authentication.
81    pub async fn connect_with_password(
82        host: &str,
83        port: u16,
84        user: &str,
85        database: &str,
86        password: Option<&str>,
87    ) -> PgResult<Self> {
88        let addr = format!("{}:{}", host, port);
89        let tcp_stream = TcpStream::connect(&addr).await?;
90
91        // Disable Nagle's algorithm for lower latency
92        tcp_stream.set_nodelay(true)?;
93
94        let mut conn = Self {
95            stream: PgStream::Tcp(tcp_stream),
96            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
97            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY), // 64KB write buffer
98            sql_buf: BytesMut::with_capacity(512),
99            params_buf: Vec::with_capacity(16), // SQL encoding buffer
100            prepared_statements: HashMap::new(),
101            stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
102            process_id: 0,
103            secret_key: 0,
104        };
105
106        conn.send(FrontendMessage::Startup {
107            user: user.to_string(),
108            database: database.to_string(),
109        })
110        .await?;
111
112        conn.handle_startup(user, password).await?;
113
114        Ok(conn)
115    }
116
117    /// Connect to PostgreSQL server with TLS encryption.
118    pub async fn connect_tls(
119        host: &str,
120        port: u16,
121        user: &str,
122        database: &str,
123        password: Option<&str>,
124    ) -> PgResult<Self> {
125        use tokio::io::AsyncReadExt;
126        use tokio_rustls::TlsConnector;
127        use tokio_rustls::rustls::ClientConfig;
128        use tokio_rustls::rustls::pki_types::ServerName;
129
130        let addr = format!("{}:{}", host, port);
131        let mut tcp_stream = TcpStream::connect(&addr).await?;
132
133        // Send SSLRequest
134        tcp_stream.write_all(&SSL_REQUEST).await?;
135
136        // Read response
137        let mut response = [0u8; 1];
138        tcp_stream.read_exact(&mut response).await?;
139
140        if response[0] != b'S' {
141            return Err(PgError::Connection(
142                "Server does not support TLS".to_string(),
143            ));
144        }
145
146        // TLS handshake
147        let certs = rustls_native_certs::load_native_certs();
148        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
149        for cert in certs.certs {
150            let _ = root_cert_store.add(cert);
151        }
152
153        let config = ClientConfig::builder()
154            .with_root_certificates(root_cert_store)
155            .with_no_client_auth();
156
157        let connector = TlsConnector::from(Arc::new(config));
158        let server_name = ServerName::try_from(host.to_string())
159            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
160
161        let tls_stream = connector
162            .connect(server_name, tcp_stream)
163            .await
164            .map_err(|e| PgError::Connection(format!("TLS handshake failed: {}", e)))?;
165
166        let mut conn = Self {
167            stream: PgStream::Tls(tls_stream),
168            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
169            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
170            sql_buf: BytesMut::with_capacity(512),
171            params_buf: Vec::with_capacity(16),
172            prepared_statements: HashMap::new(),
173            stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
174            process_id: 0,
175            secret_key: 0,
176        };
177
178        conn.send(FrontendMessage::Startup {
179            user: user.to_string(),
180            database: database.to_string(),
181        })
182        .await?;
183
184        conn.handle_startup(user, password).await?;
185
186        Ok(conn)
187    }
188
189    /// Connect with mutual TLS (client certificate authentication).
190    /// # Arguments
191    /// * `host` - PostgreSQL server hostname
192    /// * `port` - PostgreSQL server port
193    /// * `user` - Database user
194    /// * `database` - Database name
195    /// * `config` - TLS configuration with client cert/key
196    /// # Example
197    /// ```ignore
198    /// let config = TlsConfig {
199    ///     client_cert_pem: include_bytes!("client.crt").to_vec(),
200    ///     client_key_pem: include_bytes!("client.key").to_vec(),
201    ///     ca_cert_pem: Some(include_bytes!("ca.crt").to_vec()),
202    /// };
203    /// let conn = PgConnection::connect_mtls("localhost", 5432, "user", "db", config).await?;
204    /// ```
205    pub async fn connect_mtls(
206        host: &str,
207        port: u16,
208        user: &str,
209        database: &str,
210        config: TlsConfig,
211    ) -> PgResult<Self> {
212        use tokio::io::AsyncReadExt;
213        use tokio_rustls::TlsConnector;
214        use tokio_rustls::rustls::{
215            ClientConfig,
216            pki_types::{CertificateDer, ServerName},
217        };
218
219        let addr = format!("{}:{}", host, port);
220        let mut tcp_stream = TcpStream::connect(&addr).await?;
221
222        // Send SSLRequest
223        tcp_stream.write_all(&SSL_REQUEST).await?;
224
225        // Read response
226        let mut response = [0u8; 1];
227        tcp_stream.read_exact(&mut response).await?;
228
229        if response[0] != b'S' {
230            return Err(PgError::Connection(
231                "Server does not support TLS".to_string(),
232            ));
233        }
234
235        let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
236
237        if let Some(ca_pem) = &config.ca_cert_pem {
238            let certs = rustls_pemfile::certs(&mut ca_pem.as_slice())
239                .filter_map(|r| r.ok())
240                .collect::<Vec<_>>();
241            for cert in certs {
242                let _ = root_cert_store.add(cert);
243            }
244        } else {
245            // Use system certs
246            let certs = rustls_native_certs::load_native_certs();
247            for cert in certs.certs {
248                let _ = root_cert_store.add(cert);
249            }
250        }
251
252        let client_certs: Vec<CertificateDer<'static>> =
253            rustls_pemfile::certs(&mut config.client_cert_pem.as_slice())
254                .filter_map(|r| r.ok())
255                .collect();
256
257        let client_key = rustls_pemfile::private_key(&mut config.client_key_pem.as_slice())
258            .map_err(|e| PgError::Connection(format!("Invalid client key: {:?}", e)))?
259            .ok_or_else(|| PgError::Connection("No private key found in PEM".to_string()))?;
260
261        let tls_config = ClientConfig::builder()
262            .with_root_certificates(root_cert_store)
263            .with_client_auth_cert(client_certs, client_key)
264            .map_err(|e| PgError::Connection(format!("Invalid client cert/key: {}", e)))?;
265
266        let connector = TlsConnector::from(Arc::new(tls_config));
267        let server_name = ServerName::try_from(host.to_string())
268            .map_err(|_| PgError::Connection("Invalid hostname for TLS".to_string()))?;
269
270        let tls_stream = connector
271            .connect(server_name, tcp_stream)
272            .await
273            .map_err(|e| PgError::Connection(format!("mTLS handshake failed: {}", e)))?;
274
275        let mut conn = Self {
276            stream: PgStream::Tls(tls_stream),
277            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
278            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
279            sql_buf: BytesMut::with_capacity(512),
280            params_buf: Vec::with_capacity(16),
281            prepared_statements: HashMap::new(),
282            stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
283            process_id: 0,
284            secret_key: 0,
285        };
286
287        conn.send(FrontendMessage::Startup {
288            user: user.to_string(),
289            database: database.to_string(),
290        })
291        .await?;
292
293        // mTLS typically uses cert auth, no password needed
294        conn.handle_startup(user, None).await?;
295
296        Ok(conn)
297    }
298
299    /// Connect to PostgreSQL server via Unix domain socket.
300    #[cfg(unix)]
301    pub async fn connect_unix(
302        socket_path: &str,
303        user: &str,
304        database: &str,
305        password: Option<&str>,
306    ) -> PgResult<Self> {
307        use tokio::net::UnixStream;
308
309        let unix_stream = UnixStream::connect(socket_path).await?;
310
311        let mut conn = Self {
312            stream: PgStream::Unix(unix_stream),
313            buffer: BytesMut::with_capacity(BUFFER_CAPACITY),
314            write_buf: BytesMut::with_capacity(BUFFER_CAPACITY),
315            sql_buf: BytesMut::with_capacity(512),
316            params_buf: Vec::with_capacity(16),
317            prepared_statements: HashMap::new(),
318            stmt_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
319            process_id: 0,
320            secret_key: 0,
321        };
322
323        conn.send(FrontendMessage::Startup {
324            user: user.to_string(),
325            database: database.to_string(),
326        })
327        .await?;
328
329        conn.handle_startup(user, password).await?;
330
331        Ok(conn)
332    }
333
334    /// Handle startup sequence (auth + params).
335    async fn handle_startup(&mut self, user: &str, password: Option<&str>) -> PgResult<()> {
336        let mut scram_client: Option<ScramClient> = None;
337
338        loop {
339            let msg = self.recv().await?;
340            match msg {
341                BackendMessage::AuthenticationOk => {}
342                BackendMessage::AuthenticationMD5Password(_salt) => {
343                    return Err(PgError::Auth(
344                        "MD5 auth not supported. Use SCRAM-SHA-256.".to_string(),
345                    ));
346                }
347                BackendMessage::AuthenticationSASL(mechanisms) => {
348                    let password = password.ok_or_else(|| {
349                        PgError::Auth("Password required for SCRAM authentication".to_string())
350                    })?;
351
352                    if !mechanisms.iter().any(|m| m == "SCRAM-SHA-256") {
353                        return Err(PgError::Auth(format!(
354                            "Server doesn't support SCRAM-SHA-256. Available: {:?}",
355                            mechanisms
356                        )));
357                    }
358
359                    let client = ScramClient::new(user, password);
360                    let first_message = client.client_first_message();
361
362                    self.send(FrontendMessage::SASLInitialResponse {
363                        mechanism: "SCRAM-SHA-256".to_string(),
364                        data: first_message,
365                    })
366                    .await?;
367
368                    scram_client = Some(client);
369                }
370                BackendMessage::AuthenticationSASLContinue(server_data) => {
371                    let client = scram_client.as_mut().ok_or_else(|| {
372                        PgError::Auth("Received SASL Continue without SASL init".to_string())
373                    })?;
374
375                    let final_message = client
376                        .process_server_first(&server_data)
377                        .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
378
379                    self.send(FrontendMessage::SASLResponse(final_message))
380                        .await?;
381                }
382                BackendMessage::AuthenticationSASLFinal(server_signature) => {
383                    if let Some(client) = scram_client.as_ref() {
384                        client.verify_server_final(&server_signature).map_err(|e| {
385                            PgError::Auth(format!("Server verification failed: {}", e))
386                        })?;
387                    }
388                }
389                BackendMessage::ParameterStatus { .. } => {}
390                BackendMessage::BackendKeyData {
391                    process_id,
392                    secret_key,
393                } => {
394                    self.process_id = process_id;
395                    self.secret_key = secret_key;
396                }
397                BackendMessage::ReadyForQuery(TransactionStatus::Idle)
398                | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
399                | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
400                    return Ok(());
401                }
402                BackendMessage::ErrorResponse(err) => {
403                    return Err(PgError::Connection(err.message));
404                }
405                _ => {}
406            }
407        }
408    }
409
410    /// Gracefully close the connection by sending a Terminate message.
411    /// This tells the server we're done and allows proper cleanup.
412    pub async fn close(mut self) -> PgResult<()> {
413        use crate::protocol::PgEncoder;
414        
415        // Send Terminate packet ('X')
416        let terminate = PgEncoder::encode_terminate();
417        self.stream.write_all(&terminate).await?;
418        self.stream.flush().await?;
419        
420        Ok(())
421    }
422}
423
424/// Drop implementation sends Terminate packet if possible.
425/// This ensures proper cleanup even without explicit close() call.
426impl Drop for PgConnection {
427    fn drop(&mut self) {
428        // Try to send Terminate packet synchronously using try_write
429        // This is best-effort - if it fails, TCP RST will handle cleanup
430        let terminate: [u8; 5] = [b'X', 0, 0, 0, 4];
431        
432        match &mut self.stream {
433            PgStream::Tcp(tcp) => {
434                // try_write is non-blocking
435                let _ = tcp.try_write(&terminate);
436            }
437            PgStream::Tls(_) => {
438                // TLS requires async write which we can't do in Drop.
439                // The TCP connection close will still notify the server.
440                // For graceful TLS shutdown, use connection.close() explicitly.
441            }
442            #[cfg(unix)]
443            PgStream::Unix(unix) => {
444                let _ = unix.try_write(&terminate);
445            }
446        }
447    }
448}
449
450pub(crate) fn parse_affected_rows(tag: &str) -> u64 {
451    tag.split_whitespace()
452        .last()
453        .and_then(|s| s.parse().ok())
454        .unwrap_or(0)
455}