1use std::collections::VecDeque;
9use std::sync::Arc;
10
11use bytes::BytesMut;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::sync::{mpsc, oneshot, Mutex};
14
15use crate::connection::WireConn;
16use crate::error::PgWireError;
17use crate::protocol::backend;
18use crate::protocol::frontend;
19use crate::protocol::types::{BackendMsg, FormatCode, FrontendMsg, RawRow};
20
21pub(crate) struct PipelineRequest {
28 pub(crate) messages: BytesMut,
29 pub(crate) collector: ResponseCollector,
30 pub(crate) response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
31}
32
33#[allow(dead_code)]
35#[non_exhaustive]
36pub enum ResponseCollector {
37 Rows,
39 Drain,
41 Stream {
43 header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
45 row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
47 },
48 CopyIn {
50 data: Vec<u8>,
52 },
53 CopyOut,
55}
56
57#[non_exhaustive]
59pub enum PipelineResponse {
60 Rows {
62 fields: Vec<crate::protocol::types::FieldDescription>,
64 rows: Vec<RawRow>,
66 command_tag: String,
68 },
69 Done,
72}
73
74#[derive(Debug, Clone)]
76pub struct StreamHeader {
77 pub fields: Vec<crate::protocol::types::FieldDescription>,
79}
80
81pub type StreamedRow = RawRow;
83
84pub struct AsyncConn {
90 request_tx: mpsc::Sender<PipelineRequest>,
91 stmt_cache: std::sync::Mutex<std::collections::HashMap<String, (String, u64)>>,
92 stmt_counter: std::sync::atomic::AtomicU64,
93 alive: Arc<std::sync::atomic::AtomicBool>,
94 backend_pid: i32,
95 backend_secret: i32,
96 addr: String,
97 #[allow(dead_code)]
100 notification_tx: mpsc::Sender<crate::protocol::types::BackendMsg>,
101 notification_rx: std::sync::Mutex<Option<mpsc::Receiver<crate::protocol::types::BackendMsg>>>,
102 state_mutated: Arc<std::sync::atomic::AtomicBool>,
112 broken: Arc<std::sync::atomic::AtomicBool>,
118 dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
123}
124
125impl std::fmt::Debug for AsyncConn {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 f.debug_struct("AsyncConn")
128 .field("addr", &self.addr)
129 .field("backend_pid", &self.backend_pid)
130 .field("alive", &self.is_alive())
131 .finish()
132 }
133}
134
135impl AsyncConn {
136 pub fn is_alive(&self) -> bool {
138 self.alive.load(std::sync::atomic::Ordering::Relaxed)
139 }
140
141 pub fn backend_pid(&self) -> i32 {
143 self.backend_pid
144 }
145
146 pub fn addr(&self) -> &str {
148 &self.addr
149 }
150
151 pub fn cancel_token(&self) -> crate::cancel::CancelToken {
153 crate::cancel::CancelToken::new(self.addr.clone(), self.backend_pid, self.backend_secret)
154 }
155
156 pub fn mark_state_mutated(&self) {
162 self.state_mutated
163 .store(true, std::sync::atomic::Ordering::Release);
164 }
165
166 pub fn take_state_mutated(&self) -> bool {
169 self.state_mutated
170 .swap(false, std::sync::atomic::Ordering::AcqRel)
171 }
172
173 pub fn is_state_mutated(&self) -> bool {
175 self.state_mutated
176 .load(std::sync::atomic::Ordering::Acquire)
177 }
178
179 pub fn mark_broken(&self) {
186 self.broken
187 .store(true, std::sync::atomic::Ordering::Release);
188 }
189
190 pub fn is_broken(&self) -> bool {
194 self.broken.load(std::sync::atomic::Ordering::Acquire)
195 }
196
197 #[doc(hidden)]
205 pub fn __force_mark_dead_for_test(&self) {
206 self.alive
207 .store(false, std::sync::atomic::Ordering::Release);
208 }
209
210 pub fn enqueue_rollback(&self) -> bool {
224 if !self.is_alive() {
225 return false;
226 }
227 try_enqueue_rollback(&self.request_tx)
228 }
229}
230
231fn try_enqueue_rollback(request_tx: &mpsc::Sender<PipelineRequest>) -> bool {
236 let mut buf = BytesMut::with_capacity(16);
237 frontend::encode_message(&FrontendMsg::Query(b"ROLLBACK"), &mut buf);
238 let (tx, _rx) = oneshot::channel();
239 request_tx
240 .try_send(PipelineRequest {
241 messages: buf,
242 collector: ResponseCollector::Drain,
243 response_tx: tx,
244 })
245 .is_ok()
246}
247
248struct PendingResponse {
249 collector: ResponseCollector,
250 response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
251}
252
253impl AsyncConn {
254 pub fn new(conn: WireConn) -> Self {
257 let backend_pid = conn.pid;
258 let backend_secret = conn.secret;
259 let addr = conn
261 .stream
262 .peer_addr()
263 .map(|a| a.to_string())
264 .unwrap_or_default();
265
266 let (notification_tx, notification_rx) = mpsc::channel(4096);
267 let (request_tx, request_rx) = mpsc::channel::<PipelineRequest>(256);
268 let pending: Arc<Mutex<VecDeque<PendingResponse>>> = Arc::new(Mutex::new(VecDeque::new()));
269 let pending_notify = Arc::new(tokio::sync::Notify::new());
270 let alive = Arc::new(std::sync::atomic::AtomicBool::new(true));
271 let state_mutated = Arc::new(std::sync::atomic::AtomicBool::new(false));
272 let broken = Arc::new(std::sync::atomic::AtomicBool::new(false));
273 let dropped_notifications = Arc::new(std::sync::atomic::AtomicU64::new(0));
274
275 let (stream_read, stream_write) = tokio::io::split(conn.into_stream());
276
277 {
279 let pending = Arc::clone(&pending);
280 let pending_notify = Arc::clone(&pending_notify);
281 let alive = Arc::clone(&alive);
282 tokio::spawn(async move {
283 writer_task(request_rx, stream_write, pending, pending_notify).await;
284 alive.store(false, std::sync::atomic::Ordering::Relaxed);
285 tracing::warn!("pg-wired writer task exited");
286 });
287 }
288
289 {
291 let pending = Arc::clone(&pending);
292 let pending_notify = Arc::clone(&pending_notify);
293 let alive_clone = Arc::clone(&alive);
294 let state_mutated = Arc::clone(&state_mutated);
295 let ntf_tx = notification_tx.clone();
296 let dropped = Arc::clone(&dropped_notifications);
297 tokio::spawn(async move {
298 reader_task(
299 stream_read,
300 pending,
301 pending_notify,
302 ntf_tx,
303 state_mutated,
304 dropped,
305 )
306 .await;
307 alive_clone.store(false, std::sync::atomic::Ordering::Relaxed);
308 tracing::warn!("pg-wired reader task exited");
309 });
310 }
311
312 Self {
313 request_tx,
314 stmt_cache: std::sync::Mutex::new(std::collections::HashMap::new()),
315 stmt_counter: std::sync::atomic::AtomicU64::new(0),
316 alive,
317 backend_pid,
318 backend_secret,
319 addr,
320 notification_tx,
321 notification_rx: std::sync::Mutex::new(Some(notification_rx)),
322 state_mutated,
323 broken,
324 dropped_notifications,
325 }
326 }
327
328 pub fn dropped_notifications(&self) -> u64 {
336 self.dropped_notifications
337 .load(std::sync::atomic::Ordering::Relaxed)
338 }
339
340 pub fn take_notification_receiver(
343 &self,
344 ) -> Option<mpsc::Receiver<crate::protocol::types::BackendMsg>> {
345 self.notification_rx
346 .lock()
347 .ok()
348 .and_then(|mut guard| guard.take())
349 }
350
351 pub fn lookup_or_alloc(&self, sql: &str) -> (Vec<u8>, bool) {
356 let mut cache = match self.stmt_cache.lock() {
357 Ok(c) => c,
358 Err(poisoned) => poisoned.into_inner(),
359 };
360 if let Some((name, _)) = cache.get(sql) {
361 return (name.as_bytes().to_vec(), false);
362 }
363 if cache.len() >= 256 {
366 if let Some((oldest_key, oldest_name)) = cache
367 .iter()
368 .min_by_key(|(_, (_, counter))| *counter)
369 .map(|(k, (name, _))| (k.clone(), name.clone()))
370 {
371 cache.remove(&oldest_key);
372 let mut close_buf = BytesMut::with_capacity(32);
375 frontend::encode_message(
376 &FrontendMsg::Close {
377 kind: b'S',
378 name: oldest_name.as_bytes(),
379 },
380 &mut close_buf,
381 );
382 frontend::encode_message(&FrontendMsg::Sync, &mut close_buf);
383 let (tx, _rx) = oneshot::channel();
384 let _ = self.request_tx.try_send(PipelineRequest {
385 messages: close_buf,
386 collector: ResponseCollector::Drain,
387 response_tx: tx,
388 });
389 }
390 }
391 let n = self
392 .stmt_counter
393 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
394 let name = format!("s{n}");
395 cache.insert(sql.to_string(), (name.clone(), n));
396 (name.into_bytes(), true)
397 }
398
399 pub async fn copy_in(&self, copy_sql: &str, data: &[u8]) -> Result<u64, PgWireError> {
405 use crate::protocol::types::FrontendMsg;
406 const CHUNK_SIZE: usize = 1024 * 1024; let mut buf = BytesMut::with_capacity(copy_sql.len() + data.len().min(CHUNK_SIZE) + 64);
410 frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
411
412 for chunk in data.chunks(CHUNK_SIZE) {
414 frontend::encode_message(&FrontendMsg::CopyData(chunk), &mut buf);
415 }
416 frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
418
419 let resp = self
420 .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
421 .await?;
422 match resp {
423 PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
424 PipelineResponse::Done => Ok(0),
425 }
426 }
427
428 pub async fn copy_in_stream<R: tokio::io::AsyncRead + Unpin>(
441 &self,
442 copy_sql: &str,
443 mut reader: R,
444 ) -> Result<u64, PgWireError> {
445 use tokio::io::AsyncReadExt;
446 const CHUNK_SIZE: usize = 1024 * 1024; let mut buf = BytesMut::with_capacity(copy_sql.len() + 16);
450 frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
451
452 let mut chunk = vec![0u8; CHUNK_SIZE];
454 loop {
455 let n = reader.read(&mut chunk).await?;
456 if n == 0 {
457 break;
458 }
459 frontend::encode_message(&FrontendMsg::CopyData(&chunk[..n]), &mut buf);
460 }
461 frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
462
463 let resp = self
464 .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
465 .await?;
466 match resp {
467 PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
468 PipelineResponse::Done => Ok(0),
469 }
470 }
471
472 pub async fn copy_out(&self, copy_sql: &str) -> Result<Vec<u8>, PgWireError> {
474 use crate::protocol::types::FrontendMsg;
475 let mut buf = BytesMut::new();
476 frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
477
478 let resp = self.submit(buf, ResponseCollector::CopyOut).await?;
479 match resp {
480 PipelineResponse::Rows { rows, .. } => {
481 let mut result = Vec::new();
484 for row in rows {
485 for data in row.iter().flatten() {
486 result.extend_from_slice(data);
487 }
488 }
489 Ok(result)
490 }
491 PipelineResponse::Done => Ok(Vec::new()),
492 }
493 }
494
495 pub fn invalidate_statement(&self, sql: &str) {
498 let mut cache = match self.stmt_cache.lock() {
499 Ok(c) => c,
500 Err(poisoned) => poisoned.into_inner(),
501 };
502 cache.remove(sql);
503 }
504
505 pub fn clear_statement_cache(&self) {
508 let mut cache = match self.stmt_cache.lock() {
509 Ok(c) => c,
510 Err(poisoned) => poisoned.into_inner(),
511 };
512 cache.clear();
513 }
514
515 pub async fn exec_transaction(
517 &self,
518 setup_sql: &str,
519 query_sql: &str,
520 params: &[Option<&[u8]>],
521 param_oids: &[u32],
522 ) -> Result<Vec<RawRow>, PgWireError> {
523 let (stmt_name, needs_parse) = self.lookup_or_alloc(query_sql);
524 self.pipeline_transaction(
525 setup_sql,
526 query_sql,
527 params,
528 param_oids,
529 &stmt_name,
530 needs_parse,
531 )
532 .await
533 }
534
535 pub async fn exec_query(
539 &self,
540 sql: &str,
541 params: &[Option<&[u8]>],
542 param_oids: &[u32],
543 ) -> Result<Vec<RawRow>, PgWireError> {
544 let (stmt_name, needs_parse) = self.lookup_or_alloc(sql);
545 match self
546 .query(sql, params, param_oids, &stmt_name, needs_parse)
547 .await
548 {
549 Ok(rows) => Ok(rows),
550 Err(PgWireError::Pg(ref pg_err))
551 if !needs_parse && is_stale_statement_error(pg_err) =>
552 {
553 tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
554 self.invalidate_statement(sql);
555 let (stmt_name, _) = self.lookup_or_alloc(sql);
556 self.query(sql, params, param_oids, &stmt_name, true).await
557 }
558 Err(e) => Err(e),
559 }
560 }
561
562 const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
565
566 pub async fn submit(
570 &self,
571 messages: BytesMut,
572 collector: ResponseCollector,
573 ) -> Result<PipelineResponse, PgWireError> {
574 let (response_tx, response_rx) = oneshot::channel();
575 let req = PipelineRequest {
576 messages,
577 collector,
578 response_tx,
579 };
580 self.request_tx
581 .send(req)
582 .await
583 .map_err(|_| PgWireError::ConnectionClosed)?;
584 match tokio::time::timeout(Self::REQUEST_TIMEOUT, response_rx).await {
585 Ok(Ok(result)) => result,
586 Ok(Err(_)) => Err(PgWireError::ConnectionClosed),
587 Err(_elapsed) => {
588 tracing::error!(
589 "request timed out after {:?} — reader/writer task may be dead",
590 Self::REQUEST_TIMEOUT
591 );
592 Err(PgWireError::ConnectionClosed)
593 }
594 }
595 }
596
597 pub async fn submit_batch(
607 &self,
608 items: Vec<(BytesMut, ResponseCollector)>,
609 ) -> Result<Vec<Result<PipelineResponse, PgWireError>>, PgWireError> {
610 let mut receivers = Vec::with_capacity(items.len());
611 for (messages, collector) in items {
612 let (response_tx, response_rx) = oneshot::channel();
613 self.request_tx
614 .send(PipelineRequest {
615 messages,
616 collector,
617 response_tx,
618 })
619 .await
620 .map_err(|_| PgWireError::ConnectionClosed)?;
621 receivers.push(response_rx);
622 }
623 let mut results = Vec::with_capacity(receivers.len());
624 for rx in receivers {
625 match tokio::time::timeout(Self::REQUEST_TIMEOUT, rx).await {
626 Ok(Ok(r)) => results.push(r),
627 Ok(Err(_)) => results.push(Err(PgWireError::ConnectionClosed)),
628 Err(_) => {
629 tracing::error!(
630 "submit_batch request timed out after {:?}",
631 Self::REQUEST_TIMEOUT
632 );
633 results.push(Err(PgWireError::ConnectionClosed));
634 }
635 }
636 }
637 Ok(results)
638 }
639
640 pub async fn close(&self) -> Result<(), PgWireError> {
645 if !self.is_alive() {
646 return Ok(());
647 }
648 let mut buf = BytesMut::with_capacity(5);
649 frontend::encode_message(&FrontendMsg::Terminate, &mut buf);
650 match self.submit(buf, ResponseCollector::Drain).await {
655 Ok(_) | Err(PgWireError::ConnectionClosed) => Ok(()),
656 Err(PgWireError::Io(e)) if e.kind() == std::io::ErrorKind::BrokenPipe => Ok(()),
657 Err(e) => Err(e),
658 }
659 }
660
661 pub async fn submit_stream(
664 &self,
665 messages: BytesMut,
666 row_buffer: usize,
667 ) -> Result<
668 (
669 StreamHeader,
670 mpsc::Receiver<Result<StreamedRow, PgWireError>>,
671 ),
672 PgWireError,
673 > {
674 let (header_tx, header_rx) = oneshot::channel();
675 let (row_tx, row_rx) = mpsc::channel(row_buffer);
676 let (response_tx, _response_rx) = oneshot::channel();
677 let req = PipelineRequest {
678 messages,
679 collector: ResponseCollector::Stream { header_tx, row_tx },
680 response_tx,
681 };
682 self.request_tx
683 .send(req)
684 .await
685 .map_err(|_| PgWireError::ConnectionClosed)?;
686 let header = header_rx
687 .await
688 .map_err(|_| PgWireError::ConnectionClosed)??;
689 Ok((header, row_rx))
690 }
691
692 pub async fn pipeline_transaction(
696 &self,
697 setup_sql: &str,
698 query_sql: &str,
699 params: &[Option<&[u8]>],
700 param_oids: &[u32],
701 stmt_name: &[u8],
702 needs_parse: bool,
703 ) -> Result<Vec<RawRow>, PgWireError> {
704 let mut buf = BytesMut::with_capacity(1024);
705
706 frontend::encode_message(&FrontendMsg::Query(setup_sql.as_bytes()), &mut buf);
708
709 let setup_msgs = buf.split();
711
712 let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
714 let result_fmts = [FormatCode::Text];
715
716 if needs_parse {
717 frontend::encode_message(
718 &FrontendMsg::Parse {
719 name: stmt_name,
720 sql: query_sql.as_bytes(),
721 param_oids,
722 },
723 &mut buf,
724 );
725 }
726
727 frontend::encode_message(
728 &FrontendMsg::Bind {
729 portal: b"",
730 statement: stmt_name,
731 param_formats: &text_fmts[..params.len()],
732 params,
733 result_formats: &result_fmts,
734 },
735 &mut buf,
736 );
737
738 frontend::encode_message(
739 &FrontendMsg::Execute {
740 portal: b"",
741 max_rows: 0,
742 },
743 &mut buf,
744 );
745
746 frontend::encode_message(&FrontendMsg::Sync, &mut buf);
747
748 let data_msgs = buf.split();
749
750 let mut commit_buf = BytesMut::with_capacity(32);
753 frontend::encode_message(&FrontendMsg::Query(b"COMMIT"), &mut commit_buf);
754
755 let (setup_tx, setup_rx) = oneshot::channel();
758 let (data_tx, data_rx) = oneshot::channel();
759 let (commit_tx, commit_rx) = oneshot::channel();
760
761 self.request_tx
764 .send(PipelineRequest {
765 messages: setup_msgs,
766 collector: ResponseCollector::Drain,
767 response_tx: setup_tx,
768 })
769 .await
770 .map_err(|_| PgWireError::ConnectionClosed)?;
771
772 self.request_tx
773 .send(PipelineRequest {
774 messages: data_msgs,
775 collector: ResponseCollector::Rows,
776 response_tx: data_tx,
777 })
778 .await
779 .map_err(|_| PgWireError::ConnectionClosed)?;
780
781 self.request_tx
782 .send(PipelineRequest {
783 messages: commit_buf,
784 collector: ResponseCollector::Drain,
785 response_tx: commit_tx,
786 })
787 .await
788 .map_err(|_| PgWireError::ConnectionClosed)?;
789
790 setup_rx
792 .await
793 .map_err(|_| PgWireError::ConnectionClosed)??;
794
795 let data_resp = data_rx.await.map_err(|_| PgWireError::ConnectionClosed)??;
796
797 commit_rx
798 .await
799 .map_err(|_| PgWireError::ConnectionClosed)??;
800
801 match data_resp {
802 PipelineResponse::Rows { rows, .. } => Ok(rows),
803 PipelineResponse::Done => Ok(Vec::new()),
804 }
805 }
806
807 pub async fn query(
809 &self,
810 sql: &str,
811 params: &[Option<&[u8]>],
812 param_oids: &[u32],
813 stmt_name: &[u8],
814 needs_parse: bool,
815 ) -> Result<Vec<RawRow>, PgWireError> {
816 self.query_with_formats(sql, params, param_oids, &[], &[], stmt_name, needs_parse)
817 .await
818 }
819
820 #[allow(clippy::too_many_arguments)]
831 pub async fn query_with_formats(
832 &self,
833 sql: &str,
834 params: &[Option<&[u8]>],
835 param_oids: &[u32],
836 param_formats: &[FormatCode],
837 result_formats: &[FormatCode],
838 stmt_name: &[u8],
839 needs_parse: bool,
840 ) -> Result<Vec<RawRow>, PgWireError> {
841 let mut buf = BytesMut::with_capacity(512);
842
843 let text_param_fmts: Vec<FormatCode>;
845 let param_fmts_slice: &[FormatCode] = if param_formats.is_empty() {
846 text_param_fmts = vec![FormatCode::Text; params.len().max(1)];
847 &text_param_fmts[..params.len()]
848 } else {
849 param_formats
850 };
851 let default_result_fmts = [FormatCode::Text];
852 let result_fmts_slice: &[FormatCode] = if result_formats.is_empty() {
853 &default_result_fmts
854 } else {
855 result_formats
856 };
857
858 if needs_parse {
859 frontend::encode_message(
860 &FrontendMsg::Parse {
861 name: stmt_name,
862 sql: sql.as_bytes(),
863 param_oids,
864 },
865 &mut buf,
866 );
867 }
868
869 frontend::encode_message(
870 &FrontendMsg::Bind {
871 portal: b"",
872 statement: stmt_name,
873 param_formats: param_fmts_slice,
874 params,
875 result_formats: result_fmts_slice,
876 },
877 &mut buf,
878 );
879
880 frontend::encode_message(
881 &FrontendMsg::Execute {
882 portal: b"",
883 max_rows: 0,
884 },
885 &mut buf,
886 );
887
888 frontend::encode_message(&FrontendMsg::Sync, &mut buf);
889
890 let resp = self.submit(buf, ResponseCollector::Rows).await?;
891 match resp {
892 PipelineResponse::Rows { rows, .. } => Ok(rows),
893 PipelineResponse::Done => Ok(Vec::new()),
894 }
895 }
896
897 pub async fn exec_query_with_formats(
900 &self,
901 sql: &str,
902 params: &[Option<&[u8]>],
903 param_oids: &[u32],
904 param_formats: &[FormatCode],
905 result_formats: &[FormatCode],
906 ) -> Result<Vec<RawRow>, PgWireError> {
907 let (stmt_name, needs_parse) = self.lookup_or_alloc(sql);
908 match self
909 .query_with_formats(
910 sql,
911 params,
912 param_oids,
913 param_formats,
914 result_formats,
915 &stmt_name,
916 needs_parse,
917 )
918 .await
919 {
920 Ok(rows) => Ok(rows),
921 Err(PgWireError::Pg(ref pg_err))
922 if !needs_parse && is_stale_statement_error(pg_err) =>
923 {
924 tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
925 self.invalidate_statement(sql);
926 let (stmt_name, _) = self.lookup_or_alloc(sql);
927 self.query_with_formats(
928 sql,
929 params,
930 param_oids,
931 param_formats,
932 result_formats,
933 &stmt_name,
934 true,
935 )
936 .await
937 }
938 Err(e) => Err(e),
939 }
940 }
941}
942
943async fn writer_task(
948 mut rx: mpsc::Receiver<PipelineRequest>,
949 mut stream: tokio::io::WriteHalf<crate::tls::MaybeTlsStream>,
950 pending: Arc<Mutex<VecDeque<PendingResponse>>>,
951 pending_notify: Arc<tokio::sync::Notify>,
952) {
953 let mut write_buf = BytesMut::with_capacity(8192);
954
955 loop {
956 let first = match rx.recv().await {
958 Some(req) => req,
959 None => {
960 drain_pending_on_exit(&pending).await;
962 return;
963 }
964 };
965
966 write_buf.clear();
968 write_buf.extend_from_slice(&first.messages);
969
970 let mut batch: Vec<PendingResponse> = vec![PendingResponse {
971 collector: first.collector,
972 response_tx: first.response_tx,
973 }];
974
975 while let Ok(req) = rx.try_recv() {
977 write_buf.extend_from_slice(&req.messages);
978 batch.push(PendingResponse {
979 collector: req.collector,
980 response_tx: req.response_tx,
981 });
982 }
983
984 let write_result = stream.write_all(&write_buf).await;
988 let write_err = match write_result {
989 Ok(_) => stream.flush().await.err(),
990 Err(e) => Some(e),
991 };
992
993 if let Some(e) = write_err {
994 tracing::error!("Writer error: {e}");
995 let msg = e.to_string();
996 for p in batch {
997 let _ = p.response_tx.send(Err(PgWireError::Io(std::io::Error::new(
998 std::io::ErrorKind::BrokenPipe,
999 msg.clone(),
1000 ))));
1001 }
1002 drain_pending_on_exit(&pending).await;
1004 return;
1005 }
1006
1007 {
1009 let mut pq = pending.lock().await;
1010 for p in batch {
1011 pq.push_back(p);
1012 }
1013 }
1014 pending_notify.notify_one();
1016 }
1017}
1018
1019async fn drain_pending_on_exit(pending: &Arc<Mutex<VecDeque<PendingResponse>>>) {
1022 let mut pq = pending.lock().await;
1023 while let Some(pr) = pq.pop_front() {
1024 let _ = pr.response_tx.send(Err(PgWireError::ConnectionClosed));
1025 }
1026}
1027
1028async fn reader_task(
1033 mut stream: tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1034 pending: Arc<Mutex<VecDeque<PendingResponse>>>,
1035 pending_notify: Arc<tokio::sync::Notify>,
1036 notification_tx: mpsc::Sender<BackendMsg>,
1037 state_mutated: Arc<std::sync::atomic::AtomicBool>,
1038 dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
1039) {
1040 let mut recv_buf = BytesMut::with_capacity(32 * 1024);
1041
1042 loop {
1043 let pr = loop {
1045 {
1046 let mut pq = pending.lock().await;
1047 if let Some(pr) = pq.pop_front() {
1048 break pr;
1049 }
1050 }
1051 pending_notify.notified().await;
1053 };
1054
1055 let result = match pr.collector {
1057 ResponseCollector::Rows => {
1058 collect_rows(
1059 &mut stream,
1060 &mut recv_buf,
1061 ¬ification_tx,
1062 &state_mutated,
1063 &dropped_notifications,
1064 )
1065 .await
1066 }
1067 ResponseCollector::Drain => {
1068 drain_until_ready(&mut stream, &mut recv_buf, Some(&state_mutated))
1069 .await
1070 .map(|_| PipelineResponse::Done)
1071 }
1072 ResponseCollector::Stream { header_tx, row_tx } => {
1073 stream_rows(
1074 &mut stream,
1075 &mut recv_buf,
1076 header_tx,
1077 row_tx,
1078 ¬ification_tx,
1079 &state_mutated,
1080 &dropped_notifications,
1081 )
1082 .await;
1083 Ok(PipelineResponse::Done)
1084 }
1085 ResponseCollector::CopyIn { .. } => {
1086 collect_copy_in_response(&mut stream, &mut recv_buf, &state_mutated).await
1087 }
1088 ResponseCollector::CopyOut => {
1089 collect_copy_out(&mut stream, &mut recv_buf, &state_mutated).await
1090 }
1091 };
1092
1093 let _ = pr.response_tx.send(result);
1095 }
1096}
1097
1098async fn read_msg(
1099 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1100 buf: &mut BytesMut,
1101) -> Result<BackendMsg, PgWireError> {
1102 loop {
1103 if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1104 return Ok(msg);
1105 }
1106 let n = stream.read_buf(buf).await?;
1107 if n == 0 {
1108 if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1112 return Ok(msg);
1113 }
1114 return Err(PgWireError::ConnectionClosed);
1115 }
1116 }
1117}
1118
1119fn note_rfq_status(status: u8, state_mutated: &std::sync::atomic::AtomicBool) {
1123 if status != b'I' {
1124 state_mutated.store(true, std::sync::atomic::Ordering::Release);
1125 }
1126}
1127
1128async fn collect_rows(
1129 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1130 buf: &mut BytesMut,
1131 notification_tx: &mpsc::Sender<BackendMsg>,
1132 state_mutated: &std::sync::atomic::AtomicBool,
1133 dropped_notifications: &std::sync::atomic::AtomicU64,
1134) -> Result<PipelineResponse, PgWireError> {
1135 let mut rows = Vec::new();
1136 let mut fields = Vec::new();
1137 let mut command_tag = String::new();
1138 loop {
1139 let msg = read_msg(stream, buf).await?;
1140 match msg {
1141 BackendMsg::DataRow(row) => rows.push(row),
1142 BackendMsg::RowDescription { fields: f } => fields = f,
1143 BackendMsg::CommandComplete { tag } => command_tag = tag,
1144 BackendMsg::ReadyForQuery { status } => {
1145 note_rfq_status(status, state_mutated);
1146 return Ok(PipelineResponse::Rows {
1147 fields,
1148 rows,
1149 command_tag,
1150 });
1151 }
1152 BackendMsg::ErrorResponse { fields } => {
1153 drain_until_ready(stream, buf, Some(state_mutated)).await?;
1154 return Err(PgWireError::Pg(fields));
1155 }
1156 msg @ BackendMsg::NotificationResponse { .. } => {
1157 #[allow(clippy::collapsible_match)]
1159 if notification_tx.try_send(msg).is_err() {
1160 dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1161 tracing::warn!("notification channel full, dropping notification");
1162 }
1163 }
1164 BackendMsg::ParseComplete
1165 | BackendMsg::BindComplete
1166 | BackendMsg::NoData
1167 | BackendMsg::NoticeResponse { .. }
1168 | BackendMsg::EmptyQueryResponse => {}
1169 _ => {}
1170 }
1171 }
1172}
1173
1174async fn drain_until_ready(
1175 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1176 buf: &mut BytesMut,
1177 state_mutated: Option<&std::sync::atomic::AtomicBool>,
1178) -> Result<(), PgWireError> {
1179 loop {
1180 let msg = read_msg(stream, buf).await?;
1181 if let BackendMsg::ReadyForQuery { status } = msg {
1182 if let Some(sm) = state_mutated {
1183 note_rfq_status(status, sm);
1184 }
1185 return Ok(());
1186 }
1187 if let BackendMsg::ErrorResponse { ref fields } = msg {
1188 tracing::warn!("Error in drain: {}: {}", fields.code, fields.message);
1189 }
1190 }
1191}
1192
1193async fn stream_rows(
1195 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1196 buf: &mut BytesMut,
1197 header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
1198 row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
1199 notification_tx: &mpsc::Sender<BackendMsg>,
1200 state_mutated: &std::sync::atomic::AtomicBool,
1201 dropped_notifications: &std::sync::atomic::AtomicU64,
1202) {
1203 let mut header_tx = Some(header_tx);
1204 let mut fields = Vec::new();
1205 loop {
1206 let msg = match read_msg(stream, buf).await {
1207 Ok(msg) => msg,
1208 Err(e) => {
1209 if let Some(htx) = header_tx.take() {
1210 let _ = htx.send(Err(e));
1211 } else {
1212 let _ = row_tx.send(Err(e)).await;
1213 }
1214 return;
1215 }
1216 };
1217 match msg {
1218 BackendMsg::RowDescription { fields: f } => {
1219 fields = f;
1220 }
1221 BackendMsg::DataRow(row) => {
1222 if let Some(htx) = header_tx.take() {
1223 let _ = htx.send(Ok(StreamHeader {
1224 fields: fields.clone(),
1225 }));
1226 }
1227 if row_tx.send(Ok(row)).await.is_err() {
1228 let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1229 return;
1230 }
1231 }
1232 BackendMsg::CommandComplete { .. } => {
1233 if let Some(htx) = header_tx.take() {
1234 let _ = htx.send(Ok(StreamHeader {
1235 fields: std::mem::take(&mut fields),
1236 }));
1237 }
1238 }
1239 BackendMsg::ReadyForQuery { status } => {
1240 note_rfq_status(status, state_mutated);
1241 if let Some(htx) = header_tx.take() {
1242 let _ = htx.send(Ok(StreamHeader {
1243 fields: std::mem::take(&mut fields),
1244 }));
1245 }
1246 return;
1247 }
1248 BackendMsg::ErrorResponse { fields: err } => {
1249 if let Some(htx) = header_tx.take() {
1250 let _ = htx.send(Err(PgWireError::Pg(err)));
1251 } else {
1252 let _ = row_tx.send(Err(PgWireError::Pg(err))).await;
1253 }
1254 let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1255 return;
1256 }
1257 msg @ BackendMsg::NotificationResponse { .. } => {
1258 #[allow(clippy::collapsible_match)]
1259 if notification_tx.try_send(msg).is_err() {
1260 dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1261 tracing::warn!("notification channel full, dropping notification");
1262 }
1263 }
1264 BackendMsg::ParseComplete
1265 | BackendMsg::BindComplete
1266 | BackendMsg::NoData
1267 | BackendMsg::PortalSuspended
1268 | BackendMsg::NoticeResponse { .. }
1269 | BackendMsg::EmptyQueryResponse => {}
1270 _ => {}
1271 }
1272 }
1273}
1274
1275async fn collect_copy_in_response(
1278 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1279 buf: &mut BytesMut,
1280 state_mutated: &std::sync::atomic::AtomicBool,
1281) -> Result<PipelineResponse, PgWireError> {
1282 let mut command_tag = String::new();
1283 loop {
1284 let msg = read_msg(stream, buf).await?;
1285 match msg {
1286 BackendMsg::CopyInResponse { .. } => {}
1287 BackendMsg::CommandComplete { tag } => command_tag = tag,
1288 BackendMsg::ReadyForQuery { status } => {
1289 note_rfq_status(status, state_mutated);
1290 return Ok(PipelineResponse::Rows {
1291 fields: Vec::new(),
1292 rows: Vec::new(),
1293 command_tag,
1294 });
1295 }
1296 BackendMsg::ErrorResponse { fields } => {
1297 drain_until_ready(stream, buf, Some(state_mutated)).await?;
1298 return Err(PgWireError::Pg(fields));
1299 }
1300 _ => {}
1301 }
1302 }
1303}
1304
1305async fn collect_copy_out(
1307 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1308 buf: &mut BytesMut,
1309 state_mutated: &std::sync::atomic::AtomicBool,
1310) -> Result<PipelineResponse, PgWireError> {
1311 let mut data_chunks: Vec<RawRow> = Vec::new();
1312 let mut command_tag = String::new();
1313 loop {
1314 let msg = read_msg(stream, buf).await?;
1315 match msg {
1316 BackendMsg::CopyOutResponse { .. } => {}
1317 BackendMsg::CopyData { data } => {
1318 let body = bytes::Bytes::from(data);
1319 data_chunks.push(RawRow::from_full_body(body));
1320 }
1321 BackendMsg::CopyDone => {}
1322 BackendMsg::CommandComplete { tag } => command_tag = tag,
1323 BackendMsg::ReadyForQuery { status } => {
1324 note_rfq_status(status, state_mutated);
1325 return Ok(PipelineResponse::Rows {
1326 fields: Vec::new(),
1327 rows: data_chunks,
1328 command_tag,
1329 });
1330 }
1331 BackendMsg::ErrorResponse { fields } => {
1332 drain_until_ready(stream, buf, Some(state_mutated)).await?;
1333 return Err(PgWireError::Pg(fields));
1334 }
1335 _ => {}
1336 }
1337 }
1338}
1339
1340fn is_stale_statement_error(err: &crate::protocol::types::PgError) -> bool {
1344 matches!(err.code.as_str(), "26000" | "0A000")
1345}
1346
1347fn parse_copy_count(tag: &str) -> u64 {
1348 tag.strip_prefix("COPY ")
1350 .and_then(|s| s.parse::<u64>().ok())
1351 .unwrap_or(0)
1352}
1353
1354impl WireConn {
1356 pub(crate) fn into_stream(self) -> crate::tls::MaybeTlsStream {
1357 self.stream
1358 }
1359}
1360
1361#[cfg(test)]
1362mod tests {
1363 use super::*;
1364
1365 #[tokio::test]
1368 async fn try_enqueue_rollback_returns_false_when_channel_full() {
1369 let (tx, _rx) = mpsc::channel::<PipelineRequest>(2);
1370 let mut filled = false;
1374 for _ in 0..16 {
1375 if !try_enqueue_rollback(&tx) {
1376 filled = true;
1377 break;
1378 }
1379 }
1380 assert!(
1381 filled,
1382 "expected try_enqueue_rollback to eventually return false on a full channel"
1383 );
1384 assert!(
1385 !try_enqueue_rollback(&tx),
1386 "subsequent calls on a full channel must keep returning false"
1387 );
1388 }
1389
1390 #[tokio::test]
1393 async fn try_enqueue_rollback_returns_false_when_channel_closed() {
1394 let (tx, rx) = mpsc::channel::<PipelineRequest>(8);
1395 drop(rx);
1396 assert!(
1397 !try_enqueue_rollback(&tx),
1398 "try_enqueue_rollback must return false when the receiver has been dropped"
1399 );
1400 }
1401
1402 #[tokio::test]
1406 async fn try_enqueue_rollback_returns_true_and_enqueues_query() {
1407 let (tx, mut rx) = mpsc::channel::<PipelineRequest>(2);
1408 assert!(try_enqueue_rollback(&tx));
1409 let req = rx.recv().await.expect("request should be received");
1410 assert_eq!(
1411 req.messages.first().copied(),
1412 Some(b'Q'),
1413 "queued request should be a simple Query message"
1414 );
1415 assert!(
1418 req.messages.windows(8).any(|w| w == b"ROLLBACK"),
1419 "queued request should contain the ROLLBACK statement text"
1420 );
1421 }
1422}