1use super::{PgConnection, PgError, PgResult};
6use crate::protocol::{BackendMessage, FrontendMessage};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8
9const MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
14
15impl PgConnection {
16 pub async fn send(&mut self, msg: FrontendMessage) -> PgResult<()> {
18 let bytes = msg.encode();
19 self.stream.write_all(&bytes).await?;
20 Ok(())
21 }
22
23 pub async fn recv(&mut self) -> PgResult<BackendMessage> {
26 loop {
27 if self.buffer.len() >= 5 {
29 let msg_len = u32::from_be_bytes([
30 self.buffer[1],
31 self.buffer[2],
32 self.buffer[3],
33 self.buffer[4],
34 ]) as usize;
35
36 if msg_len > MAX_MESSAGE_SIZE {
37 return Err(PgError::Protocol(format!(
38 "Message too large: {} bytes (max {})",
39 msg_len, MAX_MESSAGE_SIZE
40 )));
41 }
42
43 if self.buffer.len() > msg_len {
44 let msg_bytes = self.buffer.split_to(msg_len + 1);
46 let (msg, _) = BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
47
48 if let BackendMessage::NotificationResponse {
50 process_id,
51 channel,
52 payload,
53 } = msg
54 {
55 self.notifications
56 .push_back(super::notification::Notification {
57 process_id,
58 channel,
59 payload,
60 });
61 continue; }
63
64 return Ok(msg);
65 }
66 }
67
68 let n = self.read_with_timeout().await?;
69 if n == 0 {
70 return Err(PgError::Connection("Connection closed".to_string()));
71 }
72 }
73 }
74
75 #[inline]
80 pub(crate) async fn read_with_timeout(&mut self) -> PgResult<usize> {
81 if self.buffer.capacity() - self.buffer.len() < 65536 {
82 self.buffer.reserve(131072);
83 }
84
85 match tokio::time::timeout(DEFAULT_READ_TIMEOUT, self.stream.read_buf(&mut self.buffer))
86 .await
87 {
88 Ok(Ok(n)) => Ok(n),
89 Ok(Err(e)) => Err(PgError::Connection(format!("Read error: {}", e))),
90 Err(_) => Err(PgError::Connection(format!(
91 "Read timeout after {:?} — possible Slowloris attack or dead connection",
92 DEFAULT_READ_TIMEOUT
93 ))),
94 }
95 }
96
97 pub async fn send_bytes(&mut self, bytes: &[u8]) -> PgResult<()> {
101 self.stream.write_all(bytes).await?;
102 self.stream.flush().await?;
103 Ok(())
104 }
105
106 #[inline]
111 pub fn buffer_bytes(&mut self, bytes: &[u8]) {
112 self.write_buf.extend_from_slice(bytes);
113 }
114
115 pub async fn flush_write_buf(&mut self) -> PgResult<()> {
118 if !self.write_buf.is_empty() {
119 self.stream.write_all(&self.write_buf).await?;
120 self.write_buf.clear();
121 self.stream.flush().await?;
122 }
123 Ok(())
124 }
125
126 #[inline]
130 pub(crate) async fn recv_msg_type_fast(&mut self) -> PgResult<u8> {
131 loop {
132 if self.buffer.len() >= 5 {
133 let msg_len = u32::from_be_bytes([
134 self.buffer[1],
135 self.buffer[2],
136 self.buffer[3],
137 self.buffer[4],
138 ]) as usize;
139
140 if msg_len > MAX_MESSAGE_SIZE {
141 return Err(PgError::Protocol(format!(
142 "Message too large: {} bytes (max {})",
143 msg_len, MAX_MESSAGE_SIZE
144 )));
145 }
146
147 if self.buffer.len() > msg_len {
148 let msg_type = self.buffer[0];
149
150 if msg_type == b'E' {
151 let msg_bytes = self.buffer.split_to(msg_len + 1);
152 let (msg, _) =
153 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
154 if let BackendMessage::ErrorResponse(err) = msg {
155 return Err(PgError::QueryServer(err.into()));
156 }
157 }
158
159 let _ = self.buffer.split_to(msg_len + 1);
160 return Ok(msg_type);
161 }
162 }
163
164 let n = self.read_with_timeout().await?;
165 if n == 0 {
166 return Err(PgError::Connection("Connection closed".to_string()));
167 }
168 }
169 }
170
171 #[inline]
177 pub(crate) async fn recv_with_data_fast(
178 &mut self,
179 ) -> PgResult<(u8, Option<Vec<Option<Vec<u8>>>>)> {
180 loop {
181 if self.buffer.len() >= 5 {
182 let msg_len = u32::from_be_bytes([
183 self.buffer[1],
184 self.buffer[2],
185 self.buffer[3],
186 self.buffer[4],
187 ]) as usize;
188
189 if msg_len > MAX_MESSAGE_SIZE {
190 return Err(PgError::Protocol(format!(
191 "Message too large: {} bytes (max {})",
192 msg_len, MAX_MESSAGE_SIZE
193 )));
194 }
195
196 if self.buffer.len() > msg_len {
197 let msg_type = self.buffer[0];
198
199 if msg_type == b'E' {
200 let msg_bytes = self.buffer.split_to(msg_len + 1);
201 let (msg, _) =
202 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
203 if let BackendMessage::ErrorResponse(err) = msg {
204 return Err(PgError::QueryServer(err.into()));
205 }
206 }
207
208 if msg_type == b'D' {
210 let payload = &self.buffer[5..msg_len + 1];
211
212 if payload.len() >= 2 {
213 let column_count =
214 u16::from_be_bytes([payload[0], payload[1]]) as usize;
215 let mut columns = Vec::with_capacity(column_count);
216 let mut pos = 2;
217
218 for _ in 0..column_count {
219 if pos + 4 > payload.len() {
220 let _ = self.buffer.split_to(msg_len + 1);
221 return Err(PgError::Protocol(
222 "DataRow truncated: missing column length".into(),
223 ));
224 }
225
226 let len = i32::from_be_bytes([
227 payload[pos],
228 payload[pos + 1],
229 payload[pos + 2],
230 payload[pos + 3],
231 ]);
232 pos += 4;
233
234 if len == -1 {
235 columns.push(None);
236 } else {
237 let len = len as usize;
238 if pos + len > payload.len() {
239 let _ = self.buffer.split_to(msg_len + 1);
240 return Err(PgError::Protocol(
241 "DataRow truncated: column data exceeds payload".into(),
242 ));
243 }
244 columns.push(Some(payload[pos..pos + len].to_vec()));
245 pos += len;
246 }
247 }
248
249 let _ = self.buffer.split_to(msg_len + 1);
250 return Ok((msg_type, Some(columns)));
251 }
252 }
253
254 let _ = self.buffer.split_to(msg_len + 1);
256 return Ok((msg_type, None));
257 }
258 }
259
260 let n = self.read_with_timeout().await?;
261 if n == 0 {
262 return Err(PgError::Connection("Connection closed".to_string()));
263 }
264 }
265 }
266
267 #[inline]
273 pub(crate) async fn recv_data_zerocopy(
274 &mut self,
275 ) -> PgResult<(u8, Option<Vec<Option<bytes::Bytes>>>)> {
276 use bytes::Buf;
277
278 loop {
279 if self.buffer.len() >= 5 {
280 let msg_len = u32::from_be_bytes([
281 self.buffer[1],
282 self.buffer[2],
283 self.buffer[3],
284 self.buffer[4],
285 ]) as usize;
286
287 if msg_len > MAX_MESSAGE_SIZE {
288 return Err(PgError::Protocol(format!(
289 "Message too large: {} bytes (max {})",
290 msg_len, MAX_MESSAGE_SIZE
291 )));
292 }
293
294 if self.buffer.len() > msg_len {
295 let msg_type = self.buffer[0];
296
297 if msg_type == b'E' {
298 let msg_bytes = self.buffer.split_to(msg_len + 1);
299 let (msg, _) =
300 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
301 if let BackendMessage::ErrorResponse(err) = msg {
302 return Err(PgError::QueryServer(err.into()));
303 }
304 }
305
306 if msg_type == b'D' {
308 let mut msg_bytes = self.buffer.split_to(msg_len + 1);
310
311 msg_bytes.advance(5);
313
314 if msg_bytes.len() >= 2 {
315 let column_count = msg_bytes.get_u16() as usize;
316 let mut columns = Vec::with_capacity(column_count);
317
318 for _ in 0..column_count {
319 if msg_bytes.remaining() < 4 {
320 return Err(PgError::Protocol(
321 "DataRow truncated: missing column length".into(),
322 ));
323 }
324
325 let len = msg_bytes.get_i32();
326
327 if len == -1 {
328 columns.push(None);
329 } else {
330 let len = len as usize;
331 if msg_bytes.remaining() < len {
332 return Err(PgError::Protocol(
333 "DataRow truncated: column data exceeds payload".into(),
334 ));
335 }
336 let col_data = msg_bytes.split_to(len).freeze();
337 columns.push(Some(col_data));
338 }
339 }
340
341 return Ok((msg_type, Some(columns)));
342 }
343 return Ok((msg_type, None));
344 }
345
346 let _ = self.buffer.split_to(msg_len + 1);
348 return Ok((msg_type, None));
349 }
350 }
351
352 let n = self.read_with_timeout().await?;
353 if n == 0 {
354 return Err(PgError::Connection("Connection closed".to_string()));
355 }
356 }
357 }
358
359 #[inline(always)]
363 pub(crate) async fn recv_data_ultra(
364 &mut self,
365 ) -> PgResult<(u8, Option<(bytes::Bytes, bytes::Bytes)>)> {
366 use bytes::Buf;
367
368 loop {
369 if self.buffer.len() >= 5 {
370 let msg_len = u32::from_be_bytes([
371 self.buffer[1],
372 self.buffer[2],
373 self.buffer[3],
374 self.buffer[4],
375 ]) as usize;
376
377 if msg_len > MAX_MESSAGE_SIZE {
378 return Err(PgError::Protocol(format!(
379 "Message too large: {} bytes (max {})",
380 msg_len, MAX_MESSAGE_SIZE
381 )));
382 }
383
384 if self.buffer.len() > msg_len {
385 let msg_type = self.buffer[0];
386
387 if msg_type == b'E' {
389 let msg_bytes = self.buffer.split_to(msg_len + 1);
390 let (msg, _) =
391 BackendMessage::decode(&msg_bytes).map_err(PgError::Protocol)?;
392 if let BackendMessage::ErrorResponse(err) = msg {
393 return Err(PgError::QueryServer(err.into()));
394 }
395 }
396
397 if msg_type == b'D' {
398 let mut msg_bytes = self.buffer.split_to(msg_len + 1);
399 msg_bytes.advance(5); if msg_bytes.remaining() < 2 {
403 return Err(PgError::Protocol(
404 "DataRow ultra: too short for column count".into(),
405 ));
406 }
407
408 let _col_count = msg_bytes.get_u16();
410
411 if msg_bytes.remaining() < 4 {
412 return Err(PgError::Protocol(
413 "DataRow ultra: truncated before col0 length".into(),
414 ));
415 }
416 let len0 = msg_bytes.get_i32();
417 let col0 = if len0 > 0 {
418 let len0 = len0 as usize;
419 if msg_bytes.remaining() < len0 {
420 return Err(PgError::Protocol(
421 "DataRow ultra: col0 data exceeds payload".into(),
422 ));
423 }
424 msg_bytes.split_to(len0).freeze()
425 } else {
426 bytes::Bytes::new()
427 };
428
429 if msg_bytes.remaining() < 4 {
430 return Err(PgError::Protocol(
431 "DataRow ultra: truncated before col1 length".into(),
432 ));
433 }
434 let len1 = msg_bytes.get_i32();
435 let col1 = if len1 > 0 {
436 let len1 = len1 as usize;
437 if msg_bytes.remaining() < len1 {
438 return Err(PgError::Protocol(
439 "DataRow ultra: col1 data exceeds payload".into(),
440 ));
441 }
442 msg_bytes.split_to(len1).freeze()
443 } else {
444 bytes::Bytes::new()
445 };
446
447 return Ok((msg_type, Some((col0, col1))));
448 }
449
450 let _ = self.buffer.split_to(msg_len + 1);
452 return Ok((msg_type, None));
453 }
454 }
455
456 let n = self.read_with_timeout().await?;
457 if n == 0 {
458 return Err(PgError::Connection("Connection closed".to_string()));
459 }
460 }
461 }
462}