Skip to main content

pg_wired/
async_conn.rs

1//! Async split sender/receiver connection.
2//! Inspired by hsqlx's PgWire.Async architecture.
3//!
4//! A single TCP connection is shared by many concurrent handler tasks.
5//! The writer task coalesces messages from multiple requests into one write().
6//! The reader task parses responses and dispatches them to waiting handlers via FIFO.
7
8use std::collections::VecDeque;
9use std::sync::Arc;
10
11use bytes::BytesMut;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::sync::{mpsc, oneshot, Mutex};
14
15use crate::connection::WireConn;
16use crate::error::PgWireError;
17use crate::protocol::backend;
18use crate::protocol::frontend;
19use crate::protocol::types::{BackendMsg, FormatCode, FrontendMsg, RawRow};
20
21// ---------------------------------------------------------------------------
22// Request types
23// ---------------------------------------------------------------------------
24
25/// A request to execute on the connection. Internal plumbing between the
26/// public `submit` / `submit_batch` API and the writer task.
27pub(crate) struct PipelineRequest {
28    pub(crate) messages: BytesMut,
29    pub(crate) collector: ResponseCollector,
30    pub(crate) response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
31}
32
33/// How to collect response messages for a request.
34#[allow(dead_code)]
35#[non_exhaustive]
36pub enum ResponseCollector {
37    /// Collect DataRows until ReadyForQuery (for SELECT queries).
38    Rows,
39    /// Just drain until ReadyForQuery (for setup commands like BEGIN, SET ROLE).
40    Drain,
41    /// Stream rows one at a time via channels. Sends header first, then individual rows.
42    Stream {
43        /// One-shot channel for the row description (sent once before any rows).
44        header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
45        /// Bounded channel for individual rows; closed on completion or error.
46        row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
47    },
48    /// COPY IN: after receiving CopyInResponse, send the provided data then CopyDone.
49    CopyIn {
50        /// The data to send after CopyInResponse.
51        data: Vec<u8>,
52    },
53    /// COPY OUT: collect CopyData messages until CopyDone.
54    CopyOut,
55}
56
57/// Response from a pipeline request.
58#[non_exhaustive]
59pub enum PipelineResponse {
60    /// A query that produced a row set (`SELECT`, `RETURNING`, etc.).
61    Rows {
62        /// Column metadata from RowDescription (empty if no RowDescription received).
63        fields: Vec<crate::protocol::types::FieldDescription>,
64        /// Row data.
65        rows: Vec<RawRow>,
66        /// CommandComplete tag (e.g. "SELECT 3", "INSERT 0 1").
67        command_tag: String,
68    },
69    /// A statement that produced no row set (e.g., `BEGIN`, `SET ROLE`,
70    /// non-RETURNING DML).
71    Done,
72}
73
74/// Metadata sent at the start of a streaming response.
75#[derive(Debug, Clone)]
76pub struct StreamHeader {
77    /// Column descriptions (name, OID, format) for the streamed result set.
78    pub fields: Vec<crate::protocol::types::FieldDescription>,
79}
80
81/// A single streamed row.
82pub type StreamedRow = RawRow;
83
84// ---------------------------------------------------------------------------
85// Async connection
86// ---------------------------------------------------------------------------
87
88/// A shared async connection that multiplexes requests from many tasks.
89pub struct AsyncConn {
90    request_tx: mpsc::Sender<PipelineRequest>,
91    stmt_cache: std::sync::Mutex<std::collections::HashMap<String, (String, u64)>>,
92    stmt_counter: std::sync::atomic::AtomicU64,
93    alive: Arc<std::sync::atomic::AtomicBool>,
94    backend_pid: i32,
95    backend_secret: i32,
96    addr: String,
97    /// Channel for async notifications received during query execution.
98    /// Notifications are NOT silently dropped, they're forwarded here.
99    #[allow(dead_code)]
100    notification_tx: mpsc::Sender<crate::protocol::types::BackendMsg>,
101    notification_rx: std::sync::Mutex<Option<mpsc::Receiver<crate::protocol::types::BackendMsg>>>,
102    /// True if any operation since the last `take_state_mutated()` may have
103    /// left the session in a non-default state (open transaction, SET
104    /// without LOCAL, advisory lock, temp table, prepared cursor, etc.).
105    ///
106    /// Set explicitly by callers issuing such operations
107    /// (`mark_state_mutated`), and automatically by the reader task whenever
108    /// ReadyForQuery reports a non-idle transaction status. Callers that
109    /// only run self-contained Bind/Execute/Sync queries leave this `false`,
110    /// allowing pools to skip an expensive DISCARD ALL on return.
111    state_mutated: Arc<std::sync::atomic::AtomicBool>,
112    /// True if a caller has declared the connection unusable (e.g., a
113    /// transaction was dropped without commit/rollback, leaving the session
114    /// in an unknown state). The reader/writer tasks may still be running, so
115    /// `is_alive()` is true, but pools should treat the connection as broken
116    /// and destroy it on return rather than reusing it.
117    broken: Arc<std::sync::atomic::AtomicBool>,
118    /// Cumulative count of asynchronous notifications dropped because the
119    /// notification channel was full or no application code was draining it.
120    /// Surfaced via [`AsyncConn::dropped_notifications`] so callers can detect
121    /// missed `LISTEN` events.
122    dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
123}
124
125impl std::fmt::Debug for AsyncConn {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        f.debug_struct("AsyncConn")
128            .field("addr", &self.addr)
129            .field("backend_pid", &self.backend_pid)
130            .field("alive", &self.is_alive())
131            .finish()
132    }
133}
134
135impl AsyncConn {
136    /// Check if the connection is still alive (writer/reader tasks running).
137    pub fn is_alive(&self) -> bool {
138        self.alive.load(std::sync::atomic::Ordering::Relaxed)
139    }
140
141    /// Backend process ID assigned by the server.
142    pub fn backend_pid(&self) -> i32 {
143        self.backend_pid
144    }
145
146    /// Server address this connection is talking to.
147    pub fn addr(&self) -> &str {
148        &self.addr
149    }
150
151    /// Produce a cancel token for the running session on this connection.
152    pub fn cancel_token(&self) -> crate::cancel::CancelToken {
153        crate::cancel::CancelToken::new(self.addr.clone(), self.backend_pid, self.backend_secret)
154    }
155
156    /// Mark the connection as having mutated session state since the last
157    /// reset. Pools call `take_state_mutated()` on return to decide whether
158    /// to issue `DISCARD ALL`. Callers issuing `BEGIN`, `SET` (without
159    /// `LOCAL`), advisory locks, temp tables, etc., should call this before
160    /// submitting.
161    pub fn mark_state_mutated(&self) {
162        self.state_mutated
163            .store(true, std::sync::atomic::Ordering::Release);
164    }
165
166    /// Atomically read and clear the state-mutated flag. Returns the
167    /// previous value: `true` means the caller should issue a reset.
168    pub fn take_state_mutated(&self) -> bool {
169        self.state_mutated
170            .swap(false, std::sync::atomic::Ordering::AcqRel)
171    }
172
173    /// Read the state-mutated flag without clearing it.
174    pub fn is_state_mutated(&self) -> bool {
175        self.state_mutated
176            .load(std::sync::atomic::Ordering::Acquire)
177    }
178
179    /// Mark the connection as broken. The reader/writer tasks may still be
180    /// running, but the session is in an indeterminate state (for example,
181    /// a transaction was dropped without commit or rollback) and the
182    /// connection must not be reused. Pool integrations check
183    /// [`AsyncConn::is_broken`] on return and destroy the connection
184    /// instead of returning it to the idle set.
185    pub fn mark_broken(&self) {
186        self.broken
187            .store(true, std::sync::atomic::Ordering::Release);
188    }
189
190    /// True if the connection has been declared broken by a caller via
191    /// [`AsyncConn::mark_broken`]. Independent of [`AsyncConn::is_alive`],
192    /// which only reflects whether the reader/writer tasks are still running.
193    pub fn is_broken(&self) -> bool {
194        self.broken.load(std::sync::atomic::Ordering::Acquire)
195    }
196
197    /// Test-only helper that flips the `alive` flag to `false` without
198    /// actually exiting the writer task. Used by pg-wired's own tests and
199    /// by downstream crates' integration tests (e.g. resolute) to exercise
200    /// the dead-conn branch of [`AsyncConn::enqueue_rollback`] (and any
201    /// other code that gates on `is_alive`) without racing against the
202    /// real task-exit timing. Not part of the stable API: the `__` prefix
203    /// and `#[doc(hidden)]` mark this as off-limits for production use.
204    #[doc(hidden)]
205    pub fn __force_mark_dead_for_test(&self) {
206        self.alive
207            .store(false, std::sync::atomic::Ordering::Release);
208    }
209
210    /// Fire-and-forget enqueue of a `ROLLBACK` simple-query, intended to be
211    /// callable from a synchronous `Drop`. Returns `true` if the request was
212    /// queued on the writer task, `false` if the connection is not alive or
213    /// the channel was full/closed (in which case the caller should fall
214    /// back to [`AsyncConn::mark_broken`] so the connection is discarded
215    /// by the pool).
216    ///
217    /// PostgreSQL accepts `ROLLBACK` from any in-transaction state — including
218    /// the aborted state (`25P02`) that a failed query leaves behind — so this
219    /// reliably restores the session to idle. The response is drained and
220    /// discarded; ordering on the writer queue is preserved, so any
221    /// subsequent request (e.g., the pool's `DISCARD ALL` reset) sees a clean
222    /// connection.
223    pub fn enqueue_rollback(&self) -> bool {
224        if !self.is_alive() {
225            return false;
226        }
227        try_enqueue_rollback(&self.request_tx)
228    }
229}
230
231/// Inner helper for [`AsyncConn::enqueue_rollback`]: encodes a `ROLLBACK`
232/// simple-query and tries to push it onto the writer's request channel.
233/// Extracted so the channel-full and channel-closed branches can be unit
234/// tested without instantiating a real `AsyncConn`.
235fn try_enqueue_rollback(request_tx: &mpsc::Sender<PipelineRequest>) -> bool {
236    let mut buf = BytesMut::with_capacity(16);
237    frontend::encode_message(&FrontendMsg::Query(b"ROLLBACK"), &mut buf);
238    let (tx, _rx) = oneshot::channel();
239    request_tx
240        .try_send(PipelineRequest {
241            messages: buf,
242            collector: ResponseCollector::Drain,
243            response_tx: tx,
244        })
245        .is_ok()
246}
247
248struct PendingResponse {
249    collector: ResponseCollector,
250    response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
251}
252
253impl AsyncConn {
254    /// Create a new async connection from a raw WireConn.
255    /// Spawns writer and reader tasks.
256    pub fn new(conn: WireConn) -> Self {
257        let backend_pid = conn.pid;
258        let backend_secret = conn.secret;
259        // Extract peer address before consuming the stream.
260        let addr = conn
261            .stream
262            .peer_addr()
263            .map(|a| a.to_string())
264            .unwrap_or_default();
265
266        let (notification_tx, notification_rx) = mpsc::channel(4096);
267        let (request_tx, request_rx) = mpsc::channel::<PipelineRequest>(256);
268        let pending: Arc<Mutex<VecDeque<PendingResponse>>> = Arc::new(Mutex::new(VecDeque::new()));
269        let pending_notify = Arc::new(tokio::sync::Notify::new());
270        let alive = Arc::new(std::sync::atomic::AtomicBool::new(true));
271        let state_mutated = Arc::new(std::sync::atomic::AtomicBool::new(false));
272        let broken = Arc::new(std::sync::atomic::AtomicBool::new(false));
273        let dropped_notifications = Arc::new(std::sync::atomic::AtomicU64::new(0));
274
275        let (stream_read, stream_write) = tokio::io::split(conn.into_stream());
276
277        // Spawn writer task — sets alive=false on exit.
278        {
279            let pending = Arc::clone(&pending);
280            let pending_notify = Arc::clone(&pending_notify);
281            let alive = Arc::clone(&alive);
282            tokio::spawn(async move {
283                writer_task(request_rx, stream_write, pending, pending_notify).await;
284                alive.store(false, std::sync::atomic::Ordering::Relaxed);
285                tracing::warn!("pg-wired writer task exited");
286            });
287        }
288
289        // Spawn reader task — sets alive=false on exit.
290        {
291            let pending = Arc::clone(&pending);
292            let pending_notify = Arc::clone(&pending_notify);
293            let alive_clone = Arc::clone(&alive);
294            let state_mutated = Arc::clone(&state_mutated);
295            let ntf_tx = notification_tx.clone();
296            let dropped = Arc::clone(&dropped_notifications);
297            tokio::spawn(async move {
298                reader_task(
299                    stream_read,
300                    pending,
301                    pending_notify,
302                    ntf_tx,
303                    state_mutated,
304                    dropped,
305                )
306                .await;
307                alive_clone.store(false, std::sync::atomic::Ordering::Relaxed);
308                tracing::warn!("pg-wired reader task exited");
309            });
310        }
311
312        Self {
313            request_tx,
314            stmt_cache: std::sync::Mutex::new(std::collections::HashMap::new()),
315            stmt_counter: std::sync::atomic::AtomicU64::new(0),
316            alive,
317            backend_pid,
318            backend_secret,
319            addr,
320            notification_tx,
321            notification_rx: std::sync::Mutex::new(Some(notification_rx)),
322            state_mutated,
323            broken,
324            dropped_notifications,
325        }
326    }
327
328    /// Cumulative number of `NotificationResponse` messages this connection
329    /// has discarded since it was created.
330    ///
331    /// Notifications are dropped when (a) the application has not called
332    /// [`AsyncConn::take_notification_receiver`] yet, or (b) the receiver is
333    /// not draining fast enough and the bounded channel fills up. Compare
334    /// successive readings to detect missed `LISTEN` events.
335    pub fn dropped_notifications(&self) -> u64 {
336        self.dropped_notifications
337            .load(std::sync::atomic::Ordering::Relaxed)
338    }
339
340    /// Take the notification receiver. Call once to get a channel that
341    /// receives `NotificationResponse` messages that arrive during queries.
342    pub fn take_notification_receiver(
343        &self,
344    ) -> Option<mpsc::Receiver<crate::protocol::types::BackendMsg>> {
345        self.notification_rx
346            .lock()
347            .ok()
348            .and_then(|mut guard| guard.take())
349    }
350
351    /// Look up or allocate a statement name.
352    ///
353    /// On a cache miss we MUST queue a `Parse` for the new name into the
354    /// writer FIFO **before** publishing the name in the shared cache, or
355    /// another concurrent caller could observe the cached name and submit
356    /// a Bind-only request that races ahead of our Parse. The server would
357    /// then reject the Bind with `26000: prepared statement "sN" does not
358    /// exist`.
359    ///
360    /// The fast path is unchanged: if `sql` is already in the cache,
361    /// return its name and `needs_parse=false`. On a miss we build a
362    /// `Parse + Sync` message and `try_send` it under the cache lock. If
363    /// the channel accepts it, we insert into the cache and return
364    /// `needs_parse=false` (the caller sends only `Bind/Execute`). If the
365    /// writer channel is full, we DO NOT insert: the freshly-allocated
366    /// name is unique (counter was already advanced) and is returned with
367    /// `needs_parse=true` so the caller includes its own `Parse` in the
368    /// same submit. No other lookup can race in because the entry was
369    /// never published.
370    ///
371    /// LRU eviction: when the cache is full, the oldest entry (by
372    /// insertion order / counter) is removed and a Close message is
373    /// queued to free the server-side prepared statement.
374    pub fn lookup_or_alloc(&self, sql: &str, param_oids: &[u32]) -> (Vec<u8>, bool) {
375        let mut cache = match self.stmt_cache.lock() {
376            Ok(c) => c,
377            Err(poisoned) => poisoned.into_inner(),
378        };
379        if let Some((name, _)) = cache.get(sql) {
380            return (name.as_bytes().to_vec(), false);
381        }
382        // LRU eviction: remove the entry with the lowest counter value
383        // and send a Close message to free the server-side prepared statement.
384        if cache.len() >= 256 {
385            if let Some((oldest_key, oldest_name)) = cache
386                .iter()
387                .min_by_key(|(_, (_, counter))| *counter)
388                .map(|(k, (name, _))| (k.clone(), name.clone()))
389            {
390                cache.remove(&oldest_key);
391                // Queue a Close + Sync to free the server-side statement.
392                // Fire-and-forget: if the channel is full or closed, skip it.
393                let mut close_buf = BytesMut::with_capacity(32);
394                frontend::encode_message(
395                    &FrontendMsg::Close {
396                        kind: b'S',
397                        name: oldest_name.as_bytes(),
398                    },
399                    &mut close_buf,
400                );
401                frontend::encode_message(&FrontendMsg::Sync, &mut close_buf);
402                let (tx, _rx) = oneshot::channel();
403                let _ = self.request_tx.try_send(PipelineRequest {
404                    messages: close_buf,
405                    collector: ResponseCollector::Drain,
406                    response_tx: tx,
407                });
408            }
409        }
410        let n = self
411            .stmt_counter
412            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
413        let name = format!("s{n}");
414
415        // Queue a Parse + Sync for the new name BEFORE publishing it in
416        // the cache. This guarantees that any concurrent caller who
417        // subsequently sees `name` in the cache will only submit
418        // Bind/Execute requests AFTER this Parse in the writer's FIFO.
419        let mut parse_buf = BytesMut::with_capacity(32 + sql.len());
420        frontend::encode_message(
421            &FrontendMsg::Parse {
422                name: name.as_bytes(),
423                sql: sql.as_bytes(),
424                param_oids,
425            },
426            &mut parse_buf,
427        );
428        frontend::encode_message(&FrontendMsg::Sync, &mut parse_buf);
429        let (parse_tx, _parse_rx) = oneshot::channel();
430        match self.request_tx.try_send(PipelineRequest {
431            messages: parse_buf,
432            collector: ResponseCollector::Drain,
433            response_tx: parse_tx,
434        }) {
435            Ok(()) => {
436                cache.insert(sql.to_string(), (name.clone(), n));
437                (name.into_bytes(), false)
438            }
439            Err(_) => {
440                // Writer channel full or closed. Don't publish the entry:
441                // the unique name is returned with `needs_parse=true` so
442                // the caller emits Parse atomically with Bind/Execute. No
443                // other lookup can pick up `name` because it was never
444                // inserted. Server-side leak per failed try_send is
445                // bounded by the channel-full rate, which is rare.
446                (name.into_bytes(), true)
447            }
448        }
449    }
450
451    /// Execute COPY FROM STDIN: sends the COPY command, then data in chunks, then CopyDone.
452    /// Returns the number of rows copied (from CommandComplete tag).
453    ///
454    /// Data is sent in chunks of up to 1MB to avoid buffering the entire payload
455    /// in a single BytesMut. For small payloads (< 1MB), this is a single write.
456    pub async fn copy_in(&self, copy_sql: &str, data: &[u8]) -> Result<u64, PgWireError> {
457        use crate::protocol::types::FrontendMsg;
458        const CHUNK_SIZE: usize = 1024 * 1024; // 1MB chunks
459
460        // Build the message buffer: Query + chunked CopyData + CopyDone.
461        let mut buf = BytesMut::with_capacity(copy_sql.len() + data.len().min(CHUNK_SIZE) + 64);
462        frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
463
464        // Send data in chunks to avoid a single huge allocation.
465        for chunk in data.chunks(CHUNK_SIZE) {
466            frontend::encode_message(&FrontendMsg::CopyData(chunk), &mut buf);
467        }
468        // Empty data is valid (0 rows copied).
469        frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
470
471        let resp = self
472            .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
473            .await?;
474        match resp {
475            PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
476            PipelineResponse::Done => Ok(0),
477        }
478    }
479
480    /// Execute COPY FROM STDIN with streaming: sends the COPY command, then
481    /// reads data from an async reader in chunks, avoiding buffering the entire
482    /// payload in memory.
483    ///
484    /// ```no_run
485    /// # async fn _doctest() -> Result<(), Box<dyn std::error::Error>> {
486    /// # let conn: pg_wired::AsyncConn = unimplemented!();
487    /// use tokio::fs::File;
488    /// let file = File::open("data.csv").await?;
489    /// let _count = conn.copy_in_stream("COPY users FROM STDIN WITH (FORMAT csv)", file).await?;
490    /// # Ok(()) }
491    /// ```
492    pub async fn copy_in_stream<R: tokio::io::AsyncRead + Unpin>(
493        &self,
494        copy_sql: &str,
495        mut reader: R,
496    ) -> Result<u64, PgWireError> {
497        use tokio::io::AsyncReadExt;
498        const CHUNK_SIZE: usize = 1024 * 1024; // 1MB chunks
499
500        // Send the COPY command.
501        let mut buf = BytesMut::with_capacity(copy_sql.len() + 16);
502        frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
503
504        // Read and send data in chunks.
505        let mut chunk = vec![0u8; CHUNK_SIZE];
506        loop {
507            let n = reader.read(&mut chunk).await?;
508            if n == 0 {
509                break;
510            }
511            frontend::encode_message(&FrontendMsg::CopyData(&chunk[..n]), &mut buf);
512        }
513        frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
514
515        let resp = self
516            .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
517            .await?;
518        match resp {
519            PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
520            PipelineResponse::Done => Ok(0),
521        }
522    }
523
524    /// Execute COPY TO STDOUT: sends the COPY command, collects all CopyData.
525    pub async fn copy_out(&self, copy_sql: &str) -> Result<Vec<u8>, PgWireError> {
526        use crate::protocol::types::FrontendMsg;
527        let mut buf = BytesMut::new();
528        frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
529
530        let resp = self.submit(buf, ResponseCollector::CopyOut).await?;
531        match resp {
532            PipelineResponse::Rows { rows, .. } => {
533                // For CopyOut, we reuse the Rows variant but each `RawRow` carries
534                // one cell which is the raw COPY data chunk (see `collect_copy_out`).
535                let mut result = Vec::new();
536                for row in rows {
537                    for data in row.iter().flatten() {
538                        result.extend_from_slice(data);
539                    }
540                }
541                Ok(result)
542            }
543            PipelineResponse::Done => Ok(Vec::new()),
544        }
545    }
546
547    /// Evict a SQL statement from the cache, forcing re-parse on next use.
548    /// Used for prepared statement invalidation after schema changes.
549    pub fn invalidate_statement(&self, sql: &str) {
550        let mut cache = match self.stmt_cache.lock() {
551            Ok(c) => c,
552            Err(poisoned) => poisoned.into_inner(),
553        };
554        cache.remove(sql);
555    }
556
557    /// Clear the entire statement cache. Must be called after `DISCARD ALL`
558    /// which destroys server-side prepared statements.
559    pub fn clear_statement_cache(&self) {
560        let mut cache = match self.stmt_cache.lock() {
561            Ok(c) => c,
562            Err(poisoned) => poisoned.into_inner(),
563        };
564        cache.clear();
565    }
566
567    /// Execute a pipelined transaction with automatic statement caching.
568    pub async fn exec_transaction(
569        &self,
570        setup_sql: &str,
571        query_sql: &str,
572        params: &[Option<&[u8]>],
573        param_oids: &[u32],
574    ) -> Result<Vec<RawRow>, PgWireError> {
575        let (stmt_name, needs_parse) = self.lookup_or_alloc(query_sql, param_oids);
576        self.pipeline_transaction(
577            setup_sql,
578            query_sql,
579            params,
580            param_oids,
581            &stmt_name,
582            needs_parse,
583        )
584        .await
585    }
586
587    /// Execute a parameterized query with automatic statement caching.
588    /// If a cached statement is invalidated by a schema change (PG error 26000
589    /// or 0A000), automatically evicts the cache entry, re-parses, and retries once.
590    pub async fn exec_query(
591        &self,
592        sql: &str,
593        params: &[Option<&[u8]>],
594        param_oids: &[u32],
595    ) -> Result<Vec<RawRow>, PgWireError> {
596        let (stmt_name, needs_parse) = self.lookup_or_alloc(sql, param_oids);
597        match self
598            .query(sql, params, param_oids, &stmt_name, needs_parse)
599            .await
600        {
601            Ok(rows) => Ok(rows),
602            Err(PgWireError::Pg(ref pg_err))
603                if !needs_parse && is_stale_statement_error(pg_err) =>
604            {
605                tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
606                self.invalidate_statement(sql);
607                let (stmt_name, _) = self.lookup_or_alloc(sql, param_oids);
608                self.query(sql, params, param_oids, &stmt_name, true).await
609            }
610            Err(e) => Err(e),
611        }
612    }
613
614    /// Maximum time to wait for a response from the reader task.
615    /// Prevents hanging forever if the reader/writer task dies mid-request.
616    const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
617
618    /// Submit a request to the connection. Returns a future that resolves
619    /// when the response is available. Times out after 5 minutes to prevent
620    /// hanging forever if the reader/writer task dies.
621    pub async fn submit(
622        &self,
623        messages: BytesMut,
624        collector: ResponseCollector,
625    ) -> Result<PipelineResponse, PgWireError> {
626        let (response_tx, response_rx) = oneshot::channel();
627        let req = PipelineRequest {
628            messages,
629            collector,
630            response_tx,
631        };
632        self.request_tx
633            .send(req)
634            .await
635            .map_err(|_| PgWireError::ConnectionClosed)?;
636        match tokio::time::timeout(Self::REQUEST_TIMEOUT, response_rx).await {
637            Ok(Ok(result)) => result,
638            Ok(Err(_)) => Err(PgWireError::ConnectionClosed),
639            Err(_elapsed) => {
640                tracing::error!(
641                    "request timed out after {:?} — reader/writer task may be dead",
642                    Self::REQUEST_TIMEOUT
643                );
644                Err(PgWireError::ConnectionClosed)
645            }
646        }
647    }
648
649    /// Submit a batch of requests in FIFO order. All requests are queued
650    /// before any response is awaited, so the writer task sees them together
651    /// and coalesces them into a single write() syscall. The server then
652    /// pipelines the N responses back-to-back, giving one network round-trip
653    /// for all N queries.
654    ///
655    /// Returns one `Result<PipelineResponse, PgWireError>` per input item,
656    /// in the same order. The outer `Result` fails only if queueing fails
657    /// (channel closed). Each inner `Result` reflects the per-query outcome.
658    pub async fn submit_batch(
659        &self,
660        items: Vec<(BytesMut, ResponseCollector)>,
661    ) -> Result<Vec<Result<PipelineResponse, PgWireError>>, PgWireError> {
662        let mut receivers = Vec::with_capacity(items.len());
663        for (messages, collector) in items {
664            let (response_tx, response_rx) = oneshot::channel();
665            self.request_tx
666                .send(PipelineRequest {
667                    messages,
668                    collector,
669                    response_tx,
670                })
671                .await
672                .map_err(|_| PgWireError::ConnectionClosed)?;
673            receivers.push(response_rx);
674        }
675        let mut results = Vec::with_capacity(receivers.len());
676        for rx in receivers {
677            match tokio::time::timeout(Self::REQUEST_TIMEOUT, rx).await {
678                Ok(Ok(r)) => results.push(r),
679                Ok(Err(_)) => results.push(Err(PgWireError::ConnectionClosed)),
680                Err(_) => {
681                    tracing::error!(
682                        "submit_batch request timed out after {:?}",
683                        Self::REQUEST_TIMEOUT
684                    );
685                    results.push(Err(PgWireError::ConnectionClosed));
686                }
687            }
688        }
689        Ok(results)
690    }
691
692    /// Send a Terminate message to the server and wait for the writer/reader
693    /// tasks to exit. After this returns, the connection is unusable; further
694    /// calls fail with `ConnectionClosed`. Idempotent: calling `close` on an
695    /// already-closed connection is a no-op and returns `Ok`.
696    pub async fn close(&self) -> Result<(), PgWireError> {
697        if !self.is_alive() {
698            return Ok(());
699        }
700        let mut buf = BytesMut::with_capacity(5);
701        frontend::encode_message(&FrontendMsg::Terminate, &mut buf);
702        // Submit Terminate through the writer so ordering is preserved wrt
703        // any in-flight requests ahead of us. The server replies with nothing
704        // and closes the socket, so we expect `ConnectionClosed` back from
705        // the drain collector — treat that as a successful close.
706        match self.submit(buf, ResponseCollector::Drain).await {
707            Ok(_) | Err(PgWireError::ConnectionClosed) => Ok(()),
708            Err(PgWireError::Io(e)) if e.kind() == std::io::ErrorKind::BrokenPipe => Ok(()),
709            Err(e) => Err(e),
710        }
711    }
712
713    /// Submit a streaming request. Returns the column header and an mpsc receiver
714    /// that yields rows one at a time.
715    pub async fn submit_stream(
716        &self,
717        messages: BytesMut,
718        row_buffer: usize,
719    ) -> Result<
720        (
721            StreamHeader,
722            mpsc::Receiver<Result<StreamedRow, PgWireError>>,
723        ),
724        PgWireError,
725    > {
726        let (header_tx, header_rx) = oneshot::channel();
727        let (row_tx, row_rx) = mpsc::channel(row_buffer);
728        let (response_tx, _response_rx) = oneshot::channel();
729        let req = PipelineRequest {
730            messages,
731            collector: ResponseCollector::Stream { header_tx, row_tx },
732            response_tx,
733        };
734        self.request_tx
735            .send(req)
736            .await
737            .map_err(|_| PgWireError::ConnectionClosed)?;
738        let header = header_rx
739            .await
740            .map_err(|_| PgWireError::ConnectionClosed)??;
741        Ok((header, row_rx))
742    }
743
744    /// Execute a pipelined transaction:
745    /// setup (simple query) + data query (extended protocol) + COMMIT (simple query)
746    /// All coalesced into one TCP write. Binary-safe parameterized data query.
747    pub async fn pipeline_transaction(
748        &self,
749        setup_sql: &str,
750        query_sql: &str,
751        params: &[Option<&[u8]>],
752        param_oids: &[u32],
753        stmt_name: &[u8],
754        needs_parse: bool,
755    ) -> Result<Vec<RawRow>, PgWireError> {
756        let mut buf = BytesMut::with_capacity(1024);
757
758        // 1. Simple query for setup (BEGIN + SET ROLE + set_config).
759        frontend::encode_message(&FrontendMsg::Query(setup_sql.as_bytes()), &mut buf);
760
761        // Submit setup as Drain — we don't care about its response data.
762        let setup_msgs = buf.split();
763
764        // 2. Extended query for data.
765        let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
766        let result_fmts = [FormatCode::Text];
767
768        if needs_parse {
769            frontend::encode_message(
770                &FrontendMsg::Parse {
771                    name: stmt_name,
772                    sql: query_sql.as_bytes(),
773                    param_oids,
774                },
775                &mut buf,
776            );
777        }
778
779        frontend::encode_message(
780            &FrontendMsg::Bind {
781                portal: b"",
782                statement: stmt_name,
783                param_formats: &text_fmts[..params.len()],
784                params,
785                result_formats: &result_fmts,
786            },
787            &mut buf,
788        );
789
790        frontend::encode_message(
791            &FrontendMsg::Execute {
792                portal: b"",
793                max_rows: 0,
794            },
795            &mut buf,
796        );
797
798        frontend::encode_message(&FrontendMsg::Sync, &mut buf);
799
800        let data_msgs = buf.split();
801
802        // 3. Simple query for COMMIT — in its own buffer so each request
803        // carries exactly the bytes that produce its ReadyForQuery response.
804        let mut commit_buf = BytesMut::with_capacity(32);
805        frontend::encode_message(&FrontendMsg::Query(b"COMMIT"), &mut commit_buf);
806
807        // Submit all three as separate requests with different collectors.
808        // They'll be coalesced by the writer into one write() syscall.
809        let (setup_tx, setup_rx) = oneshot::channel();
810        let (data_tx, data_rx) = oneshot::channel();
811        let (commit_tx, commit_rx) = oneshot::channel();
812
813        // Send all three requests to the writer channel.
814        // The writer drains the channel and writes them all at once.
815        self.request_tx
816            .send(PipelineRequest {
817                messages: setup_msgs,
818                collector: ResponseCollector::Drain,
819                response_tx: setup_tx,
820            })
821            .await
822            .map_err(|_| PgWireError::ConnectionClosed)?;
823
824        self.request_tx
825            .send(PipelineRequest {
826                messages: data_msgs,
827                collector: ResponseCollector::Rows,
828                response_tx: data_tx,
829            })
830            .await
831            .map_err(|_| PgWireError::ConnectionClosed)?;
832
833        self.request_tx
834            .send(PipelineRequest {
835                messages: commit_buf,
836                collector: ResponseCollector::Drain,
837                response_tx: commit_tx,
838            })
839            .await
840            .map_err(|_| PgWireError::ConnectionClosed)?;
841
842        // Wait for all responses.
843        setup_rx
844            .await
845            .map_err(|_| PgWireError::ConnectionClosed)??;
846
847        let data_resp = data_rx.await.map_err(|_| PgWireError::ConnectionClosed)??;
848
849        commit_rx
850            .await
851            .map_err(|_| PgWireError::ConnectionClosed)??;
852
853        match data_resp {
854            PipelineResponse::Rows { rows, .. } => Ok(rows),
855            PipelineResponse::Done => Ok(Vec::new()),
856        }
857    }
858
859    /// Execute a simple parameterized query (no transaction).
860    pub async fn query(
861        &self,
862        sql: &str,
863        params: &[Option<&[u8]>],
864        param_oids: &[u32],
865        stmt_name: &[u8],
866        needs_parse: bool,
867    ) -> Result<Vec<RawRow>, PgWireError> {
868        self.query_with_formats(sql, params, param_oids, &[], &[], stmt_name, needs_parse)
869            .await
870    }
871
872    /// Execute a parameterized query with explicit per-param and per-result
873    /// format codes (text = 0, binary = 1).
874    ///
875    /// `param_formats` is interpreted per PostgreSQL wire protocol rules:
876    /// - empty: all params are text
877    /// - length 1: the single code applies to every param
878    /// - length N (== params.len()): one code per param
879    ///
880    /// Same rules apply to `result_formats` for output columns (empty → all
881    /// text; single code → applies to all columns; per-column list otherwise).
882    #[allow(clippy::too_many_arguments)]
883    pub async fn query_with_formats(
884        &self,
885        sql: &str,
886        params: &[Option<&[u8]>],
887        param_oids: &[u32],
888        param_formats: &[FormatCode],
889        result_formats: &[FormatCode],
890        stmt_name: &[u8],
891        needs_parse: bool,
892    ) -> Result<Vec<RawRow>, PgWireError> {
893        let mut buf = BytesMut::with_capacity(512);
894
895        // Default to all-text if caller passes empty slices.
896        let text_param_fmts: Vec<FormatCode>;
897        let param_fmts_slice: &[FormatCode] = if param_formats.is_empty() {
898            text_param_fmts = vec![FormatCode::Text; params.len().max(1)];
899            &text_param_fmts[..params.len()]
900        } else {
901            param_formats
902        };
903        let default_result_fmts = [FormatCode::Text];
904        let result_fmts_slice: &[FormatCode] = if result_formats.is_empty() {
905            &default_result_fmts
906        } else {
907            result_formats
908        };
909
910        if needs_parse {
911            frontend::encode_message(
912                &FrontendMsg::Parse {
913                    name: stmt_name,
914                    sql: sql.as_bytes(),
915                    param_oids,
916                },
917                &mut buf,
918            );
919        }
920
921        frontend::encode_message(
922            &FrontendMsg::Bind {
923                portal: b"",
924                statement: stmt_name,
925                param_formats: param_fmts_slice,
926                params,
927                result_formats: result_fmts_slice,
928            },
929            &mut buf,
930        );
931
932        frontend::encode_message(
933            &FrontendMsg::Execute {
934                portal: b"",
935                max_rows: 0,
936            },
937            &mut buf,
938        );
939
940        frontend::encode_message(&FrontendMsg::Sync, &mut buf);
941
942        let resp = self.submit(buf, ResponseCollector::Rows).await?;
943        match resp {
944            PipelineResponse::Rows { rows, .. } => Ok(rows),
945            PipelineResponse::Done => Ok(Vec::new()),
946        }
947    }
948
949    /// Variant of `exec_query` with per-param and per-result format codes.
950    /// See `query_with_formats` for format code semantics.
951    pub async fn exec_query_with_formats(
952        &self,
953        sql: &str,
954        params: &[Option<&[u8]>],
955        param_oids: &[u32],
956        param_formats: &[FormatCode],
957        result_formats: &[FormatCode],
958    ) -> Result<Vec<RawRow>, PgWireError> {
959        let (stmt_name, needs_parse) = self.lookup_or_alloc(sql, param_oids);
960        match self
961            .query_with_formats(
962                sql,
963                params,
964                param_oids,
965                param_formats,
966                result_formats,
967                &stmt_name,
968                needs_parse,
969            )
970            .await
971        {
972            Ok(rows) => Ok(rows),
973            Err(PgWireError::Pg(ref pg_err))
974                if !needs_parse && is_stale_statement_error(pg_err) =>
975            {
976                tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
977                self.invalidate_statement(sql);
978                let (stmt_name, _) = self.lookup_or_alloc(sql, param_oids);
979                self.query_with_formats(
980                    sql,
981                    params,
982                    param_oids,
983                    param_formats,
984                    result_formats,
985                    &stmt_name,
986                    true,
987                )
988                .await
989            }
990            Err(e) => Err(e),
991        }
992    }
993}
994
995// ---------------------------------------------------------------------------
996// Writer task
997// ---------------------------------------------------------------------------
998
999async fn writer_task(
1000    mut rx: mpsc::Receiver<PipelineRequest>,
1001    mut stream: tokio::io::WriteHalf<crate::tls::MaybeTlsStream>,
1002    pending: Arc<Mutex<VecDeque<PendingResponse>>>,
1003    pending_notify: Arc<tokio::sync::Notify>,
1004) {
1005    let mut write_buf = BytesMut::with_capacity(8192);
1006
1007    loop {
1008        // Wait for the first request.
1009        let first = match rx.recv().await {
1010            Some(req) => req,
1011            None => {
1012                // Channel closed — drain any pending responses with ConnectionClosed.
1013                drain_pending_on_exit(&pending).await;
1014                return;
1015            }
1016        };
1017
1018        // Drain any additional queued requests (batch coalescing).
1019        write_buf.clear();
1020        write_buf.extend_from_slice(&first.messages);
1021
1022        let mut batch: Vec<PendingResponse> = vec![PendingResponse {
1023            collector: first.collector,
1024            response_tx: first.response_tx,
1025        }];
1026
1027        // Non-blocking drain of all queued requests.
1028        while let Ok(req) = rx.try_recv() {
1029            write_buf.extend_from_slice(&req.messages);
1030            batch.push(PendingResponse {
1031                collector: req.collector,
1032                response_tx: req.response_tx,
1033            });
1034        }
1035
1036        // ONE write() syscall for all coalesced messages.
1037        // Write BEFORE enqueuing pending responses — if the write fails,
1038        // we send errors to callers instead of leaving them hanging.
1039        let write_result = stream.write_all(&write_buf).await;
1040        let write_err = match write_result {
1041            Ok(_) => stream.flush().await.err(),
1042            Err(e) => Some(e),
1043        };
1044
1045        if let Some(e) = write_err {
1046            tracing::error!("Writer error: {e}");
1047            let msg = e.to_string();
1048            for p in batch {
1049                let _ = p.response_tx.send(Err(PgWireError::Io(std::io::Error::new(
1050                    std::io::ErrorKind::BrokenPipe,
1051                    msg.clone(),
1052                ))));
1053            }
1054            // Drain any already-pending responses so the reader doesn't hang.
1055            drain_pending_on_exit(&pending).await;
1056            return;
1057        }
1058
1059        // Write succeeded — enqueue pending responses for the reader.
1060        {
1061            let mut pq = pending.lock().await;
1062            for p in batch {
1063                pq.push_back(p);
1064            }
1065        }
1066        // Wake the reader task to process the newly enqueued responses.
1067        pending_notify.notify_one();
1068    }
1069}
1070
1071/// On writer exit, drain all pending responses with ConnectionClosed errors
1072/// so callers don't wait for the 5-minute timeout.
1073async fn drain_pending_on_exit(pending: &Arc<Mutex<VecDeque<PendingResponse>>>) {
1074    let mut pq = pending.lock().await;
1075    while let Some(pr) = pq.pop_front() {
1076        let _ = pr.response_tx.send(Err(PgWireError::ConnectionClosed));
1077    }
1078}
1079
1080// ---------------------------------------------------------------------------
1081// Reader task
1082// ---------------------------------------------------------------------------
1083
1084async fn reader_task(
1085    mut stream: tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1086    pending: Arc<Mutex<VecDeque<PendingResponse>>>,
1087    pending_notify: Arc<tokio::sync::Notify>,
1088    notification_tx: mpsc::Sender<BackendMsg>,
1089    state_mutated: Arc<std::sync::atomic::AtomicBool>,
1090    dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
1091) {
1092    let mut recv_buf = BytesMut::with_capacity(32 * 1024);
1093
1094    loop {
1095        // Wait for a pending response to become available.
1096        let pr = loop {
1097            {
1098                let mut pq = pending.lock().await;
1099                if let Some(pr) = pq.pop_front() {
1100                    break pr;
1101                }
1102            }
1103            // No pending — wait for the writer to signal.
1104            pending_notify.notified().await;
1105        };
1106
1107        // Collect the response based on the collector type.
1108        let result = match pr.collector {
1109            ResponseCollector::Rows => {
1110                collect_rows(
1111                    &mut stream,
1112                    &mut recv_buf,
1113                    &notification_tx,
1114                    &state_mutated,
1115                    &dropped_notifications,
1116                )
1117                .await
1118            }
1119            ResponseCollector::Drain => {
1120                drain_until_ready(&mut stream, &mut recv_buf, Some(&state_mutated))
1121                    .await
1122                    .map(|_| PipelineResponse::Done)
1123            }
1124            ResponseCollector::Stream { header_tx, row_tx } => {
1125                stream_rows(
1126                    &mut stream,
1127                    &mut recv_buf,
1128                    header_tx,
1129                    row_tx,
1130                    &notification_tx,
1131                    &state_mutated,
1132                    &dropped_notifications,
1133                )
1134                .await;
1135                Ok(PipelineResponse::Done)
1136            }
1137            ResponseCollector::CopyIn { .. } => {
1138                collect_copy_in_response(&mut stream, &mut recv_buf, &state_mutated).await
1139            }
1140            ResponseCollector::CopyOut => {
1141                collect_copy_out(&mut stream, &mut recv_buf, &state_mutated).await
1142            }
1143        };
1144
1145        // Send the response back to the caller.
1146        let _ = pr.response_tx.send(result);
1147    }
1148}
1149
1150async fn read_msg(
1151    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1152    buf: &mut BytesMut,
1153) -> Result<BackendMsg, PgWireError> {
1154    loop {
1155        if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1156            return Ok(msg);
1157        }
1158        let n = stream.read_buf(buf).await?;
1159        if n == 0 {
1160            // EOF — try to parse any remaining data in the buffer before giving up.
1161            // This handles the case where the last message arrived just before the
1162            // connection closed and is already fully buffered.
1163            if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1164                return Ok(msg);
1165            }
1166            return Err(PgWireError::ConnectionClosed);
1167        }
1168    }
1169}
1170
1171/// If the ReadyForQuery status byte is anything other than `I` (idle),
1172/// flag the connection as state-mutated. `T` (in transaction) and `E`
1173/// (failed transaction) both leave session state that needs DISCARD ALL.
1174fn note_rfq_status(status: u8, state_mutated: &std::sync::atomic::AtomicBool) {
1175    if status != b'I' {
1176        state_mutated.store(true, std::sync::atomic::Ordering::Release);
1177    }
1178}
1179
1180async fn collect_rows(
1181    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1182    buf: &mut BytesMut,
1183    notification_tx: &mpsc::Sender<BackendMsg>,
1184    state_mutated: &std::sync::atomic::AtomicBool,
1185    dropped_notifications: &std::sync::atomic::AtomicU64,
1186) -> Result<PipelineResponse, PgWireError> {
1187    let mut rows = Vec::new();
1188    let mut fields = Vec::new();
1189    let mut command_tag = String::new();
1190    loop {
1191        let msg = read_msg(stream, buf).await?;
1192        match msg {
1193            BackendMsg::DataRow(row) => rows.push(row),
1194            BackendMsg::RowDescription { fields: f } => fields = f,
1195            BackendMsg::CommandComplete { tag } => command_tag = tag,
1196            BackendMsg::ReadyForQuery { status } => {
1197                note_rfq_status(status, state_mutated);
1198                return Ok(PipelineResponse::Rows {
1199                    fields,
1200                    rows,
1201                    command_tag,
1202                });
1203            }
1204            BackendMsg::ErrorResponse { fields } => {
1205                drain_until_ready(stream, buf, Some(state_mutated)).await?;
1206                return Err(PgWireError::Pg(fields));
1207            }
1208            msg @ BackendMsg::NotificationResponse { .. } => {
1209                // Forward notification instead of dropping.
1210                #[allow(clippy::collapsible_match)]
1211                if notification_tx.try_send(msg).is_err() {
1212                    dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1213                    tracing::warn!("notification channel full, dropping notification");
1214                }
1215            }
1216            BackendMsg::ParseComplete
1217            | BackendMsg::BindComplete
1218            | BackendMsg::NoData
1219            | BackendMsg::NoticeResponse { .. }
1220            | BackendMsg::EmptyQueryResponse => {}
1221            _ => {}
1222        }
1223    }
1224}
1225
1226async fn drain_until_ready(
1227    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1228    buf: &mut BytesMut,
1229    state_mutated: Option<&std::sync::atomic::AtomicBool>,
1230) -> Result<(), PgWireError> {
1231    loop {
1232        let msg = read_msg(stream, buf).await?;
1233        if let BackendMsg::ReadyForQuery { status } = msg {
1234            if let Some(sm) = state_mutated {
1235                note_rfq_status(status, sm);
1236            }
1237            return Ok(());
1238        }
1239        if let BackendMsg::ErrorResponse { ref fields } = msg {
1240            tracing::warn!("Error in drain: {}: {}", fields.code, fields.message);
1241        }
1242    }
1243}
1244
1245/// Stream rows one at a time, sending header first, then individual rows.
1246async fn stream_rows(
1247    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1248    buf: &mut BytesMut,
1249    header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
1250    row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
1251    notification_tx: &mpsc::Sender<BackendMsg>,
1252    state_mutated: &std::sync::atomic::AtomicBool,
1253    dropped_notifications: &std::sync::atomic::AtomicU64,
1254) {
1255    let mut header_tx = Some(header_tx);
1256    let mut fields = Vec::new();
1257    loop {
1258        let msg = match read_msg(stream, buf).await {
1259            Ok(msg) => msg,
1260            Err(e) => {
1261                if let Some(htx) = header_tx.take() {
1262                    let _ = htx.send(Err(e));
1263                } else {
1264                    let _ = row_tx.send(Err(e)).await;
1265                }
1266                return;
1267            }
1268        };
1269        match msg {
1270            BackendMsg::RowDescription { fields: f } => {
1271                fields = f;
1272            }
1273            BackendMsg::DataRow(row) => {
1274                if let Some(htx) = header_tx.take() {
1275                    let _ = htx.send(Ok(StreamHeader {
1276                        fields: fields.clone(),
1277                    }));
1278                }
1279                if row_tx.send(Ok(row)).await.is_err() {
1280                    let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1281                    return;
1282                }
1283            }
1284            BackendMsg::CommandComplete { .. } => {
1285                if let Some(htx) = header_tx.take() {
1286                    let _ = htx.send(Ok(StreamHeader {
1287                        fields: std::mem::take(&mut fields),
1288                    }));
1289                }
1290            }
1291            BackendMsg::ReadyForQuery { status } => {
1292                note_rfq_status(status, state_mutated);
1293                if let Some(htx) = header_tx.take() {
1294                    let _ = htx.send(Ok(StreamHeader {
1295                        fields: std::mem::take(&mut fields),
1296                    }));
1297                }
1298                return;
1299            }
1300            BackendMsg::ErrorResponse { fields: err } => {
1301                if let Some(htx) = header_tx.take() {
1302                    let _ = htx.send(Err(PgWireError::Pg(err)));
1303                } else {
1304                    let _ = row_tx.send(Err(PgWireError::Pg(err))).await;
1305                }
1306                let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1307                return;
1308            }
1309            msg @ BackendMsg::NotificationResponse { .. } => {
1310                #[allow(clippy::collapsible_match)]
1311                if notification_tx.try_send(msg).is_err() {
1312                    dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1313                    tracing::warn!("notification channel full, dropping notification");
1314                }
1315            }
1316            BackendMsg::ParseComplete
1317            | BackendMsg::BindComplete
1318            | BackendMsg::NoData
1319            | BackendMsg::PortalSuspended
1320            | BackendMsg::NoticeResponse { .. }
1321            | BackendMsg::EmptyQueryResponse => {}
1322            _ => {}
1323        }
1324    }
1325}
1326
1327/// Handle COPY IN response: skip CopyInResponse, wait for CommandComplete + ReadyForQuery.
1328/// The actual CopyData + CopyDone were pre-buffered in the write, so PG processes them.
1329async fn collect_copy_in_response(
1330    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1331    buf: &mut BytesMut,
1332    state_mutated: &std::sync::atomic::AtomicBool,
1333) -> Result<PipelineResponse, PgWireError> {
1334    let mut command_tag = String::new();
1335    loop {
1336        let msg = read_msg(stream, buf).await?;
1337        match msg {
1338            BackendMsg::CopyInResponse { .. } => {}
1339            BackendMsg::CommandComplete { tag } => command_tag = tag,
1340            BackendMsg::ReadyForQuery { status } => {
1341                note_rfq_status(status, state_mutated);
1342                return Ok(PipelineResponse::Rows {
1343                    fields: Vec::new(),
1344                    rows: Vec::new(),
1345                    command_tag,
1346                });
1347            }
1348            BackendMsg::ErrorResponse { fields } => {
1349                drain_until_ready(stream, buf, Some(state_mutated)).await?;
1350                return Err(PgWireError::Pg(fields));
1351            }
1352            _ => {}
1353        }
1354    }
1355}
1356
1357/// Collect COPY OUT data: CopyOutResponse → CopyData* → CopyDone → CommandComplete → ReadyForQuery.
1358async fn collect_copy_out(
1359    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1360    buf: &mut BytesMut,
1361    state_mutated: &std::sync::atomic::AtomicBool,
1362) -> Result<PipelineResponse, PgWireError> {
1363    let mut data_chunks: Vec<RawRow> = Vec::new();
1364    let mut command_tag = String::new();
1365    loop {
1366        let msg = read_msg(stream, buf).await?;
1367        match msg {
1368            BackendMsg::CopyOutResponse { .. } => {}
1369            BackendMsg::CopyData { data } => {
1370                let body = bytes::Bytes::from(data);
1371                data_chunks.push(RawRow::from_full_body(body));
1372            }
1373            BackendMsg::CopyDone => {}
1374            BackendMsg::CommandComplete { tag } => command_tag = tag,
1375            BackendMsg::ReadyForQuery { status } => {
1376                note_rfq_status(status, state_mutated);
1377                return Ok(PipelineResponse::Rows {
1378                    fields: Vec::new(),
1379                    rows: data_chunks,
1380                    command_tag,
1381                });
1382            }
1383            BackendMsg::ErrorResponse { fields } => {
1384                drain_until_ready(stream, buf, Some(state_mutated)).await?;
1385                return Err(PgWireError::Pg(fields));
1386            }
1387            _ => {}
1388        }
1389    }
1390}
1391
1392/// Check if a PostgreSQL error indicates a stale/invalidated prepared statement.
1393/// Error codes: 26000 (invalid_sql_statement_name), 0A000 (feature_not_supported
1394/// — used when cached plan changes type).
1395fn is_stale_statement_error(err: &crate::protocol::types::PgError) -> bool {
1396    matches!(err.code.as_str(), "26000" | "0A000")
1397}
1398
1399fn parse_copy_count(tag: &str) -> u64 {
1400    // COPY tag format: "COPY 123"
1401    tag.strip_prefix("COPY ")
1402        .and_then(|s| s.parse::<u64>().ok())
1403        .unwrap_or(0)
1404}
1405
1406// Extension to WireConn to extract the underlying stream.
1407impl WireConn {
1408    pub(crate) fn into_stream(self) -> crate::tls::MaybeTlsStream {
1409        self.stream
1410    }
1411}
1412
1413#[cfg(test)]
1414mod tests {
1415    use super::*;
1416
1417    /// Channel-full branch: when the request channel has no spare capacity,
1418    /// `try_enqueue_rollback` returns `false` instead of blocking.
1419    #[tokio::test]
1420    async fn try_enqueue_rollback_returns_false_when_channel_full() {
1421        let (tx, _rx) = mpsc::channel::<PipelineRequest>(2);
1422        // Fill the channel by reusing the same helper. capacity=2 plus the
1423        // single buffered slot tokio reserves means we may need to push
1424        // until try_send fails; loop until we observe the false return.
1425        let mut filled = false;
1426        for _ in 0..16 {
1427            if !try_enqueue_rollback(&tx) {
1428                filled = true;
1429                break;
1430            }
1431        }
1432        assert!(
1433            filled,
1434            "expected try_enqueue_rollback to eventually return false on a full channel"
1435        );
1436        assert!(
1437            !try_enqueue_rollback(&tx),
1438            "subsequent calls on a full channel must keep returning false"
1439        );
1440    }
1441
1442    /// Channel-closed branch: dropping the receiver makes `try_send` fail
1443    /// with `Closed`, which `try_enqueue_rollback` reports as `false`.
1444    #[tokio::test]
1445    async fn try_enqueue_rollback_returns_false_when_channel_closed() {
1446        let (tx, rx) = mpsc::channel::<PipelineRequest>(8);
1447        drop(rx);
1448        assert!(
1449            !try_enqueue_rollback(&tx),
1450            "try_enqueue_rollback must return false when the receiver has been dropped"
1451        );
1452    }
1453
1454    /// Happy path: with a live receiver and free capacity, the helper
1455    /// reports success and the receiver observes a queued request whose
1456    /// payload starts with the simple-query opcode `'Q'`.
1457    #[tokio::test]
1458    async fn try_enqueue_rollback_returns_true_and_enqueues_query() {
1459        let (tx, mut rx) = mpsc::channel::<PipelineRequest>(2);
1460        assert!(try_enqueue_rollback(&tx));
1461        let req = rx.recv().await.expect("request should be received");
1462        assert_eq!(
1463            req.messages.first().copied(),
1464            Some(b'Q'),
1465            "queued request should be a simple Query message"
1466        );
1467        // Body should mention ROLLBACK (text follows length prefix and is
1468        // null-terminated; just substring-search to keep the test simple).
1469        assert!(
1470            req.messages.windows(8).any(|w| w == b"ROLLBACK"),
1471            "queued request should contain the ROLLBACK statement text"
1472        );
1473    }
1474}