Skip to main content

bsql_driver_postgres/
conn.rs

1//! PostgreSQL connection — startup, authentication, statement cache, query execution.
2//!
3//! `Connection` owns a TCP, TLS, or Unix domain socket stream and implements the
4//! extended query protocol with pipelining. Statements are cached by rapidhash of the
5//! SQL text. On first use, Parse+Describe+Bind+Execute+Sync are pipelined in one write.
6//! On subsequent uses, only Bind+Execute+Sync are sent.
7//!
8//! # Unix domain sockets
9//!
10//! When `Config::host` starts with `/`, the driver connects via Unix domain socket
11//! at `{host}/.s.PGSQL.{port}` (libpq convention). Use `?host=/tmp` in the connection
12//! URL to enable UDS. This avoids TCP overhead for localhost connections.
13
14use std::sync::Arc;
15
16use rapidhash::quality::RapidHasher;
17
18// --- Vec-based statement cache ---
19//
20// For typical workloads of 5-20 cached statements, linear scan over a Vec
21// with u64 comparison is faster than HashMap probe sequence because:
22// - Vec = contiguous memory, perfect cache locality (all entries in L1)
23// - u64 comparison = one instruction per entry
24// - No hash probe, no bucket lookup, no load factor math
25
26/// Vec-backed statement cache with O(n) lookup. Faster than HashMap for
27/// small n (< ~30 entries) due to cache locality and zero hashing overhead.
28struct StmtCache {
29    entries: Vec<(u64, StmtInfo)>,
30}
31
32impl Default for StmtCache {
33    fn default() -> Self {
34        Self {
35            entries: Vec::with_capacity(16),
36        }
37    }
38}
39
40impl StmtCache {
41    #[inline]
42    fn get_mut(&mut self, hash: &u64) -> Option<&mut StmtInfo> {
43        self.entries
44            .iter_mut()
45            .find(|(h, _)| h == hash)
46            .map(|(_, info)| info)
47    }
48
49    #[inline]
50    fn get(&self, hash: &u64) -> Option<&StmtInfo> {
51        self.entries
52            .iter()
53            .find(|(h, _)| h == hash)
54            .map(|(_, info)| info)
55    }
56
57    #[inline]
58    fn contains_key(&self, hash: &u64) -> bool {
59        self.entries.iter().any(|(h, _)| h == hash)
60    }
61
62    #[inline]
63    fn insert(&mut self, hash: u64, info: StmtInfo) {
64        if let Some(entry) = self.entries.iter_mut().find(|(h, _)| *h == hash) {
65            entry.1 = info;
66        } else {
67            self.entries.push((hash, info));
68        }
69    }
70
71    #[inline]
72    fn remove(&mut self, hash: &u64) -> Option<StmtInfo> {
73        if let Some(pos) = self.entries.iter().position(|(h, _)| h == hash) {
74            Some(self.entries.swap_remove(pos).1)
75        } else {
76            None
77        }
78    }
79
80    #[inline]
81    fn len(&self) -> usize {
82        self.entries.len()
83    }
84
85    /// Evict the least recently used entry (lowest `last_used` counter).
86    fn evict_lru(&mut self) -> Option<(u64, StmtInfo)> {
87        if self.entries.is_empty() {
88            return None;
89        }
90        let min_idx = self
91            .entries
92            .iter()
93            .enumerate()
94            .min_by_key(|(_, (_, info))| info.last_used)
95            .map(|(i, _)| i)?;
96        Some(self.entries.swap_remove(min_idx))
97    }
98}
99
100use tokio::io::{AsyncRead, AsyncWriteExt};
101use tokio::net::TcpStream;
102
103use crate::DriverError;
104use crate::arena::Arena;
105use crate::auth;
106use crate::codec::Encode;
107use crate::proto::{self, BackendMessage};
108
109#[cfg(feature = "tls")]
110use crate::tls;
111
112// --- Stream abstraction ---
113
114/// The underlying stream type — plain TCP, TLS, or Unix domain socket.
115enum Stream {
116    Plain(TcpStream),
117    #[cfg(feature = "tls")]
118    Tls(Box<tokio_rustls::client::TlsStream<TcpStream>>),
119    #[cfg(unix)]
120    Unix(tokio::net::UnixStream),
121}
122
123impl Stream {
124    async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
125        match self {
126            Stream::Plain(s) => s.write_all(buf).await,
127            #[cfg(feature = "tls")]
128            Stream::Tls(s) => s.write_all(buf).await,
129            #[cfg(unix)]
130            Stream::Unix(s) => s.write_all(buf).await,
131        }
132    }
133
134    async fn flush(&mut self) -> std::io::Result<()> {
135        match self {
136            Stream::Plain(s) => s.flush().await,
137            #[cfg(feature = "tls")]
138            Stream::Tls(s) => s.flush().await,
139            #[cfg(unix)]
140            Stream::Unix(s) => s.flush().await,
141        }
142    }
143}
144
145/// Wrapper to implement AsyncRead for Stream.
146struct StreamReader<'a>(&'a mut Stream);
147
148impl AsyncRead for StreamReader<'_> {
149    fn poll_read(
150        mut self: std::pin::Pin<&mut Self>,
151        cx: &mut std::task::Context<'_>,
152        buf: &mut tokio::io::ReadBuf<'_>,
153    ) -> std::task::Poll<std::io::Result<()>> {
154        match &mut *self.0 {
155            Stream::Plain(s) => std::pin::Pin::new(s).poll_read(cx, buf),
156            #[cfg(feature = "tls")]
157            Stream::Tls(s) => std::pin::Pin::new(s.as_mut()).poll_read(cx, buf),
158            #[cfg(unix)]
159            Stream::Unix(s) => std::pin::Pin::new(s).poll_read(cx, buf),
160        }
161    }
162}
163
164// --- Config ---
165
166/// Connection configuration parsed from a URL.
167///
168/// Format: `postgres://user:password@host:port/database`
169///
170/// Implements Drop to zeroize the password field, minimizing the
171/// window where plaintext credentials live in memory.
172#[derive(Debug, Clone)]
173pub struct Config {
174    pub host: String,
175    pub port: u16,
176    pub user: String,
177    pub password: String,
178    pub database: String,
179    pub ssl: SslMode,
180    /// PG-side statement timeout in seconds. Default: 30. 0 = no timeout.
181    ///
182    /// After connecting, the driver sends `SET statement_timeout = '{N}s'`.
183    /// If a query exceeds this duration, PostgreSQL kills it and returns an error.
184    pub statement_timeout_secs: u32,
185}
186
187/// Zeroize password on drop to minimize credential lifetime in memory.
188impl Drop for Config {
189    fn drop(&mut self) {
190        use zeroize::Zeroize;
191        self.password.zeroize();
192    }
193}
194
195/// SSL/TLS connection mode.
196#[derive(Debug, Clone, Copy, PartialEq, Eq)]
197pub enum SslMode {
198    /// Never use TLS.
199    Disable,
200    /// Try TLS, fall back to plain if server says 'N'.
201    Prefer,
202    /// Require TLS, fail if server says 'N'.
203    Require,
204}
205
206impl Config {
207    /// Parse a PostgreSQL connection URL.
208    ///
209    /// Format: `postgres://user:password@host:port/database?sslmode=prefer`
210    ///
211    /// # Unix domain sockets
212    ///
213    /// Use the `host` query parameter to specify a UDS directory (libpq convention):
214    /// ```text
215    /// postgres://user@localhost/dbname?host=/tmp
216    /// postgres:///dbname?host=/var/run/postgresql
217    /// ```
218    /// When `host` starts with `/`, the driver connects via Unix domain socket at
219    /// `{host}/.s.PGSQL.{port}` instead of TCP. TLS is skipped for UDS connections.
220    pub fn from_url(url: &str) -> Result<Self, DriverError> {
221        let url = url
222            .strip_prefix("postgres://")
223            .or_else(|| url.strip_prefix("postgresql://"))
224            .ok_or_else(|| DriverError::Protocol("URL must start with postgres://".into()))?;
225
226        // Split user:password@host:port/database
227        let (userinfo, rest) = url
228            .split_once('@')
229            .ok_or_else(|| DriverError::Protocol("missing @ in connection URL".into()))?;
230
231        let (user, password) = userinfo.split_once(':').unwrap_or((userinfo, ""));
232
233        // Split host:port/database?params
234        let (hostport, rest) = rest.split_once('/').unwrap_or((rest, ""));
235        let (database, params) = rest.split_once('?').unwrap_or((rest, ""));
236
237        let (host, port) = if let Some((h, p)) = hostport.split_once(':') {
238            let port = p
239                .parse::<u16>()
240                .map_err(|_| DriverError::Protocol(format!("invalid port: {p}")))?;
241            (h.to_owned(), port)
242        } else {
243            (hostport.to_owned(), 5432)
244        };
245
246        let mut ssl = SslMode::Prefer;
247        let mut statement_timeout_secs: u32 = 30;
248        let mut host_override: Option<String> = None;
249        for param in params.split('&') {
250            if param.is_empty() {
251                continue;
252            }
253            if let Some(val) = param.strip_prefix("sslmode=") {
254                // A typo like "sslmode=require" (missing 'e') would go unencrypted.
255                ssl = match val {
256                    "disable" => SslMode::Disable,
257                    "prefer" => SslMode::Prefer,
258                    "require" => SslMode::Require,
259                    _ => {
260                        return Err(DriverError::Protocol(format!(
261                            "unknown sslmode: '{val}' (expected: disable, prefer, require)"
262                        )));
263                    }
264                };
265            } else if let Some(val) = param.strip_prefix("statement_timeout=") {
266                statement_timeout_secs = val.parse::<u32>().unwrap_or(30);
267            } else if let Some(val) = param.strip_prefix("host=") {
268                host_override = Some(url_decode(val)?);
269            }
270        }
271
272        // If ?host=/path was specified, override the URL hostname with it.
273        // This follows the libpq convention: host=/tmp means UDS.
274        let final_host = if let Some(h) = host_override {
275            h
276        } else {
277            url_decode(&host)?
278        };
279
280        let config = Config {
281            host: final_host,
282            port,
283            user: url_decode(user)?,
284            password: url_decode(password)?,
285            database: if database.is_empty() {
286                url_decode(user)?
287            } else {
288                url_decode(database)?
289            },
290            ssl,
291            statement_timeout_secs,
292        };
293        config.validate()?;
294        Ok(config)
295    }
296
297    /// Validate configuration fields before attempting a connection.
298    ///
299    /// Called automatically by `from_url()`. Call manually if constructing
300    /// a `Config` by hand.
301    pub fn validate(&self) -> Result<(), DriverError> {
302        if self.host.is_empty() {
303            return Err(DriverError::Protocol("host cannot be empty".into()));
304        }
305        if self.user.is_empty() {
306            return Err(DriverError::Protocol("user cannot be empty".into()));
307        }
308        if self.database.is_empty() {
309            return Err(DriverError::Protocol("database cannot be empty".into()));
310        }
311        Ok(())
312    }
313
314    /// Returns `true` if the host is a Unix domain socket directory path.
315    ///
316    /// libpq convention: if `host` starts with `/`, the connection uses a
317    /// Unix domain socket at `{host}/.s.PGSQL.{port}`.
318    pub fn host_is_uds(&self) -> bool {
319        self.host.starts_with('/')
320    }
321
322    /// Returns the Unix domain socket path: `{host}/.s.PGSQL.{port}`.
323    ///
324    /// Only meaningful when [`host_is_uds()`](Self::host_is_uds) returns `true`.
325    pub fn uds_path(&self) -> String {
326        format!("{}/.s.PGSQL.{}", self.host, self.port)
327    }
328}
329
330/// Minimal percent-decoding for connection URL components.
331///
332/// Decodes `%XX` hex sequences into raw bytes, then validates as UTF-8.
333/// This correctly handles multi-byte UTF-8 characters that are percent-encoded
334/// byte-by-byte (e.g. `%C3%A9` for 'é').
335fn url_decode(s: &str) -> Result<String, DriverError> {
336    let mut bytes = Vec::with_capacity(s.len());
337    let input = s.as_bytes();
338    let mut i = 0;
339    while i < input.len() {
340        if input[i] == b'%' {
341            if i + 2 >= input.len() {
342                return Err(DriverError::Protocol(format!(
343                    "malformed percent-encoding in URL: '{s}'"
344                )));
345            }
346            let hi = hex_val(input[i + 1]).ok_or_else(|| {
347                DriverError::Protocol(format!(
348                    "invalid hex digit '{}' in URL: '{s}'",
349                    input[i + 1] as char
350                ))
351            })?;
352            let lo = hex_val(input[i + 2]).ok_or_else(|| {
353                DriverError::Protocol(format!(
354                    "invalid hex digit '{}' in URL: '{s}'",
355                    input[i + 2] as char
356                ))
357            })?;
358            bytes.push(hi * 16 + lo);
359            i += 3;
360        } else {
361            bytes.push(input[i]);
362            i += 1;
363        }
364    }
365    String::from_utf8(bytes)
366        .map_err(|_| DriverError::Protocol(format!("invalid UTF-8 in URL: '{s}'")))
367}
368
369fn hex_val(b: u8) -> Option<u8> {
370    match b {
371        b'0'..=b'9' => Some(b - b'0'),
372        b'a'..=b'f' => Some(b - b'a' + 10),
373        b'A'..=b'F' => Some(b - b'A' + 10),
374        _ => None,
375    }
376}
377
378/// Owned action from a startup message, avoiding borrow conflicts with `self.read_buf`.
379enum StartupAction {
380    AuthOk,
381    AuthCleartext,
382    AuthMd5([u8; 4]),
383    AuthSasl(Vec<u8>),
384    ParameterStatus(Box<str>, Box<str>),
385    BackendKeyData(i32, i32),
386    ReadyForQuery(u8),
387    Error(String),
388    Notice,
389}
390
391// --- Statement cache ---
392
393/// Format a statement name from a hash: `"s_{hash:016x}"`.
394///
395/// Stack-allocated formatting. The name is always exactly 19 bytes:
396/// "s_" (2) + 16 hex digits (16) + NUL-termination handled by protocol layer.
397/// Uses a fixed [u8; 19] buffer with manual hex encoding — no heap allocation.
398#[inline]
399fn make_stmt_name(hash: u64) -> Box<str> {
400    const HEX: &[u8; 16] = b"0123456789abcdef";
401    let mut buf = [0u8; 18]; // "s_" + 16 hex = 18 bytes
402    buf[0] = b's';
403    buf[1] = b'_';
404    let bytes = hash.to_be_bytes();
405    for (i, &b) in bytes.iter().enumerate() {
406        buf[2 + i * 2] = HEX[(b >> 4) as usize];
407        buf[2 + i * 2 + 1] = HEX[(b & 0x0f) as usize];
408    }
409    // buf contains only ASCII hex digits ('0'-'9','a'-'f') and 's','_'.
410    // from_utf8 is infallible here — the expect documents why.
411    let s = std::str::from_utf8(&buf).expect("BUG: stmt name buffer contains only ASCII hex");
412    s.into()
413}
414
415/// Cached information about a prepared statement.
416///
417/// The statement name is a 64-bit rapidhash formatted as `"s_{hash:016x}"`.
418/// With 2^64 possible values, collision probability is negligible for realistic
419/// workloads (e.g., ~1 in 10^13 for 10,000 distinct queries). A collision would
420/// cause a protocol error from PostgreSQL (parameter mismatch), not silent
421/// data corruption. If you have an adversarial workload that could craft
422/// collisions, consider a verified cache keyed on the full SQL text.
423struct StmtInfo {
424    /// Statement name: `"s_{hash:016x}"`
425    name: Box<str>,
426    /// Column metadata from RowDescription.
427    columns: Arc<[ColumnDesc]>,
428    /// Monotonic counter value at last use for LRU eviction.
429    /// Cheaper than `Instant::now()` which is a syscall on macOS (~20-40ns).
430    last_used: u64,
431    /// Pre-built Bind message template for fast re-execution.
432    ///
433    /// On the first execution of a cached statement, we snapshot the complete
434    /// Bind message bytes. On subsequent executions with fixed-size parameters,
435    /// we memcpy the template and patch only the parameter data in-place,
436    /// avoiding the full `write_bind_params` rebuild (~100-200ns savings per
437    /// query on the hot path).
438    ///
439    /// `None` until the first execution populates it.
440    bind_template: Option<BindTemplate>,
441}
442
443/// Pre-built Bind+Execute+Sync message template for fast re-execution.
444///
445/// Stores the complete Bind message bytes followed by EXECUTE_SYNC, and the
446/// byte offsets where each parameter's data begins. On re-execution with
447/// same-sized params, we copy the template and overwrite param data in-place
448/// via `encode_at` — no scratch buffer, no double-copy.
449struct BindTemplate {
450    /// Bind message bytes + EXECUTE_SYNC (15 bytes) appended.
451    bytes: Vec<u8>,
452    /// Offset where the Bind message ends (before EXECUTE_SYNC).
453    /// Used by streaming queries that need Execute+Flush instead.
454    bind_end: usize,
455    /// For each parameter: `(data_offset, data_len)` within `bytes`.
456    /// `data_offset` points to the first byte of param data (after the i32 length).
457    /// `data_len` is the length of the param data. -1 means NULL.
458    param_slots: Vec<(usize, i32)>,
459}
460
461/// Description of a result column.
462#[derive(Debug, Clone)]
463pub struct ColumnDesc {
464    /// Column name from the query.
465    pub name: Box<str>,
466    /// PostgreSQL type OID.
467    pub type_oid: u32,
468    /// Type size in bytes (-1 for variable-length).
469    pub type_size: i16,
470    /// OID of the source table (0 if not a table column, e.g. computed).
471    pub table_oid: u32,
472    /// Column number within the source table (0 if not a table column).
473    pub column_id: i16,
474}
475
476/// Result of a `prepare_describe` call — column and parameter metadata
477/// without executing the query.
478#[derive(Debug, Clone)]
479pub struct PrepareResult {
480    /// Output columns (empty for INSERT/UPDATE/DELETE without RETURNING).
481    pub columns: Vec<ColumnDesc>,
482    /// PostgreSQL OIDs of the expected parameter types.
483    pub param_oids: Vec<u32>,
484}
485
486/// A single row of text values returned by `simple_query_rows`.
487///
488/// Each field is `None` for SQL NULL, `Some(text)` otherwise.
489/// Only used for compile-time schema introspection queries.
490pub type SimpleRow = Vec<Option<String>>;
491
492// --- Connection ---
493
494/// A notification received during normal query processing.
495///
496/// When the read loop encounters a NotificationResponse during queries,
497/// it is buffered here instead of being dropped. Call
498/// [`Connection::drain_notifications`] to retrieve and clear the buffer.
499#[derive(Debug, Clone)]
500pub struct Notification {
501    /// Backend process ID that sent the notification.
502    pub pid: i32,
503    /// Channel name.
504    pub channel: String,
505    /// Payload string (may be empty).
506    pub payload: String,
507}
508
509/// A PostgreSQL connection with statement cache and inline message processing.
510///
511/// Connections are not `Send` — they must be used on one task at a time. The pool
512/// handles concurrent access by lending connections to individual tasks.
513pub struct Connection {
514    stream: Stream,
515    /// Message payload buffer (re-used per message).
516    read_buf: Vec<u8>,
517    /// Buffered read: raw bytes from the TCP stream. We read 64KB chunks and
518    /// parse messages from this buffer, issuing a new read only when exhausted.
519    stream_buf: Vec<u8>,
520    /// How many valid bytes are in `stream_buf[stream_buf_pos..]`.
521    stream_buf_pos: usize,
522    /// One past the last valid byte in `stream_buf`.
523    stream_buf_end: usize,
524    write_buf: Vec<u8>,
525    stmts: StmtCache,
526    params: Vec<(Box<str>, Box<str>)>,
527    pid: i32,
528    secret: i32,
529    tx_status: u8,
530    /// Timestamp of the last successful query completion. Used by the pool
531    /// to detect stale connections and discard them instead of returning
532    /// a potentially dead TCP socket.
533    last_used: std::time::Instant,
534    /// Whether a streaming query is in progress. When true, the
535    /// connection is in an indeterminate protocol state (portal open, no
536    /// ReadyForQuery) and cannot be reused. PoolGuard::drop checks this flag.
537    streaming_active: bool,
538    /// Timestamp of connection creation. Used by pool max_lifetime.
539    created_at: std::time::Instant,
540    /// Notifications received during query processing. Buffered here
541    /// instead of dropped; call `drain_notifications()` to retrieve.
542    pending_notifications: Vec<Notification>,
543    /// Maximum number of cached prepared statements. When the cache exceeds
544    /// this size, the least recently used statement is evicted (Close sent to PG).
545    /// Default: 256.
546    max_stmt_cache_size: usize,
547    /// Monotonic counter for LRU eviction — incremented on each cache access.
548    /// Replaces `Instant::now()` to avoid syscall overhead (~20-40ns on macOS).
549    query_counter: u64,
550}
551
552impl Connection {
553    /// Connect to PostgreSQL and complete the startup/auth handshake.
554    ///
555    /// When `config.host` starts with `/` (Unix domain socket directory),
556    /// connects via `UnixStream` at `{host}/.s.PGSQL.{port}` instead of TCP.
557    /// TCP_NODELAY and keepalive are skipped for UDS since they are TCP-only.
558    pub async fn connect(config: &Config) -> Result<Self, DriverError> {
559        // Config::from_url() already validates. Manual Config construction
560        // should call validate() explicitly before passing to connect().
561
562        #[cfg(unix)]
563        if config.host_is_uds() {
564            let path = config.uds_path();
565            let unix = tokio::net::UnixStream::connect(&path)
566                .await
567                .map_err(DriverError::Io)?;
568            let stream = Stream::Unix(unix);
569            return Self::finish_connect(stream, config).await;
570        }
571
572        let addr = format!("{}:{}", config.host, config.port);
573        let tcp = TcpStream::connect(&addr).await.map_err(DriverError::Io)?;
574
575        // Set TCP_NODELAY to avoid Nagle delay on pipelined messages
576        tcp.set_nodelay(true).map_err(DriverError::Io)?;
577
578        // Without keepalive, a half-open connection (server crashed, firewall
579        // timeout) can hang forever on read.
580        Self::set_keepalive(&tcp)?;
581
582        let stream = match config.ssl {
583            SslMode::Disable => Stream::Plain(tcp),
584            #[cfg(feature = "tls")]
585            SslMode::Prefer | SslMode::Require => {
586                match tls::try_upgrade(tcp, &config.host, config.ssl == SslMode::Require).await {
587                    Ok(tls_stream) => Stream::Tls(Box::new(tls_stream)),
588                    Err(e) if config.ssl == SslMode::Require => return Err(e),
589                    Err(_) => {
590                        // Prefer mode: TLS failed, reconnect plain
591                        let tcp = TcpStream::connect(&addr).await.map_err(DriverError::Io)?;
592                        tcp.set_nodelay(true).map_err(DriverError::Io)?;
593                        Self::set_keepalive(&tcp)?;
594                        Stream::Plain(tcp)
595                    }
596                }
597            }
598            #[cfg(not(feature = "tls"))]
599            SslMode::Require => {
600                return Err(DriverError::Protocol(
601                    "TLS required but bsql-driver-postgres compiled without 'tls' feature".into(),
602                ));
603            }
604            #[cfg(not(feature = "tls"))]
605            SslMode::Prefer => Stream::Plain(tcp),
606        };
607
608        Self::finish_connect(stream, config).await
609    }
610
611    /// Shared connection setup: build the `Connection`, run startup handshake,
612    /// validate server params, and set statement timeout. Called by both the
613    /// TCP and UDS paths in [`connect`].
614    async fn finish_connect(stream: Stream, config: &Config) -> Result<Self, DriverError> {
615        let mut conn = Self {
616            stream,
617            read_buf: Vec::with_capacity(8192),
618
619            stream_buf: vec![0u8; 65536],
620            stream_buf_pos: 0,
621            stream_buf_end: 0,
622            write_buf: Vec::with_capacity(4096),
623            stmts: StmtCache::default(),
624            params: Vec::new(),
625            pid: 0,
626            secret: 0,
627            tx_status: b'I',
628            last_used: std::time::Instant::now(),
629            streaming_active: false,
630            created_at: std::time::Instant::now(),
631            pending_notifications: Vec::new(),
632            max_stmt_cache_size: 256,
633            query_counter: 0,
634        };
635
636        conn.startup(config).await?;
637
638        // Validate critical server parameters received during startup.
639        conn.validate_server_params()?;
640
641        if config.statement_timeout_secs > 0 {
642            conn.simple_query(&format!(
643                "SET statement_timeout = '{}s'",
644                config.statement_timeout_secs
645            ))
646            .await?;
647        }
648
649        Ok(conn)
650    }
651
652    /// Perform the startup handshake: StartupMessage -> auth -> parameter status -> ReadyForQuery.
653    ///
654    /// Uses a two-phase read approach: first read the message type + copy needed
655    /// data out of the borrow, then act on it. This avoids holding a borrow on
656    /// `self.read_buf` while calling other `&mut self` methods.
657    async fn startup(&mut self, config: &Config) -> Result<(), DriverError> {
658        // Send StartupMessage
659        self.write_buf.clear();
660        proto::write_startup(&mut self.write_buf, &config.user, &config.database);
661        self.flush_write().await?;
662
663        // Process auth and startup messages
664        loop {
665            let action = self.read_startup_action().await?;
666            match action {
667                StartupAction::AuthOk => {}
668                StartupAction::AuthCleartext => {
669                    self.write_buf.clear();
670                    let mut pw = config.password.as_bytes().to_vec();
671                    pw.push(0);
672                    proto::write_password(&mut self.write_buf, &pw);
673                    self.flush_write().await?;
674                }
675                StartupAction::AuthMd5(salt) => {
676                    self.write_buf.clear();
677                    let hash = auth::md5_password(&config.user, &config.password, &salt);
678                    proto::write_password(&mut self.write_buf, &hash);
679                    self.flush_write().await?;
680                }
681                StartupAction::AuthSasl(mechanisms_data) => {
682                    self.handle_scram(config, &mechanisms_data).await?;
683                }
684                StartupAction::ParameterStatus(name, value) => {
685                    // Linear scan on ~10 entries is faster than HashMap
686                    if let Some(entry) = self.params.iter_mut().find(|(k, _)| *k == name) {
687                        entry.1 = value;
688                    } else {
689                        self.params.push((name, value));
690                    }
691                }
692                StartupAction::BackendKeyData(pid, secret) => {
693                    self.pid = pid;
694                    self.secret = secret;
695                }
696                StartupAction::ReadyForQuery(status) => {
697                    self.tx_status = status;
698                    return Ok(());
699                }
700                StartupAction::Error(msg) => {
701                    return Err(DriverError::Auth(msg));
702                }
703                StartupAction::Notice => {}
704            }
705        }
706    }
707
708    /// Read one startup message, parse it, copy needed data, and return an owned action.
709    ///
710    /// This method reads the raw message into `self.read_buf`, parses it, extracts
711    /// all needed data into owned types, and drops the borrow before returning.
712    async fn read_startup_action(&mut self) -> Result<StartupAction, DriverError> {
713        let (msg_type, _) = self.read_message_buffered().await?;
714        self.read_startup_message_from_type(msg_type)
715    }
716
717    fn read_startup_message_from_type(&self, msg_type: u8) -> Result<StartupAction, DriverError> {
718        let payload = &self.read_buf;
719        let msg = proto::parse_backend_message(msg_type, payload)?;
720        match msg {
721            BackendMessage::AuthOk => Ok(StartupAction::AuthOk),
722            BackendMessage::AuthCleartext => Ok(StartupAction::AuthCleartext),
723            BackendMessage::AuthMd5 { salt } => Ok(StartupAction::AuthMd5(salt)),
724            BackendMessage::AuthSasl { mechanisms } => {
725                Ok(StartupAction::AuthSasl(mechanisms.to_vec()))
726            }
727            BackendMessage::ParameterStatus { name, value } => {
728                Ok(StartupAction::ParameterStatus(name.into(), value.into()))
729            }
730            BackendMessage::BackendKeyData { pid, secret } => {
731                Ok(StartupAction::BackendKeyData(pid, secret))
732            }
733            BackendMessage::ReadyForQuery { status } => Ok(StartupAction::ReadyForQuery(status)),
734            BackendMessage::ErrorResponse { data } => {
735                let fields = proto::parse_error_response(data);
736                Ok(StartupAction::Error(fields.to_string()))
737            }
738            BackendMessage::NoticeResponse { .. } => Ok(StartupAction::Notice),
739            other => Err(DriverError::Protocol(format!(
740                "unexpected message during startup: {other:?}"
741            ))),
742        }
743    }
744
745    /// Handle SCRAM-SHA-256 authentication exchange.
746    async fn handle_scram(
747        &mut self,
748        config: &Config,
749        mechanisms_data: &[u8],
750    ) -> Result<(), DriverError> {
751        let mechs = auth::parse_sasl_mechanisms(mechanisms_data);
752        if !mechs.contains(&"SCRAM-SHA-256") {
753            return Err(DriverError::Auth(format!(
754                "server requires unsupported SASL mechanism(s): {mechs:?}"
755            )));
756        }
757
758        let mut scram = auth::ScramClient::new(&config.user, &config.password)?;
759
760        // Send SASLInitialResponse
761        let client_first = scram.client_first_message();
762        self.write_buf.clear();
763        proto::write_sasl_initial(&mut self.write_buf, "SCRAM-SHA-256", &client_first);
764        self.flush_write().await?;
765
766        // Read SASLContinue — read message, extract data, drop borrow
767        let (msg_type, _) = self.read_message_buffered().await?;
768        let server_first = {
769            let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
770            match msg {
771                BackendMessage::AuthSaslContinue { data } => data.to_vec(),
772                BackendMessage::ErrorResponse { data } => {
773                    let fields = proto::parse_error_response(data);
774                    return Err(DriverError::Auth(fields.to_string()));
775                }
776                other => {
777                    return Err(DriverError::Protocol(format!(
778                        "expected AuthSaslContinue, got: {other:?}"
779                    )));
780                }
781            }
782        };
783
784        scram.process_server_first(&server_first)?;
785
786        // Send SASLResponse (client-final)
787        let client_final = scram.client_final_message()?;
788        self.write_buf.clear();
789        proto::write_sasl_response(&mut self.write_buf, &client_final);
790        self.flush_write().await?;
791
792        // Read SASLFinal — read message, extract data, drop borrow
793        let (msg_type, _) = self.read_message_buffered().await?;
794        {
795            let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
796            match msg {
797                BackendMessage::AuthSaslFinal { data } => {
798                    // Copy server final data to verify after the borrow ends
799                    let data_owned = data.to_vec();
800                    scram.verify_server_final(&data_owned)?;
801                }
802                BackendMessage::ErrorResponse { data } => {
803                    let fields = proto::parse_error_response(data);
804                    return Err(DriverError::Auth(fields.to_string()));
805                }
806                other => {
807                    return Err(DriverError::Protocol(format!(
808                        "expected AuthSaslFinal, got: {other:?}"
809                    )));
810                }
811            }
812        }
813
814        // AuthOk should follow
815        let (msg_type, _) = self.read_message_buffered().await?;
816        let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
817        match msg {
818            BackendMessage::AuthOk => Ok(()),
819            BackendMessage::ErrorResponse { data } => {
820                let fields = proto::parse_error_response(data);
821                Err(DriverError::Auth(fields.to_string()))
822            }
823            other => Err(DriverError::Protocol(format!(
824                "expected AuthOk after SCRAM, got: {other:?}"
825            ))),
826        }
827    }
828
829    // --- Query execution ---
830
831    /// Prepare a statement without executing it (Parse+Describe+Sync only).
832    ///
833    /// Used by connection warmup to pre-cache statements without executing them.
834    /// If the statement is already cached, this is a no-op.
835    pub async fn prepare_only(&mut self, sql: &str, sql_hash: u64) -> Result<(), DriverError> {
836        if self.stmts.contains_key(&sql_hash) {
837            return Ok(());
838        }
839        let name = make_stmt_name(sql_hash);
840        self.write_buf.clear();
841        proto::write_parse(&mut self.write_buf, &name, sql, &[]);
842        proto::write_describe(&mut self.write_buf, b'S', &name);
843        proto::write_sync(&mut self.write_buf);
844        self.flush_write().await?;
845
846        // Read ParseComplete
847        self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
848            .await?;
849
850        // Read ParameterDescription + RowDescription/NoData via existing helper
851        let columns = self.read_column_description().await?;
852
853        // ReadyForQuery
854        self.expect_ready().await?;
855
856        // Cache the statement (with LRU eviction if needed)
857        self.query_counter += 1;
858        self.cache_stmt(
859            sql_hash,
860            StmtInfo {
861                name,
862                columns,
863                last_used: self.query_counter,
864                bind_template: None,
865            },
866        );
867        Ok(())
868    }
869
870    /// Prepare a statement and return full column + parameter metadata.
871    ///
872    /// Sends Parse + Describe(Statement) + Sync, then reads:
873    /// - ParseComplete
874    /// - ParameterDescription (param type OIDs)
875    /// - RowDescription or NoData (column metadata)
876    /// - ReadyForQuery
877    ///
878    /// Unlike `prepare_only`, this always sends Parse (no cache check) and
879    /// uses the unnamed statement `""` so it does not pollute the statement
880    /// cache. This is designed for compile-time SQL validation in the proc
881    /// macro, where we need column + param metadata but never execute.
882    pub async fn prepare_describe(&mut self, sql: &str) -> Result<PrepareResult, DriverError> {
883        self.write_buf.clear();
884        // Use unnamed statement "" — PG replaces it on every Parse,
885        // so there is no cache pollution.
886        proto::write_parse(&mut self.write_buf, "", sql, &[]);
887        proto::write_describe(&mut self.write_buf, b'S', "");
888        proto::write_sync(&mut self.write_buf);
889        self.flush_write().await?;
890
891        // Read ParseComplete
892        self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
893            .await?;
894
895        // Read ParameterDescription + RowDescription/NoData
896        let mut param_oids: Vec<u32> = Vec::new();
897        let columns;
898        loop {
899            let msg = self.read_one_message().await?;
900            match msg {
901                BackendMessage::ParameterDescription { data } => {
902                    param_oids = proto::parse_parameter_description(data)?;
903                }
904                BackendMessage::RowDescription { data } => {
905                    columns = proto::parse_row_description(data)?;
906                    break;
907                }
908                BackendMessage::NoData => {
909                    columns = Vec::new();
910                    break;
911                }
912                BackendMessage::NoticeResponse { .. } => {}
913                BackendMessage::ErrorResponse { data } => {
914                    let fields = proto::parse_error_response(data);
915                    self.drain_to_ready().await?;
916                    return Err(self.make_server_error(fields));
917                }
918                other => {
919                    return Err(DriverError::Protocol(format!(
920                        "expected ParameterDescription/RowDescription/NoData, got: {other:?}"
921                    )));
922                }
923            }
924        }
925
926        // ReadyForQuery
927        self.expect_ready().await?;
928
929        Ok(PrepareResult {
930            columns,
931            param_oids,
932        })
933    }
934
935    /// Execute a simple (text protocol) query and return all result rows.
936    ///
937    /// Each row is a `Vec<Option<String>>` — NULL values are `None`, text
938    /// values are `Some(String)`. This uses the simple query protocol which
939    /// always returns text-format results.
940    ///
941    /// Designed for compile-time schema introspection queries in the proc
942    /// macro (e.g. `pg_attribute`, `information_schema`). Not intended for
943    /// high-performance runtime use.
944    pub async fn simple_query_rows(&mut self, sql: &str) -> Result<Vec<SimpleRow>, DriverError> {
945        self.write_buf.clear();
946        proto::write_simple_query(&mut self.write_buf, sql);
947        self.flush_write().await?;
948
949        let mut rows: Vec<SimpleRow> = Vec::new();
950        loop {
951            let msg = self.read_one_message().await?;
952            match msg {
953                BackendMessage::ReadyForQuery { status } => {
954                    self.tx_status = status;
955                    return Ok(rows);
956                }
957                BackendMessage::DataRow { data } => {
958                    rows.push(proto::parse_simple_data_row(data)?);
959                }
960                BackendMessage::RowDescription { .. }
961                | BackendMessage::CommandComplete { .. }
962                | BackendMessage::EmptyQuery
963                | BackendMessage::NoticeResponse { .. } => {}
964                BackendMessage::ErrorResponse { data } => {
965                    let fields = proto::parse_error_response(data);
966                    self.drain_to_ready().await?;
967                    return Err(self.make_server_error(fields));
968                }
969                BackendMessage::ParameterStatus { .. } => {}
970                other => {
971                    return Err(DriverError::Protocol(format!(
972                        "unexpected message during simple_query_rows: {other:?}"
973                    )));
974                }
975            }
976        }
977    }
978
979    /// Begin a streaming query using the PG extended query protocol with
980    /// `Execute(max_rows=chunk_size)`.
981    ///
982    /// Returns column metadata and puts the connection into streaming mode.
983    /// The caller must repeatedly call `streaming_next_chunk()` until it returns
984    /// `Ok(false)` (all rows consumed) before issuing any other query on this
985    /// connection.
986    ///
987    /// Uses the unnamed portal `""` which stays open between Execute calls
988    /// as long as Sync is NOT sent. We use Flush (not Sync) to force PG to
989    /// send buffered output without destroying the portal. Sync is only sent
990    /// after CommandComplete to cleanly end the query cycle.
991    pub async fn query_streaming_start(
992        &mut self,
993        sql: &str,
994        sql_hash: u64,
995        params: &[&(dyn Encode + Sync)],
996        chunk_size: i32,
997    ) -> Result<(Arc<[ColumnDesc]>, bool), DriverError> {
998        self.write_buf.clear();
999
1000        // Single hash lookup via get_mut — avoids contains_key + index double-lookup.
1001        let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
1002            // Cache hit: try bind template, fall back to write_bind_params.
1003            self.query_counter += 1;
1004            info.last_used = self.query_counter;
1005
1006            let can_use_template = info
1007                .bind_template
1008                .as_ref()
1009                .is_some_and(|t| t.param_slots.len() == params.len());
1010
1011            if can_use_template {
1012                let tmpl = info.bind_template.as_ref().unwrap();
1013                // Copy only the Bind portion (not EXECUTE_SYNC) — streaming
1014                // needs Execute+Flush instead.
1015                self.write_buf
1016                    .extend_from_slice(&tmpl.bytes[..tmpl.bind_end]);
1017
1018                let mut template_ok = true;
1019                for (i, param) in params.iter().enumerate() {
1020                    let (data_offset, old_len) = tmpl.param_slots[i];
1021                    if param.is_null() {
1022                        let len_offset = data_offset - 4;
1023                        self.write_buf[len_offset..len_offset + 4]
1024                            .copy_from_slice(&(-1i32).to_be_bytes());
1025                    } else if old_len >= 0 {
1026                        let end = data_offset + old_len as usize;
1027                        if !param.encode_at(&mut self.write_buf[data_offset..end]) {
1028                            template_ok = false;
1029                            break;
1030                        }
1031                    } else {
1032                        template_ok = false;
1033                        break;
1034                    }
1035                }
1036
1037                if !template_ok {
1038                    self.write_buf.clear();
1039                    proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
1040                    info.bind_template = None;
1041                }
1042            } else {
1043                proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
1044            }
1045
1046            let cols = info.columns.clone();
1047
1048            if info.bind_template.is_none() && !self.write_buf.is_empty() {
1049                info.bind_template = build_bind_template(&self.write_buf, params.len());
1050            }
1051
1052            proto::write_execute(&mut self.write_buf, "", chunk_size);
1053            // Use Flush (not Sync!) to keep the portal alive between chunks.
1054            proto::write_flush(&mut self.write_buf);
1055            self.flush_write().await?;
1056
1057            cols
1058        } else {
1059            // Cache miss: Parse+Describe+Bind+Execute+Flush
1060            let name = make_stmt_name(sql_hash);
1061            let param_oids: smallvec::SmallVec<[u32; 8]> =
1062                params.iter().map(|p| p.type_oid()).collect();
1063            proto::write_parse(&mut self.write_buf, &name, sql, &param_oids);
1064            proto::write_describe(&mut self.write_buf, b'S', &name);
1065            proto::write_bind_params(&mut self.write_buf, "", &name, params);
1066
1067            proto::write_execute(&mut self.write_buf, "", chunk_size);
1068            proto::write_flush(&mut self.write_buf);
1069            self.flush_write().await?;
1070
1071            self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
1072                .await?;
1073            let columns = self.read_column_description().await?;
1074            self.query_counter += 1;
1075            self.cache_stmt(
1076                sql_hash,
1077                StmtInfo {
1078                    name,
1079                    columns: columns.clone(),
1080                    last_used: self.query_counter,
1081                    bind_template: None,
1082                },
1083            );
1084            columns
1085        };
1086
1087        // BindComplete
1088        self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
1089            .await?;
1090
1091        self.streaming_active = true;
1092
1093        Ok((columns, false))
1094    }
1095
1096    /// Read the next chunk of rows from an in-progress streaming query.
1097    ///
1098    /// Returns `Ok(true)` if more rows are available (PortalSuspended),
1099    /// `Ok(false)` when all rows have been consumed (CommandComplete).
1100    ///
1101    /// After CommandComplete, this method sends Sync and reads ReadyForQuery,
1102    /// returning the connection to a clean protocol state.
1103    pub async fn streaming_next_chunk(
1104        &mut self,
1105        arena: &mut Arena,
1106        all_col_offsets: &mut Vec<(usize, i32)>,
1107    ) -> Result<bool, DriverError> {
1108        all_col_offsets.clear();
1109
1110        loop {
1111            let msg = self.read_one_message().await?;
1112            match msg {
1113                BackendMessage::DataRow { data } => {
1114                    parse_data_row_flat(data, arena, all_col_offsets)?;
1115                }
1116                BackendMessage::PortalSuspended => {
1117                    // More rows available. The portal stays open because we
1118                    // used Flush (not Sync). The caller will call
1119                    // streaming_send_execute() to request the next chunk.
1120                    return Ok(true);
1121                }
1122                BackendMessage::CommandComplete { .. } => {
1123                    // All rows consumed. Send Sync to end the query cycle
1124                    // and read ReadyForQuery to restore clean state.
1125                    self.write_buf.clear();
1126                    proto::write_sync(&mut self.write_buf);
1127                    self.flush_write().await?;
1128                    self.expect_ready().await?;
1129                    self.shrink_buffers();
1130
1131                    self.streaming_active = false;
1132                    return Ok(false);
1133                }
1134                BackendMessage::EmptyQuery => {
1135                    self.write_buf.clear();
1136                    proto::write_sync(&mut self.write_buf);
1137                    self.flush_write().await?;
1138                    self.expect_ready().await?;
1139
1140                    self.streaming_active = false;
1141                    return Ok(false);
1142                }
1143                BackendMessage::ErrorResponse { data } => {
1144                    let fields = proto::parse_error_response(data);
1145                    // Send Sync to reset and drain to ReadyForQuery
1146                    self.write_buf.clear();
1147                    proto::write_sync(&mut self.write_buf);
1148                    self.flush_write().await?;
1149                    self.drain_to_ready().await?;
1150
1151                    self.streaming_active = false;
1152                    return Err(self.make_server_error(fields));
1153                }
1154                BackendMessage::NoticeResponse { .. } => {}
1155                other => {
1156                    return Err(DriverError::Protocol(format!(
1157                        "unexpected message during streaming: {other:?}"
1158                    )));
1159                }
1160            }
1161        }
1162    }
1163
1164    /// Send Execute+Flush for the next chunk of a streaming query.
1165    ///
1166    /// Must be called before `streaming_next_chunk()` on the 2nd and
1167    /// subsequent chunks (the first chunk's Execute is sent by
1168    /// `query_streaming_start`).
1169    ///
1170    /// Uses Flush (not Sync) to keep the unnamed portal alive.
1171    pub async fn streaming_send_execute(&mut self, chunk_size: i32) -> Result<(), DriverError> {
1172        self.write_buf.clear();
1173        proto::write_execute(&mut self.write_buf, "", chunk_size);
1174        proto::write_flush(&mut self.write_buf);
1175        self.flush_write().await
1176    }
1177
1178    /// Common pipeline setup — builds Parse+Describe+Bind+Execute+Sync (or
1179    /// Bind+Execute+Sync on cache hit), sends to wire, reads ParseComplete+Describe
1180    /// responses if needed, reads BindComplete. Returns column metadata.
1181    ///
1182    /// When `need_columns` is false (e.g. `for_each_raw`, `execute`), the Arc
1183    /// clone of column metadata is skipped — saving an atomic increment on the
1184    /// hot path.
1185    ///
1186    /// When `skip_bind_complete` is true, the BindComplete message is NOT
1187    /// consumed here — the caller reads it inline from stream_buf (e.g.
1188    /// `for_each_raw` which already has a zero-copy stream_buf reader).
1189    async fn send_pipeline(
1190        &mut self,
1191        sql: &str,
1192        sql_hash: u64,
1193        params: &[&(dyn Encode + Sync)],
1194        need_columns: bool,
1195        skip_bind_complete: bool,
1196    ) -> Result<Option<Arc<[ColumnDesc]>>, DriverError> {
1197        debug_assert_eq!(
1198            hash_sql(sql),
1199            sql_hash,
1200            "sql_hash mismatch: caller-provided hash does not match hash_sql(sql)"
1201        );
1202
1203        if params.len() > i16::MAX as usize {
1204            return Err(DriverError::Protocol(format!(
1205                "parameter count {} exceeds maximum {} for PG wire protocol",
1206                params.len(),
1207                i16::MAX
1208            )));
1209        }
1210
1211        self.write_buf.clear();
1212
1213        // Single hash lookup — get_mut avoids the contains_key + index double-lookup.
1214        let columns = if let Some(info) = self.stmts.get_mut(&sql_hash) {
1215            // Cache hit: try bind template for fast path, fall back to write_bind_params.
1216            self.query_counter += 1;
1217            info.last_used = self.query_counter;
1218
1219            let can_use_template = info
1220                .bind_template
1221                .as_ref()
1222                .is_some_and(|t| t.param_slots.len() == params.len());
1223
1224            // Tracks whether write_buf already contains EXECUTE_SYNC (from template).
1225            let mut has_exec_sync = false;
1226
1227            if can_use_template {
1228                // Fast path: copy template (includes EXECUTE_SYNC) and patch params
1229                // directly via encode_at — no scratch buffer, no double-copy.
1230                let tmpl = info.bind_template.as_ref().unwrap();
1231                self.write_buf.extend_from_slice(&tmpl.bytes);
1232
1233                let mut template_ok = true;
1234                for (i, param) in params.iter().enumerate() {
1235                    let (data_offset, old_len) = tmpl.param_slots[i];
1236                    if param.is_null() {
1237                        let len_offset = data_offset - 4;
1238                        self.write_buf[len_offset..len_offset + 4]
1239                            .copy_from_slice(&(-1i32).to_be_bytes());
1240                    } else if old_len >= 0 {
1241                        let end = data_offset + old_len as usize;
1242                        if !param.encode_at(&mut self.write_buf[data_offset..end]) {
1243                            template_ok = false;
1244                            break;
1245                        }
1246                    } else {
1247                        // Template had NULL here but now non-NULL — rebuild.
1248                        template_ok = false;
1249                        break;
1250                    }
1251                }
1252
1253                if template_ok {
1254                    has_exec_sync = true; // Template includes EXECUTE_SYNC.
1255                } else {
1256                    self.write_buf.clear();
1257                    proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
1258                    info.bind_template = None;
1259                }
1260            } else {
1261                proto::write_bind_params(&mut self.write_buf, "", &info.name, params);
1262            }
1263
1264            // Clone Arc only when caller needs columns (query path).
1265            // for_each_raw / execute skip this atomic increment.
1266            let cols = if need_columns {
1267                Some(info.columns.clone())
1268            } else {
1269                None
1270            };
1271
1272            // Snapshot bind template on first use or after invalidation.
1273            // build_bind_template appends EXECUTE_SYNC to the template bytes.
1274            if info.bind_template.is_none() && !self.write_buf.is_empty() {
1275                info.bind_template = build_bind_template(&self.write_buf, params.len());
1276            }
1277
1278            if !has_exec_sync {
1279                self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1280            }
1281            self.flush_write().await?;
1282
1283            cols
1284        } else {
1285            // Cache miss: Parse+Describe+Bind+Execute+Sync
1286            let name = make_stmt_name(sql_hash);
1287            let param_oids: smallvec::SmallVec<[u32; 8]> =
1288                params.iter().map(|p| p.type_oid()).collect();
1289            proto::write_parse(&mut self.write_buf, &name, sql, &param_oids);
1290            proto::write_describe(&mut self.write_buf, b'S', &name);
1291            proto::write_bind_params(&mut self.write_buf, "", &name, params);
1292
1293            self.write_buf.extend_from_slice(proto::EXECUTE_SYNC);
1294            self.flush_write().await?;
1295
1296            self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
1297                .await?;
1298            let columns = self.read_column_description().await?;
1299            self.query_counter += 1;
1300            self.cache_stmt(
1301                sql_hash,
1302                StmtInfo {
1303                    name,
1304                    columns: columns.clone(),
1305                    last_used: self.query_counter,
1306                    bind_template: None,
1307                },
1308            );
1309            if need_columns { Some(columns) } else { None }
1310        };
1311
1312        // BindComplete — skip when caller handles it inline (for_each_raw).
1313        if !skip_bind_complete {
1314            self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
1315                .await?;
1316        }
1317
1318        Ok(columns)
1319    }
1320
1321    /// Execute a prepared query and return rows in arena-allocated storage.
1322    ///
1323    /// If the statement is not yet cached, Parse+Describe+Bind+Execute+Sync are
1324    /// pipelined in a single TCP write. On cache hit, only Bind+Execute+Sync are sent.
1325    pub async fn query(
1326        &mut self,
1327        sql: &str,
1328        sql_hash: u64,
1329        params: &[&(dyn Encode + Sync)],
1330        arena: &mut Arena,
1331    ) -> Result<QueryResult, DriverError> {
1332        let columns = self
1333            .send_pipeline(sql, sql_hash, params, true, false)
1334            .await?
1335            .expect("send_pipeline(need_columns=true) must return Some");
1336
1337        // Read DataRow messages and CommandComplete.
1338        // Flat column offsets: all rows' columns are stored contiguously in
1339        // `all_col_offsets`. Row N starts at index `N * num_cols`.
1340
1341        // is just num_cols; for fetch_all we grow dynamically. The previous
1342        // `num_cols * 64` over-allocates for single-row queries.
1343        let num_cols = columns.len();
1344        // .max(1) prevents zero-capacity allocation when num_cols is 0 (e.g., INSERT/UPDATE/DELETE
1345        // with no RETURNING clause), ensuring Vec has a reasonable initial capacity.
1346        let mut all_col_offsets: Vec<(usize, i32)> = Vec::with_capacity(num_cols.max(1) * 8);
1347        let mut affected_rows: u64 = 0;
1348
1349        loop {
1350            let msg = self.read_one_message().await?;
1351            match msg {
1352                BackendMessage::DataRow { data } => {
1353                    parse_data_row_flat(data, arena, &mut all_col_offsets)?;
1354                }
1355                BackendMessage::CommandComplete { tag } => {
1356                    affected_rows = proto::parse_command_tag(tag);
1357                    break;
1358                }
1359                BackendMessage::EmptyQuery => {
1360                    break;
1361                }
1362                BackendMessage::NoticeResponse { .. } => {
1363                    // Async messages can arrive mid-query — skip them
1364                }
1365                BackendMessage::ErrorResponse { data } => {
1366                    let fields = proto::parse_error_response(data);
1367
1368                    self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1369                    self.drain_to_ready().await?;
1370                    return Err(self.make_server_error(fields));
1371                }
1372                other => {
1373                    return Err(DriverError::Protocol(format!(
1374                        "unexpected message during query: {other:?}"
1375                    )));
1376                }
1377            }
1378        }
1379
1380        // ReadyForQuery
1381        self.expect_ready().await?;
1382        self.shrink_buffers();
1383
1384        Ok(QueryResult {
1385            all_col_offsets,
1386            num_cols,
1387            columns,
1388            affected_rows,
1389        })
1390    }
1391
1392    /// Read RowDescription / NoData after ParseComplete+Describe, handling
1393    /// ParameterDescription that precedes RowDescription for Describe Statement.
1394    async fn read_column_description(&mut self) -> Result<Arc<[ColumnDesc]>, DriverError> {
1395        loop {
1396            let msg = self.read_one_message().await?;
1397            match msg {
1398                BackendMessage::RowDescription { data } => {
1399                    let cols = proto::parse_row_description(data)?;
1400                    return Ok(cols.into());
1401                }
1402                BackendMessage::ParameterDescription { .. } => {
1403                    // ParameterDescription precedes RowDescription — continue reading
1404                }
1405                BackendMessage::NoData => return Ok(Arc::from(Vec::new())),
1406                BackendMessage::NoticeResponse { .. } => {}
1407                BackendMessage::ErrorResponse { data } => {
1408                    let fields = proto::parse_error_response(data);
1409                    self.drain_to_ready().await?;
1410                    return Err(self.make_server_error(fields));
1411                }
1412                other => {
1413                    return Err(DriverError::Protocol(format!(
1414                        "expected RowDescription/NoData after Parse, got: {other:?}"
1415                    )));
1416                }
1417            }
1418        }
1419    }
1420
1421    /// Execute a query without result rows (INSERT/UPDATE/DELETE).
1422    ///
1423    /// Skips DataRow parsing entirely — only reads until CommandComplete.
1424    /// Does not allocate an Arena.
1425    pub async fn execute(
1426        &mut self,
1427        sql: &str,
1428        sql_hash: u64,
1429        params: &[&(dyn Encode + Sync)],
1430    ) -> Result<u64, DriverError> {
1431        let _ = self
1432            .send_pipeline(sql, sql_hash, params, false, false)
1433            .await?;
1434
1435        // Skip DataRow messages, read until CommandComplete
1436        let mut affected_rows: u64 = 0;
1437        loop {
1438            let msg = self.read_one_message().await?;
1439            match msg {
1440                BackendMessage::DataRow { .. } => {
1441                    // execute() discards row data — no arena allocation
1442                }
1443                BackendMessage::CommandComplete { tag } => {
1444                    affected_rows = proto::parse_command_tag(tag);
1445                    break;
1446                }
1447                BackendMessage::EmptyQuery => break,
1448                BackendMessage::NoticeResponse { .. } => {}
1449                BackendMessage::ErrorResponse { data } => {
1450                    let fields = proto::parse_error_response(data);
1451
1452                    self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1453                    self.drain_to_ready().await?;
1454                    return Err(self.make_server_error(fields));
1455                }
1456                other => {
1457                    return Err(DriverError::Protocol(format!(
1458                        "unexpected message during execute: {other:?}"
1459                    )));
1460                }
1461            }
1462        }
1463
1464        self.expect_ready().await?;
1465        self.shrink_buffers();
1466        Ok(affected_rows)
1467    }
1468
1469    /// Execute the same prepared statement N times with different parameters
1470    /// in a single pipeline round-trip.
1471    ///
1472    /// Sends all N Bind+Execute messages followed by one Sync. PostgreSQL
1473    /// processes them in order and returns N BindComplete+CommandComplete
1474    /// responses followed by one ReadyForQuery.
1475    ///
1476    /// This is a real optimization for bulk operations: N inserts in a
1477    /// transaction become 1 round-trip instead of N round-trips.
1478    ///
1479    /// The statement must already be cached (call `execute` at least once first,
1480    /// or use `prepare_describe`). If not cached, it will be prepared inline
1481    /// for the first entry, then the rest use the cached version.
1482    ///
1483    /// Returns the number of affected rows for each parameter set.
1484    pub async fn execute_pipeline(
1485        &mut self,
1486        sql: &str,
1487        sql_hash: u64,
1488        param_sets: &[&[&(dyn Encode + Sync)]],
1489    ) -> Result<Vec<u64>, DriverError> {
1490        if param_sets.is_empty() {
1491            return Ok(Vec::new());
1492        }
1493
1494        debug_assert_eq!(
1495            hash_sql(sql),
1496            sql_hash,
1497            "sql_hash mismatch: caller-provided hash does not match hash_sql(sql)"
1498        );
1499
1500        self.write_buf.clear();
1501
1502        // Ensure statement is prepared. If not cached, prepare it first with
1503        // a standalone Parse+Describe+Sync pipeline.
1504        if !self.stmts.contains_key(&sql_hash) {
1505            let name = make_stmt_name(sql_hash);
1506            let first_params = param_sets[0];
1507            if first_params.len() > i16::MAX as usize {
1508                return Err(DriverError::Protocol(format!(
1509                    "parameter count {} exceeds maximum {}",
1510                    first_params.len(),
1511                    i16::MAX
1512                )));
1513            }
1514            let param_oids: smallvec::SmallVec<[u32; 8]> =
1515                first_params.iter().map(|p| p.type_oid()).collect();
1516            proto::write_parse(&mut self.write_buf, &name, sql, &param_oids);
1517            proto::write_describe(&mut self.write_buf, b'S', &name);
1518            proto::write_sync(&mut self.write_buf);
1519            self.flush_write().await?;
1520
1521            self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
1522                .await?;
1523            let columns = self.read_column_description().await?;
1524            self.expect_ready().await?;
1525
1526            self.query_counter += 1;
1527            self.cache_stmt(
1528                sql_hash,
1529                StmtInfo {
1530                    name,
1531                    columns,
1532                    last_used: self.query_counter,
1533                    bind_template: None,
1534                },
1535            );
1536
1537            self.write_buf.clear();
1538        }
1539
1540        // Build N x (Bind + Execute) + 1 x Sync
1541        let stmt_name = self
1542            .stmts
1543            .get(&sql_hash)
1544            .expect("BUG: stmt just cached but not found")
1545            .name
1546            .clone();
1547        let count = param_sets.len();
1548
1549        for params in param_sets {
1550            if params.len() > i16::MAX as usize {
1551                return Err(DriverError::Protocol(format!(
1552                    "parameter count {} exceeds maximum {}",
1553                    params.len(),
1554                    i16::MAX
1555                )));
1556            }
1557            proto::write_bind_params(&mut self.write_buf, "", &stmt_name, params);
1558            self.write_buf.extend_from_slice(proto::EXECUTE_ONLY);
1559        }
1560
1561        // One Sync at the end
1562        self.write_buf.extend_from_slice(proto::SYNC_ONLY);
1563        self.flush_write().await?;
1564
1565        // Read N x (BindComplete + CommandComplete) + ReadyForQuery
1566        let mut results = Vec::with_capacity(count);
1567        for _ in 0..count {
1568            self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
1569                .await?;
1570
1571            // Read until CommandComplete, skipping DataRow/EmptyQuery/Notice
1572            let mut affected_rows: u64 = 0;
1573            loop {
1574                let msg = self.read_one_message().await?;
1575                match msg {
1576                    BackendMessage::DataRow { .. } => {}
1577                    BackendMessage::CommandComplete { tag } => {
1578                        affected_rows = proto::parse_command_tag(tag);
1579                        break;
1580                    }
1581                    BackendMessage::EmptyQuery => break,
1582                    BackendMessage::NoticeResponse { .. } => {}
1583                    BackendMessage::ErrorResponse { data } => {
1584                        let fields = proto::parse_error_response(data);
1585                        self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1586                        self.drain_to_ready().await?;
1587                        return Err(self.make_server_error(fields));
1588                    }
1589                    other => {
1590                        return Err(DriverError::Protocol(format!(
1591                            "unexpected message during execute_pipeline: {other:?}"
1592                        )));
1593                    }
1594                }
1595            }
1596            results.push(affected_rows);
1597        }
1598
1599        self.expect_ready().await?;
1600        self.shrink_buffers();
1601        Ok(results)
1602    }
1603
1604    /// Ensure a statement is prepared and cached, doing a round-trip if needed.
1605    ///
1606    /// Returns the cached statement name. If the statement is already cached,
1607    /// this is a no-op (hash lookup only). Otherwise, sends Parse+Describe+Sync
1608    /// and waits for the response.
1609    ///
1610    /// Used by deferred pipeline execution to separate the prepare step
1611    /// (which requires I/O) from the Bind+Execute buffering step (which doesn't).
1612    pub(crate) async fn ensure_stmt_prepared(
1613        &mut self,
1614        sql: &str,
1615        sql_hash: u64,
1616        params: &[&(dyn Encode + Sync)],
1617    ) -> Result<Box<str>, DriverError> {
1618        if let Some(info) = self.stmts.get(&sql_hash) {
1619            return Ok(info.name.clone());
1620        }
1621
1622        // Cache miss: Parse+Describe+Sync round-trip
1623        let name = make_stmt_name(sql_hash);
1624        if params.len() > i16::MAX as usize {
1625            return Err(DriverError::Protocol(format!(
1626                "parameter count {} exceeds maximum {}",
1627                params.len(),
1628                i16::MAX
1629            )));
1630        }
1631        let param_oids: smallvec::SmallVec<[u32; 8]> =
1632            params.iter().map(|p| p.type_oid()).collect();
1633
1634        self.write_buf.clear();
1635        proto::write_parse(&mut self.write_buf, &name, sql, &param_oids);
1636        proto::write_describe(&mut self.write_buf, b'S', &name);
1637        proto::write_sync(&mut self.write_buf);
1638        self.flush_write().await?;
1639
1640        self.expect_message(|m| matches!(m, BackendMessage::ParseComplete))
1641            .await?;
1642        let columns = self.read_column_description().await?;
1643        self.expect_ready().await?;
1644
1645        self.query_counter += 1;
1646        let stmt_name = name.clone();
1647        self.cache_stmt(
1648            sql_hash,
1649            StmtInfo {
1650                name,
1651                columns,
1652                last_used: self.query_counter,
1653                bind_template: None,
1654            },
1655        );
1656
1657        Ok(stmt_name)
1658    }
1659
1660    /// Write Bind+Execute message bytes for a prepared statement into an
1661    /// external buffer. Does NOT send anything on the wire.
1662    ///
1663    /// The statement must already be prepared (call `ensure_stmt_prepared` first).
1664    /// Panics in debug mode if the statement is not cached.
1665    pub(crate) fn write_deferred_bind_execute(
1666        &self,
1667        sql_hash: u64,
1668        params: &[&(dyn Encode + Sync)],
1669        buf: &mut Vec<u8>,
1670    ) {
1671        let stmt_name = &self
1672            .stmts
1673            .get(&sql_hash)
1674            .expect("BUG: stmt just cached but not found")
1675            .name;
1676        proto::write_bind_params(buf, "", stmt_name, params);
1677        buf.extend_from_slice(proto::EXECUTE_ONLY);
1678    }
1679
1680    /// Flush a buffer of deferred Bind+Execute messages as a single pipeline.
1681    ///
1682    /// Appends Sync to the buffer, writes everything in one TCP write, then
1683    /// reads `count` x (BindComplete + CommandComplete) + ReadyForQuery.
1684    /// Returns the affected row count for each deferred operation.
1685    pub(crate) async fn flush_deferred_pipeline(
1686        &mut self,
1687        buf: &mut Vec<u8>,
1688        count: usize,
1689    ) -> Result<Vec<u64>, DriverError> {
1690        if count == 0 {
1691            buf.clear();
1692            return Ok(Vec::new());
1693        }
1694
1695        buf.extend_from_slice(proto::SYNC_ONLY);
1696
1697        // Write the entire buffer in one TCP write
1698        self.stream.write_all(buf).await.map_err(DriverError::Io)?;
1699        self.stream.flush().await.map_err(DriverError::Io)?;
1700        buf.clear();
1701
1702        // Read count x (BindComplete + CommandComplete) + ReadyForQuery
1703        let mut results = Vec::with_capacity(count);
1704        for _ in 0..count {
1705            self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
1706                .await?;
1707
1708            let mut affected_rows: u64 = 0;
1709            loop {
1710                let msg = self.read_one_message().await?;
1711                match msg {
1712                    BackendMessage::DataRow { .. } => {}
1713                    BackendMessage::CommandComplete { tag } => {
1714                        affected_rows = proto::parse_command_tag(tag);
1715                        break;
1716                    }
1717                    BackendMessage::EmptyQuery => break,
1718                    BackendMessage::NoticeResponse { .. } => {}
1719                    BackendMessage::ErrorResponse { data } => {
1720                        let fields = proto::parse_error_response(data);
1721                        self.drain_to_ready().await?;
1722                        return Err(self.make_server_error(fields));
1723                    }
1724                    other => {
1725                        return Err(DriverError::Protocol(format!(
1726                            "unexpected message during flush_deferred_pipeline: {other:?}"
1727                        )));
1728                    }
1729                }
1730            }
1731            results.push(affected_rows);
1732        }
1733
1734        self.expect_ready().await?;
1735        self.shrink_buffers();
1736        Ok(results)
1737    }
1738
1739    /// Process each row directly from the wire buffer via a closure.
1740    ///
1741    /// Zero arena allocation — the closure receives a [`PgDataRow`] that reads
1742    /// columns directly from the DataRow message bytes in the read buffer.
1743    /// Column offsets are pre-scanned once per row into a stack-allocated SmallVec.
1744    ///
1745    /// This is the fastest path for row-by-row processing: no arena, no Vec of
1746    /// offsets, no materialization of the entire result set.
1747    pub async fn for_each<F>(
1748        &mut self,
1749        sql: &str,
1750        sql_hash: u64,
1751        params: &[&(dyn Encode + Sync)],
1752        mut f: F,
1753    ) -> Result<(), DriverError>
1754    where
1755        F: FnMut(PgDataRow<'_>) -> Result<(), DriverError>,
1756    {
1757        let _ = self
1758            .send_pipeline(sql, sql_hash, params, false, false)
1759            .await?;
1760
1761        loop {
1762            let msg = self.read_one_message().await?;
1763            match msg {
1764                BackendMessage::DataRow { data } => {
1765                    let row = PgDataRow::new(data)?;
1766                    f(row)?;
1767                }
1768                BackendMessage::CommandComplete { .. } => break,
1769                BackendMessage::EmptyQuery => break,
1770                BackendMessage::NoticeResponse { .. } => {}
1771                BackendMessage::ErrorResponse { data } => {
1772                    let fields = proto::parse_error_response(data);
1773                    self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1774                    self.drain_to_ready().await?;
1775                    return Err(self.make_server_error(fields));
1776                }
1777                other => {
1778                    return Err(DriverError::Protocol(format!(
1779                        "unexpected message during for_each: {other:?}"
1780                    )));
1781                }
1782            }
1783        }
1784
1785        self.expect_ready().await?;
1786        self.shrink_buffers();
1787        Ok(())
1788    }
1789
1790    /// Process each DataRow as raw bytes — no `PgDataRow`, no SmallVec, no
1791    /// pre-scanning of column offsets.
1792    ///
1793    /// The closure receives the raw DataRow message payload (starting with the
1794    /// `i16` column count). Generated code decodes columns sequentially inline,
1795    /// advancing a position cursor through the bytes.
1796    ///
1797    /// This is faster than `for_each` because it eliminates the SmallVec
1798    /// construction (~20-30ns per row) and the per-column method call overhead.
1799    ///
1800    /// Optimization: DataRow messages that fit entirely within `stream_buf` are
1801    /// parsed directly from the buffer (zero-copy — no memcpy into `read_buf`).
1802    /// Messages that span the buffer boundary fall back to `read_message_buffered`.
1803    pub async fn for_each_raw<F>(
1804        &mut self,
1805        sql: &str,
1806        sql_hash: u64,
1807        params: &[&(dyn Encode + Sync)],
1808        mut f: F,
1809    ) -> Result<(), DriverError>
1810    where
1811        F: FnMut(&[u8]) -> Result<(), DriverError>,
1812    {
1813        let _ = self
1814            .send_pipeline(sql, sql_hash, params, false, true)
1815            .await?;
1816
1817        // Read BindComplete inline from stream_buf — avoids the full
1818        // expect_message -> read_one_message -> read_message_buffered path.
1819        // BindComplete is always exactly 5 bytes: type='2'(1) + len=4(4).
1820        loop {
1821            let avail = self.stream_buf_end - self.stream_buf_pos;
1822            if avail >= 5 {
1823                let bc_type = self.stream_buf[self.stream_buf_pos];
1824                match bc_type {
1825                    b'2' => {
1826                        // BindComplete — skip the 5-byte message.
1827                        self.stream_buf_pos += 5;
1828                        break;
1829                    }
1830                    b'E' => {
1831                        // ErrorResponse — fall back to full message reader.
1832                        let msg = self.read_one_message().await?;
1833                        if let BackendMessage::ErrorResponse { data } = msg {
1834                            let fields = proto::parse_error_response(data);
1835                            self.drain_to_ready().await?;
1836                            return Err(self.make_server_error(fields));
1837                        }
1838                    }
1839                    b'N' | b'S' => {
1840                        // NoticeResponse or ParameterStatus — parse length,
1841                        // skip, and continue looking for BindComplete.
1842                        let raw_len = i32::from_be_bytes([
1843                            self.stream_buf[self.stream_buf_pos + 1],
1844                            self.stream_buf[self.stream_buf_pos + 2],
1845                            self.stream_buf[self.stream_buf_pos + 3],
1846                            self.stream_buf[self.stream_buf_pos + 4],
1847                        ]);
1848                        let total = 1 + raw_len as usize;
1849                        if avail >= total {
1850                            self.stream_buf_pos += total;
1851                            continue;
1852                        }
1853                        // Async message spans buffer boundary — fall back.
1854                        self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
1855                            .await?;
1856                        break;
1857                    }
1858                    _ => {
1859                        // Unexpected type — fall back to full reader for
1860                        // proper error handling.
1861                        self.expect_message(|m| matches!(m, BackendMessage::BindComplete))
1862                            .await?;
1863                        break;
1864                    }
1865                }
1866            } else {
1867                // Not enough data in stream_buf — compact and refill.
1868                let remaining = self.stream_buf_end - self.stream_buf_pos;
1869                if remaining > 0 && self.stream_buf_pos > 0 {
1870                    self.stream_buf
1871                        .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1872                }
1873                self.stream_buf_pos = 0;
1874                self.stream_buf_end = remaining;
1875
1876                let n = {
1877                    let mut reader = StreamReader(&mut self.stream);
1878                    use tokio::io::AsyncReadExt;
1879                    reader
1880                        .read(&mut self.stream_buf[remaining..])
1881                        .await
1882                        .map_err(DriverError::Io)?
1883                };
1884                if n == 0 {
1885                    return Err(DriverError::Io(std::io::Error::new(
1886                        std::io::ErrorKind::UnexpectedEof,
1887                        "connection closed",
1888                    )));
1889                }
1890                self.stream_buf_end = remaining + n;
1891            }
1892        }
1893
1894        // Bulk DataRow loop: parse messages directly from stream_buf when possible.
1895        'outer: loop {
1896            // Inner loop: process all complete messages already in stream_buf.
1897            loop {
1898                let avail = self.stream_buf_end - self.stream_buf_pos;
1899                if avail < 5 {
1900                    break; // need more data from TCP
1901                }
1902
1903                let msg_type = self.stream_buf[self.stream_buf_pos];
1904                let raw_len = i32::from_be_bytes([
1905                    self.stream_buf[self.stream_buf_pos + 1],
1906                    self.stream_buf[self.stream_buf_pos + 2],
1907                    self.stream_buf[self.stream_buf_pos + 3],
1908                    self.stream_buf[self.stream_buf_pos + 4],
1909                ]);
1910
1911                if raw_len < 4 {
1912                    return Err(DriverError::Protocol(format!(
1913                        "invalid message length {raw_len} for type '{}'",
1914                        msg_type as char
1915                    )));
1916                }
1917
1918                let payload_len = (raw_len - 4) as usize;
1919                let total_msg_len = 5 + payload_len; // type(1) + length(4) + payload
1920
1921                if avail < total_msg_len {
1922                    // Message doesn't fit in available buffer data.
1923                    if total_msg_len > self.stream_buf.len() {
1924                        // Message is larger than entire stream_buf — fall back to
1925                        // read_message_buffered which handles arbitrary sizes.
1926                        let msg = self.read_one_message().await?;
1927                        match msg {
1928                            BackendMessage::DataRow { data } => {
1929                                f(data)?;
1930                                continue;
1931                            }
1932                            BackendMessage::CommandComplete { .. } | BackendMessage::EmptyQuery => {
1933                                break 'outer;
1934                            }
1935                            BackendMessage::ErrorResponse { data } => {
1936                                let fields = proto::parse_error_response(data);
1937                                self.maybe_invalidate_stmt_cache(&fields, sql_hash);
1938                                self.drain_to_ready().await?;
1939                                return Err(self.make_server_error(fields));
1940                            }
1941                            BackendMessage::NoticeResponse { .. } => continue,
1942                            other => {
1943                                return Err(DriverError::Protocol(format!(
1944                                    "unexpected message during for_each_raw: {other:?}"
1945                                )));
1946                            }
1947                        }
1948                    }
1949                    // Partial message in buffer — compact and refill below.
1950                    break;
1951                }
1952
1953                // Full message is available in stream_buf — zero-copy path.
1954                let payload_start = self.stream_buf_pos + 5;
1955                let payload_end = payload_start + payload_len;
1956
1957                // Happy path first: DataRow is ~99.9% of messages during
1958                // bulk streaming. Single predicted branch.
1959                if msg_type == b'D' {
1960                    // DataRow — ZERO COPY from stream_buf.
1961                    // Safety: payload_start..payload_end is within stream_buf bounds
1962                    // (checked by `avail < total_msg_len` above).
1963                    f(&self.stream_buf[payload_start..payload_end])?;
1964                } else if msg_type == b'C' || msg_type == b'I' {
1965                    // CommandComplete / EmptyQuery — done.
1966                    self.stream_buf_pos += total_msg_len;
1967                    break 'outer;
1968                } else {
1969                    self.handle_non_datarow_async(msg_type, payload_start, payload_end, sql_hash)
1970                        .await?;
1971                }
1972
1973                self.stream_buf_pos += total_msg_len;
1974            }
1975
1976            // Compact: move unprocessed bytes to front of buffer.
1977            let remaining = self.stream_buf_end - self.stream_buf_pos;
1978            if remaining > 0 && self.stream_buf_pos > 0 {
1979                self.stream_buf
1980                    .copy_within(self.stream_buf_pos..self.stream_buf_end, 0);
1981            }
1982            self.stream_buf_pos = 0;
1983            self.stream_buf_end = remaining;
1984
1985            // Read more from TCP.
1986            let n = {
1987                let mut reader = StreamReader(&mut self.stream);
1988                use tokio::io::AsyncReadExt;
1989                reader
1990                    .read(&mut self.stream_buf[remaining..])
1991                    .await
1992                    .map_err(DriverError::Io)?
1993            };
1994            if n == 0 {
1995                return Err(DriverError::Io(std::io::Error::new(
1996                    std::io::ErrorKind::UnexpectedEof,
1997                    "connection closed",
1998                )));
1999            }
2000            self.stream_buf_end = remaining + n;
2001        }
2002
2003        // Read ReadyForQuery.
2004        self.expect_ready().await?;
2005        self.shrink_buffers();
2006        Ok(())
2007    }
2008
2009    /// Simple query protocol — for non-prepared SQL (BEGIN, COMMIT, SET, etc.).
2010    ///
2011    /// Does not use the extended query protocol. Cannot have parameters.
2012    pub async fn simple_query(&mut self, sql: &str) -> Result<(), DriverError> {
2013        self.write_buf.clear();
2014        proto::write_simple_query(&mut self.write_buf, sql);
2015        self.flush_write().await?;
2016
2017        // Read until ReadyForQuery
2018        loop {
2019            let msg = self.read_one_message().await?;
2020            match msg {
2021                BackendMessage::ReadyForQuery { status } => {
2022                    self.tx_status = status;
2023                    return Ok(());
2024                }
2025                BackendMessage::CommandComplete { .. }
2026                | BackendMessage::RowDescription { .. }
2027                | BackendMessage::DataRow { .. }
2028                | BackendMessage::EmptyQuery
2029                | BackendMessage::NoticeResponse { .. } => {}
2030                BackendMessage::ErrorResponse { data } => {
2031                    let fields = proto::parse_error_response(data);
2032                    self.drain_to_ready().await?;
2033                    return Err(self.make_server_error(fields));
2034                }
2035
2036                // ParameterStatus can arrive asynchronously during any query.
2037                BackendMessage::ParameterStatus { .. } => {}
2038
2039                // Startup messages should not appear post-startup, but if
2040                // the stream buffer contains leftover data, skip them safely.
2041                BackendMessage::AuthOk
2042                | BackendMessage::AuthSaslFinal { .. }
2043                | BackendMessage::AuthSaslContinue { .. }
2044                | BackendMessage::AuthSasl { .. }
2045                | BackendMessage::AuthMd5 { .. }
2046                | BackendMessage::AuthCleartext
2047                | BackendMessage::BackendKeyData { .. } => {}
2048
2049                other => {
2050                    return Err(DriverError::Protocol(format!(
2051                        "unexpected message during simple_query: {other:?}"
2052                    )));
2053                }
2054            }
2055        }
2056    }
2057
2058    /// Block until a NotificationResponse arrives on this connection.
2059    ///
2060    /// Reads raw messages from the stream and skips everything except
2061    /// `NotificationResponse`. Returns the `(channel, payload)` pair.
2062    /// Used by the listener's background task to receive LISTEN/NOTIFY events.
2063    ///
2064    /// This method never returns `Ok` for non-notification messages -- it loops
2065    /// internally, discarding `ParameterStatus`, `NoticeResponse`, etc.
2066    pub async fn wait_for_notification(&mut self) -> Result<(String, String), DriverError> {
2067        loop {
2068            let (msg_type, _payload_len) = self.read_message_buffered().await?;
2069            let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2070            match msg {
2071                BackendMessage::NotificationResponse {
2072                    channel, payload, ..
2073                } => {
2074                    return Ok((channel.to_owned(), payload.to_owned()));
2075                }
2076                BackendMessage::ParameterStatus { .. } | BackendMessage::NoticeResponse { .. } => {
2077                    continue;
2078                }
2079                _ => continue,
2080            }
2081        }
2082    }
2083
2084    /// Send Terminate and close the connection.
2085    pub async fn close(mut self) -> Result<(), DriverError> {
2086        self.write_buf.clear();
2087        proto::write_terminate(&mut self.write_buf);
2088        // Best-effort flush — ignore errors since we're closing
2089        let _ = self.flush_write().await;
2090        Ok(())
2091    }
2092
2093    /// Whether the connection is in an idle transaction state.
2094    pub fn is_idle(&self) -> bool {
2095        self.tx_status == b'I'
2096    }
2097
2098    /// Whether the connection is in a transaction.
2099    pub fn is_in_transaction(&self) -> bool {
2100        self.tx_status == b'T'
2101    }
2102
2103    /// Whether the connection is in a failed transaction.
2104    pub fn is_in_failed_transaction(&self) -> bool {
2105        self.tx_status == b'E'
2106    }
2107
2108    /// Record that the connection was just used. Called after successful
2109    /// query completion so the pool can detect stale connections.
2110    pub fn touch(&mut self) {
2111        self.last_used = std::time::Instant::now();
2112    }
2113
2114    /// How long since this connection last completed a query.
2115    pub fn idle_duration(&self) -> std::time::Duration {
2116        self.last_used.elapsed()
2117    }
2118
2119    /// Get a server parameter value (set during startup or via SET).
2120    pub fn parameter(&self, name: &str) -> Option<&str> {
2121        self.params
2122            .iter()
2123            .find(|(k, _)| &**k == name)
2124            .map(|(_, v)| &**v)
2125    }
2126
2127    /// All server parameters received during startup.
2128    pub fn server_params(&self) -> &[(Box<str>, Box<str>)] {
2129        &self.params
2130    }
2131
2132    /// Validate critical server parameters after startup.
2133    ///
2134    /// Checks:
2135    /// - `server_encoding` must be UTF-8 (or UTF8). Our SIMD UTF-8 validation
2136    ///   and text decoding assume UTF-8 encoding.
2137    /// - `integer_datetimes` must be "on". Our timestamp/date codecs assume
2138    ///   integer-format timestamps (microseconds since 2000-01-01). If "off",
2139    ///   PG uses float-format timestamps and our decode is wrong.
2140    fn validate_server_params(&self) -> Result<(), DriverError> {
2141        // Check server_encoding — must be UTF-8
2142        if let Some(encoding) = self.parameter("server_encoding") {
2143            let normalized = encoding.to_uppercase();
2144            if normalized != "UTF8" && normalized != "UTF-8" {
2145                return Err(DriverError::Protocol(format!(
2146                    "server_encoding is '{encoding}', but bsql requires UTF-8. \
2147                     Set server encoding to UTF-8 in postgresql.conf or \
2148                     use CREATE DATABASE ... ENCODING 'UTF8'."
2149                )));
2150            }
2151        }
2152
2153        // Check client_encoding — must be UTF-8
2154        if let Some(encoding) = self.parameter("client_encoding") {
2155            let normalized = encoding.to_uppercase();
2156            if normalized != "UTF8" && normalized != "UTF-8" {
2157                return Err(DriverError::Protocol(format!(
2158                    "client_encoding is '{encoding}', but bsql requires UTF-8. \
2159                     Check your connection or database configuration."
2160                )));
2161            }
2162        }
2163
2164        // Check integer_datetimes — MUST be "on"
2165        if let Some(idt) = self.parameter("integer_datetimes") {
2166            if idt != "on" {
2167                return Err(DriverError::Protocol(format!(
2168                    "integer_datetimes is '{idt}', but bsql requires 'on'. \
2169                     Our timestamp codec assumes integer-format timestamps \
2170                     (microseconds since 2000-01-01). Float-format timestamps \
2171                     would produce incorrect decode results."
2172                )));
2173            }
2174        }
2175
2176        Ok(())
2177    }
2178
2179    /// Backend process ID (for cancel requests).
2180    pub fn pid(&self) -> i32 {
2181        self.pid
2182    }
2183
2184    /// Backend secret key (for cancel requests).
2185    pub fn secret_key(&self) -> i32 {
2186        self.secret
2187    }
2188
2189    /// Cancel the currently running query on this connection.
2190    ///
2191    /// Opens a NEW TCP connection to the same host:port and sends a
2192    /// CancelRequest message (16 bytes: length=16, code=80877102, pid, secret).
2193    /// The cancel connection is closed immediately after sending.
2194    ///
2195    /// The `config` is needed to get the host:port for the new TCP connection.
2196    pub async fn cancel(&self, config: &Config) -> Result<(), DriverError> {
2197        let addr = format!("{}:{}", config.host, config.port);
2198        let mut tcp = TcpStream::connect(&addr).await.map_err(DriverError::Io)?;
2199        let mut buf = Vec::with_capacity(16);
2200        proto::write_cancel_request(&mut buf, self.pid, self.secret);
2201        tcp.write_all(&buf).await.map_err(DriverError::Io)?;
2202        tcp.flush().await.map_err(DriverError::Io)?;
2203        // Close immediately — PG expects no further data
2204        drop(tcp);
2205        Ok(())
2206    }
2207
2208    /// Whether a streaming query is in progress.
2209    pub fn is_streaming(&self) -> bool {
2210        self.streaming_active
2211    }
2212
2213    /// Drain all buffered notifications received during query processing.
2214    ///
2215    /// Returns the pending notifications and clears the buffer.
2216    /// Notifications arrive asynchronously from PG (via LISTEN/NOTIFY)
2217    /// and are buffered during normal query execution instead of being dropped.
2218    pub fn drain_notifications(&mut self) -> Vec<Notification> {
2219        std::mem::take(&mut self.pending_notifications)
2220    }
2221
2222    /// Number of pending notifications in the buffer.
2223    pub fn pending_notification_count(&self) -> usize {
2224        self.pending_notifications.len()
2225    }
2226
2227    /// Set the maximum number of cached prepared statements.
2228    ///
2229    /// When the cache exceeds this size, the least recently used statement
2230    /// is evicted and a Close message is sent to PG to free server memory.
2231    /// Default: 256.
2232    pub fn set_max_stmt_cache_size(&mut self, size: usize) {
2233        self.max_stmt_cache_size = size;
2234    }
2235
2236    /// Number of currently cached prepared statements.
2237    pub fn stmt_cache_len(&self) -> usize {
2238        self.stmts.len()
2239    }
2240
2241    /// Set TCP keepalive on a socket to detect dead connections.
2242    fn set_keepalive(tcp: &TcpStream) -> Result<(), DriverError> {
2243        let sock = socket2::SockRef::from(tcp);
2244        let ka = socket2::TcpKeepalive::new()
2245            .with_time(std::time::Duration::from_secs(60))
2246            .with_interval(std::time::Duration::from_secs(15));
2247        sock.set_tcp_keepalive(&ka).map_err(DriverError::Io)?;
2248        Ok(())
2249    }
2250
2251    /// When this connection was created.
2252    pub fn created_at(&self) -> std::time::Instant {
2253        self.created_at
2254    }
2255
2256    // --- Internal helpers ---
2257
2258    /// Insert a statement into the cache, evicting the LRU entry if full.
2259    ///
2260    /// When the cache exceeds `max_stmt_cache_size`, the least recently used
2261    /// statement is evicted. A Close(Statement) message is queued to free
2262    /// server-side memory. The Close is sent lazily on the next flush.
2263    ///
2264    /// 256 entries = negligible linear scan cost (~1us worst case).
2265    fn cache_stmt(&mut self, sql_hash: u64, info: StmtInfo) {
2266        // Evict LRU if cache is full
2267        if self.stmts.len() >= self.max_stmt_cache_size && !self.stmts.contains_key(&sql_hash) {
2268            if let Some((_lru_hash, evicted)) = self.stmts.evict_lru() {
2269                // Queue Close(Statement) to free server-side memory.
2270                // This will be sent on the next write+flush.
2271                proto::write_close(&mut self.write_buf, b'S', &evicted.name);
2272            }
2273        }
2274        self.stmts.insert(sql_hash, info);
2275    }
2276
2277    /// Buffer a notification received during query processing.
2278    fn buffer_notification(&mut self, pid: i32, channel: &str, payload: &str) {
2279        // Cap at 1024 buffered notifications to prevent unbounded memory growth
2280        if self.pending_notifications.len() < 1024 {
2281            self.pending_notifications.push(Notification {
2282                pid,
2283                channel: channel.to_owned(),
2284                payload: payload.to_owned(),
2285            });
2286        }
2287    }
2288
2289    /// Reclaim memory if buffers grew beyond normal thresholds.
2290    ///
2291    /// Called after query()/execute() to prevent a single large result from
2292    /// permanently bloating the connection's buffers.
2293    fn shrink_buffers(&mut self) {
2294        // Only check every 64 queries — the capacity comparisons are cheap
2295        // but the shrink itself (realloc) is not. Most queries never trigger
2296        // the threshold, so this saves ~2-5ns of branch overhead per query.
2297        if self.query_counter & 63 != 0 {
2298            return;
2299        }
2300        if self.read_buf.capacity() > 64 * 1024 {
2301            self.read_buf.clear();
2302            self.read_buf.shrink_to(8192);
2303        }
2304        if self.write_buf.capacity() > 16 * 1024 {
2305            self.write_buf.clear();
2306            self.write_buf.shrink_to(8192);
2307        }
2308    }
2309
2310    /// Read one backend message. The returned message borrows from `self.read_buf`.
2311    ///
2312    /// When a NotificationResponse is received, it is automatically buffered
2313    /// in `self.pending_notifications` and the next message is read instead.
2314    /// This means callers never see NotificationResponse from this method.
2315    async fn read_one_message(&mut self) -> Result<BackendMessage<'_>, DriverError> {
2316        loop {
2317            let (msg_type, _payload_len) = self.read_message_buffered().await?;
2318            // Check for NotificationResponse before parsing into BackendMessage,
2319            // because we need to extract owned data while we have exclusive access.
2320            if msg_type == b'A' {
2321                let msg = proto::parse_backend_message(msg_type, &self.read_buf)?;
2322                if let BackendMessage::NotificationResponse {
2323                    pid,
2324                    channel,
2325                    payload,
2326                } = msg
2327                {
2328                    // Extract owned data before releasing the borrow on self.read_buf.
2329                    let pid_owned = pid;
2330                    let channel_owned = channel.to_owned();
2331                    let payload_owned = payload.to_owned();
2332                    self.buffer_notification(pid_owned, &channel_owned, &payload_owned);
2333                    continue; // read next message
2334                }
2335            }
2336            return proto::parse_backend_message(msg_type, &self.read_buf);
2337        }
2338    }
2339
2340    /// Read messages until we find one matching `pred`, erroring on ErrorResponse.
2341    ///
2342    /// On error, drains to ReadyForQuery so the connection remains usable.
2343    /// Skips NotificationResponse, NoticeResponse, and ParameterStatus — all
2344    /// of which PostgreSQL can send asynchronously at any time.
2345    async fn expect_message(
2346        &mut self,
2347        pred: impl Fn(&BackendMessage<'_>) -> bool,
2348    ) -> Result<(), DriverError> {
2349        loop {
2350            let msg = self.read_one_message().await?;
2351            if pred(&msg) {
2352                return Ok(());
2353            }
2354            match msg {
2355                BackendMessage::ErrorResponse { data } => {
2356                    let fields = proto::parse_error_response(data);
2357                    self.drain_to_ready().await?;
2358                    return Err(self.make_server_error(fields));
2359                }
2360                BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {
2361                    // Asynchronous messages — skip them
2362                    // (NotificationResponse is auto-buffered by read_one_message)
2363                }
2364                other => {
2365                    return Err(DriverError::Protocol(format!(
2366                        "unexpected message while waiting for expected type: {other:?}"
2367                    )));
2368                }
2369            }
2370        }
2371    }
2372
2373    /// Read until ReadyForQuery. Skips NotificationResponse and other async messages.
2374    async fn expect_ready(&mut self) -> Result<(), DriverError> {
2375        loop {
2376            let msg = self.read_one_message().await?;
2377            match msg {
2378                BackendMessage::ReadyForQuery { status } => {
2379                    self.tx_status = status;
2380                    return Ok(());
2381                }
2382                BackendMessage::NoticeResponse { .. } | BackendMessage::ParameterStatus { .. } => {}
2383                BackendMessage::ErrorResponse { data } => {
2384                    let fields = proto::parse_error_response(data);
2385                    // Continue draining until ReadyForQuery
2386                    self.drain_to_ready().await?;
2387                    return Err(self.make_server_error(fields));
2388                }
2389                _ => {}
2390            }
2391        }
2392    }
2393
2394    /// Drain messages until ReadyForQuery (used after an error).
2395    /// Skips all intermediate messages including NotificationResponse.
2396    async fn drain_to_ready(&mut self) -> Result<(), DriverError> {
2397        loop {
2398            let msg = self.read_one_message().await?;
2399            if let BackendMessage::ReadyForQuery { status } = msg {
2400                self.tx_status = status;
2401                return Ok(());
2402            }
2403        }
2404    }
2405
2406    /// Check if an error is SQLSTATE 26000 ("prepared statement does not exist").
2407    /// If so, remove the stale entry from the statement cache so the caller can retry.
2408    fn maybe_invalidate_stmt_cache(&mut self, fields: &proto::ErrorFields, sql_hash: u64) -> bool {
2409        if &*fields.code == "26000" {
2410            self.stmts.remove(&sql_hash);
2411            true
2412        } else {
2413            false
2414        }
2415    }
2416
2417    /// Convert parsed ErrorFields into a DriverError::Server.
2418    #[cold]
2419    #[inline(never)]
2420    fn make_server_error(&self, fields: proto::ErrorFields) -> DriverError {
2421        DriverError::Server {
2422            code: fields.code,
2423            message: fields.message.into_boxed_str(),
2424            detail: fields.detail.map(String::into_boxed_str),
2425            hint: fields.hint.map(String::into_boxed_str),
2426            position: fields.position,
2427        }
2428    }
2429
2430    /// Handle non-DataRow messages during for_each_raw inline parsing (async).
2431    ///
2432    /// Separated from the hot loop so the compiler keeps DataRow processing
2433    /// tight in the instruction cache.
2434    #[cold]
2435    async fn handle_non_datarow_async(
2436        &mut self,
2437        msg_type: u8,
2438        payload_start: usize,
2439        payload_end: usize,
2440        sql_hash: u64,
2441    ) -> Result<(), DriverError> {
2442        match msg_type {
2443            b'E' => {
2444                let fields =
2445                    proto::parse_error_response(&self.stream_buf[payload_start..payload_end]);
2446                self.maybe_invalidate_stmt_cache(&fields, sql_hash);
2447                self.drain_to_ready().await?;
2448                return Err(self.make_server_error(fields));
2449            }
2450            b'A' => {
2451                let msg = proto::parse_backend_message(
2452                    msg_type,
2453                    &self.stream_buf[payload_start..payload_end],
2454                )?;
2455                if let BackendMessage::NotificationResponse {
2456                    pid,
2457                    channel,
2458                    payload,
2459                } = msg
2460                {
2461                    let ch = channel.to_owned();
2462                    let pl = payload.to_owned();
2463                    self.buffer_notification(pid, &ch, &pl);
2464                }
2465            }
2466            _ => {} // NoticeResponse, ParameterStatus — skip
2467        }
2468        Ok(())
2469    }
2470
2471    /// Flush the write buffer to the stream.
2472    ///
2473    /// Always flush after write_all for correctness. TCP_NODELAY only
2474    /// affects the kernel's Nagle algorithm; tokio's BufWriter (used internally
2475    /// by TcpStream) may still buffer. Always flushing ensures data reaches
2476    /// the wire immediately for both plain TCP and TLS.
2477    async fn flush_write(&mut self) -> Result<(), DriverError> {
2478        self.stream
2479            .write_all(&self.write_buf)
2480            .await
2481            .map_err(DriverError::Io)?;
2482        self.stream.flush().await.map_err(DriverError::Io)?;
2483        Ok(())
2484    }
2485
2486    /// Read one complete backend message using the internal buffer.
2487    ///
2488    /// Returns `(msg_type, payload_len)`. The payload is stored in `self.read_buf`.
2489    async fn read_message_buffered(&mut self) -> Result<(u8, usize), DriverError> {
2490        // Read 5-byte header: type(1) + length(4)
2491        let mut header = [0u8; 5];
2492        buffered_read_exact(
2493            &mut self.stream,
2494            &mut self.stream_buf,
2495            &mut self.stream_buf_pos,
2496            &mut self.stream_buf_end,
2497            &mut header,
2498        )
2499        .await?;
2500
2501        let msg_type = header[0];
2502        let len = i32::from_be_bytes([header[1], header[2], header[3], header[4]]);
2503
2504        if len < 4 {
2505            return Err(DriverError::Protocol(format!(
2506                "invalid message length {len} for type '{}'",
2507                msg_type as char
2508            )));
2509        }
2510
2511        const MAX_MESSAGE_LEN: i32 = 128 * 1024 * 1024;
2512        if len > MAX_MESSAGE_LEN {
2513            return Err(DriverError::Protocol(format!(
2514                "message length {len} exceeds maximum ({MAX_MESSAGE_LEN}) for type '{}'",
2515                msg_type as char
2516            )));
2517        }
2518
2519        let payload_len = (len - 4) as usize;
2520
2521        // the length (truncation or zeroes only new bytes beyond current len).
2522        // For the common case where read_buf was already large enough, the
2523        // zeroing cost is minimal. This is the price of safe Rust — we cannot
2524        // use set_len() without unsafe.
2525        self.read_buf.clear();
2526        self.read_buf.resize(payload_len, 0);
2527        if payload_len > 0 {
2528            buffered_read_exact(
2529                &mut self.stream,
2530                &mut self.stream_buf,
2531                &mut self.stream_buf_pos,
2532                &mut self.stream_buf_end,
2533                &mut self.read_buf[..payload_len],
2534            )
2535            .await?;
2536        }
2537
2538        Ok((msg_type, payload_len))
2539    }
2540}
2541
2542/// Read exactly `out.len()` bytes using a persistent read buffer.
2543///
2544/// This is a free function to avoid double-mutable-borrow issues when the caller
2545/// also needs to write into `self.read_buf`.
2546async fn buffered_read_exact(
2547    stream: &mut Stream,
2548    buf: &mut [u8],
2549    pos: &mut usize,
2550    end: &mut usize,
2551    out: &mut [u8],
2552) -> Result<(), DriverError> {
2553    let mut filled = 0;
2554    while filled < out.len() {
2555        let avail = *end - *pos;
2556        if avail > 0 {
2557            let take = avail.min(out.len() - filled);
2558            out[filled..filled + take].copy_from_slice(&buf[*pos..*pos + take]);
2559            *pos += take;
2560            filled += take;
2561        } else {
2562            // Buffer exhausted — refill from the stream
2563            *pos = 0;
2564            let n = {
2565                let mut reader = StreamReader(stream);
2566                use tokio::io::AsyncReadExt;
2567                reader.read(buf).await.map_err(DriverError::Io)?
2568            };
2569            if n == 0 {
2570                return Err(DriverError::Io(std::io::Error::new(
2571                    std::io::ErrorKind::UnexpectedEof,
2572                    "connection closed",
2573                )));
2574            }
2575            *end = n;
2576        }
2577    }
2578    Ok(())
2579}
2580
2581// --- Bind template builder ---
2582
2583/// Build a `BindTemplate` from the current write_buf contents.
2584///
2585/// Parses the Bind message to locate each parameter's data offset and length.
2586/// Appends EXECUTE_SYNC to the template bytes so the hot path is a single memcpy.
2587/// Returns `None` if the message cannot be parsed.
2588fn build_bind_template(write_buf: &[u8], param_count: usize) -> Option<BindTemplate> {
2589    if write_buf.is_empty() || write_buf[0] != b'B' {
2590        return None;
2591    }
2592    if write_buf.len() < 5 {
2593        return None;
2594    }
2595
2596    let mut pos = 5; // skip type byte (1) + length (4)
2597
2598    // Skip portal name (NUL-terminated).
2599    while pos < write_buf.len() && write_buf[pos] != 0 {
2600        pos += 1;
2601    }
2602    pos += 1;
2603
2604    // Skip statement name (NUL-terminated).
2605    while pos < write_buf.len() && write_buf[pos] != 0 {
2606        pos += 1;
2607    }
2608    pos += 1;
2609
2610    // Skip format codes.
2611    if pos + 2 > write_buf.len() {
2612        return None;
2613    }
2614    let num_fmt_codes = i16::from_be_bytes([write_buf[pos], write_buf[pos + 1]]);
2615    pos += 2;
2616    pos += num_fmt_codes.max(0) as usize * 2;
2617
2618    // Parameter count.
2619    if pos + 2 > write_buf.len() {
2620        return None;
2621    }
2622    let wire_param_count = i16::from_be_bytes([write_buf[pos], write_buf[pos + 1]]) as usize;
2623    pos += 2;
2624
2625    if wire_param_count != param_count {
2626        return None;
2627    }
2628
2629    let mut param_slots = Vec::with_capacity(param_count);
2630    for _ in 0..param_count {
2631        if pos + 4 > write_buf.len() {
2632            return None;
2633        }
2634        let data_len = i32::from_be_bytes([
2635            write_buf[pos],
2636            write_buf[pos + 1],
2637            write_buf[pos + 2],
2638            write_buf[pos + 3],
2639        ]);
2640        pos += 4;
2641
2642        if data_len < 0 {
2643            param_slots.push((pos, -1));
2644        } else {
2645            param_slots.push((pos, data_len));
2646            pos += data_len as usize;
2647        }
2648    }
2649
2650    // Include EXECUTE_SYNC in the template so the hot path is one memcpy.
2651    let bind_end = write_buf.len();
2652    let mut bytes = Vec::with_capacity(bind_end + proto::EXECUTE_SYNC.len());
2653    bytes.extend_from_slice(write_buf);
2654    bytes.extend_from_slice(proto::EXECUTE_SYNC);
2655
2656    Some(BindTemplate {
2657        bytes,
2658        bind_end,
2659        param_slots,
2660    })
2661}
2662
2663// --- QueryResult ---
2664
2665/// Result of a query execution. Owns the row offset metadata.
2666///
2667/// Uses flat column offset storage: all rows' `(arena_offset, length)` pairs
2668/// are stored contiguously in `all_col_offsets`. Row N starts at index
2669/// `N * num_cols`. No separate `row_starts` Vec needed.
2670///
2671/// # Example
2672///
2673/// ```no_run
2674/// # async fn example() -> Result<(), bsql_driver_postgres::DriverError> {
2675/// # let mut conn: bsql_driver_postgres::Connection = unimplemented!();
2676/// # let mut arena = bsql_driver_postgres::Arena::new();
2677/// let result = conn.query("SELECT 1 as n", 0, &[], &mut arena).await?;
2678/// for i in 0..result.len() {
2679///     let row = result.row(i, &arena);
2680///     // Access columns by index
2681/// }
2682/// # Ok(())
2683/// # }
2684/// ```
2685pub struct QueryResult {
2686    /// All rows' column (arena_offset, length) pairs, contiguous.
2687    /// length = -1 means NULL.
2688    all_col_offsets: Vec<(usize, i32)>,
2689    /// Number of columns per row.
2690    num_cols: usize,
2691    columns: Arc<[ColumnDesc]>,
2692    affected_rows: u64,
2693}
2694
2695impl QueryResult {
2696    /// Construct a `QueryResult` from its constituent parts.
2697    ///
2698    /// Used by `bsql-core`'s streaming layer to assemble per-chunk results.
2699    pub fn from_parts(
2700        all_col_offsets: Vec<(usize, i32)>,
2701        num_cols: usize,
2702        columns: Arc<[ColumnDesc]>,
2703        affected_rows: u64,
2704    ) -> Self {
2705        Self {
2706            all_col_offsets,
2707            num_cols,
2708            columns,
2709            affected_rows,
2710        }
2711    }
2712
2713    /// Number of rows in the result.
2714    pub fn len(&self) -> usize {
2715        if self.num_cols == 0 {
2716            return 0;
2717        }
2718        self.all_col_offsets.len() / self.num_cols
2719    }
2720
2721    /// Whether the result set is empty.
2722    pub fn is_empty(&self) -> bool {
2723        self.all_col_offsets.is_empty()
2724    }
2725
2726    /// Number of affected rows (for INSERT/UPDATE/DELETE).
2727    pub fn affected_rows(&self) -> u64 {
2728        self.affected_rows
2729    }
2730
2731    /// Column descriptors.
2732    pub fn columns(&self) -> &[ColumnDesc] {
2733        &self.columns
2734    }
2735
2736    /// Get a row by index. The returned `Row` borrows from the arena.
2737    pub fn row<'a>(&'a self, idx: usize, arena: &'a Arena) -> Row<'a> {
2738        let start = idx * self.num_cols;
2739        let end = start + self.num_cols;
2740        Row {
2741            arena,
2742            col_offsets: &self.all_col_offsets[start..end],
2743            columns: &self.columns,
2744        }
2745    }
2746
2747    /// Take the `col_offsets` vec out of this result, leaving it empty.
2748    ///
2749    /// Used by `QueryStream` to reclaim and reuse the allocation between chunks
2750    /// instead of allocating a new `Vec` per chunk.
2751    pub fn take_col_offsets(&mut self) -> Vec<(usize, i32)> {
2752        std::mem::take(&mut self.all_col_offsets)
2753    }
2754
2755    /// Iterate over rows.
2756    pub fn rows<'a>(&'a self, arena: &'a Arena) -> impl Iterator<Item = Row<'a>> {
2757        let num_cols = self.num_cols;
2758        let columns = &self.columns;
2759        self.all_col_offsets
2760            // .max(1) prevents a panic from chunks(0) when num_cols is 0
2761            // (e.g., commands with no columns like INSERT without RETURNING).
2762            .chunks(num_cols.max(1))
2763            .map(move |chunk| Row {
2764                arena,
2765                col_offsets: chunk,
2766                columns,
2767            })
2768    }
2769}
2770
2771// --- Row ---
2772
2773/// A view into a single result row, borrowing data from the arena.
2774///
2775/// Column values are accessed by index. NULL values return `None`.
2776/// Decode errors (protocol violations from a malicious server) are treated
2777/// as `None` rather than panicking — a compliant PostgreSQL server always
2778/// sends correctly-sized data for the declared type.
2779pub struct Row<'a> {
2780    arena: &'a Arena,
2781    col_offsets: &'a [(usize, i32)],
2782    columns: &'a [ColumnDesc],
2783}
2784
2785impl<'a> Row<'a> {
2786    /// Get the raw bytes for a column, or `None` if NULL.
2787    pub fn get_raw(&self, idx: usize) -> Option<&'a [u8]> {
2788        let (offset, len) = self.col_offsets[idx];
2789        if len < 0 {
2790            None
2791        } else {
2792            Some(self.arena.get(offset, len as usize))
2793        }
2794    }
2795
2796    /// Whether a column is NULL.
2797    pub fn is_null(&self, idx: usize) -> bool {
2798        self.col_offsets[idx].1 < 0
2799    }
2800
2801    /// Number of columns.
2802    pub fn column_count(&self) -> usize {
2803        self.col_offsets.len()
2804    }
2805
2806    /// Get a boolean column value. Returns `None` on NULL or decode error.
2807    pub fn get_bool(&self, idx: usize) -> Option<bool> {
2808        self.get_raw(idx)
2809            .and_then(|data| crate::codec::decode_bool(data).ok())
2810    }
2811
2812    /// Get an i16 column value. Returns `None` on NULL or decode error.
2813    pub fn get_i16(&self, idx: usize) -> Option<i16> {
2814        self.get_raw(idx)
2815            .and_then(|data| crate::codec::decode_i16(data).ok())
2816    }
2817
2818    /// Get an i32 column value. Returns `None` on NULL or decode error.
2819    pub fn get_i32(&self, idx: usize) -> Option<i32> {
2820        self.get_raw(idx)
2821            .and_then(|data| crate::codec::decode_i32(data).ok())
2822    }
2823
2824    /// Get an i64 column value. Returns `None` on NULL or decode error.
2825    pub fn get_i64(&self, idx: usize) -> Option<i64> {
2826        self.get_raw(idx)
2827            .and_then(|data| crate::codec::decode_i64(data).ok())
2828    }
2829
2830    /// Get an f32 column value. Returns `None` on NULL or decode error.
2831    pub fn get_f32(&self, idx: usize) -> Option<f32> {
2832        self.get_raw(idx)
2833            .and_then(|data| crate::codec::decode_f32(data).ok())
2834    }
2835
2836    /// Get an f64 column value. Returns `None` on NULL or decode error.
2837    pub fn get_f64(&self, idx: usize) -> Option<f64> {
2838        self.get_raw(idx)
2839            .and_then(|data| crate::codec::decode_f64(data).ok())
2840    }
2841
2842    /// Get a string column value. Returns `None` on NULL or decode error.
2843    pub fn get_str(&self, idx: usize) -> Option<&'a str> {
2844        self.get_raw(idx)
2845            .and_then(|data| crate::codec::decode_str(data).ok())
2846    }
2847
2848    /// Get a byte slice column value.
2849    pub fn get_bytes(&self, idx: usize) -> Option<&'a [u8]> {
2850        self.get_raw(idx)
2851    }
2852
2853    /// Get the column name by index.
2854    pub fn column_name(&self, idx: usize) -> &str {
2855        &self.columns[idx].name
2856    }
2857
2858    /// Get the column type OID by index.
2859    pub fn column_type_oid(&self, idx: usize) -> u32 {
2860        self.columns[idx].type_oid
2861    }
2862}
2863
2864// --- PgDataRow (zero-copy row view for for_each) ---
2865
2866/// A temporary view of a single PostgreSQL DataRow message.
2867///
2868/// Reads columns directly from the wire buffer — no arena copy.
2869/// Column offsets are pre-computed on construction using a `SmallVec`
2870/// that is stack-allocated for up to 16 columns (zero heap allocation
2871/// for the common case).
2872///
2873/// Lifetime `'a` borrows from `Connection::read_buf`.
2874pub struct PgDataRow<'a> {
2875    data: &'a [u8],
2876    /// Pre-scanned `(byte_offset, wire_len)` pairs for each column.
2877    /// `wire_len = -1` means NULL.
2878    offsets: smallvec::SmallVec<[(usize, i32); 16]>,
2879}
2880
2881impl<'a> PgDataRow<'a> {
2882    /// Parse column boundaries from a raw DataRow payload.
2883    ///
2884    /// `data` is the DataRow message payload (after the 'D' type byte and
2885    /// 4-byte length prefix have been stripped by the framing layer).
2886    pub fn new(data: &'a [u8]) -> Result<Self, DriverError> {
2887        if data.len() < 2 {
2888            return Err(DriverError::Protocol("DataRow too short".into()));
2889        }
2890        let num_cols = i16::from_be_bytes([data[0], data[1]]);
2891        if num_cols < 0 {
2892            return Err(DriverError::Protocol(
2893                "DataRow: negative column count".into(),
2894            ));
2895        }
2896        let num_cols = num_cols as usize;
2897        let mut offsets = smallvec::SmallVec::<[(usize, i32); 16]>::with_capacity(num_cols);
2898        let mut pos = 2usize;
2899        for _ in 0..num_cols {
2900            if pos + 4 > data.len() {
2901                return Err(DriverError::Protocol("DataRow truncated".into()));
2902            }
2903            let col_len =
2904                i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
2905            pos += 4;
2906            offsets.push((pos, col_len));
2907            if col_len > 0 {
2908                pos += col_len as usize;
2909            }
2910        }
2911        Ok(Self { data, offsets })
2912    }
2913
2914    /// Get the raw bytes for a column, or `None` if NULL.
2915    #[inline]
2916    pub fn get_raw(&self, idx: usize) -> Option<&'a [u8]> {
2917        let (offset, len) = self.offsets[idx];
2918        if len < 0 {
2919            None
2920        } else {
2921            Some(&self.data[offset..offset + len as usize])
2922        }
2923    }
2924
2925    /// Whether a column is NULL.
2926    #[inline]
2927    pub fn is_null(&self, idx: usize) -> bool {
2928        self.offsets[idx].1 < 0
2929    }
2930
2931    /// Number of columns.
2932    #[inline]
2933    pub fn column_count(&self) -> usize {
2934        self.offsets.len()
2935    }
2936
2937    /// Get a boolean column value. Returns `None` on NULL or decode error.
2938    #[inline]
2939    pub fn get_bool(&self, idx: usize) -> Option<bool> {
2940        self.get_raw(idx)
2941            .and_then(|data| crate::codec::decode_bool(data).ok())
2942    }
2943
2944    /// Get an i16 column value.
2945    #[inline]
2946    pub fn get_i16(&self, idx: usize) -> Option<i16> {
2947        self.get_raw(idx)
2948            .and_then(|data| crate::codec::decode_i16(data).ok())
2949    }
2950
2951    /// Get an i32 column value.
2952    #[inline]
2953    pub fn get_i32(&self, idx: usize) -> Option<i32> {
2954        self.get_raw(idx)
2955            .and_then(|data| crate::codec::decode_i32(data).ok())
2956    }
2957
2958    /// Get an i64 column value.
2959    #[inline]
2960    pub fn get_i64(&self, idx: usize) -> Option<i64> {
2961        self.get_raw(idx)
2962            .and_then(|data| crate::codec::decode_i64(data).ok())
2963    }
2964
2965    /// Get an f32 column value.
2966    #[inline]
2967    pub fn get_f32(&self, idx: usize) -> Option<f32> {
2968        self.get_raw(idx)
2969            .and_then(|data| crate::codec::decode_f32(data).ok())
2970    }
2971
2972    /// Get an f64 column value.
2973    #[inline]
2974    pub fn get_f64(&self, idx: usize) -> Option<f64> {
2975        self.get_raw(idx)
2976            .and_then(|data| crate::codec::decode_f64(data).ok())
2977    }
2978
2979    /// Get a string column value (zero-copy borrow from the wire buffer).
2980    #[inline]
2981    pub fn get_str(&self, idx: usize) -> Option<&'a str> {
2982        self.get_raw(idx)
2983            .and_then(|data| crate::codec::decode_str(data).ok())
2984    }
2985
2986    /// Get a byte slice column value (zero-copy borrow from the wire buffer).
2987    #[inline]
2988    pub fn get_bytes(&self, idx: usize) -> Option<&'a [u8]> {
2989        self.get_raw(idx)
2990    }
2991}
2992
2993// --- DataRow parsing ---
2994
2995/// Parse a DataRow message into the flat column offset storage.
2996///
2997/// Appends `(arena_offset, length)` pairs for each column to `out`.
2998/// `length = -1` indicates NULL.
2999///
3000/// DataRow format: `[num_columns: i16] ([col_len: i32] [col_data: col_len bytes])...`
3001fn parse_data_row_flat(
3002    data: &[u8],
3003    arena: &mut Arena,
3004    out: &mut Vec<(usize, i32)>,
3005) -> Result<(), DriverError> {
3006    if data.len() < 2 {
3007        return Err(DriverError::Protocol("DataRow too short".into()));
3008    }
3009
3010    let num_cols_raw = i16::from_be_bytes([data[0], data[1]]);
3011    if num_cols_raw < 0 {
3012        return Err(DriverError::Protocol(
3013            "DataRow: negative column count".into(),
3014        ));
3015    }
3016    let num_cols = num_cols_raw as usize;
3017    out.reserve(num_cols);
3018    let mut pos = 2;
3019
3020    for _ in 0..num_cols {
3021        if pos + 4 > data.len() {
3022            return Err(DriverError::Protocol("DataRow truncated".into()));
3023        }
3024
3025        let col_len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3026        pos += 4;
3027
3028        if col_len < 0 {
3029            // NULL
3030            out.push((0, -1));
3031        } else {
3032            let len = col_len as usize;
3033            if pos + len > data.len() {
3034                return Err(DriverError::Protocol(
3035                    "DataRow column data truncated".into(),
3036                ));
3037            }
3038
3039            let offset = arena.alloc_copy(&data[pos..pos + len]);
3040            out.push((offset, col_len));
3041            pos += len;
3042        }
3043    }
3044
3045    Ok(())
3046}
3047
3048/// Compute a rapidhash of a SQL string.
3049///
3050/// Uses `str::hash()` via the `Hash` trait, matching `bsql_core::rapid_hash_str`.
3051pub fn hash_sql(sql: &str) -> u64 {
3052    use std::hash::{Hash, Hasher};
3053    let mut hasher = RapidHasher::default();
3054    sql.hash(&mut hasher);
3055    hasher.finish()
3056}
3057
3058#[cfg(test)]
3059#[allow(clippy::approx_constant)]
3060mod tests {
3061    use super::*;
3062
3063    #[test]
3064    fn config_parse_full_url() {
3065        let cfg = Config::from_url("postgres://user:pass@localhost:5432/mydb").unwrap();
3066        assert_eq!(cfg.user, "user");
3067        assert_eq!(cfg.password, "pass");
3068        assert_eq!(cfg.host, "localhost");
3069        assert_eq!(cfg.port, 5432);
3070        assert_eq!(cfg.database, "mydb");
3071    }
3072
3073    #[test]
3074    fn config_parse_default_port() {
3075        let cfg = Config::from_url("postgres://user:pass@localhost/mydb").unwrap();
3076        assert_eq!(cfg.port, 5432);
3077    }
3078
3079    #[test]
3080    fn config_parse_no_password() {
3081        let cfg = Config::from_url("postgres://user@localhost/mydb").unwrap();
3082        assert_eq!(cfg.user, "user");
3083        assert_eq!(cfg.password, "");
3084    }
3085
3086    #[test]
3087    fn config_parse_empty_database() {
3088        let cfg = Config::from_url("postgres://user:pass@localhost").unwrap();
3089        // database defaults to user
3090        assert_eq!(cfg.database, "user");
3091    }
3092
3093    #[test]
3094    fn config_parse_sslmode() {
3095        let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=require").unwrap();
3096        assert_eq!(cfg.ssl, SslMode::Require);
3097    }
3098
3099    #[test]
3100    fn config_parse_percent_encoding() {
3101        let cfg = Config::from_url("postgres://user%40domain:p%40ss@localhost/db").unwrap();
3102        assert_eq!(cfg.user, "user@domain");
3103        assert_eq!(cfg.password, "p@ss");
3104    }
3105
3106    #[test]
3107    fn config_rejects_bad_scheme() {
3108        let result = Config::from_url("mysql://user:pass@localhost/db");
3109        assert!(result.is_err());
3110    }
3111
3112    /// Unknown sslmode should error, not silently default to Prefer.
3113    #[test]
3114    fn config_rejects_unknown_sslmode() {
3115        let result = Config::from_url("postgres://user:pass@localhost/db?sslmode=requre");
3116        assert!(result.is_err(), "typo 'requre' should be rejected");
3117        let result = Config::from_url("postgres://user:pass@localhost/db?sslmode=REQUIRE");
3118        assert!(result.is_err(), "uppercase should be rejected");
3119        let result = Config::from_url("postgres://user:pass@localhost/db?sslmode=bogus");
3120        assert!(result.is_err(), "bogus value should be rejected");
3121    }
3122
3123    /// Valid sslmodes should still work.
3124    #[test]
3125    fn config_accepts_valid_sslmodes() {
3126        let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=disable").unwrap();
3127        assert_eq!(cfg.ssl, SslMode::Disable);
3128        let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=prefer").unwrap();
3129        assert_eq!(cfg.ssl, SslMode::Prefer);
3130        let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=require").unwrap();
3131        assert_eq!(cfg.ssl, SslMode::Require);
3132    }
3133
3134    /// Vec-based StmtCache basic operations.
3135    #[test]
3136    fn stmt_cache_basic_ops() {
3137        let mut cache = StmtCache::default();
3138        assert_eq!(cache.len(), 0);
3139        assert!(!cache.contains_key(&42));
3140        assert!(cache.get(&42).is_none());
3141        assert!(cache.get_mut(&42).is_none());
3142        assert!(cache.remove(&42).is_none());
3143    }
3144
3145    /// Statement name formatting uses hex encoding.
3146    #[test]
3147    fn stmt_name_format() {
3148        let name = make_stmt_name(0);
3149        assert_eq!(&*name, "s_0000000000000000");
3150        let name = make_stmt_name(0xDEADBEEF12345678);
3151        assert_eq!(&*name, "s_deadbeef12345678");
3152        let name = make_stmt_name(u64::MAX);
3153        assert_eq!(&*name, "s_ffffffffffffffff");
3154    }
3155
3156    #[test]
3157    fn hash_sql_deterministic() {
3158        let h1 = hash_sql("SELECT 1");
3159        let h2 = hash_sql("SELECT 1");
3160        assert_eq!(h1, h2);
3161    }
3162
3163    #[test]
3164    fn hash_sql_different_queries() {
3165        let h1 = hash_sql("SELECT 1");
3166        let h2 = hash_sql("SELECT 2");
3167        assert_ne!(h1, h2);
3168    }
3169
3170    #[test]
3171    fn data_row_parsing() {
3172        let mut arena = Arena::new();
3173        let mut out = Vec::new();
3174
3175        // Build a DataRow with 2 columns: i32(42) and NULL
3176        let mut data = Vec::new();
3177        data.extend_from_slice(&2i16.to_be_bytes()); // 2 columns
3178
3179        // Column 1: i32 = 42
3180        data.extend_from_slice(&4i32.to_be_bytes()); // length = 4
3181        data.extend_from_slice(&42i32.to_be_bytes()); // value
3182
3183        // Column 2: NULL
3184        data.extend_from_slice(&(-1i32).to_be_bytes()); // length = -1
3185
3186        parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3187        assert_eq!(out.len(), 2);
3188
3189        // First column should have length 4
3190        assert_eq!(out[0].1, 4);
3191
3192        // Second column should be NULL
3193        assert_eq!(out[1].1, -1);
3194    }
3195
3196    #[test]
3197    fn data_row_empty() {
3198        let mut arena = Arena::new();
3199        let mut out = Vec::new();
3200        let data = 0i16.to_be_bytes();
3201        parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
3202        assert_eq!(out.len(), 0);
3203    }
3204
3205    #[test]
3206    fn query_result_empty() {
3207        let result = QueryResult {
3208            all_col_offsets: vec![],
3209            num_cols: 0,
3210            columns: Arc::from(Vec::new()),
3211            affected_rows: 0,
3212        };
3213        assert!(result.is_empty());
3214        assert_eq!(result.len(), 0);
3215    }
3216
3217    #[test]
3218    fn url_decode_works() {
3219        assert_eq!(url_decode("hello%20world").unwrap(), "hello world");
3220        assert_eq!(url_decode("no%20escape").unwrap(), "no escape");
3221        assert_eq!(url_decode("plain").unwrap(), "plain");
3222        assert_eq!(url_decode("a%40b").unwrap(), "a@b");
3223    }
3224
3225    #[test]
3226    fn url_decode_malformed_percent_trailing() {
3227        // Truncated percent sequence at end of string
3228        let result = url_decode("abc%2");
3229        assert!(result.is_err(), "truncated %2 should error");
3230    }
3231
3232    #[test]
3233    fn url_decode_malformed_percent_no_digits() {
3234        // % followed by no digits at all
3235        let result = url_decode("abc%");
3236        assert!(result.is_err(), "bare % at end should error");
3237    }
3238
3239    #[test]
3240    fn url_decode_invalid_hex_digit() {
3241        // %GG — 'G' is not a valid hex digit
3242        let result = url_decode("abc%GG");
3243        assert!(result.is_err(), "%GG should error");
3244    }
3245
3246    #[test]
3247    fn url_decode_invalid_hex_second_digit() {
3248        // %2Z — 'Z' is not a valid hex digit
3249        let result = url_decode("abc%2Z");
3250        assert!(result.is_err(), "%2Z should error");
3251    }
3252
3253    /// url_decode with invalid UTF-8 from percent-decoded bytes
3254    #[test]
3255    fn url_decode_invalid_utf8_percent() {
3256        // %80%81 are not valid UTF-8 start bytes
3257        let result = url_decode("%80%81");
3258        assert!(result.is_err(), "invalid UTF-8 bytes should error");
3259    }
3260
3261    /// url_decode with percent-encoded chars in all positions
3262    #[test]
3263    fn url_decode_percent_everywhere() {
3264        assert_eq!(url_decode("%41%42%43").unwrap(), "ABC");
3265        assert_eq!(url_decode("%61").unwrap(), "a");
3266        assert_eq!(url_decode("x%2Fy%2Fz").unwrap(), "x/y/z");
3267    }
3268
3269    /// url_decode with bare percent at various positions
3270    #[test]
3271    fn url_decode_bare_percent_middle() {
3272        assert!(url_decode("a%b").is_err(), "bare % in middle should error");
3273    }
3274
3275    /// T-02: url_decode with multi-byte UTF-8 (%C3%A9 -> e with acute)
3276    #[test]
3277    fn url_decode_multibyte_utf8() {
3278        let result = url_decode("caf%C3%A9").unwrap();
3279        assert_eq!(result, "caf\u{00e9}"); // cafe with accent
3280    }
3281
3282    // --- Audit gap tests ---
3283
3284    // #68: Config with postgresql:// scheme
3285    #[test]
3286    fn config_parse_postgresql_scheme() {
3287        let cfg = Config::from_url("postgresql://user:pass@localhost:5432/mydb").unwrap();
3288        assert_eq!(cfg.user, "user");
3289        assert_eq!(cfg.password, "pass");
3290        assert_eq!(cfg.host, "localhost");
3291        assert_eq!(cfg.port, 5432);
3292        assert_eq!(cfg.database, "mydb");
3293    }
3294
3295    // #69: Config URL without password
3296    #[test]
3297    fn config_parse_no_password_standalone() {
3298        let cfg = Config::from_url("postgres://admin@db.example.com/myapp").unwrap();
3299        assert_eq!(cfg.user, "admin");
3300        assert_eq!(cfg.password, "");
3301        assert_eq!(cfg.host, "db.example.com");
3302        assert_eq!(cfg.database, "myapp");
3303    }
3304
3305    // #70: Config URL with empty database (falls back to user)
3306    #[test]
3307    fn config_empty_database_falls_back_to_user() {
3308        let cfg = Config::from_url("postgres://testuser:pass@localhost").unwrap();
3309        assert_eq!(cfg.database, "testuser");
3310    }
3311
3312    // #71: Config URL with unknown sslmode error
3313    #[test]
3314    fn config_unknown_sslmode_error() {
3315        let result = Config::from_url("postgres://u:p@h/d?sslmode=verify-full");
3316        assert!(result.is_err());
3317        let err = result.unwrap_err().to_string();
3318        assert!(
3319            err.contains("unknown sslmode"),
3320            "should describe unknown sslmode: {err}"
3321        );
3322    }
3323
3324    // #72: Config URL with multiple query params
3325    #[test]
3326    fn config_multiple_query_params() {
3327        let cfg = Config::from_url(
3328            "postgres://user:pass@localhost/db?sslmode=disable&statement_timeout=60",
3329        )
3330        .unwrap();
3331        assert_eq!(cfg.ssl, SslMode::Disable);
3332        assert_eq!(cfg.statement_timeout_secs, 60);
3333    }
3334
3335    // #73: url_decode with invalid percent (%ZZ)
3336    #[test]
3337    fn url_decode_invalid_percent_zz() {
3338        let result = url_decode("abc%ZZ");
3339        assert!(result.is_err(), "%ZZ should error");
3340    }
3341
3342    // #74: url_decode with truncated percent (trailing %)
3343    #[test]
3344    fn url_decode_truncated_percent_trailing() {
3345        let result = url_decode("abc%");
3346        assert!(result.is_err(), "trailing % should error");
3347    }
3348
3349    // #75: url_decode producing invalid UTF-8
3350    #[test]
3351    fn url_decode_invalid_utf8() {
3352        // 0x80 alone is not valid UTF-8
3353        let result = url_decode("%80");
3354        assert!(result.is_err(), "invalid UTF-8 should error");
3355    }
3356
3357    // #76: Config SslMode::Require without tls feature
3358    #[cfg(not(feature = "tls"))]
3359    #[test]
3360    fn config_sslmode_require_without_tls_feature() {
3361        // The config parses fine, but validate doesn't check this.
3362        // The error occurs at connection time. Just verify parsing works.
3363        let cfg = Config::from_url("postgres://user:pass@localhost/db?sslmode=require").unwrap();
3364        assert_eq!(cfg.ssl, SslMode::Require);
3365    }
3366
3367    // #77: statement_name format: "s_" + 16 hex chars
3368    #[test]
3369    fn stmt_name_format_verification() {
3370        let name = make_stmt_name(0xDEADBEEFCAFEBABE);
3371        assert!(name.starts_with("s_"), "must start with s_");
3372        assert_eq!(name.len(), 18, "s_ (2) + 16 hex = 18");
3373        assert!(
3374            name[2..].chars().all(|c| c.is_ascii_hexdigit()),
3375            "remaining chars must be hex: {}",
3376            &*name
3377        );
3378    }
3379
3380    // stmt_name for 0 is all zeros
3381    #[test]
3382    fn stmt_name_zero() {
3383        let name = make_stmt_name(0);
3384        assert_eq!(&*name, "s_0000000000000000");
3385    }
3386
3387    // stmt_name for u64::MAX is all f's
3388    #[test]
3389    fn stmt_name_max() {
3390        let name = make_stmt_name(u64::MAX);
3391        assert_eq!(&*name, "s_ffffffffffffffff");
3392    }
3393
3394    // Config validation: empty host
3395    #[test]
3396    fn config_validate_empty_host() {
3397        let cfg = Config {
3398            host: String::new(),
3399            port: 5432,
3400            user: "user".into(),
3401            password: "pass".into(),
3402            database: "db".into(),
3403            ssl: SslMode::Disable,
3404            statement_timeout_secs: 30,
3405        };
3406        assert!(cfg.validate().is_err());
3407    }
3408
3409    // Config validation: empty user
3410    #[test]
3411    fn config_validate_empty_user() {
3412        let cfg = Config {
3413            host: "localhost".into(),
3414            port: 5432,
3415            user: String::new(),
3416            password: "pass".into(),
3417            database: "db".into(),
3418            ssl: SslMode::Disable,
3419            statement_timeout_secs: 30,
3420        };
3421        assert!(cfg.validate().is_err());
3422    }
3423
3424    // Config validation: empty database
3425    #[test]
3426    fn config_validate_empty_database() {
3427        let cfg = Config {
3428            host: "localhost".into(),
3429            port: 5432,
3430            user: "user".into(),
3431            password: "pass".into(),
3432            database: String::new(),
3433            ssl: SslMode::Disable,
3434            statement_timeout_secs: 30,
3435        };
3436        assert!(cfg.validate().is_err());
3437    }
3438
3439    // Config missing @ in URL
3440    #[test]
3441    fn config_missing_at_sign() {
3442        let result = Config::from_url("postgres://userpasslocalhost/db");
3443        assert!(result.is_err());
3444    }
3445
3446    // Config with custom port
3447    #[test]
3448    fn config_custom_port() {
3449        let cfg = Config::from_url("postgres://user:pass@localhost:5433/db").unwrap();
3450        assert_eq!(cfg.port, 5433);
3451    }
3452
3453    // Config with invalid port
3454    #[test]
3455    fn config_invalid_port() {
3456        let result = Config::from_url("postgres://user:pass@localhost:notaport/db");
3457        assert!(result.is_err());
3458    }
3459
3460    // --- Task 6: Notification buffering ---
3461
3462    #[test]
3463    fn notification_struct_fields() {
3464        let n = Notification {
3465            pid: 42,
3466            channel: "test_chan".to_owned(),
3467            payload: "hello".to_owned(),
3468        };
3469        assert_eq!(n.pid, 42);
3470        assert_eq!(n.channel, "test_chan");
3471        assert_eq!(n.payload, "hello");
3472    }
3473
3474    #[test]
3475    fn notification_clone() {
3476        let n = Notification {
3477            pid: 1,
3478            channel: "c".to_owned(),
3479            payload: "p".to_owned(),
3480        };
3481        let n2 = n.clone();
3482        assert_eq!(n2.pid, 1);
3483        assert_eq!(n2.channel, "c");
3484    }
3485
3486    #[test]
3487    fn notification_debug() {
3488        let n = Notification {
3489            pid: 1,
3490            channel: "c".to_owned(),
3491            payload: "p".to_owned(),
3492        };
3493        let dbg = format!("{n:?}");
3494        assert!(dbg.contains("Notification"));
3495    }
3496
3497    // --- Task 7: Statement cache size ---
3498
3499    #[test]
3500    fn stmt_info_has_last_used_counter() {
3501        let info = StmtInfo {
3502            name: "s_test".into(),
3503            columns: Arc::from(Vec::new()),
3504            last_used: 42,
3505            bind_template: None,
3506        };
3507        // Verify last_used counter is stored correctly
3508        assert_eq!(info.last_used, 42);
3509    }
3510
3511    // --- PgDataRow tests ---
3512
3513    /// Build a DataRow payload: [i16 num_cols] ([i32 len] [bytes])...
3514    /// len = -1 for NULL
3515    fn make_data_row(columns: &[Option<&[u8]>]) -> Vec<u8> {
3516        let mut buf = Vec::new();
3517        buf.extend_from_slice(&(columns.len() as i16).to_be_bytes());
3518        for col in columns {
3519            match col {
3520                Some(data) => {
3521                    buf.extend_from_slice(&(data.len() as i32).to_be_bytes());
3522                    buf.extend_from_slice(data);
3523                }
3524                None => {
3525                    buf.extend_from_slice(&(-1i32).to_be_bytes());
3526                }
3527            }
3528        }
3529        buf
3530    }
3531
3532    #[test]
3533    fn pg_data_row_get_i32() {
3534        let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
3535        let row = PgDataRow::new(&data).unwrap();
3536        assert_eq!(row.get_i32(0), Some(42));
3537        assert_eq!(row.column_count(), 1);
3538    }
3539
3540    #[test]
3541    fn pg_data_row_get_i64() {
3542        let data = make_data_row(&[Some(&12345i64.to_be_bytes())]);
3543        let row = PgDataRow::new(&data).unwrap();
3544        assert_eq!(row.get_i64(0), Some(12345));
3545    }
3546
3547    #[test]
3548    fn pg_data_row_get_str() {
3549        let data = make_data_row(&[Some(b"hello")]);
3550        let row = PgDataRow::new(&data).unwrap();
3551        assert_eq!(row.get_str(0), Some("hello"));
3552    }
3553
3554    #[test]
3555    fn pg_data_row_get_bytes() {
3556        let data = make_data_row(&[Some(&[0xDE, 0xAD, 0xBE, 0xEF])]);
3557        let row = PgDataRow::new(&data).unwrap();
3558        assert_eq!(row.get_bytes(0), Some(&[0xDE, 0xAD, 0xBE, 0xEF][..]));
3559    }
3560
3561    #[test]
3562    fn pg_data_row_get_bool() {
3563        let data = make_data_row(&[Some(&[1u8])]);
3564        let row = PgDataRow::new(&data).unwrap();
3565        assert_eq!(row.get_bool(0), Some(true));
3566
3567        let data = make_data_row(&[Some(&[0u8])]);
3568        let row = PgDataRow::new(&data).unwrap();
3569        assert_eq!(row.get_bool(0), Some(false));
3570    }
3571
3572    #[test]
3573    fn pg_data_row_get_f64() {
3574        let data = make_data_row(&[Some(&3.14f64.to_be_bytes())]);
3575        let row = PgDataRow::new(&data).unwrap();
3576        assert!((row.get_f64(0).unwrap() - 3.14).abs() < 1e-10);
3577    }
3578
3579    #[test]
3580    fn pg_data_row_null_column() {
3581        let data = make_data_row(&[None]);
3582        let row = PgDataRow::new(&data).unwrap();
3583        assert!(row.is_null(0));
3584        assert_eq!(row.get_i32(0), None);
3585        assert_eq!(row.get_str(0), None);
3586    }
3587
3588    #[test]
3589    fn pg_data_row_multiple_columns() {
3590        let data = make_data_row(&[
3591            Some(&42i32.to_be_bytes()),
3592            Some(b"alice"),
3593            Some(b"alice@example.com"),
3594            Some(&[1u8]),
3595            Some(&3.14f64.to_be_bytes()),
3596        ]);
3597        let row = PgDataRow::new(&data).unwrap();
3598        assert_eq!(row.column_count(), 5);
3599        assert_eq!(row.get_i32(0), Some(42));
3600        assert_eq!(row.get_str(1), Some("alice"));
3601        assert_eq!(row.get_str(2), Some("alice@example.com"));
3602        assert_eq!(row.get_bool(3), Some(true));
3603        assert!((row.get_f64(4).unwrap() - 3.14).abs() < 1e-10);
3604    }
3605
3606    #[test]
3607    fn pg_data_row_mixed_null() {
3608        let data = make_data_row(&[Some(&42i32.to_be_bytes()), None, Some(b"text")]);
3609        let row = PgDataRow::new(&data).unwrap();
3610        assert_eq!(row.get_i32(0), Some(42));
3611        assert!(row.is_null(1));
3612        assert_eq!(row.get_str(1), None);
3613        assert_eq!(row.get_str(2), Some("text"));
3614    }
3615
3616    #[test]
3617    fn pg_data_row_empty() {
3618        let data = make_data_row(&[]);
3619        let row = PgDataRow::new(&data).unwrap();
3620        assert_eq!(row.column_count(), 0);
3621    }
3622
3623    #[test]
3624    fn pg_data_row_too_short() {
3625        let data = vec![0u8]; // only 1 byte, need at least 2
3626        assert!(PgDataRow::new(&data).is_err());
3627    }
3628
3629    #[test]
3630    fn pg_data_row_truncated() {
3631        // Declare 2 columns but only include 1
3632        let mut data = Vec::new();
3633        data.extend_from_slice(&2i16.to_be_bytes());
3634        data.extend_from_slice(&4i32.to_be_bytes());
3635        data.extend_from_slice(&42i32.to_be_bytes());
3636        // Missing second column
3637        assert!(PgDataRow::new(&data).is_err());
3638    }
3639
3640    #[test]
3641    fn pg_data_row_get_i16() {
3642        let data = make_data_row(&[Some(&7i16.to_be_bytes())]);
3643        let row = PgDataRow::new(&data).unwrap();
3644        assert_eq!(row.get_i16(0), Some(7));
3645    }
3646
3647    #[test]
3648    fn pg_data_row_get_f32() {
3649        let data = make_data_row(&[Some(&2.5f32.to_be_bytes())]);
3650        let row = PgDataRow::new(&data).unwrap();
3651        assert!((row.get_f32(0).unwrap() - 2.5).abs() < 1e-6);
3652    }
3653
3654    #[test]
3655    fn pg_data_row_get_raw_null() {
3656        let data = make_data_row(&[None]);
3657        let row = PgDataRow::new(&data).unwrap();
3658        assert_eq!(row.get_raw(0), None);
3659    }
3660
3661    #[test]
3662    fn pg_data_row_get_raw_data() {
3663        let data = make_data_row(&[Some(&[1, 2, 3])]);
3664        let row = PgDataRow::new(&data).unwrap();
3665        assert_eq!(row.get_raw(0), Some(&[1u8, 2, 3][..]));
3666    }
3667
3668    #[test]
3669    fn pg_data_row_stack_alloc_16_columns() {
3670        // SmallVec<16> should not heap-allocate for <= 16 columns
3671        let cols: Vec<Option<&[u8]>> = (0..16).map(|_| Some(&[0u8][..])).collect();
3672        let data = make_data_row(&cols);
3673        let row = PgDataRow::new(&data).unwrap();
3674        assert_eq!(row.column_count(), 16);
3675        // All columns should be accessible
3676        for i in 0..16 {
3677            assert_eq!(row.get_raw(i), Some(&[0u8][..]));
3678        }
3679    }
3680
3681    // --- Inline sequential decode tests (validates the raw-bytes pattern) ---
3682
3683    /// Validate inline sequential decode of a 5-column DataRow
3684    /// (i32, str, str, bool, f64) — the same pattern the generated code uses.
3685    #[test]
3686    fn inline_sequential_decode_five_columns() {
3687        let data = make_data_row(&[
3688            Some(&42i32.to_be_bytes()),
3689            Some(b"alice"),
3690            Some(b"alice@example.com"),
3691            Some(&[1u8]),
3692            Some(&3.14f64.to_be_bytes()),
3693        ]);
3694
3695        // Simulate generated inline decode
3696        let mut pos: usize = 2; // skip i16 num_cols
3697
3698        // Column 0: i32
3699        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3700        pos += 4;
3701        assert_eq!(len, 4);
3702        let id = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3703        pos += len as usize;
3704        assert_eq!(id, 42);
3705
3706        // Column 1: str
3707        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3708        pos += 4;
3709        assert_eq!(len, 5);
3710        let name = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
3711        pos += len as usize;
3712        assert_eq!(name, "alice");
3713
3714        // Column 2: str
3715        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3716        pos += 4;
3717        let email = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
3718        pos += len as usize;
3719        assert_eq!(email, "alice@example.com");
3720
3721        // Column 3: bool
3722        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3723        pos += 4;
3724        assert_eq!(len, 1);
3725        let active = data[pos] != 0;
3726        pos += len as usize;
3727        assert!(active);
3728
3729        // Column 4: f64
3730        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3731        pos += 4;
3732        assert_eq!(len, 8);
3733        let score = f64::from_be_bytes([
3734            data[pos],
3735            data[pos + 1],
3736            data[pos + 2],
3737            data[pos + 3],
3738            data[pos + 4],
3739            data[pos + 5],
3740            data[pos + 6],
3741            data[pos + 7],
3742        ]);
3743        pos += len as usize;
3744        assert!((score - 3.14).abs() < 1e-10);
3745        assert_eq!(pos, data.len());
3746    }
3747
3748    /// Validate inline decode with NULL columns.
3749    #[test]
3750    fn inline_sequential_decode_with_nulls() {
3751        let data = make_data_row(&[
3752            Some(&42i32.to_be_bytes()),
3753            None, // NULL name
3754            Some(b"text"),
3755        ]);
3756
3757        let mut pos: usize = 2;
3758
3759        // Column 0: i32 NOT NULL
3760        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3761        pos += 4;
3762        let id = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3763        pos += len as usize;
3764        assert_eq!(id, 42);
3765
3766        // Column 1: str NULLABLE -> None
3767        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3768        pos += 4;
3769        let name: Option<&str> = if len < 0 {
3770            None
3771        } else {
3772            let s = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
3773            pos += len as usize;
3774            Some(s)
3775        };
3776        assert!(name.is_none());
3777
3778        // Column 2: str NOT NULL
3779        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3780        pos += 4;
3781        let txt = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
3782        pos += len as usize;
3783        assert_eq!(txt, "text");
3784        assert_eq!(pos, data.len());
3785    }
3786
3787    /// Validate inline decode with all supported scalar types.
3788    #[test]
3789    fn inline_sequential_decode_all_scalar_types() {
3790        let data = make_data_row(&[
3791            Some(&[1u8]),                  // bool
3792            Some(&7i16.to_be_bytes()),     // i16
3793            Some(&42i32.to_be_bytes()),    // i32
3794            Some(&12345i64.to_be_bytes()), // i64
3795            Some(&2.5f32.to_be_bytes()),   // f32
3796            Some(&3.14f64.to_be_bytes()),  // f64
3797        ]);
3798
3799        let mut pos: usize = 2;
3800
3801        // bool
3802        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3803        pos += 4;
3804        let v_bool = data[pos] != 0;
3805        pos += len as usize;
3806        assert!(v_bool);
3807
3808        // i16
3809        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3810        pos += 4;
3811        let v_i16 = i16::from_be_bytes([data[pos], data[pos + 1]]);
3812        pos += len as usize;
3813        assert_eq!(v_i16, 7);
3814
3815        // i32
3816        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3817        pos += 4;
3818        let v_i32 = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3819        pos += len as usize;
3820        assert_eq!(v_i32, 42);
3821
3822        // i64
3823        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3824        pos += 4;
3825        let v_i64 = i64::from_be_bytes([
3826            data[pos],
3827            data[pos + 1],
3828            data[pos + 2],
3829            data[pos + 3],
3830            data[pos + 4],
3831            data[pos + 5],
3832            data[pos + 6],
3833            data[pos + 7],
3834        ]);
3835        pos += len as usize;
3836        assert_eq!(v_i64, 12345);
3837
3838        // f32
3839        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3840        pos += 4;
3841        let v_f32 = f32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3842        pos += len as usize;
3843        assert!((v_f32 - 2.5).abs() < 1e-6);
3844
3845        // f64
3846        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3847        pos += 4;
3848        let v_f64 = f64::from_be_bytes([
3849            data[pos],
3850            data[pos + 1],
3851            data[pos + 2],
3852            data[pos + 3],
3853            data[pos + 4],
3854            data[pos + 5],
3855            data[pos + 6],
3856            data[pos + 7],
3857        ]);
3858        pos += len as usize;
3859        assert!((v_f64 - 3.14).abs() < 1e-10);
3860        assert_eq!(pos, data.len());
3861    }
3862
3863    /// Validate PgDataRow::new is public (callable from external code).
3864    #[test]
3865    fn pg_data_row_new_is_public() {
3866        let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
3867        // This compiles because PgDataRow::new is pub.
3868        let row = PgDataRow::new(&data).unwrap();
3869        assert_eq!(row.get_i32(0), Some(42));
3870    }
3871
3872    /// Inline decode produces identical results to PgDataRow for mixed data.
3873    #[test]
3874    fn inline_decode_matches_pgdatarow() {
3875        let data = make_data_row(&[
3876            Some(&99i32.to_be_bytes()),
3877            Some(b"hello world"),
3878            None,
3879            Some(&[0u8]),
3880            Some(&1.23f64.to_be_bytes()),
3881        ]);
3882
3883        // PgDataRow results
3884        let row = PgDataRow::new(&data).unwrap();
3885        let dr_i32 = row.get_i32(0);
3886        let dr_str = row.get_str(1);
3887        let dr_null = row.get_str(2);
3888        let dr_bool = row.get_bool(3);
3889        let dr_f64 = row.get_f64(4);
3890
3891        // Inline results
3892        let mut pos: usize = 2;
3893
3894        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3895        pos += 4;
3896        let in_i32 = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3897        pos += len as usize;
3898
3899        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3900        pos += 4;
3901        let in_str = std::str::from_utf8(&data[pos..pos + len as usize]).unwrap();
3902        pos += len as usize;
3903
3904        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3905        pos += 4;
3906        let in_null: Option<&str> = if len < 0 { None } else { unreachable!() };
3907
3908        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3909        pos += 4;
3910        let in_bool = data[pos] != 0;
3911        pos += len as usize;
3912
3913        let len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
3914        pos += 4;
3915        let in_f64 = f64::from_be_bytes([
3916            data[pos],
3917            data[pos + 1],
3918            data[pos + 2],
3919            data[pos + 3],
3920            data[pos + 4],
3921            data[pos + 5],
3922            data[pos + 6],
3923            data[pos + 7],
3924        ]);
3925        pos += len as usize;
3926
3927        // Both paths must produce identical results
3928        assert_eq!(dr_i32, Some(in_i32));
3929        assert_eq!(dr_str, Some(in_str));
3930        assert_eq!(dr_null, in_null);
3931        assert_eq!(dr_bool, Some(in_bool));
3932        assert!((dr_f64.unwrap() - in_f64).abs() < 1e-15);
3933        assert_eq!(pos, data.len());
3934    }
3935
3936    // --- Unix domain socket (UDS) tests ---
3937
3938    #[test]
3939    fn config_host_is_uds_absolute_path() {
3940        let cfg = Config {
3941            host: "/tmp".into(),
3942            port: 5432,
3943            user: "user".into(),
3944            password: "".into(),
3945            database: "db".into(),
3946            ssl: SslMode::Disable,
3947            statement_timeout_secs: 30,
3948        };
3949        assert!(cfg.host_is_uds());
3950        assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5432");
3951    }
3952
3953    #[test]
3954    fn config_host_is_uds_var_run() {
3955        let cfg = Config {
3956            host: "/var/run/postgresql".into(),
3957            port: 5433,
3958            user: "user".into(),
3959            password: "".into(),
3960            database: "db".into(),
3961            ssl: SslMode::Disable,
3962            statement_timeout_secs: 30,
3963        };
3964        assert!(cfg.host_is_uds());
3965        assert_eq!(cfg.uds_path(), "/var/run/postgresql/.s.PGSQL.5433");
3966    }
3967
3968    #[test]
3969    fn config_host_is_not_uds_for_hostname() {
3970        let cfg = Config {
3971            host: "localhost".into(),
3972            port: 5432,
3973            user: "user".into(),
3974            password: "".into(),
3975            database: "db".into(),
3976            ssl: SslMode::Disable,
3977            statement_timeout_secs: 30,
3978        };
3979        assert!(!cfg.host_is_uds());
3980    }
3981
3982    #[test]
3983    fn config_host_is_not_uds_for_ip() {
3984        let cfg = Config {
3985            host: "127.0.0.1".into(),
3986            port: 5432,
3987            user: "user".into(),
3988            password: "".into(),
3989            database: "db".into(),
3990            ssl: SslMode::Disable,
3991            statement_timeout_secs: 30,
3992        };
3993        assert!(!cfg.host_is_uds());
3994    }
3995
3996    #[test]
3997    fn config_parse_uds_host_query_param() {
3998        let cfg = Config::from_url("postgres://user@localhost/mydb?host=/tmp").unwrap();
3999        assert_eq!(cfg.host, "/tmp");
4000        assert!(cfg.host_is_uds());
4001        assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5432");
4002        assert_eq!(cfg.database, "mydb");
4003        assert_eq!(cfg.user, "user");
4004    }
4005
4006    #[test]
4007    fn config_parse_uds_host_query_param_custom_port() {
4008        let cfg = Config::from_url("postgres://user@localhost:5433/mydb?host=/var/run/postgresql")
4009            .unwrap();
4010        assert_eq!(cfg.host, "/var/run/postgresql");
4011        assert_eq!(cfg.port, 5433);
4012        assert_eq!(cfg.uds_path(), "/var/run/postgresql/.s.PGSQL.5433");
4013    }
4014
4015    #[test]
4016    fn config_parse_uds_host_with_other_params() {
4017        let cfg = Config::from_url(
4018            "postgres://user@localhost/db?host=/tmp&sslmode=disable&statement_timeout=60",
4019        )
4020        .unwrap();
4021        assert_eq!(cfg.host, "/tmp");
4022        assert!(cfg.host_is_uds());
4023        assert_eq!(cfg.ssl, SslMode::Disable);
4024        assert_eq!(cfg.statement_timeout_secs, 60);
4025    }
4026
4027    #[test]
4028    fn config_parse_uds_host_percent_encoded() {
4029        // %2F = '/'
4030        let cfg = Config::from_url("postgres://user@localhost/db?host=%2Ftmp").unwrap();
4031        assert_eq!(cfg.host, "/tmp");
4032        assert!(cfg.host_is_uds());
4033    }
4034
4035    #[test]
4036    fn config_parse_tcp_host_not_overridden_without_param() {
4037        // No ?host= param: hostname from URL is used (TCP)
4038        let cfg = Config::from_url("postgres://user@myserver/db").unwrap();
4039        assert_eq!(cfg.host, "myserver");
4040        assert!(!cfg.host_is_uds());
4041    }
4042
4043    #[test]
4044    fn config_parse_uds_host_overrides_url_hostname() {
4045        // ?host= overrides even an explicit hostname
4046        let cfg = Config::from_url("postgres://user@db.example.com/mydb?host=/var/run/postgresql")
4047            .unwrap();
4048        assert_eq!(cfg.host, "/var/run/postgresql");
4049        assert!(cfg.host_is_uds());
4050    }
4051
4052    #[test]
4053    fn config_parse_uds_empty_url_host() {
4054        // postgres:///dbname?host=/tmp — empty hostname before /, host from param
4055        let cfg = Config::from_url("postgres://user@/mydb?host=/tmp").unwrap();
4056        assert_eq!(cfg.host, "/tmp");
4057        assert!(cfg.host_is_uds());
4058        assert_eq!(cfg.database, "mydb");
4059    }
4060
4061    // ===============================================================
4062    // PgDataRow — comprehensive tests
4063    // ===============================================================
4064
4065    #[test]
4066    fn pg_data_row_all_null_columns() {
4067        let data = make_data_row(&[None, None, None, None, None]);
4068        let row = PgDataRow::new(&data).unwrap();
4069        assert_eq!(row.column_count(), 5);
4070        for i in 0..5 {
4071            assert!(row.is_null(i), "column {i} should be null");
4072            assert_eq!(row.get_raw(i), None);
4073            assert_eq!(row.get_i32(i), None);
4074            assert_eq!(row.get_i64(i), None);
4075            assert_eq!(row.get_str(i), None);
4076            assert_eq!(row.get_bool(i), None);
4077            assert_eq!(row.get_f64(i), None);
4078        }
4079    }
4080
4081    #[test]
4082    fn pg_data_row_very_long_text() {
4083        let long_text = "x".repeat(2048);
4084        let data = make_data_row(&[Some(long_text.as_bytes())]);
4085        let row = PgDataRow::new(&data).unwrap();
4086        assert_eq!(row.get_str(0), Some(long_text.as_str()));
4087    }
4088
4089    #[test]
4090    fn pg_data_row_empty_text() {
4091        let data = make_data_row(&[Some(b"")]);
4092        let row = PgDataRow::new(&data).unwrap();
4093        assert!(!row.is_null(0));
4094        assert_eq!(row.get_str(0), Some(""));
4095        assert_eq!(row.get_bytes(0), Some(&[][..]));
4096    }
4097
4098    #[test]
4099    fn pg_data_row_20_columns_exceeds_inline() {
4100        let col_data: Vec<[u8; 4]> = (0..20).map(|i: i32| i.to_be_bytes()).collect();
4101        let cols: Vec<Option<&[u8]>> = col_data.iter().map(|b| Some(b.as_slice())).collect();
4102        let data = make_data_row(&cols);
4103        let row = PgDataRow::new(&data).unwrap();
4104        assert_eq!(row.column_count(), 20);
4105        for i in 0..20 {
4106            assert_eq!(row.get_i32(i), Some(i as i32));
4107        }
4108    }
4109
4110    #[test]
4111    fn pg_data_row_is_null_each_position() {
4112        // 3 columns: data, null, data
4113        let data = make_data_row(&[Some(&1i32.to_be_bytes()), None, Some(&3i32.to_be_bytes())]);
4114        let row = PgDataRow::new(&data).unwrap();
4115        assert!(!row.is_null(0));
4116        assert!(row.is_null(1));
4117        assert!(!row.is_null(2));
4118    }
4119
4120    #[test]
4121    fn pg_data_row_negative_column_count() {
4122        let data = (-1i16).to_be_bytes();
4123        assert!(PgDataRow::new(&data).is_err());
4124    }
4125
4126    #[test]
4127    fn pg_data_row_get_str_invalid_utf8() {
4128        let invalid_utf8 = &[0xFF, 0xFE, 0x80];
4129        let data = make_data_row(&[Some(invalid_utf8)]);
4130        let row = PgDataRow::new(&data).unwrap();
4131        // get_str returns None for invalid UTF-8, but get_bytes returns the raw data
4132        assert_eq!(row.get_str(0), None);
4133        assert_eq!(row.get_bytes(0), Some(&[0xFF, 0xFE, 0x80][..]));
4134    }
4135
4136    #[test]
4137    fn pg_data_row_get_i32_wrong_length() {
4138        // i32 needs exactly 4 bytes; give it 2
4139        let data = make_data_row(&[Some(&7i16.to_be_bytes())]);
4140        let row = PgDataRow::new(&data).unwrap();
4141        assert_eq!(row.get_i32(0), None); // 2 bytes != 4 bytes
4142        assert_eq!(row.get_i16(0), Some(7)); // but i16 works
4143    }
4144
4145    #[test]
4146    fn pg_data_row_get_i64_wrong_length() {
4147        // i64 needs 8 bytes; give it 4
4148        let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
4149        let row = PgDataRow::new(&data).unwrap();
4150        assert_eq!(row.get_i64(0), None);
4151    }
4152
4153    #[test]
4154    fn pg_data_row_get_f64_wrong_length() {
4155        let data = make_data_row(&[Some(&2.5f32.to_be_bytes())]);
4156        let row = PgDataRow::new(&data).unwrap();
4157        assert_eq!(row.get_f64(0), None); // 4 bytes != 8 bytes
4158    }
4159
4160    #[test]
4161    fn pg_data_row_get_f32_wrong_length() {
4162        let data = make_data_row(&[Some(&3.14f64.to_be_bytes())]);
4163        let row = PgDataRow::new(&data).unwrap();
4164        assert_eq!(row.get_f32(0), None); // 8 bytes != 4 bytes
4165    }
4166
4167    #[test]
4168    fn pg_data_row_get_bool_wrong_length() {
4169        // bool needs 1 byte; give it 4
4170        let data = make_data_row(&[Some(&42i32.to_be_bytes())]);
4171        let row = PgDataRow::new(&data).unwrap();
4172        assert_eq!(row.get_bool(0), None);
4173    }
4174
4175    #[test]
4176    fn pg_data_row_unicode_text() {
4177        let texts = [
4178            "\u{1F600}\u{1F4A9}\u{1F680}", // emoji
4179            "\u{4e16}\u{754c}",            // CJK
4180            "\u{0645}\u{0631}\u{062D}",    // Arabic
4181            "\u{1F468}\u{200D}\u{1F469}",  // ZWJ
4182        ];
4183        for text in &texts {
4184            let data = make_data_row(&[Some(text.as_bytes())]);
4185            let row = PgDataRow::new(&data).unwrap();
4186            assert_eq!(row.get_str(0), Some(*text));
4187        }
4188    }
4189
4190    #[test]
4191    fn pg_data_row_i32_boundary_values() {
4192        for &val in &[i32::MIN, -1, 0, 1, i32::MAX] {
4193            let data = make_data_row(&[Some(&val.to_be_bytes())]);
4194            let row = PgDataRow::new(&data).unwrap();
4195            assert_eq!(row.get_i32(0), Some(val), "failed for {val}");
4196        }
4197    }
4198
4199    #[test]
4200    fn pg_data_row_i64_boundary_values() {
4201        for &val in &[i64::MIN, -1, 0, 1, i64::MAX] {
4202            let data = make_data_row(&[Some(&val.to_be_bytes())]);
4203            let row = PgDataRow::new(&data).unwrap();
4204            assert_eq!(row.get_i64(0), Some(val), "failed for {val}");
4205        }
4206    }
4207
4208    #[test]
4209    fn pg_data_row_f64_special_values() {
4210        let data = make_data_row(&[Some(&f64::INFINITY.to_be_bytes())]);
4211        let row = PgDataRow::new(&data).unwrap();
4212        assert_eq!(row.get_f64(0), Some(f64::INFINITY));
4213
4214        let data = make_data_row(&[Some(&f64::NEG_INFINITY.to_be_bytes())]);
4215        let row = PgDataRow::new(&data).unwrap();
4216        assert_eq!(row.get_f64(0), Some(f64::NEG_INFINITY));
4217
4218        let data = make_data_row(&[Some(&f64::NAN.to_be_bytes())]);
4219        let row = PgDataRow::new(&data).unwrap();
4220        assert!(row.get_f64(0).unwrap().is_nan());
4221    }
4222
4223    #[test]
4224    fn pg_data_row_f32_special_values() {
4225        let data = make_data_row(&[Some(&f32::INFINITY.to_be_bytes())]);
4226        let row = PgDataRow::new(&data).unwrap();
4227        assert_eq!(row.get_f32(0), Some(f32::INFINITY));
4228
4229        let data = make_data_row(&[Some(&f32::NAN.to_be_bytes())]);
4230        let row = PgDataRow::new(&data).unwrap();
4231        assert!(row.get_f32(0).unwrap().is_nan());
4232    }
4233
4234    #[test]
4235    fn pg_data_row_i16_boundary_values() {
4236        for &val in &[i16::MIN, -1, 0, 1, i16::MAX] {
4237            let data = make_data_row(&[Some(&val.to_be_bytes())]);
4238            let row = PgDataRow::new(&data).unwrap();
4239            assert_eq!(row.get_i16(0), Some(val));
4240        }
4241    }
4242
4243    // ===============================================================
4244    // DataRow flat parsing — comprehensive edge cases
4245    // ===============================================================
4246
4247    #[test]
4248    fn data_row_flat_all_null() {
4249        let mut arena = Arena::new();
4250        let mut out = Vec::new();
4251        let mut data = Vec::new();
4252        data.extend_from_slice(&4i16.to_be_bytes());
4253        for _ in 0..4 {
4254            data.extend_from_slice(&(-1i32).to_be_bytes());
4255        }
4256        parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
4257        assert_eq!(out.len(), 4);
4258        for (_, len) in &out {
4259            assert_eq!(*len, -1);
4260        }
4261    }
4262
4263    #[test]
4264    fn data_row_flat_long_text() {
4265        let mut arena = Arena::new();
4266        let mut out = Vec::new();
4267        let long = vec![b'A'; 1024];
4268        let mut data = Vec::new();
4269        data.extend_from_slice(&1i16.to_be_bytes());
4270        data.extend_from_slice(&(long.len() as i32).to_be_bytes());
4271        data.extend_from_slice(&long);
4272        parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
4273        assert_eq!(out[0].1, 1024);
4274        let stored = arena.get(out[0].0, 1024);
4275        assert!(stored.iter().all(|&b| b == b'A'));
4276    }
4277
4278    #[test]
4279    fn data_row_flat_empty_text() {
4280        let mut arena = Arena::new();
4281        let mut out = Vec::new();
4282        let mut data = Vec::new();
4283        data.extend_from_slice(&1i16.to_be_bytes());
4284        data.extend_from_slice(&0i32.to_be_bytes()); // 0-length, not null
4285        parse_data_row_flat(&data, &mut arena, &mut out).unwrap();
4286        assert_eq!(out[0].1, 0);
4287    }
4288
4289    // ===============================================================
4290    // QueryResult edge cases
4291    // ===============================================================
4292
4293    #[test]
4294    fn query_result_from_parts() {
4295        let result = QueryResult::from_parts(vec![(0, 4), (0, -1)], 2, Arc::from(Vec::new()), 5);
4296        assert_eq!(result.len(), 1);
4297        assert_eq!(result.num_cols, 2);
4298        assert_eq!(result.affected_rows, 5);
4299    }
4300
4301    #[test]
4302    fn query_result_affected_rows() {
4303        let result = QueryResult {
4304            all_col_offsets: vec![],
4305            num_cols: 0,
4306            columns: Arc::from(Vec::new()),
4307            affected_rows: 42,
4308        };
4309        assert_eq!(result.affected_rows, 42);
4310        assert!(result.is_empty());
4311    }
4312
4313    // ===============================================================
4314    // DriverError edge cases
4315    // ===============================================================
4316
4317    #[test]
4318    fn driver_error_server_with_hint() {
4319        let e = DriverError::Server {
4320            code: "42601".into(),
4321            message: "syntax error".into(),
4322            detail: None,
4323            hint: Some("check your SQL".into()),
4324            position: Some(10),
4325        };
4326        let s = e.to_string();
4327        assert!(s.contains("HINT: check your SQL"));
4328        assert!(s.contains("(at position 10)"));
4329    }
4330
4331    #[test]
4332    fn driver_error_server_with_all_fields() {
4333        let e = DriverError::Server {
4334            code: "23505".into(),
4335            message: "unique violation".into(),
4336            detail: Some("Key (id)=(1) already exists.".into()),
4337            hint: Some("change the id".into()),
4338            position: Some(1),
4339        };
4340        let s = e.to_string();
4341        assert!(s.contains("23505"));
4342        assert!(s.contains("unique violation"));
4343        assert!(s.contains("Key (id)=(1) already exists."));
4344        assert!(s.contains("change the id"));
4345        assert!(s.contains("(at position 1)"));
4346    }
4347
4348    // ===============================================================
4349    // Config edge cases
4350    // ===============================================================
4351
4352    #[test]
4353    fn config_statement_timeout_default() {
4354        let cfg = Config::from_url("postgres://user:pass@localhost/db").unwrap();
4355        assert_eq!(cfg.statement_timeout_secs, 30);
4356    }
4357
4358    #[test]
4359    fn config_statement_timeout_custom() {
4360        let cfg =
4361            Config::from_url("postgres://user:pass@localhost/db?statement_timeout=120").unwrap();
4362        assert_eq!(cfg.statement_timeout_secs, 120);
4363    }
4364
4365    #[test]
4366    fn config_statement_timeout_zero() {
4367        let cfg =
4368            Config::from_url("postgres://user:pass@localhost/db?statement_timeout=0").unwrap();
4369        assert_eq!(cfg.statement_timeout_secs, 0);
4370    }
4371
4372    #[test]
4373    fn config_statement_timeout_invalid_falls_back() {
4374        let cfg =
4375            Config::from_url("postgres://user:pass@localhost/db?statement_timeout=notanumber")
4376                .unwrap();
4377        assert_eq!(cfg.statement_timeout_secs, 30); // fallback
4378    }
4379
4380    #[test]
4381    fn config_uds_path_format() {
4382        let cfg = Config::from_url("postgres://user@localhost/db?host=/tmp").unwrap();
4383        assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5432");
4384    }
4385
4386    #[test]
4387    fn config_uds_path_custom_port() {
4388        let cfg = Config::from_url("postgres://user@localhost:5433/db?host=/tmp").unwrap();
4389        assert_eq!(cfg.uds_path(), "/tmp/.s.PGSQL.5433");
4390    }
4391
4392    // ===============================================================
4393    // url_decode edge cases
4394    // ===============================================================
4395
4396    #[test]
4397    fn url_decode_empty_string() {
4398        assert_eq!(url_decode("").unwrap(), "");
4399    }
4400
4401    #[test]
4402    fn url_decode_no_encoding() {
4403        assert_eq!(url_decode("hello").unwrap(), "hello");
4404    }
4405
4406    #[test]
4407    fn url_decode_all_ascii_hex() {
4408        // Uppercase hex
4409        assert_eq!(url_decode("%2F").unwrap(), "/");
4410        assert_eq!(url_decode("%2f").unwrap(), "/");
4411    }
4412
4413    // ===============================================================
4414    // hash_sql edge cases
4415    // ===============================================================
4416
4417    #[test]
4418    fn hash_sql_empty() {
4419        let _h = hash_sql(""); // should not panic
4420    }
4421
4422    #[test]
4423    fn hash_sql_whitespace_only() {
4424        let h = hash_sql("   ");
4425        assert_ne!(h, hash_sql(""));
4426    }
4427
4428    #[test]
4429    fn hash_sql_very_long() {
4430        let long_sql = "SELECT ".to_string() + &"x".repeat(10_000);
4431        let h = hash_sql(&long_sql);
4432        assert_eq!(h, hash_sql(&long_sql));
4433    }
4434
4435    #[test]
4436    fn hash_sql_unicode() {
4437        let h = hash_sql("SELECT '\u{1F600}'");
4438        assert_ne!(h, hash_sql("SELECT 'x'"));
4439    }
4440}