qail_pg/driver/
io.rs

1//! Core I/O operations for PostgreSQL connection.
2//!
3//! This module provides low-level send/receive methods.
4
5use super::{PgConnection, PgError, PgResult};
6use crate::protocol::{BackendMessage, FrontendMessage};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8
9const MAX_MESSAGE_SIZE: usize = 1024 * 1024 * 1024; // 1 GB
10
11impl PgConnection {
12    /// Send a frontend message.
13    pub async fn send(&mut self, msg: FrontendMessage) -> PgResult<()> {
14        let bytes = msg.encode();
15        self.stream.write_all(&bytes).await?;
16        Ok(())
17    }
18
19    /// Loops until a complete message is available.
20    pub async fn recv(&mut self) -> PgResult<BackendMessage> {
21        loop {
22            // Try to decode from buffer first
23            if self.buffer.len() >= 5 {
24                let msg_len = u32::from_be_bytes([
25                    self.buffer[1],
26                    self.buffer[2],
27                    self.buffer[3],
28                    self.buffer[4],
29                ]) as usize;
30
31                if msg_len > MAX_MESSAGE_SIZE {
32                    return Err(PgError::Protocol(format!(
33                        "Message too large: {} bytes (max {})",
34                        msg_len, MAX_MESSAGE_SIZE
35                    )));
36                }
37
38                if self.buffer.len() > msg_len {
39                    // We have a complete message - zero-copy split
40                    let msg_bytes = self.buffer.split_to(msg_len + 1);
41                    let (msg, _) = BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
42                    return Ok(msg);
43                }
44            }
45
46            if self.buffer.capacity() - self.buffer.len() < 65536 {
47                self.buffer.reserve(131072); // 128KB buffer - reserve once, use many
48            }
49
50            let n = self.stream.read_buf(&mut self.buffer).await?;
51            if n == 0 {
52                return Err(PgError::Connection("Connection closed".to_string()));
53            }
54        }
55    }
56
57    /// Send raw bytes to the stream.
58    pub async fn send_bytes(&mut self, bytes: &[u8]) -> PgResult<()> {
59        self.stream.write_all(bytes).await?;
60        self.stream.flush().await?; // CRITICAL: Must flush for PostgreSQL to process!
61        Ok(())
62    }
63
64    // ==================== BUFFERED WRITE API (High Performance) ====================
65
66    /// Buffer bytes for later flush (NO SYSCALL).
67    /// Use flush_write_buf() to send all buffered data.
68    #[inline]
69    pub fn buffer_bytes(&mut self, bytes: &[u8]) {
70        self.write_buf.extend_from_slice(bytes);
71    }
72
73    /// Flush the write buffer to the stream.
74    /// This is the only syscall in the buffered write path.
75    pub async fn flush_write_buf(&mut self) -> PgResult<()> {
76        if !self.write_buf.is_empty() {
77            self.stream.write_all(&self.write_buf).await?;
78            self.write_buf.clear();
79        }
80        Ok(())
81    }
82
83    /// FAST receive - returns only message type byte, skips parsing.
84    /// This is ~10x faster than recv() for pipelining benchmarks.
85    /// Returns: message_type
86    #[inline]
87    pub(crate) async fn recv_msg_type_fast(&mut self) -> PgResult<u8> {
88        loop {
89            if self.buffer.len() >= 5 {
90                let msg_len = u32::from_be_bytes([
91                    self.buffer[1],
92                    self.buffer[2],
93                    self.buffer[3],
94                    self.buffer[4],
95                ]) as usize;
96
97                if msg_len > MAX_MESSAGE_SIZE {
98                    return Err(PgError::Protocol(format!(
99                        "Message too large: {} bytes (max {})",
100                        msg_len, MAX_MESSAGE_SIZE
101                    )));
102                }
103
104                if self.buffer.len() > msg_len {
105                    let msg_type = self.buffer[0];
106
107                    if msg_type == b'E' {
108                        let msg_bytes = self.buffer.split_to(msg_len + 1);
109                        let (msg, _) =
110                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
111                        if let BackendMessage::ErrorResponse(err) = msg {
112                            return Err(PgError::Query(err.message));
113                        }
114                    }
115
116                    let _ = self.buffer.split_to(msg_len + 1);
117                    return Ok(msg_type);
118                }
119            }
120
121            if self.buffer.capacity() - self.buffer.len() < 65536 {
122                self.buffer.reserve(131072); // 128KB buffer - reserve once, use many
123            }
124
125            let n = self.stream.read_buf(&mut self.buffer).await?;
126            if n == 0 {
127                return Err(PgError::Connection("Connection closed".to_string()));
128            }
129        }
130    }
131
132    /// FAST receive for result consumption - inline DataRow parsing.
133    /// Returns: (msg_type, Option<row_data>)
134    /// For 'D' (DataRow): returns parsed columns
135    /// For other types: returns None
136    /// This avoids BackendMessage enum allocation for non-DataRow messages.
137    #[inline]
138    pub(crate) async fn recv_with_data_fast(
139        &mut self,
140    ) -> PgResult<(u8, Option<Vec<Option<Vec<u8>>>>)> {
141        loop {
142            if self.buffer.len() >= 5 {
143                let msg_len = u32::from_be_bytes([
144                    self.buffer[1],
145                    self.buffer[2],
146                    self.buffer[3],
147                    self.buffer[4],
148                ]) as usize;
149
150                if msg_len > MAX_MESSAGE_SIZE {
151                    return Err(PgError::Protocol(format!(
152                        "Message too large: {} bytes (max {})",
153                        msg_len, MAX_MESSAGE_SIZE
154                    )));
155                }
156
157                if self.buffer.len() > msg_len {
158                    let msg_type = self.buffer[0];
159
160                    if msg_type == b'E' {
161                        let msg_bytes = self.buffer.split_to(msg_len + 1);
162                        let (msg, _) =
163                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
164                        if let BackendMessage::ErrorResponse(err) = msg {
165                            return Err(PgError::Query(err.message));
166                        }
167                    }
168
169                    // Fast path: DataRow - parse inline
170                    if msg_type == b'D' {
171                        let payload = &self.buffer[5..msg_len + 1];
172
173                        if payload.len() >= 2 {
174                            let column_count =
175                                u16::from_be_bytes([payload[0], payload[1]]) as usize;
176                            let mut columns = Vec::with_capacity(column_count);
177                            let mut pos = 2;
178
179                            for _ in 0..column_count {
180                                if pos + 4 > payload.len() {
181                                    break;
182                                }
183
184                                let len = i32::from_be_bytes([
185                                    payload[pos],
186                                    payload[pos + 1],
187                                    payload[pos + 2],
188                                    payload[pos + 3],
189                                ]);
190                                pos += 4;
191
192                                if len == -1 {
193                                    columns.push(None);
194                                } else {
195                                    let len = len as usize;
196                                    if pos + len <= payload.len() {
197                                        columns.push(Some(payload[pos..pos + len].to_vec()));
198                                        pos += len;
199                                    }
200                                }
201                            }
202
203                            let _ = self.buffer.split_to(msg_len + 1);
204                            return Ok((msg_type, Some(columns)));
205                        }
206                    }
207
208                    // Other messages - skip
209                    let _ = self.buffer.split_to(msg_len + 1);
210                    return Ok((msg_type, None));
211                }
212            }
213
214            if self.buffer.capacity() - self.buffer.len() < 65536 {
215                self.buffer.reserve(131072);
216            }
217
218            let n = self.stream.read_buf(&mut self.buffer).await?;
219            if n == 0 {
220                return Err(PgError::Connection("Connection closed".to_string()));
221            }
222        }
223    }
224
225    /// ZERO-COPY receive for DataRow.
226    /// Uses bytes::Bytes for reference-counted slicing instead of Vec copy.
227    /// Returns: (msg_type, Option<row_data>)
228    /// For 'D' (DataRow): returns Bytes slices (no copy!)
229    /// For other types: returns None
230    #[inline]
231    pub(crate) async fn recv_data_zerocopy(
232        &mut self,
233    ) -> PgResult<(u8, Option<Vec<Option<bytes::Bytes>>>)> {
234        use bytes::Buf;
235
236        loop {
237            if self.buffer.len() >= 5 {
238                let msg_len = u32::from_be_bytes([
239                    self.buffer[1],
240                    self.buffer[2],
241                    self.buffer[3],
242                    self.buffer[4],
243                ]) as usize;
244
245                if msg_len > MAX_MESSAGE_SIZE {
246                    return Err(PgError::Protocol(format!(
247                        "Message too large: {} bytes (max {})",
248                        msg_len, MAX_MESSAGE_SIZE
249                    )));
250                }
251
252                if self.buffer.len() > msg_len {
253                    let msg_type = self.buffer[0];
254
255                    if msg_type == b'E' {
256                        let msg_bytes = self.buffer.split_to(msg_len + 1);
257                        let (msg, _) =
258                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
259                        if let BackendMessage::ErrorResponse(err) = msg {
260                            return Err(PgError::Query(err.message));
261                        }
262                    }
263
264                    // Fast path: DataRow - ZERO-COPY using Bytes
265                    if msg_type == b'D' {
266                        // Split off the entire message
267                        let mut msg_bytes = self.buffer.split_to(msg_len + 1);
268
269                        // Skip type byte (1) + length (4) = 5 bytes
270                        msg_bytes.advance(5);
271
272                        if msg_bytes.len() >= 2 {
273                            let column_count = msg_bytes.get_u16() as usize;
274                            let mut columns = Vec::with_capacity(column_count);
275
276                            for _ in 0..column_count {
277                                if msg_bytes.remaining() < 4 {
278                                    break;
279                                }
280
281                                let len = msg_bytes.get_i32();
282
283                                if len == -1 {
284                                    columns.push(None);
285                                } else {
286                                    let len = len as usize;
287                                    if msg_bytes.remaining() >= len {
288                                        let col_data = msg_bytes.split_to(len).freeze();
289                                        columns.push(Some(col_data));
290                                    }
291                                }
292                            }
293
294                            return Ok((msg_type, Some(columns)));
295                        }
296                        return Ok((msg_type, None));
297                    }
298
299                    // Other messages - skip
300                    let _ = self.buffer.split_to(msg_len + 1);
301                    return Ok((msg_type, None));
302                }
303            }
304
305            if self.buffer.capacity() - self.buffer.len() < 65536 {
306                self.buffer.reserve(131072);
307            }
308
309            let n = self.stream.read_buf(&mut self.buffer).await?;
310            if n == 0 {
311                return Err(PgError::Connection("Connection closed".to_string()));
312            }
313        }
314    }
315
316    /// ULTRA-FAST receive for 2-column DataRow (id, name pattern).
317    /// Uses fixed-size array instead of Vec allocation.
318    /// Returns: (msg_type, Option<(col0, col1)>)
319    #[inline(always)]
320    pub(crate) async fn recv_data_ultra(
321        &mut self,
322    ) -> PgResult<(u8, Option<(bytes::Bytes, bytes::Bytes)>)> {
323        use bytes::Buf;
324
325        loop {
326            if self.buffer.len() >= 5 {
327                let msg_len = u32::from_be_bytes([
328                    self.buffer[1],
329                    self.buffer[2],
330                    self.buffer[3],
331                    self.buffer[4],
332                ]) as usize;
333
334                if msg_len > MAX_MESSAGE_SIZE {
335                    return Err(PgError::Protocol(format!(
336                        "Message too large: {} bytes (max {})",
337                        msg_len, MAX_MESSAGE_SIZE
338                    )));
339                }
340
341                if self.buffer.len() > msg_len {
342                    let msg_type = self.buffer[0];
343
344                    // Error check
345                    if msg_type == b'E' {
346                        let msg_bytes = self.buffer.split_to(msg_len + 1);
347                        let (msg, _) =
348                            BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
349                        if let BackendMessage::ErrorResponse(err) = msg {
350                            return Err(PgError::Query(err.message));
351                        }
352                    }
353
354                    if msg_type == b'D' {
355                        let mut msg_bytes = self.buffer.split_to(msg_len + 1);
356                        msg_bytes.advance(5); // Skip type + length
357
358                        // Read column count (expect 2)
359                        let _col_count = msg_bytes.get_u16();
360
361                        let len0 = msg_bytes.get_i32();
362                        let col0 = if len0 > 0 {
363                            msg_bytes.split_to(len0 as usize).freeze()
364                        } else {
365                            bytes::Bytes::new()
366                        };
367
368                        let len1 = msg_bytes.get_i32();
369                        let col1 = if len1 > 0 {
370                            msg_bytes.split_to(len1 as usize).freeze()
371                        } else {
372                            bytes::Bytes::new()
373                        };
374
375                        return Ok((msg_type, Some((col0, col1))));
376                    }
377
378                    // Other messages - skip
379                    let _ = self.buffer.split_to(msg_len + 1);
380                    return Ok((msg_type, None));
381                }
382            }
383
384            if self.buffer.capacity() - self.buffer.len() < 65536 {
385                self.buffer.reserve(131072);
386            }
387
388            let n = self.stream.read_buf(&mut self.buffer).await?;
389            if n == 0 {
390                return Err(PgError::Connection("Connection closed".to_string()));
391            }
392        }
393    }
394}