Skip to main content

pg_wired/
async_conn.rs

1//! Async split sender/receiver connection.
2//! Inspired by hsqlx's PgWire.Async architecture.
3//!
4//! A single TCP connection is shared by many concurrent handler tasks.
5//! The writer task coalesces messages from multiple requests into one write().
6//! The reader task parses responses and dispatches them to waiting handlers via FIFO.
7
8use std::collections::VecDeque;
9use std::sync::Arc;
10
11use bytes::BytesMut;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::sync::{mpsc, oneshot, Mutex};
14
15use crate::connection::WireConn;
16use crate::error::PgWireError;
17use crate::protocol::backend;
18use crate::protocol::frontend;
19use crate::protocol::types::{BackendMsg, FormatCode, FrontendMsg, RawRow};
20
21// ---------------------------------------------------------------------------
22// Request types
23// ---------------------------------------------------------------------------
24
25/// A request to execute on the connection. Internal plumbing between the
26/// public `submit` / `submit_batch` API and the writer task.
27pub(crate) struct PipelineRequest {
28    pub(crate) messages: BytesMut,
29    pub(crate) collector: ResponseCollector,
30    pub(crate) response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
31}
32
33/// How to collect response messages for a request.
34#[allow(dead_code)]
35#[non_exhaustive]
36pub enum ResponseCollector {
37    /// Collect DataRows until ReadyForQuery (for SELECT queries).
38    Rows,
39    /// Just drain until ReadyForQuery (for setup commands like BEGIN, SET ROLE).
40    Drain,
41    /// Stream rows one at a time via channels. Sends header first, then individual rows.
42    Stream {
43        /// One-shot channel for the row description (sent once before any rows).
44        header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
45        /// Bounded channel for individual rows; closed on completion or error.
46        row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
47    },
48    /// COPY IN: after receiving CopyInResponse, send the provided data then CopyDone.
49    CopyIn {
50        /// The data to send after CopyInResponse.
51        data: Vec<u8>,
52    },
53    /// COPY OUT: collect CopyData messages until CopyDone.
54    CopyOut,
55}
56
57/// Response from a pipeline request.
58#[non_exhaustive]
59pub enum PipelineResponse {
60    /// A query that produced a row set (`SELECT`, `RETURNING`, etc.).
61    Rows {
62        /// Column metadata from RowDescription (empty if no RowDescription received).
63        fields: Vec<crate::protocol::types::FieldDescription>,
64        /// Row data.
65        rows: Vec<RawRow>,
66        /// CommandComplete tag (e.g. "SELECT 3", "INSERT 0 1").
67        command_tag: String,
68    },
69    /// A statement that produced no row set (e.g., `BEGIN`, `SET ROLE`,
70    /// non-RETURNING DML).
71    Done,
72}
73
74/// Metadata sent at the start of a streaming response.
75#[derive(Debug, Clone)]
76pub struct StreamHeader {
77    /// Column descriptions (name, OID, format) for the streamed result set.
78    pub fields: Vec<crate::protocol::types::FieldDescription>,
79}
80
81/// A single streamed row.
82pub type StreamedRow = RawRow;
83
84// ---------------------------------------------------------------------------
85// Async connection
86// ---------------------------------------------------------------------------
87
88/// A shared async connection that multiplexes requests from many tasks.
89pub struct AsyncConn {
90    request_tx: mpsc::Sender<PipelineRequest>,
91    stmt_cache: std::sync::Mutex<std::collections::HashMap<String, (String, u64)>>,
92    stmt_counter: std::sync::atomic::AtomicU64,
93    alive: Arc<std::sync::atomic::AtomicBool>,
94    backend_pid: i32,
95    backend_secret: i32,
96    addr: String,
97    /// Channel for async notifications received during query execution.
98    /// Notifications are NOT silently dropped, they're forwarded here.
99    #[allow(dead_code)]
100    notification_tx: mpsc::Sender<crate::protocol::types::BackendMsg>,
101    notification_rx: std::sync::Mutex<Option<mpsc::Receiver<crate::protocol::types::BackendMsg>>>,
102    /// True if any operation since the last `take_state_mutated()` may have
103    /// left the session in a non-default state (open transaction, SET
104    /// without LOCAL, advisory lock, temp table, prepared cursor, etc.).
105    ///
106    /// Set explicitly by callers issuing such operations
107    /// (`mark_state_mutated`), and automatically by the reader task whenever
108    /// ReadyForQuery reports a non-idle transaction status. Callers that
109    /// only run self-contained Bind/Execute/Sync queries leave this `false`,
110    /// allowing pools to skip an expensive DISCARD ALL on return.
111    state_mutated: Arc<std::sync::atomic::AtomicBool>,
112    /// True if a caller has declared the connection unusable (e.g., a
113    /// transaction was dropped without commit/rollback, leaving the session
114    /// in an unknown state). The reader/writer tasks may still be running, so
115    /// `is_alive()` is true, but pools should treat the connection as broken
116    /// and destroy it on return rather than reusing it.
117    broken: Arc<std::sync::atomic::AtomicBool>,
118    /// Cumulative count of asynchronous notifications dropped because the
119    /// notification channel was full or no application code was draining it.
120    /// Surfaced via [`AsyncConn::dropped_notifications`] so callers can detect
121    /// missed `LISTEN` events.
122    dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
123}
124
125impl std::fmt::Debug for AsyncConn {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        f.debug_struct("AsyncConn")
128            .field("addr", &self.addr)
129            .field("backend_pid", &self.backend_pid)
130            .field("alive", &self.is_alive())
131            .finish()
132    }
133}
134
135impl AsyncConn {
136    /// Check if the connection is still alive (writer/reader tasks running).
137    pub fn is_alive(&self) -> bool {
138        self.alive.load(std::sync::atomic::Ordering::Relaxed)
139    }
140
141    /// Backend process ID assigned by the server.
142    pub fn backend_pid(&self) -> i32 {
143        self.backend_pid
144    }
145
146    /// Server address this connection is talking to.
147    pub fn addr(&self) -> &str {
148        &self.addr
149    }
150
151    /// Produce a cancel token for the running session on this connection.
152    pub fn cancel_token(&self) -> crate::cancel::CancelToken {
153        crate::cancel::CancelToken::new(self.addr.clone(), self.backend_pid, self.backend_secret)
154    }
155
156    /// Mark the connection as having mutated session state since the last
157    /// reset. Pools call `take_state_mutated()` on return to decide whether
158    /// to issue `DISCARD ALL`. Callers issuing `BEGIN`, `SET` (without
159    /// `LOCAL`), advisory locks, temp tables, etc., should call this before
160    /// submitting.
161    pub fn mark_state_mutated(&self) {
162        self.state_mutated
163            .store(true, std::sync::atomic::Ordering::Release);
164    }
165
166    /// Atomically read and clear the state-mutated flag. Returns the
167    /// previous value: `true` means the caller should issue a reset.
168    pub fn take_state_mutated(&self) -> bool {
169        self.state_mutated
170            .swap(false, std::sync::atomic::Ordering::AcqRel)
171    }
172
173    /// Read the state-mutated flag without clearing it.
174    pub fn is_state_mutated(&self) -> bool {
175        self.state_mutated
176            .load(std::sync::atomic::Ordering::Acquire)
177    }
178
179    /// Mark the connection as broken. The reader/writer tasks may still be
180    /// running, but the session is in an indeterminate state (for example,
181    /// a transaction was dropped without commit or rollback) and the
182    /// connection must not be reused. Pool integrations check
183    /// [`AsyncConn::is_broken`] on return and destroy the connection
184    /// instead of returning it to the idle set.
185    pub fn mark_broken(&self) {
186        self.broken
187            .store(true, std::sync::atomic::Ordering::Release);
188    }
189
190    /// True if the connection has been declared broken by a caller via
191    /// [`AsyncConn::mark_broken`]. Independent of [`AsyncConn::is_alive`],
192    /// which only reflects whether the reader/writer tasks are still running.
193    pub fn is_broken(&self) -> bool {
194        self.broken.load(std::sync::atomic::Ordering::Acquire)
195    }
196
197    /// Test-only helper that flips the `alive` flag to `false` without
198    /// actually exiting the writer task. Used by pg-wired's own tests and
199    /// by downstream crates' integration tests (e.g. resolute) to exercise
200    /// the dead-conn branch of [`AsyncConn::enqueue_rollback`] (and any
201    /// other code that gates on `is_alive`) without racing against the
202    /// real task-exit timing. Not part of the stable API: the `__` prefix
203    /// and `#[doc(hidden)]` mark this as off-limits for production use.
204    #[doc(hidden)]
205    pub fn __force_mark_dead_for_test(&self) {
206        self.alive
207            .store(false, std::sync::atomic::Ordering::Release);
208    }
209
210    /// Fire-and-forget enqueue of a `ROLLBACK` simple-query, intended to be
211    /// callable from a synchronous `Drop`. Returns `true` if the request was
212    /// queued on the writer task, `false` if the connection is not alive or
213    /// the channel was full/closed (in which case the caller should fall
214    /// back to [`AsyncConn::mark_broken`] so the connection is discarded
215    /// by the pool).
216    ///
217    /// PostgreSQL accepts `ROLLBACK` from any in-transaction state — including
218    /// the aborted state (`25P02`) that a failed query leaves behind — so this
219    /// reliably restores the session to idle. The response is drained and
220    /// discarded; ordering on the writer queue is preserved, so any
221    /// subsequent request (e.g., the pool's `DISCARD ALL` reset) sees a clean
222    /// connection.
223    pub fn enqueue_rollback(&self) -> bool {
224        if !self.is_alive() {
225            return false;
226        }
227        try_enqueue_rollback(&self.request_tx)
228    }
229}
230
231/// Inner helper for [`AsyncConn::enqueue_rollback`]: encodes a `ROLLBACK`
232/// simple-query and tries to push it onto the writer's request channel.
233/// Extracted so the channel-full and channel-closed branches can be unit
234/// tested without instantiating a real `AsyncConn`.
235fn try_enqueue_rollback(request_tx: &mpsc::Sender<PipelineRequest>) -> bool {
236    let mut buf = BytesMut::with_capacity(16);
237    frontend::encode_message(&FrontendMsg::Query(b"ROLLBACK"), &mut buf);
238    let (tx, _rx) = oneshot::channel();
239    request_tx
240        .try_send(PipelineRequest {
241            messages: buf,
242            collector: ResponseCollector::Drain,
243            response_tx: tx,
244        })
245        .is_ok()
246}
247
248struct PendingResponse {
249    collector: ResponseCollector,
250    response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
251}
252
253impl AsyncConn {
254    /// Create a new async connection from a raw WireConn.
255    /// Spawns writer and reader tasks.
256    pub fn new(conn: WireConn) -> Self {
257        let backend_pid = conn.pid;
258        let backend_secret = conn.secret;
259        // Extract peer address before consuming the stream.
260        let addr = conn
261            .stream
262            .peer_addr()
263            .map(|a| a.to_string())
264            .unwrap_or_default();
265
266        let (notification_tx, notification_rx) = mpsc::channel(4096);
267        let (request_tx, request_rx) = mpsc::channel::<PipelineRequest>(256);
268        let pending: Arc<Mutex<VecDeque<PendingResponse>>> = Arc::new(Mutex::new(VecDeque::new()));
269        let pending_notify = Arc::new(tokio::sync::Notify::new());
270        let alive = Arc::new(std::sync::atomic::AtomicBool::new(true));
271        let state_mutated = Arc::new(std::sync::atomic::AtomicBool::new(false));
272        let broken = Arc::new(std::sync::atomic::AtomicBool::new(false));
273        let dropped_notifications = Arc::new(std::sync::atomic::AtomicU64::new(0));
274
275        let (stream_read, stream_write) = tokio::io::split(conn.into_stream());
276
277        // Spawn writer task — sets alive=false on exit.
278        {
279            let pending = Arc::clone(&pending);
280            let pending_notify = Arc::clone(&pending_notify);
281            let alive = Arc::clone(&alive);
282            tokio::spawn(async move {
283                writer_task(request_rx, stream_write, pending, pending_notify).await;
284                alive.store(false, std::sync::atomic::Ordering::Relaxed);
285                tracing::warn!("pg-wired writer task exited");
286            });
287        }
288
289        // Spawn reader task — sets alive=false on exit.
290        {
291            let pending = Arc::clone(&pending);
292            let pending_notify = Arc::clone(&pending_notify);
293            let alive_clone = Arc::clone(&alive);
294            let state_mutated = Arc::clone(&state_mutated);
295            let ntf_tx = notification_tx.clone();
296            let dropped = Arc::clone(&dropped_notifications);
297            tokio::spawn(async move {
298                reader_task(
299                    stream_read,
300                    pending,
301                    pending_notify,
302                    ntf_tx,
303                    state_mutated,
304                    dropped,
305                )
306                .await;
307                alive_clone.store(false, std::sync::atomic::Ordering::Relaxed);
308                tracing::warn!("pg-wired reader task exited");
309            });
310        }
311
312        Self {
313            request_tx,
314            stmt_cache: std::sync::Mutex::new(std::collections::HashMap::new()),
315            stmt_counter: std::sync::atomic::AtomicU64::new(0),
316            alive,
317            backend_pid,
318            backend_secret,
319            addr,
320            notification_tx,
321            notification_rx: std::sync::Mutex::new(Some(notification_rx)),
322            state_mutated,
323            broken,
324            dropped_notifications,
325        }
326    }
327
328    /// Cumulative number of `NotificationResponse` messages this connection
329    /// has discarded since it was created.
330    ///
331    /// Notifications are dropped when (a) the application has not called
332    /// [`AsyncConn::take_notification_receiver`] yet, or (b) the receiver is
333    /// not draining fast enough and the bounded channel fills up. Compare
334    /// successive readings to detect missed `LISTEN` events.
335    pub fn dropped_notifications(&self) -> u64 {
336        self.dropped_notifications
337            .load(std::sync::atomic::Ordering::Relaxed)
338    }
339
340    /// Take the notification receiver. Call once to get a channel that
341    /// receives `NotificationResponse` messages that arrive during queries.
342    pub fn take_notification_receiver(
343        &self,
344    ) -> Option<mpsc::Receiver<crate::protocol::types::BackendMsg>> {
345        self.notification_rx
346            .lock()
347            .ok()
348            .and_then(|mut guard| guard.take())
349    }
350
351    /// Look up or allocate a statement name.
352    /// Uses an LRU-style eviction: when the cache is full, the oldest entry
353    /// (by insertion order / counter) is removed and a Close message is queued
354    /// to free the server-side prepared statement.
355    pub fn lookup_or_alloc(&self, sql: &str) -> (Vec<u8>, bool) {
356        let mut cache = match self.stmt_cache.lock() {
357            Ok(c) => c,
358            Err(poisoned) => poisoned.into_inner(),
359        };
360        if let Some((name, _)) = cache.get(sql) {
361            return (name.as_bytes().to_vec(), false);
362        }
363        // LRU eviction: remove the entry with the lowest counter value
364        // and send a Close message to free the server-side prepared statement.
365        if cache.len() >= 256 {
366            if let Some((oldest_key, oldest_name)) = cache
367                .iter()
368                .min_by_key(|(_, (_, counter))| *counter)
369                .map(|(k, (name, _))| (k.clone(), name.clone()))
370            {
371                cache.remove(&oldest_key);
372                // Queue a Close + Sync to free the server-side statement.
373                // Fire-and-forget: if the channel is full or closed, skip it.
374                let mut close_buf = BytesMut::with_capacity(32);
375                frontend::encode_message(
376                    &FrontendMsg::Close {
377                        kind: b'S',
378                        name: oldest_name.as_bytes(),
379                    },
380                    &mut close_buf,
381                );
382                frontend::encode_message(&FrontendMsg::Sync, &mut close_buf);
383                let (tx, _rx) = oneshot::channel();
384                let _ = self.request_tx.try_send(PipelineRequest {
385                    messages: close_buf,
386                    collector: ResponseCollector::Drain,
387                    response_tx: tx,
388                });
389            }
390        }
391        let n = self
392            .stmt_counter
393            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
394        let name = format!("s{n}");
395        cache.insert(sql.to_string(), (name.clone(), n));
396        (name.into_bytes(), true)
397    }
398
399    /// Execute COPY FROM STDIN: sends the COPY command, then data in chunks, then CopyDone.
400    /// Returns the number of rows copied (from CommandComplete tag).
401    ///
402    /// Data is sent in chunks of up to 1MB to avoid buffering the entire payload
403    /// in a single BytesMut. For small payloads (< 1MB), this is a single write.
404    pub async fn copy_in(&self, copy_sql: &str, data: &[u8]) -> Result<u64, PgWireError> {
405        use crate::protocol::types::FrontendMsg;
406        const CHUNK_SIZE: usize = 1024 * 1024; // 1MB chunks
407
408        // Build the message buffer: Query + chunked CopyData + CopyDone.
409        let mut buf = BytesMut::with_capacity(copy_sql.len() + data.len().min(CHUNK_SIZE) + 64);
410        frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
411
412        // Send data in chunks to avoid a single huge allocation.
413        for chunk in data.chunks(CHUNK_SIZE) {
414            frontend::encode_message(&FrontendMsg::CopyData(chunk), &mut buf);
415        }
416        // Empty data is valid (0 rows copied).
417        frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
418
419        let resp = self
420            .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
421            .await?;
422        match resp {
423            PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
424            PipelineResponse::Done => Ok(0),
425        }
426    }
427
428    /// Execute COPY FROM STDIN with streaming: sends the COPY command, then
429    /// reads data from an async reader in chunks, avoiding buffering the entire
430    /// payload in memory.
431    ///
432    /// ```no_run
433    /// # async fn _doctest() -> Result<(), Box<dyn std::error::Error>> {
434    /// # let conn: pg_wired::AsyncConn = unimplemented!();
435    /// use tokio::fs::File;
436    /// let file = File::open("data.csv").await?;
437    /// let _count = conn.copy_in_stream("COPY users FROM STDIN WITH (FORMAT csv)", file).await?;
438    /// # Ok(()) }
439    /// ```
440    pub async fn copy_in_stream<R: tokio::io::AsyncRead + Unpin>(
441        &self,
442        copy_sql: &str,
443        mut reader: R,
444    ) -> Result<u64, PgWireError> {
445        use tokio::io::AsyncReadExt;
446        const CHUNK_SIZE: usize = 1024 * 1024; // 1MB chunks
447
448        // Send the COPY command.
449        let mut buf = BytesMut::with_capacity(copy_sql.len() + 16);
450        frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
451
452        // Read and send data in chunks.
453        let mut chunk = vec![0u8; CHUNK_SIZE];
454        loop {
455            let n = reader.read(&mut chunk).await?;
456            if n == 0 {
457                break;
458            }
459            frontend::encode_message(&FrontendMsg::CopyData(&chunk[..n]), &mut buf);
460        }
461        frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
462
463        let resp = self
464            .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
465            .await?;
466        match resp {
467            PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
468            PipelineResponse::Done => Ok(0),
469        }
470    }
471
472    /// Execute COPY TO STDOUT: sends the COPY command, collects all CopyData.
473    pub async fn copy_out(&self, copy_sql: &str) -> Result<Vec<u8>, PgWireError> {
474        use crate::protocol::types::FrontendMsg;
475        let mut buf = BytesMut::new();
476        frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
477
478        let resp = self.submit(buf, ResponseCollector::CopyOut).await?;
479        match resp {
480            PipelineResponse::Rows { rows, .. } => {
481                // For CopyOut, we reuse the Rows variant but each `RawRow` carries
482                // one cell which is the raw COPY data chunk (see `collect_copy_out`).
483                let mut result = Vec::new();
484                for row in rows {
485                    for data in row.iter().flatten() {
486                        result.extend_from_slice(data);
487                    }
488                }
489                Ok(result)
490            }
491            PipelineResponse::Done => Ok(Vec::new()),
492        }
493    }
494
495    /// Evict a SQL statement from the cache, forcing re-parse on next use.
496    /// Used for prepared statement invalidation after schema changes.
497    pub fn invalidate_statement(&self, sql: &str) {
498        let mut cache = match self.stmt_cache.lock() {
499            Ok(c) => c,
500            Err(poisoned) => poisoned.into_inner(),
501        };
502        cache.remove(sql);
503    }
504
505    /// Clear the entire statement cache. Must be called after `DISCARD ALL`
506    /// which destroys server-side prepared statements.
507    pub fn clear_statement_cache(&self) {
508        let mut cache = match self.stmt_cache.lock() {
509            Ok(c) => c,
510            Err(poisoned) => poisoned.into_inner(),
511        };
512        cache.clear();
513    }
514
515    /// Execute a pipelined transaction with automatic statement caching.
516    pub async fn exec_transaction(
517        &self,
518        setup_sql: &str,
519        query_sql: &str,
520        params: &[Option<&[u8]>],
521        param_oids: &[u32],
522    ) -> Result<Vec<RawRow>, PgWireError> {
523        let (stmt_name, needs_parse) = self.lookup_or_alloc(query_sql);
524        self.pipeline_transaction(
525            setup_sql,
526            query_sql,
527            params,
528            param_oids,
529            &stmt_name,
530            needs_parse,
531        )
532        .await
533    }
534
535    /// Execute a parameterized query with automatic statement caching.
536    /// If a cached statement is invalidated by a schema change (PG error 26000
537    /// or 0A000), automatically evicts the cache entry, re-parses, and retries once.
538    pub async fn exec_query(
539        &self,
540        sql: &str,
541        params: &[Option<&[u8]>],
542        param_oids: &[u32],
543    ) -> Result<Vec<RawRow>, PgWireError> {
544        let (stmt_name, needs_parse) = self.lookup_or_alloc(sql);
545        match self
546            .query(sql, params, param_oids, &stmt_name, needs_parse)
547            .await
548        {
549            Ok(rows) => Ok(rows),
550            Err(PgWireError::Pg(ref pg_err))
551                if !needs_parse && is_stale_statement_error(pg_err) =>
552            {
553                tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
554                self.invalidate_statement(sql);
555                let (stmt_name, _) = self.lookup_or_alloc(sql);
556                self.query(sql, params, param_oids, &stmt_name, true).await
557            }
558            Err(e) => Err(e),
559        }
560    }
561
562    /// Maximum time to wait for a response from the reader task.
563    /// Prevents hanging forever if the reader/writer task dies mid-request.
564    const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
565
566    /// Submit a request to the connection. Returns a future that resolves
567    /// when the response is available. Times out after 5 minutes to prevent
568    /// hanging forever if the reader/writer task dies.
569    pub async fn submit(
570        &self,
571        messages: BytesMut,
572        collector: ResponseCollector,
573    ) -> Result<PipelineResponse, PgWireError> {
574        let (response_tx, response_rx) = oneshot::channel();
575        let req = PipelineRequest {
576            messages,
577            collector,
578            response_tx,
579        };
580        self.request_tx
581            .send(req)
582            .await
583            .map_err(|_| PgWireError::ConnectionClosed)?;
584        match tokio::time::timeout(Self::REQUEST_TIMEOUT, response_rx).await {
585            Ok(Ok(result)) => result,
586            Ok(Err(_)) => Err(PgWireError::ConnectionClosed),
587            Err(_elapsed) => {
588                tracing::error!(
589                    "request timed out after {:?} — reader/writer task may be dead",
590                    Self::REQUEST_TIMEOUT
591                );
592                Err(PgWireError::ConnectionClosed)
593            }
594        }
595    }
596
597    /// Submit a batch of requests in FIFO order. All requests are queued
598    /// before any response is awaited, so the writer task sees them together
599    /// and coalesces them into a single write() syscall. The server then
600    /// pipelines the N responses back-to-back, giving one network round-trip
601    /// for all N queries.
602    ///
603    /// Returns one `Result<PipelineResponse, PgWireError>` per input item,
604    /// in the same order. The outer `Result` fails only if queueing fails
605    /// (channel closed). Each inner `Result` reflects the per-query outcome.
606    pub async fn submit_batch(
607        &self,
608        items: Vec<(BytesMut, ResponseCollector)>,
609    ) -> Result<Vec<Result<PipelineResponse, PgWireError>>, PgWireError> {
610        let mut receivers = Vec::with_capacity(items.len());
611        for (messages, collector) in items {
612            let (response_tx, response_rx) = oneshot::channel();
613            self.request_tx
614                .send(PipelineRequest {
615                    messages,
616                    collector,
617                    response_tx,
618                })
619                .await
620                .map_err(|_| PgWireError::ConnectionClosed)?;
621            receivers.push(response_rx);
622        }
623        let mut results = Vec::with_capacity(receivers.len());
624        for rx in receivers {
625            match tokio::time::timeout(Self::REQUEST_TIMEOUT, rx).await {
626                Ok(Ok(r)) => results.push(r),
627                Ok(Err(_)) => results.push(Err(PgWireError::ConnectionClosed)),
628                Err(_) => {
629                    tracing::error!(
630                        "submit_batch request timed out after {:?}",
631                        Self::REQUEST_TIMEOUT
632                    );
633                    results.push(Err(PgWireError::ConnectionClosed));
634                }
635            }
636        }
637        Ok(results)
638    }
639
640    /// Send a Terminate message to the server and wait for the writer/reader
641    /// tasks to exit. After this returns, the connection is unusable; further
642    /// calls fail with `ConnectionClosed`. Idempotent: calling `close` on an
643    /// already-closed connection is a no-op and returns `Ok`.
644    pub async fn close(&self) -> Result<(), PgWireError> {
645        if !self.is_alive() {
646            return Ok(());
647        }
648        let mut buf = BytesMut::with_capacity(5);
649        frontend::encode_message(&FrontendMsg::Terminate, &mut buf);
650        // Submit Terminate through the writer so ordering is preserved wrt
651        // any in-flight requests ahead of us. The server replies with nothing
652        // and closes the socket, so we expect `ConnectionClosed` back from
653        // the drain collector — treat that as a successful close.
654        match self.submit(buf, ResponseCollector::Drain).await {
655            Ok(_) | Err(PgWireError::ConnectionClosed) => Ok(()),
656            Err(PgWireError::Io(e)) if e.kind() == std::io::ErrorKind::BrokenPipe => Ok(()),
657            Err(e) => Err(e),
658        }
659    }
660
661    /// Submit a streaming request. Returns the column header and an mpsc receiver
662    /// that yields rows one at a time.
663    pub async fn submit_stream(
664        &self,
665        messages: BytesMut,
666        row_buffer: usize,
667    ) -> Result<
668        (
669            StreamHeader,
670            mpsc::Receiver<Result<StreamedRow, PgWireError>>,
671        ),
672        PgWireError,
673    > {
674        let (header_tx, header_rx) = oneshot::channel();
675        let (row_tx, row_rx) = mpsc::channel(row_buffer);
676        let (response_tx, _response_rx) = oneshot::channel();
677        let req = PipelineRequest {
678            messages,
679            collector: ResponseCollector::Stream { header_tx, row_tx },
680            response_tx,
681        };
682        self.request_tx
683            .send(req)
684            .await
685            .map_err(|_| PgWireError::ConnectionClosed)?;
686        let header = header_rx
687            .await
688            .map_err(|_| PgWireError::ConnectionClosed)??;
689        Ok((header, row_rx))
690    }
691
692    /// Execute a pipelined transaction:
693    /// setup (simple query) + data query (extended protocol) + COMMIT (simple query)
694    /// All coalesced into one TCP write. Binary-safe parameterized data query.
695    pub async fn pipeline_transaction(
696        &self,
697        setup_sql: &str,
698        query_sql: &str,
699        params: &[Option<&[u8]>],
700        param_oids: &[u32],
701        stmt_name: &[u8],
702        needs_parse: bool,
703    ) -> Result<Vec<RawRow>, PgWireError> {
704        let mut buf = BytesMut::with_capacity(1024);
705
706        // 1. Simple query for setup (BEGIN + SET ROLE + set_config).
707        frontend::encode_message(&FrontendMsg::Query(setup_sql.as_bytes()), &mut buf);
708
709        // Submit setup as Drain — we don't care about its response data.
710        let setup_msgs = buf.split();
711
712        // 2. Extended query for data.
713        let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
714        let result_fmts = [FormatCode::Text];
715
716        if needs_parse {
717            frontend::encode_message(
718                &FrontendMsg::Parse {
719                    name: stmt_name,
720                    sql: query_sql.as_bytes(),
721                    param_oids,
722                },
723                &mut buf,
724            );
725        }
726
727        frontend::encode_message(
728            &FrontendMsg::Bind {
729                portal: b"",
730                statement: stmt_name,
731                param_formats: &text_fmts[..params.len()],
732                params,
733                result_formats: &result_fmts,
734            },
735            &mut buf,
736        );
737
738        frontend::encode_message(
739            &FrontendMsg::Execute {
740                portal: b"",
741                max_rows: 0,
742            },
743            &mut buf,
744        );
745
746        frontend::encode_message(&FrontendMsg::Sync, &mut buf);
747
748        let data_msgs = buf.split();
749
750        // 3. Simple query for COMMIT — in its own buffer so each request
751        // carries exactly the bytes that produce its ReadyForQuery response.
752        let mut commit_buf = BytesMut::with_capacity(32);
753        frontend::encode_message(&FrontendMsg::Query(b"COMMIT"), &mut commit_buf);
754
755        // Submit all three as separate requests with different collectors.
756        // They'll be coalesced by the writer into one write() syscall.
757        let (setup_tx, setup_rx) = oneshot::channel();
758        let (data_tx, data_rx) = oneshot::channel();
759        let (commit_tx, commit_rx) = oneshot::channel();
760
761        // Send all three requests to the writer channel.
762        // The writer drains the channel and writes them all at once.
763        self.request_tx
764            .send(PipelineRequest {
765                messages: setup_msgs,
766                collector: ResponseCollector::Drain,
767                response_tx: setup_tx,
768            })
769            .await
770            .map_err(|_| PgWireError::ConnectionClosed)?;
771
772        self.request_tx
773            .send(PipelineRequest {
774                messages: data_msgs,
775                collector: ResponseCollector::Rows,
776                response_tx: data_tx,
777            })
778            .await
779            .map_err(|_| PgWireError::ConnectionClosed)?;
780
781        self.request_tx
782            .send(PipelineRequest {
783                messages: commit_buf,
784                collector: ResponseCollector::Drain,
785                response_tx: commit_tx,
786            })
787            .await
788            .map_err(|_| PgWireError::ConnectionClosed)?;
789
790        // Wait for all responses.
791        setup_rx
792            .await
793            .map_err(|_| PgWireError::ConnectionClosed)??;
794
795        let data_resp = data_rx.await.map_err(|_| PgWireError::ConnectionClosed)??;
796
797        commit_rx
798            .await
799            .map_err(|_| PgWireError::ConnectionClosed)??;
800
801        match data_resp {
802            PipelineResponse::Rows { rows, .. } => Ok(rows),
803            PipelineResponse::Done => Ok(Vec::new()),
804        }
805    }
806
807    /// Execute a simple parameterized query (no transaction).
808    pub async fn query(
809        &self,
810        sql: &str,
811        params: &[Option<&[u8]>],
812        param_oids: &[u32],
813        stmt_name: &[u8],
814        needs_parse: bool,
815    ) -> Result<Vec<RawRow>, PgWireError> {
816        self.query_with_formats(sql, params, param_oids, &[], &[], stmt_name, needs_parse)
817            .await
818    }
819
820    /// Execute a parameterized query with explicit per-param and per-result
821    /// format codes (text = 0, binary = 1).
822    ///
823    /// `param_formats` is interpreted per PostgreSQL wire protocol rules:
824    /// - empty: all params are text
825    /// - length 1: the single code applies to every param
826    /// - length N (== params.len()): one code per param
827    ///
828    /// Same rules apply to `result_formats` for output columns (empty → all
829    /// text; single code → applies to all columns; per-column list otherwise).
830    #[allow(clippy::too_many_arguments)]
831    pub async fn query_with_formats(
832        &self,
833        sql: &str,
834        params: &[Option<&[u8]>],
835        param_oids: &[u32],
836        param_formats: &[FormatCode],
837        result_formats: &[FormatCode],
838        stmt_name: &[u8],
839        needs_parse: bool,
840    ) -> Result<Vec<RawRow>, PgWireError> {
841        let mut buf = BytesMut::with_capacity(512);
842
843        // Default to all-text if caller passes empty slices.
844        let text_param_fmts: Vec<FormatCode>;
845        let param_fmts_slice: &[FormatCode] = if param_formats.is_empty() {
846            text_param_fmts = vec![FormatCode::Text; params.len().max(1)];
847            &text_param_fmts[..params.len()]
848        } else {
849            param_formats
850        };
851        let default_result_fmts = [FormatCode::Text];
852        let result_fmts_slice: &[FormatCode] = if result_formats.is_empty() {
853            &default_result_fmts
854        } else {
855            result_formats
856        };
857
858        if needs_parse {
859            frontend::encode_message(
860                &FrontendMsg::Parse {
861                    name: stmt_name,
862                    sql: sql.as_bytes(),
863                    param_oids,
864                },
865                &mut buf,
866            );
867        }
868
869        frontend::encode_message(
870            &FrontendMsg::Bind {
871                portal: b"",
872                statement: stmt_name,
873                param_formats: param_fmts_slice,
874                params,
875                result_formats: result_fmts_slice,
876            },
877            &mut buf,
878        );
879
880        frontend::encode_message(
881            &FrontendMsg::Execute {
882                portal: b"",
883                max_rows: 0,
884            },
885            &mut buf,
886        );
887
888        frontend::encode_message(&FrontendMsg::Sync, &mut buf);
889
890        let resp = self.submit(buf, ResponseCollector::Rows).await?;
891        match resp {
892            PipelineResponse::Rows { rows, .. } => Ok(rows),
893            PipelineResponse::Done => Ok(Vec::new()),
894        }
895    }
896
897    /// Variant of `exec_query` with per-param and per-result format codes.
898    /// See `query_with_formats` for format code semantics.
899    pub async fn exec_query_with_formats(
900        &self,
901        sql: &str,
902        params: &[Option<&[u8]>],
903        param_oids: &[u32],
904        param_formats: &[FormatCode],
905        result_formats: &[FormatCode],
906    ) -> Result<Vec<RawRow>, PgWireError> {
907        let (stmt_name, needs_parse) = self.lookup_or_alloc(sql);
908        match self
909            .query_with_formats(
910                sql,
911                params,
912                param_oids,
913                param_formats,
914                result_formats,
915                &stmt_name,
916                needs_parse,
917            )
918            .await
919        {
920            Ok(rows) => Ok(rows),
921            Err(PgWireError::Pg(ref pg_err))
922                if !needs_parse && is_stale_statement_error(pg_err) =>
923            {
924                tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
925                self.invalidate_statement(sql);
926                let (stmt_name, _) = self.lookup_or_alloc(sql);
927                self.query_with_formats(
928                    sql,
929                    params,
930                    param_oids,
931                    param_formats,
932                    result_formats,
933                    &stmt_name,
934                    true,
935                )
936                .await
937            }
938            Err(e) => Err(e),
939        }
940    }
941}
942
943// ---------------------------------------------------------------------------
944// Writer task
945// ---------------------------------------------------------------------------
946
947async fn writer_task(
948    mut rx: mpsc::Receiver<PipelineRequest>,
949    mut stream: tokio::io::WriteHalf<crate::tls::MaybeTlsStream>,
950    pending: Arc<Mutex<VecDeque<PendingResponse>>>,
951    pending_notify: Arc<tokio::sync::Notify>,
952) {
953    let mut write_buf = BytesMut::with_capacity(8192);
954
955    loop {
956        // Wait for the first request.
957        let first = match rx.recv().await {
958            Some(req) => req,
959            None => {
960                // Channel closed — drain any pending responses with ConnectionClosed.
961                drain_pending_on_exit(&pending).await;
962                return;
963            }
964        };
965
966        // Drain any additional queued requests (batch coalescing).
967        write_buf.clear();
968        write_buf.extend_from_slice(&first.messages);
969
970        let mut batch: Vec<PendingResponse> = vec![PendingResponse {
971            collector: first.collector,
972            response_tx: first.response_tx,
973        }];
974
975        // Non-blocking drain of all queued requests.
976        while let Ok(req) = rx.try_recv() {
977            write_buf.extend_from_slice(&req.messages);
978            batch.push(PendingResponse {
979                collector: req.collector,
980                response_tx: req.response_tx,
981            });
982        }
983
984        // ONE write() syscall for all coalesced messages.
985        // Write BEFORE enqueuing pending responses — if the write fails,
986        // we send errors to callers instead of leaving them hanging.
987        let write_result = stream.write_all(&write_buf).await;
988        let write_err = match write_result {
989            Ok(_) => stream.flush().await.err(),
990            Err(e) => Some(e),
991        };
992
993        if let Some(e) = write_err {
994            tracing::error!("Writer error: {e}");
995            let msg = e.to_string();
996            for p in batch {
997                let _ = p.response_tx.send(Err(PgWireError::Io(std::io::Error::new(
998                    std::io::ErrorKind::BrokenPipe,
999                    msg.clone(),
1000                ))));
1001            }
1002            // Drain any already-pending responses so the reader doesn't hang.
1003            drain_pending_on_exit(&pending).await;
1004            return;
1005        }
1006
1007        // Write succeeded — enqueue pending responses for the reader.
1008        {
1009            let mut pq = pending.lock().await;
1010            for p in batch {
1011                pq.push_back(p);
1012            }
1013        }
1014        // Wake the reader task to process the newly enqueued responses.
1015        pending_notify.notify_one();
1016    }
1017}
1018
1019/// On writer exit, drain all pending responses with ConnectionClosed errors
1020/// so callers don't wait for the 5-minute timeout.
1021async fn drain_pending_on_exit(pending: &Arc<Mutex<VecDeque<PendingResponse>>>) {
1022    let mut pq = pending.lock().await;
1023    while let Some(pr) = pq.pop_front() {
1024        let _ = pr.response_tx.send(Err(PgWireError::ConnectionClosed));
1025    }
1026}
1027
1028// ---------------------------------------------------------------------------
1029// Reader task
1030// ---------------------------------------------------------------------------
1031
1032async fn reader_task(
1033    mut stream: tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1034    pending: Arc<Mutex<VecDeque<PendingResponse>>>,
1035    pending_notify: Arc<tokio::sync::Notify>,
1036    notification_tx: mpsc::Sender<BackendMsg>,
1037    state_mutated: Arc<std::sync::atomic::AtomicBool>,
1038    dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
1039) {
1040    let mut recv_buf = BytesMut::with_capacity(32 * 1024);
1041
1042    loop {
1043        // Wait for a pending response to become available.
1044        let pr = loop {
1045            {
1046                let mut pq = pending.lock().await;
1047                if let Some(pr) = pq.pop_front() {
1048                    break pr;
1049                }
1050            }
1051            // No pending — wait for the writer to signal.
1052            pending_notify.notified().await;
1053        };
1054
1055        // Collect the response based on the collector type.
1056        let result = match pr.collector {
1057            ResponseCollector::Rows => {
1058                collect_rows(
1059                    &mut stream,
1060                    &mut recv_buf,
1061                    &notification_tx,
1062                    &state_mutated,
1063                    &dropped_notifications,
1064                )
1065                .await
1066            }
1067            ResponseCollector::Drain => {
1068                drain_until_ready(&mut stream, &mut recv_buf, Some(&state_mutated))
1069                    .await
1070                    .map(|_| PipelineResponse::Done)
1071            }
1072            ResponseCollector::Stream { header_tx, row_tx } => {
1073                stream_rows(
1074                    &mut stream,
1075                    &mut recv_buf,
1076                    header_tx,
1077                    row_tx,
1078                    &notification_tx,
1079                    &state_mutated,
1080                    &dropped_notifications,
1081                )
1082                .await;
1083                Ok(PipelineResponse::Done)
1084            }
1085            ResponseCollector::CopyIn { .. } => {
1086                collect_copy_in_response(&mut stream, &mut recv_buf, &state_mutated).await
1087            }
1088            ResponseCollector::CopyOut => {
1089                collect_copy_out(&mut stream, &mut recv_buf, &state_mutated).await
1090            }
1091        };
1092
1093        // Send the response back to the caller.
1094        let _ = pr.response_tx.send(result);
1095    }
1096}
1097
1098async fn read_msg(
1099    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1100    buf: &mut BytesMut,
1101) -> Result<BackendMsg, PgWireError> {
1102    loop {
1103        if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1104            return Ok(msg);
1105        }
1106        let n = stream.read_buf(buf).await?;
1107        if n == 0 {
1108            // EOF — try to parse any remaining data in the buffer before giving up.
1109            // This handles the case where the last message arrived just before the
1110            // connection closed and is already fully buffered.
1111            if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1112                return Ok(msg);
1113            }
1114            return Err(PgWireError::ConnectionClosed);
1115        }
1116    }
1117}
1118
1119/// If the ReadyForQuery status byte is anything other than `I` (idle),
1120/// flag the connection as state-mutated. `T` (in transaction) and `E`
1121/// (failed transaction) both leave session state that needs DISCARD ALL.
1122fn note_rfq_status(status: u8, state_mutated: &std::sync::atomic::AtomicBool) {
1123    if status != b'I' {
1124        state_mutated.store(true, std::sync::atomic::Ordering::Release);
1125    }
1126}
1127
1128async fn collect_rows(
1129    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1130    buf: &mut BytesMut,
1131    notification_tx: &mpsc::Sender<BackendMsg>,
1132    state_mutated: &std::sync::atomic::AtomicBool,
1133    dropped_notifications: &std::sync::atomic::AtomicU64,
1134) -> Result<PipelineResponse, PgWireError> {
1135    let mut rows = Vec::new();
1136    let mut fields = Vec::new();
1137    let mut command_tag = String::new();
1138    loop {
1139        let msg = read_msg(stream, buf).await?;
1140        match msg {
1141            BackendMsg::DataRow(row) => rows.push(row),
1142            BackendMsg::RowDescription { fields: f } => fields = f,
1143            BackendMsg::CommandComplete { tag } => command_tag = tag,
1144            BackendMsg::ReadyForQuery { status } => {
1145                note_rfq_status(status, state_mutated);
1146                return Ok(PipelineResponse::Rows {
1147                    fields,
1148                    rows,
1149                    command_tag,
1150                });
1151            }
1152            BackendMsg::ErrorResponse { fields } => {
1153                drain_until_ready(stream, buf, Some(state_mutated)).await?;
1154                return Err(PgWireError::Pg(fields));
1155            }
1156            msg @ BackendMsg::NotificationResponse { .. } => {
1157                // Forward notification instead of dropping.
1158                #[allow(clippy::collapsible_match)]
1159                if notification_tx.try_send(msg).is_err() {
1160                    dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1161                    tracing::warn!("notification channel full, dropping notification");
1162                }
1163            }
1164            BackendMsg::ParseComplete
1165            | BackendMsg::BindComplete
1166            | BackendMsg::NoData
1167            | BackendMsg::NoticeResponse { .. }
1168            | BackendMsg::EmptyQueryResponse => {}
1169            _ => {}
1170        }
1171    }
1172}
1173
1174async fn drain_until_ready(
1175    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1176    buf: &mut BytesMut,
1177    state_mutated: Option<&std::sync::atomic::AtomicBool>,
1178) -> Result<(), PgWireError> {
1179    loop {
1180        let msg = read_msg(stream, buf).await?;
1181        if let BackendMsg::ReadyForQuery { status } = msg {
1182            if let Some(sm) = state_mutated {
1183                note_rfq_status(status, sm);
1184            }
1185            return Ok(());
1186        }
1187        if let BackendMsg::ErrorResponse { ref fields } = msg {
1188            tracing::warn!("Error in drain: {}: {}", fields.code, fields.message);
1189        }
1190    }
1191}
1192
1193/// Stream rows one at a time, sending header first, then individual rows.
1194async fn stream_rows(
1195    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1196    buf: &mut BytesMut,
1197    header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
1198    row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
1199    notification_tx: &mpsc::Sender<BackendMsg>,
1200    state_mutated: &std::sync::atomic::AtomicBool,
1201    dropped_notifications: &std::sync::atomic::AtomicU64,
1202) {
1203    let mut header_tx = Some(header_tx);
1204    let mut fields = Vec::new();
1205    loop {
1206        let msg = match read_msg(stream, buf).await {
1207            Ok(msg) => msg,
1208            Err(e) => {
1209                if let Some(htx) = header_tx.take() {
1210                    let _ = htx.send(Err(e));
1211                } else {
1212                    let _ = row_tx.send(Err(e)).await;
1213                }
1214                return;
1215            }
1216        };
1217        match msg {
1218            BackendMsg::RowDescription { fields: f } => {
1219                fields = f;
1220            }
1221            BackendMsg::DataRow(row) => {
1222                if let Some(htx) = header_tx.take() {
1223                    let _ = htx.send(Ok(StreamHeader {
1224                        fields: fields.clone(),
1225                    }));
1226                }
1227                if row_tx.send(Ok(row)).await.is_err() {
1228                    let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1229                    return;
1230                }
1231            }
1232            BackendMsg::CommandComplete { .. } => {
1233                if let Some(htx) = header_tx.take() {
1234                    let _ = htx.send(Ok(StreamHeader {
1235                        fields: std::mem::take(&mut fields),
1236                    }));
1237                }
1238            }
1239            BackendMsg::ReadyForQuery { status } => {
1240                note_rfq_status(status, state_mutated);
1241                if let Some(htx) = header_tx.take() {
1242                    let _ = htx.send(Ok(StreamHeader {
1243                        fields: std::mem::take(&mut fields),
1244                    }));
1245                }
1246                return;
1247            }
1248            BackendMsg::ErrorResponse { fields: err } => {
1249                if let Some(htx) = header_tx.take() {
1250                    let _ = htx.send(Err(PgWireError::Pg(err)));
1251                } else {
1252                    let _ = row_tx.send(Err(PgWireError::Pg(err))).await;
1253                }
1254                let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1255                return;
1256            }
1257            msg @ BackendMsg::NotificationResponse { .. } => {
1258                #[allow(clippy::collapsible_match)]
1259                if notification_tx.try_send(msg).is_err() {
1260                    dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1261                    tracing::warn!("notification channel full, dropping notification");
1262                }
1263            }
1264            BackendMsg::ParseComplete
1265            | BackendMsg::BindComplete
1266            | BackendMsg::NoData
1267            | BackendMsg::PortalSuspended
1268            | BackendMsg::NoticeResponse { .. }
1269            | BackendMsg::EmptyQueryResponse => {}
1270            _ => {}
1271        }
1272    }
1273}
1274
1275/// Handle COPY IN response: skip CopyInResponse, wait for CommandComplete + ReadyForQuery.
1276/// The actual CopyData + CopyDone were pre-buffered in the write, so PG processes them.
1277async fn collect_copy_in_response(
1278    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1279    buf: &mut BytesMut,
1280    state_mutated: &std::sync::atomic::AtomicBool,
1281) -> Result<PipelineResponse, PgWireError> {
1282    let mut command_tag = String::new();
1283    loop {
1284        let msg = read_msg(stream, buf).await?;
1285        match msg {
1286            BackendMsg::CopyInResponse { .. } => {}
1287            BackendMsg::CommandComplete { tag } => command_tag = tag,
1288            BackendMsg::ReadyForQuery { status } => {
1289                note_rfq_status(status, state_mutated);
1290                return Ok(PipelineResponse::Rows {
1291                    fields: Vec::new(),
1292                    rows: Vec::new(),
1293                    command_tag,
1294                });
1295            }
1296            BackendMsg::ErrorResponse { fields } => {
1297                drain_until_ready(stream, buf, Some(state_mutated)).await?;
1298                return Err(PgWireError::Pg(fields));
1299            }
1300            _ => {}
1301        }
1302    }
1303}
1304
1305/// Collect COPY OUT data: CopyOutResponse → CopyData* → CopyDone → CommandComplete → ReadyForQuery.
1306async fn collect_copy_out(
1307    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1308    buf: &mut BytesMut,
1309    state_mutated: &std::sync::atomic::AtomicBool,
1310) -> Result<PipelineResponse, PgWireError> {
1311    let mut data_chunks: Vec<RawRow> = Vec::new();
1312    let mut command_tag = String::new();
1313    loop {
1314        let msg = read_msg(stream, buf).await?;
1315        match msg {
1316            BackendMsg::CopyOutResponse { .. } => {}
1317            BackendMsg::CopyData { data } => {
1318                let body = bytes::Bytes::from(data);
1319                data_chunks.push(RawRow::from_full_body(body));
1320            }
1321            BackendMsg::CopyDone => {}
1322            BackendMsg::CommandComplete { tag } => command_tag = tag,
1323            BackendMsg::ReadyForQuery { status } => {
1324                note_rfq_status(status, state_mutated);
1325                return Ok(PipelineResponse::Rows {
1326                    fields: Vec::new(),
1327                    rows: data_chunks,
1328                    command_tag,
1329                });
1330            }
1331            BackendMsg::ErrorResponse { fields } => {
1332                drain_until_ready(stream, buf, Some(state_mutated)).await?;
1333                return Err(PgWireError::Pg(fields));
1334            }
1335            _ => {}
1336        }
1337    }
1338}
1339
1340/// Check if a PostgreSQL error indicates a stale/invalidated prepared statement.
1341/// Error codes: 26000 (invalid_sql_statement_name), 0A000 (feature_not_supported
1342/// — used when cached plan changes type).
1343fn is_stale_statement_error(err: &crate::protocol::types::PgError) -> bool {
1344    matches!(err.code.as_str(), "26000" | "0A000")
1345}
1346
1347fn parse_copy_count(tag: &str) -> u64 {
1348    // COPY tag format: "COPY 123"
1349    tag.strip_prefix("COPY ")
1350        .and_then(|s| s.parse::<u64>().ok())
1351        .unwrap_or(0)
1352}
1353
1354// Extension to WireConn to extract the underlying stream.
1355impl WireConn {
1356    pub(crate) fn into_stream(self) -> crate::tls::MaybeTlsStream {
1357        self.stream
1358    }
1359}
1360
1361#[cfg(test)]
1362mod tests {
1363    use super::*;
1364
1365    /// Channel-full branch: when the request channel has no spare capacity,
1366    /// `try_enqueue_rollback` returns `false` instead of blocking.
1367    #[tokio::test]
1368    async fn try_enqueue_rollback_returns_false_when_channel_full() {
1369        let (tx, _rx) = mpsc::channel::<PipelineRequest>(2);
1370        // Fill the channel by reusing the same helper. capacity=2 plus the
1371        // single buffered slot tokio reserves means we may need to push
1372        // until try_send fails; loop until we observe the false return.
1373        let mut filled = false;
1374        for _ in 0..16 {
1375            if !try_enqueue_rollback(&tx) {
1376                filled = true;
1377                break;
1378            }
1379        }
1380        assert!(
1381            filled,
1382            "expected try_enqueue_rollback to eventually return false on a full channel"
1383        );
1384        assert!(
1385            !try_enqueue_rollback(&tx),
1386            "subsequent calls on a full channel must keep returning false"
1387        );
1388    }
1389
1390    /// Channel-closed branch: dropping the receiver makes `try_send` fail
1391    /// with `Closed`, which `try_enqueue_rollback` reports as `false`.
1392    #[tokio::test]
1393    async fn try_enqueue_rollback_returns_false_when_channel_closed() {
1394        let (tx, rx) = mpsc::channel::<PipelineRequest>(8);
1395        drop(rx);
1396        assert!(
1397            !try_enqueue_rollback(&tx),
1398            "try_enqueue_rollback must return false when the receiver has been dropped"
1399        );
1400    }
1401
1402    /// Happy path: with a live receiver and free capacity, the helper
1403    /// reports success and the receiver observes a queued request whose
1404    /// payload starts with the simple-query opcode `'Q'`.
1405    #[tokio::test]
1406    async fn try_enqueue_rollback_returns_true_and_enqueues_query() {
1407        let (tx, mut rx) = mpsc::channel::<PipelineRequest>(2);
1408        assert!(try_enqueue_rollback(&tx));
1409        let req = rx.recv().await.expect("request should be received");
1410        assert_eq!(
1411            req.messages.first().copied(),
1412            Some(b'Q'),
1413            "queued request should be a simple Query message"
1414        );
1415        // Body should mention ROLLBACK (text follows length prefix and is
1416        // null-terminated; just substring-search to keep the test simple).
1417        assert!(
1418            req.messages.windows(8).any(|w| w == b"ROLLBACK"),
1419            "queued request should contain the ROLLBACK statement text"
1420        );
1421    }
1422}