Skip to main content

chopin_pg/
connection.rs

1//! PostgreSQL connection with poll-based non-blocking I/O.
2//!
3//! Designed for integration with a thread-per-core event loop. The socket is set
4//! to **non-blocking mode** at connect time; all reads and writes go through
5//! `try_fill_read_buf` / `try_flush_write_buf` which return `WouldBlock` when
6//! the socket is not ready. Higher-level methods use `poll_read` / `poll_write`
7//! with a configurable timeout so the caller can integrate with epoll/kqueue.
8//!
9//! Features:
10//! - **Non-blocking I/O** with application-level timeouts
11//! - SCRAM-SHA-256 and cleartext authentication
12//! - Extended Query Protocol with implicit statement caching
13//! - Transaction support with safe closure-based API
14//! - COPY IN (writer) and COPY OUT (reader)
15//! - LISTEN/NOTIFY with notification buffering
16//! - Proper affected row count from CommandComplete
17//! - Raw socket fd accessor for event-loop registration
18
19use std::collections::VecDeque;
20use std::io::{Read, Write};
21use std::net::TcpStream;
22#[cfg(unix)]
23use std::os::unix::io::AsRawFd;
24#[cfg(unix)]
25use std::os::unix::net::UnixStream;
26use std::rc::Rc;
27use std::time::{Duration, Instant};
28
29use crate::auth::ScramClient;
30use crate::codec;
31use crate::error::{PgError, PgResult};
32use crate::protocol::*;
33use crate::row::Row;
34use crate::statement::StatementCache;
35#[cfg(feature = "tls")]
36use crate::tls;
37use crate::types::{PgValue, ToSql};
38
39/// Default I/O timeout for poll operations (5 seconds).
40const DEFAULT_IO_TIMEOUT: Duration = Duration::from_secs(5);
41
42// ─── Stream Abstraction ──────────────────────────────────────
43
44/// Unified stream type supporting TCP, Unix domain sockets, and TLS.
45enum PgStream {
46    Tcp(TcpStream),
47    #[cfg(unix)]
48    Unix(UnixStream),
49    #[cfg(feature = "tls")]
50    Tls(tls::TlsStream),
51}
52
53impl Read for PgStream {
54    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
55        match self {
56            PgStream::Tcp(s) => s.read(buf),
57            #[cfg(unix)]
58            PgStream::Unix(s) => s.read(buf),
59            #[cfg(feature = "tls")]
60            PgStream::Tls(s) => s.read(buf),
61        }
62    }
63}
64
65impl Write for PgStream {
66    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
67        match self {
68            PgStream::Tcp(s) => s.write(buf),
69            #[cfg(unix)]
70            PgStream::Unix(s) => s.write(buf),
71            #[cfg(feature = "tls")]
72            PgStream::Tls(s) => s.write(buf),
73        }
74    }
75
76    fn flush(&mut self) -> std::io::Result<()> {
77        match self {
78            PgStream::Tcp(s) => s.flush(),
79            #[cfg(unix)]
80            PgStream::Unix(s) => s.flush(),
81            #[cfg(feature = "tls")]
82            PgStream::Tls(s) => s.flush(),
83        }
84    }
85
86    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
87        match self {
88            PgStream::Tcp(s) => s.write_all(buf),
89            #[cfg(unix)]
90            PgStream::Unix(s) => s.write_all(buf),
91            #[cfg(feature = "tls")]
92            PgStream::Tls(s) => s.write_all(buf),
93        }
94    }
95}
96
97impl PgStream {
98    fn set_nonblocking(&self, nonblocking: bool) -> std::io::Result<()> {
99        match self {
100            PgStream::Tcp(s) => s.set_nonblocking(nonblocking),
101            #[cfg(unix)]
102            PgStream::Unix(s) => s.set_nonblocking(nonblocking),
103            #[cfg(feature = "tls")]
104            PgStream::Tls(s) => s.set_nonblocking(nonblocking),
105        }
106    }
107
108    #[cfg(unix)]
109    fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
110        match self {
111            PgStream::Tcp(s) => s.as_raw_fd(),
112            PgStream::Unix(s) => s.as_raw_fd(),
113            #[cfg(feature = "tls")]
114            PgStream::Tls(s) => s.as_raw_fd(),
115        }
116    }
117
118    /// Get the SHA-256 hash of the server's TLS certificate for channel binding.
119    /// Returns `None` if the connection is not TLS or no cert is available.
120    #[cfg(feature = "tls")]
121    fn tls_server_cert_hash(&self) -> Option<Vec<u8>> {
122        match self {
123            PgStream::Tls(s) => s.server_cert_hash(),
124            _ => None,
125        }
126    }
127}
128
129/// Connection configuration.
130#[derive(Debug, Clone)]
131pub struct PgConfig {
132    pub host: String,
133    pub port: u16,
134    pub user: String,
135    pub password: String,
136    pub database: String,
137    /// Optional Unix domain socket directory.
138    /// When set, connect via `<socket_dir>/.s.PGSQL.<port>` instead of TCP.
139    pub socket_dir: Option<String>,
140    /// SSL/TLS mode. Only effective when the `tls` feature is enabled.
141    /// Default: `Prefer` (try TLS, fall back to plaintext).
142    #[cfg(feature = "tls")]
143    pub ssl_mode: tls::SslMode,
144}
145
146impl PgConfig {
147    pub fn new(host: &str, port: u16, user: &str, password: &str, database: &str) -> Self {
148        Self {
149            host: host.to_string(),
150            port,
151            user: user.to_string(),
152            password: password.to_string(),
153            database: database.to_string(),
154            socket_dir: None,
155            #[cfg(feature = "tls")]
156            ssl_mode: tls::SslMode::default(),
157        }
158    }
159
160    /// Set a Unix domain socket directory for the connection.
161    /// The actual socket path will be `<dir>/.s.PGSQL.<port>`.
162    pub fn with_socket_dir(mut self, dir: &str) -> Self {
163        self.socket_dir = Some(dir.to_string());
164        self
165    }
166
167    /// Set the SSL/TLS mode for the connection.
168    #[cfg(feature = "tls")]
169    pub fn with_ssl_mode(mut self, mode: tls::SslMode) -> Self {
170        self.ssl_mode = mode;
171        self
172    }
173
174    /// Parse from a connection string: `postgres://user:pass@host:port/db`
175    ///
176    /// For Unix domain sockets, use a path as the host:
177    /// `postgres://user:pass@%2Fvar%2Frun%2Fpostgresql/db`  (URL-encoded slashes)
178    /// or `postgres://user:pass@/db?host=/var/run/postgresql`
179    pub fn from_url(url: &str) -> PgResult<Self> {
180        let url = url
181            .strip_prefix("postgres://")
182            .or_else(|| url.strip_prefix("postgresql://"))
183            .ok_or_else(|| PgError::Protocol("Invalid URL scheme".to_string()))?;
184
185        // user:pass@host:port/db
186        let (userpass, hostdb) = url
187            .split_once('@')
188            .ok_or_else(|| PgError::Protocol("Missing @ in URL".to_string()))?;
189        let (user, password) = userpass.split_once(':').unwrap_or((userpass, ""));
190
191        // Check for ?host= query parameter (Unix socket)
192        let (hostdb_part, query_part) = hostdb.split_once('?').unwrap_or((hostdb, ""));
193
194        let (hostport, database) = hostdb_part
195            .split_once('/')
196            .ok_or_else(|| PgError::Protocol("Missing database in URL".to_string()))?;
197
198        // Parse query params for socket dir and sslmode
199        let mut socket_dir: Option<String> = None;
200        #[cfg(feature = "tls")]
201        let mut ssl_mode = tls::SslMode::default();
202        if !query_part.is_empty() {
203            for param in query_part.split('&') {
204                if let Some(value) = param.strip_prefix("host=")
205                    && value.starts_with('/')
206                {
207                    socket_dir = Some(value.to_string());
208                }
209                #[cfg(feature = "tls")]
210                if let Some(value) = param.strip_prefix("sslmode=") {
211                    if let Some(mode) = tls::SslMode::parse(value) {
212                        ssl_mode = mode;
213                    }
214                }
215            }
216        }
217
218        // Decode percent-encoded host (e.g., %2Fvar%2Frun -> /var/run)
219        let decoded_host = percent_decode(hostport);
220        let is_unix_path = decoded_host.starts_with('/');
221        if is_unix_path {
222            socket_dir = Some(decoded_host);
223        }
224
225        let (host, port) = if socket_dir.is_some() {
226            // Unix socket — host is irrelevant, use default port
227            let port_str = if hostport.is_empty() || is_unix_path {
228                "5432"
229            } else {
230                hostport.rsplit_once(':').map(|(_, p)| p).unwrap_or("5432")
231            };
232            let port: u16 = port_str
233                .parse()
234                .map_err(|_| PgError::Protocol("Invalid port".to_string()))?;
235            ("localhost".to_string(), port)
236        } else {
237            let (h, port_str) = hostport.split_once(':').unwrap_or((hostport, "5432"));
238            let port: u16 = port_str
239                .parse()
240                .map_err(|_| PgError::Protocol("Invalid port".to_string()))?;
241            (h.to_string(), port)
242        };
243
244        Ok(Self {
245            host,
246            port,
247            user: user.to_string(),
248            password: password.to_string(),
249            database: database.to_string(),
250            socket_dir,
251            #[cfg(feature = "tls")]
252            ssl_mode,
253        })
254    }
255}
256
257/// Minimal percent-decoding for URL host component.
258fn percent_decode(input: &str) -> String {
259    let mut result = String::with_capacity(input.len());
260    let bytes = input.as_bytes();
261    let mut i = 0;
262    while i < bytes.len() {
263        if bytes[i] == b'%'
264            && i + 2 < bytes.len()
265            && let (Some(hi), Some(lo)) = (hex_digit(bytes[i + 1]), hex_digit(bytes[i + 2]))
266        {
267            result.push((hi << 4 | lo) as char);
268            i += 3;
269            continue;
270        }
271        result.push(bytes[i] as char);
272        i += 1;
273    }
274    result
275}
276
277fn hex_digit(b: u8) -> Option<u8> {
278    match b {
279        b'0'..=b'9' => Some(b - b'0'),
280        b'a'..=b'f' => Some(b - b'a' + 10),
281        b'A'..=b'F' => Some(b - b'A' + 10),
282        _ => None,
283    }
284}
285
286/// A notification received via LISTEN/NOTIFY.
287#[derive(Debug, Clone)]
288pub struct Notification {
289    /// Process ID of the notifying backend.
290    pub process_id: i32,
291    /// Channel name.
292    pub channel: String,
293    /// Payload string.
294    pub payload: String,
295}
296
297/// Type alias for a notice handler function pointer.
298type NoticeHandler = Box<dyn Fn(&str, &str, &str) + Send + Sync>;
299
300/// A synchronous PostgreSQL connection with poll-based non-blocking I/O.
301///
302/// The socket is set to non-blocking mode at connect time.  I/O methods
303/// internally poll with a configurable timeout so they can be adapted to
304/// an event-loop's readiness notifications.
305pub struct PgConnection {
306    stream: PgStream,
307    read_buf: Vec<u8>,
308    write_buf: Vec<u8>,
309    read_pos: usize,
310    tx_status: TransactionStatus,
311    stmt_cache: StatementCache,
312    process_id: i32,
313    secret_key: i32,
314    server_params: Vec<(String, String)>,
315    /// Buffered notifications received during query processing.
316    notifications: VecDeque<Notification>,
317    /// Number of rows affected by the last command (from CommandComplete).
318    last_affected_rows: u64,
319    /// The last CommandComplete tag string.
320    last_command_tag: String,
321    /// Whether the socket is in non-blocking mode.
322    nonblocking: bool,
323    /// Application-level I/O timeout for poll operations.
324    io_timeout: Duration,
325    /// Optional callback invoked when the server sends a NoticeResponse.
326    notice_handler: Option<NoticeHandler>,
327    /// Flag set on fatal I/O errors. A broken connection must not be
328    /// returned to the pool; it will be discarded on drop.
329    broken: bool,
330}
331
332impl PgConnection {
333    /// Connect to PostgreSQL (blocking during handshake, then switches to
334    /// non-blocking mode once authentication completes).
335    pub fn connect(config: &PgConfig) -> PgResult<Self> {
336        let stream = if let Some(ref socket_dir) = config.socket_dir {
337            // Unix domain socket connection
338            #[cfg(unix)]
339            {
340                let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, config.port);
341                let unix_stream = UnixStream::connect(&socket_path).map_err(PgError::Io)?;
342                PgStream::Unix(unix_stream)
343            }
344            #[cfg(not(unix))]
345            {
346                let _ = socket_dir;
347                return Err(PgError::Protocol(
348                    "Unix domain sockets are not supported on this platform".to_string(),
349                ));
350            }
351        } else {
352            let addr = format!("{}:{}", config.host, config.port);
353            let tcp = TcpStream::connect(&addr).map_err(PgError::Io)?;
354            // Disable Nagle's algorithm for lower latency
355            let _ = tcp.set_nodelay(true);
356
357            // TLS negotiation (when feature enabled)
358            #[cfg(feature = "tls")]
359            {
360                match config.ssl_mode {
361                    tls::SslMode::Disable => PgStream::Tcp(tcp),
362                    tls::SslMode::Prefer => match tls::negotiate(tcp, &config.host) {
363                        Ok(tls::TlsNegotiateResult::Tls(tls_stream)) => PgStream::Tls(tls_stream),
364                        Ok(tls::TlsNegotiateResult::Rejected(tcp)) => PgStream::Tcp(tcp),
365                        Err(_) => {
366                            // TLS negotiation failed — reconnect plain-text
367                            let tcp = TcpStream::connect(&addr).map_err(PgError::Io)?;
368                            let _ = tcp.set_nodelay(true);
369                            PgStream::Tcp(tcp)
370                        }
371                    },
372                    tls::SslMode::Require => match tls::negotiate(tcp, &config.host)? {
373                        tls::TlsNegotiateResult::Tls(tls_stream) => PgStream::Tls(tls_stream),
374                        tls::TlsNegotiateResult::Rejected(_) => {
375                            return Err(PgError::Protocol(
376                                "Server does not support SSL (sslmode=require)".to_string(),
377                            ));
378                        }
379                    },
380                }
381            }
382
383            #[cfg(not(feature = "tls"))]
384            PgStream::Tcp(tcp)
385        };
386
387        let mut conn = Self {
388            stream,
389            read_buf: vec![0u8; 64 * 1024],  // 64 KB read buffer
390            write_buf: vec![0u8; 64 * 1024], // 64 KB write buffer
391            read_pos: 0,
392            tx_status: TransactionStatus::Idle,
393            stmt_cache: StatementCache::new(),
394            process_id: 0,
395            secret_key: 0,
396            server_params: Vec::new(),
397            notifications: VecDeque::new(),
398            last_affected_rows: 0,
399            last_command_tag: String::new(),
400            nonblocking: false,
401            io_timeout: DEFAULT_IO_TIMEOUT,
402            notice_handler: None,
403            broken: false,
404        };
405
406        conn.startup(config)?;
407
408        // Switch to non-blocking after successful authentication
409        conn.stream.set_nonblocking(true).map_err(PgError::Io)?;
410        conn.nonblocking = true;
411
412        Ok(conn)
413    }
414
415    /// Connect with a custom I/O timeout.
416    pub fn connect_with_timeout(config: &PgConfig, timeout: Duration) -> PgResult<Self> {
417        let mut conn = Self::connect(config)?;
418        conn.io_timeout = timeout;
419        Ok(conn)
420    }
421
422    /// Set the application-level I/O timeout.
423    pub fn set_io_timeout(&mut self, timeout: Duration) {
424        self.io_timeout = timeout;
425    }
426
427    /// Get the current I/O timeout.
428    pub fn io_timeout(&self) -> Duration {
429        self.io_timeout
430    }
431
432    /// Set a callback that is invoked when the server sends a NoticeResponse.
433    ///
434    /// The callback receives `(severity, code, message)`. This is useful for
435    /// logging warnings, deprecation notices, etc.
436    ///
437    /// # Example
438    /// ```ignore
439    /// conn.set_notice_handler(|severity, code, message| {
440    ///     eprintln!("PG {}: {} ({})", severity, message, code);
441    /// });
442    /// ```
443    pub fn set_notice_handler<F>(&mut self, handler: F)
444    where
445        F: Fn(&str, &str, &str) + Send + Sync + 'static,
446    {
447        self.notice_handler = Some(Box::new(handler));
448    }
449
450    /// Remove the notice handler.
451    pub fn clear_notice_handler(&mut self) {
452        self.notice_handler = None;
453    }
454
455    /// Set the maximum number of statements to cache before LRU eviction.
456    pub fn set_statement_cache_capacity(&mut self, capacity: usize) {
457        self.stmt_cache.set_max_capacity(capacity);
458    }
459
460    /// Return the raw file descriptor for event-loop registration
461    /// (epoll / kqueue).
462    #[cfg(unix)]
463    pub fn raw_fd(&self) -> std::os::unix::io::RawFd {
464        self.stream.as_raw_fd()
465    }
466
467    /// Check if the socket is in non-blocking mode.
468    pub fn is_nonblocking(&self) -> bool {
469        self.nonblocking
470    }
471
472    /// Set non-blocking mode on the socket.
473    pub fn set_nonblocking(&mut self, nonblocking: bool) -> PgResult<()> {
474        self.stream
475            .set_nonblocking(nonblocking)
476            .map_err(PgError::Io)?;
477        self.nonblocking = nonblocking;
478        Ok(())
479    }
480
481    /// Perform the startup and authentication handshake.
482    fn startup(&mut self, config: &PgConfig) -> PgResult<()> {
483        // Send StartupMessage
484        self.ensure_write_capacity(512);
485        let n = codec::encode_startup(&mut self.write_buf, &config.user, &config.database, &[]);
486        self.stream
487            .write_all(&self.write_buf[..n])
488            .map_err(PgError::Io)?;
489
490        // Read server response
491        loop {
492            self.fill_read_buf(None)?;
493
494            while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? {
495                let header = codec::decode_header(&self.read_buf)
496                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
497                let body = &self.read_buf[5..msg_len];
498
499                match header.tag {
500                    BackendTag::AuthenticationRequest => {
501                        let auth_type = codec::read_i32(&self.read_buf, 5);
502                        match AuthType::from_i32(auth_type) {
503                            Some(AuthType::Ok) => {
504                                // Handled! Keep going to ReadyForQuery
505                            }
506                            Some(AuthType::CleartextPassword) => {
507                                let n =
508                                    codec::encode_password(&mut self.write_buf, &config.password);
509                                self.stream
510                                    .write_all(&self.write_buf[..n])
511                                    .map_err(PgError::Io)?;
512                            }
513                            Some(AuthType::SASLInit) => {
514                                // B.3: Use SCRAM-SHA-256-PLUS with channel binding when TLS is active
515                                #[cfg(feature = "tls")]
516                                let (mut scram, mechanism) =
517                                    if let Some(cb_data) = self.stream.tls_server_cert_hash() {
518                                        (
519                                            ScramClient::new_with_channel_binding(
520                                                &config.user,
521                                                &config.password,
522                                                cb_data,
523                                            ),
524                                            "SCRAM-SHA-256-PLUS",
525                                        )
526                                    } else {
527                                        (
528                                            ScramClient::new(&config.user, &config.password),
529                                            "SCRAM-SHA-256",
530                                        )
531                                    };
532                                #[cfg(not(feature = "tls"))]
533                                let (mut scram, mechanism) = (
534                                    ScramClient::new(&config.user, &config.password),
535                                    "SCRAM-SHA-256",
536                                );
537
538                                let client_first = scram.client_first_message();
539                                let n = codec::encode_sasl_initial(
540                                    &mut self.write_buf,
541                                    mechanism,
542                                    &client_first,
543                                );
544                                self.stream
545                                    .write_all(&self.write_buf[..n])
546                                    .map_err(PgError::Io)?;
547
548                                self.consume_read(msg_len);
549                                self.wait_for_sasl_continue(&mut scram, config)?;
550                                // After SASL, we might still have messages in the buffer
551                                // so we don't return, we continue the outer loop.
552                                continue;
553                            }
554                            Some(AuthType::MD5Password) => {
555                                // Salt is 4 bytes following the auth_type int32
556                                if body.len() < 8 {
557                                    return Err(PgError::Protocol(
558                                        "MD5Password message too short".to_string(),
559                                    ));
560                                }
561                                let salt: [u8; 4] = [body[4], body[5], body[6], body[7]];
562                                let hash = crate::auth::md5_password_hash(
563                                    &config.user,
564                                    &config.password,
565                                    &salt,
566                                );
567                                let n = codec::encode_password(&mut self.write_buf, &hash);
568                                self.stream
569                                    .write_all(&self.write_buf[..n])
570                                    .map_err(PgError::Io)?;
571                            }
572                            _ => {
573                                return Err(PgError::Auth(format!(
574                                    "Unsupported auth type: {}",
575                                    auth_type
576                                )));
577                            }
578                        }
579                    }
580                    BackendTag::ParameterStatus => {
581                        let (name, consumed) = codec::read_cstring(body, 0);
582                        let (value, _) = codec::read_cstring(body, consumed);
583                        self.server_params
584                            .push((name.to_string(), value.to_string()));
585                    }
586                    BackendTag::BackendKeyData => {
587                        self.process_id = codec::read_i32(body, 0);
588                        self.secret_key = codec::read_i32(body, 4);
589                    }
590                    BackendTag::ReadyForQuery => {
591                        self.tx_status = TransactionStatus::from(body[0]);
592                        self.consume_read(msg_len);
593                        return Ok(()); // Connection is ready!
594                    }
595                    BackendTag::ErrorResponse => {
596                        let fields = codec::parse_error_fields(body);
597                        return Err(PgError::from_fields(&fields));
598                    }
599                    _ => {
600                        // Skip unknown messages
601                    }
602                }
603                self.consume_read(msg_len);
604            }
605        }
606    }
607
608    /// Handle SASL Continue/Final exchange.
609    fn wait_for_sasl_continue(
610        &mut self,
611        scram: &mut ScramClient,
612        _config: &PgConfig,
613    ) -> PgResult<()> {
614        loop {
615            self.fill_read_buf(None)?;
616
617            while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? {
618                let header = codec::decode_header(&self.read_buf)
619                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
620                let body = &self.read_buf[5..msg_len].to_vec();
621
622                match header.tag {
623                    BackendTag::AuthenticationRequest => {
624                        let auth_type = codec::read_i32(&self.read_buf, 5);
625                        match AuthType::from_i32(auth_type) {
626                            Some(AuthType::SASLContinue) => {
627                                let server_first = &body[4..];
628                                let client_final = scram
629                                    .process_server_first(server_first)
630                                    .map_err(PgError::Auth)?;
631
632                                let n =
633                                    codec::encode_sasl_response(&mut self.write_buf, &client_final);
634                                self.stream
635                                    .write_all(&self.write_buf[..n])
636                                    .map_err(PgError::Io)?;
637                            }
638                            Some(AuthType::SASLFinal) => {
639                                let server_final = &body[4..];
640                                scram
641                                    .verify_server_final(server_final)
642                                    .map_err(PgError::Auth)?;
643                            }
644                            Some(AuthType::Ok) => {
645                                self.consume_read(msg_len);
646                                return Ok(());
647                            }
648                            _ => {
649                                return Err(PgError::Auth(
650                                    "Unexpected auth message during SASL".to_string(),
651                                ));
652                            }
653                        }
654                    }
655                    _ => {
656                        // Skip
657                    }
658                }
659                self.consume_read(msg_len);
660            }
661        }
662    }
663
664    // ─── Query Methods ────────────────────────────────────────
665
666    /// Execute a simple query (no parameters). Returns all result rows.
667    pub fn query_simple(&mut self, sql: &str) -> PgResult<Vec<Row>> {
668        self.ensure_write_capacity(5 + sql.len());
669        let n = codec::encode_query(&mut self.write_buf, sql);
670        self.flush_write_buf(n)?;
671        self.read_query_results()
672    }
673
674    /// Execute a parameterized query using the Extended Query Protocol.
675    /// Uses implicit statement caching for performance.
676    pub fn query(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Vec<Row>> {
677        let stmt = self.stmt_cache.get_or_create(sql);
678
679        // Conservative upper bound for write buffer
680        let estimated = 10 + sql.len() + (params.len() * 256);
681        self.ensure_write_capacity(estimated);
682
683        let mut pos = 0;
684
685        if stmt.is_new {
686            // Parse
687            let n = codec::encode_parse(&mut self.write_buf[pos..], &stmt.name, sql, &[]);
688            pos += n;
689
690            // Describe (to get column info)
691            let n = codec::encode_describe(
692                &mut self.write_buf[pos..],
693                DescribeTarget::Statement,
694                &stmt.name,
695            );
696            pos += n;
697        }
698
699        // Bind — encode parameters with per-parameter format codes
700        let pg_values: Vec<PgValue> = params.iter().map(|p| p.to_sql()).collect();
701        let param_formats: Vec<i16> = pg_values
702            .iter()
703            .map(|v| if v.prefers_binary() { 1_i16 } else { 0_i16 })
704            .collect();
705        let param_values: Vec<Option<Vec<u8>>> = pg_values
706            .iter()
707            .zip(param_formats.iter())
708            .map(|(v, &fmt)| {
709                if fmt == 1 {
710                    v.to_binary_bytes()
711                } else {
712                    v.to_text_bytes()
713                }
714            })
715            .collect();
716        let param_refs: Vec<Option<&[u8]>> = param_values.iter().map(|p| p.as_deref()).collect();
717        let n = codec::encode_bind(
718            &mut self.write_buf[pos..],
719            "", // unnamed portal
720            &stmt.name,
721            &param_formats,
722            &param_refs,
723            &[1], // request all results in binary format
724        );
725        pos += n;
726
727        // Execute
728        let n = codec::encode_execute(&mut self.write_buf[pos..], "", 0);
729        pos += n;
730
731        // Sync
732        let n = codec::encode_sync(&mut self.write_buf[pos..]);
733        pos += n;
734
735        self.flush_write_buf(pos)?;
736
737        // Read results
738        let rows = self.read_extended_results(sql, &stmt.name, stmt.is_new, stmt.columns)?;
739        Ok(rows)
740    }
741
742    /// Execute a query expecting exactly one row.
743    ///
744    /// Optimised path: reads the Extended Query Protocol response stream and
745    /// returns the **first** `DataRow` directly, without collecting into a
746    /// `Vec<Row>`.  Subsequent rows (if any) are still drained so the
747    /// connection is left in a clean state.
748    pub fn query_one(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Row> {
749        let stmt = self.stmt_cache.get_or_create(sql);
750
751        let estimated = 10 + sql.len() + (params.len() * 256);
752        self.ensure_write_capacity(estimated);
753
754        let mut pos = 0;
755
756        if stmt.is_new {
757            let n = codec::encode_parse(&mut self.write_buf[pos..], &stmt.name, sql, &[]);
758            pos += n;
759            let n = codec::encode_describe(
760                &mut self.write_buf[pos..],
761                DescribeTarget::Statement,
762                &stmt.name,
763            );
764            pos += n;
765        }
766
767        let pg_values: Vec<PgValue> = params.iter().map(|p| p.to_sql()).collect();
768        let param_formats: Vec<i16> = pg_values
769            .iter()
770            .map(|v| if v.prefers_binary() { 1_i16 } else { 0_i16 })
771            .collect();
772        let param_values: Vec<Option<Vec<u8>>> = pg_values
773            .iter()
774            .zip(param_formats.iter())
775            .map(|(v, &fmt)| {
776                if fmt == 1 {
777                    v.to_binary_bytes()
778                } else {
779                    v.to_text_bytes()
780                }
781            })
782            .collect();
783        let param_refs: Vec<Option<&[u8]>> = param_values.iter().map(|p| p.as_deref()).collect();
784        let n = codec::encode_bind(
785            &mut self.write_buf[pos..],
786            "",
787            &stmt.name,
788            &param_formats,
789            &param_refs,
790            &[1],
791        );
792        pos += n;
793
794        // Execute with max_rows = 0 (unlimited) — PostgreSQL will send all
795        // DataRows but we stop caring after the first one.
796        let n = codec::encode_execute(&mut self.write_buf[pos..], "", 0);
797        pos += n;
798
799        let n = codec::encode_sync(&mut self.write_buf[pos..]);
800        pos += n;
801
802        self.flush_write_buf(pos)?;
803
804        self.read_extended_result_one(sql, &stmt.name, stmt.is_new, stmt.columns)
805    }
806
807    /// Execute a query expecting zero or one row. Returns `Ok(None)` when
808    /// the query returns no rows, avoiding the `PgError::NoRows` error path.
809    pub fn query_opt(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Option<Row>> {
810        let stmt = self.stmt_cache.get_or_create(sql);
811
812        let estimated = 10 + sql.len() + (params.len() * 256);
813        self.ensure_write_capacity(estimated);
814
815        let mut pos = 0;
816
817        if stmt.is_new {
818            let n = codec::encode_parse(&mut self.write_buf[pos..], &stmt.name, sql, &[]);
819            pos += n;
820            let n = codec::encode_describe(
821                &mut self.write_buf[pos..],
822                DescribeTarget::Statement,
823                &stmt.name,
824            );
825            pos += n;
826        }
827
828        let pg_values: Vec<PgValue> = params.iter().map(|p| p.to_sql()).collect();
829        let param_formats: Vec<i16> = pg_values
830            .iter()
831            .map(|v| if v.prefers_binary() { 1_i16 } else { 0_i16 })
832            .collect();
833        let param_values: Vec<Option<Vec<u8>>> = pg_values
834            .iter()
835            .zip(param_formats.iter())
836            .map(|(v, &fmt)| {
837                if fmt == 1 {
838                    v.to_binary_bytes()
839                } else {
840                    v.to_text_bytes()
841                }
842            })
843            .collect();
844        let param_refs: Vec<Option<&[u8]>> = param_values.iter().map(|p| p.as_deref()).collect();
845        let n = codec::encode_bind(
846            &mut self.write_buf[pos..],
847            "",
848            &stmt.name,
849            &param_formats,
850            &param_refs,
851            &[1],
852        );
853        pos += n;
854
855        let n = codec::encode_execute(&mut self.write_buf[pos..], "", 0);
856        pos += n;
857
858        let n = codec::encode_sync(&mut self.write_buf[pos..]);
859        pos += n;
860
861        self.flush_write_buf(pos)?;
862
863        self.read_extended_result_opt(sql, &stmt.name, stmt.is_new, stmt.columns)
864    }
865
866    /// Execute a statement that returns no rows (INSERT, UPDATE, DELETE).
867    /// Returns the number of affected rows as reported by the server.
868    pub fn execute(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<u64> {
869        let _rows = self.query(sql, params)?;
870        Ok(self.last_affected_rows)
871    }
872
873    // ─── Transaction Support ──────────────────────────────────
874
875    /// Begin a transaction.
876    pub fn begin(&mut self) -> PgResult<()> {
877        self.query_simple("BEGIN")?;
878        Ok(())
879    }
880
881    /// Commit the current transaction.
882    pub fn commit(&mut self) -> PgResult<()> {
883        self.query_simple("COMMIT")?;
884        Ok(())
885    }
886
887    /// Rollback the current transaction.
888    pub fn rollback(&mut self) -> PgResult<()> {
889        self.query_simple("ROLLBACK")?;
890        Ok(())
891    }
892
893    /// Create a savepoint.
894    pub fn savepoint(&mut self, name: &str) -> PgResult<()> {
895        self.query_simple(&format!("SAVEPOINT {}", name))?;
896        Ok(())
897    }
898
899    /// Rollback to a savepoint.
900    pub fn rollback_to(&mut self, name: &str) -> PgResult<()> {
901        self.query_simple(&format!("ROLLBACK TO SAVEPOINT {}", name))?;
902        Ok(())
903    }
904
905    /// Release a savepoint.
906    pub fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
907        self.query_simple(&format!("RELEASE SAVEPOINT {}", name))?;
908        Ok(())
909    }
910
911    /// Execute a closure within a transaction.
912    ///
913    /// Automatically BEGINs before calling `f`, COMMITs on success,
914    /// and ROLLBACKs on error. This ensures the transaction is always
915    /// finalized, even if the closure panics (via Drop).
916    ///
917    /// # Example
918    /// ```ignore
919    /// conn.transaction(|tx| {
920    ///     tx.execute("INSERT INTO users (name) VALUES ($1)", &[&"Alice"])?;
921    ///     tx.execute("INSERT INTO logs (msg) VALUES ($1)", &[&"User created"])?;
922    ///     Ok(())
923    /// })?;
924    /// ```
925    pub fn transaction<F, T>(&mut self, f: F) -> PgResult<T>
926    where
927        F: FnOnce(&mut Transaction<'_>) -> PgResult<T>,
928    {
929        self.begin()?;
930        let mut tx = Transaction {
931            conn: self,
932            finished: false,
933            savepoint_name: None,
934            savepoint_counter: 0,
935        };
936        match f(&mut tx) {
937            Ok(val) => {
938                tx.commit()?;
939                Ok(val)
940            }
941            Err(e) => {
942                // Attempt rollback, but propagate original error
943                let _ = tx.rollback();
944                Err(e)
945            }
946        }
947    }
948
949    // ─── COPY Protocol ────────────────────────────────────────
950
951    /// Start a COPY FROM STDIN operation.
952    pub fn copy_in(&mut self, sql: &str) -> PgResult<CopyWriter<'_>> {
953        let n = codec::encode_query(&mut self.write_buf, sql);
954        #[allow(clippy::unnecessary_to_owned)]
955        self.write_all(&self.write_buf[..n].to_vec())?;
956
957        // Read until CopyInResponse
958        loop {
959            self.fill_read_buf(None)?;
960            let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? else {
961                continue;
962            };
963            let header = codec::decode_header(&self.read_buf)
964                .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
965            match header.tag {
966                BackendTag::CopyInResponse => {
967                    self.consume_read(msg_len);
968                    return Ok(CopyWriter { conn: self });
969                }
970                BackendTag::ErrorResponse => {
971                    let body = &self.read_buf[5..msg_len];
972                    return Err(self.parse_error(body));
973                }
974                _ => {
975                    self.consume_read(msg_len);
976                }
977            }
978        }
979    }
980
981    /// Start a COPY TO STDOUT operation.
982    /// Returns a CopyReader that yields data chunks.
983    pub fn copy_out(&mut self, sql: &str) -> PgResult<CopyReader<'_>> {
984        let n = codec::encode_query(&mut self.write_buf, sql);
985        #[allow(clippy::unnecessary_to_owned)]
986        self.write_all(&self.write_buf[..n].to_vec())?;
987
988        // Read until CopyOutResponse
989        loop {
990            self.fill_read_buf(None)?;
991            let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? else {
992                continue;
993            };
994            let header = codec::decode_header(&self.read_buf)
995                .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
996            match header.tag {
997                BackendTag::CopyOutResponse => {
998                    self.consume_read(msg_len);
999                    return Ok(CopyReader {
1000                        conn: self,
1001                        done: false,
1002                    });
1003                }
1004                BackendTag::ErrorResponse => {
1005                    let body = &self.read_buf[5..msg_len];
1006                    return Err(self.parse_error(body));
1007                }
1008                _ => {
1009                    self.consume_read(msg_len);
1010                }
1011            }
1012        }
1013    }
1014
1015    // ─── LISTEN / NOTIFY ──────────────────────────────────────
1016
1017    /// Subscribe to a notification channel.
1018    pub fn listen(&mut self, channel: &str) -> PgResult<()> {
1019        self.query_simple(&format!("LISTEN {}", channel))?;
1020        Ok(())
1021    }
1022
1023    /// Send a notification.
1024    pub fn notify(&mut self, channel: &str, payload: &str) -> PgResult<()> {
1025        self.query_simple(&format!("NOTIFY {}, '{}'", channel, payload))?;
1026        Ok(())
1027    }
1028
1029    /// Unsubscribe from a notification channel.
1030    pub fn unlisten(&mut self, channel: &str) -> PgResult<()> {
1031        self.query_simple(&format!("UNLISTEN {}", channel))?;
1032        Ok(())
1033    }
1034
1035    /// Unsubscribe from all notification channels.
1036    pub fn unlisten_all(&mut self) -> PgResult<()> {
1037        self.query_simple("UNLISTEN *")?;
1038        Ok(())
1039    }
1040
1041    /// Drain and return all buffered notifications.
1042    pub fn drain_notifications(&mut self) -> Vec<Notification> {
1043        self.notifications.drain(..).collect()
1044    }
1045
1046    /// Check if there are buffered notifications.
1047    pub fn has_notifications(&self) -> bool {
1048        !self.notifications.is_empty()
1049    }
1050
1051    /// Get the number of buffered notifications.
1052    pub fn notification_count(&self) -> usize {
1053        self.notifications.len()
1054    }
1055
1056    /// Poll for a notification (always non-blocking).
1057    /// Reads from the socket and returns the first notification found,
1058    /// or None if no notification is immediately available.
1059    pub fn poll_notification(&mut self) -> PgResult<Option<Notification>> {
1060        // First check buffer
1061        if let Some(n) = self.notifications.pop_front() {
1062            return Ok(Some(n));
1063        }
1064
1065        // Try a non-blocking read (socket is already non-blocking)
1066        self.ensure_read_space();
1067        match self.stream.read(&mut self.read_buf[self.read_pos..]) {
1068            Ok(0) => return Err(PgError::ConnectionClosed),
1069            Ok(n) => {
1070                self.read_pos += n;
1071                // Process any complete messages
1072                while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])?
1073                {
1074                    let header = codec::decode_header(&self.read_buf).ok_or_else(|| {
1075                        PgError::Protocol("Incomplete message header".to_string())
1076                    })?;
1077                    if header.tag == BackendTag::NotificationResponse {
1078                        let body = &self.read_buf[5..msg_len];
1079                        let notification = Self::parse_notification(body);
1080                        self.notifications.push_back(notification);
1081                    }
1082                    self.consume_read(msg_len);
1083                }
1084            }
1085            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1086                // No data available
1087            }
1088            Err(e) => return Err(PgError::Io(e)),
1089        }
1090
1091        Ok(self.notifications.pop_front())
1092    }
1093
1094    // ─── Accessors ────────────────────────────────────────────
1095
1096    /// Get the current transaction status.
1097    pub fn transaction_status(&self) -> TransactionStatus {
1098        self.tx_status
1099    }
1100
1101    /// Get the number of cached statements.
1102    pub fn cached_statements(&self) -> usize {
1103        self.stmt_cache.len()
1104    }
1105
1106    /// Get the number of rows affected by the last command.
1107    pub fn last_affected_rows(&self) -> u64 {
1108        self.last_affected_rows
1109    }
1110
1111    /// Get the last CommandComplete tag string.
1112    pub fn last_command_tag(&self) -> &str {
1113        &self.last_command_tag
1114    }
1115
1116    /// Get the backend process ID.
1117    pub fn process_id(&self) -> i32 {
1118        self.process_id
1119    }
1120
1121    /// Get the backend secret key (used for cancel requests).
1122    pub fn secret_key(&self) -> i32 {
1123        self.secret_key
1124    }
1125
1126    /// Get server parameters received during startup.
1127    pub fn server_params(&self) -> &[(String, String)] {
1128        &self.server_params
1129    }
1130
1131    /// Get a specific server parameter by name.
1132    pub fn server_param(&self, name: &str) -> Option<&str> {
1133        self.server_params
1134            .iter()
1135            .find(|(k, _)| k == name)
1136            .map(|(_, v)| v.as_str())
1137    }
1138
1139    /// Check if the connection is in a transaction.
1140    pub fn in_transaction(&self) -> bool {
1141        matches!(
1142            self.tx_status,
1143            TransactionStatus::InTransaction | TransactionStatus::Failed
1144        )
1145    }
1146
1147    /// Clear the statement cache and deallocate all server-side prepared statements.
1148    ///
1149    /// Sends `DEALLOCATE ALL` to the server before clearing the client-side
1150    /// cache.  The statement name counter is preserved to prevent name
1151    /// collisions with any stale server-side references.
1152    pub fn clear_statement_cache(&mut self) {
1153        let _ = self.query_simple("DEALLOCATE ALL");
1154        self.stmt_cache.clear();
1155    }
1156
1157    /// Returns `true` if the connection has been marked as broken due to a
1158    /// fatal I/O error.  A broken connection should be discarded (not
1159    /// returned to the pool).
1160    pub fn is_broken(&self) -> bool {
1161        self.broken
1162    }
1163
1164    /// Reset the connection to a clean state for pool reuse.
1165    ///
1166    /// Sends `DISCARD ALL` which resets session state, deallocates prepared
1167    /// statements, closes cursors, drops temps, releases advisory locks.
1168    /// Then clears the client-side statement cache.
1169    pub fn reset(&mut self) -> PgResult<()> {
1170        self.query_simple("DISCARD ALL")?;
1171        self.stmt_cache.clear();
1172        Ok(())
1173    }
1174
1175    /// Execute one or more SQL statements separated by semicolons, using
1176    /// the Simple Query Protocol.  Returns the number of affected rows from
1177    /// the **last** command.
1178    ///
1179    /// This is useful for running DDL migrations, multi-statement scripts,
1180    /// or any sequence of commands that don't require parameters.
1181    ///
1182    /// # Example
1183    /// ```ignore
1184    /// conn.execute_batch("CREATE TABLE t(id INT); INSERT INTO t VALUES (1); INSERT INTO t VALUES (2);")?;
1185    /// ```
1186    pub fn execute_batch(&mut self, sql: &str) -> PgResult<u64> {
1187        self.query_simple(sql)?;
1188        Ok(self.last_affected_rows)
1189    }
1190
1191    /// Check if the connection is alive by sending a simple query.
1192    pub fn is_alive(&mut self) -> bool {
1193        self.query_simple("SELECT 1").is_ok()
1194    }
1195
1196    // ─── Internal Methods ─────────────────────────────────────
1197
1198    // ─── Non-blocking read/write primitives ───────────────────
1199
1200    /// Try to read data into the read buffer without blocking.
1201    /// Returns `Ok(n)` with bytes read, or `Err(PgError::WouldBlock)` if no
1202    /// data is available, or another error on failure.
1203    pub fn try_fill_read_buf(&mut self) -> PgResult<usize> {
1204        self.ensure_read_space();
1205
1206        match self.stream.read(&mut self.read_buf[self.read_pos..]) {
1207            Ok(0) => {
1208                self.broken = true;
1209                Err(PgError::ConnectionClosed)
1210            }
1211            Ok(n) => {
1212                self.read_pos += n;
1213                Ok(n)
1214            }
1215            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(PgError::WouldBlock),
1216            Err(e) => {
1217                self.broken = true;
1218                Err(PgError::Io(e))
1219            }
1220        }
1221    }
1222
1223    /// Try to write a buffer to the socket without blocking.
1224    /// Returns `Ok(n)` with bytes written, or `Err(PgError::WouldBlock)` if the
1225    /// socket is not writable.
1226    pub fn try_write(&mut self, data: &[u8]) -> PgResult<usize> {
1227        match self.stream.write(data) {
1228            Ok(n) => Ok(n),
1229            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(PgError::WouldBlock),
1230            Err(e) => {
1231                self.broken = true;
1232                Err(PgError::Io(e))
1233            }
1234        }
1235    }
1236
1237    /// Wait for the socket to become readable using `poll(2)`.
1238    ///
1239    /// This is a true OS-level wait — the thread sleeps in the kernel until
1240    /// the socket has data or the timeout expires.  No busy-waiting, no
1241    /// `thread::sleep`.
1242    #[cfg(unix)]
1243    fn wait_readable(&self, timeout: Duration) -> PgResult<()> {
1244        let fd = self.stream.as_raw_fd();
1245        let timeout_ms = timeout.as_millis().min(i32::MAX as u128) as i32;
1246        let mut pfd = libc::pollfd {
1247            fd,
1248            events: libc::POLLIN,
1249            revents: 0,
1250        };
1251        let ret = unsafe { libc::poll(&mut pfd, 1, timeout_ms) };
1252        if ret < 0 {
1253            let e = std::io::Error::last_os_error();
1254            if e.kind() == std::io::ErrorKind::Interrupted {
1255                return Ok(()); // EINTR — caller will retry
1256            }
1257            return Err(PgError::Io(e));
1258        }
1259        if ret == 0 {
1260            return Err(PgError::Timeout);
1261        }
1262        if pfd.revents & (libc::POLLERR | libc::POLLHUP | libc::POLLNVAL) != 0 {
1263            return Err(PgError::ConnectionClosed);
1264        }
1265        Ok(())
1266    }
1267
1268    /// Wait for the socket to become writable using `poll(2)`.
1269    #[cfg(unix)]
1270    fn wait_writable(&self, timeout: Duration) -> PgResult<()> {
1271        let fd = self.stream.as_raw_fd();
1272        let timeout_ms = timeout.as_millis().min(i32::MAX as u128) as i32;
1273        let mut pfd = libc::pollfd {
1274            fd,
1275            events: libc::POLLOUT,
1276            revents: 0,
1277        };
1278        let ret = unsafe { libc::poll(&mut pfd, 1, timeout_ms) };
1279        if ret < 0 {
1280            let e = std::io::Error::last_os_error();
1281            if e.kind() == std::io::ErrorKind::Interrupted {
1282                return Ok(());
1283            }
1284            return Err(PgError::Io(e));
1285        }
1286        if ret == 0 {
1287            return Err(PgError::Timeout);
1288        }
1289        if pfd.revents & (libc::POLLERR | libc::POLLHUP | libc::POLLNVAL) != 0 {
1290            return Err(PgError::ConnectionClosed);
1291        }
1292        Ok(())
1293    }
1294
1295    /// Poll the socket for readability with a timeout.
1296    ///
1297    /// Uses the OS `poll(2)` syscall for efficient, zero-waste waiting.
1298    /// The thread sleeps in the kernel until data arrives or the timeout
1299    /// expires — no busy-loop, no `thread::sleep`.
1300    pub fn poll_read(&mut self, timeout: Duration) -> PgResult<usize> {
1301        let start = Instant::now();
1302        loop {
1303            match self.try_fill_read_buf() {
1304                Ok(n) => return Ok(n),
1305                Err(PgError::WouldBlock) => {
1306                    let elapsed = start.elapsed();
1307                    if elapsed >= timeout {
1308                        return Err(PgError::Timeout);
1309                    }
1310                    #[cfg(unix)]
1311                    self.wait_readable(timeout - elapsed)?;
1312                    #[cfg(not(unix))]
1313                    std::thread::sleep(Duration::from_micros(50));
1314                }
1315                Err(e) => return Err(e),
1316            }
1317        }
1318    }
1319
1320    /// Poll the socket for writability with a timeout.
1321    /// Writes all of `data` or times out.
1322    pub fn poll_write(&mut self, data: &[u8], timeout: Duration) -> PgResult<()> {
1323        let start = Instant::now();
1324        let mut written = 0;
1325        while written < data.len() {
1326            match self.try_write(&data[written..]) {
1327                Ok(n) => written += n,
1328                Err(PgError::WouldBlock) => {
1329                    let elapsed = start.elapsed();
1330                    if elapsed >= timeout {
1331                        return Err(PgError::Timeout);
1332                    }
1333                    #[cfg(unix)]
1334                    self.wait_writable(timeout - elapsed)?;
1335                    #[cfg(not(unix))]
1336                    std::thread::sleep(Duration::from_micros(50));
1337                }
1338                Err(e) => return Err(e),
1339            }
1340        }
1341        Ok(())
1342    }
1343
1344    /// Internal: fill the read buffer, blocking with the connection's
1345    /// configured timeout. This is the workhorse used by query methods.
1346    fn fill_read_buf(&mut self, min_size: Option<usize>) -> PgResult<()> {
1347        if let Some(min) = min_size {
1348            self.ensure_read_capacity(min);
1349        }
1350
1351        self.ensure_read_space();
1352
1353        if self.nonblocking {
1354            // Use poll_read with timeout
1355            self.poll_read(self.io_timeout)?;
1356        } else {
1357            // Blocking path (used during startup before we switch to NB)
1358            let n = self
1359                .stream
1360                .read(&mut self.read_buf[self.read_pos..])
1361                .map_err(PgError::Io)?;
1362            if n == 0 {
1363                return Err(PgError::ConnectionClosed);
1364            }
1365            self.read_pos += n;
1366        }
1367        Ok(())
1368    }
1369
1370    /// Internal: write all bytes to the socket, respecting non-blocking mode.
1371    fn write_all(&mut self, data: &[u8]) -> PgResult<()> {
1372        if self.nonblocking {
1373            self.poll_write(data, self.io_timeout)
1374        } else {
1375            self.stream.write_all(data).map_err(PgError::Io)
1376        }
1377    }
1378
1379    /// Internal: flush the first `n` bytes of `self.write_buf` to the stream.
1380    ///
1381    /// This avoids the `.to_vec()` copy that was previously needed to work
1382    /// around borrow-checker limitations when `self.write_buf` is the source
1383    /// and `self.write_all()` takes `&mut self`.  By inlining the write loop
1384    /// here, the compiler can see that `stream` and `write_buf` are disjoint
1385    /// fields (split borrow).
1386    fn flush_write_buf(&mut self, n: usize) -> PgResult<()> {
1387        if self.nonblocking {
1388            let timeout = self.io_timeout;
1389            let start = Instant::now();
1390            let mut written = 0;
1391            while written < n {
1392                match self.stream.write(&self.write_buf[written..n]) {
1393                    Ok(w) => written += w,
1394                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1395                        let elapsed = start.elapsed();
1396                        if elapsed >= timeout {
1397                            return Err(PgError::Timeout);
1398                        }
1399                        #[cfg(unix)]
1400                        self.wait_writable(timeout - elapsed)?;
1401                        #[cfg(not(unix))]
1402                        std::thread::sleep(Duration::from_micros(50));
1403                    }
1404                    Err(e) => {
1405                        self.broken = true;
1406                        return Err(PgError::Io(e));
1407                    }
1408                }
1409            }
1410            Ok(())
1411        } else {
1412            self.stream
1413                .write_all(&self.write_buf[..n])
1414                .map_err(PgError::Io)
1415        }
1416    }
1417
1418    /// Ensure there is room in read_buf for at least one read call.
1419    fn ensure_read_space(&mut self) {
1420        if self.read_pos == self.read_buf.len() {
1421            if self.read_pos >= 5
1422                && let Some(header) = codec::decode_header(&self.read_buf)
1423            {
1424                // Guard against malicious or corrupt servers advertising a
1425                // length that exceeds our safety limit — do not allocate.
1426                // The overflow will be caught as Err(BufferOverflow) by the
1427                // next call to message_complete().
1428                if header.length as usize > codec::MAX_MESSAGE_SIZE {
1429                    return;
1430                }
1431                let total = 1 + header.length as usize;
1432                self.ensure_read_capacity(total - self.read_pos);
1433                return;
1434            }
1435            self.ensure_read_capacity(8192);
1436        }
1437    }
1438
1439    fn consume_read(&mut self, n: usize) {
1440        self.read_buf.copy_within(n..self.read_pos, 0);
1441        self.read_pos -= n;
1442    }
1443
1444    fn ensure_read_capacity(&mut self, additional: usize) {
1445        if self.read_pos + additional > self.read_buf.len() {
1446            let new_len = (self.read_pos + additional).max(self.read_buf.len() * 2);
1447            self.read_buf.resize(new_len, 0);
1448        }
1449    }
1450
1451    fn ensure_write_capacity(&mut self, additional: usize) {
1452        if additional > self.write_buf.len() {
1453            let new_len = additional.max(self.write_buf.len() * 2);
1454            self.write_buf.resize(new_len, 0);
1455        }
1456    }
1457
1458    fn read_query_results(&mut self) -> PgResult<Vec<Row>> {
1459        let mut rows = Vec::new();
1460        let mut columns_rc: Rc<Vec<codec::ColumnDesc>> = Rc::new(Vec::new());
1461
1462        loop {
1463            if codec::message_complete(&self.read_buf[..self.read_pos])?.is_none() {
1464                self.fill_read_buf(None)?;
1465            }
1466
1467            while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? {
1468                let header = codec::decode_header(&self.read_buf)
1469                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1470                let body = &self.read_buf[5..msg_len];
1471
1472                match header.tag {
1473                    BackendTag::RowDescription => {
1474                        columns_rc = Rc::new(codec::parse_row_description(body));
1475                    }
1476                    BackendTag::DataRow => {
1477                        let raw_values = codec::parse_data_row(body);
1478                        rows.push(Row::new(Rc::clone(&columns_rc), raw_values));
1479                    }
1480                    BackendTag::CommandComplete => {
1481                        let (tag, rows_affected) = extract_command_complete(body);
1482                        self.last_command_tag = tag;
1483                        self.last_affected_rows = rows_affected;
1484                    }
1485                    BackendTag::ReadyForQuery => {
1486                        self.tx_status = TransactionStatus::from(body[0]);
1487                        self.consume_read(msg_len);
1488                        return Ok(rows);
1489                    }
1490                    BackendTag::ErrorResponse => {
1491                        let err = self.parse_error(body);
1492                        self.consume_read(msg_len);
1493                        // Drain to ReadyForQuery
1494                        self.drain_to_ready()?;
1495                        return Err(err);
1496                    }
1497                    BackendTag::NotificationResponse => {
1498                        let notification = Self::parse_notification(body);
1499                        self.notifications.push_back(notification);
1500                    }
1501                    BackendTag::EmptyQueryResponse => {}
1502                    BackendTag::NoticeResponse => {
1503                        self.dispatch_notice(body);
1504                    }
1505                    _ => {}
1506                }
1507                self.consume_read(msg_len);
1508            }
1509        }
1510    }
1511
1512    fn read_extended_results(
1513        &mut self,
1514        sql: &str,
1515        stmt_name: &str,
1516        is_new: bool,
1517        cached_columns: Option<Vec<codec::ColumnDesc>>,
1518    ) -> PgResult<Vec<Row>> {
1519        let mut rows = Vec::new();
1520        let mut columns_rc: Rc<Vec<codec::ColumnDesc>> = match cached_columns {
1521            Some(c) => Rc::new(c),
1522            None => Rc::new(Vec::new()),
1523        };
1524
1525        loop {
1526            if codec::message_complete(&self.read_buf[..self.read_pos])?.is_none() {
1527                self.fill_read_buf(None)?;
1528            }
1529
1530            while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? {
1531                let header = codec::decode_header(&self.read_buf)
1532                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1533                let body = &self.read_buf[5..msg_len];
1534
1535                match header.tag {
1536                    BackendTag::ParseComplete => {}
1537                    BackendTag::ParameterDescription => {}
1538                    BackendTag::RowDescription => {
1539                        let mut columns = codec::parse_row_description(body);
1540                        // The driver always requests binary-format results in the
1541                        // Bind message (&[1]).  RowDescription from a Describe
1542                        // *Statement* always has format_code = Text (0x0) because
1543                        // no Bind has occurred yet.  Override every column to
1544                        // Binary so DataRow bytes are decoded correctly.
1545                        for col in &mut columns {
1546                            col.format_code = FormatCode::Binary;
1547                        }
1548                        if is_new
1549                            && let Some(evicted) = self.stmt_cache.insert(
1550                                sql,
1551                                stmt_name.to_string(),
1552                                0,
1553                                Some(columns.clone()),
1554                            )
1555                        {
1556                            self.close_statement_on_server(&evicted.name);
1557                        }
1558                        columns_rc = Rc::new(columns);
1559                    }
1560                    BackendTag::NoData if is_new => {
1561                        if let Some(evicted) =
1562                            self.stmt_cache.insert(sql, stmt_name.to_string(), 0, None)
1563                        {
1564                            self.close_statement_on_server(&evicted.name);
1565                        }
1566                    }
1567                    BackendTag::NoData => {}
1568                    BackendTag::BindComplete => {}
1569                    BackendTag::DataRow => {
1570                        let raw_values = codec::parse_data_row(body);
1571                        rows.push(Row::new(Rc::clone(&columns_rc), raw_values));
1572                    }
1573                    BackendTag::CommandComplete => {
1574                        let (tag, rows_affected) = extract_command_complete(body);
1575                        self.last_command_tag = tag;
1576                        self.last_affected_rows = rows_affected;
1577                    }
1578                    BackendTag::ReadyForQuery => {
1579                        self.tx_status = TransactionStatus::from(body[0]);
1580                        self.consume_read(msg_len);
1581                        return Ok(rows);
1582                    }
1583                    BackendTag::ErrorResponse => {
1584                        let err = self.parse_error_with_context(body, sql);
1585                        self.consume_read(msg_len);
1586                        self.drain_to_ready()?;
1587                        return Err(err);
1588                    }
1589                    BackendTag::NotificationResponse => {
1590                        let notification = Self::parse_notification(body);
1591                        self.notifications.push_back(notification);
1592                    }
1593                    BackendTag::NoticeResponse => {
1594                        self.dispatch_notice(body);
1595                    }
1596                    _ => {}
1597                }
1598                self.consume_read(msg_len);
1599            }
1600        }
1601    }
1602
1603    /// Optimised read path for `query_one`: returns the first `DataRow`
1604    /// directly without collecting into a `Vec`. Remaining rows and
1605    /// protocol messages are drained so the connection stays clean.
1606    fn read_extended_result_one(
1607        &mut self,
1608        sql: &str,
1609        stmt_name: &str,
1610        is_new: bool,
1611        cached_columns: Option<Vec<codec::ColumnDesc>>,
1612    ) -> PgResult<Row> {
1613        match self.read_extended_result_opt(sql, stmt_name, is_new, cached_columns)? {
1614            Some(row) => Ok(row),
1615            None => Err(PgError::NoRows),
1616        }
1617    }
1618
1619    /// Optimised read path for `query_opt`: returns the first `DataRow`
1620    /// as `Some(Row)`, or `None` if the query returns zero rows.
1621    /// Remaining rows are drained.
1622    fn read_extended_result_opt(
1623        &mut self,
1624        sql: &str,
1625        stmt_name: &str,
1626        is_new: bool,
1627        cached_columns: Option<Vec<codec::ColumnDesc>>,
1628    ) -> PgResult<Option<Row>> {
1629        let mut result: Option<Row> = None;
1630        let mut columns_rc: Rc<Vec<codec::ColumnDesc>> = match cached_columns {
1631            Some(c) => Rc::new(c),
1632            None => Rc::new(Vec::new()),
1633        };
1634
1635        loop {
1636            if codec::message_complete(&self.read_buf[..self.read_pos])?.is_none() {
1637                self.fill_read_buf(None)?;
1638            }
1639
1640            while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? {
1641                let header = codec::decode_header(&self.read_buf)
1642                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1643                let body = &self.read_buf[5..msg_len];
1644
1645                match header.tag {
1646                    BackendTag::ParseComplete => {}
1647                    BackendTag::ParameterDescription => {}
1648                    BackendTag::RowDescription => {
1649                        let mut columns = codec::parse_row_description(body);
1650                        for col in &mut columns {
1651                            col.format_code = FormatCode::Binary;
1652                        }
1653                        if is_new
1654                            && let Some(evicted) = self.stmt_cache.insert(
1655                                sql,
1656                                stmt_name.to_string(),
1657                                0,
1658                                Some(columns.clone()),
1659                            )
1660                        {
1661                            self.close_statement_on_server(&evicted.name);
1662                        }
1663                        columns_rc = Rc::new(columns);
1664                    }
1665                    BackendTag::NoData if is_new => {
1666                        if let Some(evicted) =
1667                            self.stmt_cache.insert(sql, stmt_name.to_string(), 0, None)
1668                        {
1669                            self.close_statement_on_server(&evicted.name);
1670                        }
1671                    }
1672                    BackendTag::NoData => {}
1673                    BackendTag::BindComplete => {}
1674                    BackendTag::DataRow
1675                        // Only capture the first row; subsequent DataRows are skipped.
1676                        if result.is_none() => {
1677                            let raw_values = codec::parse_data_row(body);
1678                            result = Some(Row::new(Rc::clone(&columns_rc), raw_values));
1679                    }
1680                    BackendTag::DataRow => {
1681                        // Subsequent DataRows are skipped (not allocated).
1682                    }
1683                    BackendTag::CommandComplete => {
1684                        let (tag, rows_affected) = extract_command_complete(body);
1685                        self.last_command_tag = tag;
1686                        self.last_affected_rows = rows_affected;
1687                    }
1688                    BackendTag::ReadyForQuery => {
1689                        self.tx_status = TransactionStatus::from(body[0]);
1690                        self.consume_read(msg_len);
1691                        return Ok(result);
1692                    }
1693                    BackendTag::ErrorResponse => {
1694                        let err = self.parse_error_with_context(body, sql);
1695                        self.consume_read(msg_len);
1696                        self.drain_to_ready()?;
1697                        return Err(err);
1698                    }
1699                    BackendTag::NotificationResponse => {
1700                        let notification = Self::parse_notification(body);
1701                        self.notifications.push_back(notification);
1702                    }
1703                    BackendTag::NoticeResponse => {
1704                        self.dispatch_notice(body);
1705                    }
1706                    _ => {}
1707                }
1708                self.consume_read(msg_len);
1709            }
1710        }
1711    }
1712
1713    fn drain_to_ready(&mut self) -> PgResult<()> {
1714        loop {
1715            // Only read from the socket when the buffer has no complete message;
1716            // data may already be buffered from an earlier fill_read_buf call.
1717            if codec::message_complete(&self.read_buf[..self.read_pos])?.is_none() {
1718                self.fill_read_buf(None)?;
1719            }
1720            while let Some(msg_len) = codec::message_complete(&self.read_buf[..self.read_pos])? {
1721                let header = codec::decode_header(&self.read_buf)
1722                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
1723                if header.tag == BackendTag::ReadyForQuery {
1724                    let body = &self.read_buf[5..msg_len];
1725                    self.tx_status = TransactionStatus::from(body[0]);
1726                    self.consume_read(msg_len);
1727                    return Ok(());
1728                }
1729                self.consume_read(msg_len);
1730            }
1731        }
1732    }
1733
1734    fn parse_error(&self, body: &[u8]) -> PgError {
1735        let fields = codec::parse_error_fields(body);
1736        PgError::from_fields(&fields)
1737    }
1738
1739    /// Parse an error and attach query context for better debugging.
1740    fn parse_error_with_context(&self, body: &[u8], query: &str) -> PgError {
1741        let fields = codec::parse_error_fields(body);
1742        let mut err = PgError::from_fields(&fields);
1743        if let PgError::Server(ref mut server_err) = err
1744            && server_err.internal_query.is_none()
1745        {
1746            server_err.internal_query = Some(query.to_string());
1747        }
1748        err
1749    }
1750
1751    /// Dispatch a NoticeResponse to the registered handler.
1752    fn dispatch_notice(&self, body: &[u8]) {
1753        if let Some(ref handler) = self.notice_handler {
1754            let fields = codec::parse_error_fields(body);
1755            let mut severity = "";
1756            let mut code = "";
1757            let mut message = "";
1758            for (field_type, value) in &fields {
1759                match field_type {
1760                    b'S' => severity = value,
1761                    b'C' => code = value,
1762                    b'M' => message = value,
1763                    _ => {}
1764                }
1765            }
1766            handler(severity, code, message);
1767        }
1768    }
1769
1770    /// Send a Close('S') message to deallocate a server-side prepared statement.
1771    /// This is fire-and-forget — we don't wait for CloseComplete.
1772    fn close_statement_on_server(&mut self, name: &str) {
1773        self.ensure_write_capacity(7 + name.len());
1774        let n = codec::encode_close(&mut self.write_buf, CloseTarget::Statement, name);
1775        let _ = self.flush_write_buf(n);
1776    }
1777
1778    // Parse a CommandComplete tag to extract affected row count.
1779    // Tags look like: "INSERT 0 5", "UPDATE 3", "DELETE 1", "SELECT 10", etc.
1780    // parse_command_complete is now a free function: extract_command_complete()
1781
1782    /// Parse a NotificationResponse message body.
1783    fn parse_notification(body: &[u8]) -> Notification {
1784        let process_id = codec::read_i32(body, 0);
1785        let (channel, consumed) = codec::read_cstring(body, 4);
1786        let (payload, _) = codec::read_cstring(body, 4 + consumed);
1787        Notification {
1788            process_id,
1789            channel: channel.to_string(),
1790            payload: payload.to_string(),
1791        }
1792    }
1793}
1794
1795/// Extract the command tag and affected row count from a CommandComplete body.
1796/// This is a free function (not a method) to avoid borrow conflicts when
1797/// `body` is a slice of the connection's read buffer.
1798fn extract_command_complete(body: &[u8]) -> (String, u64) {
1799    let (tag, _) = codec::read_cstring(body, 0);
1800    let tag_str = tag.to_string();
1801    let affected_rows = tag
1802        .rsplit(' ')
1803        .next()
1804        .and_then(|s| s.parse::<u64>().ok())
1805        .unwrap_or(0);
1806    (tag_str, affected_rows)
1807}
1808
1809impl Drop for PgConnection {
1810    fn drop(&mut self) {
1811        // Switch to blocking mode so the Terminate message is reliably sent.
1812        // On a non-blocking socket write_all may fail with WouldBlock,
1813        // silently leaving the server-side session open.
1814        if self.nonblocking {
1815            let _ = self.stream.set_nonblocking(false);
1816        }
1817        let n = codec::encode_terminate(&mut self.write_buf);
1818        let _ = self.stream.write_all(&self.write_buf[..n]);
1819    }
1820}
1821
1822// ─── Transaction ──────────────────────────────────────────────
1823
1824/// A transaction guard. Ensures the transaction is committed or rolled back.
1825///
1826/// Created via `PgConnection::transaction()`. Provides the same query
1827/// methods as `PgConnection`. On drop, if neither `commit` nor `rollback`
1828/// was called, automatically rolls back.
1829pub struct Transaction<'a> {
1830    conn: &'a mut PgConnection,
1831    finished: bool,
1832    /// If Some, this is a nested transaction backed by a SAVEPOINT.
1833    savepoint_name: Option<String>,
1834    /// Counter for generating unique savepoint names in nested calls.
1835    savepoint_counter: u32,
1836}
1837
1838impl<'a> Transaction<'a> {
1839    /// Commit this transaction (or release savepoint if nested).
1840    pub fn commit(&mut self) -> PgResult<()> {
1841        if !self.finished {
1842            self.finished = true;
1843            if let Some(ref name) = self.savepoint_name {
1844                self.conn.release_savepoint(name)
1845            } else {
1846                self.conn.commit()
1847            }
1848        } else {
1849            Ok(())
1850        }
1851    }
1852
1853    /// Rollback this transaction (or rollback to savepoint if nested).
1854    pub fn rollback(&mut self) -> PgResult<()> {
1855        if !self.finished {
1856            self.finished = true;
1857            if let Some(ref name) = self.savepoint_name {
1858                self.conn.rollback_to(name)
1859            } else {
1860                self.conn.rollback()
1861            }
1862        } else {
1863            Ok(())
1864        }
1865    }
1866
1867    /// Execute a nested transaction using a SAVEPOINT.
1868    ///
1869    /// Creates a savepoint, calls the closure, and either releases
1870    /// (on success) or rolls back to the savepoint (on error/drop).
1871    ///
1872    /// Nesting is unlimited — each level creates a new savepoint.
1873    ///
1874    /// # Example
1875    /// ```ignore
1876    /// conn.transaction(|tx| {
1877    ///     tx.execute("INSERT INTO users (name) VALUES ($1)", &[&"Alice"])?;
1878    ///     tx.transaction(|nested| {
1879    ///         nested.execute("INSERT INTO logs (msg) VALUES ($1)", &[&"nested"])?;
1880    ///         Ok(())
1881    ///     })?;
1882    ///     Ok(())
1883    /// })?;
1884    /// ```
1885    pub fn transaction<F, T>(&mut self, f: F) -> PgResult<T>
1886    where
1887        F: FnOnce(&mut Transaction<'_>) -> PgResult<T>,
1888    {
1889        self.savepoint_counter += 1;
1890        let sp_name = format!("chopin_sp_{}", self.savepoint_counter);
1891        self.conn.savepoint(&sp_name)?;
1892        let mut nested = Transaction {
1893            conn: self.conn,
1894            finished: false,
1895            savepoint_name: Some(sp_name),
1896            savepoint_counter: 0,
1897        };
1898        match f(&mut nested) {
1899            Ok(val) => {
1900                nested.commit()?;
1901                Ok(val)
1902            }
1903            Err(e) => {
1904                let _ = nested.rollback();
1905                Err(e)
1906            }
1907        }
1908    }
1909
1910    /// Execute a simple query (no parameters).
1911    pub fn query_simple(&mut self, sql: &str) -> PgResult<Vec<Row>> {
1912        self.conn.query_simple(sql)
1913    }
1914
1915    /// Execute a parameterized query.
1916    pub fn query(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Vec<Row>> {
1917        self.conn.query(sql, params)
1918    }
1919
1920    /// Execute a query expecting exactly one row.
1921    pub fn query_one(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Row> {
1922        self.conn.query_one(sql, params)
1923    }
1924
1925    /// Execute a query expecting zero or one row.
1926    pub fn query_opt(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<Option<Row>> {
1927        self.conn.query_opt(sql, params)
1928    }
1929
1930    /// Execute a statement that returns no rows.
1931    pub fn execute(&mut self, sql: &str, params: &[&dyn ToSql]) -> PgResult<u64> {
1932        self.conn.execute(sql, params)
1933    }
1934
1935    /// Create a savepoint within this transaction.
1936    pub fn savepoint(&mut self, name: &str) -> PgResult<()> {
1937        self.conn.savepoint(name)
1938    }
1939
1940    /// Rollback to a savepoint.
1941    pub fn rollback_to(&mut self, name: &str) -> PgResult<()> {
1942        self.conn.rollback_to(name)
1943    }
1944
1945    /// Release a savepoint.
1946    pub fn release_savepoint(&mut self, name: &str) -> PgResult<()> {
1947        self.conn.release_savepoint(name)
1948    }
1949
1950    /// Get the transaction status.
1951    pub fn status(&self) -> TransactionStatus {
1952        self.conn.transaction_status()
1953    }
1954}
1955
1956impl<'a> Drop for Transaction<'a> {
1957    fn drop(&mut self) {
1958        if !self.finished {
1959            // Auto-rollback on drop (savepoint if nested, full rollback otherwise)
1960            if let Some(ref name) = self.savepoint_name {
1961                let _ = self.conn.rollback_to(name);
1962            } else {
1963                let _ = self.conn.rollback();
1964            }
1965        }
1966    }
1967}
1968
1969// ─── COPY Writer ──────────────────────────────────────────────
1970
1971/// COPY writer for streaming data into PostgreSQL via COPY FROM STDIN.
1972pub struct CopyWriter<'a> {
1973    conn: &'a mut PgConnection,
1974}
1975
1976impl<'a> CopyWriter<'a> {
1977    /// Write a chunk of COPY data.
1978    pub fn write_data(&mut self, data: &[u8]) -> PgResult<()> {
1979        self.conn.ensure_write_capacity(5 + data.len());
1980        let n = codec::encode_copy_data(&mut self.conn.write_buf, data);
1981        self.conn.flush_write_buf(n)
1982    }
1983
1984    /// Abort the COPY operation with an error message.
1985    ///
1986    /// Sends a CopyFail message to the server. The server will respond
1987    /// with an ErrorResponse and then ReadyForQuery. The connection
1988    /// remains usable after this call.
1989    pub fn fail(self, reason: &str) -> PgResult<()> {
1990        self.conn.ensure_write_capacity(6 + reason.len());
1991        let n = codec::encode_copy_fail(&mut self.conn.write_buf, reason);
1992        self.conn.flush_write_buf(n)?;
1993
1994        // Drain to ReadyForQuery (server sends ErrorResponse first)
1995        loop {
1996            self.conn.fill_read_buf(None)?;
1997            while let Some(msg_len) =
1998                codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])?
1999            {
2000                let header = codec::decode_header(&self.conn.read_buf)
2001                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
2002                match header.tag {
2003                    BackendTag::ErrorResponse => {
2004                        // Expected — server acknowledges the CopyFail
2005                        self.conn.consume_read(msg_len);
2006                    }
2007                    BackendTag::ReadyForQuery => {
2008                        let body = &self.conn.read_buf[5..msg_len];
2009                        self.conn.tx_status = TransactionStatus::from(body[0]);
2010                        self.conn.consume_read(msg_len);
2011                        return Ok(());
2012                    }
2013                    _ => {
2014                        self.conn.consume_read(msg_len);
2015                    }
2016                }
2017            }
2018        }
2019    }
2020
2021    /// Write a text row (tab-separated values with newline).
2022    pub fn write_row(&mut self, columns: &[&str]) -> PgResult<()> {
2023        let line = columns.join("\t") + "\n";
2024        self.write_data(line.as_bytes())
2025    }
2026
2027    /// Finish the COPY operation successfully.
2028    pub fn finish(self) -> PgResult<u64> {
2029        let n = codec::encode_copy_done(&mut self.conn.write_buf);
2030        self.conn.flush_write_buf(n)?;
2031
2032        // Drain to ReadyForQuery
2033        loop {
2034            self.conn.fill_read_buf(None)?;
2035            while let Some(msg_len) =
2036                codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])?
2037            {
2038                let header = codec::decode_header(&self.conn.read_buf)
2039                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
2040                let body = &self.conn.read_buf[5..msg_len];
2041                match header.tag {
2042                    BackendTag::CommandComplete => {
2043                        let (tag, rows_affected) = extract_command_complete(body);
2044                        self.conn.last_command_tag = tag;
2045                        self.conn.last_affected_rows = rows_affected;
2046                    }
2047                    BackendTag::ReadyForQuery => {
2048                        self.conn.tx_status = TransactionStatus::from(body[0]);
2049                        self.conn.consume_read(msg_len);
2050                        return Ok(self.conn.last_affected_rows);
2051                    }
2052                    BackendTag::ErrorResponse => {
2053                        let err = self.conn.parse_error(body);
2054                        self.conn.consume_read(msg_len);
2055                        return Err(err);
2056                    }
2057                    _ => {}
2058                }
2059                self.conn.consume_read(msg_len);
2060            }
2061        }
2062    }
2063}
2064
2065// ─── COPY Reader ──────────────────────────────────────────────
2066
2067/// COPY reader for receiving data from PostgreSQL via COPY TO STDOUT.
2068pub struct CopyReader<'a> {
2069    conn: &'a mut PgConnection,
2070    done: bool,
2071}
2072
2073impl<'a> CopyReader<'a> {
2074    /// Read the next chunk of COPY data.
2075    /// Returns None when the COPY operation is complete.
2076    pub fn read_data(&mut self) -> PgResult<Option<Vec<u8>>> {
2077        if self.done {
2078            return Ok(None);
2079        }
2080
2081        loop {
2082            // Only refill the buffer when it doesn't already hold a complete
2083            // message.  If a previous fill read the entire COPY response in
2084            // one TCP segment, the CopyData / CopyDone / ReadyForQuery bytes
2085            // are already in `read_buf` and we must not block waiting for
2086            // more socket data before processing them.
2087            if codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])?.is_none() {
2088                self.conn.fill_read_buf(None)?;
2089            }
2090
2091            while let Some(msg_len) =
2092                codec::message_complete(&self.conn.read_buf[..self.conn.read_pos])?
2093            {
2094                let header = codec::decode_header(&self.conn.read_buf)
2095                    .ok_or_else(|| PgError::Protocol("Incomplete message header".to_string()))?;
2096                let body = &self.conn.read_buf[5..msg_len];
2097
2098                match header.tag {
2099                    BackendTag::CopyData => {
2100                        let data = body.to_vec();
2101                        self.conn.consume_read(msg_len);
2102                        return Ok(Some(data));
2103                    }
2104                    BackendTag::CopyDone => {
2105                        self.conn.consume_read(msg_len);
2106                        // Continue to receive CommandComplete + ReadyForQuery
2107                    }
2108                    BackendTag::CommandComplete => {
2109                        let (tag, rows_affected) = extract_command_complete(body);
2110                        self.conn.last_command_tag = tag;
2111                        self.conn.last_affected_rows = rows_affected;
2112                        self.conn.consume_read(msg_len);
2113                    }
2114                    BackendTag::ReadyForQuery => {
2115                        self.conn.tx_status = TransactionStatus::from(body[0]);
2116                        self.conn.consume_read(msg_len);
2117                        self.done = true;
2118                        return Ok(None);
2119                    }
2120                    BackendTag::ErrorResponse => {
2121                        let err = self.conn.parse_error(body);
2122                        self.conn.consume_read(msg_len);
2123                        self.done = true;
2124                        return Err(err);
2125                    }
2126                    _ => {
2127                        self.conn.consume_read(msg_len);
2128                    }
2129                }
2130            }
2131        }
2132    }
2133
2134    /// Read all remaining COPY data into a single Vec.
2135    pub fn read_all(&mut self) -> PgResult<Vec<u8>> {
2136        let mut result = Vec::new();
2137        while let Some(chunk) = self.read_data()? {
2138            result.extend_from_slice(&chunk);
2139        }
2140        Ok(result)
2141    }
2142
2143    /// Check if the COPY operation is complete.
2144    pub fn is_done(&self) -> bool {
2145        self.done
2146    }
2147}
2148
2149#[cfg(test)]
2150mod tests {
2151    use super::*;
2152
2153    // ─── PgConfig::new ────────────────────────────────────────────────────────
2154
2155    #[test]
2156    fn test_pgconfig_new_fields() {
2157        let cfg = PgConfig::new("db.example.com", 5432, "alice", "s3cret", "mydb");
2158        assert_eq!(cfg.host, "db.example.com");
2159        assert_eq!(cfg.port, 5432);
2160        assert_eq!(cfg.user, "alice");
2161        assert_eq!(cfg.password, "s3cret");
2162        assert_eq!(cfg.database, "mydb");
2163        assert!(cfg.socket_dir.is_none());
2164    }
2165
2166    #[test]
2167    fn test_pgconfig_new_custom_port() {
2168        let cfg = PgConfig::new("host", 9999, "u", "p", "d");
2169        assert_eq!(cfg.port, 9999);
2170    }
2171
2172    #[test]
2173    fn test_pgconfig_with_socket_dir_sets_field() {
2174        let cfg =
2175            PgConfig::new("localhost", 5432, "u", "p", "d").with_socket_dir("/var/run/postgresql");
2176        assert_eq!(cfg.socket_dir.as_deref(), Some("/var/run/postgresql"));
2177    }
2178
2179    #[test]
2180    fn test_pgconfig_clone_preserves_all_fields() {
2181        let cfg = PgConfig::new("h", 1234, "u", "p", "db").with_socket_dir("/tmp");
2182        let cloned = cfg.clone();
2183        assert_eq!(cloned.host, "h");
2184        assert_eq!(cloned.port, 1234);
2185        assert_eq!(cloned.user, "u");
2186        assert_eq!(cloned.password, "p");
2187        assert_eq!(cloned.database, "db");
2188        assert_eq!(cloned.socket_dir, Some("/tmp".to_string()));
2189    }
2190
2191    #[test]
2192    fn test_pgconfig_debug_contains_host() {
2193        let cfg = PgConfig::new("myhost", 5432, "u", "p", "d");
2194        let s = format!("{:?}", cfg);
2195        assert!(s.contains("myhost"), "Debug must include host: {}", s);
2196    }
2197
2198    // ─── PgConfig::from_url — happy paths ────────────────────────────────────
2199
2200    #[test]
2201    fn test_from_url_basic_postgres_scheme() {
2202        let cfg = PgConfig::from_url("postgres://bob:hunter2@dbhost:5432/appdb").unwrap();
2203        assert_eq!(cfg.host, "dbhost");
2204        assert_eq!(cfg.port, 5432);
2205        assert_eq!(cfg.user, "bob");
2206        assert_eq!(cfg.password, "hunter2");
2207        assert_eq!(cfg.database, "appdb");
2208        assert!(cfg.socket_dir.is_none());
2209    }
2210
2211    #[test]
2212    fn test_from_url_postgresql_scheme() {
2213        let cfg = PgConfig::from_url("postgresql://u:p@host:5432/db").unwrap();
2214        assert_eq!(cfg.host, "host");
2215        assert_eq!(cfg.user, "u");
2216    }
2217
2218    #[test]
2219    fn test_from_url_default_port() {
2220        // When no port is given, should default to 5432
2221        let cfg = PgConfig::from_url("postgres://u:p@myhost/mydb").unwrap();
2222        assert_eq!(cfg.port, 5432);
2223        assert_eq!(cfg.host, "myhost");
2224    }
2225
2226    #[test]
2227    fn test_from_url_no_password() {
2228        // user only, no colon → password is empty string
2229        let cfg = PgConfig::from_url("postgres://alice@host:5432/db").unwrap();
2230        assert_eq!(cfg.user, "alice");
2231        assert_eq!(cfg.password, "");
2232    }
2233
2234    #[test]
2235    fn test_from_url_custom_port() {
2236        let cfg = PgConfig::from_url("postgres://u:p@host:9000/db").unwrap();
2237        assert_eq!(cfg.port, 9000);
2238    }
2239
2240    #[test]
2241    fn test_from_url_unix_socket_query_param() {
2242        let cfg = PgConfig::from_url("postgres://u:p@/db?host=/var/run/postgresql").unwrap();
2243        assert_eq!(cfg.socket_dir.as_deref(), Some("/var/run/postgresql"));
2244        assert_eq!(cfg.database, "db");
2245    }
2246
2247    #[test]
2248    fn test_from_url_unix_socket_percent_encoded() {
2249        let cfg = PgConfig::from_url("postgres://u:p@%2Fvar%2Frun%2Fpostgresql/db").unwrap();
2250        assert_eq!(cfg.socket_dir.as_deref(), Some("/var/run/postgresql"));
2251        assert_eq!(cfg.database, "db");
2252    }
2253
2254    // ─── PgConfig::from_url — error paths ────────────────────────────────────
2255
2256    #[test]
2257    fn test_from_url_invalid_scheme_errors() {
2258        let result = PgConfig::from_url("mysql://u:p@host/db");
2259        assert!(result.is_err(), "Non-postgres scheme must fail");
2260    }
2261
2262    #[test]
2263    fn test_from_url_missing_at_symbol_errors() {
2264        let result = PgConfig::from_url("postgres://no-at-sign/db");
2265        assert!(result.is_err(), "URL without @ must fail");
2266    }
2267
2268    #[test]
2269    fn test_from_url_missing_database_errors() {
2270        // No "/" after host — no database segment
2271        let result = PgConfig::from_url("postgres://u:p@host");
2272        assert!(result.is_err(), "URL without database must fail");
2273    }
2274
2275    #[test]
2276    fn test_from_url_invalid_port_errors() {
2277        let result = PgConfig::from_url("postgres://u:p@host:notaport/db");
2278        assert!(result.is_err(), "Non-numeric port must fail");
2279    }
2280
2281    #[test]
2282    fn test_from_url_empty_string_errors() {
2283        let result = PgConfig::from_url("");
2284        assert!(result.is_err());
2285    }
2286
2287    #[test]
2288    fn test_from_url_special_chars_in_password() {
2289        // Passwords with @ in them need proper URL encoding, but basic case works
2290        let cfg = PgConfig::from_url("postgres://user:p%40ss@host:5432/db");
2291        // Parser finds last @ — this might fail or partially parse; just verify no panic
2292        let _ = cfg; // result can be Ok or Err, but must not panic
2293    }
2294
2295    // ─── Notification struct ──────────────────────────────────────────────────
2296
2297    #[test]
2298    fn test_notification_fields() {
2299        let n = Notification {
2300            process_id: 12345,
2301            channel: "my_channel".to_string(),
2302            payload: "hello world".to_string(),
2303        };
2304        assert_eq!(n.process_id, 12345);
2305        assert_eq!(n.channel, "my_channel");
2306        assert_eq!(n.payload, "hello world");
2307    }
2308
2309    #[test]
2310    fn test_notification_clone() {
2311        let n = Notification {
2312            process_id: 42,
2313            channel: "ch".to_string(),
2314            payload: "pay".to_string(),
2315        };
2316        let n2 = n.clone();
2317        assert_eq!(n2.process_id, n.process_id);
2318        assert_eq!(n2.channel, n.channel);
2319        assert_eq!(n2.payload, n.payload);
2320    }
2321
2322    #[test]
2323    fn test_notification_debug() {
2324        let n = Notification {
2325            process_id: 1,
2326            channel: "c".to_string(),
2327            payload: "p".to_string(),
2328        };
2329        let s = format!("{:?}", n);
2330        assert!(
2331            s.contains("process_id"),
2332            "Debug must include process_id: {}",
2333            s
2334        );
2335    }
2336}