Skip to main content

qail_pg/driver/
io.rs

1//! Core I/O operations for PostgreSQL connection.
2//!
3//! This module provides low-level send/receive methods.
4
5use super::{PgConnection, PgError, PgResult};
6use crate::protocol::{BackendMessage, FrontendMessage};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8
9const MAX_MESSAGE_SIZE: usize = 1024 * 1024 * 1024; // 1 GB
10
11impl PgConnection {
12    /// Send a frontend message.
13    pub async fn send(&mut self, msg: FrontendMessage) -> PgResult<()> {
14        let bytes = msg.encode();
15        self.stream.write_all(&bytes).await?;
16        Ok(())
17    }
18
19    /// Loops until a complete message is available.
20    pub async fn recv(&mut self) -> PgResult<BackendMessage> {
21        loop {
22            // Try to decode from buffer first
23            if self.buffer.len() >= 5 {
24                let msg_len = u32::from_be_bytes([
25                    self.buffer[1],
26                    self.buffer[2],
27                    self.buffer[3],
28                    self.buffer[4],
29                ]) as usize;
30
31                if msg_len > MAX_MESSAGE_SIZE {
32                    return Err(PgError::Protocol(format!(
33                        "Message too large: {} bytes (max {})",
34                        msg_len, MAX_MESSAGE_SIZE
35                    )));
36                }
37
38                if self.buffer.len() > msg_len {
39                    // We have a complete message - zero-copy split
40                    let msg_bytes = self.buffer.split_to(msg_len + 1);
41                    let (msg, _) = BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
42                    return Ok(msg);
43                }
44            }
45
46            if self.buffer.capacity() - self.buffer.len() < 65536 {
47                self.buffer.reserve(131072); // 128KB buffer - reserve once, use many
48            }
49
50            let n = self.stream.read_buf(&mut self.buffer).await?;
51            if n == 0 {
52                return Err(PgError::Connection("Connection closed".to_string()));
53            }
54        }
55    }
56
57    /// Send raw bytes to the stream.
58    /// Includes flush for TLS safety — TLS buffers internally and
59    /// needs flush to push encrypted data to the underlying TCP socket.
60    pub async fn send_bytes(&mut self, bytes: &[u8]) -> PgResult<()> {
61        self.stream.write_all(bytes).await?;
62        self.stream.flush().await?;
63        Ok(())
64    }
65
66    // ==================== BUFFERED WRITE API (High Performance) ====================
67
68    /// Buffer bytes for later flush (NO SYSCALL).
69    /// Use flush_write_buf() to send all buffered data.
70    #[inline]
71    pub fn buffer_bytes(&mut self, bytes: &[u8]) {
72        self.write_buf.extend_from_slice(bytes);
73    }
74
75    /// Flush the write buffer to the stream (single write_all + flush).
76    /// The flush is critical for TLS connections.
77    pub async fn flush_write_buf(&mut self) -> PgResult<()> {
78        if !self.write_buf.is_empty() {
79            self.stream.write_all(&self.write_buf).await?;
80            self.write_buf.clear();
81            self.stream.flush().await?;
82        }
83        Ok(())
84    }
85
86    /// FAST receive - returns only message type byte, skips parsing.
87    /// This is ~10x faster than recv() for pipelining benchmarks.
88    /// Returns: message_type
89    #[inline]
90    pub(crate) async fn recv_msg_type_fast(&mut self) -> PgResult<u8> {
91        loop {
92            if self.buffer.len() >= 5 {
93                let msg_len = u32::from_be_bytes([
94                    self.buffer[1],
95                    self.buffer[2],
96                    self.buffer[3],
97                    self.buffer[4],
98                ]) as usize;
99
100                if msg_len > MAX_MESSAGE_SIZE {
101                    return Err(PgError::Protocol(format!(
102                        "Message too large: {} bytes (max {})",
103                        msg_len, MAX_MESSAGE_SIZE
104                    )));
105                }
106
107                if self.buffer.len() > msg_len {
108                    let msg_type = self.buffer[0];
109
110                    if msg_type == b'E' {
111                        let msg_bytes = self.buffer.split_to(msg_len + 1);
112                        let (msg, _) =
113                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
114                        if let BackendMessage::ErrorResponse(err) = msg {
115                            return Err(PgError::Query(err.message));
116                        }
117                    }
118
119                    let _ = self.buffer.split_to(msg_len + 1);
120                    return Ok(msg_type);
121                }
122            }
123
124            if self.buffer.capacity() - self.buffer.len() < 65536 {
125                self.buffer.reserve(131072); // 128KB buffer - reserve once, use many
126            }
127
128            let n = self.stream.read_buf(&mut self.buffer).await?;
129            if n == 0 {
130                return Err(PgError::Connection("Connection closed".to_string()));
131            }
132        }
133    }
134
135    /// FAST receive for result consumption - inline DataRow parsing.
136    /// Returns: (msg_type, Option<row_data>)
137    /// For 'D' (DataRow): returns parsed columns
138    /// For other types: returns None
139    /// This avoids BackendMessage enum allocation for non-DataRow messages.
140    #[inline]
141    pub(crate) async fn recv_with_data_fast(
142        &mut self,
143    ) -> PgResult<(u8, Option<Vec<Option<Vec<u8>>>>)> {
144        loop {
145            if self.buffer.len() >= 5 {
146                let msg_len = u32::from_be_bytes([
147                    self.buffer[1],
148                    self.buffer[2],
149                    self.buffer[3],
150                    self.buffer[4],
151                ]) as usize;
152
153                if msg_len > MAX_MESSAGE_SIZE {
154                    return Err(PgError::Protocol(format!(
155                        "Message too large: {} bytes (max {})",
156                        msg_len, MAX_MESSAGE_SIZE
157                    )));
158                }
159
160                if self.buffer.len() > msg_len {
161                    let msg_type = self.buffer[0];
162
163                    if msg_type == b'E' {
164                        let msg_bytes = self.buffer.split_to(msg_len + 1);
165                        let (msg, _) =
166                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
167                        if let BackendMessage::ErrorResponse(err) = msg {
168                            return Err(PgError::Query(err.message));
169                        }
170                    }
171
172                    // Fast path: DataRow - parse inline
173                    if msg_type == b'D' {
174                        let payload = &self.buffer[5..msg_len + 1];
175
176                        if payload.len() >= 2 {
177                            let column_count =
178                                u16::from_be_bytes([payload[0], payload[1]]) as usize;
179                            let mut columns = Vec::with_capacity(column_count);
180                            let mut pos = 2;
181
182                            for _ in 0..column_count {
183                                if pos + 4 > payload.len() {
184                                    break;
185                                }
186
187                                let len = i32::from_be_bytes([
188                                    payload[pos],
189                                    payload[pos + 1],
190                                    payload[pos + 2],
191                                    payload[pos + 3],
192                                ]);
193                                pos += 4;
194
195                                if len == -1 {
196                                    columns.push(None);
197                                } else {
198                                    let len = len as usize;
199                                    if pos + len <= payload.len() {
200                                        columns.push(Some(payload[pos..pos + len].to_vec()));
201                                        pos += len;
202                                    }
203                                }
204                            }
205
206                            let _ = self.buffer.split_to(msg_len + 1);
207                            return Ok((msg_type, Some(columns)));
208                        }
209                    }
210
211                    // Other messages - skip
212                    let _ = self.buffer.split_to(msg_len + 1);
213                    return Ok((msg_type, None));
214                }
215            }
216
217            if self.buffer.capacity() - self.buffer.len() < 65536 {
218                self.buffer.reserve(131072);
219            }
220
221            let n = self.stream.read_buf(&mut self.buffer).await?;
222            if n == 0 {
223                return Err(PgError::Connection("Connection closed".to_string()));
224            }
225        }
226    }
227
228    /// ZERO-COPY receive for DataRow.
229    /// Uses bytes::Bytes for reference-counted slicing instead of Vec copy.
230    /// Returns: (msg_type, Option<row_data>)
231    /// For 'D' (DataRow): returns Bytes slices (no copy!)
232    /// For other types: returns None
233    #[inline]
234    pub(crate) async fn recv_data_zerocopy(
235        &mut self,
236    ) -> PgResult<(u8, Option<Vec<Option<bytes::Bytes>>>)> {
237        use bytes::Buf;
238
239        loop {
240            if self.buffer.len() >= 5 {
241                let msg_len = u32::from_be_bytes([
242                    self.buffer[1],
243                    self.buffer[2],
244                    self.buffer[3],
245                    self.buffer[4],
246                ]) as usize;
247
248                if msg_len > MAX_MESSAGE_SIZE {
249                    return Err(PgError::Protocol(format!(
250                        "Message too large: {} bytes (max {})",
251                        msg_len, MAX_MESSAGE_SIZE
252                    )));
253                }
254
255                if self.buffer.len() > msg_len {
256                    let msg_type = self.buffer[0];
257
258                    if msg_type == b'E' {
259                        let msg_bytes = self.buffer.split_to(msg_len + 1);
260                        let (msg, _) =
261                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
262                        if let BackendMessage::ErrorResponse(err) = msg {
263                            return Err(PgError::Query(err.message));
264                        }
265                    }
266
267                    // Fast path: DataRow - ZERO-COPY using Bytes
268                    if msg_type == b'D' {
269                        // Split off the entire message
270                        let mut msg_bytes = self.buffer.split_to(msg_len + 1);
271
272                        // Skip type byte (1) + length (4) = 5 bytes
273                        msg_bytes.advance(5);
274
275                        if msg_bytes.len() >= 2 {
276                            let column_count = msg_bytes.get_u16() as usize;
277                            let mut columns = Vec::with_capacity(column_count);
278
279                            for _ in 0..column_count {
280                                if msg_bytes.remaining() < 4 {
281                                    break;
282                                }
283
284                                let len = msg_bytes.get_i32();
285
286                                if len == -1 {
287                                    columns.push(None);
288                                } else {
289                                    let len = len as usize;
290                                    if msg_bytes.remaining() >= len {
291                                        let col_data = msg_bytes.split_to(len).freeze();
292                                        columns.push(Some(col_data));
293                                    }
294                                }
295                            }
296
297                            return Ok((msg_type, Some(columns)));
298                        }
299                        return Ok((msg_type, None));
300                    }
301
302                    // Other messages - skip
303                    let _ = self.buffer.split_to(msg_len + 1);
304                    return Ok((msg_type, None));
305                }
306            }
307
308            if self.buffer.capacity() - self.buffer.len() < 65536 {
309                self.buffer.reserve(131072);
310            }
311
312            let n = self.stream.read_buf(&mut self.buffer).await?;
313            if n == 0 {
314                return Err(PgError::Connection("Connection closed".to_string()));
315            }
316        }
317    }
318
319    /// ULTRA-FAST receive for 2-column DataRow (id, name pattern).
320    /// Uses fixed-size array instead of Vec allocation.
321    /// Returns: (msg_type, Option<(col0, col1)>)
322    #[inline(always)]
323    pub(crate) async fn recv_data_ultra(
324        &mut self,
325    ) -> PgResult<(u8, Option<(bytes::Bytes, bytes::Bytes)>)> {
326        use bytes::Buf;
327
328        loop {
329            if self.buffer.len() >= 5 {
330                let msg_len = u32::from_be_bytes([
331                    self.buffer[1],
332                    self.buffer[2],
333                    self.buffer[3],
334                    self.buffer[4],
335                ]) as usize;
336
337                if msg_len > MAX_MESSAGE_SIZE {
338                    return Err(PgError::Protocol(format!(
339                        "Message too large: {} bytes (max {})",
340                        msg_len, MAX_MESSAGE_SIZE
341                    )));
342                }
343
344                if self.buffer.len() > msg_len {
345                    let msg_type = self.buffer[0];
346
347                    // Error check
348                    if msg_type == b'E' {
349                        let msg_bytes = self.buffer.split_to(msg_len + 1);
350                        let (msg, _) =
351                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
352                        if let BackendMessage::ErrorResponse(err) = msg {
353                            return Err(PgError::Query(err.message));
354                        }
355                    }
356
357                    if msg_type == b'D' {
358                        let mut msg_bytes = self.buffer.split_to(msg_len + 1);
359                        msg_bytes.advance(5); // Skip type + length
360
361                        // Read column count (expect 2)
362                        let _col_count = msg_bytes.get_u16();
363
364                        let len0 = msg_bytes.get_i32();
365                        let col0 = if len0 > 0 {
366                            msg_bytes.split_to(len0 as usize).freeze()
367                        } else {
368                            bytes::Bytes::new()
369                        };
370
371                        let len1 = msg_bytes.get_i32();
372                        let col1 = if len1 > 0 {
373                            msg_bytes.split_to(len1 as usize).freeze()
374                        } else {
375                            bytes::Bytes::new()
376                        };
377
378                        return Ok((msg_type, Some((col0, col1))));
379                    }
380
381                    // Other messages - skip
382                    let _ = self.buffer.split_to(msg_len + 1);
383                    return Ok((msg_type, None));
384                }
385            }
386
387            if self.buffer.capacity() - self.buffer.len() < 65536 {
388                self.buffer.reserve(131072);
389            }
390
391            let n = self.stream.read_buf(&mut self.buffer).await?;
392            if n == 0 {
393                return Err(PgError::Connection("Connection closed".to_string()));
394            }
395        }
396    }
397}