Skip to main content

chopin_pg/
connection.rs

1//! PostgreSQL connection with poll-based non-blocking I/O.
2//!
3//! Designed for integration with a thread-per-core event loop. The socket is set
4//! to **non-blocking mode** at connect time; all reads and writes go through
5//! `try_fill_read_buf` / `try_flush_write_buf` which return `WouldBlock` when
6//! the socket is not ready. Higher-level methods use `poll_read` / `poll_write`
7//! with a configurable timeout so the caller can integrate with epoll/kqueue.
8//!
9//! Features:
10//! - **Non-blocking I/O** with application-level timeouts
11//! - SCRAM-SHA-256 and cleartext authentication
12//! - Extended Query Protocol with implicit statement caching
13//! - Transaction support with safe closure-based API
14//! - COPY IN (writer) and COPY OUT (reader)
15//! - LISTEN/NOTIFY with notification buffering
16//! - Proper affected row count from CommandComplete
17//! - Raw socket fd accessor for event-loop registration
18
19use std::collections::VecDeque;
20use std::io::{Read, Write};
21use std::net::TcpStream;
22#[cfg(unix)]
23use std::os::unix::io::AsRawFd;
24#[cfg(unix)]
25use std::os::unix::net::UnixStream;
26use std::rc::Rc;
27use std::time::{Duration, Instant};
28
29use crate::auth::ScramClient;
30use crate::codec;
31use crate::error::{PgError, PgResult};
32use crate::protocol::*;
33use crate::row::Row;
34use crate::statement::StatementCache;
35#[cfg(feature = "tls")]
36use crate::tls;
37use crate::types::{PgValue, ToSql};
38
39/// Default I/O timeout for poll operations (5 seconds).
40const DEFAULT_IO_TIMEOUT: Duration = Duration::from_secs(5);
41
42// ─── Stream Abstraction ──────────────────────────────────────
43
44/// Unified stream type supporting TCP, Unix domain sockets, and TLS.
45enum PgStream {
46    Tcp(TcpStream),
47    #[cfg(unix)]
48    Unix(UnixStream),
49    #[cfg(feature = "tls")]
50    Tls(tls::TlsStream),
51}
52
53impl Read for PgStream {
54    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
55        match self {
56            PgStream::Tcp(s) => s.read(buf),
57            #[cfg(unix)]
58            PgStream::Unix(s) => s.read(buf),
59            #[cfg(feature = "tls")]
60            PgStream::Tls(s) => s.read(buf),
61        }
62    }
63}
64
65impl Write for PgStream {
66    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
67        match self {
68            PgStream::Tcp(s) => s.write(buf),
69            #[cfg(unix)]
70            PgStream::Unix(s) => s.write(buf),
71            #[cfg(feature = "tls")]
72            PgStream::Tls(s) => s.write(buf),
73        }
74    }
75
76    fn flush(&mut self) -> std::io::Result<()> {
77        match self {
78            PgStream::Tcp(s) => s.flush(),
79            #[cfg(unix)]
80            PgStream::Unix(s) => s.flush(),
81            #[cfg(feature = "tls")]
82            PgStream::Tls(s) => s.flush(),
83        }
84    }
85
86    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
87        match self {
88            PgStream::Tcp(s) => s.write_all(buf),
89            #[cfg(unix)]
90            PgStream::Unix(s) => s.write_all(buf),
91            #[cfg(feature = "tls")]
92            PgStream::Tls(s) => s.write_all(buf),
93        }
94    }
95}
96
97impl PgStream {
98    fn set_nonblocking(&self, nonblocking: bool) -> std::io::Result<()> {
99        match self {
100            PgStream::Tcp(s) => s.set_nonblocking(nonblocking),
101            #[cfg(unix)]
102            PgStream::Unix(s) => s.set_nonblocking(nonblocking),
103            #[cfg(feature = "tls")]
104            PgStream::Tls(s) => s.set_nonblocking(nonblocking),
105        }
106    }
107
108    #[cfg(unix)]
109    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
110        match self {
111            PgStream::Tcp(s) => s.as_raw_fd(),
112            PgStream::Unix(s) => s.as_raw_fd(),
113            #[cfg(feature = "tls")]
114            PgStream::Tls(s) => s.as_raw_fd(),
115        }
116    }
117
118    /// Get the SHA-256 hash of the server's TLS certificate for channel binding.
119    /// Returns `None` if the connection is not TLS or no cert is available.
120    #[cfg(feature = "tls")]
121    fn tls_server_cert_hash(&self) -> Option<Vec<u8>> {
122        match self {
123            PgStream::Tls(s) => s.server_cert_hash(),
124            _ => None,
125        }
126    }
127}
128
129/// Connection configuration.
130#[derive(Debug, Clone)]
131pub struct PgConfig {
132    pub host: String,
133    pub port: u16,
134    pub user: String,
135    pub password: String,
136    pub database: String,
137    /// Optional Unix domain socket directory.
138    /// When set, connect via `<socket_dir>/.s.PGSQL.<port>` instead of TCP.
139    pub socket_dir: Option<String>,
140    /// SSL/TLS mode. Only effective when the `tls` feature is enabled.
141    /// Default: `Prefer` (try TLS, fall back to plaintext).
142    #[cfg(feature = "tls")]
143    pub ssl_mode: tls::SslMode,
144}
145
146impl PgConfig {
147    pub fn new(host: &str, port: u16, user: &str, password: &str, database: &str) -> Self {
148        Self {
149            host: host.to_string(),
150            port,
151            user: user.to_string(),
152            password: password.to_string(),
153            database: database.to_string(),
154            socket_dir: None,
155            #[cfg(feature = "tls")]
156            ssl_mode: tls::SslMode::default(),
157        }
158    }
159
160    /// Set a Unix domain socket directory for the connection.
161    /// The actual socket path will be `<dir>/.s.PGSQL.<port>`.
162    pub fn with_socket_dir(mut self, dir: &str) -> Self {
163        self.socket_dir = Some(dir.to_string());
164        self
165    }
166
167    /// Set the SSL/TLS mode for the connection.
168    #[cfg(feature = "tls")]
169    pub fn with_ssl_mode(mut self, mode: tls::SslMode) -> Self {
170        self.ssl_mode = mode;
171        self
172    }
173
174    /// Parse from a connection string: `postgres://user:pass@host:port/db`
175    ///
176    /// For Unix domain sockets, use a path as the host:
177    /// `postgres://user:pass@%2Fvar%2Frun%2Fpostgresql/db`  (URL-encoded slashes)
178    /// or `postgres://user:pass@/db?host=/var/run/postgresql`
179    pub fn from_url(url: &str) -> PgResult<Self> {
180        let url = url
181            .strip_prefix("postgres://")
182            .or_else(|| url.strip_prefix("postgresql://"))
183            .ok_or_else(|| PgError::Protocol("Invalid URL scheme".to_string()))?;
184
185        // user:pass@host:port/db
186        let (userpass, hostdb) = url
187            .split_once('@')
188            .ok_or_else(|| PgError::Protocol("Missing @ in URL".to_string()))?;
189        let (user, password) = userpass.split_once(':').unwrap_or((userpass, ""));
190
191        // Check for ?host= query parameter (Unix socket)
192        let (hostdb_part, query_part) = hostdb.split_once('?').unwrap_or((hostdb, ""));
193
194        let (hostport, database) = hostdb_part
195            .split_once('/')
196            .ok_or_else(|| PgError::Protocol("Missing database in URL".to_string()))?;
197
198        // Parse query params for socket dir and sslmode
199        let mut socket_dir: Option<String> = None;
200        #[cfg(feature = "tls")]
201        let mut ssl_mode = tls::SslMode::default();
202        if !query_part.is_empty() {
203            for param in query_part.split('&') {
204                if let Some(value) = param.strip_prefix("host=")
205                    && value.starts_with('/')
206                {
207                    socket_dir = Some(value.to_string());
208                }
209                #[cfg(feature = "tls")]
210                if let Some(value) = param.strip_prefix("sslmode=") {
211                    if let Some(mode) = tls::SslMode::parse(value) {
212                        ssl_mode = mode;
213                    }
214                }
215            }
216        }
217
218        // Decode percent-encoded host (e.g., %2Fvar%2Frun -> /var/run)
219        let decoded_host = percent_decode(hostport);
220        let is_unix_path = decoded_host.starts_with('/');
221        if is_unix_path {
222            socket_dir = Some(decoded_host);
223        }
224
225        let (host, port) = if socket_dir.is_some() {
226            // Unix socket — host is irrelevant, use default port
227            let port_str = if hostport.is_empty() || is_unix_path {
228                "5432"
229            } else {
230                hostport.rsplit_once(':').map(|(_, p)| p).unwrap_or("5432")
231            };
232            let port: u16 = port_str
233                .parse()
234                .map_err(|_| PgError::Protocol("Invalid port".to_string()))?;
235            ("localhost".to_string(), port)
236        } else {
237            let (h, port_str) = hostport.split_once(':').unwrap_or((hostport, "5432"));
238            let port: u16 = port_str
239                .parse()
240                .map_err(|_| PgError::Protocol("Invalid port".to_string()))?;
241            (h.to_string(), port)
242        };
243
244        Ok(Self {
245            host,
246            port,
247            user: user.to_string(),
248            password: password.to_string(),
249            database: database.to_string(),
250            socket_dir,
251            #[cfg(feature = "tls")]
252            ssl_mode,
253        })
254    }
255}
256
257/// Minimal percent-decoding for URL host component.
258fn percent_decode(input: &str) -> String {
259    let mut result = String::with_capacity(input.len());
260    let bytes = input.as_bytes();
261    let mut i = 0;
262    while i < bytes.len() {
263        if bytes[i] == b'%'
264            && i + 2 < bytes.len()
265            && let (Some(hi), Some(lo)) = (hex_digit(bytes[i + 1]), hex_digit(bytes[i + 2]))
266        {
267            result.push((hi << 4 | lo) as char);
268            i += 3;
269            continue;
270        }
271        result.push(bytes[i] as char);
272        i += 1;
273    }
274    result
275}
276
277fn hex_digit(b: u8) -> Option<u8> {
278    match b {
279        b'0'..=b'9' => Some(b - b'0'),
280        b'a'..=b'f' => Some(b - b'a' + 10),
281        b'A'..=b'F' => Some(b - b'A' + 10),
282        _ => None,
283    }
284}
285
286/// A notification received via LISTEN/NOTIFY.
287#[derive(Debug, Clone)]
288pub struct Notification {
289    /// Process ID of the notifying backend.
290    pub process_id: i32,
291    /// Channel name.
292    pub channel: String,
293    /// Payload string.
294    pub payload: String,
295}
296
297/// Type alias for a notice handler function pointer.
298type NoticeHandler = Box<dyn Fn(&str, &str, &str) + Send + Sync>;
299
300/// A synchronous PostgreSQL connection with poll-based non-blocking I/O.
301///
302/// The socket is set to non-blocking mode at connect time.  I/O methods
303/// internally poll with a configurable timeout so they can be adapted to
304/// an event-loop's readiness notifications.
305pub struct PgConnection {
306    stream: PgStream,
307    read_buf: Vec<u8>,
308    write_buf: Vec<u8>,
309    read_pos: usize,
310    tx_status: TransactionStatus,
311    stmt_cache: StatementCache,
312    process_id: i32,
313    secret_key: i32,
314    server_params: Vec<(String, String)>,
315    /// Buffered notifications received during query processing.
316    notifications: VecDeque<Notification>,
317    /// Number of rows affected by the last command (from CommandComplete).
318    last_affected_rows: u64,
319    /// The last CommandComplete tag string.
320    last_command_tag: String,
321    /// Whether the socket is in non-blocking mode.
322    nonblocking: bool,
323    /// Application-level I/O timeout for poll operations.
324    io_timeout: Duration,
325    /// Optional callback invoked when the server sends a NoticeResponse.
326    notice_handler: Option<NoticeHandler>,
327    /// Flag set on fatal I/O errors. A broken connection must not be
328    /// returned to the pool; it will be discarded on drop.
329    broken: bool,
330}
331
332impl PgConnection {
333    /// Connect to PostgreSQL (blocking during handshake, then switches to
334    /// non-blocking mode once authentication completes).
335    pub fn connect(config: &PgConfig) -> PgResult<Self> {
336        let stream = if let Some(ref socket_dir) = config.socket_dir {
337            // Unix domain socket connection
338            #[cfg(unix)]
339            {
340                let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, config.port);
341                let unix_stream = UnixStream::connect(&socket_path).map_err(PgError::Io)?;
342                PgStream::Unix(unix_stream)
343            }
344            #[cfg(not(unix))]
345            {
346                let _ = socket_dir;
347                return Err(PgError::Protocol(
348                    "Unix domain sockets are not supported on this platform".to_string(),
349                ));
350            }
351        } else {
352            let addr = format!("{}:{}", config.host, config.port);
353            let tcp = TcpStream::connect(&addr).map_err(PgError::Io)?;
354            // Disable Nagle's algorithm for lower latency
355            let _ = tcp.set_nodelay(true);
356
357            // TLS negotiation (when feature enabled)
358            #[cfg(feature = "tls")]
359            {
360                match config.ssl_mode {
361                    tls::SslMode::Disable => PgStream::Tcp(tcp),
362                    tls::SslMode::Prefer => match tls::negotiate(tcp, &config.host) {
363                        Ok(tls::TlsNegotiateResult::Tls(tls_stream)) => PgStream::Tls(tls_stream),
364                        Ok(tls::TlsNegotiateResult::Rejected(tcp)) => PgStream::Tcp(tcp),
365                        Err(_) => {
366                            // TLS negotiation failed — reconnect plain-text
367                            let tcp = TcpStream::connect(&addr).map_err(PgError::Io)?;
368                            let _ = tcp.set_nodelay(true);
369                            PgStream::Tcp(tcp)
370                        }
371                    },
372                    tls::SslMode::Require => match tls::negotiate(tcp, &config.host)? {
373                        tls::TlsNegotiateResult::Tls(tls_stream) => PgStream::Tls(tls_stream),
374                        tls::TlsNegotiateResult::Rejected(_) => {
375                            return Err(PgError::Protocol(
376                                "Server does not support SSL (sslmode=require)".to_string(),
377                            ));
378                        }
379                    },
380                }
381            }
382
383            #[cfg(not(feature = "tls"))]
384            PgStream::Tcp(tcp)
385        };
386
387        let mut conn = Self {
388            stream,
389            read_buf: vec![0u8; 64 * 1024],  // 64 KB read buffer
390            write_buf: vec![0u8; 64 * 1024], // 64 KB write buffer
391            read_pos: 0,
392            tx_status: TransactionStatus::Idle,
393            stmt_cache: StatementCache::new(),
394            process_id: 0,
395            secret_key: 0,
396            server_params: Vec::new(),
397            notifications: VecDeque::new(),
398            last_affected_rows: 0,
399            last_command_tag: String::new(),
400            nonblocking: false,
401            io_timeout: DEFAULT_IO_TIMEOUT,
402            notice_handler: None,
403            broken: false,
404        };
405
406        conn.startup(config)?;
407
408        // Switch to non-blocking after successful authentication
409        conn.stream.set_nonblocking(true).map_err(PgError::Io)?;
410        conn.nonblocking = true;
411
412        Ok(conn)
413    }
414
415    /// Connect with a custom I/O timeout.
416    pub fn connect_with_timeout(config: &PgConfig, timeout: Duration) -> PgResult<Self> {
417        let mut conn = Self::connect(config)?;
418        conn.io_timeout = timeout;
419        Ok(conn)
420    }
421
422    /// Set the application-level I/O timeout.
423    pub fn set_io_timeout(&mut self, timeout: Duration) {
424        self.io_timeout = timeout;
425    }
426
427    /// Get the current I/O timeout.
428    pub fn io_timeout(&self) -> Duration {
429        self.io_timeout
430    }
431
432    /// Set a callback that is invoked when the server sends a NoticeResponse.
433    ///
434    /// The callback receives `(severity, code, message)`. This is useful for
435    /// logging warnings, deprecation notices, etc.
436    ///
437    /// # Example
438    /// ```ignore
439    /// conn.set_notice_handler(|severity, code, message| {
440    ///     eprintln!("PG {}: {} ({})", severity, message, code);
441    /// });
442    /// ```
443    pub fn set_notice_handler<F>(&mut self, handler: F)
444    where
445        F: Fn(&str, &str, &str) + Send + Sync + 'static,
446    {
447        self.notice_handler = Some(Box::new(handler));
448    }
449
450    /// Remove the notice handler.
451    pub fn clear_notice_handler(&mut self) {
452        self.notice_handler = None;
453    }
454
455    /// Set the maximum number of statements to cache before LRU eviction.
456    pub fn set_statement_cache_capacity(&mut self, capacity: usize) {
457        self.stmt_cache.set_max_capacity(capacity);
458    }
459
460    /// Return the raw file descriptor for event-loop registration
461    /// (epoll / kqueue).
462    #[cfg(unix)]
463    pub fn raw_fd(&self) -> std::os::unix::io::RawFd {
464        self.stream.as_raw_fd()
465    }
466
467    /// Check if the socket is in non-blocking mode.
468    pub fn is_nonblocking(&self) -> bool {
469        self.nonblocking
470    }
471
472    /// Set non-blocking mode on the socket.
473    pub fn set_nonblocking(&mut self, nonblocking: bool) -> PgResult<()> {
474        self.stream
475            .set_nonblocking(nonblocking)
476            .map_err(PgError::Io)?;
477        self.nonblocking = nonblocking;
478        Ok(())
479    }
480
481    /// Perform the startup and authentication handshake.
482    fn startup(&mut self, config: &PgConfig) -> PgResult<()> {
483        // Send StartupMessage
484        self.ensure_write_capacity(512);
485        let n = codec::encode_startup(&mut self.write_buf, &config.user, &config.database, &[]);
486        self.stream
487            .write_all(&self.write_buf[..n])
488            .map_err(PgError::Io)?;
489
490        // Read server response
491        loop {
492            self.fill_read_buf(None)?;
493
494            while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
495                let header = codec::decode_header(&self.read_buf)
496                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
497                let body = &self.read_buf[5..msg_len];
498
499                match header.tag {
500                    BackendTag::AuthenticationRequest => {
501                        let auth_type = codec::read_i32(&self.read_buf, 5);
502                        match AuthType::from_i32(auth_type) {
503                            Some(AuthType::Ok) => {
504                                // Handled! Keep going to ReadyForQuery
505                            }
506                            Some(AuthType::CleartextPassword) => {
507                                let n =
508                                    codec::encode_password(&mut self.write_buf, &config.password);
509                                self.stream
510                                    .write_all(&self.write_buf[..n])
511                                    .map_err(PgError::Io)?;
512                            }
513                            Some(AuthType::SASLInit) => {
514                                // B.3: Use SCRAM-SHA-256-PLUS with channel binding when TLS is active
515                                #[cfg(feature = "tls")]
516                                let (mut scram, mechanism) =
517                                    if let Some(cb_data) = self.stream.tls_server_cert_hash() {
518                                        (
519                                            ScramClient::new_with_channel_binding(
520                                                &config.user,
521                                                &config.password,
522                                                cb_data,
523                                            ),
524                                            "SCRAM-SHA-256-PLUS",
525                                        )
526                                    } else {
527                                        (
528                                            ScramClient::new(&config.user, &config.password),
529                                            "SCRAM-SHA-256",
530                                        )
531                                    };
532                                #[cfg(not(feature = "tls"))]
533                                let (mut scram, mechanism) = (
534                                    ScramClient::new(&config.user, &config.password),
535                                    "SCRAM-SHA-256",
536                                );
537
538                                let client_first = scram.client_first_message();
539                                let n = codec::encode_sasl_initial(
540                                    &mut self.write_buf,
541                                    mechanism,
542                                    &client_first,
543                                );
544                                self.stream
545                                    .write_all(&self.write_buf[..n])
546                                    .map_err(PgError::Io)?;
547
548                                self.consume_read(msg_len);
549                                self.wait_for_sasl_continue(&mut scram, config)?;
550                                // After SASL, we might still have messages in the buffer
551                                // so we don't return, we continue the outer loop.
552                                continue;
553                            }
554                            Some(AuthType::MD5Password) => {
555                                // Salt is 4 bytes following the auth_type int32
556                                if body.len() < 8 {
557                                    return Err(PgError::Protocol(
558                                        "MD5Password message too short".to_string(),
559                                    ));
560                                }
561                                let salt: [u8; 4] = [body[4], body[5], body[6], body[7]];
562                                let hash = crate::auth::md5_password_hash(
563                                    &config.user,
564                                    &config.password,
565                                    &salt,
566                                );
567                                let n = codec::encode_password(&mut self.write_buf, &hash);
568                                self.stream
569                                    .write_all(&self.write_buf[..n])
570                                    .map_err(PgError::Io)?;
571                            }
572                            _ => {
573                                return Err(PgError::Auth(format!(
574                                    "Unsupported auth type: {}",
575                                    auth_type
576                                )));
577                            }
578                        }
579                    }
580                    BackendTag::ParameterStatus => {
581                        let (name, consumed) = codec::read_cstring(body, 0);
582                        let (value, _) = codec::read_cstring(body, consumed);
583                        self.server_params
584                            .push((name.to_string(), value.to_string()));
585                    }
586                    BackendTag::BackendKeyData => {
587                        self.process_id = codec::read_i32(body, 0);
588                        self.secret_key = codec::read_i32(body, 4);
589                    }
590                    BackendTag::ReadyForQuery => {
591                        self.tx_status = TransactionStatus::from(body[0]);
592                        self.consume_read(msg_len);
593                        return Ok(()); // Connection is ready!
594                    }
595                    BackendTag::ErrorResponse => {
596                        let fields = codec::parse_error_fields(body);
597                        return Err(PgError::from_fields(&fields));
598                    }
599                    _ => {
600                        // Skip unknown messages
601                    }
602                }
603                self.consume_read(msg_len);
604            }
605        }
606    }
607
608    /// Handle SASL Continue/Final exchange.
609    fn wait_for_sasl_continue(
610        &mut self,
611        scram: &mut ScramClient,
612        _config: &PgConfig,
613    ) -> PgResult<()> {
614        loop {
615            self.fill_read_buf(None)?;
616
617            while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
618                let header = codec::decode_header(&self.read_buf)
619                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
620                let body = &self.read_buf[5..msg_len].to_vec();
621
622                match header.tag {
623                    BackendTag::AuthenticationRequest => {
624                        let auth_type = codec::read_i32(&self.read_buf, 5);
625                        match AuthType::from_i32(auth_type) {
626                            Some(AuthType::SASLContinue) => {
627                                let server_first = &body[4..];
628                                let client_final = scram
629                                    .process_server_first(server_first)
630                                    .map_err(PgError::Auth)?;
631
632                                let n =
633                                    codec::encode_sasl_response(&mut self.write_buf, &client_final);
634                                self.stream
635                                    .write_all(&self.write_buf[..n])
636                                    .map_err(PgError::Io)?;
637                            }
638                            Some(AuthType::SASLFinal) => {
639                                let server_final = &body[4..];
640                                scram
641                                    .verify_server_final(server_final)
642                                    .map_err(PgError::Auth)?;
643                            }
644                            Some(AuthType::Ok) => {
645                                self.consume_read(msg_len);
646                                return Ok(());
647                            }
648                            _ => {
649                                return Err(PgError::Auth(
650                                    "Unexpected auth message during SASL".to_string(),
651                                ));
652                            }
653                        }
654                    }
655                    _ => {
656                        // Skip
657                    }
658                }
659                self.consume_read(msg_len);
660            }
661        }
662    }
663
664    // ─── Query Methods ────────────────────────────────────────
665
666    /// Execute a simple query (no parameters). Returns all result rows.
667    pub fn query_simple(&mut self, sql: &str) -> PgResult<Vec<Row>> {
668        self.ensure_write_capacity(5 + sql.len());
669        let n = codec::encode_query(&mut self.write_buf, sql);
670        self.flush_write_buf(n)?;
671        self.read_query_results()
672    }
673
674    /// Execute a parameterized query using the Extended Query Protocol.
675    /// Uses implicit statement caching for performance.
676    pub fn query(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Vec<Row>> {
677        let stmt = self.stmt_cache.get_or_create(sql);
678
679        // Conservative upper bound for write buffer
680        let estimated = 10 + sql.len() + (params.len() * 256);
681        self.ensure_write_capacity(estimated);
682
683        let mut pos = 0;
684
685        if stmt.is_new {
686            // Parse
687            let n = codec::encode_parse(&mut self.write_buf[pos..], &stmt.name, sql, &[]);
688            pos += n;
689
690            // Describe (to get column info)
691            let n = codec::encode_describe(
692                &mut self.write_buf[pos..],
693                DescribeTarget::Statement,
694                &stmt.name,
695            );
696            pos += n;
697        }
698
699        // Bind — encode parameters with per-parameter format codes
700        let pg_values: Vec<PgValue> = params.iter().map(|p| p.to_sql()).collect();
701        let param_formats: Vec<i16> = pg_values
702            .iter()
703            .map(|v| if v.prefers_binary() { 1_i16 } else { 0_i16 })
704            .collect();
705        let param_values: Vec<Option<Vec<u8>>> = pg_values
706            .iter()
707            .zip(param_formats.iter())
708            .map(|(v, &fmt)| {
709                if fmt == 1 {
710                    v.to_binary_bytes()
711                } else {
712                    v.to_text_bytes()
713                }
714            })
715            .collect();
716        let param_refs: Vec<Option<&[u8]>> = param_values.iter().map(|p| p.as_deref()).collect();
717        let n = codec::encode_bind(
718            &mut self.write_buf[pos..],
719            "", // unnamed portal
720            &stmt.name,
721            &param_formats,
722            &param_refs,
723            &[1], // request all results in binary format
724        );
725        pos += n;
726
727        // Execute
728        let n = codec::encode_execute(&mut self.write_buf[pos..], "", 0);
729        pos += n;
730
731        // Sync
732        let n = codec::encode_sync(&mut self.write_buf[pos..]);
733        pos += n;
734
735        self.flush_write_buf(pos)?;
736
737        // Read results
738        let rows = self.read_extended_results(sql, &stmt.name, stmt.is_new, stmt.columns)?;
739        Ok(rows)
740    }
741
742    /// Execute a query expecting exactly one row.
743    ///
744    /// Optimised path: reads the Extended Query Protocol response stream and
745    /// returns the **first** `DataRow` directly, without collecting into a
746    /// `Vec<Row>`.  Subsequent rows (if any) are still drained so the
747    /// connection is left in a clean state.
748    pub fn query_one(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Row> {
749        let stmt = self.stmt_cache.get_or_create(sql);
750
751        let estimated = 10 + sql.len() + (params.len() * 256);
752        self.ensure_write_capacity(estimated);
753
754        let mut pos = 0;
755
756        if stmt.is_new {
757            let n = codec::encode_parse(&mut self.write_buf[pos..], &stmt.name, sql, &[]);
758            pos += n;
759            let n = codec::encode_describe(
760                &mut self.write_buf[pos..],
761                DescribeTarget::Statement,
762                &stmt.name,
763            );
764            pos += n;
765        }
766
767        let pg_values: Vec<PgValue> = params.iter().map(|p| p.to_sql()).collect();
768        let param_formats: Vec<i16> = pg_values
769            .iter()
770            .map(|v| if v.prefers_binary() { 1_i16 } else { 0_i16 })
771            .collect();
772        let param_values: Vec<Option<Vec<u8>>> = pg_values
773            .iter()
774            .zip(param_formats.iter())
775            .map(|(v, &fmt)| {
776                if fmt == 1 {
777                    v.to_binary_bytes()
778                } else {
779                    v.to_text_bytes()
780                }
781            })
782            .collect();
783        let param_refs: Vec<Option<&[u8]>> = param_values.iter().map(|p| p.as_deref()).collect();
784        let n = codec::encode_bind(
785            &mut self.write_buf[pos..],
786            "",
787            &stmt.name,
788            &param_formats,
789            &param_refs,
790            &[1],
791        );
792        pos += n;
793
794        // Execute with max_rows = 0 (unlimited) — PostgreSQL will send all
795        // DataRows but we stop caring after the first one.
796        let n = codec::encode_execute(&mut self.write_buf[pos..], "", 0);
797        pos += n;
798
799        let n = codec::encode_sync(&mut self.write_buf[pos..]);
800        pos += n;
801
802        self.flush_write_buf(pos)?;
803
804        self.read_extended_result_one(sql, &stmt.name, stmt.is_new, stmt.columns)
805    }
806
807    /// Execute a query expecting zero or one row. Returns `Ok(None)` when
808    /// the query returns no rows, avoiding the `PgError::NoRows` error path.
809    pub fn query_opt(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Option<Row>> {
810        let stmt = self.stmt_cache.get_or_create(sql);
811
812        let estimated = 10 + sql.len() + (params.len() * 256);
813        self.ensure_write_capacity(estimated);
814
815        let mut pos = 0;
816
817        if stmt.is_new {
818            let n = codec::encode_parse(&mut self.write_buf[pos..], &stmt.name, sql, &[]);
819            pos += n;
820            let n = codec::encode_describe(
821                &mut self.write_buf[pos..],
822                DescribeTarget::Statement,
823                &stmt.name,
824            );
825            pos += n;
826        }
827
828        let pg_values: Vec<PgValue> = params.iter().map(|p| p.to_sql()).collect();
829        let param_formats: Vec<i16> = pg_values
830            .iter()
831            .map(|v| if v.prefers_binary() { 1_i16 } else { 0_i16 })
832            .collect();
833        let param_values: Vec<Option<Vec<u8>>> = pg_values
834            .iter()
835            .zip(param_formats.iter())
836            .map(|(v, &fmt)| {
837                if fmt == 1 {
838                    v.to_binary_bytes()
839                } else {
840                    v.to_text_bytes()
841                }
842            })
843            .collect();
844        let param_refs: Vec<Option<&[u8]>> = param_values.iter().map(|p| p.as_deref()).collect();
845        let n = codec::encode_bind(
846            &mut self.write_buf[pos..],
847            "",
848            &stmt.name,
849            &param_formats,
850            &param_refs,
851            &[1],
852        );
853        pos += n;
854
855        let n = codec::encode_execute(&mut self.write_buf[pos..], "", 0);
856        pos += n;
857
858        let n = codec::encode_sync(&mut self.write_buf[pos..]);
859        pos += n;
860
861        self.flush_write_buf(pos)?;
862
863        self.read_extended_result_opt(sql, &stmt.name, stmt.is_new, stmt.columns)
864    }
865
866    /// Execute a statement that returns no rows (INSERT, UPDATE, DELETE).
867    /// Returns the number of affected rows as reported by the server.
868    pub fn execute(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<u64> {
869        let _rows = self.query(sql, params)?;
870        Ok(self.last_affected_rows)
871    }
872
873    // ─── Transaction Support ──────────────────────────────────
874
875    /// Begin a transaction.
876    pub fn begin(&mut self) -> PgResult<()> {
877        self.query_simple("BEGIN")?;
878        Ok(())
879    }
880
881    /// Commit the current transaction.
882    pub fn commit(&mut self) -> PgResult<()> {
883        self.query_simple("COMMIT")?;
884        Ok(())
885    }
886
887    /// Rollback the current transaction.
888    pub fn rollback(&mut self) -> PgResult<()> {
889        self.query_simple("ROLLBACK")?;
890        Ok(())
891    }
892
893    /// Create a savepoint.
894    pub fn savepoint(&mut self, name: &str) -> PgResult<()> {
895        self.query_simple(&format!("SAVEPOINT {}", name))?;
896        Ok(())
897    }
898
899    /// Rollback to a savepoint.
900    pub fn rollback_to(&mut self, name: &str) -> PgResult<()> {
901        self.query_simple(&format!("ROLLBACK TO SAVEPOINT {}", name))?;
902        Ok(())
903    }
904
905    /// Release a savepoint.
906    pub fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
907        self.query_simple(&format!("RELEASE SAVEPOINT {}", name))?;
908        Ok(())
909    }
910
911    /// Execute a closure within a transaction.
912    ///
913    /// Automatically BEGINs before calling `f`, COMMITs on success,
914    /// and ROLLBACKs on error. This ensures the transaction is always
915    /// finalized, even if the closure panics (via Drop).
916    ///
917    /// # Example
918    /// ```ignore
919    /// conn.transaction(|tx| {
920    ///     tx.execute("INSERT INTO users (name) VALUES ($1)", &[&"Alice"])?;
921    ///     tx.execute("INSERT INTO logs (msg) VALUES ($1)", &[&"User created"])?;
922    ///     Ok(())
923    /// })?;
924    /// ```
925    pub fn transaction<F, T>(&mut self, f: F) -> PgResult<T>
926    where
927        F: FnOnce(&mut Transaction<'_>) -> PgResult<T>,
928    {
929        self.begin()?;
930        let mut tx = Transaction {
931            conn: self,
932            finished: false,
933            savepoint_name: None,
934            savepoint_counter: 0,
935        };
936        match f(&mut tx) {
937            Ok(val) => {
938                tx.commit()?;
939                Ok(val)
940            }
941            Err(e) => {
942                // Attempt rollback, but propagate original error
943                let _ = tx.rollback();
944                Err(e)
945            }
946        }
947    }
948
949    // ─── COPY Protocol ────────────────────────────────────────
950
951    /// Start a COPY FROM STDIN operation.
952    pub fn copy_in(&mut self, sql: &str) -> PgResult<CopyWriter<'_>> {
953        let n = codec::encode_query(&mut self.write_buf, sql);
954        #[allow(clippy::unnecessary_to_owned)]
955        self.write_all(&self.write_buf[..n].to_vec())?;
956
957        // Read until CopyInResponse
958        loop {
959            self.fill_read_buf(None)?;
960            let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) else {
961                continue;
962            };
963            let header = codec::decode_header(&self.read_buf)
964                .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
965            match header.tag {
966                BackendTag::CopyInResponse => {
967                    self.consume_read(msg_len);
968                    return Ok(CopyWriter { conn: self });
969                }
970                BackendTag::ErrorResponse => {
971                    let body = &self.read_buf[5..msg_len];
972                    return Err(self.parse_error(body));
973                }
974                _ => {
975                    self.consume_read(msg_len);
976                }
977            }
978        }
979    }
980
981    /// Start a COPY TO STDOUT operation.
982    /// Returns a CopyReader that yields data chunks.
983    pub fn copy_out(&mut self, sql: &str) -> PgResult<CopyReader<'_>> {
984        let n = codec::encode_query(&mut self.write_buf, sql);
985        #[allow(clippy::unnecessary_to_owned)]
986        self.write_all(&self.write_buf[..n].to_vec())?;
987
988        // Read until CopyOutResponse
989        loop {
990            self.fill_read_buf(None)?;
991            let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) else {
992                continue;
993            };
994            let header = codec::decode_header(&self.read_buf)
995                .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
996            match header.tag {
997                BackendTag::CopyOutResponse => {
998                    self.consume_read(msg_len);
999                    return Ok(CopyReader {
1000                        conn: self,
1001                        done: false,
1002                    });
1003                }
1004                BackendTag::ErrorResponse => {
1005                    let body = &self.read_buf[5..msg_len];
1006                    return Err(self.parse_error(body));
1007                }
1008                _ => {
1009                    self.consume_read(msg_len);
1010                }
1011            }
1012        }
1013    }
1014
1015    // ─── LISTEN / NOTIFY ──────────────────────────────────────
1016
1017    /// Subscribe to a notification channel.
1018    pub fn listen(&mut self, channel: &str) -> PgResult<()> {
1019        self.query_simple(&format!("LISTEN {}", channel))?;
1020        Ok(())
1021    }
1022
1023    /// Send a notification.
1024    pub fn notify(&mut self, channel: &str, payload: &str) -> PgResult<()> {
1025        self.query_simple(&format!("NOTIFY {}, '{}'", channel, payload))?;
1026        Ok(())
1027    }
1028
1029    /// Unsubscribe from a notification channel.
1030    pub fn unlisten(&mut self, channel: &str) -> PgResult<()> {
1031        self.query_simple(&format!("UNLISTEN {}", channel))?;
1032        Ok(())
1033    }
1034
1035    /// Unsubscribe from all notification channels.
1036    pub fn unlisten_all(&mut self) -> PgResult<()> {
1037        self.query_simple("UNLISTEN *")?;
1038        Ok(())
1039    }
1040
1041    /// Drain and return all buffered notifications.
1042    pub fn drain_notifications(&mut self) -> Vec<Notification> {
1043        self.notifications.drain(..).collect()
1044    }
1045
1046    /// Check if there are buffered notifications.
1047    pub fn has_notifications(&self) -> bool {
1048        !self.notifications.is_empty()
1049    }
1050
1051    /// Get the number of buffered notifications.
1052    pub fn notification_count(&self) -> usize {
1053        self.notifications.len()
1054    }
1055
1056    /// Poll for a notification (always non-blocking).
1057    /// Reads from the socket and returns the first notification found,
1058    /// or None if no notification is immediately available.
1059    pub fn poll_notification(&mut self) -> PgResult<Option<Notification>> {
1060        // First check buffer
1061        if let Some(n) = self.notifications.pop_front() {
1062            return Ok(Some(n));
1063        }
1064
1065        // Try a non-blocking read (socket is already non-blocking)
1066        self.ensure_read_space();
1067        match self.stream.read(&mut self.read_buf[self.read_pos..]) {
1068            Ok(0) => return Err(PgError::ConnectionClosed),
1069            Ok(n) => {
1070                self.read_pos += n;
1071                // Process any complete messages
1072                while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
1073                    let header = codec::decode_header(&self.read_buf).ok_or_else(|| {
1074                        PgError::Protocol("Incomplete message header".to_string())
1075                    })?;
1076                    if header.tag == BackendTag::NotificationResponse {
1077                        let body = &self.read_buf[5..msg_len];
1078                        let notification = Self::parse_notification(body);
1079                        self.notifications.push_back(notification);
1080                    }
1081                    self.consume_read(msg_len);
1082                }
1083            }
1084            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1085                // No data available
1086            }
1087            Err(e) => return Err(PgError::Io(e)),
1088        }
1089
1090        Ok(self.notifications.pop_front())
1091    }
1092
1093    // ─── Accessors ────────────────────────────────────────────
1094
1095    /// Get the current transaction status.
1096    pub fn transaction_status(&self) -> TransactionStatus {
1097        self.tx_status
1098    }
1099
1100    /// Get the number of cached statements.
1101    pub fn cached_statements(&self) -> usize {
1102        self.stmt_cache.len()
1103    }
1104
1105    /// Get the number of rows affected by the last command.
1106    pub fn last_affected_rows(&self) -> u64 {
1107        self.last_affected_rows
1108    }
1109
1110    /// Get the last CommandComplete tag string.
1111    pub fn last_command_tag(&self) -> &str {
1112        &self.last_command_tag
1113    }
1114
1115    /// Get the backend process ID.
1116    pub fn process_id(&self) -> i32 {
1117        self.process_id
1118    }
1119
1120    /// Get the backend secret key (used for cancel requests).
1121    pub fn secret_key(&self) -> i32 {
1122        self.secret_key
1123    }
1124
1125    /// Get server parameters received during startup.
1126    pub fn server_params(&self) -> &[(String, String)] {
1127        &self.server_params
1128    }
1129
1130    /// Get a specific server parameter by name.
1131    pub fn server_param(&self, name: &str) -> Option<&str> {
1132        self.server_params
1133            .iter()
1134            .find(|(k, _)| k == name)
1135            .map(|(_, v)| v.as_str())
1136    }
1137
1138    /// Check if the connection is in a transaction.
1139    pub fn in_transaction(&self) -> bool {
1140        matches!(
1141            self.tx_status,
1142            TransactionStatus::InTransaction | TransactionStatus::Failed
1143        )
1144    }
1145
1146    /// Clear the statement cache and deallocate all server-side prepared statements.
1147    ///
1148    /// Sends `DEALLOCATE ALL` to the server before clearing the client-side
1149    /// cache.  The statement name counter is preserved to prevent name
1150    /// collisions with any stale server-side references.
1151    pub fn clear_statement_cache(&mut self) {
1152        let _ = self.query_simple("DEALLOCATE ALL");
1153        self.stmt_cache.clear();
1154    }
1155
1156    /// Returns `true` if the connection has been marked as broken due to a
1157    /// fatal I/O error.  A broken connection should be discarded (not
1158    /// returned to the pool).
1159    pub fn is_broken(&self) -> bool {
1160        self.broken
1161    }
1162
1163    /// Reset the connection to a clean state for pool reuse.
1164    ///
1165    /// Sends `DISCARD ALL` which resets session state, deallocates prepared
1166    /// statements, closes cursors, drops temps, releases advisory locks.
1167    /// Then clears the client-side statement cache.
1168    pub fn reset(&mut self) -> PgResult<()> {
1169        self.query_simple("DISCARD ALL")?;
1170        self.stmt_cache.clear();
1171        Ok(())
1172    }
1173
1174    /// Execute one or more SQL statements separated by semicolons, using
1175    /// the Simple Query Protocol.  Returns the number of affected rows from
1176    /// the **last** command.
1177    ///
1178    /// This is useful for running DDL migrations, multi-statement scripts,
1179    /// or any sequence of commands that don't require parameters.
1180    ///
1181    /// # Example
1182    /// ```ignore
1183    /// conn.execute_batch("CREATE TABLE t(id INT); INSERT INTO t VALUES (1); INSERT INTO t VALUES (2);")?;
1184    /// ```
1185    pub fn execute_batch(&mut self, sql: &str) -> PgResult<u64> {
1186        self.query_simple(sql)?;
1187        Ok(self.last_affected_rows)
1188    }
1189
1190    /// Check if the connection is alive by sending a simple query.
1191    pub fn is_alive(&mut self) -> bool {
1192        self.query_simple("SELECT 1").is_ok()
1193    }
1194
1195    // ─── Internal Methods ─────────────────────────────────────
1196
1197    // ─── Non-blocking read/write primitives ───────────────────
1198
1199    /// Try to read data into the read buffer without blocking.
1200    /// Returns `Ok(n)` with bytes read, or `Err(PgError::WouldBlock)` if no
1201    /// data is available, or another error on failure.
1202    pub fn try_fill_read_buf(&mut self) -> PgResult<usize> {
1203        self.ensure_read_space();
1204
1205        match self.stream.read(&mut self.read_buf[self.read_pos..]) {
1206            Ok(0) => {
1207                self.broken = true;
1208                Err(PgError::ConnectionClosed)
1209            }
1210            Ok(n) => {
1211                self.read_pos += n;
1212                Ok(n)
1213            }
1214            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(PgError::WouldBlock),
1215            Err(e) => {
1216                self.broken = true;
1217                Err(PgError::Io(e))
1218            }
1219        }
1220    }
1221
1222    /// Try to write a buffer to the socket without blocking.
1223    /// Returns `Ok(n)` with bytes written, or `Err(PgError::WouldBlock)` if the
1224    /// socket is not writable.
1225    pub fn try_write(&mut self, data: &[u8]) -> PgResult<usize> {
1226        match self.stream.write(data) {
1227            Ok(n) => Ok(n),
1228            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(PgError::WouldBlock),
1229            Err(e) => {
1230                self.broken = true;
1231                Err(PgError::Io(e))
1232            }
1233        }
1234    }
1235
1236    /// Wait for the socket to become readable using `poll(2)`.
1237    ///
1238    /// This is a true OS-level wait — the thread sleeps in the kernel until
1239    /// the socket has data or the timeout expires.  No busy-waiting, no
1240    /// `thread::sleep`.
1241    #[cfg(unix)]
1242    fn wait_readable(&self, timeout: Duration) -> PgResult<()> {
1243        let fd = self.stream.as_raw_fd();
1244        let timeout_ms = timeout.as_millis().min(i32::MAX as u128) as i32;
1245        let mut pfd = libc::pollfd {
1246            fd,
1247            events: libc::POLLIN,
1248            revents: 0,
1249        };
1250        let ret = unsafe { libc::poll(&mut pfd, 1, timeout_ms) };
1251        if ret < 0 {
1252            let e = std::io::Error::last_os_error();
1253            if e.kind() == std::io::ErrorKind::Interrupted {
1254                return Ok(()); // EINTR — caller will retry
1255            }
1256            return Err(PgError::Io(e));
1257        }
1258        if ret == 0 {
1259            return Err(PgError::Timeout);
1260        }
1261        if pfd.revents & (libc::POLLERR | libc::POLLHUP | libc::POLLNVAL) != 0 {
1262            return Err(PgError::ConnectionClosed);
1263        }
1264        Ok(())
1265    }
1266
1267    /// Wait for the socket to become writable using `poll(2)`.
1268    #[cfg(unix)]
1269    fn wait_writable(&self, timeout: Duration) -> PgResult<()> {
1270        let fd = self.stream.as_raw_fd();
1271        let timeout_ms = timeout.as_millis().min(i32::MAX as u128) as i32;
1272        let mut pfd = libc::pollfd {
1273            fd,
1274            events: libc::POLLOUT,
1275            revents: 0,
1276        };
1277        let ret = unsafe { libc::poll(&mut pfd, 1, timeout_ms) };
1278        if ret < 0 {
1279            let e = std::io::Error::last_os_error();
1280            if e.kind() == std::io::ErrorKind::Interrupted {
1281                return Ok(());
1282            }
1283            return Err(PgError::Io(e));
1284        }
1285        if ret == 0 {
1286            return Err(PgError::Timeout);
1287        }
1288        if pfd.revents & (libc::POLLERR | libc::POLLHUP | libc::POLLNVAL) != 0 {
1289            return Err(PgError::ConnectionClosed);
1290        }
1291        Ok(())
1292    }
1293
1294    /// Poll the socket for readability with a timeout.
1295    ///
1296    /// Uses the OS `poll(2)` syscall for efficient, zero-waste waiting.
1297    /// The thread sleeps in the kernel until data arrives or the timeout
1298    /// expires — no busy-loop, no `thread::sleep`.
1299    pub fn poll_read(&mut self, timeout: Duration) -> PgResult<usize> {
1300        let start = Instant::now();
1301        loop {
1302            match self.try_fill_read_buf() {
1303                Ok(n) => return Ok(n),
1304                Err(PgError::WouldBlock) => {
1305                    let elapsed = start.elapsed();
1306                    if elapsed >= timeout {
1307                        return Err(PgError::Timeout);
1308                    }
1309                    #[cfg(unix)]
1310                    self.wait_readable(timeout - elapsed)?;
1311                    #[cfg(not(unix))]
1312                    std::thread::sleep(Duration::from_micros(50));
1313                }
1314                Err(e) => return Err(e),
1315            }
1316        }
1317    }
1318
1319    /// Poll the socket for writability with a timeout.
1320    /// Writes all of `data` or times out.
1321    pub fn poll_write(&mut self, data: &[u8], timeout: Duration) -> PgResult<()> {
1322        let start = Instant::now();
1323        let mut written = 0;
1324        while written < data.len() {
1325            match self.try_write(&data[written..]) {
1326                Ok(n) => written += n,
1327                Err(PgError::WouldBlock) => {
1328                    let elapsed = start.elapsed();
1329                    if elapsed >= timeout {
1330                        return Err(PgError::Timeout);
1331                    }
1332                    #[cfg(unix)]
1333                    self.wait_writable(timeout - elapsed)?;
1334                    #[cfg(not(unix))]
1335                    std::thread::sleep(Duration::from_micros(50));
1336                }
1337                Err(e) => return Err(e),
1338            }
1339        }
1340        Ok(())
1341    }
1342
1343    /// Internal: fill the read buffer, blocking with the connection's
1344    /// configured timeout. This is the workhorse used by query methods.
1345    fn fill_read_buf(&mut self, min_size: Option<usize>) -> PgResult<()> {
1346        if let Some(min) = min_size {
1347            self.ensure_read_capacity(min);
1348        }
1349
1350        self.ensure_read_space();
1351
1352        if self.nonblocking {
1353            // Use poll_read with timeout
1354            self.poll_read(self.io_timeout)?;
1355        } else {
1356            // Blocking path (used during startup before we switch to NB)
1357            let n = self
1358                .stream
1359                .read(&mut self.read_buf[self.read_pos..])
1360                .map_err(PgError::Io)?;
1361            if n == 0 {
1362                return Err(PgError::ConnectionClosed);
1363            }
1364            self.read_pos += n;
1365        }
1366        Ok(())
1367    }
1368
1369    /// Internal: write all bytes to the socket, respecting non-blocking mode.
1370    fn write_all(&mut self, data: &[u8]) -> PgResult<()> {
1371        if self.nonblocking {
1372            self.poll_write(data, self.io_timeout)
1373        } else {
1374            self.stream.write_all(data).map_err(PgError::Io)
1375        }
1376    }
1377
1378    /// Internal: flush the first `n` bytes of `self.write_buf` to the stream.
1379    ///
1380    /// This avoids the `.to_vec()` copy that was previously needed to work
1381    /// around borrow-checker limitations when `self.write_buf` is the source
1382    /// and `self.write_all()` takes `&mut self`.  By inlining the write loop
1383    /// here, the compiler can see that `stream` and `write_buf` are disjoint
1384    /// fields (split borrow).
1385    fn flush_write_buf(&mut self, n: usize) -> PgResult<()> {
1386        if self.nonblocking {
1387            let timeout = self.io_timeout;
1388            let start = Instant::now();
1389            let mut written = 0;
1390            while written < n {
1391                match self.stream.write(&self.write_buf[written..n]) {
1392                    Ok(w) => written += w,
1393                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1394                        let elapsed = start.elapsed();
1395                        if elapsed >= timeout {
1396                            return Err(PgError::Timeout);
1397                        }
1398                        #[cfg(unix)]
1399                        self.wait_writable(timeout - elapsed)?;
1400                        #[cfg(not(unix))]
1401                        std::thread::sleep(Duration::from_micros(50));
1402                    }
1403                    Err(e) => {
1404                        self.broken = true;
1405                        return Err(PgError::Io(e));
1406                    }
1407                }
1408            }
1409            Ok(())
1410        } else {
1411            self.stream
1412                .write_all(&self.write_buf[..n])
1413                .map_err(PgError::Io)
1414        }
1415    }
1416
1417    /// Ensure there is room in read_buf for at least one read call.
1418    fn ensure_read_space(&mut self) {
1419        if self.read_pos == self.read_buf.len() {
1420            if self.read_pos >= 5
1421                && let Some(header) = codec::decode_header(&self.read_buf)
1422            {
1423                let total = 1 + header.length as usize;
1424                self.ensure_read_capacity(total - self.read_pos);
1425                return;
1426            }
1427            self.ensure_read_capacity(8192);
1428        }
1429    }
1430
1431    fn consume_read(&mut self, n: usize) {
1432        self.read_buf.copy_within(n..self.read_pos, 0);
1433        self.read_pos -= n;
1434    }
1435
1436    fn ensure_read_capacity(&mut self, additional: usize) {
1437        if self.read_pos + additional > self.read_buf.len() {
1438            let new_len = (self.read_pos + additional).max(self.read_buf.len() * 2);
1439            self.read_buf.resize(new_len, 0);
1440        }
1441    }
1442
1443    fn ensure_write_capacity(&mut self, additional: usize) {
1444        if additional > self.write_buf.len() {
1445            let new_len = additional.max(self.write_buf.len() * 2);
1446            self.write_buf.resize(new_len, 0);
1447        }
1448    }
1449
1450    fn read_query_results(&mut self) -> PgResult<Vec<Row>> {
1451        let mut rows = Vec::new();
1452        let mut columns_rc: Rc<Vec<codec::ColumnDesc>> = Rc::new(Vec::new());
1453
1454        loop {
1455            if codec::message_complete(&self.read_buf[..self.read_pos]).is_none() {
1456                self.fill_read_buf(None)?;
1457            }
1458
1459            while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
1460                let header = codec::decode_header(&self.read_buf)
1461                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1462                let body = &self.read_buf[5..msg_len];
1463
1464                match header.tag {
1465                    BackendTag::RowDescription => {
1466                        columns_rc = Rc::new(codec::parse_row_description(body));
1467                    }
1468                    BackendTag::DataRow => {
1469                        let raw_values = codec::parse_data_row(body);
1470                        rows.push(Row::new(Rc::clone(&columns_rc), raw_values));
1471                    }
1472                    BackendTag::CommandComplete => {
1473                        let (tag, rows_affected) = extract_command_complete(body);
1474                        self.last_command_tag = tag;
1475                        self.last_affected_rows = rows_affected;
1476                    }
1477                    BackendTag::ReadyForQuery => {
1478                        self.tx_status = TransactionStatus::from(body[0]);
1479                        self.consume_read(msg_len);
1480                        return Ok(rows);
1481                    }
1482                    BackendTag::ErrorResponse => {
1483                        let err = self.parse_error(body);
1484                        self.consume_read(msg_len);
1485                        // Drain to ReadyForQuery
1486                        self.drain_to_ready()?;
1487                        return Err(err);
1488                    }
1489                    BackendTag::NotificationResponse => {
1490                        let notification = Self::parse_notification(body);
1491                        self.notifications.push_back(notification);
1492                    }
1493                    BackendTag::EmptyQueryResponse => {}
1494                    BackendTag::NoticeResponse => {
1495                        self.dispatch_notice(body);
1496                    }
1497                    _ => {}
1498                }
1499                self.consume_read(msg_len);
1500            }
1501        }
1502    }
1503
1504    fn read_extended_results(
1505        &mut self,
1506        sql: &str,
1507        stmt_name: &str,
1508        is_new: bool,
1509        cached_columns: Option<Vec<codec::ColumnDesc>>,
1510    ) -> PgResult<Vec<Row>> {
1511        let mut rows = Vec::new();
1512        let mut columns_rc: Rc<Vec<codec::ColumnDesc>> = match cached_columns {
1513            Some(c) => Rc::new(c),
1514            None => Rc::new(Vec::new()),
1515        };
1516
1517        loop {
1518            if codec::message_complete(&self.read_buf[..self.read_pos]).is_none() {
1519                self.fill_read_buf(None)?;
1520            }
1521
1522            while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
1523                let header = codec::decode_header(&self.read_buf)
1524                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1525                let body = &self.read_buf[5..msg_len];
1526
1527                match header.tag {
1528                    BackendTag::ParseComplete => {}
1529                    BackendTag::ParameterDescription => {}
1530                    BackendTag::RowDescription => {
1531                        let mut columns = codec::parse_row_description(body);
1532                        // The driver always requests binary-format results in the
1533                        // Bind message (&[1]).  RowDescription from a Describe
1534                        // *Statement* always has format_code = Text (0x0) because
1535                        // no Bind has occurred yet.  Override every column to
1536                        // Binary so DataRow bytes are decoded correctly.
1537                        for col in &mut columns {
1538                            col.format_code = FormatCode::Binary;
1539                        }
1540                        if is_new
1541                            && let Some(evicted) = self.stmt_cache.insert(
1542                                sql,
1543                                stmt_name.to_string(),
1544                                0,
1545                                Some(columns.clone()),
1546                            )
1547                        {
1548                            self.close_statement_on_server(&evicted.name);
1549                        }
1550                        columns_rc = Rc::new(columns);
1551                    }
1552                    BackendTag::NoData if is_new => {
1553                        if let Some(evicted) =
1554                            self.stmt_cache.insert(sql, stmt_name.to_string(), 0, None)
1555                        {
1556                            self.close_statement_on_server(&evicted.name);
1557                        }
1558                    }
1559                    BackendTag::NoData => {}
1560                    BackendTag::BindComplete => {}
1561                    BackendTag::DataRow => {
1562                        let raw_values = codec::parse_data_row(body);
1563                        rows.push(Row::new(Rc::clone(&columns_rc), raw_values));
1564                    }
1565                    BackendTag::CommandComplete => {
1566                        let (tag, rows_affected) = extract_command_complete(body);
1567                        self.last_command_tag = tag;
1568                        self.last_affected_rows = rows_affected;
1569                    }
1570                    BackendTag::ReadyForQuery => {
1571                        self.tx_status = TransactionStatus::from(body[0]);
1572                        self.consume_read(msg_len);
1573                        return Ok(rows);
1574                    }
1575                    BackendTag::ErrorResponse => {
1576                        let err = self.parse_error_with_context(body, sql);
1577                        self.consume_read(msg_len);
1578                        self.drain_to_ready()?;
1579                        return Err(err);
1580                    }
1581                    BackendTag::NotificationResponse => {
1582                        let notification = Self::parse_notification(body);
1583                        self.notifications.push_back(notification);
1584                    }
1585                    BackendTag::NoticeResponse => {
1586                        self.dispatch_notice(body);
1587                    }
1588                    _ => {}
1589                }
1590                self.consume_read(msg_len);
1591            }
1592        }
1593    }
1594
1595    /// Optimised read path for `query_one`: returns the first `DataRow`
1596    /// directly without collecting into a `Vec`. Remaining rows and
1597    /// protocol messages are drained so the connection stays clean.
1598    fn read_extended_result_one(
1599        &mut self,
1600        sql: &str,
1601        stmt_name: &str,
1602        is_new: bool,
1603        cached_columns: Option<Vec<codec::ColumnDesc>>,
1604    ) -> PgResult<Row> {
1605        match self.read_extended_result_opt(sql, stmt_name, is_new, cached_columns)? {
1606            Some(row) => Ok(row),
1607            None => Err(PgError::NoRows),
1608        }
1609    }
1610
1611    /// Optimised read path for `query_opt`: returns the first `DataRow`
1612    /// as `Some(Row)`, or `None` if the query returns zero rows.
1613    /// Remaining rows are drained.
1614    fn read_extended_result_opt(
1615        &mut self,
1616        sql: &str,
1617        stmt_name: &str,
1618        is_new: bool,
1619        cached_columns: Option<Vec<codec::ColumnDesc>>,
1620    ) -> PgResult<Option<Row>> {
1621        let mut result: Option<Row> = None;
1622        let mut columns_rc: Rc<Vec<codec::ColumnDesc>> = match cached_columns {
1623            Some(c) => Rc::new(c),
1624            None => Rc::new(Vec::new()),
1625        };
1626
1627        loop {
1628            if codec::message_complete(&self.read_buf[..self.read_pos]).is_none() {
1629                self.fill_read_buf(None)?;
1630            }
1631
1632            while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
1633                let header = codec::decode_header(&self.read_buf)
1634                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1635                let body = &self.read_buf[5..msg_len];
1636
1637                match header.tag {
1638                    BackendTag::ParseComplete => {}
1639                    BackendTag::ParameterDescription => {}
1640                    BackendTag::RowDescription => {
1641                        let mut columns = codec::parse_row_description(body);
1642                        for col in &mut columns {
1643                            col.format_code = FormatCode::Binary;
1644                        }
1645                        if is_new
1646                            && let Some(evicted) = self.stmt_cache.insert(
1647                                sql,
1648                                stmt_name.to_string(),
1649                                0,
1650                                Some(columns.clone()),
1651                            )
1652                        {
1653                            self.close_statement_on_server(&evicted.name);
1654                        }
1655                        columns_rc = Rc::new(columns);
1656                    }
1657                    BackendTag::NoData if is_new => {
1658                        if let Some(evicted) =
1659                            self.stmt_cache.insert(sql, stmt_name.to_string(), 0, None)
1660                        {
1661                            self.close_statement_on_server(&evicted.name);
1662                        }
1663                    }
1664                    BackendTag::NoData => {}
1665                    BackendTag::BindComplete => {}
1666                    BackendTag::DataRow
1667                        // Only capture the first row; subsequent DataRows are skipped.
1668                        if result.is_none() => {
1669                            let raw_values = codec::parse_data_row(body);
1670                            result = Some(Row::new(Rc::clone(&columns_rc), raw_values));
1671                    }
1672                    BackendTag::DataRow => {
1673                        // Subsequent DataRows are skipped (not allocated).
1674                    }
1675                    BackendTag::CommandComplete => {
1676                        let (tag, rows_affected) = extract_command_complete(body);
1677                        self.last_command_tag = tag;
1678                        self.last_affected_rows = rows_affected;
1679                    }
1680                    BackendTag::ReadyForQuery => {
1681                        self.tx_status = TransactionStatus::from(body[0]);
1682                        self.consume_read(msg_len);
1683                        return Ok(result);
1684                    }
1685                    BackendTag::ErrorResponse => {
1686                        let err = self.parse_error_with_context(body, sql);
1687                        self.consume_read(msg_len);
1688                        self.drain_to_ready()?;
1689                        return Err(err);
1690                    }
1691                    BackendTag::NotificationResponse => {
1692                        let notification = Self::parse_notification(body);
1693                        self.notifications.push_back(notification);
1694                    }
1695                    BackendTag::NoticeResponse => {
1696                        self.dispatch_notice(body);
1697                    }
1698                    _ => {}
1699                }
1700                self.consume_read(msg_len);
1701            }
1702        }
1703    }
1704
1705    fn drain_to_ready(&mut self) -> PgResult<()> {
1706        loop {
1707            // Only read from the socket when the buffer has no complete message;
1708            // data may already be buffered from an earlier fill_read_buf call.
1709            if codec::message_complete(&self.read_buf[..self.read_pos]).is_none() {
1710                self.fill_read_buf(None)?;
1711            }
1712            while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos]) {
1713                let header = codec::decode_header(&self.read_buf)
1714                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1715                if header.tag == BackendTag::ReadyForQuery {
1716                    let body = &self.read_buf[5..msg_len];
1717                    self.tx_status = TransactionStatus::from(body[0]);
1718                    self.consume_read(msg_len);
1719                    return Ok(());
1720                }
1721                self.consume_read(msg_len);
1722            }
1723        }
1724    }
1725
1726    fn parse_error(&self, body: &[u8]) -> PgError {
1727        let fields = codec::parse_error_fields(body);
1728        PgError::from_fields(&fields)
1729    }
1730
1731    /// Parse an error and attach query context for better debugging.
1732    fn parse_error_with_context(&self, body: &[u8], query: &str) -> PgError {
1733        let fields = codec::parse_error_fields(body);
1734        let mut err = PgError::from_fields(&fields);
1735        if let PgError::Server(ref mut server_err) = err
1736            && server_err.internal_query.is_none()
1737        {
1738            server_err.internal_query = Some(query.to_string());
1739        }
1740        err
1741    }
1742
1743    /// Dispatch a NoticeResponse to the registered handler.
1744    fn dispatch_notice(&self, body: &[u8]) {
1745        if let Some(ref handler) = self.notice_handler {
1746            let fields = codec::parse_error_fields(body);
1747            let mut severity = "";
1748            let mut code = "";
1749            let mut message = "";
1750            for (field_type, value) in &fields {
1751                match field_type {
1752                    b'S' => severity = value,
1753                    b'C' => code = value,
1754                    b'M' => message = value,
1755                    _ => {}
1756                }
1757            }
1758            handler(severity, code, message);
1759        }
1760    }
1761
1762    /// Send a Close('S') message to deallocate a server-side prepared statement.
1763    /// This is fire-and-forget — we don't wait for CloseComplete.
1764    fn close_statement_on_server(&mut self, name: &str) {
1765        self.ensure_write_capacity(7 + name.len());
1766        let n = codec::encode_close(&mut self.write_buf, CloseTarget::Statement, name);
1767        let _ = self.flush_write_buf(n);
1768    }
1769
1770    // Parse a CommandComplete tag to extract affected row count.
1771    // Tags look like: "INSERT 0 5", "UPDATE 3", "DELETE 1", "SELECT 10", etc.
1772    // parse_command_complete is now a free function: extract_command_complete()
1773
1774    /// Parse a NotificationResponse message body.
1775    fn parse_notification(body: &[u8]) -> Notification {
1776        let process_id = codec::read_i32(body, 0);
1777        let (channel, consumed) = codec::read_cstring(body, 4);
1778        let (payload, _) = codec::read_cstring(body, 4 + consumed);
1779        Notification {
1780            process_id,
1781            channel: channel.to_string(),
1782            payload: payload.to_string(),
1783        }
1784    }
1785}
1786
1787/// Extract the command tag and affected row count from a CommandComplete body.
1788/// This is a free function (not a method) to avoid borrow conflicts when
1789/// `body` is a slice of the connection's read buffer.
1790fn extract_command_complete(body: &[u8]) -> (String, u64) {
1791    let (tag, _) = codec::read_cstring(body, 0);
1792    let tag_str = tag.to_string();
1793    let affected_rows = tag
1794        .rsplit(' ')
1795        .next()
1796        .and_then(|s| s.parse::<u64>().ok())
1797        .unwrap_or(0);
1798    (tag_str, affected_rows)
1799}
1800
1801impl Drop for PgConnection {
1802    fn drop(&mut self) {
1803        // Switch to blocking mode so the Terminate message is reliably sent.
1804        // On a non-blocking socket write_all may fail with WouldBlock,
1805        // silently leaving the server-side session open.
1806        if self.nonblocking {
1807            let _ = self.stream.set_nonblocking(false);
1808        }
1809        let n = codec::encode_terminate(&mut self.write_buf);
1810        let _ = self.stream.write_all(&self.write_buf[..n]);
1811    }
1812}
1813
1814// ─── Transaction ──────────────────────────────────────────────
1815
1816/// A transaction guard. Ensures the transaction is committed or rolled back.
1817///
1818/// Created via `PgConnection::transaction()`. Provides the same query
1819/// methods as `PgConnection`. On drop, if neither `commit` nor `rollback`
1820/// was called, automatically rolls back.
1821pub struct Transaction<'a> {
1822    conn: &'a mut PgConnection,
1823    finished: bool,
1824    /// If Some, this is a nested transaction backed by a SAVEPOINT.
1825    savepoint_name: Option<String>,
1826    /// Counter for generating unique savepoint names in nested calls.
1827    savepoint_counter: u32,
1828}
1829
1830impl<'a> Transaction<'a> {
1831    /// Commit this transaction (or release savepoint if nested).
1832    pub fn commit(&mut self) -> PgResult<()> {
1833        if !self.finished {
1834            self.finished = true;
1835            if let Some(ref name) = self.savepoint_name {
1836                self.conn.release_savepoint(name)
1837            } else {
1838                self.conn.commit()
1839            }
1840        } else {
1841            Ok(())
1842        }
1843    }
1844
1845    /// Rollback this transaction (or rollback to savepoint if nested).
1846    pub fn rollback(&mut self) -> PgResult<()> {
1847        if !self.finished {
1848            self.finished = true;
1849            if let Some(ref name) = self.savepoint_name {
1850                self.conn.rollback_to(name)
1851            } else {
1852                self.conn.rollback()
1853            }
1854        } else {
1855            Ok(())
1856        }
1857    }
1858
1859    /// Execute a nested transaction using a SAVEPOINT.
1860    ///
1861    /// Creates a savepoint, calls the closure, and either releases
1862    /// (on success) or rolls back to the savepoint (on error/drop).
1863    ///
1864    /// Nesting is unlimited — each level creates a new savepoint.
1865    ///
1866    /// # Example
1867    /// ```ignore
1868    /// conn.transaction(|tx| {
1869    ///     tx.execute("INSERT INTO users (name) VALUES ($1)", &[&"Alice"])?;
1870    ///     tx.transaction(|nested| {
1871    ///         nested.execute("INSERT INTO logs (msg) VALUES ($1)", &[&"nested"])?;
1872    ///         Ok(())
1873    ///     })?;
1874    ///     Ok(())
1875    /// })?;
1876    /// ```
1877    pub fn transaction<F, T>(&mut self, f: F) -> PgResult<T>
1878    where
1879        F: FnOnce(&mut Transaction<'_>) -> PgResult<T>,
1880    {
1881        self.savepoint_counter += 1;
1882        let sp_name = format!("chopin_sp_{}", self.savepoint_counter);
1883        self.conn.savepoint(&sp_name)?;
1884        let mut nested = Transaction {
1885            conn: self.conn,
1886            finished: false,
1887            savepoint_name: Some(sp_name),
1888            savepoint_counter: 0,
1889        };
1890        match f(&mut nested) {
1891            Ok(val) => {
1892                nested.commit()?;
1893                Ok(val)
1894            }
1895            Err(e) => {
1896                let _ = nested.rollback();
1897                Err(e)
1898            }
1899        }
1900    }
1901
1902    /// Execute a simple query (no parameters).
1903    pub fn query_simple(&mut self, sql: &str) -> PgResult<Vec<Row>> {
1904        self.conn.query_simple(sql)
1905    }
1906
1907    /// Execute a parameterized query.
1908    pub fn query(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Vec<Row>> {
1909        self.conn.query(sql, params)
1910    }
1911
1912    /// Execute a query expecting exactly one row.
1913    pub fn query_one(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Row> {
1914        self.conn.query_one(sql, params)
1915    }
1916
1917    /// Execute a query expecting zero or one row.
1918    pub fn query_opt(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Option<Row>> {
1919        self.conn.query_opt(sql, params)
1920    }
1921
1922    /// Execute a statement that returns no rows.
1923    pub fn execute(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<u64> {
1924        self.conn.execute(sql, params)
1925    }
1926
1927    /// Create a savepoint within this transaction.
1928    pub fn savepoint(&mut self, name: &str) -> PgResult<()> {
1929        self.conn.savepoint(name)
1930    }
1931
1932    /// Rollback to a savepoint.
1933    pub fn rollback_to(&mut self, name: &str) -> PgResult<()> {
1934        self.conn.rollback_to(name)
1935    }
1936
1937    /// Release a savepoint.
1938    pub fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
1939        self.conn.release_savepoint(name)
1940    }
1941
1942    /// Get the transaction status.
1943    pub fn status(&self) -> TransactionStatus {
1944        self.conn.transaction_status()
1945    }
1946}
1947
1948impl<'a> Drop for Transaction<'a> {
1949    fn drop(&mut self) {
1950        if !self.finished {
1951            // Auto-rollback on drop (savepoint if nested, full rollback otherwise)
1952            if let Some(ref name) = self.savepoint_name {
1953                let _ = self.conn.rollback_to(name);
1954            } else {
1955                let _ = self.conn.rollback();
1956            }
1957        }
1958    }
1959}
1960
1961// ─── COPY Writer ──────────────────────────────────────────────
1962
1963/// COPY writer for streaming data into PostgreSQL via COPY FROM STDIN.
1964pub struct CopyWriter<'a> {
1965    conn: &'a mut PgConnection,
1966}
1967
1968impl<'a> CopyWriter<'a> {
1969    /// Write a chunk of COPY data.
1970    pub fn write_data(&mut self, data: &[u8]) -> PgResult<()> {
1971        self.conn.ensure_write_capacity(5 + data.len());
1972        let n = codec::encode_copy_data(&mut self.conn.write_buf, data);
1973        self.conn.flush_write_buf(n)
1974    }
1975
1976    /// Abort the COPY operation with an error message.
1977    ///
1978    /// Sends a CopyFail message to the server. The server will respond
1979    /// with an ErrorResponse and then ReadyForQuery. The connection
1980    /// remains usable after this call.
1981    pub fn fail(self, reason: &str) -> PgResult<()> {
1982        self.conn.ensure_write_capacity(6 + reason.len());
1983        let n = codec::encode_copy_fail(&mut self.conn.write_buf, reason);
1984        self.conn.flush_write_buf(n)?;
1985
1986        // Drain to ReadyForQuery (server sends ErrorResponse first)
1987        loop {
1988            self.conn.fill_read_buf(None)?;
1989            while let Some(msg_len) =
1990                codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])
1991            {
1992                let header = codec::decode_header(&self.conn.read_buf)
1993                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1994                match header.tag {
1995                    BackendTag::ErrorResponse => {
1996                        // Expected — server acknowledges the CopyFail
1997                        self.conn.consume_read(msg_len);
1998                    }
1999                    BackendTag::ReadyForQuery => {
2000                        let body = &self.conn.read_buf[5..msg_len];
2001                        self.conn.tx_status = TransactionStatus::from(body[0]);
2002                        self.conn.consume_read(msg_len);
2003                        return Ok(());
2004                    }
2005                    _ => {
2006                        self.conn.consume_read(msg_len);
2007                    }
2008                }
2009            }
2010        }
2011    }
2012
2013    /// Write a text row (tab-separated values with newline).
2014    pub fn write_row(&mut self, columns: &[&str]) -> PgResult<()> {
2015        let line = columns.join("\t") + "\n";
2016        self.write_data(line.as_bytes())
2017    }
2018
2019    /// Finish the COPY operation successfully.
2020    pub fn finish(self) -> PgResult<u64> {
2021        let n = codec::encode_copy_done(&mut self.conn.write_buf);
2022        self.conn.flush_write_buf(n)?;
2023
2024        // Drain to ReadyForQuery
2025        loop {
2026            self.conn.fill_read_buf(None)?;
2027            while let Some(msg_len) =
2028                codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])
2029            {
2030                let header = codec::decode_header(&self.conn.read_buf)
2031                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
2032                let body = &self.conn.read_buf[5..msg_len];
2033                match header.tag {
2034                    BackendTag::CommandComplete => {
2035                        let (tag, rows_affected) = extract_command_complete(body);
2036                        self.conn.last_command_tag = tag;
2037                        self.conn.last_affected_rows = rows_affected;
2038                    }
2039                    BackendTag::ReadyForQuery => {
2040                        self.conn.tx_status = TransactionStatus::from(body[0]);
2041                        self.conn.consume_read(msg_len);
2042                        return Ok(self.conn.last_affected_rows);
2043                    }
2044                    BackendTag::ErrorResponse => {
2045                        let err = self.conn.parse_error(body);
2046                        self.conn.consume_read(msg_len);
2047                        return Err(err);
2048                    }
2049                    _ => {}
2050                }
2051                self.conn.consume_read(msg_len);
2052            }
2053        }
2054    }
2055}
2056
2057// ─── COPY Reader ──────────────────────────────────────────────
2058
2059/// COPY reader for receiving data from PostgreSQL via COPY TO STDOUT.
2060pub struct CopyReader<'a> {
2061    conn: &'a mut PgConnection,
2062    done: bool,
2063}
2064
2065impl<'a> CopyReader<'a> {
2066    /// Read the next chunk of COPY data.
2067    /// Returns None when the COPY operation is complete.
2068    pub fn read_data(&mut self) -> PgResult<Option<Vec<u8>>> {
2069        if self.done {
2070            return Ok(None);
2071        }
2072
2073        loop {
2074            // Only refill the buffer when it doesn't already hold a complete
2075            // message.  If a previous fill read the entire COPY response in
2076            // one TCP segment, the CopyData / CopyDone / ReadyForQuery bytes
2077            // are already in `read_buf` and we must not block waiting for
2078            // more socket data before processing them.
2079            if codec::message_complete(&self.conn.read_buf[..self.conn.read_pos]).is_none() {
2080                self.conn.fill_read_buf(None)?;
2081            }
2082
2083            while let Some(msg_len) =
2084                codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])
2085            {
2086                let header = codec::decode_header(&self.conn.read_buf)
2087                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
2088                let body = &self.conn.read_buf[5..msg_len];
2089
2090                match header.tag {
2091                    BackendTag::CopyData => {
2092                        let data = body.to_vec();
2093                        self.conn.consume_read(msg_len);
2094                        return Ok(Some(data));
2095                    }
2096                    BackendTag::CopyDone => {
2097                        self.conn.consume_read(msg_len);
2098                        // Continue to receive CommandComplete + ReadyForQuery
2099                    }
2100                    BackendTag::CommandComplete => {
2101                        let (tag, rows_affected) = extract_command_complete(body);
2102                        self.conn.last_command_tag = tag;
2103                        self.conn.last_affected_rows = rows_affected;
2104                        self.conn.consume_read(msg_len);
2105                    }
2106                    BackendTag::ReadyForQuery => {
2107                        self.conn.tx_status = TransactionStatus::from(body[0]);
2108                        self.conn.consume_read(msg_len);
2109                        self.done = true;
2110                        return Ok(None);
2111                    }
2112                    BackendTag::ErrorResponse => {
2113                        let err = self.conn.parse_error(body);
2114                        self.conn.consume_read(msg_len);
2115                        self.done = true;
2116                        return Err(err);
2117                    }
2118                    _ => {
2119                        self.conn.consume_read(msg_len);
2120                    }
2121                }
2122            }
2123        }
2124    }
2125
2126    /// Read all remaining COPY data into a single Vec.
2127    pub fn read_all(&mut self) -> PgResult<Vec<u8>> {
2128        let mut result = Vec::new();
2129        while let Some(chunk) = self.read_data()? {
2130            result.extend_from_slice(&chunk);
2131        }
2132        Ok(result)
2133    }
2134
2135    /// Check if the COPY operation is complete.
2136    pub fn is_done(&self) -> bool {
2137        self.done
2138    }
2139}
2140
2141#[cfg(test)]
2142mod tests {
2143    use super::*;
2144
2145    // ─── PgConfig::new ────────────────────────────────────────────────────────
2146
2147    #[test]
2148    fn test_pgconfig_new_fields() {
2149        let cfg = PgConfig::new("db.example.com", 5432, "alice", "s3cret", "mydb");
2150        assert_eq!(cfg.host, "db.example.com");
2151        assert_eq!(cfg.port, 5432);
2152        assert_eq!(cfg.user, "alice");
2153        assert_eq!(cfg.password, "s3cret");
2154        assert_eq!(cfg.database, "mydb");
2155        assert!(cfg.socket_dir.is_none());
2156    }
2157
2158    #[test]
2159    fn test_pgconfig_new_custom_port() {
2160        let cfg = PgConfig::new("host", 9999, "u", "p", "d");
2161        assert_eq!(cfg.port, 9999);
2162    }
2163
2164    #[test]
2165    fn test_pgconfig_with_socket_dir_sets_field() {
2166        let cfg =
2167            PgConfig::new("localhost", 5432, "u", "p", "d").with_socket_dir("/var/run/postgresql");
2168        assert_eq!(cfg.socket_dir.as_deref(), Some("/var/run/postgresql"));
2169    }
2170
2171    #[test]
2172    fn test_pgconfig_clone_preserves_all_fields() {
2173        let cfg = PgConfig::new("h", 1234, "u", "p", "db").with_socket_dir("/tmp");
2174        let cloned = cfg.clone();
2175        assert_eq!(cloned.host, "h");
2176        assert_eq!(cloned.port, 1234);
2177        assert_eq!(cloned.user, "u");
2178        assert_eq!(cloned.password, "p");
2179        assert_eq!(cloned.database, "db");
2180        assert_eq!(cloned.socket_dir, Some("/tmp".to_string()));
2181    }
2182
2183    #[test]
2184    fn test_pgconfig_debug_contains_host() {
2185        let cfg = PgConfig::new("myhost", 5432, "u", "p", "d");
2186        let s = format!("{:?}", cfg);
2187        assert!(s.contains("myhost"), "Debug must include host: {}", s);
2188    }
2189
2190    // ─── PgConfig::from_url — happy paths ────────────────────────────────────
2191
2192    #[test]
2193    fn test_from_url_basic_postgres_scheme() {
2194        let cfg = PgConfig::from_url("postgres://bob:hunter2@dbhost:5432/appdb").unwrap();
2195        assert_eq!(cfg.host, "dbhost");
2196        assert_eq!(cfg.port, 5432);
2197        assert_eq!(cfg.user, "bob");
2198        assert_eq!(cfg.password, "hunter2");
2199        assert_eq!(cfg.database, "appdb");
2200        assert!(cfg.socket_dir.is_none());
2201    }
2202
2203    #[test]
2204    fn test_from_url_postgresql_scheme() {
2205        let cfg = PgConfig::from_url("postgresql://u:p@host:5432/db").unwrap();
2206        assert_eq!(cfg.host, "host");
2207        assert_eq!(cfg.user, "u");
2208    }
2209
2210    #[test]
2211    fn test_from_url_default_port() {
2212        // When no port is given, should default to 5432
2213        let cfg = PgConfig::from_url("postgres://u:p@myhost/mydb").unwrap();
2214        assert_eq!(cfg.port, 5432);
2215        assert_eq!(cfg.host, "myhost");
2216    }
2217
2218    #[test]
2219    fn test_from_url_no_password() {
2220        // user only, no colon → password is empty string
2221        let cfg = PgConfig::from_url("postgres://alice@host:5432/db").unwrap();
2222        assert_eq!(cfg.user, "alice");
2223        assert_eq!(cfg.password, "");
2224    }
2225
2226    #[test]
2227    fn test_from_url_custom_port() {
2228        let cfg = PgConfig::from_url("postgres://u:p@host:9000/db").unwrap();
2229        assert_eq!(cfg.port, 9000);
2230    }
2231
2232    #[test]
2233    fn test_from_url_unix_socket_query_param() {
2234        let cfg = PgConfig::from_url("postgres://u:p@/db?host=/var/run/postgresql").unwrap();
2235        assert_eq!(cfg.socket_dir.as_deref(), Some("/var/run/postgresql"));
2236        assert_eq!(cfg.database, "db");
2237    }
2238
2239    #[test]
2240    fn test_from_url_unix_socket_percent_encoded() {
2241        let cfg = PgConfig::from_url("postgres://u:p@%2Fvar%2Frun%2Fpostgresql/db").unwrap();
2242        assert_eq!(cfg.socket_dir.as_deref(), Some("/var/run/postgresql"));
2243        assert_eq!(cfg.database, "db");
2244    }
2245
2246    // ─── PgConfig::from_url — error paths ────────────────────────────────────
2247
2248    #[test]
2249    fn test_from_url_invalid_scheme_errors() {
2250        let result = PgConfig::from_url("mysql://u:p@host/db");
2251        assert!(result.is_err(), "Non-postgres scheme must fail");
2252    }
2253
2254    #[test]
2255    fn test_from_url_missing_at_symbol_errors() {
2256        let result = PgConfig::from_url("postgres://no-at-sign/db");
2257        assert!(result.is_err(), "URL without @ must fail");
2258    }
2259
2260    #[test]
2261    fn test_from_url_missing_database_errors() {
2262        // No "/" after host — no database segment
2263        let result = PgConfig::from_url("postgres://u:p@host");
2264        assert!(result.is_err(), "URL without database must fail");
2265    }
2266
2267    #[test]
2268    fn test_from_url_invalid_port_errors() {
2269        let result = PgConfig::from_url("postgres://u:p@host:notaport/db");
2270        assert!(result.is_err(), "Non-numeric port must fail");
2271    }
2272
2273    #[test]
2274    fn test_from_url_empty_string_errors() {
2275        let result = PgConfig::from_url("");
2276        assert!(result.is_err());
2277    }
2278
2279    #[test]
2280    fn test_from_url_special_chars_in_password() {
2281        // Passwords with @ in them need proper URL encoding, but basic case works
2282        let cfg = PgConfig::from_url("postgres://user:p%40ss@host:5432/db");
2283        // Parser finds last @ — this might fail or partially parse; just verify no panic
2284        let _ = cfg; // result can be Ok or Err, but must not panic
2285    }
2286
2287    // ─── Notification struct ──────────────────────────────────────────────────
2288
2289    #[test]
2290    fn test_notification_fields() {
2291        let n = Notification {
2292            process_id: 12345,
2293            channel: "my_channel".to_string(),
2294            payload: "hello world".to_string(),
2295        };
2296        assert_eq!(n.process_id, 12345);
2297        assert_eq!(n.channel, "my_channel");
2298        assert_eq!(n.payload, "hello world");
2299    }
2300
2301    #[test]
2302    fn test_notification_clone() {
2303        let n = Notification {
2304            process_id: 42,
2305            channel: "ch".to_string(),
2306            payload: "pay".to_string(),
2307        };
2308        let n2 = n.clone();
2309        assert_eq!(n2.process_id, n.process_id);
2310        assert_eq!(n2.channel, n.channel);
2311        assert_eq!(n2.payload, n.payload);
2312    }
2313
2314    #[test]
2315    fn test_notification_debug() {
2316        let n = Notification {
2317            process_id: 1,
2318            channel: "c".to_string(),
2319            payload: "p".to_string(),
2320        };
2321        let s = format!("{:?}", n);
2322        assert!(
2323            s.contains("process_id"),
2324            "Debug must include process_id: {}",
2325            s
2326        );
2327    }
2328}