Skip to main content

qail_pg/driver/
io.rs

1//! Core I/O operations for PostgreSQL connection.
2//!
3//! This module provides low-level send/receive methods.
4
5use super::{PgConnection, PgError, PgResult};
6use crate::protocol::{BackendMessage, FrontendMessage};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8
9const MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; // 64 MB — prevents OOM from malicious server messages
10
11/// Default read timeout for individual socket reads.
12/// Prevents Slowloris DoS where a server sends partial data then goes silent.
13const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
14
15impl PgConnection {
16    /// Send a frontend message.
17    pub async fn send(&mut self, msg: FrontendMessage) -> PgResult<()> {
18        let bytes = msg.encode();
19        self.stream.write_all(&bytes).await?;
20        Ok(())
21    }
22
23    /// Loops until a complete message is available.
24    /// Automatically buffers NotificationResponse messages for LISTEN/NOTIFY.
25    pub async fn recv(&mut self) -> PgResult<BackendMessage> {
26        loop {
27            // Try to decode from buffer first
28            if self.buffer.len() >= 5 {
29                let msg_len = u32::from_be_bytes([
30                    self.buffer[1],
31                    self.buffer[2],
32                    self.buffer[3],
33                    self.buffer[4],
34                ]) as usize;
35
36                if msg_len > MAX_MESSAGE_SIZE {
37                    return Err(PgError::Protocol(format!(
38                        "Message too large: {} bytes (max {})",
39                        msg_len, MAX_MESSAGE_SIZE
40                    )));
41                }
42
43                if self.buffer.len() > msg_len {
44                    // We have a complete message - zero-copy split
45                    let msg_bytes = self.buffer.split_to(msg_len + 1);
46                    let (msg, _) = BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
47
48                    // Intercept async notifications — buffer them instead of returning
49                    if let BackendMessage::NotificationResponse {
50                        process_id,
51                        channel,
52                        payload,
53                    } = msg
54                    {
55                        self.notifications
56                            .push_back(super::notification::Notification {
57                                process_id,
58                                channel,
59                                payload,
60                            });
61                        continue; // Keep reading for the actual response
62                    }
63
64                    return Ok(msg);
65                }
66            }
67
68            let n = self.read_with_timeout().await?;
69            if n == 0 {
70                return Err(PgError::Connection("Connection closed".to_string()));
71            }
72        }
73    }
74
75    /// Read from the socket with a timeout guard.
76    /// Returns the number of bytes read, or an error if the timeout fires.
77    /// This prevents Slowloris DoS attacks where a malicious server sends
78    /// partial data then goes silent, causing the driver to hang forever.
79    #[inline]
80    pub(crate) async fn read_with_timeout(&mut self) -> PgResult<usize> {
81        if self.buffer.capacity() - self.buffer.len() < 65536 {
82            self.buffer.reserve(131072);
83        }
84
85        match tokio::time::timeout(DEFAULT_READ_TIMEOUT, self.stream.read_buf(&mut self.buffer))
86            .await
87        {
88            Ok(Ok(n)) => Ok(n),
89            Ok(Err(e)) => Err(PgError::Connection(format!("Read error: {}", e))),
90            Err(_) => Err(PgError::Connection(format!(
91                "Read timeout after {:?} — possible Slowloris attack or dead connection",
92                DEFAULT_READ_TIMEOUT
93            ))),
94        }
95    }
96
97    /// Send raw bytes to the stream.
98    /// Includes flush for TLS safety — TLS buffers internally and
99    /// needs flush to push encrypted data to the underlying TCP socket.
100    pub async fn send_bytes(&mut self, bytes: &[u8]) -> PgResult<()> {
101        self.stream.write_all(bytes).await?;
102        self.stream.flush().await?;
103        Ok(())
104    }
105
106    // ==================== BUFFERED WRITE API (High Performance) ====================
107
108    /// Buffer bytes for later flush (NO SYSCALL).
109    /// Use flush_write_buf() to send all buffered data.
110    #[inline]
111    pub fn buffer_bytes(&mut self, bytes: &[u8]) {
112        self.write_buf.extend_from_slice(bytes);
113    }
114
115    /// Flush the write buffer to the stream (single write_all + flush).
116    /// The flush is critical for TLS connections.
117    pub async fn flush_write_buf(&mut self) -> PgResult<()> {
118        if !self.write_buf.is_empty() {
119            self.stream.write_all(&self.write_buf).await?;
120            self.write_buf.clear();
121            self.stream.flush().await?;
122        }
123        Ok(())
124    }
125
126    /// FAST receive - returns only message type byte, skips parsing.
127    /// This is ~10x faster than recv() for pipelining benchmarks.
128    /// Returns: message_type
129    #[inline]
130    pub(crate) async fn recv_msg_type_fast(&mut self) -> PgResult<u8> {
131        loop {
132            if self.buffer.len() >= 5 {
133                let msg_len = u32::from_be_bytes([
134                    self.buffer[1],
135                    self.buffer[2],
136                    self.buffer[3],
137                    self.buffer[4],
138                ]) as usize;
139
140                if msg_len > MAX_MESSAGE_SIZE {
141                    return Err(PgError::Protocol(format!(
142                        "Message too large: {} bytes (max {})",
143                        msg_len, MAX_MESSAGE_SIZE
144                    )));
145                }
146
147                if self.buffer.len() > msg_len {
148                    let msg_type = self.buffer[0];
149
150                    if msg_type == b'E' {
151                        let msg_bytes = self.buffer.split_to(msg_len + 1);
152                        let (msg, _) =
153                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
154                        if let BackendMessage::ErrorResponse(err) = msg {
155                            return Err(PgError::QueryServer(err.into()));
156                        }
157                    }
158
159                    let _ = self.buffer.split_to(msg_len + 1);
160                    return Ok(msg_type);
161                }
162            }
163
164            let n = self.read_with_timeout().await?;
165            if n == 0 {
166                return Err(PgError::Connection("Connection closed".to_string()));
167            }
168        }
169    }
170
171    /// FAST receive for result consumption - inline DataRow parsing.
172    /// Returns: (msg_type, Option<row_data>)
173    /// For 'D' (DataRow): returns parsed columns
174    /// For other types: returns None
175    /// This avoids BackendMessage enum allocation for non-DataRow messages.
176    #[inline]
177    pub(crate) async fn recv_with_data_fast(
178        &mut self,
179    ) -> PgResult<(u8, Option<Vec<Option<Vec<u8>>>>)> {
180        loop {
181            if self.buffer.len() >= 5 {
182                let msg_len = u32::from_be_bytes([
183                    self.buffer[1],
184                    self.buffer[2],
185                    self.buffer[3],
186                    self.buffer[4],
187                ]) as usize;
188
189                if msg_len > MAX_MESSAGE_SIZE {
190                    return Err(PgError::Protocol(format!(
191                        "Message too large: {} bytes (max {})",
192                        msg_len, MAX_MESSAGE_SIZE
193                    )));
194                }
195
196                if self.buffer.len() > msg_len {
197                    let msg_type = self.buffer[0];
198
199                    if msg_type == b'E' {
200                        let msg_bytes = self.buffer.split_to(msg_len + 1);
201                        let (msg, _) =
202                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
203                        if let BackendMessage::ErrorResponse(err) = msg {
204                            return Err(PgError::QueryServer(err.into()));
205                        }
206                    }
207
208                    // Fast path: DataRow - parse inline
209                    if msg_type == b'D' {
210                        let payload = &self.buffer[5..msg_len + 1];
211
212                        if payload.len() >= 2 {
213                            let column_count =
214                                u16::from_be_bytes([payload[0], payload[1]]) as usize;
215                            let mut columns = Vec::with_capacity(column_count);
216                            let mut pos = 2;
217
218                            for _ in 0..column_count {
219                                if pos + 4 > payload.len() {
220                                    let _ = self.buffer.split_to(msg_len + 1);
221                                    return Err(PgError::Protocol(
222                                        "DataRow truncated: missing column length".into(),
223                                    ));
224                                }
225
226                                let len = i32::from_be_bytes([
227                                    payload[pos],
228                                    payload[pos + 1],
229                                    payload[pos + 2],
230                                    payload[pos + 3],
231                                ]);
232                                pos += 4;
233
234                                if len == -1 {
235                                    columns.push(None);
236                                } else {
237                                    let len = len as usize;
238                                    if pos + len > payload.len() {
239                                        let _ = self.buffer.split_to(msg_len + 1);
240                                        return Err(PgError::Protocol(
241                                            "DataRow truncated: column data exceeds payload".into(),
242                                        ));
243                                    }
244                                    columns.push(Some(payload[pos..pos + len].to_vec()));
245                                    pos += len;
246                                }
247                            }
248
249                            let _ = self.buffer.split_to(msg_len + 1);
250                            return Ok((msg_type, Some(columns)));
251                        }
252                    }
253
254                    // Other messages - skip
255                    let _ = self.buffer.split_to(msg_len + 1);
256                    return Ok((msg_type, None));
257                }
258            }
259
260            let n = self.read_with_timeout().await?;
261            if n == 0 {
262                return Err(PgError::Connection("Connection closed".to_string()));
263            }
264        }
265    }
266
267    /// ZERO-COPY receive for DataRow.
268    /// Uses bytes::Bytes for reference-counted slicing instead of Vec copy.
269    /// Returns: (msg_type, Option<row_data>)
270    /// For 'D' (DataRow): returns Bytes slices (no copy!)
271    /// For other types: returns None
272    #[inline]
273    pub(crate) async fn recv_data_zerocopy(
274        &mut self,
275    ) -> PgResult<(u8, Option<Vec<Option<bytes::Bytes>>>)> {
276        use bytes::Buf;
277
278        loop {
279            if self.buffer.len() >= 5 {
280                let msg_len = u32::from_be_bytes([
281                    self.buffer[1],
282                    self.buffer[2],
283                    self.buffer[3],
284                    self.buffer[4],
285                ]) as usize;
286
287                if msg_len > MAX_MESSAGE_SIZE {
288                    return Err(PgError::Protocol(format!(
289                        "Message too large: {} bytes (max {})",
290                        msg_len, MAX_MESSAGE_SIZE
291                    )));
292                }
293
294                if self.buffer.len() > msg_len {
295                    let msg_type = self.buffer[0];
296
297                    if msg_type == b'E' {
298                        let msg_bytes = self.buffer.split_to(msg_len + 1);
299                        let (msg, _) =
300                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
301                        if let BackendMessage::ErrorResponse(err) = msg {
302                            return Err(PgError::QueryServer(err.into()));
303                        }
304                    }
305
306                    // Fast path: DataRow - ZERO-COPY using Bytes
307                    if msg_type == b'D' {
308                        // Split off the entire message
309                        let mut msg_bytes = self.buffer.split_to(msg_len + 1);
310
311                        // Skip type byte (1) + length (4) = 5 bytes
312                        msg_bytes.advance(5);
313
314                        if msg_bytes.len() >= 2 {
315                            let column_count = msg_bytes.get_u16() as usize;
316                            let mut columns = Vec::with_capacity(column_count);
317
318                            for _ in 0..column_count {
319                                if msg_bytes.remaining() < 4 {
320                                    return Err(PgError::Protocol(
321                                        "DataRow truncated: missing column length".into(),
322                                    ));
323                                }
324
325                                let len = msg_bytes.get_i32();
326
327                                if len == -1 {
328                                    columns.push(None);
329                                } else {
330                                    let len = len as usize;
331                                    if msg_bytes.remaining() < len {
332                                        return Err(PgError::Protocol(
333                                            "DataRow truncated: column data exceeds payload".into(),
334                                        ));
335                                    }
336                                    let col_data = msg_bytes.split_to(len).freeze();
337                                    columns.push(Some(col_data));
338                                }
339                            }
340
341                            return Ok((msg_type, Some(columns)));
342                        }
343                        return Ok((msg_type, None));
344                    }
345
346                    // Other messages - skip
347                    let _ = self.buffer.split_to(msg_len + 1);
348                    return Ok((msg_type, None));
349                }
350            }
351
352            let n = self.read_with_timeout().await?;
353            if n == 0 {
354                return Err(PgError::Connection("Connection closed".to_string()));
355            }
356        }
357    }
358
359    /// ULTRA-FAST receive for 2-column DataRow (id, name pattern).
360    /// Uses fixed-size array instead of Vec allocation.
361    /// Returns: (msg_type, Option<(col0, col1)>)
362    #[inline(always)]
363    pub(crate) async fn recv_data_ultra(
364        &mut self,
365    ) -> PgResult<(u8, Option<(bytes::Bytes, bytes::Bytes)>)> {
366        use bytes::Buf;
367
368        loop {
369            if self.buffer.len() >= 5 {
370                let msg_len = u32::from_be_bytes([
371                    self.buffer[1],
372                    self.buffer[2],
373                    self.buffer[3],
374                    self.buffer[4],
375                ]) as usize;
376
377                if msg_len > MAX_MESSAGE_SIZE {
378                    return Err(PgError::Protocol(format!(
379                        "Message too large: {} bytes (max {})",
380                        msg_len, MAX_MESSAGE_SIZE
381                    )));
382                }
383
384                if self.buffer.len() > msg_len {
385                    let msg_type = self.buffer[0];
386
387                    // Error check
388                    if msg_type == b'E' {
389                        let msg_bytes = self.buffer.split_to(msg_len + 1);
390                        let (msg, _) =
391                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
392                        if let BackendMessage::ErrorResponse(err) = msg {
393                            return Err(PgError::QueryServer(err.into()));
394                        }
395                    }
396
397                    if msg_type == b'D' {
398                        let mut msg_bytes = self.buffer.split_to(msg_len + 1);
399                        msg_bytes.advance(5); // Skip type + length
400
401                        // Bounds checks to prevent panic on truncated DataRow
402                        if msg_bytes.remaining() < 2 {
403                            return Err(PgError::Protocol(
404                                "DataRow ultra: too short for column count".into(),
405                            ));
406                        }
407
408                        // Read column count (expect 2)
409                        let _col_count = msg_bytes.get_u16();
410
411                        if msg_bytes.remaining() < 4 {
412                            return Err(PgError::Protocol(
413                                "DataRow ultra: truncated before col0 length".into(),
414                            ));
415                        }
416                        let len0 = msg_bytes.get_i32();
417                        let col0 = if len0 > 0 {
418                            let len0 = len0 as usize;
419                            if msg_bytes.remaining() < len0 {
420                                return Err(PgError::Protocol(
421                                    "DataRow ultra: col0 data exceeds payload".into(),
422                                ));
423                            }
424                            msg_bytes.split_to(len0).freeze()
425                        } else {
426                            bytes::Bytes::new()
427                        };
428
429                        if msg_bytes.remaining() < 4 {
430                            return Err(PgError::Protocol(
431                                "DataRow ultra: truncated before col1 length".into(),
432                            ));
433                        }
434                        let len1 = msg_bytes.get_i32();
435                        let col1 = if len1 > 0 {
436                            let len1 = len1 as usize;
437                            if msg_bytes.remaining() < len1 {
438                                return Err(PgError::Protocol(
439                                    "DataRow ultra: col1 data exceeds payload".into(),
440                                ));
441                            }
442                            msg_bytes.split_to(len1).freeze()
443                        } else {
444                            bytes::Bytes::new()
445                        };
446
447                        return Ok((msg_type, Some((col0, col1))));
448                    }
449
450                    // Other messages - skip
451                    let _ = self.buffer.split_to(msg_len + 1);
452                    return Ok((msg_type, None));
453                }
454            }
455
456            let n = self.read_with_timeout().await?;
457            if n == 0 {
458                return Err(PgError::Connection("Connection closed".to_string()));
459            }
460        }
461    }
462}