1use std::collections::VecDeque;
9use std::sync::Arc;
10
11use bytes::BytesMut;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::sync::{mpsc, oneshot, Mutex};
14
15use crate::connection::WireConn;
16use crate::error::PgWireError;
17use crate::protocol::backend;
18use crate::protocol::frontend;
19use crate::protocol::types::{BackendMsg, FormatCode, FrontendMsg, RawRow};
20
21pub(crate) struct PipelineRequest {
28 pub(crate) messages: BytesMut,
29 pub(crate) collector: ResponseCollector,
30 pub(crate) response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
31}
32
33#[allow(dead_code)]
35#[non_exhaustive]
36pub enum ResponseCollector {
37 Rows,
39 Drain,
41 Stream {
43 header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
45 row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
47 },
48 CopyIn {
50 data: Vec<u8>,
52 },
53 CopyOut,
55}
56
57#[non_exhaustive]
59pub enum PipelineResponse {
60 Rows {
62 fields: Vec<crate::protocol::types::FieldDescription>,
64 rows: Vec<RawRow>,
66 command_tag: String,
68 },
69 Done,
72}
73
74#[derive(Debug, Clone)]
76pub struct StreamHeader {
77 pub fields: Vec<crate::protocol::types::FieldDescription>,
79}
80
81pub type StreamedRow = RawRow;
83
84pub struct AsyncConn {
90 request_tx: mpsc::Sender<PipelineRequest>,
91 stmt_cache: std::sync::Mutex<std::collections::HashMap<String, (String, u64)>>,
92 stmt_counter: std::sync::atomic::AtomicU64,
93 alive: Arc<std::sync::atomic::AtomicBool>,
94 backend_pid: i32,
95 backend_secret: i32,
96 addr: String,
97 #[allow(dead_code)]
100 notification_tx: mpsc::Sender<crate::protocol::types::BackendMsg>,
101 notification_rx: std::sync::Mutex<Option<mpsc::Receiver<crate::protocol::types::BackendMsg>>>,
102 state_mutated: Arc<std::sync::atomic::AtomicBool>,
112 dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
117}
118
119impl std::fmt::Debug for AsyncConn {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 f.debug_struct("AsyncConn")
122 .field("addr", &self.addr)
123 .field("backend_pid", &self.backend_pid)
124 .field("alive", &self.is_alive())
125 .finish()
126 }
127}
128
129impl AsyncConn {
130 pub fn is_alive(&self) -> bool {
132 self.alive.load(std::sync::atomic::Ordering::Relaxed)
133 }
134
135 pub fn backend_pid(&self) -> i32 {
137 self.backend_pid
138 }
139
140 pub fn addr(&self) -> &str {
142 &self.addr
143 }
144
145 pub fn cancel_token(&self) -> crate::cancel::CancelToken {
147 crate::cancel::CancelToken::new(self.addr.clone(), self.backend_pid, self.backend_secret)
148 }
149
150 pub fn mark_state_mutated(&self) {
156 self.state_mutated
157 .store(true, std::sync::atomic::Ordering::Release);
158 }
159
160 pub fn take_state_mutated(&self) -> bool {
163 self.state_mutated
164 .swap(false, std::sync::atomic::Ordering::AcqRel)
165 }
166
167 pub fn is_state_mutated(&self) -> bool {
169 self.state_mutated
170 .load(std::sync::atomic::Ordering::Acquire)
171 }
172}
173
174struct PendingResponse {
175 collector: ResponseCollector,
176 response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
177}
178
179impl AsyncConn {
180 pub fn new(conn: WireConn) -> Self {
183 let backend_pid = conn.pid;
184 let backend_secret = conn.secret;
185 let addr = conn
187 .stream
188 .peer_addr()
189 .map(|a| a.to_string())
190 .unwrap_or_default();
191
192 let (notification_tx, notification_rx) = mpsc::channel(4096);
193 let (request_tx, request_rx) = mpsc::channel::<PipelineRequest>(256);
194 let pending: Arc<Mutex<VecDeque<PendingResponse>>> = Arc::new(Mutex::new(VecDeque::new()));
195 let pending_notify = Arc::new(tokio::sync::Notify::new());
196 let alive = Arc::new(std::sync::atomic::AtomicBool::new(true));
197 let state_mutated = Arc::new(std::sync::atomic::AtomicBool::new(false));
198 let dropped_notifications = Arc::new(std::sync::atomic::AtomicU64::new(0));
199
200 let (stream_read, stream_write) = tokio::io::split(conn.into_stream());
201
202 {
204 let pending = Arc::clone(&pending);
205 let pending_notify = Arc::clone(&pending_notify);
206 let alive = Arc::clone(&alive);
207 tokio::spawn(async move {
208 writer_task(request_rx, stream_write, pending, pending_notify).await;
209 alive.store(false, std::sync::atomic::Ordering::Relaxed);
210 tracing::warn!("pg-wired writer task exited");
211 });
212 }
213
214 {
216 let pending = Arc::clone(&pending);
217 let pending_notify = Arc::clone(&pending_notify);
218 let alive_clone = Arc::clone(&alive);
219 let state_mutated = Arc::clone(&state_mutated);
220 let ntf_tx = notification_tx.clone();
221 let dropped = Arc::clone(&dropped_notifications);
222 tokio::spawn(async move {
223 reader_task(
224 stream_read,
225 pending,
226 pending_notify,
227 ntf_tx,
228 state_mutated,
229 dropped,
230 )
231 .await;
232 alive_clone.store(false, std::sync::atomic::Ordering::Relaxed);
233 tracing::warn!("pg-wired reader task exited");
234 });
235 }
236
237 Self {
238 request_tx,
239 stmt_cache: std::sync::Mutex::new(std::collections::HashMap::new()),
240 stmt_counter: std::sync::atomic::AtomicU64::new(0),
241 alive,
242 backend_pid,
243 backend_secret,
244 addr,
245 notification_tx,
246 notification_rx: std::sync::Mutex::new(Some(notification_rx)),
247 state_mutated,
248 dropped_notifications,
249 }
250 }
251
252 pub fn dropped_notifications(&self) -> u64 {
260 self.dropped_notifications
261 .load(std::sync::atomic::Ordering::Relaxed)
262 }
263
264 pub fn take_notification_receiver(
267 &self,
268 ) -> Option<mpsc::Receiver<crate::protocol::types::BackendMsg>> {
269 self.notification_rx
270 .lock()
271 .ok()
272 .and_then(|mut guard| guard.take())
273 }
274
275 pub fn lookup_or_alloc(&self, sql: &str) -> (Vec<u8>, bool) {
280 let mut cache = match self.stmt_cache.lock() {
281 Ok(c) => c,
282 Err(poisoned) => poisoned.into_inner(),
283 };
284 if let Some((name, _)) = cache.get(sql) {
285 return (name.as_bytes().to_vec(), false);
286 }
287 if cache.len() >= 256 {
290 if let Some((oldest_key, oldest_name)) = cache
291 .iter()
292 .min_by_key(|(_, (_, counter))| *counter)
293 .map(|(k, (name, _))| (k.clone(), name.clone()))
294 {
295 cache.remove(&oldest_key);
296 let mut close_buf = BytesMut::with_capacity(32);
299 frontend::encode_message(
300 &FrontendMsg::Close {
301 kind: b'S',
302 name: oldest_name.as_bytes(),
303 },
304 &mut close_buf,
305 );
306 frontend::encode_message(&FrontendMsg::Sync, &mut close_buf);
307 let (tx, _rx) = oneshot::channel();
308 let _ = self.request_tx.try_send(PipelineRequest {
309 messages: close_buf,
310 collector: ResponseCollector::Drain,
311 response_tx: tx,
312 });
313 }
314 }
315 let n = self
316 .stmt_counter
317 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
318 let name = format!("s{n}");
319 cache.insert(sql.to_string(), (name.clone(), n));
320 (name.into_bytes(), true)
321 }
322
323 pub async fn copy_in(&self, copy_sql: &str, data: &[u8]) -> Result<u64, PgWireError> {
329 use crate::protocol::types::FrontendMsg;
330 const CHUNK_SIZE: usize = 1024 * 1024; let mut buf = BytesMut::with_capacity(copy_sql.len() + data.len().min(CHUNK_SIZE) + 64);
334 frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
335
336 for chunk in data.chunks(CHUNK_SIZE) {
338 frontend::encode_message(&FrontendMsg::CopyData(chunk), &mut buf);
339 }
340 frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
342
343 let resp = self
344 .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
345 .await?;
346 match resp {
347 PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
348 PipelineResponse::Done => Ok(0),
349 }
350 }
351
352 pub async fn copy_in_stream<R: tokio::io::AsyncRead + Unpin>(
365 &self,
366 copy_sql: &str,
367 mut reader: R,
368 ) -> Result<u64, PgWireError> {
369 use tokio::io::AsyncReadExt;
370 const CHUNK_SIZE: usize = 1024 * 1024; let mut buf = BytesMut::with_capacity(copy_sql.len() + 16);
374 frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
375
376 let mut chunk = vec![0u8; CHUNK_SIZE];
378 loop {
379 let n = reader.read(&mut chunk).await?;
380 if n == 0 {
381 break;
382 }
383 frontend::encode_message(&FrontendMsg::CopyData(&chunk[..n]), &mut buf);
384 }
385 frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
386
387 let resp = self
388 .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
389 .await?;
390 match resp {
391 PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
392 PipelineResponse::Done => Ok(0),
393 }
394 }
395
396 pub async fn copy_out(&self, copy_sql: &str) -> Result<Vec<u8>, PgWireError> {
398 use crate::protocol::types::FrontendMsg;
399 let mut buf = BytesMut::new();
400 frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
401
402 let resp = self.submit(buf, ResponseCollector::CopyOut).await?;
403 match resp {
404 PipelineResponse::Rows { rows, .. } => {
405 let mut result = Vec::new();
408 for row in rows {
409 for data in row.iter().flatten() {
410 result.extend_from_slice(data);
411 }
412 }
413 Ok(result)
414 }
415 PipelineResponse::Done => Ok(Vec::new()),
416 }
417 }
418
419 pub fn invalidate_statement(&self, sql: &str) {
422 let mut cache = match self.stmt_cache.lock() {
423 Ok(c) => c,
424 Err(poisoned) => poisoned.into_inner(),
425 };
426 cache.remove(sql);
427 }
428
429 pub fn clear_statement_cache(&self) {
432 let mut cache = match self.stmt_cache.lock() {
433 Ok(c) => c,
434 Err(poisoned) => poisoned.into_inner(),
435 };
436 cache.clear();
437 }
438
439 pub async fn exec_transaction(
441 &self,
442 setup_sql: &str,
443 query_sql: &str,
444 params: &[Option<&[u8]>],
445 param_oids: &[u32],
446 ) -> Result<Vec<RawRow>, PgWireError> {
447 let (stmt_name, needs_parse) = self.lookup_or_alloc(query_sql);
448 self.pipeline_transaction(
449 setup_sql,
450 query_sql,
451 params,
452 param_oids,
453 &stmt_name,
454 needs_parse,
455 )
456 .await
457 }
458
459 pub async fn exec_query(
463 &self,
464 sql: &str,
465 params: &[Option<&[u8]>],
466 param_oids: &[u32],
467 ) -> Result<Vec<RawRow>, PgWireError> {
468 let (stmt_name, needs_parse) = self.lookup_or_alloc(sql);
469 match self
470 .query(sql, params, param_oids, &stmt_name, needs_parse)
471 .await
472 {
473 Ok(rows) => Ok(rows),
474 Err(PgWireError::Pg(ref pg_err))
475 if !needs_parse && is_stale_statement_error(pg_err) =>
476 {
477 tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
478 self.invalidate_statement(sql);
479 let (stmt_name, _) = self.lookup_or_alloc(sql);
480 self.query(sql, params, param_oids, &stmt_name, true).await
481 }
482 Err(e) => Err(e),
483 }
484 }
485
486 const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
489
490 pub async fn submit(
494 &self,
495 messages: BytesMut,
496 collector: ResponseCollector,
497 ) -> Result<PipelineResponse, PgWireError> {
498 let (response_tx, response_rx) = oneshot::channel();
499 let req = PipelineRequest {
500 messages,
501 collector,
502 response_tx,
503 };
504 self.request_tx
505 .send(req)
506 .await
507 .map_err(|_| PgWireError::ConnectionClosed)?;
508 match tokio::time::timeout(Self::REQUEST_TIMEOUT, response_rx).await {
509 Ok(Ok(result)) => result,
510 Ok(Err(_)) => Err(PgWireError::ConnectionClosed),
511 Err(_elapsed) => {
512 tracing::error!(
513 "request timed out after {:?} — reader/writer task may be dead",
514 Self::REQUEST_TIMEOUT
515 );
516 Err(PgWireError::ConnectionClosed)
517 }
518 }
519 }
520
521 pub async fn submit_batch(
531 &self,
532 items: Vec<(BytesMut, ResponseCollector)>,
533 ) -> Result<Vec<Result<PipelineResponse, PgWireError>>, PgWireError> {
534 let mut receivers = Vec::with_capacity(items.len());
535 for (messages, collector) in items {
536 let (response_tx, response_rx) = oneshot::channel();
537 self.request_tx
538 .send(PipelineRequest {
539 messages,
540 collector,
541 response_tx,
542 })
543 .await
544 .map_err(|_| PgWireError::ConnectionClosed)?;
545 receivers.push(response_rx);
546 }
547 let mut results = Vec::with_capacity(receivers.len());
548 for rx in receivers {
549 match tokio::time::timeout(Self::REQUEST_TIMEOUT, rx).await {
550 Ok(Ok(r)) => results.push(r),
551 Ok(Err(_)) => results.push(Err(PgWireError::ConnectionClosed)),
552 Err(_) => {
553 tracing::error!(
554 "submit_batch request timed out after {:?}",
555 Self::REQUEST_TIMEOUT
556 );
557 results.push(Err(PgWireError::ConnectionClosed));
558 }
559 }
560 }
561 Ok(results)
562 }
563
564 pub async fn close(&self) -> Result<(), PgWireError> {
569 if !self.is_alive() {
570 return Ok(());
571 }
572 let mut buf = BytesMut::with_capacity(5);
573 frontend::encode_message(&FrontendMsg::Terminate, &mut buf);
574 match self.submit(buf, ResponseCollector::Drain).await {
579 Ok(_) | Err(PgWireError::ConnectionClosed) => Ok(()),
580 Err(PgWireError::Io(e)) if e.kind() == std::io::ErrorKind::BrokenPipe => Ok(()),
581 Err(e) => Err(e),
582 }
583 }
584
585 pub async fn submit_stream(
588 &self,
589 messages: BytesMut,
590 row_buffer: usize,
591 ) -> Result<
592 (
593 StreamHeader,
594 mpsc::Receiver<Result<StreamedRow, PgWireError>>,
595 ),
596 PgWireError,
597 > {
598 let (header_tx, header_rx) = oneshot::channel();
599 let (row_tx, row_rx) = mpsc::channel(row_buffer);
600 let (response_tx, _response_rx) = oneshot::channel();
601 let req = PipelineRequest {
602 messages,
603 collector: ResponseCollector::Stream { header_tx, row_tx },
604 response_tx,
605 };
606 self.request_tx
607 .send(req)
608 .await
609 .map_err(|_| PgWireError::ConnectionClosed)?;
610 let header = header_rx
611 .await
612 .map_err(|_| PgWireError::ConnectionClosed)??;
613 Ok((header, row_rx))
614 }
615
616 pub async fn pipeline_transaction(
620 &self,
621 setup_sql: &str,
622 query_sql: &str,
623 params: &[Option<&[u8]>],
624 param_oids: &[u32],
625 stmt_name: &[u8],
626 needs_parse: bool,
627 ) -> Result<Vec<RawRow>, PgWireError> {
628 let mut buf = BytesMut::with_capacity(1024);
629
630 frontend::encode_message(&FrontendMsg::Query(setup_sql.as_bytes()), &mut buf);
632
633 let setup_msgs = buf.split();
635
636 let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
638 let result_fmts = [FormatCode::Text];
639
640 if needs_parse {
641 frontend::encode_message(
642 &FrontendMsg::Parse {
643 name: stmt_name,
644 sql: query_sql.as_bytes(),
645 param_oids,
646 },
647 &mut buf,
648 );
649 }
650
651 frontend::encode_message(
652 &FrontendMsg::Bind {
653 portal: b"",
654 statement: stmt_name,
655 param_formats: &text_fmts[..params.len()],
656 params,
657 result_formats: &result_fmts,
658 },
659 &mut buf,
660 );
661
662 frontend::encode_message(
663 &FrontendMsg::Execute {
664 portal: b"",
665 max_rows: 0,
666 },
667 &mut buf,
668 );
669
670 frontend::encode_message(&FrontendMsg::Sync, &mut buf);
671
672 let data_msgs = buf.split();
673
674 let mut commit_buf = BytesMut::with_capacity(32);
677 frontend::encode_message(&FrontendMsg::Query(b"COMMIT"), &mut commit_buf);
678
679 let (setup_tx, setup_rx) = oneshot::channel();
682 let (data_tx, data_rx) = oneshot::channel();
683 let (commit_tx, commit_rx) = oneshot::channel();
684
685 self.request_tx
688 .send(PipelineRequest {
689 messages: setup_msgs,
690 collector: ResponseCollector::Drain,
691 response_tx: setup_tx,
692 })
693 .await
694 .map_err(|_| PgWireError::ConnectionClosed)?;
695
696 self.request_tx
697 .send(PipelineRequest {
698 messages: data_msgs,
699 collector: ResponseCollector::Rows,
700 response_tx: data_tx,
701 })
702 .await
703 .map_err(|_| PgWireError::ConnectionClosed)?;
704
705 self.request_tx
706 .send(PipelineRequest {
707 messages: commit_buf,
708 collector: ResponseCollector::Drain,
709 response_tx: commit_tx,
710 })
711 .await
712 .map_err(|_| PgWireError::ConnectionClosed)?;
713
714 setup_rx
716 .await
717 .map_err(|_| PgWireError::ConnectionClosed)??;
718
719 let data_resp = data_rx.await.map_err(|_| PgWireError::ConnectionClosed)??;
720
721 commit_rx
722 .await
723 .map_err(|_| PgWireError::ConnectionClosed)??;
724
725 match data_resp {
726 PipelineResponse::Rows { rows, .. } => Ok(rows),
727 PipelineResponse::Done => Ok(Vec::new()),
728 }
729 }
730
731 pub async fn query(
733 &self,
734 sql: &str,
735 params: &[Option<&[u8]>],
736 param_oids: &[u32],
737 stmt_name: &[u8],
738 needs_parse: bool,
739 ) -> Result<Vec<RawRow>, PgWireError> {
740 self.query_with_formats(sql, params, param_oids, &[], &[], stmt_name, needs_parse)
741 .await
742 }
743
744 #[allow(clippy::too_many_arguments)]
755 pub async fn query_with_formats(
756 &self,
757 sql: &str,
758 params: &[Option<&[u8]>],
759 param_oids: &[u32],
760 param_formats: &[FormatCode],
761 result_formats: &[FormatCode],
762 stmt_name: &[u8],
763 needs_parse: bool,
764 ) -> Result<Vec<RawRow>, PgWireError> {
765 let mut buf = BytesMut::with_capacity(512);
766
767 let text_param_fmts: Vec<FormatCode>;
769 let param_fmts_slice: &[FormatCode] = if param_formats.is_empty() {
770 text_param_fmts = vec![FormatCode::Text; params.len().max(1)];
771 &text_param_fmts[..params.len()]
772 } else {
773 param_formats
774 };
775 let default_result_fmts = [FormatCode::Text];
776 let result_fmts_slice: &[FormatCode] = if result_formats.is_empty() {
777 &default_result_fmts
778 } else {
779 result_formats
780 };
781
782 if needs_parse {
783 frontend::encode_message(
784 &FrontendMsg::Parse {
785 name: stmt_name,
786 sql: sql.as_bytes(),
787 param_oids,
788 },
789 &mut buf,
790 );
791 }
792
793 frontend::encode_message(
794 &FrontendMsg::Bind {
795 portal: b"",
796 statement: stmt_name,
797 param_formats: param_fmts_slice,
798 params,
799 result_formats: result_fmts_slice,
800 },
801 &mut buf,
802 );
803
804 frontend::encode_message(
805 &FrontendMsg::Execute {
806 portal: b"",
807 max_rows: 0,
808 },
809 &mut buf,
810 );
811
812 frontend::encode_message(&FrontendMsg::Sync, &mut buf);
813
814 let resp = self.submit(buf, ResponseCollector::Rows).await?;
815 match resp {
816 PipelineResponse::Rows { rows, .. } => Ok(rows),
817 PipelineResponse::Done => Ok(Vec::new()),
818 }
819 }
820
821 pub async fn exec_query_with_formats(
824 &self,
825 sql: &str,
826 params: &[Option<&[u8]>],
827 param_oids: &[u32],
828 param_formats: &[FormatCode],
829 result_formats: &[FormatCode],
830 ) -> Result<Vec<RawRow>, PgWireError> {
831 let (stmt_name, needs_parse) = self.lookup_or_alloc(sql);
832 match self
833 .query_with_formats(
834 sql,
835 params,
836 param_oids,
837 param_formats,
838 result_formats,
839 &stmt_name,
840 needs_parse,
841 )
842 .await
843 {
844 Ok(rows) => Ok(rows),
845 Err(PgWireError::Pg(ref pg_err))
846 if !needs_parse && is_stale_statement_error(pg_err) =>
847 {
848 tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
849 self.invalidate_statement(sql);
850 let (stmt_name, _) = self.lookup_or_alloc(sql);
851 self.query_with_formats(
852 sql,
853 params,
854 param_oids,
855 param_formats,
856 result_formats,
857 &stmt_name,
858 true,
859 )
860 .await
861 }
862 Err(e) => Err(e),
863 }
864 }
865}
866
867async fn writer_task(
872 mut rx: mpsc::Receiver<PipelineRequest>,
873 mut stream: tokio::io::WriteHalf<crate::tls::MaybeTlsStream>,
874 pending: Arc<Mutex<VecDeque<PendingResponse>>>,
875 pending_notify: Arc<tokio::sync::Notify>,
876) {
877 let mut write_buf = BytesMut::with_capacity(8192);
878
879 loop {
880 let first = match rx.recv().await {
882 Some(req) => req,
883 None => {
884 drain_pending_on_exit(&pending).await;
886 return;
887 }
888 };
889
890 write_buf.clear();
892 write_buf.extend_from_slice(&first.messages);
893
894 let mut batch: Vec<PendingResponse> = vec![PendingResponse {
895 collector: first.collector,
896 response_tx: first.response_tx,
897 }];
898
899 while let Ok(req) = rx.try_recv() {
901 write_buf.extend_from_slice(&req.messages);
902 batch.push(PendingResponse {
903 collector: req.collector,
904 response_tx: req.response_tx,
905 });
906 }
907
908 let write_result = stream.write_all(&write_buf).await;
912 let write_err = match write_result {
913 Ok(_) => stream.flush().await.err(),
914 Err(e) => Some(e),
915 };
916
917 if let Some(e) = write_err {
918 tracing::error!("Writer error: {e}");
919 let msg = e.to_string();
920 for p in batch {
921 let _ = p.response_tx.send(Err(PgWireError::Io(std::io::Error::new(
922 std::io::ErrorKind::BrokenPipe,
923 msg.clone(),
924 ))));
925 }
926 drain_pending_on_exit(&pending).await;
928 return;
929 }
930
931 {
933 let mut pq = pending.lock().await;
934 for p in batch {
935 pq.push_back(p);
936 }
937 }
938 pending_notify.notify_one();
940 }
941}
942
943async fn drain_pending_on_exit(pending: &Arc<Mutex<VecDeque<PendingResponse>>>) {
946 let mut pq = pending.lock().await;
947 while let Some(pr) = pq.pop_front() {
948 let _ = pr.response_tx.send(Err(PgWireError::ConnectionClosed));
949 }
950}
951
952async fn reader_task(
957 mut stream: tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
958 pending: Arc<Mutex<VecDeque<PendingResponse>>>,
959 pending_notify: Arc<tokio::sync::Notify>,
960 notification_tx: mpsc::Sender<BackendMsg>,
961 state_mutated: Arc<std::sync::atomic::AtomicBool>,
962 dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
963) {
964 let mut recv_buf = BytesMut::with_capacity(32 * 1024);
965
966 loop {
967 let pr = loop {
969 {
970 let mut pq = pending.lock().await;
971 if let Some(pr) = pq.pop_front() {
972 break pr;
973 }
974 }
975 pending_notify.notified().await;
977 };
978
979 let result = match pr.collector {
981 ResponseCollector::Rows => {
982 collect_rows(
983 &mut stream,
984 &mut recv_buf,
985 ¬ification_tx,
986 &state_mutated,
987 &dropped_notifications,
988 )
989 .await
990 }
991 ResponseCollector::Drain => {
992 drain_until_ready(&mut stream, &mut recv_buf, Some(&state_mutated))
993 .await
994 .map(|_| PipelineResponse::Done)
995 }
996 ResponseCollector::Stream { header_tx, row_tx } => {
997 stream_rows(
998 &mut stream,
999 &mut recv_buf,
1000 header_tx,
1001 row_tx,
1002 ¬ification_tx,
1003 &state_mutated,
1004 &dropped_notifications,
1005 )
1006 .await;
1007 Ok(PipelineResponse::Done)
1008 }
1009 ResponseCollector::CopyIn { .. } => {
1010 collect_copy_in_response(&mut stream, &mut recv_buf, &state_mutated).await
1011 }
1012 ResponseCollector::CopyOut => {
1013 collect_copy_out(&mut stream, &mut recv_buf, &state_mutated).await
1014 }
1015 };
1016
1017 let _ = pr.response_tx.send(result);
1019 }
1020}
1021
1022async fn read_msg(
1023 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1024 buf: &mut BytesMut,
1025) -> Result<BackendMsg, PgWireError> {
1026 loop {
1027 if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1028 return Ok(msg);
1029 }
1030 let n = stream.read_buf(buf).await?;
1031 if n == 0 {
1032 if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1036 return Ok(msg);
1037 }
1038 return Err(PgWireError::ConnectionClosed);
1039 }
1040 }
1041}
1042
1043fn note_rfq_status(status: u8, state_mutated: &std::sync::atomic::AtomicBool) {
1047 if status != b'I' {
1048 state_mutated.store(true, std::sync::atomic::Ordering::Release);
1049 }
1050}
1051
1052async fn collect_rows(
1053 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1054 buf: &mut BytesMut,
1055 notification_tx: &mpsc::Sender<BackendMsg>,
1056 state_mutated: &std::sync::atomic::AtomicBool,
1057 dropped_notifications: &std::sync::atomic::AtomicU64,
1058) -> Result<PipelineResponse, PgWireError> {
1059 let mut rows = Vec::new();
1060 let mut fields = Vec::new();
1061 let mut command_tag = String::new();
1062 loop {
1063 let msg = read_msg(stream, buf).await?;
1064 match msg {
1065 BackendMsg::DataRow(row) => rows.push(row),
1066 BackendMsg::RowDescription { fields: f } => fields = f,
1067 BackendMsg::CommandComplete { tag } => command_tag = tag,
1068 BackendMsg::ReadyForQuery { status } => {
1069 note_rfq_status(status, state_mutated);
1070 return Ok(PipelineResponse::Rows {
1071 fields,
1072 rows,
1073 command_tag,
1074 });
1075 }
1076 BackendMsg::ErrorResponse { fields } => {
1077 drain_until_ready(stream, buf, Some(state_mutated)).await?;
1078 return Err(PgWireError::Pg(fields));
1079 }
1080 msg @ BackendMsg::NotificationResponse { .. } => {
1081 #[allow(clippy::collapsible_match)]
1083 if notification_tx.try_send(msg).is_err() {
1084 dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1085 tracing::warn!("notification channel full, dropping notification");
1086 }
1087 }
1088 BackendMsg::ParseComplete
1089 | BackendMsg::BindComplete
1090 | BackendMsg::NoData
1091 | BackendMsg::NoticeResponse { .. }
1092 | BackendMsg::EmptyQueryResponse => {}
1093 _ => {}
1094 }
1095 }
1096}
1097
1098async fn drain_until_ready(
1099 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1100 buf: &mut BytesMut,
1101 state_mutated: Option<&std::sync::atomic::AtomicBool>,
1102) -> Result<(), PgWireError> {
1103 loop {
1104 let msg = read_msg(stream, buf).await?;
1105 if let BackendMsg::ReadyForQuery { status } = msg {
1106 if let Some(sm) = state_mutated {
1107 note_rfq_status(status, sm);
1108 }
1109 return Ok(());
1110 }
1111 if let BackendMsg::ErrorResponse { ref fields } = msg {
1112 tracing::warn!("Error in drain: {}: {}", fields.code, fields.message);
1113 }
1114 }
1115}
1116
1117async fn stream_rows(
1119 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1120 buf: &mut BytesMut,
1121 header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
1122 row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
1123 notification_tx: &mpsc::Sender<BackendMsg>,
1124 state_mutated: &std::sync::atomic::AtomicBool,
1125 dropped_notifications: &std::sync::atomic::AtomicU64,
1126) {
1127 let mut header_tx = Some(header_tx);
1128 let mut fields = Vec::new();
1129 loop {
1130 let msg = match read_msg(stream, buf).await {
1131 Ok(msg) => msg,
1132 Err(e) => {
1133 if let Some(htx) = header_tx.take() {
1134 let _ = htx.send(Err(e));
1135 } else {
1136 let _ = row_tx.send(Err(e)).await;
1137 }
1138 return;
1139 }
1140 };
1141 match msg {
1142 BackendMsg::RowDescription { fields: f } => {
1143 fields = f;
1144 }
1145 BackendMsg::DataRow(row) => {
1146 if let Some(htx) = header_tx.take() {
1147 let _ = htx.send(Ok(StreamHeader {
1148 fields: fields.clone(),
1149 }));
1150 }
1151 if row_tx.send(Ok(row)).await.is_err() {
1152 let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1153 return;
1154 }
1155 }
1156 BackendMsg::CommandComplete { .. } => {
1157 if let Some(htx) = header_tx.take() {
1158 let _ = htx.send(Ok(StreamHeader {
1159 fields: std::mem::take(&mut fields),
1160 }));
1161 }
1162 }
1163 BackendMsg::ReadyForQuery { status } => {
1164 note_rfq_status(status, state_mutated);
1165 if let Some(htx) = header_tx.take() {
1166 let _ = htx.send(Ok(StreamHeader {
1167 fields: std::mem::take(&mut fields),
1168 }));
1169 }
1170 return;
1171 }
1172 BackendMsg::ErrorResponse { fields: err } => {
1173 if let Some(htx) = header_tx.take() {
1174 let _ = htx.send(Err(PgWireError::Pg(err)));
1175 } else {
1176 let _ = row_tx.send(Err(PgWireError::Pg(err))).await;
1177 }
1178 let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1179 return;
1180 }
1181 msg @ BackendMsg::NotificationResponse { .. } => {
1182 #[allow(clippy::collapsible_match)]
1183 if notification_tx.try_send(msg).is_err() {
1184 dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1185 tracing::warn!("notification channel full, dropping notification");
1186 }
1187 }
1188 BackendMsg::ParseComplete
1189 | BackendMsg::BindComplete
1190 | BackendMsg::NoData
1191 | BackendMsg::PortalSuspended
1192 | BackendMsg::NoticeResponse { .. }
1193 | BackendMsg::EmptyQueryResponse => {}
1194 _ => {}
1195 }
1196 }
1197}
1198
1199async fn collect_copy_in_response(
1202 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1203 buf: &mut BytesMut,
1204 state_mutated: &std::sync::atomic::AtomicBool,
1205) -> Result<PipelineResponse, PgWireError> {
1206 let mut command_tag = String::new();
1207 loop {
1208 let msg = read_msg(stream, buf).await?;
1209 match msg {
1210 BackendMsg::CopyInResponse { .. } => {}
1211 BackendMsg::CommandComplete { tag } => command_tag = tag,
1212 BackendMsg::ReadyForQuery { status } => {
1213 note_rfq_status(status, state_mutated);
1214 return Ok(PipelineResponse::Rows {
1215 fields: Vec::new(),
1216 rows: Vec::new(),
1217 command_tag,
1218 });
1219 }
1220 BackendMsg::ErrorResponse { fields } => {
1221 drain_until_ready(stream, buf, Some(state_mutated)).await?;
1222 return Err(PgWireError::Pg(fields));
1223 }
1224 _ => {}
1225 }
1226 }
1227}
1228
1229async fn collect_copy_out(
1231 stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1232 buf: &mut BytesMut,
1233 state_mutated: &std::sync::atomic::AtomicBool,
1234) -> Result<PipelineResponse, PgWireError> {
1235 let mut data_chunks: Vec<RawRow> = Vec::new();
1236 let mut command_tag = String::new();
1237 loop {
1238 let msg = read_msg(stream, buf).await?;
1239 match msg {
1240 BackendMsg::CopyOutResponse { .. } => {}
1241 BackendMsg::CopyData { data } => {
1242 let body = bytes::Bytes::from(data);
1243 data_chunks.push(RawRow::from_full_body(body));
1244 }
1245 BackendMsg::CopyDone => {}
1246 BackendMsg::CommandComplete { tag } => command_tag = tag,
1247 BackendMsg::ReadyForQuery { status } => {
1248 note_rfq_status(status, state_mutated);
1249 return Ok(PipelineResponse::Rows {
1250 fields: Vec::new(),
1251 rows: data_chunks,
1252 command_tag,
1253 });
1254 }
1255 BackendMsg::ErrorResponse { fields } => {
1256 drain_until_ready(stream, buf, Some(state_mutated)).await?;
1257 return Err(PgWireError::Pg(fields));
1258 }
1259 _ => {}
1260 }
1261 }
1262}
1263
1264fn is_stale_statement_error(err: &crate::protocol::types::PgError) -> bool {
1268 matches!(err.code.as_str(), "26000" | "0A000")
1269}
1270
1271fn parse_copy_count(tag: &str) -> u64 {
1272 tag.strip_prefix("COPY ")
1274 .and_then(|s| s.parse::<u64>().ok())
1275 .unwrap_or(0)
1276}
1277
1278impl WireConn {
1280 pub(crate) fn into_stream(self) -> crate::tls::MaybeTlsStream {
1281 self.stream
1282 }
1283}