1use std::collections::VecDeque;
9use std::sync::Arc;
10
11use bytes::BytesMut;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::sync::{mpsc, oneshot, Mutex};
14
15use crate::connection::WireConn;
16use crate::error::PgWireError;
17use crate::protocol::backend;
18use crate::protocol::frontend;
19use crate::protocol::types::{BackendMsg, FormatCode, FrontendMsg, RawRow};
20
21pub(crate) struct PipelineRequest {
28 pub(crate) messages: BytesMut,
29 pub(crate) collector: ResponseCollector,
30 pub(crate) response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
31}
32
33#[allow(dead_code)]
35#[non_exhaustive]
36pub enum ResponseCollector {
37 Rows,
39 Drain,
41 Stream {
43 header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
45 row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
47 },
48 CopyIn {
50 data: Vec<u8>,
52 },
53 CopyOut,
55}
56
57#[non_exhaustive]
59pub enum PipelineResponse {
60 Rows {
62 fields: Vec<crate::protocol::types::FieldDescription>,
64 rows: Vec<RawRow>,
66 command_tag: String,
68 },
69 Done,
72}
73
74#[derive(Debug, Clone)]
76pub struct StreamHeader {
77 pub fields: Vec<crate::protocol::types::FieldDescription>,
79}
80
81pub type StreamedRow = RawRow;
83
84pub struct AsyncConn {
90 request_tx: mpsc::Sender<PipelineRequest>,
91 stmt_cache: std::sync::Mutex<std::collections::HashMap<String, (String, u64)>>,
92 stmt_counter: std::sync::atomic::AtomicU64,
93 alive: Arc<std::sync::atomic::AtomicBool>,
94 backend_pid: i32,
95 backend_secret: i32,
96 addr: String,
97 #[allow(dead_code)]
100 notification_tx: mpsc::Sender<crate::protocol::types::BackendMsg>,
101 notification_rx: std::sync::Mutex<Option<mpsc::Receiver<crate::protocol::types::BackendMsg>>>,
102 state_mutated: Arc<std::sync::atomic::AtomicBool>,
112 broken: Arc<std::sync::atomic::AtomicBool>,
118 dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
123}
124
125impl std::fmt::Debug for AsyncConn {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 f.debug_struct("AsyncConn")
128 .field("addr", &self.addr)
129 .field("backend_pid", &self.backend_pid)
130 .field("alive", &self.is_alive())
131 .finish()
132 }
133}
134
135impl AsyncConn {
136 pub fn is_alive(&self) -> bool {
138 self.alive.load(std::sync::atomic::Ordering::Relaxed)
139 }
140
141 pub fn backend_pid(&self) -> i32 {
143 self.backend_pid
144 }
145
146 pub fn addr(&self) -> &str {
148 &self.addr
149 }
150
151 pub fn cancel_token(&self) -> crate::cancel::CancelToken {
153 crate::cancel::CancelToken::new(self.addr.clone(), self.backend_pid, self.backend_secret)
154 }
155
156 pub fn mark_state_mutated(&self) {
162 self.state_mutated
163 .store(true, std::sync::atomic::Ordering::Release);
164 }
165
166 pub fn take_state_mutated(&self) -> bool {
169 self.state_mutated
170 .swap(false, std::sync::atomic::Ordering::AcqRel)
171 }
172
173 pub fn is_state_mutated(&self) -> bool {
175 self.state_mutated
176 .load(std::sync::atomic::Ordering::Acquire)
177 }
178
179 pub fn mark_broken(&self) {
186 self.broken
187 .store(true, std::sync::atomic::Ordering::Release);
188 }
189
190 pub fn is_broken(&self) -> bool {
194 self.broken.load(std::sync::atomic::Ordering::Acquire)
195 }
196
197 #[doc(hidden)]
205 pub fn __force_mark_dead_for_test(&self) {
206 self.alive
207 .store(false, std::sync::atomic::Ordering::Release);
208 }
209
210 pub fn enqueue_rollback(&self) -> bool {
224 if !self.is_alive() {
225 return false;
226 }
227 try_enqueue_rollback(&self.request_tx)
228 }
229}
230
231fn try_enqueue_rollback(request_tx: &mpsc::Sender<PipelineRequest>) -> bool {
236 let mut buf = BytesMut::with_capacity(16);
237 frontend::encode_message(&FrontendMsg::Query(b"ROLLBACK"), &mut buf);
238 let (tx, _rx) = oneshot::channel();
239 request_tx
240 .try_send(PipelineRequest {
241 messages: buf,
242 collector: ResponseCollector::Drain,
243 response_tx: tx,
244 })
245 .is_ok()
246}
247
248struct PendingResponse {
249 collector: ResponseCollector,
250 response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
251}
252
253impl AsyncConn {
254 pub fn new(conn: WireConn) -> Self {
257 let backend_pid = conn.pid;
258 let backend_secret = conn.secret;
259 let addr = conn
261 .stream
262 .peer_addr()
263 .map(|a| a.to_string())
264 .unwrap_or_default();
265
266 let (notification_tx, notification_rx) = mpsc::channel(4096);
267 let (request_tx, request_rx) = mpsc::channel::<PipelineRequest>(256);
268 let pending: Arc<Mutex<VecDeque<PendingResponse>>> = Arc::new(Mutex::new(VecDeque::new()));
269 let pending_notify = Arc::new(tokio::sync::Notify::new());
270 let alive = Arc::new(std::sync::atomic::AtomicBool::new(true));
271 let state_mutated = Arc::new(std::sync::atomic::AtomicBool::new(false));
272 let broken = Arc::new(std::sync::atomic::AtomicBool::new(false));
273 let dropped_notifications = Arc::new(std::sync::atomic::AtomicU64::new(0));
274
275 let (stream_read, stream_write) = tokio::io::split(conn.into_stream());
276
277 {
279 let pending = Arc::clone(&pending);
280 let pending_notify = Arc::clone(&pending_notify);
281 let alive = Arc::clone(&alive);
282 tokio::spawn(async move {
283 writer_task(request_rx, stream_write, pending, pending_notify).await;
284 alive.store(false, std::sync::atomic::Ordering::Relaxed);
285 tracing::warn!("pg-wired writer task exited");
286 });
287 }
288
289 {
291 let pending = Arc::clone(&pending);
292 let pending_notify = Arc::clone(&pending_notify);
293 let alive_clone = Arc::clone(&alive);
294 let state_mutated = Arc::clone(&state_mutated);
295 let ntf_tx = notification_tx.clone();
296 let dropped = Arc::clone(&dropped_notifications);
297 tokio::spawn(async move {
298 reader_task(
299 stream_read,
300 pending,
301 pending_notify,
302 ntf_tx,
303 state_mutated,
304 dropped,
305 )
306 .await;
307 alive_clone.store(false, std::sync::atomic::Ordering::Relaxed);
308 tracing::warn!("pg-wired reader task exited");
309 });
310 }
311
312 Self {
313 request_tx,
314 stmt_cache: std::sync::Mutex::new(std::collections::HashMap::new()),
315 stmt_counter: std::sync::atomic::AtomicU64::new(0),
316 alive,
317 backend_pid,
318 backend_secret,
319 addr,
320 notification_tx,
321 notification_rx: std::sync::Mutex::new(Some(notification_rx)),
322 state_mutated,
323 broken,
324 dropped_notifications,
325 }
326 }
327
328 pub fn dropped_notifications(&self) -> u64 {
336 self.dropped_notifications
337 .load(std::sync::atomic::Ordering::Relaxed)
338 }
339
340 pub fn take_notification_receiver(
343 &self,
344 ) -> Option<mpsc::Receiver<crate::protocol::types::BackendMsg>> {
345 self.notification_rx
346 .lock()
347 .ok()
348 .and_then(|mut guard| guard.take())
349 }
350
351 pub fn lookup_or_alloc(&self, sql: &str, _param_oids: &[u32]) -> (Vec<u8>, bool) {
382 let cache = match self.stmt_cache.lock() {
383 Ok(c) => c,
384 Err(poisoned) => poisoned.into_inner(),
385 };
386 if let Some((name, _)) = cache.get(sql) {
387 return (name.as_bytes().to_vec(), false);
388 }
389 let n = self
394 .stmt_counter
395 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
396 let name = format!("s{n}");
397 (name.into_bytes(), true)
398 }
399
400 pub fn cache_statement(&self, sql: &str, name: &[u8]) {
420 let Ok(name_str) = std::str::from_utf8(name) else {
421 return;
422 };
423 let counter = name_str
424 .strip_prefix('s')
425 .and_then(|s| s.parse::<u64>().ok())
426 .unwrap_or_else(|| self.stmt_counter.load(std::sync::atomic::Ordering::Relaxed));
427 let mut cache = match self.stmt_cache.lock() {
428 Ok(c) => c,
429 Err(poisoned) => poisoned.into_inner(),
430 };
431 if cache.contains_key(sql) {
432 return;
433 }
434 if cache.len() >= 256 {
435 if let Some((oldest_key, oldest_name)) = cache
436 .iter()
437 .min_by_key(|(_, (_, counter))| *counter)
438 .map(|(k, (name, _))| (k.clone(), name.clone()))
439 {
440 cache.remove(&oldest_key);
441 let mut close_buf = BytesMut::with_capacity(32);
442 frontend::encode_message(
443 &FrontendMsg::Close {
444 kind: b'S',
445 name: oldest_name.as_bytes(),
446 },
447 &mut close_buf,
448 );
449 frontend::encode_message(&FrontendMsg::Sync, &mut close_buf);
450 let (tx, _rx) = oneshot::channel();
451 let _ = self.request_tx.try_send(PipelineRequest {
452 messages: close_buf,
453 collector: ResponseCollector::Drain,
454 response_tx: tx,
455 });
456 }
457 }
458 cache.insert(sql.to_string(), (name_str.to_string(), counter));
459 }
460
461 pub async fn copy_in(&self, copy_sql: &str, data: &[u8]) -> Result<u64, PgWireError> {
467 use crate::protocol::types::FrontendMsg;
468 const CHUNK_SIZE: usize = 1024 * 1024; let mut buf = BytesMut::with_capacity(copy_sql.len() + data.len().min(CHUNK_SIZE) + 64);
472 frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
473
474 for chunk in data.chunks(CHUNK_SIZE) {
476 frontend::encode_message(&FrontendMsg::CopyData(chunk), &mut buf);
477 }
478 frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
480
481 let resp = self
482 .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
483 .await?;
484 match resp {
485 PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
486 PipelineResponse::Done => Ok(0),
487 }
488 }
489
490 pub async fn copy_in_stream<R: tokio::io::AsyncRead + Unpin>(
503 &self,
504 copy_sql: &str,
505 mut reader: R,
506 ) -> Result<u64, PgWireError> {
507 use tokio::io::AsyncReadExt;
508 const CHUNK_SIZE: usize = 1024 * 1024; let mut buf = BytesMut::with_capacity(copy_sql.len() + 16);
512 frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
513
514 let mut chunk = vec![0u8; CHUNK_SIZE];
516 loop {
517 let n = reader.read(&mut chunk).await?;
518 if n == 0 {
519 break;
520 }
521 frontend::encode_message(&FrontendMsg::CopyData(&chunk[..n]), &mut buf);
522 }
523 frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
524
525 let resp = self
526 .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
527 .await?;
528 match resp {
529 PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
530 PipelineResponse::Done => Ok(0),
531 }
532 }
533
534 pub async fn copy_out(&self, copy_sql: &str) -> Result<Vec<u8>, PgWireError> {
536 use crate::protocol::types::FrontendMsg;
537 let mut buf = BytesMut::new();
538 frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
539
540 let resp = self.submit(buf, ResponseCollector::CopyOut).await?;
541 match resp {
542 PipelineResponse::Rows { rows, .. } => {
543 let mut result = Vec::new();
546 for row in rows {
547 for data in row.iter().flatten() {
548 result.extend_from_slice(data);
549 }
550 }
551 Ok(result)
552 }
553 PipelineResponse::Done => Ok(Vec::new()),
554 }
555 }
556
557 pub fn invalidate_statement(&self, sql: &str) {
560 let mut cache = match self.stmt_cache.lock() {
561 Ok(c) => c,
562 Err(poisoned) => poisoned.into_inner(),
563 };
564 cache.remove(sql);
565 }
566
567 pub fn clear_statement_cache(&self) {
570 let mut cache = match self.stmt_cache.lock() {
571 Ok(c) => c,
572 Err(poisoned) => poisoned.into_inner(),
573 };
574 cache.clear();
575 }
576
577 pub async fn exec_transaction(
586 &self,
587 setup_sql: &str,
588 query_sql: &str,
589 params: &[Option<&[u8]>],
590 param_oids: &[u32],
591 ) -> Result<Vec<RawRow>, PgWireError> {
592 let (stmt_name, needs_parse) = self.lookup_or_alloc(query_sql, param_oids);
593 match self
594 .pipeline_transaction(
595 setup_sql,
596 query_sql,
597 params,
598 param_oids,
599 &stmt_name,
600 needs_parse,
601 )
602 .await
603 {
604 Ok(rows) => {
605 if needs_parse {
606 self.cache_statement(query_sql, &stmt_name);
607 }
608 Ok(rows)
609 }
610 Err(PgWireError::Pg(ref pg_err))
611 if !needs_parse && is_stale_statement_error(pg_err) =>
612 {
613 tracing::debug!(
614 sql = query_sql,
615 "prepared statement invalidated — re-parsing in transaction"
616 );
617 self.invalidate_statement(query_sql);
618 let (stmt_name, _) = self.lookup_or_alloc(query_sql, param_oids);
619 let result = self
620 .pipeline_transaction(
621 setup_sql, query_sql, params, param_oids, &stmt_name, true,
622 )
623 .await;
624 if result.is_ok() {
625 self.cache_statement(query_sql, &stmt_name);
626 }
627 result
628 }
629 Err(e) => Err(e),
630 }
631 }
632
633 pub async fn exec_query(
637 &self,
638 sql: &str,
639 params: &[Option<&[u8]>],
640 param_oids: &[u32],
641 ) -> Result<Vec<RawRow>, PgWireError> {
642 let (stmt_name, needs_parse) = self.lookup_or_alloc(sql, param_oids);
643 match self
644 .query(sql, params, param_oids, &stmt_name, needs_parse)
645 .await
646 {
647 Ok(rows) => {
648 if needs_parse {
649 self.cache_statement(sql, &stmt_name);
650 }
651 Ok(rows)
652 }
653 Err(PgWireError::Pg(ref pg_err))
654 if !needs_parse && is_stale_statement_error(pg_err) =>
655 {
656 tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
657 self.invalidate_statement(sql);
658 let (stmt_name, _) = self.lookup_or_alloc(sql, param_oids);
659 let result = self.query(sql, params, param_oids, &stmt_name, true).await;
660 if result.is_ok() {
661 self.cache_statement(sql, &stmt_name);
662 }
663 result
664 }
665 Err(e) => Err(e),
666 }
667 }
668
669 const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
672
673 pub async fn submit(
677 &self,
678 messages: BytesMut,
679 collector: ResponseCollector,
680 ) -> Result<PipelineResponse, PgWireError> {
681 let (response_tx, response_rx) = oneshot::channel();
682 let req = PipelineRequest {
683 messages,
684 collector,
685 response_tx,
686 };
687 self.request_tx
688 .send(req)
689 .await
690 .map_err(|_| PgWireError::ConnectionClosed)?;
691 match tokio::time::timeout(Self::REQUEST_TIMEOUT, response_rx).await {
692 Ok(Ok(result)) => result,
693 Ok(Err(_)) => Err(PgWireError::ConnectionClosed),
694 Err(_elapsed) => {
695 tracing::error!(
696 "request timed out after {:?} — reader/writer task may be dead",
697 Self::REQUEST_TIMEOUT
698 );
699 Err(PgWireError::ConnectionClosed)
700 }
701 }
702 }
703
704 pub async fn submit_batch(
714 &self,
715 items: Vec<(BytesMut, ResponseCollector)>,
716 ) -> Result<Vec<Result<PipelineResponse, PgWireError>>, PgWireError> {
717 let mut receivers = Vec::with_capacity(items.len());
718 for (messages, collector) in items {
719 let (response_tx, response_rx) = oneshot::channel();
720 self.request_tx
721 .send(PipelineRequest {
722 messages,
723 collector,
724 response_tx,
725 })
726 .await
727 .map_err(|_| PgWireError::ConnectionClosed)?;
728 receivers.push(response_rx);
729 }
730 let mut results = Vec::with_capacity(receivers.len());
731 for rx in receivers {
732 match tokio::time::timeout(Self::REQUEST_TIMEOUT, rx).await {
733 Ok(Ok(r)) => results.push(r),
734 Ok(Err(_)) => results.push(Err(PgWireError::ConnectionClosed)),
735 Err(_) => {
736 tracing::error!(
737 "submit_batch request timed out after {:?}",
738 Self::REQUEST_TIMEOUT
739 );
740 results.push(Err(PgWireError::ConnectionClosed));
741 }
742 }
743 }
744 Ok(results)
745 }
746
747 pub async fn close(&self) -> Result<(), PgWireError> {
752 if !self.is_alive() {
753 return Ok(());
754 }
755 let mut buf = BytesMut::with_capacity(5);
756 frontend::encode_message(&FrontendMsg::Terminate, &mut buf);
757 match self.submit(buf, ResponseCollector::Drain).await {
762 Ok(_) | Err(PgWireError::ConnectionClosed) => Ok(()),
763 Err(PgWireError::Io(e)) if e.kind() == std::io::ErrorKind::BrokenPipe => Ok(()),
764 Err(e) => Err(e),
765 }
766 }
767
768 pub async fn submit_stream(
771 &self,
772 messages: BytesMut,
773 row_buffer: usize,
774 ) -> Result<
775 (
776 StreamHeader,
777 mpsc::Receiver<Result<StreamedRow, PgWireError>>,
778 ),
779 PgWireError,
780 > {
781 let (header_tx, header_rx) = oneshot::channel();
782 let (row_tx, row_rx) = mpsc::channel(row_buffer);
783 let (response_tx, _response_rx) = oneshot::channel();
784 let req = PipelineRequest {
785 messages,
786 collector: ResponseCollector::Stream { header_tx, row_tx },
787 response_tx,
788 };
789 self.request_tx
790 .send(req)
791 .await
792 .map_err(|_| PgWireError::ConnectionClosed)?;
793 let header = header_rx
794 .await
795 .map_err(|_| PgWireError::ConnectionClosed)??;
796 Ok((header, row_rx))
797 }
798
799 pub async fn pipeline_transaction(
803 &self,
804 setup_sql: &str,
805 query_sql: &str,
806 params: &[Option<&[u8]>],
807 param_oids: &[u32],
808 stmt_name: &[u8],
809 needs_parse: bool,
810 ) -> Result<Vec<RawRow>, PgWireError> {
811 let mut buf = BytesMut::with_capacity(1024);
812
813 frontend::encode_message(&FrontendMsg::Query(setup_sql.as_bytes()), &mut buf);
815
816 let setup_msgs = buf.split();
818
819 let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
821 let result_fmts = [FormatCode::Text];
822
823 if needs_parse {
824 frontend::encode_message(
825 &FrontendMsg::Parse {
826 name: stmt_name,
827 sql: query_sql.as_bytes(),
828 param_oids,
829 },
830 &mut buf,
831 );
832 }
833
834 frontend::encode_message(
835 &FrontendMsg::Bind {
836 portal: b"",
837 statement: stmt_name,
838 param_formats: &text_fmts[..params.len()],
839 params,
840 result_formats: &result_fmts,
841 },
842 &mut buf,
843 );
844
845 frontend::encode_message(
846 &FrontendMsg::Execute {
847 portal: b"",
848 max_rows: 0,
849 },
850 &mut buf,
851 );
852
853 frontend::encode_message(&FrontendMsg::Sync, &mut buf);
854
855 let data_msgs = buf.split();
856
857 let mut commit_buf = BytesMut::with_capacity(32);
860 frontend::encode_message(&FrontendMsg::Query(b"COMMIT"), &mut commit_buf);
861
862 let (setup_tx, setup_rx) = oneshot::channel();
865 let (data_tx, data_rx) = oneshot::channel();
866 let (commit_tx, commit_rx) = oneshot::channel();
867
868 self.request_tx
871 .send(PipelineRequest {
872 messages: setup_msgs,
873 collector: ResponseCollector::Drain,
874 response_tx: setup_tx,
875 })
876 .await
877 .map_err(|_| PgWireError::ConnectionClosed)?;
878
879 self.request_tx
880 .send(PipelineRequest {
881 messages: data_msgs,
882 collector: ResponseCollector::Rows,
883 response_tx: data_tx,
884 })
885 .await
886 .map_err(|_| PgWireError::ConnectionClosed)?;
887
888 self.request_tx
889 .send(PipelineRequest {
890 messages: commit_buf,
891 collector: ResponseCollector::Drain,
892 response_tx: commit_tx,
893 })
894 .await
895 .map_err(|_| PgWireError::ConnectionClosed)?;
896
897 setup_rx
899 .await
900 .map_err(|_| PgWireError::ConnectionClosed)??;
901
902 let data_resp = data_rx.await.map_err(|_| PgWireError::ConnectionClosed)??;
903
904 commit_rx
905 .await
906 .map_err(|_| PgWireError::ConnectionClosed)??;
907
908 match data_resp {
909 PipelineResponse::Rows { rows, .. } => Ok(rows),
910 PipelineResponse::Done => Ok(Vec::new()),
911 }
912 }
913
914 pub async fn query(
916 &self,
917 sql: &str,
918 params: &[Option<&[u8]>],
919 param_oids: &[u32],
920 stmt_name: &[u8],
921 needs_parse: bool,
922 ) -> Result<Vec<RawRow>, PgWireError> {
923 self.query_with_formats(sql, params, param_oids, &[], &[], stmt_name, needs_parse)
924 .await
925 }
926
927 #[allow(clippy::too_many_arguments)]
938 pub async fn query_with_formats(
939 &self,
940 sql: &str,
941 params: &[Option<&[u8]>],
942 param_oids: &[u32],
943 param_formats: &[FormatCode],
944 result_formats: &[FormatCode],
945 stmt_name: &[u8],
946 needs_parse: bool,
947 ) -> Result<Vec<RawRow>, PgWireError> {
948 let mut buf = BytesMut::with_capacity(512);
949
950 let text_param_fmts: Vec<FormatCode>;
952 let param_fmts_slice: &[FormatCode] = if param_formats.is_empty() {
953 text_param_fmts = vec![FormatCode::Text; params.len().max(1)];
954 &text_param_fmts[..params.len()]
955 } else {
956 param_formats
957 };
958 let default_result_fmts = [FormatCode::Text];
959 let result_fmts_slice: &[FormatCode] = if result_formats.is_empty() {
960 &default_result_fmts
961 } else {
962 result_formats
963 };
964
965 if needs_parse {
966 frontend::encode_message(
967 &FrontendMsg::Parse {
968 name: stmt_name,
969 sql: sql.as_bytes(),
970 param_oids,
971 },
972 &mut buf,
973 );
974 }
975
976 frontend::encode_message(
977 &FrontendMsg::Bind {
978 portal: b"",
979 statement: stmt_name,
980 param_formats: param_fmts_slice,
981 params,
982 result_formats: result_fmts_slice,
983 },
984 &mut buf,
985 );
986
987 frontend::encode_message(
988 &FrontendMsg::Execute {
989 portal: b"",
990 max_rows: 0,
991 },
992 &mut buf,
993 );
994
995 frontend::encode_message(&FrontendMsg::Sync, &mut buf);
996
997 let resp = self.submit(buf, ResponseCollector::Rows).await?;
998 match resp {
999 PipelineResponse::Rows { rows, .. } => Ok(rows),
1000 PipelineResponse::Done => Ok(Vec::new()),
1001 }
1002 }
1003
1004 pub async fn exec_query_with_formats(
1007 &self,
1008 sql: &str,
1009 params: &[Option<&[u8]>],
1010 param_oids: &[u32],
1011 param_formats: &[FormatCode],
1012 result_formats: &[FormatCode],
1013 ) -> Result<Vec<RawRow>, PgWireError> {
1014 let (stmt_name, needs_parse) = self.lookup_or_alloc(sql, param_oids);
1015 match self
1016 .query_with_formats(
1017 sql,
1018 params,
1019 param_oids,
1020 param_formats,
1021 result_formats,
1022 &stmt_name,
1023 needs_parse,
1024 )
1025 .await
1026 {
1027 Ok(rows) => {
1028 if needs_parse {
1029 self.cache_statement(sql, &stmt_name);
1030 }
1031 Ok(rows)
1032 }
1033 Err(PgWireError::Pg(ref pg_err))
1034 if !needs_parse && is_stale_statement_error(pg_err) =>
1035 {
1036 tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
1037 self.invalidate_statement(sql);
1038 let (stmt_name, _) = self.lookup_or_alloc(sql, param_oids);
1039 let result = self
1040 .query_with_formats(
1041 sql,
1042 params,
1043 param_oids,
1044 param_formats,
1045 result_formats,
1046 &stmt_name,
1047 true,
1048 )
1049 .await;
1050 if result.is_ok() {
1051 self.cache_statement(sql, &stmt_name);
1052 }
1053 result
1054 }
1055 Err(e) => Err(e),
1056 }
1057 }
1058}
1059
1060async fn writer_task(
1065 mut rx: mpsc::Receiver<PipelineRequest>,
1066 mut stream: tokio::io::WriteHalf<crate::tls::MaybeTlsStream>,
1067 pending: Arc<Mutex<VecDeque<PendingResponse>>>,
1068 pending_notify: Arc<tokio::sync::Notify>,
1069) {
1070 let mut write_buf = BytesMut::with_capacity(8192);
1071
1072 loop {
1073 let first = match rx.recv().await {
1075 Some(req) => req,
1076 None => {
1077 drain_pending_on_exit(&pending).await;
1079 return;
1080 }
1081 };
1082
1083 write_buf.clear();
1085 write_buf.extend_from_slice(&first.messages);
1086
1087 let mut batch: Vec<PendingResponse> = vec![PendingResponse {
1088 collector: first.collector,
1089 response_tx: first.response_tx,
1090 }];
1091
1092 while let Ok(req) = rx.try_recv() {
1094 write_buf.extend_from_slice(&req.messages);
1095 batch.push(PendingResponse {
1096 collector: req.collector,
1097 response_tx: req.response_tx,
1098 });
1099 }
1100
1101 let write_result = stream.write_all(&write_buf).await;
1105 let write_err = match write_result {
1106 Ok(_) => stream.flush().await.err(),
1107 Err(e) => Some(e),
1108 };
1109
1110 if let Some(e) = write_err {
1111 tracing::error!("Writer error: {e}");
1112 let msg = e.to_string();
1113 for p in batch {
1114 let _ = p.response_tx.send(Err(PgWireError::Io(std::io::Error::new(
1115 std::io::ErrorKind::BrokenPipe,
1116 msg.clone(),
1117 ))));
1118 }
1119 drain_pending_on_exit(&pending).await;
1121 return;
1122 }
1123
1124 {
1126 let mut pq = pending.lock().await;
1127 for p in batch {
1128 pq.push_back(p);
1129 }
1130 }
1131 pending_notify.notify_one();
1133 }
1134}
1135
1136async fn drain_pending_on_exit(pending: &Arc<Mutex<VecDeque<PendingResponse>>>) {
1139 let mut pq = pending.lock().await;
1140 while let Some(pr) = pq.pop_front() {
1141 let _ = pr.response_tx.send(Err(PgWireError::ConnectionClosed));
1142 }
1143}
1144
1145async fn reader_task(
1150 mut stream: tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1151 pending: Arc<Mutex<VecDeque<PendingResponse>>>,
1152 pending_notify: Arc<tokio::sync::Notify>,
1153 notification_tx: mpsc::Sender<BackendMsg>,
1154 state_mutated: Arc<std::sync::atomic::AtomicBool>,
1155 dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
1156) {
1157 let mut recv_buf = BytesMut::with_capacity(32 * 1024);
1158
1159 loop {
1160 let pr = loop {
1162 {
1163 let mut pq = pending.lock().await;
1164 if let Some(pr) = pq.pop_front() {
1165 break pr;
1166 }
1167 }
1168 pending_notify.notified().await;
1170 };
1171
1172 let result = match pr.collector {
1174 ResponseCollector::Rows => {
1175 collect_rows(
1176 &mut stream,
1177 &mut recv_buf,
1178 ¬ification_tx,
1179 &state_mutated,
1180 &dropped_notifications,
1181 )
1182 .await
1183 }
1184 ResponseCollector::Drain => {
1185 drain_until_ready(&mut stream, &mut recv_buf, Some(&state_mutated))
1186 .await
1187 .map(|_| PipelineResponse::Done)
1188 }
1189 ResponseCollector::Stream { header_tx, row_tx } => {
1190 stream_rows(
1191 &mut stream,
1192 &mut recv_buf,
1193 header_tx,
1194 row_tx,
1195 ¬ification_tx,
1196 &state_mutated,
1197 &dropped_notifications,
1198 )
1199 .await;
1200 Ok(PipelineResponse::Done)
1201 }
1202 ResponseCollector::CopyIn { .. } => {
1203 collect_copy_in_response(&mut stream, &mut recv_buf, &state_mutated).await
1204 }
1205 ResponseCollector::CopyOut => {
1206 collect_copy_out(&mut stream, &mut recv_buf, &state_mutated).await
1207 }
1208 };
1209
1210 let _ = pr.response_tx.send(result);
1212 }
1213}
1214
1215async fn read_msg(
1216 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1217 buf: &mut BytesMut,
1218) -> Result<BackendMsg, PgWireError> {
1219 loop {
1220 if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1221 return Ok(msg);
1222 }
1223 let n = stream.read_buf(buf).await?;
1224 if n == 0 {
1225 if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1229 return Ok(msg);
1230 }
1231 return Err(PgWireError::ConnectionClosed);
1232 }
1233 }
1234}
1235
1236fn note_rfq_status(status: u8, state_mutated: &std::sync::atomic::AtomicBool) {
1240 if status != b'I' {
1241 state_mutated.store(true, std::sync::atomic::Ordering::Release);
1242 }
1243}
1244
1245async fn collect_rows(
1246 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1247 buf: &mut BytesMut,
1248 notification_tx: &mpsc::Sender<BackendMsg>,
1249 state_mutated: &std::sync::atomic::AtomicBool,
1250 dropped_notifications: &std::sync::atomic::AtomicU64,
1251) -> Result<PipelineResponse, PgWireError> {
1252 let mut rows = Vec::new();
1253 let mut fields = Vec::new();
1254 let mut command_tag = String::new();
1255 loop {
1256 let msg = read_msg(stream, buf).await?;
1257 match msg {
1258 BackendMsg::DataRow(row) => rows.push(row),
1259 BackendMsg::RowDescription { fields: f } => fields = f,
1260 BackendMsg::CommandComplete { tag } => command_tag = tag,
1261 BackendMsg::ReadyForQuery { status } => {
1262 note_rfq_status(status, state_mutated);
1263 return Ok(PipelineResponse::Rows {
1264 fields,
1265 rows,
1266 command_tag,
1267 });
1268 }
1269 BackendMsg::ErrorResponse { fields } => {
1270 drain_until_ready(stream, buf, Some(state_mutated)).await?;
1271 return Err(PgWireError::Pg(fields));
1272 }
1273 msg @ BackendMsg::NotificationResponse { .. } => {
1274 #[allow(clippy::collapsible_match)]
1276 if notification_tx.try_send(msg).is_err() {
1277 dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1278 tracing::warn!("notification channel full, dropping notification");
1279 }
1280 }
1281 BackendMsg::ParseComplete
1282 | BackendMsg::BindComplete
1283 | BackendMsg::NoData
1284 | BackendMsg::NoticeResponse { .. }
1285 | BackendMsg::EmptyQueryResponse => {}
1286 _ => {}
1287 }
1288 }
1289}
1290
1291async fn drain_until_ready(
1292 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1293 buf: &mut BytesMut,
1294 state_mutated: Option<&std::sync::atomic::AtomicBool>,
1295) -> Result<(), PgWireError> {
1296 loop {
1297 let msg = read_msg(stream, buf).await?;
1298 if let BackendMsg::ReadyForQuery { status } = msg {
1299 if let Some(sm) = state_mutated {
1300 note_rfq_status(status, sm);
1301 }
1302 return Ok(());
1303 }
1304 if let BackendMsg::ErrorResponse { ref fields } = msg {
1305 tracing::warn!("Error in drain: {}: {}", fields.code, fields.message);
1306 }
1307 }
1308}
1309
1310async fn stream_rows(
1312 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1313 buf: &mut BytesMut,
1314 header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
1315 row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
1316 notification_tx: &mpsc::Sender<BackendMsg>,
1317 state_mutated: &std::sync::atomic::AtomicBool,
1318 dropped_notifications: &std::sync::atomic::AtomicU64,
1319) {
1320 let mut header_tx = Some(header_tx);
1321 let mut fields = Vec::new();
1322 loop {
1323 let msg = match read_msg(stream, buf).await {
1324 Ok(msg) => msg,
1325 Err(e) => {
1326 if let Some(htx) = header_tx.take() {
1327 let _ = htx.send(Err(e));
1328 } else {
1329 let _ = row_tx.send(Err(e)).await;
1330 }
1331 return;
1332 }
1333 };
1334 match msg {
1335 BackendMsg::RowDescription { fields: f } => {
1336 fields = f;
1337 }
1338 BackendMsg::DataRow(row) => {
1339 if let Some(htx) = header_tx.take() {
1340 let _ = htx.send(Ok(StreamHeader {
1341 fields: fields.clone(),
1342 }));
1343 }
1344 if row_tx.send(Ok(row)).await.is_err() {
1345 let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1346 return;
1347 }
1348 }
1349 BackendMsg::CommandComplete { .. } => {
1350 if let Some(htx) = header_tx.take() {
1351 let _ = htx.send(Ok(StreamHeader {
1352 fields: std::mem::take(&mut fields),
1353 }));
1354 }
1355 }
1356 BackendMsg::ReadyForQuery { status } => {
1357 note_rfq_status(status, state_mutated);
1358 if let Some(htx) = header_tx.take() {
1359 let _ = htx.send(Ok(StreamHeader {
1360 fields: std::mem::take(&mut fields),
1361 }));
1362 }
1363 return;
1364 }
1365 BackendMsg::ErrorResponse { fields: err } => {
1366 if let Some(htx) = header_tx.take() {
1367 let _ = htx.send(Err(PgWireError::Pg(err)));
1368 } else {
1369 let _ = row_tx.send(Err(PgWireError::Pg(err))).await;
1370 }
1371 let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1372 return;
1373 }
1374 msg @ BackendMsg::NotificationResponse { .. } => {
1375 #[allow(clippy::collapsible_match)]
1376 if notification_tx.try_send(msg).is_err() {
1377 dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1378 tracing::warn!("notification channel full, dropping notification");
1379 }
1380 }
1381 BackendMsg::ParseComplete
1382 | BackendMsg::BindComplete
1383 | BackendMsg::NoData
1384 | BackendMsg::PortalSuspended
1385 | BackendMsg::NoticeResponse { .. }
1386 | BackendMsg::EmptyQueryResponse => {}
1387 _ => {}
1388 }
1389 }
1390}
1391
1392async fn collect_copy_in_response(
1395 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1396 buf: &mut BytesMut,
1397 state_mutated: &std::sync::atomic::AtomicBool,
1398) -> Result<PipelineResponse, PgWireError> {
1399 let mut command_tag = String::new();
1400 loop {
1401 let msg = read_msg(stream, buf).await?;
1402 match msg {
1403 BackendMsg::CopyInResponse { .. } => {}
1404 BackendMsg::CommandComplete { tag } => command_tag = tag,
1405 BackendMsg::ReadyForQuery { status } => {
1406 note_rfq_status(status, state_mutated);
1407 return Ok(PipelineResponse::Rows {
1408 fields: Vec::new(),
1409 rows: Vec::new(),
1410 command_tag,
1411 });
1412 }
1413 BackendMsg::ErrorResponse { fields } => {
1414 drain_until_ready(stream, buf, Some(state_mutated)).await?;
1415 return Err(PgWireError::Pg(fields));
1416 }
1417 _ => {}
1418 }
1419 }
1420}
1421
1422async fn collect_copy_out(
1424 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1425 buf: &mut BytesMut,
1426 state_mutated: &std::sync::atomic::AtomicBool,
1427) -> Result<PipelineResponse, PgWireError> {
1428 let mut data_chunks: Vec<RawRow> = Vec::new();
1429 let mut command_tag = String::new();
1430 loop {
1431 let msg = read_msg(stream, buf).await?;
1432 match msg {
1433 BackendMsg::CopyOutResponse { .. } => {}
1434 BackendMsg::CopyData { data } => {
1435 let body = bytes::Bytes::from(data);
1436 data_chunks.push(RawRow::from_full_body(body));
1437 }
1438 BackendMsg::CopyDone => {}
1439 BackendMsg::CommandComplete { tag } => command_tag = tag,
1440 BackendMsg::ReadyForQuery { status } => {
1441 note_rfq_status(status, state_mutated);
1442 return Ok(PipelineResponse::Rows {
1443 fields: Vec::new(),
1444 rows: data_chunks,
1445 command_tag,
1446 });
1447 }
1448 BackendMsg::ErrorResponse { fields } => {
1449 drain_until_ready(stream, buf, Some(state_mutated)).await?;
1450 return Err(PgWireError::Pg(fields));
1451 }
1452 _ => {}
1453 }
1454 }
1455}
1456
1457fn is_stale_statement_error(err: &crate::protocol::types::PgError) -> bool {
1461 matches!(err.code.as_str(), "26000" | "0A000")
1462}
1463
1464fn parse_copy_count(tag: &str) -> u64 {
1465 tag.strip_prefix("COPY ")
1467 .and_then(|s| s.parse::<u64>().ok())
1468 .unwrap_or(0)
1469}
1470
1471impl WireConn {
1473 pub(crate) fn into_stream(self) -> crate::tls::MaybeTlsStream {
1474 self.stream
1475 }
1476}
1477
1478#[cfg(test)]
1479mod tests {
1480 use super::*;
1481
1482 #[tokio::test]
1485 async fn try_enqueue_rollback_returns_false_when_channel_full() {
1486 let (tx, _rx) = mpsc::channel::<PipelineRequest>(2);
1487 let mut filled = false;
1491 for _ in 0..16 {
1492 if !try_enqueue_rollback(&tx) {
1493 filled = true;
1494 break;
1495 }
1496 }
1497 assert!(
1498 filled,
1499 "expected try_enqueue_rollback to eventually return false on a full channel"
1500 );
1501 assert!(
1502 !try_enqueue_rollback(&tx),
1503 "subsequent calls on a full channel must keep returning false"
1504 );
1505 }
1506
1507 #[tokio::test]
1510 async fn try_enqueue_rollback_returns_false_when_channel_closed() {
1511 let (tx, rx) = mpsc::channel::<PipelineRequest>(8);
1512 drop(rx);
1513 assert!(
1514 !try_enqueue_rollback(&tx),
1515 "try_enqueue_rollback must return false when the receiver has been dropped"
1516 );
1517 }
1518
1519 #[tokio::test]
1523 async fn try_enqueue_rollback_returns_true_and_enqueues_query() {
1524 let (tx, mut rx) = mpsc::channel::<PipelineRequest>(2);
1525 assert!(try_enqueue_rollback(&tx));
1526 let req = rx.recv().await.expect("request should be received");
1527 assert_eq!(
1528 req.messages.first().copied(),
1529 Some(b'Q'),
1530 "queued request should be a simple Query message"
1531 );
1532 assert!(
1535 req.messages.windows(8).any(|w| w == b"ROLLBACK"),
1536 "queued request should contain the ROLLBACK statement text"
1537 );
1538 }
1539}