1use super::{PgBytesRow, PgConnection, PgError, PgResult, is_ignorable_session_message};
6use crate::protocol::{BackendMessage, FrontendMessage, PgEncoder};
7use bytes::{Bytes, BytesMut};
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9
10pub(crate) const MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
15const DEFAULT_WRITE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
18const READ_SPARE_LOW_WATERMARK: usize = 64 * 1024;
19
20#[inline]
21fn reserve_read_spare_capacity(buffer: &mut BytesMut) {
22 let spare = buffer.capacity().saturating_sub(buffer.len());
23 if spare < READ_SPARE_LOW_WATERMARK {
24 let target_spare = READ_SPARE_LOW_WATERMARK.max(buffer.capacity());
25 buffer.reserve(target_spare.saturating_sub(spare));
26 }
27}
28
29#[inline]
30fn parse_data_row_payload_owned(payload: &[u8]) -> PgResult<Vec<Option<Vec<u8>>>> {
31 if payload.len() < 2 {
32 return Err(PgError::Protocol("DataRow payload too short".into()));
33 }
34
35 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
36 if raw_count < 0 {
37 return Err(PgError::Protocol(format!(
38 "DataRow invalid column count: {}",
39 raw_count
40 )));
41 }
42 let column_count = raw_count as usize;
43 if column_count > (payload.len() - 2) / 4 + 1 {
44 return Err(PgError::Protocol(format!(
45 "DataRow claims {} columns but payload is only {} bytes",
46 column_count,
47 payload.len()
48 )));
49 }
50
51 let mut columns = Vec::with_capacity(column_count);
52 let mut pos = 2;
53 for _ in 0..column_count {
54 if pos + 4 > payload.len() {
55 return Err(PgError::Protocol(
56 "DataRow truncated: missing column length".into(),
57 ));
58 }
59
60 let len = i32::from_be_bytes([
61 payload[pos],
62 payload[pos + 1],
63 payload[pos + 2],
64 payload[pos + 3],
65 ]);
66 pos += 4;
67
68 if len == -1 {
69 columns.push(None);
70 continue;
71 }
72 if len < -1 {
73 return Err(PgError::Protocol(format!(
74 "DataRow invalid column length: {}",
75 len
76 )));
77 }
78
79 let len = len as usize;
80 if len > payload.len().saturating_sub(pos) {
81 return Err(PgError::Protocol(
82 "DataRow truncated: column data exceeds payload".into(),
83 ));
84 }
85 columns.push(Some(payload[pos..pos + len].to_vec()));
86 pos += len;
87 }
88
89 if pos != payload.len() {
90 return Err(PgError::Protocol("DataRow has trailing bytes".into()));
91 }
92
93 Ok(columns)
94}
95
96#[inline]
97fn parse_data_row_payload_reuse(
98 payload: &[u8],
99 columns: &mut Vec<Option<Vec<u8>>>,
100) -> PgResult<()> {
101 if payload.len() < 2 {
102 return Err(PgError::Protocol("DataRow payload too short".into()));
103 }
104
105 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
106 if raw_count < 0 {
107 return Err(PgError::Protocol(format!(
108 "DataRow invalid column count: {}",
109 raw_count
110 )));
111 }
112 let column_count = raw_count as usize;
113 if column_count > (payload.len() - 2) / 4 + 1 {
114 return Err(PgError::Protocol(format!(
115 "DataRow claims {} columns but payload is only {} bytes",
116 column_count,
117 payload.len()
118 )));
119 }
120
121 let previous_len = columns.len();
122 if previous_len < column_count {
123 columns.reserve(column_count - previous_len);
124 }
125
126 let mut pos = 2usize;
127 for idx in 0..column_count {
128 if pos + 4 > payload.len() {
129 return Err(PgError::Protocol(
130 "DataRow truncated: missing column length".into(),
131 ));
132 }
133
134 let len = i32::from_be_bytes([
135 payload[pos],
136 payload[pos + 1],
137 payload[pos + 2],
138 payload[pos + 3],
139 ]);
140 pos += 4;
141
142 if len == -1 {
143 if idx < previous_len {
144 columns[idx] = None;
145 } else {
146 columns.push(None);
147 }
148 continue;
149 }
150 if len < -1 {
151 return Err(PgError::Protocol(format!(
152 "DataRow invalid column length: {}",
153 len
154 )));
155 }
156
157 let len = len as usize;
158 if len > payload.len().saturating_sub(pos) {
159 return Err(PgError::Protocol(
160 "DataRow truncated: column data exceeds payload".into(),
161 ));
162 }
163 let value = &payload[pos..pos + len];
164 pos += len;
165
166 if idx < previous_len {
167 match &mut columns[idx] {
168 Some(buf) => {
169 buf.clear();
170 buf.extend_from_slice(value);
171 }
172 None => columns[idx] = Some(value.to_vec()),
173 }
174 } else {
175 columns.push(Some(value.to_vec()));
176 }
177 }
178
179 if columns.len() > column_count {
180 columns.truncate(column_count);
181 }
182
183 if pos != payload.len() {
184 return Err(PgError::Protocol("DataRow has trailing bytes".into()));
185 }
186
187 Ok(())
188}
189
190#[inline]
191fn parse_data_row_payload_zerocopy(payload: Bytes, row: &mut PgBytesRow) -> PgResult<()> {
192 if payload.len() < 2 {
193 return Err(PgError::Protocol("DataRow payload too short".into()));
194 }
195
196 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
197 if raw_count < 0 {
198 return Err(PgError::Protocol(format!(
199 "DataRow invalid column count: {}",
200 raw_count
201 )));
202 }
203 let column_count = raw_count as usize;
204 if column_count > (payload.len() - 2) / 4 + 1 {
205 return Err(PgError::Protocol(format!(
206 "DataRow claims {} columns but payload is only {} bytes",
207 column_count,
208 payload.len()
209 )));
210 }
211
212 row.payload = payload;
213 row.spans.clear();
214 if row.spans.capacity() < column_count {
215 row.spans.reserve(column_count - row.spans.capacity());
216 }
217
218 let mut pos = 2usize;
219 for _ in 0..column_count {
220 if pos + 4 > row.payload.len() {
221 return Err(PgError::Protocol(
222 "DataRow truncated: missing column length".into(),
223 ));
224 }
225
226 let len = i32::from_be_bytes([
227 row.payload[pos],
228 row.payload[pos + 1],
229 row.payload[pos + 2],
230 row.payload[pos + 3],
231 ]);
232 pos += 4;
233
234 if len == -1 {
235 row.spans.push(None);
236 continue;
237 }
238 if len < -1 {
239 return Err(PgError::Protocol(format!(
240 "DataRow invalid column length: {}",
241 len
242 )));
243 }
244
245 let len = len as usize;
246 if len > row.payload.len().saturating_sub(pos) {
247 return Err(PgError::Protocol(
248 "DataRow truncated: column data exceeds payload".into(),
249 ));
250 }
251 row.spans.push(Some((pos, len)));
252 pos += len;
253 }
254
255 if pos != row.payload.len() {
256 return Err(PgError::Protocol("DataRow has trailing bytes".into()));
257 }
258
259 Ok(())
260}
261
262#[inline]
263fn parse_first_column_payload_zerocopy(payload: Bytes) -> PgResult<Option<Bytes>> {
264 if payload.len() < 2 {
265 return Err(PgError::Protocol("DataRow payload too short".into()));
266 }
267
268 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
269 if raw_count < 0 {
270 return Err(PgError::Protocol(format!(
271 "DataRow invalid column count: {}",
272 raw_count
273 )));
274 }
275 let column_count = raw_count as usize;
276 if column_count > (payload.len() - 2) / 4 + 1 {
277 return Err(PgError::Protocol(format!(
278 "DataRow claims {} columns but payload is only {} bytes",
279 column_count,
280 payload.len()
281 )));
282 }
283
284 let mut pos = 2usize;
285 let mut first_column = None;
286
287 for idx in 0..column_count {
288 if pos + 4 > payload.len() {
289 return Err(PgError::Protocol(
290 "DataRow truncated: missing column length".into(),
291 ));
292 }
293
294 let len = i32::from_be_bytes([
295 payload[pos],
296 payload[pos + 1],
297 payload[pos + 2],
298 payload[pos + 3],
299 ]);
300 pos += 4;
301
302 if len == -1 {
303 if idx == 0 {
304 first_column = None;
305 }
306 continue;
307 }
308 if len < -1 {
309 return Err(PgError::Protocol(format!(
310 "DataRow invalid column length: {}",
311 len
312 )));
313 }
314
315 let len = len as usize;
316 if len > payload.len().saturating_sub(pos) {
317 return Err(PgError::Protocol(
318 "DataRow truncated: column data exceeds payload".into(),
319 ));
320 }
321
322 if idx == 0 {
323 first_column = Some(payload.slice(pos..pos + len));
324 }
325 pos += len;
326 }
327
328 if pos != payload.len() {
329 return Err(PgError::Protocol("DataRow has trailing bytes".into()));
330 }
331
332 Ok(first_column)
333}
334
335#[inline]
336fn parse_first_four_columns_payload_zerocopy(
337 payload: Bytes,
338 columns: &mut [Option<Bytes>; 4],
339) -> PgResult<()> {
340 if payload.len() < 2 {
341 return Err(PgError::Protocol("DataRow payload too short".into()));
342 }
343
344 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
345 if raw_count < 0 {
346 return Err(PgError::Protocol(format!(
347 "DataRow invalid column count: {}",
348 raw_count
349 )));
350 }
351 let column_count = raw_count as usize;
352 if column_count > (payload.len() - 2) / 4 + 1 {
353 return Err(PgError::Protocol(format!(
354 "DataRow claims {} columns but payload is only {} bytes",
355 column_count,
356 payload.len()
357 )));
358 }
359 if column_count != 4 {
360 return Err(PgError::Protocol(format!(
361 "DataRow fast-path expects exactly 4 columns, got {}",
362 column_count
363 )));
364 }
365
366 let mut pos = 2usize;
367 for slot in columns.iter_mut() {
368 if pos + 4 > payload.len() {
369 return Err(PgError::Protocol(
370 "DataRow truncated: missing column length".into(),
371 ));
372 }
373
374 let len = i32::from_be_bytes([
375 payload[pos],
376 payload[pos + 1],
377 payload[pos + 2],
378 payload[pos + 3],
379 ]);
380 pos += 4;
381
382 if len == -1 {
383 *slot = None;
384 continue;
385 }
386 if len < -1 {
387 return Err(PgError::Protocol(format!(
388 "DataRow invalid column length: {}",
389 len
390 )));
391 }
392
393 let len = len as usize;
394 if len > payload.len().saturating_sub(pos) {
395 return Err(PgError::Protocol(
396 "DataRow truncated: column data exceeds payload".into(),
397 ));
398 }
399 *slot = Some(payload.slice(pos..pos + len));
400 pos += len;
401 }
402
403 if pos != payload.len() {
404 return Err(PgError::Protocol("DataRow has trailing bytes".into()));
405 }
406
407 Ok(())
408}
409
410impl PgConnection {
411 #[inline]
412 fn stream_requires_flush(&self) -> bool {
413 use super::stream::PgStream;
414
415 match &self.stream {
416 PgStream::Tcp(_) => false,
417 PgStream::Tls(_) => true,
418 #[cfg(all(target_os = "linux", feature = "io_uring"))]
419 PgStream::Uring(_) => false,
420 #[cfg(unix)]
421 PgStream::Unix(_) => false,
422 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
423 PgStream::GssEnc(_) => true,
424 }
425 }
426
427 #[inline]
428 pub(crate) fn mark_io_desynced(&mut self) {
429 self.io_desynced = true;
430 }
431
432 #[inline]
433 pub(crate) fn is_io_desynced(&self) -> bool {
434 self.io_desynced
435 }
436
437 #[inline]
438 fn protocol_desync<T>(&mut self, msg: String) -> PgResult<T> {
439 self.mark_io_desynced();
440 Err(PgError::Protocol(msg))
441 }
442
443 #[inline]
444 fn protocol_desync_error<T>(&mut self, err: PgError) -> PgResult<T> {
445 match err {
446 PgError::Protocol(msg) => self.protocol_desync(msg),
447 err => {
448 self.mark_io_desynced();
449 Err(err)
450 }
451 }
452 }
453
454 #[inline]
455 fn connection_desync<T>(&mut self, msg: String) -> PgResult<T> {
456 self.mark_io_desynced();
457 Err(PgError::Connection(msg))
458 }
459
460 async fn flush_pending_statement_closes(&mut self) -> PgResult<()> {
466 if self.draining_statement_closes || self.pending_statement_closes.is_empty() {
467 return Ok(());
468 }
469
470 self.draining_statement_closes = true;
471 let close_names = std::mem::take(&mut self.pending_statement_closes);
472
473 let estimated_payload_len: usize = close_names
474 .iter()
475 .map(|name| 16usize.saturating_add(name.len()))
476 .sum();
477 let mut buf = BytesMut::with_capacity(estimated_payload_len.saturating_add(5));
478 for stmt_name in &close_names {
479 let close_msg = PgEncoder::try_encode_close(false, stmt_name)
480 .map_err(|e| PgError::Encode(e.to_string()))?;
481 buf.extend_from_slice(&close_msg);
482 }
483 PgEncoder::encode_sync_to(&mut buf);
484
485 if let Err(err) = self
486 .write_all_with_timeout_inner(&buf, "pending statement close write")
487 .await
488 {
489 self.draining_statement_closes = false;
490 return Err(err);
491 }
492 if let Err(err) = self
493 .flush_with_timeout("pending statement close flush")
494 .await
495 {
496 self.draining_statement_closes = false;
497 return Err(err);
498 }
499
500 let mut error: Option<PgError> = None;
501 loop {
502 let msg = match self.recv().await {
503 Ok(msg) => msg,
504 Err(err) => {
505 self.draining_statement_closes = false;
506 return Err(err);
507 }
508 };
509 match msg {
510 BackendMessage::CloseComplete => {}
511 BackendMessage::ReadyForQuery(_) => {
512 self.draining_statement_closes = false;
513 if let Some(err) = error {
514 return Err(err);
515 }
516 return Ok(());
517 }
518 BackendMessage::ErrorResponse(err_fields) => {
519 if error.is_none() {
520 let code_26000 = err_fields.code.eq_ignore_ascii_case("26000");
521 let msg_lower = err_fields.message.to_ascii_lowercase();
522 let missing_prepared = msg_lower.contains("prepared statement")
523 && msg_lower.contains("does not exist");
524 if !(code_26000 && missing_prepared) {
525 error = Some(PgError::QueryServer(err_fields.into()));
526 }
527 }
528 }
529 msg if is_ignorable_session_message(&msg) => {}
530 other => {
531 self.draining_statement_closes = false;
532 return self.protocol_desync(format!(
533 "Unexpected backend message during pending statement close drain: {:?}",
534 other
535 ));
536 }
537 }
538 }
539 }
540
541 pub(crate) async fn write_all_with_timeout(
545 &mut self,
546 bytes: &[u8],
547 operation: &str,
548 ) -> PgResult<()> {
549 if !self.draining_statement_closes && !self.pending_statement_closes.is_empty() {
550 self.flush_pending_statement_closes().await?;
551 }
552 self.write_all_with_timeout_inner(bytes, operation).await
553 }
554
555 async fn write_all_with_timeout_inner(
556 &mut self,
557 bytes: &[u8],
558 operation: &str,
559 ) -> PgResult<()> {
560 if bytes.is_empty() {
561 return Err(PgError::Encode(
562 "refusing to send empty frontend payload".to_string(),
563 ));
564 }
565 use super::stream::PgStream;
566 let mut mark_desync = false;
567 let result = match &mut self.stream {
568 PgStream::Tcp(stream) => {
569 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
570 Ok(Ok(())) => Ok(()),
571 Ok(Err(e)) => {
572 mark_desync = true;
573 Err(PgError::Connection(format!("Write error: {}", e)))
574 }
575 Err(_) => {
576 mark_desync = true;
577 Err(PgError::Timeout(format!(
578 "{} timeout after {:?}",
579 operation, DEFAULT_WRITE_TIMEOUT
580 )))
581 }
582 }
583 }
584 PgStream::Tls(stream) => {
585 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
586 Ok(Ok(())) => Ok(()),
587 Ok(Err(e)) => {
588 mark_desync = true;
589 Err(PgError::Connection(format!("Write error: {}", e)))
590 }
591 Err(_) => {
592 mark_desync = true;
593 Err(PgError::Timeout(format!(
594 "{} timeout after {:?}",
595 operation, DEFAULT_WRITE_TIMEOUT
596 )))
597 }
598 }
599 }
600 #[cfg(all(target_os = "linux", feature = "io_uring"))]
601 PgStream::Uring(stream) => {
602 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
603 Ok(Ok(())) => Ok(()),
604 Ok(Err(e)) => {
605 mark_desync = true;
606 Err(PgError::Connection(format!("Write error: {}", e)))
607 }
608 Err(_) => {
609 mark_desync = true;
610 let _ = stream.abort_inflight();
611 Err(PgError::Timeout(format!(
612 "{} timeout after {:?}",
613 operation, DEFAULT_WRITE_TIMEOUT
614 )))
615 }
616 }
617 }
618 #[cfg(unix)]
619 PgStream::Unix(stream) => {
620 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
621 Ok(Ok(())) => Ok(()),
622 Ok(Err(e)) => {
623 mark_desync = true;
624 Err(PgError::Connection(format!("Write error: {}", e)))
625 }
626 Err(_) => {
627 mark_desync = true;
628 Err(PgError::Timeout(format!(
629 "{} timeout after {:?}",
630 operation, DEFAULT_WRITE_TIMEOUT
631 )))
632 }
633 }
634 }
635 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
636 PgStream::GssEnc(stream) => {
637 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
638 Ok(Ok(())) => Ok(()),
639 Ok(Err(e)) => {
640 mark_desync = true;
641 Err(PgError::Connection(format!("Write error: {}", e)))
642 }
643 Err(_) => {
644 mark_desync = true;
645 Err(PgError::Timeout(format!(
646 "{} timeout after {:?}",
647 operation, DEFAULT_WRITE_TIMEOUT
648 )))
649 }
650 }
651 }
652 };
653 if mark_desync {
654 self.mark_io_desynced();
655 }
656 result
657 }
658
659 pub(crate) async fn flush_with_timeout(&mut self, operation: &str) -> PgResult<()> {
661 if !self.stream_requires_flush() {
662 return Ok(());
663 }
664
665 use super::stream::PgStream;
666 let mut mark_desync = false;
667 let result = match &mut self.stream {
668 PgStream::Tcp(stream) => {
669 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
670 Ok(Ok(())) => Ok(()),
671 Ok(Err(e)) => {
672 mark_desync = true;
673 Err(PgError::Connection(format!("Flush error: {}", e)))
674 }
675 Err(_) => {
676 mark_desync = true;
677 Err(PgError::Timeout(format!(
678 "{} timeout after {:?}",
679 operation, DEFAULT_WRITE_TIMEOUT
680 )))
681 }
682 }
683 }
684 PgStream::Tls(stream) => {
685 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
686 Ok(Ok(())) => Ok(()),
687 Ok(Err(e)) => {
688 mark_desync = true;
689 Err(PgError::Connection(format!("Flush error: {}", e)))
690 }
691 Err(_) => {
692 mark_desync = true;
693 Err(PgError::Timeout(format!(
694 "{} timeout after {:?}",
695 operation, DEFAULT_WRITE_TIMEOUT
696 )))
697 }
698 }
699 }
700 #[cfg(all(target_os = "linux", feature = "io_uring"))]
701 PgStream::Uring(stream) => {
702 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
703 Ok(Ok(())) => Ok(()),
704 Ok(Err(e)) => {
705 mark_desync = true;
706 Err(PgError::Connection(format!("Flush error: {}", e)))
707 }
708 Err(_) => {
709 mark_desync = true;
710 let _ = stream.abort_inflight();
711 Err(PgError::Timeout(format!(
712 "{} timeout after {:?}",
713 operation, DEFAULT_WRITE_TIMEOUT
714 )))
715 }
716 }
717 }
718 #[cfg(unix)]
719 PgStream::Unix(stream) => {
720 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
721 Ok(Ok(())) => Ok(()),
722 Ok(Err(e)) => {
723 mark_desync = true;
724 Err(PgError::Connection(format!("Flush error: {}", e)))
725 }
726 Err(_) => {
727 mark_desync = true;
728 Err(PgError::Timeout(format!(
729 "{} timeout after {:?}",
730 operation, DEFAULT_WRITE_TIMEOUT
731 )))
732 }
733 }
734 }
735 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
736 PgStream::GssEnc(stream) => {
737 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
738 Ok(Ok(())) => Ok(()),
739 Ok(Err(e)) => {
740 mark_desync = true;
741 Err(PgError::Connection(format!("Flush error: {}", e)))
742 }
743 Err(_) => {
744 mark_desync = true;
745 Err(PgError::Timeout(format!(
746 "{} timeout after {:?}",
747 operation, DEFAULT_WRITE_TIMEOUT
748 )))
749 }
750 }
751 }
752 };
753 if mark_desync {
754 self.mark_io_desynced();
755 }
756 result
757 }
758
759 pub async fn send(&mut self, msg: FrontendMessage) -> PgResult<()> {
761 let bytes = msg
762 .encode_checked()
763 .map_err(|e| PgError::Encode(e.to_string()))?;
764 self.write_all_with_timeout(&bytes, "send frontend message")
765 .await?;
766 Ok(())
767 }
768
769 pub async fn recv(&mut self) -> PgResult<BackendMessage> {
772 loop {
773 if self.buffer.len() >= 5 {
775 let msg_len = u32::from_be_bytes([
776 self.buffer[1],
777 self.buffer[2],
778 self.buffer[3],
779 self.buffer[4],
780 ]) as usize;
781
782 if msg_len < 4 {
783 return self.protocol_desync(format!(
784 "Invalid message length: {} (minimum 4)",
785 msg_len
786 ));
787 }
788
789 if msg_len > MAX_MESSAGE_SIZE {
790 return self.protocol_desync(format!(
791 "Message too large: {} bytes (max {})",
792 msg_len, MAX_MESSAGE_SIZE
793 ));
794 }
795
796 if self.buffer.len() > msg_len {
797 let msg_bytes = self.buffer.split_to(msg_len + 1);
799 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
800 Ok(decoded) => decoded,
801 Err(e) => return self.protocol_desync(e),
802 };
803
804 if let BackendMessage::NotificationResponse {
806 process_id,
807 channel,
808 payload,
809 } = msg
810 {
811 self.notifications
812 .push_back(super::notification::Notification {
813 process_id,
814 channel,
815 payload,
816 });
817 continue; }
819
820 return Ok(msg);
821 }
822 }
823
824 let n = self.read_with_timeout().await?;
825 if n == 0 {
826 return self.connection_desync("Connection closed".to_string());
827 }
828 }
829 }
830
831 pub(crate) async fn recv_without_timeout(&mut self) -> PgResult<BackendMessage> {
838 loop {
839 if self.buffer.len() >= 5 {
840 let msg_len = u32::from_be_bytes([
841 self.buffer[1],
842 self.buffer[2],
843 self.buffer[3],
844 self.buffer[4],
845 ]) as usize;
846
847 if msg_len < 4 {
848 return self.protocol_desync(format!(
849 "Invalid message length: {} (minimum 4)",
850 msg_len
851 ));
852 }
853
854 if msg_len > MAX_MESSAGE_SIZE {
855 return self.protocol_desync(format!(
856 "Message too large: {} bytes (max {})",
857 msg_len, MAX_MESSAGE_SIZE
858 ));
859 }
860
861 if self.buffer.len() > msg_len {
862 let msg_bytes = self.buffer.split_to(msg_len + 1);
863 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
864 Ok(decoded) => decoded,
865 Err(e) => return self.protocol_desync(e),
866 };
867
868 if let BackendMessage::NotificationResponse {
869 process_id,
870 channel,
871 payload,
872 } = msg
873 {
874 self.notifications
875 .push_back(super::notification::Notification {
876 process_id,
877 channel,
878 payload,
879 });
880 continue;
881 }
882
883 return Ok(msg);
884 }
885 }
886
887 let n = if self.buffer.is_empty() {
888 self.read_without_timeout().await?
889 } else {
890 self.read_with_timeout().await?
891 };
892 if n == 0 {
893 return self.connection_desync("Connection closed".to_string());
894 }
895 }
896 }
897
898 #[inline]
903 pub(crate) async fn read_with_timeout(&mut self) -> PgResult<usize> {
904 reserve_read_spare_capacity(&mut self.buffer);
905
906 use super::stream::PgStream;
907 let (stream, buffer) = (&mut self.stream, &mut self.buffer);
908 let mut mark_desync = false;
909 let result = match stream {
910 PgStream::Tcp(stream) => {
911 match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_buf(buffer)).await {
912 Ok(Ok(n)) => Ok(n),
913 Ok(Err(e)) => {
914 mark_desync = true;
915 Err(PgError::Connection(format!("Read error: {}", e)))
916 }
917 Err(_) => {
918 mark_desync = true;
919 Err(PgError::Connection(format!(
920 "Read timeout after {:?} — possible Slowloris attack or dead connection",
921 DEFAULT_READ_TIMEOUT
922 )))
923 }
924 }
925 }
926 PgStream::Tls(stream) => {
927 match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_buf(buffer)).await {
928 Ok(Ok(n)) => Ok(n),
929 Ok(Err(e)) => {
930 mark_desync = true;
931 Err(PgError::Connection(format!("Read error: {}", e)))
932 }
933 Err(_) => {
934 mark_desync = true;
935 Err(PgError::Connection(format!(
936 "Read timeout after {:?} — possible Slowloris attack or dead connection",
937 DEFAULT_READ_TIMEOUT
938 )))
939 }
940 }
941 }
942 #[cfg(all(target_os = "linux", feature = "io_uring"))]
943 PgStream::Uring(stream) => {
944 match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_into(buffer, 131072))
945 .await
946 {
947 Ok(Ok(n)) => Ok(n),
948 Ok(Err(e)) => {
949 mark_desync = true;
950 Err(PgError::Connection(format!("Read error: {}", e)))
951 }
952 Err(_) => {
953 mark_desync = true;
954 let _ = stream.abort_inflight();
955 Err(PgError::Connection(format!(
956 "Read timeout after {:?} — possible Slowloris attack or dead connection",
957 DEFAULT_READ_TIMEOUT
958 )))
959 }
960 }
961 }
962 #[cfg(unix)]
963 PgStream::Unix(stream) => {
964 match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_buf(buffer)).await {
965 Ok(Ok(n)) => Ok(n),
966 Ok(Err(e)) => {
967 mark_desync = true;
968 Err(PgError::Connection(format!("Read error: {}", e)))
969 }
970 Err(_) => {
971 mark_desync = true;
972 Err(PgError::Connection(format!(
973 "Read timeout after {:?} — possible Slowloris attack or dead connection",
974 DEFAULT_READ_TIMEOUT
975 )))
976 }
977 }
978 }
979 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
980 PgStream::GssEnc(stream) => {
981 match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_buf(buffer)).await {
982 Ok(Ok(n)) => Ok(n),
983 Ok(Err(e)) => {
984 mark_desync = true;
985 Err(PgError::Connection(format!("Read error: {}", e)))
986 }
987 Err(_) => {
988 mark_desync = true;
989 Err(PgError::Connection(format!(
990 "Read timeout after {:?} — possible Slowloris attack or dead connection",
991 DEFAULT_READ_TIMEOUT
992 )))
993 }
994 }
995 }
996 };
997 if mark_desync {
998 self.mark_io_desynced();
999 }
1000 result
1001 }
1002
1003 pub(crate) async fn read_without_timeout(&mut self) -> PgResult<usize> {
1007 reserve_read_spare_capacity(&mut self.buffer);
1008
1009 use super::stream::PgStream;
1010 let (stream, buffer) = (&mut self.stream, &mut self.buffer);
1011 let read_result = match stream {
1012 PgStream::Tcp(stream) => stream.read_buf(buffer).await,
1013 PgStream::Tls(stream) => stream.read_buf(buffer).await,
1014 #[cfg(all(target_os = "linux", feature = "io_uring"))]
1015 PgStream::Uring(stream) => stream.read_into(buffer, 131072).await,
1016 #[cfg(unix)]
1017 PgStream::Unix(stream) => stream.read_buf(buffer).await,
1018 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
1019 PgStream::GssEnc(stream) => stream.read_buf(buffer).await,
1020 };
1021
1022 match read_result {
1023 Ok(n) => Ok(n),
1024 Err(e) => {
1025 self.mark_io_desynced();
1026 Err(PgError::Connection(format!("Read error: {}", e)))
1027 }
1028 }
1029 }
1030
1031 pub async fn send_bytes(&mut self, bytes: &[u8]) -> PgResult<()> {
1035 self.write_all_with_timeout(bytes, "send raw bytes").await?;
1036 self.flush_with_timeout("flush raw bytes").await?;
1037 Ok(())
1038 }
1039
1040 #[inline]
1045 pub fn buffer_bytes(&mut self, bytes: &[u8]) {
1046 self.write_buf.extend_from_slice(bytes);
1047 }
1048
1049 pub async fn flush_write_buf(&mut self) -> PgResult<()> {
1052 if !self.write_buf.is_empty() {
1053 let payload = self.write_buf.split().freeze();
1054 self.write_all_with_timeout(&payload, "flush write buffer")
1055 .await?;
1056 self.flush_with_timeout("flush write buffer").await?;
1057 }
1058 Ok(())
1059 }
1060
1061 #[inline]
1065 pub(crate) async fn recv_msg_type_fast(&mut self) -> PgResult<u8> {
1066 loop {
1067 if self.buffer.len() >= 5 {
1068 let msg_len = u32::from_be_bytes([
1069 self.buffer[1],
1070 self.buffer[2],
1071 self.buffer[3],
1072 self.buffer[4],
1073 ]) as usize;
1074
1075 if msg_len < 4 {
1076 return self.protocol_desync(format!(
1077 "Invalid message length: {} (minimum 4)",
1078 msg_len
1079 ));
1080 }
1081
1082 if msg_len > MAX_MESSAGE_SIZE {
1083 return self.protocol_desync(format!(
1084 "Message too large: {} bytes (max {})",
1085 msg_len, MAX_MESSAGE_SIZE
1086 ));
1087 }
1088
1089 if self.buffer.len() > msg_len {
1090 let msg_type = self.buffer[0];
1091
1092 if matches!(
1093 msg_type,
1094 b'E' | b'A'
1095 | b'Z'
1096 | b'C'
1097 | b'1'
1098 | b'2'
1099 | b'3'
1100 | b'n'
1101 | b's'
1102 | b'I'
1103 | b'S'
1104 | b'N'
1105 ) {
1106 let msg_bytes = self.buffer.split_to(msg_len + 1);
1107 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
1108 Ok(decoded) => decoded,
1109 Err(e) => return self.protocol_desync(e),
1110 };
1111 match msg {
1112 BackendMessage::ErrorResponse(err) => {
1113 return Err(PgError::QueryServer(err.into()));
1114 }
1115 BackendMessage::NotificationResponse {
1116 process_id,
1117 channel,
1118 payload,
1119 } => {
1120 self.notifications
1121 .push_back(super::notification::Notification {
1122 process_id,
1123 channel,
1124 payload,
1125 });
1126 continue;
1127 }
1128 BackendMessage::ReadyForQuery(_)
1129 | BackendMessage::CommandComplete(_)
1130 | BackendMessage::ParseComplete
1131 | BackendMessage::BindComplete
1132 | BackendMessage::CloseComplete
1133 | BackendMessage::NoData
1134 | BackendMessage::PortalSuspended
1135 | BackendMessage::EmptyQueryResponse
1136 | BackendMessage::ParameterStatus { .. }
1137 | BackendMessage::NoticeResponse(_) => {
1138 return Ok(msg_type);
1139 }
1140 _ => {
1141 return Err(PgError::Protocol(
1142 "Unexpected fast-path message".into(),
1143 ));
1144 }
1145 }
1146 }
1147
1148 let _ = self.buffer.split_to(msg_len + 1);
1149 return Ok(msg_type);
1150 }
1151 }
1152
1153 let n = self.read_with_timeout().await?;
1154 if n == 0 {
1155 return self.connection_desync("Connection closed".to_string());
1156 }
1157 }
1158 }
1159
1160 #[inline]
1166 pub(crate) async fn recv_with_data_fast(
1167 &mut self,
1168 ) -> PgResult<(u8, Option<Vec<Option<Vec<u8>>>>)> {
1169 loop {
1170 if self.buffer.len() >= 5 {
1171 let msg_len = u32::from_be_bytes([
1172 self.buffer[1],
1173 self.buffer[2],
1174 self.buffer[3],
1175 self.buffer[4],
1176 ]) as usize;
1177
1178 if msg_len < 4 {
1179 return self.protocol_desync(format!(
1180 "Invalid message length: {} (minimum 4)",
1181 msg_len
1182 ));
1183 }
1184
1185 if msg_len > MAX_MESSAGE_SIZE {
1186 return self.protocol_desync(format!(
1187 "Message too large: {} bytes (max {})",
1188 msg_len, MAX_MESSAGE_SIZE
1189 ));
1190 }
1191
1192 if self.buffer.len() > msg_len {
1193 let msg_type = self.buffer[0];
1194
1195 if msg_type == b'E' || msg_type == b'A' {
1196 let msg_bytes = self.buffer.split_to(msg_len + 1);
1197 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
1198 Ok(decoded) => decoded,
1199 Err(e) => return self.protocol_desync(e),
1200 };
1201 match msg {
1202 BackendMessage::ErrorResponse(err) => {
1203 return Err(PgError::QueryServer(err.into()));
1204 }
1205 BackendMessage::NotificationResponse {
1206 process_id,
1207 channel,
1208 payload,
1209 } => {
1210 self.notifications
1211 .push_back(super::notification::Notification {
1212 process_id,
1213 channel,
1214 payload,
1215 });
1216 continue;
1217 }
1218 _ => {
1219 return Err(PgError::Protocol(
1220 "Unexpected fast-path message".into(),
1221 ));
1222 }
1223 }
1224 }
1225
1226 if msg_type == b'D' {
1228 let parse_result = {
1229 let payload = &self.buffer[5..msg_len + 1];
1230 parse_data_row_payload_owned(payload)
1231 };
1232
1233 let _ = self.buffer.split_to(msg_len + 1);
1234 match parse_result {
1235 Ok(columns) => return Ok((msg_type, Some(columns))),
1236 Err(err) => return self.protocol_desync_error(err),
1237 }
1238 }
1239
1240 let _ = self.buffer.split_to(msg_len + 1);
1242 return Ok((msg_type, None));
1243 }
1244 }
1245
1246 let n = self.read_with_timeout().await?;
1247 if n == 0 {
1248 return self.connection_desync("Connection closed".to_string());
1249 }
1250 }
1251 }
1252
1253 #[inline]
1258 pub(crate) async fn recv_fill_data_row_fast(
1259 &mut self,
1260 row_buf: &mut Vec<Option<Vec<u8>>>,
1261 ) -> PgResult<u8> {
1262 loop {
1263 if self.buffer.len() >= 5 {
1264 let msg_len = u32::from_be_bytes([
1265 self.buffer[1],
1266 self.buffer[2],
1267 self.buffer[3],
1268 self.buffer[4],
1269 ]) as usize;
1270
1271 if msg_len < 4 {
1272 return self.protocol_desync(format!(
1273 "Invalid message length: {} (minimum 4)",
1274 msg_len
1275 ));
1276 }
1277
1278 if msg_len > MAX_MESSAGE_SIZE {
1279 return self.protocol_desync(format!(
1280 "Message too large: {} bytes (max {})",
1281 msg_len, MAX_MESSAGE_SIZE
1282 ));
1283 }
1284
1285 if self.buffer.len() > msg_len {
1286 let msg_type = self.buffer[0];
1287
1288 if msg_type == b'E' || msg_type == b'A' {
1289 let msg_bytes = self.buffer.split_to(msg_len + 1);
1290 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
1291 Ok(decoded) => decoded,
1292 Err(e) => return self.protocol_desync(e),
1293 };
1294 match msg {
1295 BackendMessage::ErrorResponse(err) => {
1296 return Err(PgError::QueryServer(err.into()));
1297 }
1298 BackendMessage::NotificationResponse {
1299 process_id,
1300 channel,
1301 payload,
1302 } => {
1303 self.notifications
1304 .push_back(super::notification::Notification {
1305 process_id,
1306 channel,
1307 payload,
1308 });
1309 continue;
1310 }
1311 _ => {
1312 return Err(PgError::Protocol(
1313 "Unexpected fast-path message".into(),
1314 ));
1315 }
1316 }
1317 }
1318
1319 if msg_type == b'D' {
1320 let parse_result = {
1321 let payload = &self.buffer[5..msg_len + 1];
1322 parse_data_row_payload_reuse(payload, row_buf)
1323 };
1324
1325 let _ = self.buffer.split_to(msg_len + 1);
1326 if let Err(err) = parse_result {
1327 return self.protocol_desync_error(err);
1328 }
1329 return Ok(msg_type);
1330 }
1331
1332 let _ = self.buffer.split_to(msg_len + 1);
1333 return Ok(msg_type);
1334 }
1335 }
1336
1337 let n = self.read_with_timeout().await?;
1338 if n == 0 {
1339 return self.connection_desync("Connection closed".to_string());
1340 }
1341 }
1342 }
1343
1344 #[inline]
1346 pub(crate) async fn recv_fill_zerocopy_row_fast(
1347 &mut self,
1348 row: &mut PgBytesRow,
1349 ) -> PgResult<u8> {
1350 loop {
1351 if self.buffer.len() >= 5 {
1352 let msg_len = u32::from_be_bytes([
1353 self.buffer[1],
1354 self.buffer[2],
1355 self.buffer[3],
1356 self.buffer[4],
1357 ]) as usize;
1358
1359 if msg_len < 4 {
1360 return self.protocol_desync(format!(
1361 "Invalid message length: {} (minimum 4)",
1362 msg_len
1363 ));
1364 }
1365
1366 if msg_len > MAX_MESSAGE_SIZE {
1367 return self.protocol_desync(format!(
1368 "Message too large: {} bytes (max {})",
1369 msg_len, MAX_MESSAGE_SIZE
1370 ));
1371 }
1372
1373 if self.buffer.len() > msg_len {
1374 let msg_type = self.buffer[0];
1375
1376 if msg_type == b'E' || msg_type == b'A' {
1377 let msg_bytes = self.buffer.split_to(msg_len + 1);
1378 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
1379 Ok(decoded) => decoded,
1380 Err(e) => return self.protocol_desync(e),
1381 };
1382 match msg {
1383 BackendMessage::ErrorResponse(err) => {
1384 return Err(PgError::QueryServer(err.into()));
1385 }
1386 BackendMessage::NotificationResponse {
1387 process_id,
1388 channel,
1389 payload,
1390 } => {
1391 self.notifications
1392 .push_back(super::notification::Notification {
1393 process_id,
1394 channel,
1395 payload,
1396 });
1397 continue;
1398 }
1399 _ => {
1400 return Err(PgError::Protocol(
1401 "Unexpected fast-path message".into(),
1402 ));
1403 }
1404 }
1405 }
1406
1407 if msg_type == b'D' {
1408 let msg_bytes = self.buffer.split_to(msg_len + 1).freeze();
1409 let payload = msg_bytes.slice(5..);
1410 if let Err(err) = parse_data_row_payload_zerocopy(payload, row) {
1411 return self.protocol_desync_error(err);
1412 }
1413 return Ok(msg_type);
1414 }
1415
1416 let _ = self.buffer.split_to(msg_len + 1);
1417 return Ok(msg_type);
1418 }
1419 }
1420
1421 let n = self.read_with_timeout().await?;
1422 if n == 0 {
1423 return self.connection_desync("Connection closed".to_string());
1424 }
1425 }
1426 }
1427
1428 #[inline]
1430 pub(crate) async fn recv_fill_first_column_zerocopy_fast(
1431 &mut self,
1432 first_column: &mut Option<Bytes>,
1433 ) -> PgResult<u8> {
1434 loop {
1435 if self.buffer.len() >= 5 {
1436 let msg_len = u32::from_be_bytes([
1437 self.buffer[1],
1438 self.buffer[2],
1439 self.buffer[3],
1440 self.buffer[4],
1441 ]) as usize;
1442
1443 if msg_len < 4 {
1444 return self.protocol_desync(format!(
1445 "Invalid message length: {} (minimum 4)",
1446 msg_len
1447 ));
1448 }
1449
1450 if msg_len > MAX_MESSAGE_SIZE {
1451 return self.protocol_desync(format!(
1452 "Message too large: {} bytes (max {})",
1453 msg_len, MAX_MESSAGE_SIZE
1454 ));
1455 }
1456
1457 if self.buffer.len() > msg_len {
1458 let msg_type = self.buffer[0];
1459
1460 if msg_type == b'E' || msg_type == b'A' {
1461 let msg_bytes = self.buffer.split_to(msg_len + 1);
1462 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
1463 Ok(decoded) => decoded,
1464 Err(e) => return self.protocol_desync(e),
1465 };
1466 match msg {
1467 BackendMessage::ErrorResponse(err) => {
1468 return Err(PgError::QueryServer(err.into()));
1469 }
1470 BackendMessage::NotificationResponse {
1471 process_id,
1472 channel,
1473 payload,
1474 } => {
1475 self.notifications
1476 .push_back(super::notification::Notification {
1477 process_id,
1478 channel,
1479 payload,
1480 });
1481 continue;
1482 }
1483 _ => {
1484 return Err(PgError::Protocol(
1485 "Unexpected fast-path message".into(),
1486 ));
1487 }
1488 }
1489 }
1490
1491 if msg_type == b'D' {
1492 let msg_bytes = self.buffer.split_to(msg_len + 1).freeze();
1493 let payload = msg_bytes.slice(5..);
1494 match parse_first_column_payload_zerocopy(payload) {
1495 Ok(column) => *first_column = column,
1496 Err(err) => return self.protocol_desync_error(err),
1497 }
1498 return Ok(msg_type);
1499 }
1500
1501 let _ = self.buffer.split_to(msg_len + 1);
1502 return Ok(msg_type);
1503 }
1504 }
1505
1506 let n = self.read_with_timeout().await?;
1507 if n == 0 {
1508 return self.connection_desync("Connection closed".to_string());
1509 }
1510 }
1511 }
1512
1513 #[inline]
1515 pub(crate) async fn recv_fill_first_four_columns_zerocopy_fast(
1516 &mut self,
1517 columns: &mut [Option<Bytes>; 4],
1518 ) -> PgResult<u8> {
1519 loop {
1520 if self.buffer.len() >= 5 {
1521 let msg_len = u32::from_be_bytes([
1522 self.buffer[1],
1523 self.buffer[2],
1524 self.buffer[3],
1525 self.buffer[4],
1526 ]) as usize;
1527
1528 if msg_len < 4 {
1529 return self.protocol_desync(format!(
1530 "Invalid message length: {} (minimum 4)",
1531 msg_len
1532 ));
1533 }
1534
1535 if msg_len > MAX_MESSAGE_SIZE {
1536 return self.protocol_desync(format!(
1537 "Message too large: {} bytes (max {})",
1538 msg_len, MAX_MESSAGE_SIZE
1539 ));
1540 }
1541
1542 if self.buffer.len() > msg_len {
1543 let msg_type = self.buffer[0];
1544
1545 if msg_type == b'E' || msg_type == b'A' {
1546 let msg_bytes = self.buffer.split_to(msg_len + 1);
1547 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
1548 Ok(decoded) => decoded,
1549 Err(e) => return self.protocol_desync(e),
1550 };
1551 match msg {
1552 BackendMessage::ErrorResponse(err) => {
1553 return Err(PgError::QueryServer(err.into()));
1554 }
1555 BackendMessage::NotificationResponse {
1556 process_id,
1557 channel,
1558 payload,
1559 } => {
1560 self.notifications
1561 .push_back(super::notification::Notification {
1562 process_id,
1563 channel,
1564 payload,
1565 });
1566 continue;
1567 }
1568 _ => {
1569 return Err(PgError::Protocol(
1570 "Unexpected fast-path message".into(),
1571 ));
1572 }
1573 }
1574 }
1575
1576 if msg_type == b'D' {
1577 let msg_bytes = self.buffer.split_to(msg_len + 1).freeze();
1578 let payload = msg_bytes.slice(5..);
1579 if let Err(err) =
1580 parse_first_four_columns_payload_zerocopy(payload, columns)
1581 {
1582 return self.protocol_desync_error(err);
1583 }
1584 return Ok(msg_type);
1585 }
1586
1587 let _ = self.buffer.split_to(msg_len + 1);
1588 return Ok(msg_type);
1589 }
1590 }
1591
1592 let n = self.read_with_timeout().await?;
1593 if n == 0 {
1594 return self.connection_desync("Connection closed".to_string());
1595 }
1596 }
1597 }
1598
1599 #[inline]
1605 pub(crate) async fn recv_data_zerocopy(
1606 &mut self,
1607 ) -> PgResult<(u8, Option<Vec<Option<bytes::Bytes>>>)> {
1608 use bytes::Buf;
1609
1610 loop {
1611 if self.buffer.len() >= 5 {
1612 let msg_len = u32::from_be_bytes([
1613 self.buffer[1],
1614 self.buffer[2],
1615 self.buffer[3],
1616 self.buffer[4],
1617 ]) as usize;
1618
1619 if msg_len < 4 {
1620 return self.protocol_desync(format!(
1621 "Invalid message length: {} (minimum 4)",
1622 msg_len
1623 ));
1624 }
1625
1626 if msg_len > MAX_MESSAGE_SIZE {
1627 return self.protocol_desync(format!(
1628 "Message too large: {} bytes (max {})",
1629 msg_len, MAX_MESSAGE_SIZE
1630 ));
1631 }
1632
1633 if self.buffer.len() > msg_len {
1634 let msg_type = self.buffer[0];
1635
1636 if msg_type == b'E' || msg_type == b'A' {
1637 let msg_bytes = self.buffer.split_to(msg_len + 1);
1638 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
1639 Ok(decoded) => decoded,
1640 Err(e) => return self.protocol_desync(e),
1641 };
1642 match msg {
1643 BackendMessage::ErrorResponse(err) => {
1644 return Err(PgError::QueryServer(err.into()));
1645 }
1646 BackendMessage::NotificationResponse {
1647 process_id,
1648 channel,
1649 payload,
1650 } => {
1651 self.notifications
1652 .push_back(super::notification::Notification {
1653 process_id,
1654 channel,
1655 payload,
1656 });
1657 continue;
1658 }
1659 _ => {
1660 return Err(PgError::Protocol(
1661 "Unexpected fast-path message".into(),
1662 ));
1663 }
1664 }
1665 }
1666
1667 if msg_type == b'D' {
1669 let mut msg_bytes = self.buffer.split_to(msg_len + 1);
1671
1672 msg_bytes.advance(5);
1674
1675 if msg_bytes.len() >= 2 {
1676 let raw_count = msg_bytes.get_i16();
1677 if raw_count < 0 {
1678 return self.protocol_desync(format!(
1679 "DataRow invalid column count: {}",
1680 raw_count
1681 ));
1682 }
1683 let column_count = raw_count as usize;
1684 if column_count > msg_bytes.remaining() / 4 + 1 {
1685 return self.protocol_desync(format!(
1686 "DataRow claims {} columns but payload is only {} bytes",
1687 column_count,
1688 msg_bytes.remaining() + 2
1689 ));
1690 }
1691 let mut columns = Vec::with_capacity(column_count);
1692
1693 for _ in 0..column_count {
1694 if msg_bytes.remaining() < 4 {
1695 return self.protocol_desync(
1696 "DataRow truncated: missing column length".into(),
1697 );
1698 }
1699
1700 let len = msg_bytes.get_i32();
1701
1702 if len == -1 {
1703 columns.push(None);
1704 } else {
1705 if len < -1 {
1706 return self.protocol_desync(format!(
1707 "DataRow invalid column length: {}",
1708 len
1709 ));
1710 }
1711 let len = len as usize;
1712 if msg_bytes.remaining() < len {
1713 return self.protocol_desync(
1714 "DataRow truncated: column data exceeds payload".into(),
1715 );
1716 }
1717 let col_data = msg_bytes.split_to(len).freeze();
1718 columns.push(Some(col_data));
1719 }
1720 }
1721
1722 if msg_bytes.remaining() != 0 {
1723 return self.protocol_desync("DataRow has trailing bytes".into());
1724 }
1725
1726 return Ok((msg_type, Some(columns)));
1727 }
1728 return self.protocol_desync("DataRow payload too short".into());
1729 }
1730
1731 let _ = self.buffer.split_to(msg_len + 1);
1733 return Ok((msg_type, None));
1734 }
1735 }
1736
1737 let n = self.read_with_timeout().await?;
1738 if n == 0 {
1739 return self.connection_desync("Connection closed".to_string());
1740 }
1741 }
1742 }
1743
1744 #[inline(always)]
1748 pub(crate) async fn recv_data_ultra(
1749 &mut self,
1750 ) -> PgResult<(u8, Option<(bytes::Bytes, bytes::Bytes)>)> {
1751 use bytes::Buf;
1752
1753 loop {
1754 if self.buffer.len() >= 5 {
1755 let msg_len = u32::from_be_bytes([
1756 self.buffer[1],
1757 self.buffer[2],
1758 self.buffer[3],
1759 self.buffer[4],
1760 ]) as usize;
1761
1762 if msg_len < 4 {
1763 return self.protocol_desync(format!(
1764 "Invalid message length: {} (minimum 4)",
1765 msg_len
1766 ));
1767 }
1768
1769 if msg_len > MAX_MESSAGE_SIZE {
1770 return self.protocol_desync(format!(
1771 "Message too large: {} bytes (max {})",
1772 msg_len, MAX_MESSAGE_SIZE
1773 ));
1774 }
1775
1776 if self.buffer.len() > msg_len {
1777 let msg_type = self.buffer[0];
1778
1779 if msg_type == b'E' || msg_type == b'A' {
1781 let msg_bytes = self.buffer.split_to(msg_len + 1);
1782 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
1783 Ok(decoded) => decoded,
1784 Err(e) => return self.protocol_desync(e),
1785 };
1786 match msg {
1787 BackendMessage::ErrorResponse(err) => {
1788 return Err(PgError::QueryServer(err.into()));
1789 }
1790 BackendMessage::NotificationResponse {
1791 process_id,
1792 channel,
1793 payload,
1794 } => {
1795 self.notifications
1796 .push_back(super::notification::Notification {
1797 process_id,
1798 channel,
1799 payload,
1800 });
1801 continue;
1802 }
1803 _ => {
1804 return Err(PgError::Protocol(
1805 "Unexpected fast-path message".into(),
1806 ));
1807 }
1808 }
1809 }
1810
1811 if msg_type == b'D' {
1812 let mut msg_bytes = self.buffer.split_to(msg_len + 1);
1813 msg_bytes.advance(5); if msg_bytes.remaining() < 2 {
1817 return self.protocol_desync(
1818 "DataRow ultra: too short for column count".into(),
1819 );
1820 }
1821
1822 let col_count = msg_bytes.get_i16();
1824 if col_count != 2 {
1825 return self.protocol_desync(format!(
1826 "DataRow ultra expects exactly 2 columns, got {}",
1827 col_count
1828 ));
1829 }
1830
1831 if msg_bytes.remaining() < 4 {
1832 return self.protocol_desync(
1833 "DataRow ultra: truncated before col0 length".into(),
1834 );
1835 }
1836 let len0 = msg_bytes.get_i32();
1837 let col0 = if len0 > 0 {
1838 let len0 = len0 as usize;
1839 if msg_bytes.remaining() < len0 {
1840 return self.protocol_desync(
1841 "DataRow ultra: col0 data exceeds payload".into(),
1842 );
1843 }
1844 msg_bytes.split_to(len0).freeze()
1845 } else if len0 == 0 {
1846 bytes::Bytes::new()
1847 } else if len0 == -1 {
1848 return self.protocol_desync(
1849 "DataRow ultra does not support NULL columns".into(),
1850 );
1851 } else {
1852 return self.protocol_desync(format!(
1853 "DataRow ultra: invalid col0 length {}",
1854 len0
1855 ));
1856 };
1857
1858 if msg_bytes.remaining() < 4 {
1859 return self.protocol_desync(
1860 "DataRow ultra: truncated before col1 length".into(),
1861 );
1862 }
1863 let len1 = msg_bytes.get_i32();
1864 let col1 = if len1 > 0 {
1865 let len1 = len1 as usize;
1866 if msg_bytes.remaining() < len1 {
1867 return self.protocol_desync(
1868 "DataRow ultra: col1 data exceeds payload".into(),
1869 );
1870 }
1871 msg_bytes.split_to(len1).freeze()
1872 } else if len1 == 0 {
1873 bytes::Bytes::new()
1874 } else if len1 == -1 {
1875 return self.protocol_desync(
1876 "DataRow ultra does not support NULL columns".into(),
1877 );
1878 } else {
1879 return self.protocol_desync(format!(
1880 "DataRow ultra: invalid col1 length {}",
1881 len1
1882 ));
1883 };
1884
1885 if msg_bytes.remaining() != 0 {
1886 return self.protocol_desync(
1887 "DataRow ultra: trailing bytes after expected columns".into(),
1888 );
1889 }
1890
1891 return Ok((msg_type, Some((col0, col1))));
1892 }
1893
1894 let _ = self.buffer.split_to(msg_len + 1);
1896 return Ok((msg_type, None));
1897 }
1898 }
1899
1900 let n = self.read_with_timeout().await?;
1901 if n == 0 {
1902 return self.connection_desync("Connection closed".to_string());
1903 }
1904 }
1905 }
1906}
1907
1908#[cfg(test)]
1909mod tests {
1910 use super::*;
1911
1912 #[cfg(unix)]
1913 fn test_conn() -> PgConnection {
1914 use crate::driver::connection::StatementCache;
1915 use crate::driver::stream::PgStream;
1916 use std::collections::{HashMap, VecDeque};
1917 use std::num::NonZeroUsize;
1918 use tokio::net::UnixStream;
1919
1920 let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
1921 PgConnection {
1922 stream: PgStream::Unix(unix_stream),
1923 buffer: BytesMut::with_capacity(1024),
1924 write_buf: BytesMut::with_capacity(1024),
1925 sql_buf: BytesMut::with_capacity(256),
1926 params_buf: Vec::new(),
1927 prepared_statements: HashMap::new(),
1928 stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
1929 column_info_cache: HashMap::new(),
1930 process_id: 0,
1931 cancel_key_bytes: Vec::new(),
1932 requested_protocol_minor: PgConnection::default_protocol_minor(),
1933 negotiated_protocol_minor: PgConnection::default_protocol_minor(),
1934 notifications: VecDeque::new(),
1935 replication_stream_active: false,
1936 replication_mode_enabled: false,
1937 last_replication_wal_end: None,
1938 io_desynced: false,
1939 pending_statement_closes: Vec::new(),
1940 draining_statement_closes: false,
1941 }
1942 }
1943
1944 fn build_data_row_payload(columns: &[Option<&[u8]>]) -> Bytes {
1945 let mut payload = Vec::new();
1946 payload.extend_from_slice(&(columns.len() as i16).to_be_bytes());
1947 for column in columns {
1948 match column {
1949 Some(bytes) => {
1950 payload.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
1951 payload.extend_from_slice(bytes);
1952 }
1953 None => payload.extend_from_slice(&(-1i32).to_be_bytes()),
1954 }
1955 }
1956 Bytes::from(payload)
1957 }
1958
1959 fn push_data_row_frame(conn: &mut PgConnection, payload: &[u8]) {
1960 let msg_len = payload.len() + 4;
1961 conn.buffer.extend_from_slice(b"D");
1962 conn.buffer
1963 .extend_from_slice(&(msg_len as u32).to_be_bytes());
1964 conn.buffer.extend_from_slice(payload);
1965 }
1966
1967 fn push_one_column_datarow_without_column_length(conn: &mut PgConnection) {
1968 push_data_row_frame(conn, &[0, 1]);
1969 }
1970
1971 fn assert_protocol_error_contains(err: PgError, expected: &str) {
1972 match err {
1973 PgError::Protocol(msg) => assert!(
1974 msg.contains(expected),
1975 "expected protocol error containing {expected:?}, got {msg:?}"
1976 ),
1977 err => panic!("expected protocol error containing {expected:?}, got {err:?}"),
1978 }
1979 }
1980
1981 #[test]
1982 fn parse_first_four_columns_payload_zerocopy_reads_values() {
1983 let payload = build_data_row_payload(&[Some(b"10"), None, Some(b"30"), Some(b"")]);
1984 let mut columns = [None, None, None, None];
1985
1986 parse_first_four_columns_payload_zerocopy(payload, &mut columns).unwrap();
1987
1988 assert_eq!(columns[0].as_deref(), Some(&b"10"[..]));
1989 assert_eq!(columns[1].as_deref(), None);
1990 assert_eq!(columns[2].as_deref(), Some(&b"30"[..]));
1991 assert_eq!(columns[3].as_deref(), Some(&b""[..]));
1992 }
1993
1994 #[test]
1995 fn parse_first_four_columns_payload_zerocopy_rejects_wrong_arity() {
1996 let payload = build_data_row_payload(&[Some(b"1"), Some(b"2"), Some(b"3")]);
1997 let mut columns = [None, None, None, None];
1998
1999 let err = parse_first_four_columns_payload_zerocopy(payload, &mut columns).unwrap_err();
2000
2001 assert!(
2002 err.to_string()
2003 .contains("fast-path expects exactly 4 columns")
2004 );
2005 }
2006
2007 #[cfg(unix)]
2008 #[tokio::test]
2009 async fn recv_data_zerocopy_rejects_datarow_length_4() {
2010 let mut conn = test_conn();
2011 conn.buffer.extend_from_slice(&[b'D', 0, 0, 0, 4]);
2012
2013 let err = conn.recv_data_zerocopy().await.unwrap_err();
2014
2015 assert!(err.to_string().contains("DataRow payload too short"));
2016 assert!(conn.is_io_desynced());
2017 }
2018
2019 #[cfg(unix)]
2020 #[tokio::test]
2021 async fn recv_data_zerocopy_rejects_datarow_length_5() {
2022 let mut conn = test_conn();
2023 conn.buffer.extend_from_slice(&[b'D', 0, 0, 0, 5, 0]);
2024
2025 let err = conn.recv_data_zerocopy().await.unwrap_err();
2026
2027 assert!(err.to_string().contains("DataRow payload too short"));
2028 assert!(conn.is_io_desynced());
2029 }
2030
2031 #[cfg(unix)]
2032 #[tokio::test]
2033 async fn recv_with_data_fast_desyncs_on_malformed_datarow() {
2034 let mut conn = test_conn();
2035 push_one_column_datarow_without_column_length(&mut conn);
2036
2037 let err = conn.recv_with_data_fast().await.unwrap_err();
2038
2039 assert_protocol_error_contains(err, "DataRow truncated");
2040 assert!(conn.is_io_desynced());
2041 }
2042
2043 #[cfg(unix)]
2044 #[tokio::test]
2045 async fn recv_fill_data_row_fast_desyncs_on_malformed_datarow() {
2046 let mut conn = test_conn();
2047 let mut row = Vec::new();
2048 push_one_column_datarow_without_column_length(&mut conn);
2049
2050 let err = conn.recv_fill_data_row_fast(&mut row).await.unwrap_err();
2051
2052 assert_protocol_error_contains(err, "DataRow truncated");
2053 assert!(conn.is_io_desynced());
2054 }
2055
2056 #[cfg(unix)]
2057 #[tokio::test]
2058 async fn recv_fill_zerocopy_row_fast_desyncs_on_malformed_datarow() {
2059 let mut conn = test_conn();
2060 let mut row = PgBytesRow::default();
2061 push_one_column_datarow_without_column_length(&mut conn);
2062
2063 let err = conn
2064 .recv_fill_zerocopy_row_fast(&mut row)
2065 .await
2066 .unwrap_err();
2067
2068 assert_protocol_error_contains(err, "DataRow truncated");
2069 assert!(conn.is_io_desynced());
2070 }
2071
2072 #[cfg(unix)]
2073 #[tokio::test]
2074 async fn recv_fill_first_column_zerocopy_fast_desyncs_on_malformed_datarow() {
2075 let mut conn = test_conn();
2076 let mut first_column = None;
2077 push_one_column_datarow_without_column_length(&mut conn);
2078
2079 let err = conn
2080 .recv_fill_first_column_zerocopy_fast(&mut first_column)
2081 .await
2082 .unwrap_err();
2083
2084 assert_protocol_error_contains(err, "DataRow truncated");
2085 assert!(conn.is_io_desynced());
2086 }
2087
2088 #[cfg(unix)]
2089 #[tokio::test]
2090 async fn recv_fill_first_four_columns_zerocopy_fast_desyncs_on_malformed_datarow() {
2091 let mut conn = test_conn();
2092 let mut columns = [None, None, None, None];
2093 push_one_column_datarow_without_column_length(&mut conn);
2094
2095 let err = conn
2096 .recv_fill_first_four_columns_zerocopy_fast(&mut columns)
2097 .await
2098 .unwrap_err();
2099
2100 assert_protocol_error_contains(err, "DataRow fast-path expects exactly 4 columns");
2101 assert!(conn.is_io_desynced());
2102 }
2103
2104 #[cfg(unix)]
2105 #[tokio::test]
2106 async fn recv_data_ultra_desyncs_on_malformed_datarow() {
2107 let mut conn = test_conn();
2108 push_one_column_datarow_without_column_length(&mut conn);
2109
2110 let err = conn.recv_data_ultra().await.unwrap_err();
2111
2112 assert_protocol_error_contains(err, "DataRow ultra expects exactly 2 columns");
2113 assert!(conn.is_io_desynced());
2114 }
2115
2116 #[cfg(unix)]
2117 #[tokio::test]
2118 async fn recv_msg_type_fast_rejects_malformed_ready_for_query() {
2119 let mut conn = test_conn();
2120 conn.buffer.extend_from_slice(&[b'Z', 0, 0, 0, 5, b'X']);
2121
2122 let err = conn.recv_msg_type_fast().await.unwrap_err();
2123
2124 assert!(err.to_string().contains("Unknown transaction status"));
2125 assert!(conn.is_io_desynced());
2126 }
2127
2128 #[cfg(unix)]
2129 #[tokio::test]
2130 async fn recv_msg_type_fast_rejects_malformed_command_complete() {
2131 let mut conn = test_conn();
2132 conn.buffer.extend_from_slice(&[
2133 b'C', 0, 0, 0, 12, b'S', b'E', b'L', b'E', b'C', b'T', b' ', b'1',
2134 ]);
2135
2136 let err = conn.recv_msg_type_fast().await.unwrap_err();
2137
2138 assert!(
2139 err.to_string()
2140 .contains("CommandComplete missing null terminator")
2141 );
2142 assert!(conn.is_io_desynced());
2143 }
2144}