Skip to main content

qail_pg/driver/
io.rs

1//! Core I/O operations for PostgreSQL connection.
2//!
3//! This module provides low-level send/receive methods.
4
5use super::{PgConnection, PgError, PgResult};
6use crate::protocol::{BackendMessage, FrontendMessage};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8
9const MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; // 64 MB — prevents OOM from malicious server messages
10
11/// Default read timeout for individual socket reads.
12/// Prevents Slowloris DoS where a server sends partial data then goes silent.
13const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
14
15impl PgConnection {
16    /// Send a frontend message.
17    pub async fn send(&mut self, msg: FrontendMessage) -> PgResult<()> {
18        let bytes = msg.encode();
19        self.stream.write_all(&bytes).await?;
20        Ok(())
21    }
22
23    /// Loops until a complete message is available.
24    /// Automatically buffers NotificationResponse messages for LISTEN/NOTIFY.
25    pub async fn recv(&mut self) -> PgResult<BackendMessage> {
26        loop {
27            // Try to decode from buffer first
28            if self.buffer.len() >= 5 {
29                let msg_len = u32::from_be_bytes([
30                    self.buffer[1],
31                    self.buffer[2],
32                    self.buffer[3],
33                    self.buffer[4],
34                ]) as usize;
35
36                if msg_len > MAX_MESSAGE_SIZE {
37                    return Err(PgError::Protocol(format!(
38                        "Message too large: {} bytes (max {})",
39                        msg_len, MAX_MESSAGE_SIZE
40                    )));
41                }
42
43                if self.buffer.len() > msg_len {
44                    // We have a complete message - zero-copy split
45                    let msg_bytes = self.buffer.split_to(msg_len + 1);
46                    let (msg, _) = BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
47
48                    // Intercept async notifications — buffer them instead of returning
49                    if let BackendMessage::NotificationResponse { process_id, channel, payload } = msg {
50                        self.notifications.push_back(
51                            super::notification::Notification { process_id, channel, payload }
52                        );
53                        continue; // Keep reading for the actual response
54                    }
55
56                    return Ok(msg);
57                }
58            }
59
60
61            let n = self.read_with_timeout().await?;
62            if n == 0 {
63                return Err(PgError::Connection("Connection closed".to_string()));
64            }
65        }
66    }
67
68    /// Read from the socket with a timeout guard.
69    /// Returns the number of bytes read, or an error if the timeout fires.
70    /// This prevents Slowloris DoS attacks where a malicious server sends
71    /// partial data then goes silent, causing the driver to hang forever.
72    #[inline]
73    pub(crate) async fn read_with_timeout(&mut self) -> PgResult<usize> {
74        if self.buffer.capacity() - self.buffer.len() < 65536 {
75            self.buffer.reserve(131072);
76        }
77        
78        match tokio::time::timeout(
79            DEFAULT_READ_TIMEOUT,
80            self.stream.read_buf(&mut self.buffer),
81        ).await {
82            Ok(Ok(n)) => Ok(n),
83            Ok(Err(e)) => Err(PgError::Connection(format!("Read error: {}", e))),
84            Err(_) => Err(PgError::Connection(format!(
85                "Read timeout after {:?} — possible Slowloris attack or dead connection",
86                DEFAULT_READ_TIMEOUT
87            ))),
88        }
89    }
90
91    /// Send raw bytes to the stream.
92    /// Includes flush for TLS safety — TLS buffers internally and
93    /// needs flush to push encrypted data to the underlying TCP socket.
94    pub async fn send_bytes(&mut self, bytes: &[u8]) -> PgResult<()> {
95        self.stream.write_all(bytes).await?;
96        self.stream.flush().await?;
97        Ok(())
98    }
99
100    // ==================== BUFFERED WRITE API (High Performance) ====================
101
102    /// Buffer bytes for later flush (NO SYSCALL).
103    /// Use flush_write_buf() to send all buffered data.
104    #[inline]
105    pub fn buffer_bytes(&mut self, bytes: &[u8]) {
106        self.write_buf.extend_from_slice(bytes);
107    }
108
109    /// Flush the write buffer to the stream (single write_all + flush).
110    /// The flush is critical for TLS connections.
111    pub async fn flush_write_buf(&mut self) -> PgResult<()> {
112        if !self.write_buf.is_empty() {
113            self.stream.write_all(&self.write_buf).await?;
114            self.write_buf.clear();
115            self.stream.flush().await?;
116        }
117        Ok(())
118    }
119
120    /// FAST receive - returns only message type byte, skips parsing.
121    /// This is ~10x faster than recv() for pipelining benchmarks.
122    /// Returns: message_type
123    #[inline]
124    pub(crate) async fn recv_msg_type_fast(&mut self) -> PgResult<u8> {
125        loop {
126            if self.buffer.len() >= 5 {
127                let msg_len = u32::from_be_bytes([
128                    self.buffer[1],
129                    self.buffer[2],
130                    self.buffer[3],
131                    self.buffer[4],
132                ]) as usize;
133
134                if msg_len > MAX_MESSAGE_SIZE {
135                    return Err(PgError::Protocol(format!(
136                        "Message too large: {} bytes (max {})",
137                        msg_len, MAX_MESSAGE_SIZE
138                    )));
139                }
140
141                if self.buffer.len() > msg_len {
142                    let msg_type = self.buffer[0];
143
144                    if msg_type == b'E' {
145                        let msg_bytes = self.buffer.split_to(msg_len + 1);
146                        let (msg, _) =
147                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
148                        if let BackendMessage::ErrorResponse(err) = msg {
149                            return Err(PgError::Query(err.message));
150                        }
151                    }
152
153                    let _ = self.buffer.split_to(msg_len + 1);
154                    return Ok(msg_type);
155                }
156            }
157
158
159            let n = self.read_with_timeout().await?;
160            if n == 0 {
161                return Err(PgError::Connection("Connection closed".to_string()));
162            }
163        }
164    }
165
166    /// FAST receive for result consumption - inline DataRow parsing.
167    /// Returns: (msg_type, Option<row_data>)
168    /// For 'D' (DataRow): returns parsed columns
169    /// For other types: returns None
170    /// This avoids BackendMessage enum allocation for non-DataRow messages.
171    #[inline]
172    pub(crate) async fn recv_with_data_fast(
173        &mut self,
174    ) -> PgResult<(u8, Option<Vec<Option<Vec<u8>>>>)> {
175        loop {
176            if self.buffer.len() >= 5 {
177                let msg_len = u32::from_be_bytes([
178                    self.buffer[1],
179                    self.buffer[2],
180                    self.buffer[3],
181                    self.buffer[4],
182                ]) as usize;
183
184                if msg_len > MAX_MESSAGE_SIZE {
185                    return Err(PgError::Protocol(format!(
186                        "Message too large: {} bytes (max {})",
187                        msg_len, MAX_MESSAGE_SIZE
188                    )));
189                }
190
191                if self.buffer.len() > msg_len {
192                    let msg_type = self.buffer[0];
193
194                    if msg_type == b'E' {
195                        let msg_bytes = self.buffer.split_to(msg_len + 1);
196                        let (msg, _) =
197                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
198                        if let BackendMessage::ErrorResponse(err) = msg {
199                            return Err(PgError::Query(err.message));
200                        }
201                    }
202
203                    // Fast path: DataRow - parse inline
204                    if msg_type == b'D' {
205                        let payload = &self.buffer[5..msg_len + 1];
206
207                        if payload.len() >= 2 {
208                            let column_count =
209                                u16::from_be_bytes([payload[0], payload[1]]) as usize;
210                            let mut columns = Vec::with_capacity(column_count);
211                            let mut pos = 2;
212
213                            for _ in 0..column_count {
214                                if pos + 4 > payload.len() {
215                                    let _ = self.buffer.split_to(msg_len + 1);
216                                    return Err(PgError::Protocol("DataRow truncated: missing column length".into()));
217                                }
218
219                                let len = i32::from_be_bytes([
220                                    payload[pos],
221                                    payload[pos + 1],
222                                    payload[pos + 2],
223                                    payload[pos + 3],
224                                ]);
225                                pos += 4;
226
227                                if len == -1 {
228                                    columns.push(None);
229                                } else {
230                                    let len = len as usize;
231                                    if pos + len > payload.len() {
232                                        let _ = self.buffer.split_to(msg_len + 1);
233                                        return Err(PgError::Protocol("DataRow truncated: column data exceeds payload".into()));
234                                    }
235                                    columns.push(Some(payload[pos..pos + len].to_vec()));
236                                    pos += len;
237                                }
238                            }
239
240                            let _ = self.buffer.split_to(msg_len + 1);
241                            return Ok((msg_type, Some(columns)));
242                        }
243                    }
244
245                    // Other messages - skip
246                    let _ = self.buffer.split_to(msg_len + 1);
247                    return Ok((msg_type, None));
248                }
249            }
250
251
252            let n = self.read_with_timeout().await?;
253            if n == 0 {
254                return Err(PgError::Connection("Connection closed".to_string()));
255            }
256        }
257    }
258
259    /// ZERO-COPY receive for DataRow.
260    /// Uses bytes::Bytes for reference-counted slicing instead of Vec copy.
261    /// Returns: (msg_type, Option<row_data>)
262    /// For 'D' (DataRow): returns Bytes slices (no copy!)
263    /// For other types: returns None
264    #[inline]
265    pub(crate) async fn recv_data_zerocopy(
266        &mut self,
267    ) -> PgResult<(u8, Option<Vec<Option<bytes::Bytes>>>)> {
268        use bytes::Buf;
269
270        loop {
271            if self.buffer.len() >= 5 {
272                let msg_len = u32::from_be_bytes([
273                    self.buffer[1],
274                    self.buffer[2],
275                    self.buffer[3],
276                    self.buffer[4],
277                ]) as usize;
278
279                if msg_len > MAX_MESSAGE_SIZE {
280                    return Err(PgError::Protocol(format!(
281                        "Message too large: {} bytes (max {})",
282                        msg_len, MAX_MESSAGE_SIZE
283                    )));
284                }
285
286                if self.buffer.len() > msg_len {
287                    let msg_type = self.buffer[0];
288
289                    if msg_type == b'E' {
290                        let msg_bytes = self.buffer.split_to(msg_len + 1);
291                        let (msg, _) =
292                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
293                        if let BackendMessage::ErrorResponse(err) = msg {
294                            return Err(PgError::Query(err.message));
295                        }
296                    }
297
298                    // Fast path: DataRow - ZERO-COPY using Bytes
299                    if msg_type == b'D' {
300                        // Split off the entire message
301                        let mut msg_bytes = self.buffer.split_to(msg_len + 1);
302
303                        // Skip type byte (1) + length (4) = 5 bytes
304                        msg_bytes.advance(5);
305
306                        if msg_bytes.len() >= 2 {
307                            let column_count = msg_bytes.get_u16() as usize;
308                            let mut columns = Vec::with_capacity(column_count);
309
310                            for _ in 0..column_count {
311                                if msg_bytes.remaining() < 4 {
312                                    return Err(PgError::Protocol("DataRow truncated: missing column length".into()));
313                                }
314
315                                let len = msg_bytes.get_i32();
316
317                                if len == -1 {
318                                    columns.push(None);
319                                } else {
320                                    let len = len as usize;
321                                    if msg_bytes.remaining() < len {
322                                        return Err(PgError::Protocol("DataRow truncated: column data exceeds payload".into()));
323                                    }
324                                    let col_data = msg_bytes.split_to(len).freeze();
325                                    columns.push(Some(col_data));
326                                }
327                            }
328
329                            return Ok((msg_type, Some(columns)));
330                        }
331                        return Ok((msg_type, None));
332                    }
333
334                    // Other messages - skip
335                    let _ = self.buffer.split_to(msg_len + 1);
336                    return Ok((msg_type, None));
337                }
338            }
339
340
341            let n = self.read_with_timeout().await?;
342            if n == 0 {
343                return Err(PgError::Connection("Connection closed".to_string()));
344            }
345        }
346    }
347
348    /// ULTRA-FAST receive for 2-column DataRow (id, name pattern).
349    /// Uses fixed-size array instead of Vec allocation.
350    /// Returns: (msg_type, Option<(col0, col1)>)
351    #[inline(always)]
352    pub(crate) async fn recv_data_ultra(
353        &mut self,
354    ) -> PgResult<(u8, Option<(bytes::Bytes, bytes::Bytes)>)> {
355        use bytes::Buf;
356
357        loop {
358            if self.buffer.len() >= 5 {
359                let msg_len = u32::from_be_bytes([
360                    self.buffer[1],
361                    self.buffer[2],
362                    self.buffer[3],
363                    self.buffer[4],
364                ]) as usize;
365
366                if msg_len > MAX_MESSAGE_SIZE {
367                    return Err(PgError::Protocol(format!(
368                        "Message too large: {} bytes (max {})",
369                        msg_len, MAX_MESSAGE_SIZE
370                    )));
371                }
372
373                if self.buffer.len() > msg_len {
374                    let msg_type = self.buffer[0];
375
376                    // Error check
377                    if msg_type == b'E' {
378                        let msg_bytes = self.buffer.split_to(msg_len + 1);
379                        let (msg, _) =
380                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
381                        if let BackendMessage::ErrorResponse(err) = msg {
382                            return Err(PgError::Query(err.message));
383                        }
384                    }
385
386                    if msg_type == b'D' {
387                        let mut msg_bytes = self.buffer.split_to(msg_len + 1);
388                        msg_bytes.advance(5); // Skip type + length
389
390                        // Bounds checks to prevent panic on truncated DataRow
391                        if msg_bytes.remaining() < 2 {
392                            return Err(PgError::Protocol("DataRow ultra: too short for column count".into()));
393                        }
394
395                        // Read column count (expect 2)
396                        let _col_count = msg_bytes.get_u16();
397
398                        if msg_bytes.remaining() < 4 {
399                            return Err(PgError::Protocol("DataRow ultra: truncated before col0 length".into()));
400                        }
401                        let len0 = msg_bytes.get_i32();
402                        let col0 = if len0 > 0 {
403                            let len0 = len0 as usize;
404                            if msg_bytes.remaining() < len0 {
405                                return Err(PgError::Protocol("DataRow ultra: col0 data exceeds payload".into()));
406                            }
407                            msg_bytes.split_to(len0).freeze()
408                        } else {
409                            bytes::Bytes::new()
410                        };
411
412                        if msg_bytes.remaining() < 4 {
413                            return Err(PgError::Protocol("DataRow ultra: truncated before col1 length".into()));
414                        }
415                        let len1 = msg_bytes.get_i32();
416                        let col1 = if len1 > 0 {
417                            let len1 = len1 as usize;
418                            if msg_bytes.remaining() < len1 {
419                                return Err(PgError::Protocol("DataRow ultra: col1 data exceeds payload".into()));
420                            }
421                            msg_bytes.split_to(len1).freeze()
422                        } else {
423                            bytes::Bytes::new()
424                        };
425
426                        return Ok((msg_type, Some((col0, col1))));
427                    }
428
429                    // Other messages - skip
430                    let _ = self.buffer.split_to(msg_len + 1);
431                    return Ok((msg_type, None));
432                }
433            }
434
435
436            let n = self.read_with_timeout().await?;
437            if n == 0 {
438                return Err(PgError::Connection("Connection closed".to_string()));
439            }
440        }
441    }
442}