1use std::collections::VecDeque;
9use std::sync::Arc;
10
11use bytes::BytesMut;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::sync::{mpsc, oneshot, Mutex};
14
15use crate::connection::WireConn;
16use crate::error::PgWireError;
17use crate::protocol::backend;
18use crate::protocol::frontend;
19use crate::protocol::types::{BackendMsg, FormatCode, FrontendMsg, RawRow};
20
21pub(crate) struct PipelineRequest {
28 pub(crate) messages: BytesMut,
29 pub(crate) collector: ResponseCollector,
30 pub(crate) response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
31}
32
33#[allow(dead_code)]
35#[non_exhaustive]
36pub enum ResponseCollector {
37 Rows,
39 Drain,
41 Stream {
43 header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
45 row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
47 },
48 CopyIn {
50 data: Vec<u8>,
52 },
53 CopyOut,
55}
56
57#[non_exhaustive]
59pub enum PipelineResponse {
60 Rows {
62 fields: Vec<crate::protocol::types::FieldDescription>,
64 rows: Vec<RawRow>,
66 command_tag: String,
68 },
69 Done,
72}
73
74#[derive(Debug, Clone)]
76pub struct StreamHeader {
77 pub fields: Vec<crate::protocol::types::FieldDescription>,
79}
80
81pub type StreamedRow = RawRow;
83
84pub struct AsyncConn {
90 request_tx: mpsc::Sender<PipelineRequest>,
91 stmt_cache: std::sync::Mutex<std::collections::HashMap<String, (String, u64)>>,
92 stmt_counter: std::sync::atomic::AtomicU64,
93 alive: Arc<std::sync::atomic::AtomicBool>,
94 backend_pid: i32,
95 backend_secret: i32,
96 addr: String,
97 #[allow(dead_code)]
100 notification_tx: mpsc::Sender<crate::protocol::types::BackendMsg>,
101 notification_rx: std::sync::Mutex<Option<mpsc::Receiver<crate::protocol::types::BackendMsg>>>,
102 state_mutated: Arc<std::sync::atomic::AtomicBool>,
112 broken: Arc<std::sync::atomic::AtomicBool>,
118 dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
123}
124
125impl std::fmt::Debug for AsyncConn {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 f.debug_struct("AsyncConn")
128 .field("addr", &self.addr)
129 .field("backend_pid", &self.backend_pid)
130 .field("alive", &self.is_alive())
131 .finish()
132 }
133}
134
135impl AsyncConn {
136 pub fn is_alive(&self) -> bool {
138 self.alive.load(std::sync::atomic::Ordering::Relaxed)
139 }
140
141 pub fn backend_pid(&self) -> i32 {
143 self.backend_pid
144 }
145
146 pub fn addr(&self) -> &str {
148 &self.addr
149 }
150
151 pub fn cancel_token(&self) -> crate::cancel::CancelToken {
153 crate::cancel::CancelToken::new(self.addr.clone(), self.backend_pid, self.backend_secret)
154 }
155
156 pub fn mark_state_mutated(&self) {
162 self.state_mutated
163 .store(true, std::sync::atomic::Ordering::Release);
164 }
165
166 pub fn take_state_mutated(&self) -> bool {
169 self.state_mutated
170 .swap(false, std::sync::atomic::Ordering::AcqRel)
171 }
172
173 pub fn is_state_mutated(&self) -> bool {
175 self.state_mutated
176 .load(std::sync::atomic::Ordering::Acquire)
177 }
178
179 pub fn mark_broken(&self) {
186 self.broken
187 .store(true, std::sync::atomic::Ordering::Release);
188 }
189
190 pub fn is_broken(&self) -> bool {
194 self.broken.load(std::sync::atomic::Ordering::Acquire)
195 }
196
197 #[doc(hidden)]
205 pub fn __force_mark_dead_for_test(&self) {
206 self.alive
207 .store(false, std::sync::atomic::Ordering::Release);
208 }
209
210 pub fn enqueue_rollback(&self) -> bool {
224 if !self.is_alive() {
225 return false;
226 }
227 try_enqueue_rollback(&self.request_tx)
228 }
229}
230
231fn try_enqueue_rollback(request_tx: &mpsc::Sender<PipelineRequest>) -> bool {
236 let mut buf = BytesMut::with_capacity(16);
237 frontend::encode_message(&FrontendMsg::Query(b"ROLLBACK"), &mut buf);
238 let (tx, _rx) = oneshot::channel();
239 request_tx
240 .try_send(PipelineRequest {
241 messages: buf,
242 collector: ResponseCollector::Drain,
243 response_tx: tx,
244 })
245 .is_ok()
246}
247
248struct PendingResponse {
249 collector: ResponseCollector,
250 response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
251}
252
253impl AsyncConn {
254 pub fn new(conn: WireConn) -> Self {
257 let backend_pid = conn.pid;
258 let backend_secret = conn.secret;
259 let addr = conn
261 .stream
262 .peer_addr()
263 .map(|a| a.to_string())
264 .unwrap_or_default();
265
266 let (notification_tx, notification_rx) = mpsc::channel(4096);
267 let (request_tx, request_rx) = mpsc::channel::<PipelineRequest>(256);
268 let pending: Arc<Mutex<VecDeque<PendingResponse>>> = Arc::new(Mutex::new(VecDeque::new()));
269 let pending_notify = Arc::new(tokio::sync::Notify::new());
270 let alive = Arc::new(std::sync::atomic::AtomicBool::new(true));
271 let state_mutated = Arc::new(std::sync::atomic::AtomicBool::new(false));
272 let broken = Arc::new(std::sync::atomic::AtomicBool::new(false));
273 let dropped_notifications = Arc::new(std::sync::atomic::AtomicU64::new(0));
274
275 let (stream_read, stream_write) = tokio::io::split(conn.into_stream());
276
277 {
279 let pending = Arc::clone(&pending);
280 let pending_notify = Arc::clone(&pending_notify);
281 let alive = Arc::clone(&alive);
282 tokio::spawn(async move {
283 writer_task(request_rx, stream_write, pending, pending_notify).await;
284 alive.store(false, std::sync::atomic::Ordering::Relaxed);
285 tracing::warn!("pg-wired writer task exited");
286 });
287 }
288
289 {
291 let pending = Arc::clone(&pending);
292 let pending_notify = Arc::clone(&pending_notify);
293 let alive_clone = Arc::clone(&alive);
294 let state_mutated = Arc::clone(&state_mutated);
295 let ntf_tx = notification_tx.clone();
296 let dropped = Arc::clone(&dropped_notifications);
297 tokio::spawn(async move {
298 reader_task(
299 stream_read,
300 pending,
301 pending_notify,
302 ntf_tx,
303 state_mutated,
304 dropped,
305 )
306 .await;
307 alive_clone.store(false, std::sync::atomic::Ordering::Relaxed);
308 tracing::warn!("pg-wired reader task exited");
309 });
310 }
311
312 Self {
313 request_tx,
314 stmt_cache: std::sync::Mutex::new(std::collections::HashMap::new()),
315 stmt_counter: std::sync::atomic::AtomicU64::new(0),
316 alive,
317 backend_pid,
318 backend_secret,
319 addr,
320 notification_tx,
321 notification_rx: std::sync::Mutex::new(Some(notification_rx)),
322 state_mutated,
323 broken,
324 dropped_notifications,
325 }
326 }
327
328 pub fn dropped_notifications(&self) -> u64 {
336 self.dropped_notifications
337 .load(std::sync::atomic::Ordering::Relaxed)
338 }
339
340 pub fn take_notification_receiver(
343 &self,
344 ) -> Option<mpsc::Receiver<crate::protocol::types::BackendMsg>> {
345 self.notification_rx
346 .lock()
347 .ok()
348 .and_then(|mut guard| guard.take())
349 }
350
351 pub fn lookup_or_alloc(&self, sql: &str, param_oids: &[u32]) -> (Vec<u8>, bool) {
375 let mut cache = match self.stmt_cache.lock() {
376 Ok(c) => c,
377 Err(poisoned) => poisoned.into_inner(),
378 };
379 if let Some((name, _)) = cache.get(sql) {
380 return (name.as_bytes().to_vec(), false);
381 }
382 if cache.len() >= 256 {
385 if let Some((oldest_key, oldest_name)) = cache
386 .iter()
387 .min_by_key(|(_, (_, counter))| *counter)
388 .map(|(k, (name, _))| (k.clone(), name.clone()))
389 {
390 cache.remove(&oldest_key);
391 let mut close_buf = BytesMut::with_capacity(32);
394 frontend::encode_message(
395 &FrontendMsg::Close {
396 kind: b'S',
397 name: oldest_name.as_bytes(),
398 },
399 &mut close_buf,
400 );
401 frontend::encode_message(&FrontendMsg::Sync, &mut close_buf);
402 let (tx, _rx) = oneshot::channel();
403 let _ = self.request_tx.try_send(PipelineRequest {
404 messages: close_buf,
405 collector: ResponseCollector::Drain,
406 response_tx: tx,
407 });
408 }
409 }
410 let n = self
411 .stmt_counter
412 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
413 let name = format!("s{n}");
414
415 let mut parse_buf = BytesMut::with_capacity(32 + sql.len());
420 frontend::encode_message(
421 &FrontendMsg::Parse {
422 name: name.as_bytes(),
423 sql: sql.as_bytes(),
424 param_oids,
425 },
426 &mut parse_buf,
427 );
428 frontend::encode_message(&FrontendMsg::Sync, &mut parse_buf);
429 let (parse_tx, _parse_rx) = oneshot::channel();
430 match self.request_tx.try_send(PipelineRequest {
431 messages: parse_buf,
432 collector: ResponseCollector::Drain,
433 response_tx: parse_tx,
434 }) {
435 Ok(()) => {
436 cache.insert(sql.to_string(), (name.clone(), n));
437 (name.into_bytes(), false)
438 }
439 Err(_) => {
440 (name.into_bytes(), true)
447 }
448 }
449 }
450
451 pub async fn copy_in(&self, copy_sql: &str, data: &[u8]) -> Result<u64, PgWireError> {
457 use crate::protocol::types::FrontendMsg;
458 const CHUNK_SIZE: usize = 1024 * 1024; let mut buf = BytesMut::with_capacity(copy_sql.len() + data.len().min(CHUNK_SIZE) + 64);
462 frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
463
464 for chunk in data.chunks(CHUNK_SIZE) {
466 frontend::encode_message(&FrontendMsg::CopyData(chunk), &mut buf);
467 }
468 frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
470
471 let resp = self
472 .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
473 .await?;
474 match resp {
475 PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
476 PipelineResponse::Done => Ok(0),
477 }
478 }
479
480 pub async fn copy_in_stream<R: tokio::io::AsyncRead + Unpin>(
493 &self,
494 copy_sql: &str,
495 mut reader: R,
496 ) -> Result<u64, PgWireError> {
497 use tokio::io::AsyncReadExt;
498 const CHUNK_SIZE: usize = 1024 * 1024; let mut buf = BytesMut::with_capacity(copy_sql.len() + 16);
502 frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
503
504 let mut chunk = vec![0u8; CHUNK_SIZE];
506 loop {
507 let n = reader.read(&mut chunk).await?;
508 if n == 0 {
509 break;
510 }
511 frontend::encode_message(&FrontendMsg::CopyData(&chunk[..n]), &mut buf);
512 }
513 frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
514
515 let resp = self
516 .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
517 .await?;
518 match resp {
519 PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
520 PipelineResponse::Done => Ok(0),
521 }
522 }
523
524 pub async fn copy_out(&self, copy_sql: &str) -> Result<Vec<u8>, PgWireError> {
526 use crate::protocol::types::FrontendMsg;
527 let mut buf = BytesMut::new();
528 frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
529
530 let resp = self.submit(buf, ResponseCollector::CopyOut).await?;
531 match resp {
532 PipelineResponse::Rows { rows, .. } => {
533 let mut result = Vec::new();
536 for row in rows {
537 for data in row.iter().flatten() {
538 result.extend_from_slice(data);
539 }
540 }
541 Ok(result)
542 }
543 PipelineResponse::Done => Ok(Vec::new()),
544 }
545 }
546
547 pub fn invalidate_statement(&self, sql: &str) {
550 let mut cache = match self.stmt_cache.lock() {
551 Ok(c) => c,
552 Err(poisoned) => poisoned.into_inner(),
553 };
554 cache.remove(sql);
555 }
556
557 pub fn clear_statement_cache(&self) {
560 let mut cache = match self.stmt_cache.lock() {
561 Ok(c) => c,
562 Err(poisoned) => poisoned.into_inner(),
563 };
564 cache.clear();
565 }
566
567 pub async fn exec_transaction(
569 &self,
570 setup_sql: &str,
571 query_sql: &str,
572 params: &[Option<&[u8]>],
573 param_oids: &[u32],
574 ) -> Result<Vec<RawRow>, PgWireError> {
575 let (stmt_name, needs_parse) = self.lookup_or_alloc(query_sql, param_oids);
576 self.pipeline_transaction(
577 setup_sql,
578 query_sql,
579 params,
580 param_oids,
581 &stmt_name,
582 needs_parse,
583 )
584 .await
585 }
586
587 pub async fn exec_query(
591 &self,
592 sql: &str,
593 params: &[Option<&[u8]>],
594 param_oids: &[u32],
595 ) -> Result<Vec<RawRow>, PgWireError> {
596 let (stmt_name, needs_parse) = self.lookup_or_alloc(sql, param_oids);
597 match self
598 .query(sql, params, param_oids, &stmt_name, needs_parse)
599 .await
600 {
601 Ok(rows) => Ok(rows),
602 Err(PgWireError::Pg(ref pg_err))
603 if !needs_parse && is_stale_statement_error(pg_err) =>
604 {
605 tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
606 self.invalidate_statement(sql);
607 let (stmt_name, _) = self.lookup_or_alloc(sql, param_oids);
608 self.query(sql, params, param_oids, &stmt_name, true).await
609 }
610 Err(e) => Err(e),
611 }
612 }
613
614 const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
617
618 pub async fn submit(
622 &self,
623 messages: BytesMut,
624 collector: ResponseCollector,
625 ) -> Result<PipelineResponse, PgWireError> {
626 let (response_tx, response_rx) = oneshot::channel();
627 let req = PipelineRequest {
628 messages,
629 collector,
630 response_tx,
631 };
632 self.request_tx
633 .send(req)
634 .await
635 .map_err(|_| PgWireError::ConnectionClosed)?;
636 match tokio::time::timeout(Self::REQUEST_TIMEOUT, response_rx).await {
637 Ok(Ok(result)) => result,
638 Ok(Err(_)) => Err(PgWireError::ConnectionClosed),
639 Err(_elapsed) => {
640 tracing::error!(
641 "request timed out after {:?} — reader/writer task may be dead",
642 Self::REQUEST_TIMEOUT
643 );
644 Err(PgWireError::ConnectionClosed)
645 }
646 }
647 }
648
649 pub async fn submit_batch(
659 &self,
660 items: Vec<(BytesMut, ResponseCollector)>,
661 ) -> Result<Vec<Result<PipelineResponse, PgWireError>>, PgWireError> {
662 let mut receivers = Vec::with_capacity(items.len());
663 for (messages, collector) in items {
664 let (response_tx, response_rx) = oneshot::channel();
665 self.request_tx
666 .send(PipelineRequest {
667 messages,
668 collector,
669 response_tx,
670 })
671 .await
672 .map_err(|_| PgWireError::ConnectionClosed)?;
673 receivers.push(response_rx);
674 }
675 let mut results = Vec::with_capacity(receivers.len());
676 for rx in receivers {
677 match tokio::time::timeout(Self::REQUEST_TIMEOUT, rx).await {
678 Ok(Ok(r)) => results.push(r),
679 Ok(Err(_)) => results.push(Err(PgWireError::ConnectionClosed)),
680 Err(_) => {
681 tracing::error!(
682 "submit_batch request timed out after {:?}",
683 Self::REQUEST_TIMEOUT
684 );
685 results.push(Err(PgWireError::ConnectionClosed));
686 }
687 }
688 }
689 Ok(results)
690 }
691
692 pub async fn close(&self) -> Result<(), PgWireError> {
697 if !self.is_alive() {
698 return Ok(());
699 }
700 let mut buf = BytesMut::with_capacity(5);
701 frontend::encode_message(&FrontendMsg::Terminate, &mut buf);
702 match self.submit(buf, ResponseCollector::Drain).await {
707 Ok(_) | Err(PgWireError::ConnectionClosed) => Ok(()),
708 Err(PgWireError::Io(e)) if e.kind() == std::io::ErrorKind::BrokenPipe => Ok(()),
709 Err(e) => Err(e),
710 }
711 }
712
713 pub async fn submit_stream(
716 &self,
717 messages: BytesMut,
718 row_buffer: usize,
719 ) -> Result<
720 (
721 StreamHeader,
722 mpsc::Receiver<Result<StreamedRow, PgWireError>>,
723 ),
724 PgWireError,
725 > {
726 let (header_tx, header_rx) = oneshot::channel();
727 let (row_tx, row_rx) = mpsc::channel(row_buffer);
728 let (response_tx, _response_rx) = oneshot::channel();
729 let req = PipelineRequest {
730 messages,
731 collector: ResponseCollector::Stream { header_tx, row_tx },
732 response_tx,
733 };
734 self.request_tx
735 .send(req)
736 .await
737 .map_err(|_| PgWireError::ConnectionClosed)?;
738 let header = header_rx
739 .await
740 .map_err(|_| PgWireError::ConnectionClosed)??;
741 Ok((header, row_rx))
742 }
743
744 pub async fn pipeline_transaction(
748 &self,
749 setup_sql: &str,
750 query_sql: &str,
751 params: &[Option<&[u8]>],
752 param_oids: &[u32],
753 stmt_name: &[u8],
754 needs_parse: bool,
755 ) -> Result<Vec<RawRow>, PgWireError> {
756 let mut buf = BytesMut::with_capacity(1024);
757
758 frontend::encode_message(&FrontendMsg::Query(setup_sql.as_bytes()), &mut buf);
760
761 let setup_msgs = buf.split();
763
764 let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
766 let result_fmts = [FormatCode::Text];
767
768 if needs_parse {
769 frontend::encode_message(
770 &FrontendMsg::Parse {
771 name: stmt_name,
772 sql: query_sql.as_bytes(),
773 param_oids,
774 },
775 &mut buf,
776 );
777 }
778
779 frontend::encode_message(
780 &FrontendMsg::Bind {
781 portal: b"",
782 statement: stmt_name,
783 param_formats: &text_fmts[..params.len()],
784 params,
785 result_formats: &result_fmts,
786 },
787 &mut buf,
788 );
789
790 frontend::encode_message(
791 &FrontendMsg::Execute {
792 portal: b"",
793 max_rows: 0,
794 },
795 &mut buf,
796 );
797
798 frontend::encode_message(&FrontendMsg::Sync, &mut buf);
799
800 let data_msgs = buf.split();
801
802 let mut commit_buf = BytesMut::with_capacity(32);
805 frontend::encode_message(&FrontendMsg::Query(b"COMMIT"), &mut commit_buf);
806
807 let (setup_tx, setup_rx) = oneshot::channel();
810 let (data_tx, data_rx) = oneshot::channel();
811 let (commit_tx, commit_rx) = oneshot::channel();
812
813 self.request_tx
816 .send(PipelineRequest {
817 messages: setup_msgs,
818 collector: ResponseCollector::Drain,
819 response_tx: setup_tx,
820 })
821 .await
822 .map_err(|_| PgWireError::ConnectionClosed)?;
823
824 self.request_tx
825 .send(PipelineRequest {
826 messages: data_msgs,
827 collector: ResponseCollector::Rows,
828 response_tx: data_tx,
829 })
830 .await
831 .map_err(|_| PgWireError::ConnectionClosed)?;
832
833 self.request_tx
834 .send(PipelineRequest {
835 messages: commit_buf,
836 collector: ResponseCollector::Drain,
837 response_tx: commit_tx,
838 })
839 .await
840 .map_err(|_| PgWireError::ConnectionClosed)?;
841
842 setup_rx
844 .await
845 .map_err(|_| PgWireError::ConnectionClosed)??;
846
847 let data_resp = data_rx.await.map_err(|_| PgWireError::ConnectionClosed)??;
848
849 commit_rx
850 .await
851 .map_err(|_| PgWireError::ConnectionClosed)??;
852
853 match data_resp {
854 PipelineResponse::Rows { rows, .. } => Ok(rows),
855 PipelineResponse::Done => Ok(Vec::new()),
856 }
857 }
858
859 pub async fn query(
861 &self,
862 sql: &str,
863 params: &[Option<&[u8]>],
864 param_oids: &[u32],
865 stmt_name: &[u8],
866 needs_parse: bool,
867 ) -> Result<Vec<RawRow>, PgWireError> {
868 self.query_with_formats(sql, params, param_oids, &[], &[], stmt_name, needs_parse)
869 .await
870 }
871
872 #[allow(clippy::too_many_arguments)]
883 pub async fn query_with_formats(
884 &self,
885 sql: &str,
886 params: &[Option<&[u8]>],
887 param_oids: &[u32],
888 param_formats: &[FormatCode],
889 result_formats: &[FormatCode],
890 stmt_name: &[u8],
891 needs_parse: bool,
892 ) -> Result<Vec<RawRow>, PgWireError> {
893 let mut buf = BytesMut::with_capacity(512);
894
895 let text_param_fmts: Vec<FormatCode>;
897 let param_fmts_slice: &[FormatCode] = if param_formats.is_empty() {
898 text_param_fmts = vec![FormatCode::Text; params.len().max(1)];
899 &text_param_fmts[..params.len()]
900 } else {
901 param_formats
902 };
903 let default_result_fmts = [FormatCode::Text];
904 let result_fmts_slice: &[FormatCode] = if result_formats.is_empty() {
905 &default_result_fmts
906 } else {
907 result_formats
908 };
909
910 if needs_parse {
911 frontend::encode_message(
912 &FrontendMsg::Parse {
913 name: stmt_name,
914 sql: sql.as_bytes(),
915 param_oids,
916 },
917 &mut buf,
918 );
919 }
920
921 frontend::encode_message(
922 &FrontendMsg::Bind {
923 portal: b"",
924 statement: stmt_name,
925 param_formats: param_fmts_slice,
926 params,
927 result_formats: result_fmts_slice,
928 },
929 &mut buf,
930 );
931
932 frontend::encode_message(
933 &FrontendMsg::Execute {
934 portal: b"",
935 max_rows: 0,
936 },
937 &mut buf,
938 );
939
940 frontend::encode_message(&FrontendMsg::Sync, &mut buf);
941
942 let resp = self.submit(buf, ResponseCollector::Rows).await?;
943 match resp {
944 PipelineResponse::Rows { rows, .. } => Ok(rows),
945 PipelineResponse::Done => Ok(Vec::new()),
946 }
947 }
948
949 pub async fn exec_query_with_formats(
952 &self,
953 sql: &str,
954 params: &[Option<&[u8]>],
955 param_oids: &[u32],
956 param_formats: &[FormatCode],
957 result_formats: &[FormatCode],
958 ) -> Result<Vec<RawRow>, PgWireError> {
959 let (stmt_name, needs_parse) = self.lookup_or_alloc(sql, param_oids);
960 match self
961 .query_with_formats(
962 sql,
963 params,
964 param_oids,
965 param_formats,
966 result_formats,
967 &stmt_name,
968 needs_parse,
969 )
970 .await
971 {
972 Ok(rows) => Ok(rows),
973 Err(PgWireError::Pg(ref pg_err))
974 if !needs_parse && is_stale_statement_error(pg_err) =>
975 {
976 tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
977 self.invalidate_statement(sql);
978 let (stmt_name, _) = self.lookup_or_alloc(sql, param_oids);
979 self.query_with_formats(
980 sql,
981 params,
982 param_oids,
983 param_formats,
984 result_formats,
985 &stmt_name,
986 true,
987 )
988 .await
989 }
990 Err(e) => Err(e),
991 }
992 }
993}
994
995async fn writer_task(
1000 mut rx: mpsc::Receiver<PipelineRequest>,
1001 mut stream: tokio::io::WriteHalf<crate::tls::MaybeTlsStream>,
1002 pending: Arc<Mutex<VecDeque<PendingResponse>>>,
1003 pending_notify: Arc<tokio::sync::Notify>,
1004) {
1005 let mut write_buf = BytesMut::with_capacity(8192);
1006
1007 loop {
1008 let first = match rx.recv().await {
1010 Some(req) => req,
1011 None => {
1012 drain_pending_on_exit(&pending).await;
1014 return;
1015 }
1016 };
1017
1018 write_buf.clear();
1020 write_buf.extend_from_slice(&first.messages);
1021
1022 let mut batch: Vec<PendingResponse> = vec![PendingResponse {
1023 collector: first.collector,
1024 response_tx: first.response_tx,
1025 }];
1026
1027 while let Ok(req) = rx.try_recv() {
1029 write_buf.extend_from_slice(&req.messages);
1030 batch.push(PendingResponse {
1031 collector: req.collector,
1032 response_tx: req.response_tx,
1033 });
1034 }
1035
1036 let write_result = stream.write_all(&write_buf).await;
1040 let write_err = match write_result {
1041 Ok(_) => stream.flush().await.err(),
1042 Err(e) => Some(e),
1043 };
1044
1045 if let Some(e) = write_err {
1046 tracing::error!("Writer error: {e}");
1047 let msg = e.to_string();
1048 for p in batch {
1049 let _ = p.response_tx.send(Err(PgWireError::Io(std::io::Error::new(
1050 std::io::ErrorKind::BrokenPipe,
1051 msg.clone(),
1052 ))));
1053 }
1054 drain_pending_on_exit(&pending).await;
1056 return;
1057 }
1058
1059 {
1061 let mut pq = pending.lock().await;
1062 for p in batch {
1063 pq.push_back(p);
1064 }
1065 }
1066 pending_notify.notify_one();
1068 }
1069}
1070
1071async fn drain_pending_on_exit(pending: &Arc<Mutex<VecDeque<PendingResponse>>>) {
1074 let mut pq = pending.lock().await;
1075 while let Some(pr) = pq.pop_front() {
1076 let _ = pr.response_tx.send(Err(PgWireError::ConnectionClosed));
1077 }
1078}
1079
1080async fn reader_task(
1085 mut stream: tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1086 pending: Arc<Mutex<VecDeque<PendingResponse>>>,
1087 pending_notify: Arc<tokio::sync::Notify>,
1088 notification_tx: mpsc::Sender<BackendMsg>,
1089 state_mutated: Arc<std::sync::atomic::AtomicBool>,
1090 dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
1091) {
1092 let mut recv_buf = BytesMut::with_capacity(32 * 1024);
1093
1094 loop {
1095 let pr = loop {
1097 {
1098 let mut pq = pending.lock().await;
1099 if let Some(pr) = pq.pop_front() {
1100 break pr;
1101 }
1102 }
1103 pending_notify.notified().await;
1105 };
1106
1107 let result = match pr.collector {
1109 ResponseCollector::Rows => {
1110 collect_rows(
1111 &mut stream,
1112 &mut recv_buf,
1113 ¬ification_tx,
1114 &state_mutated,
1115 &dropped_notifications,
1116 )
1117 .await
1118 }
1119 ResponseCollector::Drain => {
1120 drain_until_ready(&mut stream, &mut recv_buf, Some(&state_mutated))
1121 .await
1122 .map(|_| PipelineResponse::Done)
1123 }
1124 ResponseCollector::Stream { header_tx, row_tx } => {
1125 stream_rows(
1126 &mut stream,
1127 &mut recv_buf,
1128 header_tx,
1129 row_tx,
1130 ¬ification_tx,
1131 &state_mutated,
1132 &dropped_notifications,
1133 )
1134 .await;
1135 Ok(PipelineResponse::Done)
1136 }
1137 ResponseCollector::CopyIn { .. } => {
1138 collect_copy_in_response(&mut stream, &mut recv_buf, &state_mutated).await
1139 }
1140 ResponseCollector::CopyOut => {
1141 collect_copy_out(&mut stream, &mut recv_buf, &state_mutated).await
1142 }
1143 };
1144
1145 let _ = pr.response_tx.send(result);
1147 }
1148}
1149
1150async fn read_msg(
1151 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1152 buf: &mut BytesMut,
1153) -> Result<BackendMsg, PgWireError> {
1154 loop {
1155 if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1156 return Ok(msg);
1157 }
1158 let n = stream.read_buf(buf).await?;
1159 if n == 0 {
1160 if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1164 return Ok(msg);
1165 }
1166 return Err(PgWireError::ConnectionClosed);
1167 }
1168 }
1169}
1170
1171fn note_rfq_status(status: u8, state_mutated: &std::sync::atomic::AtomicBool) {
1175 if status != b'I' {
1176 state_mutated.store(true, std::sync::atomic::Ordering::Release);
1177 }
1178}
1179
1180async fn collect_rows(
1181 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1182 buf: &mut BytesMut,
1183 notification_tx: &mpsc::Sender<BackendMsg>,
1184 state_mutated: &std::sync::atomic::AtomicBool,
1185 dropped_notifications: &std::sync::atomic::AtomicU64,
1186) -> Result<PipelineResponse, PgWireError> {
1187 let mut rows = Vec::new();
1188 let mut fields = Vec::new();
1189 let mut command_tag = String::new();
1190 loop {
1191 let msg = read_msg(stream, buf).await?;
1192 match msg {
1193 BackendMsg::DataRow(row) => rows.push(row),
1194 BackendMsg::RowDescription { fields: f } => fields = f,
1195 BackendMsg::CommandComplete { tag } => command_tag = tag,
1196 BackendMsg::ReadyForQuery { status } => {
1197 note_rfq_status(status, state_mutated);
1198 return Ok(PipelineResponse::Rows {
1199 fields,
1200 rows,
1201 command_tag,
1202 });
1203 }
1204 BackendMsg::ErrorResponse { fields } => {
1205 drain_until_ready(stream, buf, Some(state_mutated)).await?;
1206 return Err(PgWireError::Pg(fields));
1207 }
1208 msg @ BackendMsg::NotificationResponse { .. } => {
1209 #[allow(clippy::collapsible_match)]
1211 if notification_tx.try_send(msg).is_err() {
1212 dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1213 tracing::warn!("notification channel full, dropping notification");
1214 }
1215 }
1216 BackendMsg::ParseComplete
1217 | BackendMsg::BindComplete
1218 | BackendMsg::NoData
1219 | BackendMsg::NoticeResponse { .. }
1220 | BackendMsg::EmptyQueryResponse => {}
1221 _ => {}
1222 }
1223 }
1224}
1225
1226async fn drain_until_ready(
1227 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1228 buf: &mut BytesMut,
1229 state_mutated: Option<&std::sync::atomic::AtomicBool>,
1230) -> Result<(), PgWireError> {
1231 loop {
1232 let msg = read_msg(stream, buf).await?;
1233 if let BackendMsg::ReadyForQuery { status } = msg {
1234 if let Some(sm) = state_mutated {
1235 note_rfq_status(status, sm);
1236 }
1237 return Ok(());
1238 }
1239 if let BackendMsg::ErrorResponse { ref fields } = msg {
1240 tracing::warn!("Error in drain: {}: {}", fields.code, fields.message);
1241 }
1242 }
1243}
1244
1245async fn stream_rows(
1247 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1248 buf: &mut BytesMut,
1249 header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
1250 row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
1251 notification_tx: &mpsc::Sender<BackendMsg>,
1252 state_mutated: &std::sync::atomic::AtomicBool,
1253 dropped_notifications: &std::sync::atomic::AtomicU64,
1254) {
1255 let mut header_tx = Some(header_tx);
1256 let mut fields = Vec::new();
1257 loop {
1258 let msg = match read_msg(stream, buf).await {
1259 Ok(msg) => msg,
1260 Err(e) => {
1261 if let Some(htx) = header_tx.take() {
1262 let _ = htx.send(Err(e));
1263 } else {
1264 let _ = row_tx.send(Err(e)).await;
1265 }
1266 return;
1267 }
1268 };
1269 match msg {
1270 BackendMsg::RowDescription { fields: f } => {
1271 fields = f;
1272 }
1273 BackendMsg::DataRow(row) => {
1274 if let Some(htx) = header_tx.take() {
1275 let _ = htx.send(Ok(StreamHeader {
1276 fields: fields.clone(),
1277 }));
1278 }
1279 if row_tx.send(Ok(row)).await.is_err() {
1280 let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1281 return;
1282 }
1283 }
1284 BackendMsg::CommandComplete { .. } => {
1285 if let Some(htx) = header_tx.take() {
1286 let _ = htx.send(Ok(StreamHeader {
1287 fields: std::mem::take(&mut fields),
1288 }));
1289 }
1290 }
1291 BackendMsg::ReadyForQuery { status } => {
1292 note_rfq_status(status, state_mutated);
1293 if let Some(htx) = header_tx.take() {
1294 let _ = htx.send(Ok(StreamHeader {
1295 fields: std::mem::take(&mut fields),
1296 }));
1297 }
1298 return;
1299 }
1300 BackendMsg::ErrorResponse { fields: err } => {
1301 if let Some(htx) = header_tx.take() {
1302 let _ = htx.send(Err(PgWireError::Pg(err)));
1303 } else {
1304 let _ = row_tx.send(Err(PgWireError::Pg(err))).await;
1305 }
1306 let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1307 return;
1308 }
1309 msg @ BackendMsg::NotificationResponse { .. } => {
1310 #[allow(clippy::collapsible_match)]
1311 if notification_tx.try_send(msg).is_err() {
1312 dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1313 tracing::warn!("notification channel full, dropping notification");
1314 }
1315 }
1316 BackendMsg::ParseComplete
1317 | BackendMsg::BindComplete
1318 | BackendMsg::NoData
1319 | BackendMsg::PortalSuspended
1320 | BackendMsg::NoticeResponse { .. }
1321 | BackendMsg::EmptyQueryResponse => {}
1322 _ => {}
1323 }
1324 }
1325}
1326
1327async fn collect_copy_in_response(
1330 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1331 buf: &mut BytesMut,
1332 state_mutated: &std::sync::atomic::AtomicBool,
1333) -> Result<PipelineResponse, PgWireError> {
1334 let mut command_tag = String::new();
1335 loop {
1336 let msg = read_msg(stream, buf).await?;
1337 match msg {
1338 BackendMsg::CopyInResponse { .. } => {}
1339 BackendMsg::CommandComplete { tag } => command_tag = tag,
1340 BackendMsg::ReadyForQuery { status } => {
1341 note_rfq_status(status, state_mutated);
1342 return Ok(PipelineResponse::Rows {
1343 fields: Vec::new(),
1344 rows: Vec::new(),
1345 command_tag,
1346 });
1347 }
1348 BackendMsg::ErrorResponse { fields } => {
1349 drain_until_ready(stream, buf, Some(state_mutated)).await?;
1350 return Err(PgWireError::Pg(fields));
1351 }
1352 _ => {}
1353 }
1354 }
1355}
1356
1357async fn collect_copy_out(
1359 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1360 buf: &mut BytesMut,
1361 state_mutated: &std::sync::atomic::AtomicBool,
1362) -> Result<PipelineResponse, PgWireError> {
1363 let mut data_chunks: Vec<RawRow> = Vec::new();
1364 let mut command_tag = String::new();
1365 loop {
1366 let msg = read_msg(stream, buf).await?;
1367 match msg {
1368 BackendMsg::CopyOutResponse { .. } => {}
1369 BackendMsg::CopyData { data } => {
1370 let body = bytes::Bytes::from(data);
1371 data_chunks.push(RawRow::from_full_body(body));
1372 }
1373 BackendMsg::CopyDone => {}
1374 BackendMsg::CommandComplete { tag } => command_tag = tag,
1375 BackendMsg::ReadyForQuery { status } => {
1376 note_rfq_status(status, state_mutated);
1377 return Ok(PipelineResponse::Rows {
1378 fields: Vec::new(),
1379 rows: data_chunks,
1380 command_tag,
1381 });
1382 }
1383 BackendMsg::ErrorResponse { fields } => {
1384 drain_until_ready(stream, buf, Some(state_mutated)).await?;
1385 return Err(PgWireError::Pg(fields));
1386 }
1387 _ => {}
1388 }
1389 }
1390}
1391
1392fn is_stale_statement_error(err: &crate::protocol::types::PgError) -> bool {
1396 matches!(err.code.as_str(), "26000" | "0A000")
1397}
1398
1399fn parse_copy_count(tag: &str) -> u64 {
1400 tag.strip_prefix("COPY ")
1402 .and_then(|s| s.parse::<u64>().ok())
1403 .unwrap_or(0)
1404}
1405
1406impl WireConn {
1408 pub(crate) fn into_stream(self) -> crate::tls::MaybeTlsStream {
1409 self.stream
1410 }
1411}
1412
1413#[cfg(test)]
1414mod tests {
1415 use super::*;
1416
1417 #[tokio::test]
1420 async fn try_enqueue_rollback_returns_false_when_channel_full() {
1421 let (tx, _rx) = mpsc::channel::<PipelineRequest>(2);
1422 let mut filled = false;
1426 for _ in 0..16 {
1427 if !try_enqueue_rollback(&tx) {
1428 filled = true;
1429 break;
1430 }
1431 }
1432 assert!(
1433 filled,
1434 "expected try_enqueue_rollback to eventually return false on a full channel"
1435 );
1436 assert!(
1437 !try_enqueue_rollback(&tx),
1438 "subsequent calls on a full channel must keep returning false"
1439 );
1440 }
1441
1442 #[tokio::test]
1445 async fn try_enqueue_rollback_returns_false_when_channel_closed() {
1446 let (tx, rx) = mpsc::channel::<PipelineRequest>(8);
1447 drop(rx);
1448 assert!(
1449 !try_enqueue_rollback(&tx),
1450 "try_enqueue_rollback must return false when the receiver has been dropped"
1451 );
1452 }
1453
1454 #[tokio::test]
1458 async fn try_enqueue_rollback_returns_true_and_enqueues_query() {
1459 let (tx, mut rx) = mpsc::channel::<PipelineRequest>(2);
1460 assert!(try_enqueue_rollback(&tx));
1461 let req = rx.recv().await.expect("request should be received");
1462 assert_eq!(
1463 req.messages.first().copied(),
1464 Some(b'Q'),
1465 "queued request should be a simple Query message"
1466 );
1467 assert!(
1470 req.messages.windows(8).any(|w| w == b"ROLLBACK"),
1471 "queued request should contain the ROLLBACK statement text"
1472 );
1473 }
1474}