Skip to main content

qail_pg/driver/
io.rs

1//! Core I/O operations for PostgreSQL connection.
2//!
3//! This module provides low-level send/receive methods.
4
5use super::{PgConnection, PgError, PgResult, is_ignorable_session_message};
6use crate::protocol::{BackendMessage, FrontendMessage, PgEncoder};
7use bytes::BytesMut;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9
10pub(crate) const MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; // 64 MB — prevents OOM from malicious server messages
11
12/// Default read timeout for individual socket reads.
13/// Prevents Slowloris DoS where a server sends partial data then goes silent.
14const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
15/// Default write timeout for individual socket writes/flushes.
16/// Prevents indefinitely blocked writes from pinning pool slots.
17const DEFAULT_WRITE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
18
19#[inline]
20fn parse_data_row_payload_owned(payload: &[u8]) -> PgResult<Vec<Option<Vec<u8>>>> {
21    if payload.len() < 2 {
22        return Err(PgError::Protocol("DataRow payload too short".into()));
23    }
24
25    let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
26    if raw_count < 0 {
27        return Err(PgError::Protocol(format!(
28            "DataRow invalid column count: {}",
29            raw_count
30        )));
31    }
32    let column_count = raw_count as usize;
33    if column_count > (payload.len() - 2) / 4 + 1 {
34        return Err(PgError::Protocol(format!(
35            "DataRow claims {} columns but payload is only {} bytes",
36            column_count,
37            payload.len()
38        )));
39    }
40
41    let mut columns = Vec::with_capacity(column_count);
42    let mut pos = 2;
43    for _ in 0..column_count {
44        if pos + 4 > payload.len() {
45            return Err(PgError::Protocol(
46                "DataRow truncated: missing column length".into(),
47            ));
48        }
49
50        let len = i32::from_be_bytes([
51            payload[pos],
52            payload[pos + 1],
53            payload[pos + 2],
54            payload[pos + 3],
55        ]);
56        pos += 4;
57
58        if len == -1 {
59            columns.push(None);
60            continue;
61        }
62        if len < -1 {
63            return Err(PgError::Protocol(format!(
64                "DataRow invalid column length: {}",
65                len
66            )));
67        }
68
69        let len = len as usize;
70        if len > payload.len().saturating_sub(pos) {
71            return Err(PgError::Protocol(
72                "DataRow truncated: column data exceeds payload".into(),
73            ));
74        }
75        columns.push(Some(payload[pos..pos + len].to_vec()));
76        pos += len;
77    }
78
79    if pos != payload.len() {
80        return Err(PgError::Protocol("DataRow has trailing bytes".into()));
81    }
82
83    Ok(columns)
84}
85
86impl PgConnection {
87    #[inline]
88    pub(crate) fn mark_io_desynced(&mut self) {
89        self.io_desynced = true;
90    }
91
92    #[inline]
93    pub(crate) fn is_io_desynced(&self) -> bool {
94        self.io_desynced
95    }
96
97    #[inline]
98    fn protocol_desync<T>(&mut self, msg: String) -> PgResult<T> {
99        self.mark_io_desynced();
100        Err(PgError::Protocol(msg))
101    }
102
103    #[inline]
104    fn connection_desync<T>(&mut self, msg: String) -> PgResult<T> {
105        self.mark_io_desynced();
106        Err(PgError::Connection(msg))
107    }
108
109    /// Send queued statement `Close` messages and drain until `ReadyForQuery`.
110    ///
111    /// We ignore `26000 prepared statement ... does not exist` because this
112    /// can happen after failover or server-side invalidation, and in that case
113    /// local state is already being reconciled by retry paths.
114    async fn flush_pending_statement_closes(&mut self) -> PgResult<()> {
115        if self.draining_statement_closes || self.pending_statement_closes.is_empty() {
116            return Ok(());
117        }
118
119        self.draining_statement_closes = true;
120        let close_names = std::mem::take(&mut self.pending_statement_closes);
121
122        let estimated_payload_len: usize = close_names
123            .iter()
124            .map(|name| 16usize.saturating_add(name.len()))
125            .sum();
126        let mut buf = BytesMut::with_capacity(estimated_payload_len.saturating_add(5));
127        for stmt_name in &close_names {
128            let close_msg = PgEncoder::try_encode_close(false, stmt_name)
129                .map_err(|e| PgError::Encode(e.to_string()))?;
130            buf.extend_from_slice(&close_msg);
131        }
132        PgEncoder::encode_sync_to(&mut buf);
133
134        if let Err(err) = self
135            .write_all_with_timeout_inner(&buf, "pending statement close write")
136            .await
137        {
138            self.draining_statement_closes = false;
139            return Err(err);
140        }
141        if let Err(err) = self
142            .flush_with_timeout("pending statement close flush")
143            .await
144        {
145            self.draining_statement_closes = false;
146            return Err(err);
147        }
148
149        let mut error: Option<PgError> = None;
150        loop {
151            let msg = match self.recv().await {
152                Ok(msg) => msg,
153                Err(err) => {
154                    self.draining_statement_closes = false;
155                    return Err(err);
156                }
157            };
158            match msg {
159                BackendMessage::CloseComplete => {}
160                BackendMessage::ReadyForQuery(_) => {
161                    self.draining_statement_closes = false;
162                    if let Some(err) = error {
163                        return Err(err);
164                    }
165                    return Ok(());
166                }
167                BackendMessage::ErrorResponse(err_fields) => {
168                    if error.is_none() {
169                        let code_26000 = err_fields.code.eq_ignore_ascii_case("26000");
170                        let msg_lower = err_fields.message.to_ascii_lowercase();
171                        let missing_prepared = msg_lower.contains("prepared statement")
172                            && msg_lower.contains("does not exist");
173                        if !(code_26000 && missing_prepared) {
174                            error = Some(PgError::QueryServer(err_fields.into()));
175                        }
176                    }
177                }
178                msg if is_ignorable_session_message(&msg) => {}
179                other => {
180                    self.draining_statement_closes = false;
181                    return self.protocol_desync(format!(
182                        "Unexpected backend message during pending statement close drain: {:?}",
183                        other
184                    ));
185                }
186            }
187        }
188    }
189
190    /// Write all bytes with a timeout guard.
191    ///
192    /// Prevents stuck kernel send buffers or dead sockets from hanging forever.
193    pub(crate) async fn write_all_with_timeout(
194        &mut self,
195        bytes: &[u8],
196        operation: &str,
197    ) -> PgResult<()> {
198        if !self.draining_statement_closes && !self.pending_statement_closes.is_empty() {
199            self.flush_pending_statement_closes().await?;
200        }
201        self.write_all_with_timeout_inner(bytes, operation).await
202    }
203
204    async fn write_all_with_timeout_inner(
205        &mut self,
206        bytes: &[u8],
207        operation: &str,
208    ) -> PgResult<()> {
209        if bytes.is_empty() {
210            return Err(PgError::Encode(
211                "refusing to send empty frontend payload".to_string(),
212            ));
213        }
214        use super::stream::PgStream;
215        let mut mark_desync = false;
216        let result = match &mut self.stream {
217            PgStream::Tcp(stream) => {
218                match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
219                    Ok(Ok(())) => Ok(()),
220                    Ok(Err(e)) => {
221                        mark_desync = true;
222                        Err(PgError::Connection(format!("Write error: {}", e)))
223                    }
224                    Err(_) => {
225                        mark_desync = true;
226                        Err(PgError::Timeout(format!(
227                            "{} timeout after {:?}",
228                            operation, DEFAULT_WRITE_TIMEOUT
229                        )))
230                    }
231                }
232            }
233            PgStream::Tls(stream) => {
234                match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
235                    Ok(Ok(())) => Ok(()),
236                    Ok(Err(e)) => {
237                        mark_desync = true;
238                        Err(PgError::Connection(format!("Write error: {}", e)))
239                    }
240                    Err(_) => {
241                        mark_desync = true;
242                        Err(PgError::Timeout(format!(
243                            "{} timeout after {:?}",
244                            operation, DEFAULT_WRITE_TIMEOUT
245                        )))
246                    }
247                }
248            }
249            #[cfg(all(target_os = "linux", feature = "io_uring"))]
250            PgStream::Uring(stream) => {
251                match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
252                    Ok(Ok(())) => Ok(()),
253                    Ok(Err(e)) => {
254                        mark_desync = true;
255                        Err(PgError::Connection(format!("Write error: {}", e)))
256                    }
257                    Err(_) => {
258                        mark_desync = true;
259                        let _ = stream.abort_inflight();
260                        Err(PgError::Timeout(format!(
261                            "{} timeout after {:?}",
262                            operation, DEFAULT_WRITE_TIMEOUT
263                        )))
264                    }
265                }
266            }
267            #[cfg(unix)]
268            PgStream::Unix(stream) => {
269                match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
270                    Ok(Ok(())) => Ok(()),
271                    Ok(Err(e)) => {
272                        mark_desync = true;
273                        Err(PgError::Connection(format!("Write error: {}", e)))
274                    }
275                    Err(_) => {
276                        mark_desync = true;
277                        Err(PgError::Timeout(format!(
278                            "{} timeout after {:?}",
279                            operation, DEFAULT_WRITE_TIMEOUT
280                        )))
281                    }
282                }
283            }
284            #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
285            PgStream::GssEnc(stream) => {
286                match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
287                    Ok(Ok(())) => Ok(()),
288                    Ok(Err(e)) => {
289                        mark_desync = true;
290                        Err(PgError::Connection(format!("Write error: {}", e)))
291                    }
292                    Err(_) => {
293                        mark_desync = true;
294                        Err(PgError::Timeout(format!(
295                            "{} timeout after {:?}",
296                            operation, DEFAULT_WRITE_TIMEOUT
297                        )))
298                    }
299                }
300            }
301        };
302        if mark_desync {
303            self.mark_io_desynced();
304        }
305        result
306    }
307
308    /// Flush with a timeout guard.
309    pub(crate) async fn flush_with_timeout(&mut self, operation: &str) -> PgResult<()> {
310        use super::stream::PgStream;
311        let mut mark_desync = false;
312        let result = match &mut self.stream {
313            PgStream::Tcp(stream) => {
314                match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
315                    Ok(Ok(())) => Ok(()),
316                    Ok(Err(e)) => {
317                        mark_desync = true;
318                        Err(PgError::Connection(format!("Flush error: {}", e)))
319                    }
320                    Err(_) => {
321                        mark_desync = true;
322                        Err(PgError::Timeout(format!(
323                            "{} timeout after {:?}",
324                            operation, DEFAULT_WRITE_TIMEOUT
325                        )))
326                    }
327                }
328            }
329            PgStream::Tls(stream) => {
330                match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
331                    Ok(Ok(())) => Ok(()),
332                    Ok(Err(e)) => {
333                        mark_desync = true;
334                        Err(PgError::Connection(format!("Flush error: {}", e)))
335                    }
336                    Err(_) => {
337                        mark_desync = true;
338                        Err(PgError::Timeout(format!(
339                            "{} timeout after {:?}",
340                            operation, DEFAULT_WRITE_TIMEOUT
341                        )))
342                    }
343                }
344            }
345            #[cfg(all(target_os = "linux", feature = "io_uring"))]
346            PgStream::Uring(stream) => {
347                match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
348                    Ok(Ok(())) => Ok(()),
349                    Ok(Err(e)) => {
350                        mark_desync = true;
351                        Err(PgError::Connection(format!("Flush error: {}", e)))
352                    }
353                    Err(_) => {
354                        mark_desync = true;
355                        let _ = stream.abort_inflight();
356                        Err(PgError::Timeout(format!(
357                            "{} timeout after {:?}",
358                            operation, DEFAULT_WRITE_TIMEOUT
359                        )))
360                    }
361                }
362            }
363            #[cfg(unix)]
364            PgStream::Unix(stream) => {
365                match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
366                    Ok(Ok(())) => Ok(()),
367                    Ok(Err(e)) => {
368                        mark_desync = true;
369                        Err(PgError::Connection(format!("Flush error: {}", e)))
370                    }
371                    Err(_) => {
372                        mark_desync = true;
373                        Err(PgError::Timeout(format!(
374                            "{} timeout after {:?}",
375                            operation, DEFAULT_WRITE_TIMEOUT
376                        )))
377                    }
378                }
379            }
380            #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
381            PgStream::GssEnc(stream) => {
382                match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
383                    Ok(Ok(())) => Ok(()),
384                    Ok(Err(e)) => {
385                        mark_desync = true;
386                        Err(PgError::Connection(format!("Flush error: {}", e)))
387                    }
388                    Err(_) => {
389                        mark_desync = true;
390                        Err(PgError::Timeout(format!(
391                            "{} timeout after {:?}",
392                            operation, DEFAULT_WRITE_TIMEOUT
393                        )))
394                    }
395                }
396            }
397        };
398        if mark_desync {
399            self.mark_io_desynced();
400        }
401        result
402    }
403
404    /// Send a frontend message.
405    pub async fn send(&mut self, msg: FrontendMessage) -> PgResult<()> {
406        let bytes = msg
407            .encode_checked()
408            .map_err(|e| PgError::Encode(e.to_string()))?;
409        self.write_all_with_timeout(&bytes, "send frontend message")
410            .await?;
411        Ok(())
412    }
413
414    /// Loops until a complete message is available.
415    /// Automatically buffers NotificationResponse messages for LISTEN/NOTIFY.
416    pub async fn recv(&mut self) -> PgResult<BackendMessage> {
417        loop {
418            // Try to decode from buffer first
419            if self.buffer.len() >= 5 {
420                let msg_len = u32::from_be_bytes([
421                    self.buffer[1],
422                    self.buffer[2],
423                    self.buffer[3],
424                    self.buffer[4],
425                ]) as usize;
426
427                if msg_len < 4 {
428                    return self.protocol_desync(format!(
429                        "Invalid message length: {} (minimum 4)",
430                        msg_len
431                    ));
432                }
433
434                if msg_len > MAX_MESSAGE_SIZE {
435                    return self.protocol_desync(format!(
436                        "Message too large: {} bytes (max {})",
437                        msg_len, MAX_MESSAGE_SIZE
438                    ));
439                }
440
441                if self.buffer.len() > msg_len {
442                    // We have a complete message - zero-copy split
443                    let msg_bytes = self.buffer.split_to(msg_len + 1);
444                    let (msg, _) = match BackendMessage::decode(&msg_bytes) {
445                        Ok(decoded) => decoded,
446                        Err(e) => return self.protocol_desync(e),
447                    };
448
449                    // Intercept async notifications — buffer them instead of returning
450                    if let BackendMessage::NotificationResponse {
451                        process_id,
452                        channel,
453                        payload,
454                    } = msg
455                    {
456                        self.notifications
457                            .push_back(super::notification::Notification {
458                                process_id,
459                                channel,
460                                payload,
461                            });
462                        continue; // Keep reading for the actual response
463                    }
464
465                    return Ok(msg);
466                }
467            }
468
469            let n = self.read_with_timeout().await?;
470            if n == 0 {
471                return self.connection_desync("Connection closed".to_string());
472            }
473        }
474    }
475
476    /// Receive a backend message with idle-friendly timeout behavior.
477    ///
478    /// For long-lived idle streams (e.g. logical replication), an empty
479    /// buffer uses no-timeout reads so inactivity does not fail the stream.
480    /// If a backend frame is already partially buffered, switch back to the
481    /// normal read timeout to fail-closed on partial-frame stalls.
482    pub(crate) async fn recv_without_timeout(&mut self) -> PgResult<BackendMessage> {
483        loop {
484            if self.buffer.len() >= 5 {
485                let msg_len = u32::from_be_bytes([
486                    self.buffer[1],
487                    self.buffer[2],
488                    self.buffer[3],
489                    self.buffer[4],
490                ]) as usize;
491
492                if msg_len < 4 {
493                    return self.protocol_desync(format!(
494                        "Invalid message length: {} (minimum 4)",
495                        msg_len
496                    ));
497                }
498
499                if msg_len > MAX_MESSAGE_SIZE {
500                    return self.protocol_desync(format!(
501                        "Message too large: {} bytes (max {})",
502                        msg_len, MAX_MESSAGE_SIZE
503                    ));
504                }
505
506                if self.buffer.len() > msg_len {
507                    let msg_bytes = self.buffer.split_to(msg_len + 1);
508                    let (msg, _) = match BackendMessage::decode(&msg_bytes) {
509                        Ok(decoded) => decoded,
510                        Err(e) => return self.protocol_desync(e),
511                    };
512
513                    if let BackendMessage::NotificationResponse {
514                        process_id,
515                        channel,
516                        payload,
517                    } = msg
518                    {
519                        self.notifications
520                            .push_back(super::notification::Notification {
521                                process_id,
522                                channel,
523                                payload,
524                            });
525                        continue;
526                    }
527
528                    return Ok(msg);
529                }
530            }
531
532            let n = if self.buffer.is_empty() {
533                self.read_without_timeout().await?
534            } else {
535                self.read_with_timeout().await?
536            };
537            if n == 0 {
538                return self.connection_desync("Connection closed".to_string());
539            }
540        }
541    }
542
543    /// Read from the socket with a timeout guard.
544    /// Returns the number of bytes read, or an error if the timeout fires.
545    /// This prevents Slowloris DoS attacks where a malicious server sends
546    /// partial data then goes silent, causing the driver to hang forever.
547    #[inline]
548    pub(crate) async fn read_with_timeout(&mut self) -> PgResult<usize> {
549        if self.buffer.capacity() - self.buffer.len() < 65536 {
550            self.buffer.reserve(131072);
551        }
552
553        use super::stream::PgStream;
554        let (stream, buffer) = (&mut self.stream, &mut self.buffer);
555        let mut mark_desync = false;
556        let result = match stream {
557            PgStream::Tcp(stream) => {
558                match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_buf(buffer)).await {
559                    Ok(Ok(n)) => Ok(n),
560                    Ok(Err(e)) => {
561                        mark_desync = true;
562                        Err(PgError::Connection(format!("Read error: {}", e)))
563                    }
564                    Err(_) => {
565                        mark_desync = true;
566                        Err(PgError::Connection(format!(
567                            "Read timeout after {:?} — possible Slowloris attack or dead connection",
568                            DEFAULT_READ_TIMEOUT
569                        )))
570                    }
571                }
572            }
573            PgStream::Tls(stream) => {
574                match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_buf(buffer)).await {
575                    Ok(Ok(n)) => Ok(n),
576                    Ok(Err(e)) => {
577                        mark_desync = true;
578                        Err(PgError::Connection(format!("Read error: {}", e)))
579                    }
580                    Err(_) => {
581                        mark_desync = true;
582                        Err(PgError::Connection(format!(
583                            "Read timeout after {:?} — possible Slowloris attack or dead connection",
584                            DEFAULT_READ_TIMEOUT
585                        )))
586                    }
587                }
588            }
589            #[cfg(all(target_os = "linux", feature = "io_uring"))]
590            PgStream::Uring(stream) => {
591                match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_into(buffer, 131072))
592                    .await
593                {
594                    Ok(Ok(n)) => Ok(n),
595                    Ok(Err(e)) => {
596                        mark_desync = true;
597                        Err(PgError::Connection(format!("Read error: {}", e)))
598                    }
599                    Err(_) => {
600                        mark_desync = true;
601                        let _ = stream.abort_inflight();
602                        Err(PgError::Connection(format!(
603                            "Read timeout after {:?} — possible Slowloris attack or dead connection",
604                            DEFAULT_READ_TIMEOUT
605                        )))
606                    }
607                }
608            }
609            #[cfg(unix)]
610            PgStream::Unix(stream) => {
611                match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_buf(buffer)).await {
612                    Ok(Ok(n)) => Ok(n),
613                    Ok(Err(e)) => {
614                        mark_desync = true;
615                        Err(PgError::Connection(format!("Read error: {}", e)))
616                    }
617                    Err(_) => {
618                        mark_desync = true;
619                        Err(PgError::Connection(format!(
620                            "Read timeout after {:?} — possible Slowloris attack or dead connection",
621                            DEFAULT_READ_TIMEOUT
622                        )))
623                    }
624                }
625            }
626            #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
627            PgStream::GssEnc(stream) => {
628                match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_buf(buffer)).await {
629                    Ok(Ok(n)) => Ok(n),
630                    Ok(Err(e)) => {
631                        mark_desync = true;
632                        Err(PgError::Connection(format!("Read error: {}", e)))
633                    }
634                    Err(_) => {
635                        mark_desync = true;
636                        Err(PgError::Connection(format!(
637                            "Read timeout after {:?} — possible Slowloris attack or dead connection",
638                            DEFAULT_READ_TIMEOUT
639                        )))
640                    }
641                }
642            }
643        };
644        if mark_desync {
645            self.mark_io_desynced();
646        }
647        result
648    }
649
650    /// Read from socket without timeout guard.
651    ///
652    /// Used for long-idle LISTEN/NOTIFY connections.
653    pub(crate) async fn read_without_timeout(&mut self) -> PgResult<usize> {
654        if self.buffer.capacity() - self.buffer.len() < 65536 {
655            self.buffer.reserve(131072);
656        }
657
658        use super::stream::PgStream;
659        let (stream, buffer) = (&mut self.stream, &mut self.buffer);
660        let read_result = match stream {
661            PgStream::Tcp(stream) => stream.read_buf(buffer).await,
662            PgStream::Tls(stream) => stream.read_buf(buffer).await,
663            #[cfg(all(target_os = "linux", feature = "io_uring"))]
664            PgStream::Uring(stream) => stream.read_into(buffer, 131072).await,
665            #[cfg(unix)]
666            PgStream::Unix(stream) => stream.read_buf(buffer).await,
667            #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
668            PgStream::GssEnc(stream) => stream.read_buf(buffer).await,
669        };
670
671        match read_result {
672            Ok(n) => Ok(n),
673            Err(e) => {
674                self.mark_io_desynced();
675                Err(PgError::Connection(format!("Read error: {}", e)))
676            }
677        }
678    }
679
680    /// Send raw bytes to the stream.
681    /// Includes flush for TLS safety — TLS buffers internally and
682    /// needs flush to push encrypted data to the underlying TCP socket.
683    pub async fn send_bytes(&mut self, bytes: &[u8]) -> PgResult<()> {
684        self.write_all_with_timeout(bytes, "send raw bytes").await?;
685        self.flush_with_timeout("flush raw bytes").await?;
686        Ok(())
687    }
688
689    // ==================== BUFFERED WRITE API (High Performance) ====================
690
691    /// Buffer bytes for later flush (NO SYSCALL).
692    /// Use flush_write_buf() to send all buffered data.
693    #[inline]
694    pub fn buffer_bytes(&mut self, bytes: &[u8]) {
695        self.write_buf.extend_from_slice(bytes);
696    }
697
698    /// Flush the write buffer to the stream (single write_all + flush).
699    /// The flush is critical for TLS connections.
700    pub async fn flush_write_buf(&mut self) -> PgResult<()> {
701        if !self.write_buf.is_empty() {
702            let payload = std::mem::take(&mut self.write_buf);
703            self.write_all_with_timeout(&payload, "flush write buffer")
704                .await?;
705            self.flush_with_timeout("flush write buffer").await?;
706        }
707        Ok(())
708    }
709
710    /// FAST receive - returns only message type byte, skips parsing.
711    /// This is ~10x faster than recv() for pipelining benchmarks.
712    /// Returns: message_type
713    #[inline]
714    pub(crate) async fn recv_msg_type_fast(&mut self) -> PgResult<u8> {
715        loop {
716            if self.buffer.len() >= 5 {
717                let msg_len = u32::from_be_bytes([
718                    self.buffer[1],
719                    self.buffer[2],
720                    self.buffer[3],
721                    self.buffer[4],
722                ]) as usize;
723
724                if msg_len < 4 {
725                    return self.protocol_desync(format!(
726                        "Invalid message length: {} (minimum 4)",
727                        msg_len
728                    ));
729                }
730
731                if msg_len > MAX_MESSAGE_SIZE {
732                    return self.protocol_desync(format!(
733                        "Message too large: {} bytes (max {})",
734                        msg_len, MAX_MESSAGE_SIZE
735                    ));
736                }
737
738                if self.buffer.len() > msg_len {
739                    let msg_type = self.buffer[0];
740
741                    if msg_type == b'E' || msg_type == b'A' {
742                        let msg_bytes = self.buffer.split_to(msg_len + 1);
743                        let (msg, _) = match BackendMessage::decode(&msg_bytes) {
744                            Ok(decoded) => decoded,
745                            Err(e) => return self.protocol_desync(e),
746                        };
747                        match msg {
748                            BackendMessage::ErrorResponse(err) => {
749                                return Err(PgError::QueryServer(err.into()));
750                            }
751                            BackendMessage::NotificationResponse {
752                                process_id,
753                                channel,
754                                payload,
755                            } => {
756                                self.notifications
757                                    .push_back(super::notification::Notification {
758                                        process_id,
759                                        channel,
760                                        payload,
761                                    });
762                                continue;
763                            }
764                            _ => {
765                                return Err(PgError::Protocol(
766                                    "Unexpected fast-path message".into(),
767                                ));
768                            }
769                        }
770                    }
771
772                    let _ = self.buffer.split_to(msg_len + 1);
773                    return Ok(msg_type);
774                }
775            }
776
777            let n = self.read_with_timeout().await?;
778            if n == 0 {
779                return self.connection_desync("Connection closed".to_string());
780            }
781        }
782    }
783
784    /// FAST receive for result consumption - inline DataRow parsing.
785    /// Returns: (msg_type, Option<row_data>)
786    /// For 'D' (DataRow): returns parsed columns
787    /// For other types: returns None
788    /// This avoids BackendMessage enum allocation for non-DataRow messages.
789    #[inline]
790    pub(crate) async fn recv_with_data_fast(
791        &mut self,
792    ) -> PgResult<(u8, Option<Vec<Option<Vec<u8>>>>)> {
793        loop {
794            if self.buffer.len() >= 5 {
795                let msg_len = u32::from_be_bytes([
796                    self.buffer[1],
797                    self.buffer[2],
798                    self.buffer[3],
799                    self.buffer[4],
800                ]) as usize;
801
802                if msg_len < 4 {
803                    return self.protocol_desync(format!(
804                        "Invalid message length: {} (minimum 4)",
805                        msg_len
806                    ));
807                }
808
809                if msg_len > MAX_MESSAGE_SIZE {
810                    return self.protocol_desync(format!(
811                        "Message too large: {} bytes (max {})",
812                        msg_len, MAX_MESSAGE_SIZE
813                    ));
814                }
815
816                if self.buffer.len() > msg_len {
817                    let msg_type = self.buffer[0];
818
819                    if msg_type == b'E' || msg_type == b'A' {
820                        let msg_bytes = self.buffer.split_to(msg_len + 1);
821                        let (msg, _) = match BackendMessage::decode(&msg_bytes) {
822                            Ok(decoded) => decoded,
823                            Err(e) => return self.protocol_desync(e),
824                        };
825                        match msg {
826                            BackendMessage::ErrorResponse(err) => {
827                                return Err(PgError::QueryServer(err.into()));
828                            }
829                            BackendMessage::NotificationResponse {
830                                process_id,
831                                channel,
832                                payload,
833                            } => {
834                                self.notifications
835                                    .push_back(super::notification::Notification {
836                                        process_id,
837                                        channel,
838                                        payload,
839                                    });
840                                continue;
841                            }
842                            _ => {
843                                return Err(PgError::Protocol(
844                                    "Unexpected fast-path message".into(),
845                                ));
846                            }
847                        }
848                    }
849
850                    // Fast path: DataRow - parse inline
851                    if msg_type == b'D' {
852                        let parse_result = {
853                            let payload = &self.buffer[5..msg_len + 1];
854                            parse_data_row_payload_owned(payload)
855                        };
856
857                        let _ = self.buffer.split_to(msg_len + 1);
858                        match parse_result {
859                            Ok(columns) => return Ok((msg_type, Some(columns))),
860                            Err(err) => return Err(err),
861                        }
862                    }
863
864                    // Other messages - skip
865                    let _ = self.buffer.split_to(msg_len + 1);
866                    return Ok((msg_type, None));
867                }
868            }
869
870            let n = self.read_with_timeout().await?;
871            if n == 0 {
872                return self.connection_desync("Connection closed".to_string());
873            }
874        }
875    }
876
877    /// ZERO-COPY receive for DataRow.
878    /// Uses bytes::Bytes for reference-counted slicing instead of Vec copy.
879    /// Returns: (msg_type, Option<row_data>)
880    /// For 'D' (DataRow): returns Bytes slices (no copy!)
881    /// For other types: returns None
882    #[inline]
883    pub(crate) async fn recv_data_zerocopy(
884        &mut self,
885    ) -> PgResult<(u8, Option<Vec<Option<bytes::Bytes>>>)> {
886        use bytes::Buf;
887
888        loop {
889            if self.buffer.len() >= 5 {
890                let msg_len = u32::from_be_bytes([
891                    self.buffer[1],
892                    self.buffer[2],
893                    self.buffer[3],
894                    self.buffer[4],
895                ]) as usize;
896
897                if msg_len < 4 {
898                    return self.protocol_desync(format!(
899                        "Invalid message length: {} (minimum 4)",
900                        msg_len
901                    ));
902                }
903
904                if msg_len > MAX_MESSAGE_SIZE {
905                    return self.protocol_desync(format!(
906                        "Message too large: {} bytes (max {})",
907                        msg_len, MAX_MESSAGE_SIZE
908                    ));
909                }
910
911                if self.buffer.len() > msg_len {
912                    let msg_type = self.buffer[0];
913
914                    if msg_type == b'E' || msg_type == b'A' {
915                        let msg_bytes = self.buffer.split_to(msg_len + 1);
916                        let (msg, _) = match BackendMessage::decode(&msg_bytes) {
917                            Ok(decoded) => decoded,
918                            Err(e) => return self.protocol_desync(e),
919                        };
920                        match msg {
921                            BackendMessage::ErrorResponse(err) => {
922                                return Err(PgError::QueryServer(err.into()));
923                            }
924                            BackendMessage::NotificationResponse {
925                                process_id,
926                                channel,
927                                payload,
928                            } => {
929                                self.notifications
930                                    .push_back(super::notification::Notification {
931                                        process_id,
932                                        channel,
933                                        payload,
934                                    });
935                                continue;
936                            }
937                            _ => {
938                                return Err(PgError::Protocol(
939                                    "Unexpected fast-path message".into(),
940                                ));
941                            }
942                        }
943                    }
944
945                    // Fast path: DataRow - ZERO-COPY using Bytes
946                    if msg_type == b'D' {
947                        // Split off the entire message
948                        let mut msg_bytes = self.buffer.split_to(msg_len + 1);
949
950                        // Skip type byte (1) + length (4) = 5 bytes
951                        msg_bytes.advance(5);
952
953                        if msg_bytes.len() >= 2 {
954                            let raw_count = msg_bytes.get_i16();
955                            if raw_count < 0 {
956                                return Err(PgError::Protocol(format!(
957                                    "DataRow invalid column count: {}",
958                                    raw_count
959                                )));
960                            }
961                            let column_count = raw_count as usize;
962                            if column_count > msg_bytes.remaining() / 4 + 1 {
963                                return Err(PgError::Protocol(format!(
964                                    "DataRow claims {} columns but payload is only {} bytes",
965                                    column_count,
966                                    msg_bytes.remaining() + 2
967                                )));
968                            }
969                            let mut columns = Vec::with_capacity(column_count);
970
971                            for _ in 0..column_count {
972                                if msg_bytes.remaining() < 4 {
973                                    return Err(PgError::Protocol(
974                                        "DataRow truncated: missing column length".into(),
975                                    ));
976                                }
977
978                                let len = msg_bytes.get_i32();
979
980                                if len == -1 {
981                                    columns.push(None);
982                                } else {
983                                    if len < -1 {
984                                        return Err(PgError::Protocol(format!(
985                                            "DataRow invalid column length: {}",
986                                            len
987                                        )));
988                                    }
989                                    let len = len as usize;
990                                    if msg_bytes.remaining() < len {
991                                        return Err(PgError::Protocol(
992                                            "DataRow truncated: column data exceeds payload".into(),
993                                        ));
994                                    }
995                                    let col_data = msg_bytes.split_to(len).freeze();
996                                    columns.push(Some(col_data));
997                                }
998                            }
999
1000                            if msg_bytes.remaining() != 0 {
1001                                return Err(PgError::Protocol("DataRow has trailing bytes".into()));
1002                            }
1003
1004                            return Ok((msg_type, Some(columns)));
1005                        }
1006                        return Ok((msg_type, None));
1007                    }
1008
1009                    // Other messages - skip
1010                    let _ = self.buffer.split_to(msg_len + 1);
1011                    return Ok((msg_type, None));
1012                }
1013            }
1014
1015            let n = self.read_with_timeout().await?;
1016            if n == 0 {
1017                return self.connection_desync("Connection closed".to_string());
1018            }
1019        }
1020    }
1021
1022    /// ULTRA-FAST receive for 2-column DataRow (id, name pattern).
1023    /// Uses fixed-size array instead of Vec allocation.
1024    /// Returns: (msg_type, Option<(col0, col1)>)
1025    #[inline(always)]
1026    pub(crate) async fn recv_data_ultra(
1027        &mut self,
1028    ) -> PgResult<(u8, Option<(bytes::Bytes, bytes::Bytes)>)> {
1029        use bytes::Buf;
1030
1031        loop {
1032            if self.buffer.len() >= 5 {
1033                let msg_len = u32::from_be_bytes([
1034                    self.buffer[1],
1035                    self.buffer[2],
1036                    self.buffer[3],
1037                    self.buffer[4],
1038                ]) as usize;
1039
1040                if msg_len < 4 {
1041                    return self.protocol_desync(format!(
1042                        "Invalid message length: {} (minimum 4)",
1043                        msg_len
1044                    ));
1045                }
1046
1047                if msg_len > MAX_MESSAGE_SIZE {
1048                    return self.protocol_desync(format!(
1049                        "Message too large: {} bytes (max {})",
1050                        msg_len, MAX_MESSAGE_SIZE
1051                    ));
1052                }
1053
1054                if self.buffer.len() > msg_len {
1055                    let msg_type = self.buffer[0];
1056
1057                    // Error and async-notify checks
1058                    if msg_type == b'E' || msg_type == b'A' {
1059                        let msg_bytes = self.buffer.split_to(msg_len + 1);
1060                        let (msg, _) = match BackendMessage::decode(&msg_bytes) {
1061                            Ok(decoded) => decoded,
1062                            Err(e) => return self.protocol_desync(e),
1063                        };
1064                        match msg {
1065                            BackendMessage::ErrorResponse(err) => {
1066                                return Err(PgError::QueryServer(err.into()));
1067                            }
1068                            BackendMessage::NotificationResponse {
1069                                process_id,
1070                                channel,
1071                                payload,
1072                            } => {
1073                                self.notifications
1074                                    .push_back(super::notification::Notification {
1075                                        process_id,
1076                                        channel,
1077                                        payload,
1078                                    });
1079                                continue;
1080                            }
1081                            _ => {
1082                                return Err(PgError::Protocol(
1083                                    "Unexpected fast-path message".into(),
1084                                ));
1085                            }
1086                        }
1087                    }
1088
1089                    if msg_type == b'D' {
1090                        let mut msg_bytes = self.buffer.split_to(msg_len + 1);
1091                        msg_bytes.advance(5); // Skip type + length
1092
1093                        // Bounds checks to prevent panic on truncated DataRow
1094                        if msg_bytes.remaining() < 2 {
1095                            return Err(PgError::Protocol(
1096                                "DataRow ultra: too short for column count".into(),
1097                            ));
1098                        }
1099
1100                        // Read column count (expect 2)
1101                        let col_count = msg_bytes.get_i16();
1102                        if col_count != 2 {
1103                            return Err(PgError::Protocol(format!(
1104                                "DataRow ultra expects exactly 2 columns, got {}",
1105                                col_count
1106                            )));
1107                        }
1108
1109                        if msg_bytes.remaining() < 4 {
1110                            return Err(PgError::Protocol(
1111                                "DataRow ultra: truncated before col0 length".into(),
1112                            ));
1113                        }
1114                        let len0 = msg_bytes.get_i32();
1115                        let col0 = if len0 > 0 {
1116                            let len0 = len0 as usize;
1117                            if msg_bytes.remaining() < len0 {
1118                                return Err(PgError::Protocol(
1119                                    "DataRow ultra: col0 data exceeds payload".into(),
1120                                ));
1121                            }
1122                            msg_bytes.split_to(len0).freeze()
1123                        } else if len0 == 0 {
1124                            bytes::Bytes::new()
1125                        } else if len0 == -1 {
1126                            return Err(PgError::Protocol(
1127                                "DataRow ultra does not support NULL columns".into(),
1128                            ));
1129                        } else {
1130                            return Err(PgError::Protocol(format!(
1131                                "DataRow ultra: invalid col0 length {}",
1132                                len0
1133                            )));
1134                        };
1135
1136                        if msg_bytes.remaining() < 4 {
1137                            return Err(PgError::Protocol(
1138                                "DataRow ultra: truncated before col1 length".into(),
1139                            ));
1140                        }
1141                        let len1 = msg_bytes.get_i32();
1142                        let col1 = if len1 > 0 {
1143                            let len1 = len1 as usize;
1144                            if msg_bytes.remaining() < len1 {
1145                                return Err(PgError::Protocol(
1146                                    "DataRow ultra: col1 data exceeds payload".into(),
1147                                ));
1148                            }
1149                            msg_bytes.split_to(len1).freeze()
1150                        } else if len1 == 0 {
1151                            bytes::Bytes::new()
1152                        } else if len1 == -1 {
1153                            return Err(PgError::Protocol(
1154                                "DataRow ultra does not support NULL columns".into(),
1155                            ));
1156                        } else {
1157                            return Err(PgError::Protocol(format!(
1158                                "DataRow ultra: invalid col1 length {}",
1159                                len1
1160                            )));
1161                        };
1162
1163                        if msg_bytes.remaining() != 0 {
1164                            return Err(PgError::Protocol(
1165                                "DataRow ultra: trailing bytes after expected columns".into(),
1166                            ));
1167                        }
1168
1169                        return Ok((msg_type, Some((col0, col1))));
1170                    }
1171
1172                    // Other messages - skip
1173                    let _ = self.buffer.split_to(msg_len + 1);
1174                    return Ok((msg_type, None));
1175                }
1176            }
1177
1178            let n = self.read_with_timeout().await?;
1179            if n == 0 {
1180                return self.connection_desync("Connection closed".to_string());
1181            }
1182        }
1183    }
1184}