1use super::{PgConnection, PgError, PgResult, is_ignorable_session_message};
6use crate::protocol::{BackendMessage, FrontendMessage, PgEncoder};
7use bytes::BytesMut;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9
10pub(crate) const MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
15const DEFAULT_WRITE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
18
19#[inline]
20fn parse_data_row_payload_owned(payload: &[u8]) -> PgResult<Vec<Option<Vec<u8>>>> {
21 if payload.len() < 2 {
22 return Err(PgError::Protocol("DataRow payload too short".into()));
23 }
24
25 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
26 if raw_count < 0 {
27 return Err(PgError::Protocol(format!(
28 "DataRow invalid column count: {}",
29 raw_count
30 )));
31 }
32 let column_count = raw_count as usize;
33 if column_count > (payload.len() - 2) / 4 + 1 {
34 return Err(PgError::Protocol(format!(
35 "DataRow claims {} columns but payload is only {} bytes",
36 column_count,
37 payload.len()
38 )));
39 }
40
41 let mut columns = Vec::with_capacity(column_count);
42 let mut pos = 2;
43 for _ in 0..column_count {
44 if pos + 4 > payload.len() {
45 return Err(PgError::Protocol(
46 "DataRow truncated: missing column length".into(),
47 ));
48 }
49
50 let len = i32::from_be_bytes([
51 payload[pos],
52 payload[pos + 1],
53 payload[pos + 2],
54 payload[pos + 3],
55 ]);
56 pos += 4;
57
58 if len == -1 {
59 columns.push(None);
60 continue;
61 }
62 if len < -1 {
63 return Err(PgError::Protocol(format!(
64 "DataRow invalid column length: {}",
65 len
66 )));
67 }
68
69 let len = len as usize;
70 if len > payload.len().saturating_sub(pos) {
71 return Err(PgError::Protocol(
72 "DataRow truncated: column data exceeds payload".into(),
73 ));
74 }
75 columns.push(Some(payload[pos..pos + len].to_vec()));
76 pos += len;
77 }
78
79 if pos != payload.len() {
80 return Err(PgError::Protocol("DataRow has trailing bytes".into()));
81 }
82
83 Ok(columns)
84}
85
86impl PgConnection {
87 #[inline]
88 pub(crate) fn mark_io_desynced(&mut self) {
89 self.io_desynced = true;
90 }
91
92 #[inline]
93 pub(crate) fn is_io_desynced(&self) -> bool {
94 self.io_desynced
95 }
96
97 #[inline]
98 fn protocol_desync<T>(&mut self, msg: String) -> PgResult<T> {
99 self.mark_io_desynced();
100 Err(PgError::Protocol(msg))
101 }
102
103 #[inline]
104 fn connection_desync<T>(&mut self, msg: String) -> PgResult<T> {
105 self.mark_io_desynced();
106 Err(PgError::Connection(msg))
107 }
108
109 async fn flush_pending_statement_closes(&mut self) -> PgResult<()> {
115 if self.draining_statement_closes || self.pending_statement_closes.is_empty() {
116 return Ok(());
117 }
118
119 self.draining_statement_closes = true;
120 let close_names = std::mem::take(&mut self.pending_statement_closes);
121
122 let estimated_payload_len: usize = close_names
123 .iter()
124 .map(|name| 16usize.saturating_add(name.len()))
125 .sum();
126 let mut buf = BytesMut::with_capacity(estimated_payload_len.saturating_add(5));
127 for stmt_name in &close_names {
128 let close_msg = PgEncoder::try_encode_close(false, stmt_name)
129 .map_err(|e| PgError::Encode(e.to_string()))?;
130 buf.extend_from_slice(&close_msg);
131 }
132 PgEncoder::encode_sync_to(&mut buf);
133
134 if let Err(err) = self
135 .write_all_with_timeout_inner(&buf, "pending statement close write")
136 .await
137 {
138 self.draining_statement_closes = false;
139 return Err(err);
140 }
141 if let Err(err) = self
142 .flush_with_timeout("pending statement close flush")
143 .await
144 {
145 self.draining_statement_closes = false;
146 return Err(err);
147 }
148
149 let mut error: Option<PgError> = None;
150 loop {
151 let msg = match self.recv().await {
152 Ok(msg) => msg,
153 Err(err) => {
154 self.draining_statement_closes = false;
155 return Err(err);
156 }
157 };
158 match msg {
159 BackendMessage::CloseComplete => {}
160 BackendMessage::ReadyForQuery(_) => {
161 self.draining_statement_closes = false;
162 if let Some(err) = error {
163 return Err(err);
164 }
165 return Ok(());
166 }
167 BackendMessage::ErrorResponse(err_fields) => {
168 if error.is_none() {
169 let code_26000 = err_fields.code.eq_ignore_ascii_case("26000");
170 let msg_lower = err_fields.message.to_ascii_lowercase();
171 let missing_prepared = msg_lower.contains("prepared statement")
172 && msg_lower.contains("does not exist");
173 if !(code_26000 && missing_prepared) {
174 error = Some(PgError::QueryServer(err_fields.into()));
175 }
176 }
177 }
178 msg if is_ignorable_session_message(&msg) => {}
179 other => {
180 self.draining_statement_closes = false;
181 return self.protocol_desync(format!(
182 "Unexpected backend message during pending statement close drain: {:?}",
183 other
184 ));
185 }
186 }
187 }
188 }
189
190 pub(crate) async fn write_all_with_timeout(
194 &mut self,
195 bytes: &[u8],
196 operation: &str,
197 ) -> PgResult<()> {
198 if !self.draining_statement_closes && !self.pending_statement_closes.is_empty() {
199 self.flush_pending_statement_closes().await?;
200 }
201 self.write_all_with_timeout_inner(bytes, operation).await
202 }
203
204 async fn write_all_with_timeout_inner(
205 &mut self,
206 bytes: &[u8],
207 operation: &str,
208 ) -> PgResult<()> {
209 if bytes.is_empty() {
210 return Err(PgError::Encode(
211 "refusing to send empty frontend payload".to_string(),
212 ));
213 }
214 use super::stream::PgStream;
215 let mut mark_desync = false;
216 let result = match &mut self.stream {
217 PgStream::Tcp(stream) => {
218 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
219 Ok(Ok(())) => Ok(()),
220 Ok(Err(e)) => {
221 mark_desync = true;
222 Err(PgError::Connection(format!("Write error: {}", e)))
223 }
224 Err(_) => {
225 mark_desync = true;
226 Err(PgError::Timeout(format!(
227 "{} timeout after {:?}",
228 operation, DEFAULT_WRITE_TIMEOUT
229 )))
230 }
231 }
232 }
233 PgStream::Tls(stream) => {
234 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
235 Ok(Ok(())) => Ok(()),
236 Ok(Err(e)) => {
237 mark_desync = true;
238 Err(PgError::Connection(format!("Write error: {}", e)))
239 }
240 Err(_) => {
241 mark_desync = true;
242 Err(PgError::Timeout(format!(
243 "{} timeout after {:?}",
244 operation, DEFAULT_WRITE_TIMEOUT
245 )))
246 }
247 }
248 }
249 #[cfg(all(target_os = "linux", feature = "io_uring"))]
250 PgStream::Uring(stream) => {
251 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
252 Ok(Ok(())) => Ok(()),
253 Ok(Err(e)) => {
254 mark_desync = true;
255 Err(PgError::Connection(format!("Write error: {}", e)))
256 }
257 Err(_) => {
258 mark_desync = true;
259 let _ = stream.abort_inflight();
260 Err(PgError::Timeout(format!(
261 "{} timeout after {:?}",
262 operation, DEFAULT_WRITE_TIMEOUT
263 )))
264 }
265 }
266 }
267 #[cfg(unix)]
268 PgStream::Unix(stream) => {
269 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
270 Ok(Ok(())) => Ok(()),
271 Ok(Err(e)) => {
272 mark_desync = true;
273 Err(PgError::Connection(format!("Write error: {}", e)))
274 }
275 Err(_) => {
276 mark_desync = true;
277 Err(PgError::Timeout(format!(
278 "{} timeout after {:?}",
279 operation, DEFAULT_WRITE_TIMEOUT
280 )))
281 }
282 }
283 }
284 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
285 PgStream::GssEnc(stream) => {
286 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.write_all(bytes)).await {
287 Ok(Ok(())) => Ok(()),
288 Ok(Err(e)) => {
289 mark_desync = true;
290 Err(PgError::Connection(format!("Write error: {}", e)))
291 }
292 Err(_) => {
293 mark_desync = true;
294 Err(PgError::Timeout(format!(
295 "{} timeout after {:?}",
296 operation, DEFAULT_WRITE_TIMEOUT
297 )))
298 }
299 }
300 }
301 };
302 if mark_desync {
303 self.mark_io_desynced();
304 }
305 result
306 }
307
308 pub(crate) async fn flush_with_timeout(&mut self, operation: &str) -> PgResult<()> {
310 use super::stream::PgStream;
311 let mut mark_desync = false;
312 let result = match &mut self.stream {
313 PgStream::Tcp(stream) => {
314 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
315 Ok(Ok(())) => Ok(()),
316 Ok(Err(e)) => {
317 mark_desync = true;
318 Err(PgError::Connection(format!("Flush error: {}", e)))
319 }
320 Err(_) => {
321 mark_desync = true;
322 Err(PgError::Timeout(format!(
323 "{} timeout after {:?}",
324 operation, DEFAULT_WRITE_TIMEOUT
325 )))
326 }
327 }
328 }
329 PgStream::Tls(stream) => {
330 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
331 Ok(Ok(())) => Ok(()),
332 Ok(Err(e)) => {
333 mark_desync = true;
334 Err(PgError::Connection(format!("Flush error: {}", e)))
335 }
336 Err(_) => {
337 mark_desync = true;
338 Err(PgError::Timeout(format!(
339 "{} timeout after {:?}",
340 operation, DEFAULT_WRITE_TIMEOUT
341 )))
342 }
343 }
344 }
345 #[cfg(all(target_os = "linux", feature = "io_uring"))]
346 PgStream::Uring(stream) => {
347 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
348 Ok(Ok(())) => Ok(()),
349 Ok(Err(e)) => {
350 mark_desync = true;
351 Err(PgError::Connection(format!("Flush error: {}", e)))
352 }
353 Err(_) => {
354 mark_desync = true;
355 let _ = stream.abort_inflight();
356 Err(PgError::Timeout(format!(
357 "{} timeout after {:?}",
358 operation, DEFAULT_WRITE_TIMEOUT
359 )))
360 }
361 }
362 }
363 #[cfg(unix)]
364 PgStream::Unix(stream) => {
365 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
366 Ok(Ok(())) => Ok(()),
367 Ok(Err(e)) => {
368 mark_desync = true;
369 Err(PgError::Connection(format!("Flush error: {}", e)))
370 }
371 Err(_) => {
372 mark_desync = true;
373 Err(PgError::Timeout(format!(
374 "{} timeout after {:?}",
375 operation, DEFAULT_WRITE_TIMEOUT
376 )))
377 }
378 }
379 }
380 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
381 PgStream::GssEnc(stream) => {
382 match tokio::time::timeout(DEFAULT_WRITE_TIMEOUT, stream.flush()).await {
383 Ok(Ok(())) => Ok(()),
384 Ok(Err(e)) => {
385 mark_desync = true;
386 Err(PgError::Connection(format!("Flush error: {}", e)))
387 }
388 Err(_) => {
389 mark_desync = true;
390 Err(PgError::Timeout(format!(
391 "{} timeout after {:?}",
392 operation, DEFAULT_WRITE_TIMEOUT
393 )))
394 }
395 }
396 }
397 };
398 if mark_desync {
399 self.mark_io_desynced();
400 }
401 result
402 }
403
404 pub async fn send(&mut self, msg: FrontendMessage) -> PgResult<()> {
406 let bytes = msg
407 .encode_checked()
408 .map_err(|e| PgError::Encode(e.to_string()))?;
409 self.write_all_with_timeout(&bytes, "send frontend message")
410 .await?;
411 Ok(())
412 }
413
414 pub async fn recv(&mut self) -> PgResult<BackendMessage> {
417 loop {
418 if self.buffer.len() >= 5 {
420 let msg_len = u32::from_be_bytes([
421 self.buffer[1],
422 self.buffer[2],
423 self.buffer[3],
424 self.buffer[4],
425 ]) as usize;
426
427 if msg_len < 4 {
428 return self.protocol_desync(format!(
429 "Invalid message length: {} (minimum 4)",
430 msg_len
431 ));
432 }
433
434 if msg_len > MAX_MESSAGE_SIZE {
435 return self.protocol_desync(format!(
436 "Message too large: {} bytes (max {})",
437 msg_len, MAX_MESSAGE_SIZE
438 ));
439 }
440
441 if self.buffer.len() > msg_len {
442 let msg_bytes = self.buffer.split_to(msg_len + 1);
444 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
445 Ok(decoded) => decoded,
446 Err(e) => return self.protocol_desync(e),
447 };
448
449 if let BackendMessage::NotificationResponse {
451 process_id,
452 channel,
453 payload,
454 } = msg
455 {
456 self.notifications
457 .push_back(super::notification::Notification {
458 process_id,
459 channel,
460 payload,
461 });
462 continue; }
464
465 return Ok(msg);
466 }
467 }
468
469 let n = self.read_with_timeout().await?;
470 if n == 0 {
471 return self.connection_desync("Connection closed".to_string());
472 }
473 }
474 }
475
476 pub(crate) async fn recv_without_timeout(&mut self) -> PgResult<BackendMessage> {
483 loop {
484 if self.buffer.len() >= 5 {
485 let msg_len = u32::from_be_bytes([
486 self.buffer[1],
487 self.buffer[2],
488 self.buffer[3],
489 self.buffer[4],
490 ]) as usize;
491
492 if msg_len < 4 {
493 return self.protocol_desync(format!(
494 "Invalid message length: {} (minimum 4)",
495 msg_len
496 ));
497 }
498
499 if msg_len > MAX_MESSAGE_SIZE {
500 return self.protocol_desync(format!(
501 "Message too large: {} bytes (max {})",
502 msg_len, MAX_MESSAGE_SIZE
503 ));
504 }
505
506 if self.buffer.len() > msg_len {
507 let msg_bytes = self.buffer.split_to(msg_len + 1);
508 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
509 Ok(decoded) => decoded,
510 Err(e) => return self.protocol_desync(e),
511 };
512
513 if let BackendMessage::NotificationResponse {
514 process_id,
515 channel,
516 payload,
517 } = msg
518 {
519 self.notifications
520 .push_back(super::notification::Notification {
521 process_id,
522 channel,
523 payload,
524 });
525 continue;
526 }
527
528 return Ok(msg);
529 }
530 }
531
532 let n = if self.buffer.is_empty() {
533 self.read_without_timeout().await?
534 } else {
535 self.read_with_timeout().await?
536 };
537 if n == 0 {
538 return self.connection_desync("Connection closed".to_string());
539 }
540 }
541 }
542
543 #[inline]
548 pub(crate) async fn read_with_timeout(&mut self) -> PgResult<usize> {
549 if self.buffer.capacity() - self.buffer.len() < 65536 {
550 self.buffer.reserve(131072);
551 }
552
553 use super::stream::PgStream;
554 let (stream, buffer) = (&mut self.stream, &mut self.buffer);
555 let mut mark_desync = false;
556 let result = match stream {
557 PgStream::Tcp(stream) => {
558 match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_buf(buffer)).await {
559 Ok(Ok(n)) => Ok(n),
560 Ok(Err(e)) => {
561 mark_desync = true;
562 Err(PgError::Connection(format!("Read error: {}", e)))
563 }
564 Err(_) => {
565 mark_desync = true;
566 Err(PgError::Connection(format!(
567 "Read timeout after {:?} — possible Slowloris attack or dead connection",
568 DEFAULT_READ_TIMEOUT
569 )))
570 }
571 }
572 }
573 PgStream::Tls(stream) => {
574 match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_buf(buffer)).await {
575 Ok(Ok(n)) => Ok(n),
576 Ok(Err(e)) => {
577 mark_desync = true;
578 Err(PgError::Connection(format!("Read error: {}", e)))
579 }
580 Err(_) => {
581 mark_desync = true;
582 Err(PgError::Connection(format!(
583 "Read timeout after {:?} — possible Slowloris attack or dead connection",
584 DEFAULT_READ_TIMEOUT
585 )))
586 }
587 }
588 }
589 #[cfg(all(target_os = "linux", feature = "io_uring"))]
590 PgStream::Uring(stream) => {
591 match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_into(buffer, 131072))
592 .await
593 {
594 Ok(Ok(n)) => Ok(n),
595 Ok(Err(e)) => {
596 mark_desync = true;
597 Err(PgError::Connection(format!("Read error: {}", e)))
598 }
599 Err(_) => {
600 mark_desync = true;
601 let _ = stream.abort_inflight();
602 Err(PgError::Connection(format!(
603 "Read timeout after {:?} — possible Slowloris attack or dead connection",
604 DEFAULT_READ_TIMEOUT
605 )))
606 }
607 }
608 }
609 #[cfg(unix)]
610 PgStream::Unix(stream) => {
611 match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_buf(buffer)).await {
612 Ok(Ok(n)) => Ok(n),
613 Ok(Err(e)) => {
614 mark_desync = true;
615 Err(PgError::Connection(format!("Read error: {}", e)))
616 }
617 Err(_) => {
618 mark_desync = true;
619 Err(PgError::Connection(format!(
620 "Read timeout after {:?} — possible Slowloris attack or dead connection",
621 DEFAULT_READ_TIMEOUT
622 )))
623 }
624 }
625 }
626 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
627 PgStream::GssEnc(stream) => {
628 match tokio::time::timeout(DEFAULT_READ_TIMEOUT, stream.read_buf(buffer)).await {
629 Ok(Ok(n)) => Ok(n),
630 Ok(Err(e)) => {
631 mark_desync = true;
632 Err(PgError::Connection(format!("Read error: {}", e)))
633 }
634 Err(_) => {
635 mark_desync = true;
636 Err(PgError::Connection(format!(
637 "Read timeout after {:?} — possible Slowloris attack or dead connection",
638 DEFAULT_READ_TIMEOUT
639 )))
640 }
641 }
642 }
643 };
644 if mark_desync {
645 self.mark_io_desynced();
646 }
647 result
648 }
649
650 pub(crate) async fn read_without_timeout(&mut self) -> PgResult<usize> {
654 if self.buffer.capacity() - self.buffer.len() < 65536 {
655 self.buffer.reserve(131072);
656 }
657
658 use super::stream::PgStream;
659 let (stream, buffer) = (&mut self.stream, &mut self.buffer);
660 let read_result = match stream {
661 PgStream::Tcp(stream) => stream.read_buf(buffer).await,
662 PgStream::Tls(stream) => stream.read_buf(buffer).await,
663 #[cfg(all(target_os = "linux", feature = "io_uring"))]
664 PgStream::Uring(stream) => stream.read_into(buffer, 131072).await,
665 #[cfg(unix)]
666 PgStream::Unix(stream) => stream.read_buf(buffer).await,
667 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
668 PgStream::GssEnc(stream) => stream.read_buf(buffer).await,
669 };
670
671 match read_result {
672 Ok(n) => Ok(n),
673 Err(e) => {
674 self.mark_io_desynced();
675 Err(PgError::Connection(format!("Read error: {}", e)))
676 }
677 }
678 }
679
680 pub async fn send_bytes(&mut self, bytes: &[u8]) -> PgResult<()> {
684 self.write_all_with_timeout(bytes, "send raw bytes").await?;
685 self.flush_with_timeout("flush raw bytes").await?;
686 Ok(())
687 }
688
689 #[inline]
694 pub fn buffer_bytes(&mut self, bytes: &[u8]) {
695 self.write_buf.extend_from_slice(bytes);
696 }
697
698 pub async fn flush_write_buf(&mut self) -> PgResult<()> {
701 if !self.write_buf.is_empty() {
702 let payload = std::mem::take(&mut self.write_buf);
703 self.write_all_with_timeout(&payload, "flush write buffer")
704 .await?;
705 self.flush_with_timeout("flush write buffer").await?;
706 }
707 Ok(())
708 }
709
710 #[inline]
714 pub(crate) async fn recv_msg_type_fast(&mut self) -> PgResult<u8> {
715 loop {
716 if self.buffer.len() >= 5 {
717 let msg_len = u32::from_be_bytes([
718 self.buffer[1],
719 self.buffer[2],
720 self.buffer[3],
721 self.buffer[4],
722 ]) as usize;
723
724 if msg_len < 4 {
725 return self.protocol_desync(format!(
726 "Invalid message length: {} (minimum 4)",
727 msg_len
728 ));
729 }
730
731 if msg_len > MAX_MESSAGE_SIZE {
732 return self.protocol_desync(format!(
733 "Message too large: {} bytes (max {})",
734 msg_len, MAX_MESSAGE_SIZE
735 ));
736 }
737
738 if self.buffer.len() > msg_len {
739 let msg_type = self.buffer[0];
740
741 if msg_type == b'E' || msg_type == b'A' {
742 let msg_bytes = self.buffer.split_to(msg_len + 1);
743 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
744 Ok(decoded) => decoded,
745 Err(e) => return self.protocol_desync(e),
746 };
747 match msg {
748 BackendMessage::ErrorResponse(err) => {
749 return Err(PgError::QueryServer(err.into()));
750 }
751 BackendMessage::NotificationResponse {
752 process_id,
753 channel,
754 payload,
755 } => {
756 self.notifications
757 .push_back(super::notification::Notification {
758 process_id,
759 channel,
760 payload,
761 });
762 continue;
763 }
764 _ => {
765 return Err(PgError::Protocol(
766 "Unexpected fast-path message".into(),
767 ));
768 }
769 }
770 }
771
772 let _ = self.buffer.split_to(msg_len + 1);
773 return Ok(msg_type);
774 }
775 }
776
777 let n = self.read_with_timeout().await?;
778 if n == 0 {
779 return self.connection_desync("Connection closed".to_string());
780 }
781 }
782 }
783
784 #[inline]
790 pub(crate) async fn recv_with_data_fast(
791 &mut self,
792 ) -> PgResult<(u8, Option<Vec<Option<Vec<u8>>>>)> {
793 loop {
794 if self.buffer.len() >= 5 {
795 let msg_len = u32::from_be_bytes([
796 self.buffer[1],
797 self.buffer[2],
798 self.buffer[3],
799 self.buffer[4],
800 ]) as usize;
801
802 if msg_len < 4 {
803 return self.protocol_desync(format!(
804 "Invalid message length: {} (minimum 4)",
805 msg_len
806 ));
807 }
808
809 if msg_len > MAX_MESSAGE_SIZE {
810 return self.protocol_desync(format!(
811 "Message too large: {} bytes (max {})",
812 msg_len, MAX_MESSAGE_SIZE
813 ));
814 }
815
816 if self.buffer.len() > msg_len {
817 let msg_type = self.buffer[0];
818
819 if msg_type == b'E' || msg_type == b'A' {
820 let msg_bytes = self.buffer.split_to(msg_len + 1);
821 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
822 Ok(decoded) => decoded,
823 Err(e) => return self.protocol_desync(e),
824 };
825 match msg {
826 BackendMessage::ErrorResponse(err) => {
827 return Err(PgError::QueryServer(err.into()));
828 }
829 BackendMessage::NotificationResponse {
830 process_id,
831 channel,
832 payload,
833 } => {
834 self.notifications
835 .push_back(super::notification::Notification {
836 process_id,
837 channel,
838 payload,
839 });
840 continue;
841 }
842 _ => {
843 return Err(PgError::Protocol(
844 "Unexpected fast-path message".into(),
845 ));
846 }
847 }
848 }
849
850 if msg_type == b'D' {
852 let parse_result = {
853 let payload = &self.buffer[5..msg_len + 1];
854 parse_data_row_payload_owned(payload)
855 };
856
857 let _ = self.buffer.split_to(msg_len + 1);
858 match parse_result {
859 Ok(columns) => return Ok((msg_type, Some(columns))),
860 Err(err) => return Err(err),
861 }
862 }
863
864 let _ = self.buffer.split_to(msg_len + 1);
866 return Ok((msg_type, None));
867 }
868 }
869
870 let n = self.read_with_timeout().await?;
871 if n == 0 {
872 return self.connection_desync("Connection closed".to_string());
873 }
874 }
875 }
876
877 #[inline]
883 pub(crate) async fn recv_data_zerocopy(
884 &mut self,
885 ) -> PgResult<(u8, Option<Vec<Option<bytes::Bytes>>>)> {
886 use bytes::Buf;
887
888 loop {
889 if self.buffer.len() >= 5 {
890 let msg_len = u32::from_be_bytes([
891 self.buffer[1],
892 self.buffer[2],
893 self.buffer[3],
894 self.buffer[4],
895 ]) as usize;
896
897 if msg_len < 4 {
898 return self.protocol_desync(format!(
899 "Invalid message length: {} (minimum 4)",
900 msg_len
901 ));
902 }
903
904 if msg_len > MAX_MESSAGE_SIZE {
905 return self.protocol_desync(format!(
906 "Message too large: {} bytes (max {})",
907 msg_len, MAX_MESSAGE_SIZE
908 ));
909 }
910
911 if self.buffer.len() > msg_len {
912 let msg_type = self.buffer[0];
913
914 if msg_type == b'E' || msg_type == b'A' {
915 let msg_bytes = self.buffer.split_to(msg_len + 1);
916 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
917 Ok(decoded) => decoded,
918 Err(e) => return self.protocol_desync(e),
919 };
920 match msg {
921 BackendMessage::ErrorResponse(err) => {
922 return Err(PgError::QueryServer(err.into()));
923 }
924 BackendMessage::NotificationResponse {
925 process_id,
926 channel,
927 payload,
928 } => {
929 self.notifications
930 .push_back(super::notification::Notification {
931 process_id,
932 channel,
933 payload,
934 });
935 continue;
936 }
937 _ => {
938 return Err(PgError::Protocol(
939 "Unexpected fast-path message".into(),
940 ));
941 }
942 }
943 }
944
945 if msg_type == b'D' {
947 let mut msg_bytes = self.buffer.split_to(msg_len + 1);
949
950 msg_bytes.advance(5);
952
953 if msg_bytes.len() >= 2 {
954 let raw_count = msg_bytes.get_i16();
955 if raw_count < 0 {
956 return Err(PgError::Protocol(format!(
957 "DataRow invalid column count: {}",
958 raw_count
959 )));
960 }
961 let column_count = raw_count as usize;
962 if column_count > msg_bytes.remaining() / 4 + 1 {
963 return Err(PgError::Protocol(format!(
964 "DataRow claims {} columns but payload is only {} bytes",
965 column_count,
966 msg_bytes.remaining() + 2
967 )));
968 }
969 let mut columns = Vec::with_capacity(column_count);
970
971 for _ in 0..column_count {
972 if msg_bytes.remaining() < 4 {
973 return Err(PgError::Protocol(
974 "DataRow truncated: missing column length".into(),
975 ));
976 }
977
978 let len = msg_bytes.get_i32();
979
980 if len == -1 {
981 columns.push(None);
982 } else {
983 if len < -1 {
984 return Err(PgError::Protocol(format!(
985 "DataRow invalid column length: {}",
986 len
987 )));
988 }
989 let len = len as usize;
990 if msg_bytes.remaining() < len {
991 return Err(PgError::Protocol(
992 "DataRow truncated: column data exceeds payload".into(),
993 ));
994 }
995 let col_data = msg_bytes.split_to(len).freeze();
996 columns.push(Some(col_data));
997 }
998 }
999
1000 if msg_bytes.remaining() != 0 {
1001 return Err(PgError::Protocol("DataRow has trailing bytes".into()));
1002 }
1003
1004 return Ok((msg_type, Some(columns)));
1005 }
1006 return Ok((msg_type, None));
1007 }
1008
1009 let _ = self.buffer.split_to(msg_len + 1);
1011 return Ok((msg_type, None));
1012 }
1013 }
1014
1015 let n = self.read_with_timeout().await?;
1016 if n == 0 {
1017 return self.connection_desync("Connection closed".to_string());
1018 }
1019 }
1020 }
1021
1022 #[inline(always)]
1026 pub(crate) async fn recv_data_ultra(
1027 &mut self,
1028 ) -> PgResult<(u8, Option<(bytes::Bytes, bytes::Bytes)>)> {
1029 use bytes::Buf;
1030
1031 loop {
1032 if self.buffer.len() >= 5 {
1033 let msg_len = u32::from_be_bytes([
1034 self.buffer[1],
1035 self.buffer[2],
1036 self.buffer[3],
1037 self.buffer[4],
1038 ]) as usize;
1039
1040 if msg_len < 4 {
1041 return self.protocol_desync(format!(
1042 "Invalid message length: {} (minimum 4)",
1043 msg_len
1044 ));
1045 }
1046
1047 if msg_len > MAX_MESSAGE_SIZE {
1048 return self.protocol_desync(format!(
1049 "Message too large: {} bytes (max {})",
1050 msg_len, MAX_MESSAGE_SIZE
1051 ));
1052 }
1053
1054 if self.buffer.len() > msg_len {
1055 let msg_type = self.buffer[0];
1056
1057 if msg_type == b'E' || msg_type == b'A' {
1059 let msg_bytes = self.buffer.split_to(msg_len + 1);
1060 let (msg, _) = match BackendMessage::decode(&msg_bytes) {
1061 Ok(decoded) => decoded,
1062 Err(e) => return self.protocol_desync(e),
1063 };
1064 match msg {
1065 BackendMessage::ErrorResponse(err) => {
1066 return Err(PgError::QueryServer(err.into()));
1067 }
1068 BackendMessage::NotificationResponse {
1069 process_id,
1070 channel,
1071 payload,
1072 } => {
1073 self.notifications
1074 .push_back(super::notification::Notification {
1075 process_id,
1076 channel,
1077 payload,
1078 });
1079 continue;
1080 }
1081 _ => {
1082 return Err(PgError::Protocol(
1083 "Unexpected fast-path message".into(),
1084 ));
1085 }
1086 }
1087 }
1088
1089 if msg_type == b'D' {
1090 let mut msg_bytes = self.buffer.split_to(msg_len + 1);
1091 msg_bytes.advance(5); if msg_bytes.remaining() < 2 {
1095 return Err(PgError::Protocol(
1096 "DataRow ultra: too short for column count".into(),
1097 ));
1098 }
1099
1100 let col_count = msg_bytes.get_i16();
1102 if col_count != 2 {
1103 return Err(PgError::Protocol(format!(
1104 "DataRow ultra expects exactly 2 columns, got {}",
1105 col_count
1106 )));
1107 }
1108
1109 if msg_bytes.remaining() < 4 {
1110 return Err(PgError::Protocol(
1111 "DataRow ultra: truncated before col0 length".into(),
1112 ));
1113 }
1114 let len0 = msg_bytes.get_i32();
1115 let col0 = if len0 > 0 {
1116 let len0 = len0 as usize;
1117 if msg_bytes.remaining() < len0 {
1118 return Err(PgError::Protocol(
1119 "DataRow ultra: col0 data exceeds payload".into(),
1120 ));
1121 }
1122 msg_bytes.split_to(len0).freeze()
1123 } else if len0 == 0 {
1124 bytes::Bytes::new()
1125 } else if len0 == -1 {
1126 return Err(PgError::Protocol(
1127 "DataRow ultra does not support NULL columns".into(),
1128 ));
1129 } else {
1130 return Err(PgError::Protocol(format!(
1131 "DataRow ultra: invalid col0 length {}",
1132 len0
1133 )));
1134 };
1135
1136 if msg_bytes.remaining() < 4 {
1137 return Err(PgError::Protocol(
1138 "DataRow ultra: truncated before col1 length".into(),
1139 ));
1140 }
1141 let len1 = msg_bytes.get_i32();
1142 let col1 = if len1 > 0 {
1143 let len1 = len1 as usize;
1144 if msg_bytes.remaining() < len1 {
1145 return Err(PgError::Protocol(
1146 "DataRow ultra: col1 data exceeds payload".into(),
1147 ));
1148 }
1149 msg_bytes.split_to(len1).freeze()
1150 } else if len1 == 0 {
1151 bytes::Bytes::new()
1152 } else if len1 == -1 {
1153 return Err(PgError::Protocol(
1154 "DataRow ultra does not support NULL columns".into(),
1155 ));
1156 } else {
1157 return Err(PgError::Protocol(format!(
1158 "DataRow ultra: invalid col1 length {}",
1159 len1
1160 )));
1161 };
1162
1163 if msg_bytes.remaining() != 0 {
1164 return Err(PgError::Protocol(
1165 "DataRow ultra: trailing bytes after expected columns".into(),
1166 ));
1167 }
1168
1169 return Ok((msg_type, Some((col0, col1))));
1170 }
1171
1172 let _ = self.buffer.split_to(msg_len + 1);
1174 return Ok((msg_type, None));
1175 }
1176 }
1177
1178 let n = self.read_with_timeout().await?;
1179 if n == 0 {
1180 return self.connection_desync("Connection closed".to_string());
1181 }
1182 }
1183 }
1184}