Skip to main content

pg_wired/
async_conn.rs

1//! Async split sender/receiver connection.
2//! Inspired by hsqlx's PgWire.Async architecture.
3//!
4//! A single TCP connection is shared by many concurrent handler tasks.
5//! The writer task coalesces messages from multiple requests into one write().
6//! The reader task parses responses and dispatches them to waiting handlers via FIFO.
7
8use std::collections::VecDeque;
9use std::sync::Arc;
10
11use bytes::BytesMut;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::sync::{mpsc, oneshot, Mutex};
14
15use crate::connection::WireConn;
16use crate::error::PgWireError;
17use crate::protocol::backend;
18use crate::protocol::frontend;
19use crate::protocol::types::{BackendMsg, FormatCode, FrontendMsg, RawRow};
20
21// ---------------------------------------------------------------------------
22// Request types
23// ---------------------------------------------------------------------------
24
25/// A request to execute on the connection. Internal plumbing between the
26/// public `submit` / `submit_batch` API and the writer task.
27pub(crate) struct PipelineRequest {
28    pub(crate) messages: BytesMut,
29    pub(crate) collector: ResponseCollector,
30    pub(crate) response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
31}
32
33/// How to collect response messages for a request.
34#[allow(dead_code)]
35#[non_exhaustive]
36pub enum ResponseCollector {
37    /// Collect DataRows until ReadyForQuery (for SELECT queries).
38    Rows,
39    /// Just drain until ReadyForQuery (for setup commands like BEGIN, SET ROLE).
40    Drain,
41    /// Stream rows one at a time via channels. Sends header first, then individual rows.
42    Stream {
43        /// One-shot channel for the row description (sent once before any rows).
44        header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
45        /// Bounded channel for individual rows; closed on completion or error.
46        row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
47    },
48    /// COPY IN: after receiving CopyInResponse, send the provided data then CopyDone.
49    CopyIn {
50        /// The data to send after CopyInResponse.
51        data: Vec<u8>,
52    },
53    /// COPY OUT: collect CopyData messages until CopyDone.
54    CopyOut,
55}
56
57/// Response from a pipeline request.
58#[non_exhaustive]
59pub enum PipelineResponse {
60    /// A query that produced a row set (`SELECT`, `RETURNING`, etc.).
61    Rows {
62        /// Column metadata from RowDescription (empty if no RowDescription received).
63        fields: Vec<crate::protocol::types::FieldDescription>,
64        /// Row data.
65        rows: Vec<RawRow>,
66        /// CommandComplete tag (e.g. "SELECT 3", "INSERT 0 1").
67        command_tag: String,
68    },
69    /// A statement that produced no row set (e.g., `BEGIN`, `SET ROLE`,
70    /// non-RETURNING DML).
71    Done,
72}
73
74/// Metadata sent at the start of a streaming response.
75#[derive(Debug, Clone)]
76pub struct StreamHeader {
77    /// Column descriptions (name, OID, format) for the streamed result set.
78    pub fields: Vec<crate::protocol::types::FieldDescription>,
79}
80
81/// A single streamed row.
82pub type StreamedRow = RawRow;
83
84// ---------------------------------------------------------------------------
85// Async connection
86// ---------------------------------------------------------------------------
87
88/// A shared async connection that multiplexes requests from many tasks.
89pub struct AsyncConn {
90    request_tx: mpsc::Sender<PipelineRequest>,
91    stmt_cache: std::sync::Mutex<std::collections::HashMap<String, (String, u64)>>,
92    stmt_counter: std::sync::atomic::AtomicU64,
93    alive: Arc<std::sync::atomic::AtomicBool>,
94    backend_pid: i32,
95    backend_secret: i32,
96    addr: String,
97    /// Channel for async notifications received during query execution.
98    /// Notifications are NOT silently dropped, they're forwarded here.
99    #[allow(dead_code)]
100    notification_tx: mpsc::Sender<crate::protocol::types::BackendMsg>,
101    notification_rx: std::sync::Mutex<Option<mpsc::Receiver<crate::protocol::types::BackendMsg>>>,
102    /// True if any operation since the last `take_state_mutated()` may have
103    /// left the session in a non-default state (open transaction, SET
104    /// without LOCAL, advisory lock, temp table, prepared cursor, etc.).
105    ///
106    /// Set explicitly by callers issuing such operations
107    /// (`mark_state_mutated`), and automatically by the reader task whenever
108    /// ReadyForQuery reports a non-idle transaction status. Callers that
109    /// only run self-contained Bind/Execute/Sync queries leave this `false`,
110    /// allowing pools to skip an expensive DISCARD ALL on return.
111    state_mutated: Arc<std::sync::atomic::AtomicBool>,
112    /// True if a caller has declared the connection unusable (e.g., a
113    /// transaction was dropped without commit/rollback, leaving the session
114    /// in an unknown state). The reader/writer tasks may still be running, so
115    /// `is_alive()` is true, but pools should treat the connection as broken
116    /// and destroy it on return rather than reusing it.
117    broken: Arc<std::sync::atomic::AtomicBool>,
118    /// Cumulative count of asynchronous notifications dropped because the
119    /// notification channel was full or no application code was draining it.
120    /// Surfaced via [`AsyncConn::dropped_notifications`] so callers can detect
121    /// missed `LISTEN` events.
122    dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
123}
124
125impl std::fmt::Debug for AsyncConn {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        f.debug_struct("AsyncConn")
128            .field("addr", &self.addr)
129            .field("backend_pid", &self.backend_pid)
130            .field("alive", &self.is_alive())
131            .finish()
132    }
133}
134
135impl AsyncConn {
136    /// Check if the connection is still alive (writer/reader tasks running).
137    pub fn is_alive(&self) -> bool {
138        self.alive.load(std::sync::atomic::Ordering::Relaxed)
139    }
140
141    /// Backend process ID assigned by the server.
142    pub fn backend_pid(&self) -> i32 {
143        self.backend_pid
144    }
145
146    /// Server address this connection is talking to.
147    pub fn addr(&self) -> &str {
148        &self.addr
149    }
150
151    /// Produce a cancel token for the running session on this connection.
152    pub fn cancel_token(&self) -> crate::cancel::CancelToken {
153        crate::cancel::CancelToken::new(self.addr.clone(), self.backend_pid, self.backend_secret)
154    }
155
156    /// Mark the connection as having mutated session state since the last
157    /// reset. Pools call `take_state_mutated()` on return to decide whether
158    /// to issue `DISCARD ALL`. Callers issuing `BEGIN`, `SET` (without
159    /// `LOCAL`), advisory locks, temp tables, etc., should call this before
160    /// submitting.
161    pub fn mark_state_mutated(&self) {
162        self.state_mutated
163            .store(true, std::sync::atomic::Ordering::Release);
164    }
165
166    /// Atomically read and clear the state-mutated flag. Returns the
167    /// previous value: `true` means the caller should issue a reset.
168    pub fn take_state_mutated(&self) -> bool {
169        self.state_mutated
170            .swap(false, std::sync::atomic::Ordering::AcqRel)
171    }
172
173    /// Read the state-mutated flag without clearing it.
174    pub fn is_state_mutated(&self) -> bool {
175        self.state_mutated
176            .load(std::sync::atomic::Ordering::Acquire)
177    }
178
179    /// Mark the connection as broken. The reader/writer tasks may still be
180    /// running, but the session is in an indeterminate state (for example,
181    /// a transaction was dropped without commit or rollback) and the
182    /// connection must not be reused. Pool integrations check
183    /// [`AsyncConn::is_broken`] on return and destroy the connection
184    /// instead of returning it to the idle set.
185    pub fn mark_broken(&self) {
186        self.broken
187            .store(true, std::sync::atomic::Ordering::Release);
188    }
189
190    /// True if the connection has been declared broken by a caller via
191    /// [`AsyncConn::mark_broken`]. Independent of [`AsyncConn::is_alive`],
192    /// which only reflects whether the reader/writer tasks are still running.
193    pub fn is_broken(&self) -> bool {
194        self.broken.load(std::sync::atomic::Ordering::Acquire)
195    }
196
197    /// Test-only helper that flips the `alive` flag to `false` without
198    /// actually exiting the writer task. Used by pg-wired's own tests and
199    /// by downstream crates' integration tests (e.g. resolute) to exercise
200    /// the dead-conn branch of [`AsyncConn::enqueue_rollback`] (and any
201    /// other code that gates on `is_alive`) without racing against the
202    /// real task-exit timing. Not part of the stable API: the `__` prefix
203    /// and `#[doc(hidden)]` mark this as off-limits for production use.
204    #[doc(hidden)]
205    pub fn __force_mark_dead_for_test(&self) {
206        self.alive
207            .store(false, std::sync::atomic::Ordering::Release);
208    }
209
210    /// Fire-and-forget enqueue of a `ROLLBACK` simple-query, intended to be
211    /// callable from a synchronous `Drop`. Returns `true` if the request was
212    /// queued on the writer task, `false` if the connection is not alive or
213    /// the channel was full/closed (in which case the caller should fall
214    /// back to [`AsyncConn::mark_broken`] so the connection is discarded
215    /// by the pool).
216    ///
217    /// PostgreSQL accepts `ROLLBACK` from any in-transaction state — including
218    /// the aborted state (`25P02`) that a failed query leaves behind — so this
219    /// reliably restores the session to idle. The response is drained and
220    /// discarded; ordering on the writer queue is preserved, so any
221    /// subsequent request (e.g., the pool's `DISCARD ALL` reset) sees a clean
222    /// connection.
223    pub fn enqueue_rollback(&self) -> bool {
224        if !self.is_alive() {
225            return false;
226        }
227        try_enqueue_rollback(&self.request_tx)
228    }
229}
230
231/// Inner helper for [`AsyncConn::enqueue_rollback`]: encodes a `ROLLBACK`
232/// simple-query and tries to push it onto the writer's request channel.
233/// Extracted so the channel-full and channel-closed branches can be unit
234/// tested without instantiating a real `AsyncConn`.
235fn try_enqueue_rollback(request_tx: &mpsc::Sender<PipelineRequest>) -> bool {
236    let mut buf = BytesMut::with_capacity(16);
237    frontend::encode_message(&FrontendMsg::Query(b"ROLLBACK"), &mut buf);
238    let (tx, _rx) = oneshot::channel();
239    request_tx
240        .try_send(PipelineRequest {
241            messages: buf,
242            collector: ResponseCollector::Drain,
243            response_tx: tx,
244        })
245        .is_ok()
246}
247
248struct PendingResponse {
249    collector: ResponseCollector,
250    response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
251}
252
253impl AsyncConn {
254    /// Create a new async connection from a raw WireConn.
255    /// Spawns writer and reader tasks.
256    pub fn new(conn: WireConn) -> Self {
257        let backend_pid = conn.pid;
258        let backend_secret = conn.secret;
259        // Extract peer address before consuming the stream.
260        let addr = conn
261            .stream
262            .peer_addr()
263            .map(|a| a.to_string())
264            .unwrap_or_default();
265
266        let (notification_tx, notification_rx) = mpsc::channel(4096);
267        let (request_tx, request_rx) = mpsc::channel::<PipelineRequest>(256);
268        let pending: Arc<Mutex<VecDeque<PendingResponse>>> = Arc::new(Mutex::new(VecDeque::new()));
269        let pending_notify = Arc::new(tokio::sync::Notify::new());
270        let alive = Arc::new(std::sync::atomic::AtomicBool::new(true));
271        let state_mutated = Arc::new(std::sync::atomic::AtomicBool::new(false));
272        let broken = Arc::new(std::sync::atomic::AtomicBool::new(false));
273        let dropped_notifications = Arc::new(std::sync::atomic::AtomicU64::new(0));
274
275        let (stream_read, stream_write) = tokio::io::split(conn.into_stream());
276
277        // Spawn writer task — sets alive=false on exit.
278        {
279            let pending = Arc::clone(&pending);
280            let pending_notify = Arc::clone(&pending_notify);
281            let alive = Arc::clone(&alive);
282            tokio::spawn(async move {
283                writer_task(request_rx, stream_write, pending, pending_notify).await;
284                alive.store(false, std::sync::atomic::Ordering::Relaxed);
285                tracing::warn!("pg-wired writer task exited");
286            });
287        }
288
289        // Spawn reader task — sets alive=false on exit.
290        {
291            let pending = Arc::clone(&pending);
292            let pending_notify = Arc::clone(&pending_notify);
293            let alive_clone = Arc::clone(&alive);
294            let state_mutated = Arc::clone(&state_mutated);
295            let ntf_tx = notification_tx.clone();
296            let dropped = Arc::clone(&dropped_notifications);
297            tokio::spawn(async move {
298                reader_task(
299                    stream_read,
300                    pending,
301                    pending_notify,
302                    ntf_tx,
303                    state_mutated,
304                    dropped,
305                )
306                .await;
307                alive_clone.store(false, std::sync::atomic::Ordering::Relaxed);
308                tracing::warn!("pg-wired reader task exited");
309            });
310        }
311
312        Self {
313            request_tx,
314            stmt_cache: std::sync::Mutex::new(std::collections::HashMap::new()),
315            stmt_counter: std::sync::atomic::AtomicU64::new(0),
316            alive,
317            backend_pid,
318            backend_secret,
319            addr,
320            notification_tx,
321            notification_rx: std::sync::Mutex::new(Some(notification_rx)),
322            state_mutated,
323            broken,
324            dropped_notifications,
325        }
326    }
327
328    /// Cumulative number of `NotificationResponse` messages this connection
329    /// has discarded since it was created.
330    ///
331    /// Notifications are dropped when (a) the application has not called
332    /// [`AsyncConn::take_notification_receiver`] yet, or (b) the receiver is
333    /// not draining fast enough and the bounded channel fills up. Compare
334    /// successive readings to detect missed `LISTEN` events.
335    pub fn dropped_notifications(&self) -> u64 {
336        self.dropped_notifications
337            .load(std::sync::atomic::Ordering::Relaxed)
338    }
339
340    /// Take the notification receiver. Call once to get a channel that
341    /// receives `NotificationResponse` messages that arrive during queries.
342    pub fn take_notification_receiver(
343        &self,
344    ) -> Option<mpsc::Receiver<crate::protocol::types::BackendMsg>> {
345        self.notification_rx
346            .lock()
347            .ok()
348            .and_then(|mut guard| guard.take())
349    }
350
351    /// Look up or allocate a statement name.
352    ///
353    /// Cache hit: returns the cached name with `needs_parse=false`. The
354    /// caller submits only `Bind/Execute/Sync`.
355    ///
356    /// Cache miss: allocates a fresh, unique name from the connection's
357    /// statement counter and returns `(name, needs_parse=true)`. The name
358    /// is NOT yet published in the cache: the caller MUST include a
359    /// `Parse` for the new name in the same atomic submit as
360    /// `Bind/Execute/Sync` (so the Parse runs inside whatever
361    /// role-switched transaction the caller has framed, e.g. `BEGIN; SET
362    /// LOCAL ROLE …; …`), and then call [`Self::cache_statement`] to
363    /// publish the name only after the Parse has succeeded on the wire.
364    ///
365    /// Why publish-after-success: an earlier version pre-queued the
366    /// Parse as its own writer request and published the cache entry
367    /// up-front to avoid a race where a concurrent caller saw the
368    /// cached name and submitted a Bind-only request that races ahead
369    /// of the Parse. That eliminated the race, but ran the Parse
370    /// outside any transaction, under the connection's persistent role
371    /// (e.g. PostgREST's `authenticator`). SQL that references objects
372    /// only reachable after `SET LOCAL ROLE` to a user role failed
373    /// with `42501 permission denied` during Parse, while every
374    /// subsequent Bind for the same name failed with `26000: prepared
375    /// statement "sN" does not exist`. Publishing only after a
376    /// successful Parse keeps caching role-correct: each first-time
377    /// concurrent caller pays for its own Parse (rather than sharing a
378    /// pre-queued one), and `cache_statement` uses first-publisher-wins
379    /// semantics so the losing names become session-bounded orphans
380    /// (bounded by the 256-entry LRU on this connection).
381    pub fn lookup_or_alloc(&self, sql: &str, _param_oids: &[u32]) -> (Vec<u8>, bool) {
382        let cache = match self.stmt_cache.lock() {
383            Ok(c) => c,
384            Err(poisoned) => poisoned.into_inner(),
385        };
386        if let Some((name, _)) = cache.get(sql) {
387            return (name.as_bytes().to_vec(), false);
388        }
389        // Allocate a unique name. Counters never collide, so concurrent
390        // misses get distinct names. The cache stays empty for `sql`
391        // until the caller calls `cache_statement` after a successful
392        // Parse.
393        let n = self
394            .stmt_counter
395            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
396        let name = format!("s{n}");
397        (name.into_bytes(), true)
398    }
399
400    /// Publish a freshly Parsed statement in the cache so subsequent
401    /// lookups for the same SQL skip the Parse step.
402    ///
403    /// Called by the high-level `exec_*` helpers (and any external
404    /// caller of [`Self::lookup_or_alloc`]) after the writer submit that
405    /// included `Parse` for `name` returned successfully. Skipping this
406    /// step doesn't cause correctness problems; the next lookup just
407    /// misses and re-Parses.
408    ///
409    /// First-publisher-wins: if another concurrent miss already
410    /// published a different name for the same SQL, that name stays in
411    /// the cache and the caller's name becomes a server-side orphan
412    /// (cleaned up at session end; bounded by LRU eviction during the
413    /// session).
414    ///
415    /// LRU eviction: when the cache reaches its 256-entry cap, the
416    /// oldest entry by counter is removed and a `Close + Sync` is
417    /// fire-and-forget queued to free the server-side prepared
418    /// statement.
419    pub fn cache_statement(&self, sql: &str, name: &[u8]) {
420        let Ok(name_str) = std::str::from_utf8(name) else {
421            return;
422        };
423        let counter = name_str
424            .strip_prefix('s')
425            .and_then(|s| s.parse::<u64>().ok())
426            .unwrap_or_else(|| self.stmt_counter.load(std::sync::atomic::Ordering::Relaxed));
427        let mut cache = match self.stmt_cache.lock() {
428            Ok(c) => c,
429            Err(poisoned) => poisoned.into_inner(),
430        };
431        if cache.contains_key(sql) {
432            return;
433        }
434        if cache.len() >= 256 {
435            if let Some((oldest_key, oldest_name)) = cache
436                .iter()
437                .min_by_key(|(_, (_, counter))| *counter)
438                .map(|(k, (name, _))| (k.clone(), name.clone()))
439            {
440                cache.remove(&oldest_key);
441                let mut close_buf = BytesMut::with_capacity(32);
442                frontend::encode_message(
443                    &FrontendMsg::Close {
444                        kind: b'S',
445                        name: oldest_name.as_bytes(),
446                    },
447                    &mut close_buf,
448                );
449                frontend::encode_message(&FrontendMsg::Sync, &mut close_buf);
450                let (tx, _rx) = oneshot::channel();
451                let _ = self.request_tx.try_send(PipelineRequest {
452                    messages: close_buf,
453                    collector: ResponseCollector::Drain,
454                    response_tx: tx,
455                });
456            }
457        }
458        cache.insert(sql.to_string(), (name_str.to_string(), counter));
459    }
460
461    /// Execute COPY FROM STDIN: sends the COPY command, then data in chunks, then CopyDone.
462    /// Returns the number of rows copied (from CommandComplete tag).
463    ///
464    /// Data is sent in chunks of up to 1MB to avoid buffering the entire payload
465    /// in a single BytesMut. For small payloads (< 1MB), this is a single write.
466    pub async fn copy_in(&self, copy_sql: &str, data: &[u8]) -> Result<u64, PgWireError> {
467        use crate::protocol::types::FrontendMsg;
468        const CHUNK_SIZE: usize = 1024 * 1024; // 1MB chunks
469
470        // Build the message buffer: Query + chunked CopyData + CopyDone.
471        let mut buf = BytesMut::with_capacity(copy_sql.len() + data.len().min(CHUNK_SIZE) + 64);
472        frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
473
474        // Send data in chunks to avoid a single huge allocation.
475        for chunk in data.chunks(CHUNK_SIZE) {
476            frontend::encode_message(&FrontendMsg::CopyData(chunk), &mut buf);
477        }
478        // Empty data is valid (0 rows copied).
479        frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
480
481        let resp = self
482            .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
483            .await?;
484        match resp {
485            PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
486            PipelineResponse::Done => Ok(0),
487        }
488    }
489
490    /// Execute COPY FROM STDIN with streaming: sends the COPY command, then
491    /// reads data from an async reader in chunks, avoiding buffering the entire
492    /// payload in memory.
493    ///
494    /// ```no_run
495    /// # async fn _doctest() -> Result<(), Box<dyn std::error::Error>> {
496    /// # let conn: pg_wired::AsyncConn = unimplemented!();
497    /// use tokio::fs::File;
498    /// let file = File::open("data.csv").await?;
499    /// let _count = conn.copy_in_stream("COPY users FROM STDIN WITH (FORMAT csv)", file).await?;
500    /// # Ok(()) }
501    /// ```
502    pub async fn copy_in_stream<R: tokio::io::AsyncRead + Unpin>(
503        &self,
504        copy_sql: &str,
505        mut reader: R,
506    ) -> Result<u64, PgWireError> {
507        use tokio::io::AsyncReadExt;
508        const CHUNK_SIZE: usize = 1024 * 1024; // 1MB chunks
509
510        // Send the COPY command.
511        let mut buf = BytesMut::with_capacity(copy_sql.len() + 16);
512        frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
513
514        // Read and send data in chunks.
515        let mut chunk = vec![0u8; CHUNK_SIZE];
516        loop {
517            let n = reader.read(&mut chunk).await?;
518            if n == 0 {
519                break;
520            }
521            frontend::encode_message(&FrontendMsg::CopyData(&chunk[..n]), &mut buf);
522        }
523        frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
524
525        let resp = self
526            .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
527            .await?;
528        match resp {
529            PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
530            PipelineResponse::Done => Ok(0),
531        }
532    }
533
534    /// Execute COPY TO STDOUT: sends the COPY command, collects all CopyData.
535    pub async fn copy_out(&self, copy_sql: &str) -> Result<Vec<u8>, PgWireError> {
536        use crate::protocol::types::FrontendMsg;
537        let mut buf = BytesMut::new();
538        frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
539
540        let resp = self.submit(buf, ResponseCollector::CopyOut).await?;
541        match resp {
542            PipelineResponse::Rows { rows, .. } => {
543                // For CopyOut, we reuse the Rows variant but each `RawRow` carries
544                // one cell which is the raw COPY data chunk (see `collect_copy_out`).
545                let mut result = Vec::new();
546                for row in rows {
547                    for data in row.iter().flatten() {
548                        result.extend_from_slice(data);
549                    }
550                }
551                Ok(result)
552            }
553            PipelineResponse::Done => Ok(Vec::new()),
554        }
555    }
556
557    /// Evict a SQL statement from the cache, forcing re-parse on next use.
558    /// Used for prepared statement invalidation after schema changes.
559    pub fn invalidate_statement(&self, sql: &str) {
560        let mut cache = match self.stmt_cache.lock() {
561            Ok(c) => c,
562            Err(poisoned) => poisoned.into_inner(),
563        };
564        cache.remove(sql);
565    }
566
567    /// Clear the entire statement cache. Must be called after `DISCARD ALL`
568    /// which destroys server-side prepared statements.
569    pub fn clear_statement_cache(&self) {
570        let mut cache = match self.stmt_cache.lock() {
571            Ok(c) => c,
572            Err(poisoned) => poisoned.into_inner(),
573        };
574        cache.clear();
575    }
576
577    /// Execute a pipelined transaction with automatic statement caching.
578    ///
579    /// On a successful Parse the new statement name is published in the
580    /// cache via [`Self::cache_statement`]. If a cached statement turns
581    /// out to be invalid (PG error 26000 or 0A000), the cache entry is
582    /// evicted and the transaction is retried once with a fresh Parse.
583    /// This handles schema changes invalidating cached plans after their
584    /// initial Parse.
585    pub async fn exec_transaction(
586        &self,
587        setup_sql: &str,
588        query_sql: &str,
589        params: &[Option<&[u8]>],
590        param_oids: &[u32],
591    ) -> Result<Vec<RawRow>, PgWireError> {
592        let (stmt_name, needs_parse) = self.lookup_or_alloc(query_sql, param_oids);
593        match self
594            .pipeline_transaction(
595                setup_sql,
596                query_sql,
597                params,
598                param_oids,
599                &stmt_name,
600                needs_parse,
601            )
602            .await
603        {
604            Ok(rows) => {
605                if needs_parse {
606                    self.cache_statement(query_sql, &stmt_name);
607                }
608                Ok(rows)
609            }
610            Err(PgWireError::Pg(ref pg_err))
611                if !needs_parse && is_stale_statement_error(pg_err) =>
612            {
613                tracing::debug!(
614                    sql = query_sql,
615                    "prepared statement invalidated — re-parsing in transaction"
616                );
617                self.invalidate_statement(query_sql);
618                let (stmt_name, _) = self.lookup_or_alloc(query_sql, param_oids);
619                let result = self
620                    .pipeline_transaction(
621                        setup_sql, query_sql, params, param_oids, &stmt_name, true,
622                    )
623                    .await;
624                if result.is_ok() {
625                    self.cache_statement(query_sql, &stmt_name);
626                }
627                result
628            }
629            Err(e) => Err(e),
630        }
631    }
632
633    /// Execute a parameterized query with automatic statement caching.
634    /// If a cached statement is invalidated by a schema change (PG error 26000
635    /// or 0A000), automatically evicts the cache entry, re-parses, and retries once.
636    pub async fn exec_query(
637        &self,
638        sql: &str,
639        params: &[Option<&[u8]>],
640        param_oids: &[u32],
641    ) -> Result<Vec<RawRow>, PgWireError> {
642        let (stmt_name, needs_parse) = self.lookup_or_alloc(sql, param_oids);
643        match self
644            .query(sql, params, param_oids, &stmt_name, needs_parse)
645            .await
646        {
647            Ok(rows) => {
648                if needs_parse {
649                    self.cache_statement(sql, &stmt_name);
650                }
651                Ok(rows)
652            }
653            Err(PgWireError::Pg(ref pg_err))
654                if !needs_parse && is_stale_statement_error(pg_err) =>
655            {
656                tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
657                self.invalidate_statement(sql);
658                let (stmt_name, _) = self.lookup_or_alloc(sql, param_oids);
659                let result = self.query(sql, params, param_oids, &stmt_name, true).await;
660                if result.is_ok() {
661                    self.cache_statement(sql, &stmt_name);
662                }
663                result
664            }
665            Err(e) => Err(e),
666        }
667    }
668
669    /// Maximum time to wait for a response from the reader task.
670    /// Prevents hanging forever if the reader/writer task dies mid-request.
671    const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
672
673    /// Submit a request to the connection. Returns a future that resolves
674    /// when the response is available. Times out after 5 minutes to prevent
675    /// hanging forever if the reader/writer task dies.
676    pub async fn submit(
677        &self,
678        messages: BytesMut,
679        collector: ResponseCollector,
680    ) -> Result<PipelineResponse, PgWireError> {
681        let (response_tx, response_rx) = oneshot::channel();
682        let req = PipelineRequest {
683            messages,
684            collector,
685            response_tx,
686        };
687        self.request_tx
688            .send(req)
689            .await
690            .map_err(|_| PgWireError::ConnectionClosed)?;
691        match tokio::time::timeout(Self::REQUEST_TIMEOUT, response_rx).await {
692            Ok(Ok(result)) => result,
693            Ok(Err(_)) => Err(PgWireError::ConnectionClosed),
694            Err(_elapsed) => {
695                tracing::error!(
696                    "request timed out after {:?} — reader/writer task may be dead",
697                    Self::REQUEST_TIMEOUT
698                );
699                Err(PgWireError::ConnectionClosed)
700            }
701        }
702    }
703
704    /// Submit a batch of requests in FIFO order. All requests are queued
705    /// before any response is awaited, so the writer task sees them together
706    /// and coalesces them into a single write() syscall. The server then
707    /// pipelines the N responses back-to-back, giving one network round-trip
708    /// for all N queries.
709    ///
710    /// Returns one `Result<PipelineResponse, PgWireError>` per input item,
711    /// in the same order. The outer `Result` fails only if queueing fails
712    /// (channel closed). Each inner `Result` reflects the per-query outcome.
713    pub async fn submit_batch(
714        &self,
715        items: Vec<(BytesMut, ResponseCollector)>,
716    ) -> Result<Vec<Result<PipelineResponse, PgWireError>>, PgWireError> {
717        let mut receivers = Vec::with_capacity(items.len());
718        for (messages, collector) in items {
719            let (response_tx, response_rx) = oneshot::channel();
720            self.request_tx
721                .send(PipelineRequest {
722                    messages,
723                    collector,
724                    response_tx,
725                })
726                .await
727                .map_err(|_| PgWireError::ConnectionClosed)?;
728            receivers.push(response_rx);
729        }
730        let mut results = Vec::with_capacity(receivers.len());
731        for rx in receivers {
732            match tokio::time::timeout(Self::REQUEST_TIMEOUT, rx).await {
733                Ok(Ok(r)) => results.push(r),
734                Ok(Err(_)) => results.push(Err(PgWireError::ConnectionClosed)),
735                Err(_) => {
736                    tracing::error!(
737                        "submit_batch request timed out after {:?}",
738                        Self::REQUEST_TIMEOUT
739                    );
740                    results.push(Err(PgWireError::ConnectionClosed));
741                }
742            }
743        }
744        Ok(results)
745    }
746
747    /// Send a Terminate message to the server and wait for the writer/reader
748    /// tasks to exit. After this returns, the connection is unusable; further
749    /// calls fail with `ConnectionClosed`. Idempotent: calling `close` on an
750    /// already-closed connection is a no-op and returns `Ok`.
751    pub async fn close(&self) -> Result<(), PgWireError> {
752        if !self.is_alive() {
753            return Ok(());
754        }
755        let mut buf = BytesMut::with_capacity(5);
756        frontend::encode_message(&FrontendMsg::Terminate, &mut buf);
757        // Submit Terminate through the writer so ordering is preserved wrt
758        // any in-flight requests ahead of us. The server replies with nothing
759        // and closes the socket, so we expect `ConnectionClosed` back from
760        // the drain collector — treat that as a successful close.
761        match self.submit(buf, ResponseCollector::Drain).await {
762            Ok(_) | Err(PgWireError::ConnectionClosed) => Ok(()),
763            Err(PgWireError::Io(e)) if e.kind() == std::io::ErrorKind::BrokenPipe => Ok(()),
764            Err(e) => Err(e),
765        }
766    }
767
768    /// Submit a streaming request. Returns the column header and an mpsc receiver
769    /// that yields rows one at a time.
770    pub async fn submit_stream(
771        &self,
772        messages: BytesMut,
773        row_buffer: usize,
774    ) -> Result<
775        (
776            StreamHeader,
777            mpsc::Receiver<Result<StreamedRow, PgWireError>>,
778        ),
779        PgWireError,
780    > {
781        let (header_tx, header_rx) = oneshot::channel();
782        let (row_tx, row_rx) = mpsc::channel(row_buffer);
783        let (response_tx, _response_rx) = oneshot::channel();
784        let req = PipelineRequest {
785            messages,
786            collector: ResponseCollector::Stream { header_tx, row_tx },
787            response_tx,
788        };
789        self.request_tx
790            .send(req)
791            .await
792            .map_err(|_| PgWireError::ConnectionClosed)?;
793        let header = header_rx
794            .await
795            .map_err(|_| PgWireError::ConnectionClosed)??;
796        Ok((header, row_rx))
797    }
798
799    /// Execute a pipelined transaction:
800    /// setup (simple query) + data query (extended protocol) + COMMIT (simple query)
801    /// All coalesced into one TCP write. Binary-safe parameterized data query.
802    pub async fn pipeline_transaction(
803        &self,
804        setup_sql: &str,
805        query_sql: &str,
806        params: &[Option<&[u8]>],
807        param_oids: &[u32],
808        stmt_name: &[u8],
809        needs_parse: bool,
810    ) -> Result<Vec<RawRow>, PgWireError> {
811        let mut buf = BytesMut::with_capacity(1024);
812
813        // 1. Simple query for setup (BEGIN + SET ROLE + set_config).
814        frontend::encode_message(&FrontendMsg::Query(setup_sql.as_bytes()), &mut buf);
815
816        // Submit setup as Drain — we don't care about its response data.
817        let setup_msgs = buf.split();
818
819        // 2. Extended query for data.
820        let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
821        let result_fmts = [FormatCode::Text];
822
823        if needs_parse {
824            frontend::encode_message(
825                &FrontendMsg::Parse {
826                    name: stmt_name,
827                    sql: query_sql.as_bytes(),
828                    param_oids,
829                },
830                &mut buf,
831            );
832        }
833
834        frontend::encode_message(
835            &FrontendMsg::Bind {
836                portal: b"",
837                statement: stmt_name,
838                param_formats: &text_fmts[..params.len()],
839                params,
840                result_formats: &result_fmts,
841            },
842            &mut buf,
843        );
844
845        frontend::encode_message(
846            &FrontendMsg::Execute {
847                portal: b"",
848                max_rows: 0,
849            },
850            &mut buf,
851        );
852
853        frontend::encode_message(&FrontendMsg::Sync, &mut buf);
854
855        let data_msgs = buf.split();
856
857        // 3. Simple query for COMMIT — in its own buffer so each request
858        // carries exactly the bytes that produce its ReadyForQuery response.
859        let mut commit_buf = BytesMut::with_capacity(32);
860        frontend::encode_message(&FrontendMsg::Query(b"COMMIT"), &mut commit_buf);
861
862        // Submit all three as separate requests with different collectors.
863        // They'll be coalesced by the writer into one write() syscall.
864        let (setup_tx, setup_rx) = oneshot::channel();
865        let (data_tx, data_rx) = oneshot::channel();
866        let (commit_tx, commit_rx) = oneshot::channel();
867
868        // Send all three requests to the writer channel.
869        // The writer drains the channel and writes them all at once.
870        self.request_tx
871            .send(PipelineRequest {
872                messages: setup_msgs,
873                collector: ResponseCollector::Drain,
874                response_tx: setup_tx,
875            })
876            .await
877            .map_err(|_| PgWireError::ConnectionClosed)?;
878
879        self.request_tx
880            .send(PipelineRequest {
881                messages: data_msgs,
882                collector: ResponseCollector::Rows,
883                response_tx: data_tx,
884            })
885            .await
886            .map_err(|_| PgWireError::ConnectionClosed)?;
887
888        self.request_tx
889            .send(PipelineRequest {
890                messages: commit_buf,
891                collector: ResponseCollector::Drain,
892                response_tx: commit_tx,
893            })
894            .await
895            .map_err(|_| PgWireError::ConnectionClosed)?;
896
897        // Wait for all responses.
898        setup_rx
899            .await
900            .map_err(|_| PgWireError::ConnectionClosed)??;
901
902        let data_resp = data_rx.await.map_err(|_| PgWireError::ConnectionClosed)??;
903
904        commit_rx
905            .await
906            .map_err(|_| PgWireError::ConnectionClosed)??;
907
908        match data_resp {
909            PipelineResponse::Rows { rows, .. } => Ok(rows),
910            PipelineResponse::Done => Ok(Vec::new()),
911        }
912    }
913
914    /// Execute a simple parameterized query (no transaction).
915    pub async fn query(
916        &self,
917        sql: &str,
918        params: &[Option<&[u8]>],
919        param_oids: &[u32],
920        stmt_name: &[u8],
921        needs_parse: bool,
922    ) -> Result<Vec<RawRow>, PgWireError> {
923        self.query_with_formats(sql, params, param_oids, &[], &[], stmt_name, needs_parse)
924            .await
925    }
926
927    /// Execute a parameterized query with explicit per-param and per-result
928    /// format codes (text = 0, binary = 1).
929    ///
930    /// `param_formats` is interpreted per PostgreSQL wire protocol rules:
931    /// - empty: all params are text
932    /// - length 1: the single code applies to every param
933    /// - length N (== params.len()): one code per param
934    ///
935    /// Same rules apply to `result_formats` for output columns (empty → all
936    /// text; single code → applies to all columns; per-column list otherwise).
937    #[allow(clippy::too_many_arguments)]
938    pub async fn query_with_formats(
939        &self,
940        sql: &str,
941        params: &[Option<&[u8]>],
942        param_oids: &[u32],
943        param_formats: &[FormatCode],
944        result_formats: &[FormatCode],
945        stmt_name: &[u8],
946        needs_parse: bool,
947    ) -> Result<Vec<RawRow>, PgWireError> {
948        let mut buf = BytesMut::with_capacity(512);
949
950        // Default to all-text if caller passes empty slices.
951        let text_param_fmts: Vec<FormatCode>;
952        let param_fmts_slice: &[FormatCode] = if param_formats.is_empty() {
953            text_param_fmts = vec![FormatCode::Text; params.len().max(1)];
954            &text_param_fmts[..params.len()]
955        } else {
956            param_formats
957        };
958        let default_result_fmts = [FormatCode::Text];
959        let result_fmts_slice: &[FormatCode] = if result_formats.is_empty() {
960            &default_result_fmts
961        } else {
962            result_formats
963        };
964
965        if needs_parse {
966            frontend::encode_message(
967                &FrontendMsg::Parse {
968                    name: stmt_name,
969                    sql: sql.as_bytes(),
970                    param_oids,
971                },
972                &mut buf,
973            );
974        }
975
976        frontend::encode_message(
977            &FrontendMsg::Bind {
978                portal: b"",
979                statement: stmt_name,
980                param_formats: param_fmts_slice,
981                params,
982                result_formats: result_fmts_slice,
983            },
984            &mut buf,
985        );
986
987        frontend::encode_message(
988            &FrontendMsg::Execute {
989                portal: b"",
990                max_rows: 0,
991            },
992            &mut buf,
993        );
994
995        frontend::encode_message(&FrontendMsg::Sync, &mut buf);
996
997        let resp = self.submit(buf, ResponseCollector::Rows).await?;
998        match resp {
999            PipelineResponse::Rows { rows, .. } => Ok(rows),
1000            PipelineResponse::Done => Ok(Vec::new()),
1001        }
1002    }
1003
1004    /// Variant of `exec_query` with per-param and per-result format codes.
1005    /// See `query_with_formats` for format code semantics.
1006    pub async fn exec_query_with_formats(
1007        &self,
1008        sql: &str,
1009        params: &[Option<&[u8]>],
1010        param_oids: &[u32],
1011        param_formats: &[FormatCode],
1012        result_formats: &[FormatCode],
1013    ) -> Result<Vec<RawRow>, PgWireError> {
1014        let (stmt_name, needs_parse) = self.lookup_or_alloc(sql, param_oids);
1015        match self
1016            .query_with_formats(
1017                sql,
1018                params,
1019                param_oids,
1020                param_formats,
1021                result_formats,
1022                &stmt_name,
1023                needs_parse,
1024            )
1025            .await
1026        {
1027            Ok(rows) => {
1028                if needs_parse {
1029                    self.cache_statement(sql, &stmt_name);
1030                }
1031                Ok(rows)
1032            }
1033            Err(PgWireError::Pg(ref pg_err))
1034                if !needs_parse && is_stale_statement_error(pg_err) =>
1035            {
1036                tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
1037                self.invalidate_statement(sql);
1038                let (stmt_name, _) = self.lookup_or_alloc(sql, param_oids);
1039                let result = self
1040                    .query_with_formats(
1041                        sql,
1042                        params,
1043                        param_oids,
1044                        param_formats,
1045                        result_formats,
1046                        &stmt_name,
1047                        true,
1048                    )
1049                    .await;
1050                if result.is_ok() {
1051                    self.cache_statement(sql, &stmt_name);
1052                }
1053                result
1054            }
1055            Err(e) => Err(e),
1056        }
1057    }
1058}
1059
1060// ---------------------------------------------------------------------------
1061// Writer task
1062// ---------------------------------------------------------------------------
1063
1064async fn writer_task(
1065    mut rx: mpsc::Receiver<PipelineRequest>,
1066    mut stream: tokio::io::WriteHalf<crate::tls::MaybeTlsStream>,
1067    pending: Arc<Mutex<VecDeque<PendingResponse>>>,
1068    pending_notify: Arc<tokio::sync::Notify>,
1069) {
1070    let mut write_buf = BytesMut::with_capacity(8192);
1071
1072    loop {
1073        // Wait for the first request.
1074        let first = match rx.recv().await {
1075            Some(req) => req,
1076            None => {
1077                // Channel closed — drain any pending responses with ConnectionClosed.
1078                drain_pending_on_exit(&pending).await;
1079                return;
1080            }
1081        };
1082
1083        // Drain any additional queued requests (batch coalescing).
1084        write_buf.clear();
1085        write_buf.extend_from_slice(&first.messages);
1086
1087        let mut batch: Vec<PendingResponse> = vec![PendingResponse {
1088            collector: first.collector,
1089            response_tx: first.response_tx,
1090        }];
1091
1092        // Non-blocking drain of all queued requests.
1093        while let Ok(req) = rx.try_recv() {
1094            write_buf.extend_from_slice(&req.messages);
1095            batch.push(PendingResponse {
1096                collector: req.collector,
1097                response_tx: req.response_tx,
1098            });
1099        }
1100
1101        // ONE write() syscall for all coalesced messages.
1102        // Write BEFORE enqueuing pending responses — if the write fails,
1103        // we send errors to callers instead of leaving them hanging.
1104        let write_result = stream.write_all(&write_buf).await;
1105        let write_err = match write_result {
1106            Ok(_) => stream.flush().await.err(),
1107            Err(e) => Some(e),
1108        };
1109
1110        if let Some(e) = write_err {
1111            tracing::error!("Writer error: {e}");
1112            let msg = e.to_string();
1113            for p in batch {
1114                let _ = p.response_tx.send(Err(PgWireError::Io(std::io::Error::new(
1115                    std::io::ErrorKind::BrokenPipe,
1116                    msg.clone(),
1117                ))));
1118            }
1119            // Drain any already-pending responses so the reader doesn't hang.
1120            drain_pending_on_exit(&pending).await;
1121            return;
1122        }
1123
1124        // Write succeeded — enqueue pending responses for the reader.
1125        {
1126            let mut pq = pending.lock().await;
1127            for p in batch {
1128                pq.push_back(p);
1129            }
1130        }
1131        // Wake the reader task to process the newly enqueued responses.
1132        pending_notify.notify_one();
1133    }
1134}
1135
1136/// On writer exit, drain all pending responses with ConnectionClosed errors
1137/// so callers don't wait for the 5-minute timeout.
1138async fn drain_pending_on_exit(pending: &Arc<Mutex<VecDeque<PendingResponse>>>) {
1139    let mut pq = pending.lock().await;
1140    while let Some(pr) = pq.pop_front() {
1141        let _ = pr.response_tx.send(Err(PgWireError::ConnectionClosed));
1142    }
1143}
1144
1145// ---------------------------------------------------------------------------
1146// Reader task
1147// ---------------------------------------------------------------------------
1148
1149async fn reader_task(
1150    mut stream: tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1151    pending: Arc<Mutex<VecDeque<PendingResponse>>>,
1152    pending_notify: Arc<tokio::sync::Notify>,
1153    notification_tx: mpsc::Sender<BackendMsg>,
1154    state_mutated: Arc<std::sync::atomic::AtomicBool>,
1155    dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
1156) {
1157    let mut recv_buf = BytesMut::with_capacity(32 * 1024);
1158
1159    loop {
1160        // Wait for a pending response to become available.
1161        let pr = loop {
1162            {
1163                let mut pq = pending.lock().await;
1164                if let Some(pr) = pq.pop_front() {
1165                    break pr;
1166                }
1167            }
1168            // No pending — wait for the writer to signal.
1169            pending_notify.notified().await;
1170        };
1171
1172        // Collect the response based on the collector type.
1173        let result = match pr.collector {
1174            ResponseCollector::Rows => {
1175                collect_rows(
1176                    &mut stream,
1177                    &mut recv_buf,
1178                    &notification_tx,
1179                    &state_mutated,
1180                    &dropped_notifications,
1181                )
1182                .await
1183            }
1184            ResponseCollector::Drain => {
1185                drain_until_ready(&mut stream, &mut recv_buf, Some(&state_mutated))
1186                    .await
1187                    .map(|_| PipelineResponse::Done)
1188            }
1189            ResponseCollector::Stream { header_tx, row_tx } => {
1190                stream_rows(
1191                    &mut stream,
1192                    &mut recv_buf,
1193                    header_tx,
1194                    row_tx,
1195                    &notification_tx,
1196                    &state_mutated,
1197                    &dropped_notifications,
1198                )
1199                .await;
1200                Ok(PipelineResponse::Done)
1201            }
1202            ResponseCollector::CopyIn { .. } => {
1203                collect_copy_in_response(&mut stream, &mut recv_buf, &state_mutated).await
1204            }
1205            ResponseCollector::CopyOut => {
1206                collect_copy_out(&mut stream, &mut recv_buf, &state_mutated).await
1207            }
1208        };
1209
1210        // Send the response back to the caller.
1211        let _ = pr.response_tx.send(result);
1212    }
1213}
1214
1215async fn read_msg(
1216    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1217    buf: &mut BytesMut,
1218) -> Result<BackendMsg, PgWireError> {
1219    loop {
1220        if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1221            return Ok(msg);
1222        }
1223        let n = stream.read_buf(buf).await?;
1224        if n == 0 {
1225            // EOF — try to parse any remaining data in the buffer before giving up.
1226            // This handles the case where the last message arrived just before the
1227            // connection closed and is already fully buffered.
1228            if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1229                return Ok(msg);
1230            }
1231            return Err(PgWireError::ConnectionClosed);
1232        }
1233    }
1234}
1235
1236/// If the ReadyForQuery status byte is anything other than `I` (idle),
1237/// flag the connection as state-mutated. `T` (in transaction) and `E`
1238/// (failed transaction) both leave session state that needs DISCARD ALL.
1239fn note_rfq_status(status: u8, state_mutated: &std::sync::atomic::AtomicBool) {
1240    if status != b'I' {
1241        state_mutated.store(true, std::sync::atomic::Ordering::Release);
1242    }
1243}
1244
1245async fn collect_rows(
1246    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1247    buf: &mut BytesMut,
1248    notification_tx: &mpsc::Sender<BackendMsg>,
1249    state_mutated: &std::sync::atomic::AtomicBool,
1250    dropped_notifications: &std::sync::atomic::AtomicU64,
1251) -> Result<PipelineResponse, PgWireError> {
1252    let mut rows = Vec::new();
1253    let mut fields = Vec::new();
1254    let mut command_tag = String::new();
1255    loop {
1256        let msg = read_msg(stream, buf).await?;
1257        match msg {
1258            BackendMsg::DataRow(row) => rows.push(row),
1259            BackendMsg::RowDescription { fields: f } => fields = f,
1260            BackendMsg::CommandComplete { tag } => command_tag = tag,
1261            BackendMsg::ReadyForQuery { status } => {
1262                note_rfq_status(status, state_mutated);
1263                return Ok(PipelineResponse::Rows {
1264                    fields,
1265                    rows,
1266                    command_tag,
1267                });
1268            }
1269            BackendMsg::ErrorResponse { fields } => {
1270                drain_until_ready(stream, buf, Some(state_mutated)).await?;
1271                return Err(PgWireError::Pg(fields));
1272            }
1273            msg @ BackendMsg::NotificationResponse { .. } => {
1274                // Forward notification instead of dropping.
1275                #[allow(clippy::collapsible_match)]
1276                if notification_tx.try_send(msg).is_err() {
1277                    dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1278                    tracing::warn!("notification channel full, dropping notification");
1279                }
1280            }
1281            BackendMsg::ParseComplete
1282            | BackendMsg::BindComplete
1283            | BackendMsg::NoData
1284            | BackendMsg::NoticeResponse { .. }
1285            | BackendMsg::EmptyQueryResponse => {}
1286            _ => {}
1287        }
1288    }
1289}
1290
1291async fn drain_until_ready(
1292    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1293    buf: &mut BytesMut,
1294    state_mutated: Option<&std::sync::atomic::AtomicBool>,
1295) -> Result<(), PgWireError> {
1296    loop {
1297        let msg = read_msg(stream, buf).await?;
1298        if let BackendMsg::ReadyForQuery { status } = msg {
1299            if let Some(sm) = state_mutated {
1300                note_rfq_status(status, sm);
1301            }
1302            return Ok(());
1303        }
1304        if let BackendMsg::ErrorResponse { ref fields } = msg {
1305            tracing::warn!("Error in drain: {}: {}", fields.code, fields.message);
1306        }
1307    }
1308}
1309
1310/// Stream rows one at a time, sending header first, then individual rows.
1311async fn stream_rows(
1312    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1313    buf: &mut BytesMut,
1314    header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
1315    row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
1316    notification_tx: &mpsc::Sender<BackendMsg>,
1317    state_mutated: &std::sync::atomic::AtomicBool,
1318    dropped_notifications: &std::sync::atomic::AtomicU64,
1319) {
1320    let mut header_tx = Some(header_tx);
1321    let mut fields = Vec::new();
1322    loop {
1323        let msg = match read_msg(stream, buf).await {
1324            Ok(msg) => msg,
1325            Err(e) => {
1326                if let Some(htx) = header_tx.take() {
1327                    let _ = htx.send(Err(e));
1328                } else {
1329                    let _ = row_tx.send(Err(e)).await;
1330                }
1331                return;
1332            }
1333        };
1334        match msg {
1335            BackendMsg::RowDescription { fields: f } => {
1336                fields = f;
1337            }
1338            BackendMsg::DataRow(row) => {
1339                if let Some(htx) = header_tx.take() {
1340                    let _ = htx.send(Ok(StreamHeader {
1341                        fields: fields.clone(),
1342                    }));
1343                }
1344                if row_tx.send(Ok(row)).await.is_err() {
1345                    let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1346                    return;
1347                }
1348            }
1349            BackendMsg::CommandComplete { .. } => {
1350                if let Some(htx) = header_tx.take() {
1351                    let _ = htx.send(Ok(StreamHeader {
1352                        fields: std::mem::take(&mut fields),
1353                    }));
1354                }
1355            }
1356            BackendMsg::ReadyForQuery { status } => {
1357                note_rfq_status(status, state_mutated);
1358                if let Some(htx) = header_tx.take() {
1359                    let _ = htx.send(Ok(StreamHeader {
1360                        fields: std::mem::take(&mut fields),
1361                    }));
1362                }
1363                return;
1364            }
1365            BackendMsg::ErrorResponse { fields: err } => {
1366                if let Some(htx) = header_tx.take() {
1367                    let _ = htx.send(Err(PgWireError::Pg(err)));
1368                } else {
1369                    let _ = row_tx.send(Err(PgWireError::Pg(err))).await;
1370                }
1371                let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1372                return;
1373            }
1374            msg @ BackendMsg::NotificationResponse { .. } => {
1375                #[allow(clippy::collapsible_match)]
1376                if notification_tx.try_send(msg).is_err() {
1377                    dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1378                    tracing::warn!("notification channel full, dropping notification");
1379                }
1380            }
1381            BackendMsg::ParseComplete
1382            | BackendMsg::BindComplete
1383            | BackendMsg::NoData
1384            | BackendMsg::PortalSuspended
1385            | BackendMsg::NoticeResponse { .. }
1386            | BackendMsg::EmptyQueryResponse => {}
1387            _ => {}
1388        }
1389    }
1390}
1391
1392/// Handle COPY IN response: skip CopyInResponse, wait for CommandComplete + ReadyForQuery.
1393/// The actual CopyData + CopyDone were pre-buffered in the write, so PG processes them.
1394async fn collect_copy_in_response(
1395    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1396    buf: &mut BytesMut,
1397    state_mutated: &std::sync::atomic::AtomicBool,
1398) -> Result<PipelineResponse, PgWireError> {
1399    let mut command_tag = String::new();
1400    loop {
1401        let msg = read_msg(stream, buf).await?;
1402        match msg {
1403            BackendMsg::CopyInResponse { .. } => {}
1404            BackendMsg::CommandComplete { tag } => command_tag = tag,
1405            BackendMsg::ReadyForQuery { status } => {
1406                note_rfq_status(status, state_mutated);
1407                return Ok(PipelineResponse::Rows {
1408                    fields: Vec::new(),
1409                    rows: Vec::new(),
1410                    command_tag,
1411                });
1412            }
1413            BackendMsg::ErrorResponse { fields } => {
1414                drain_until_ready(stream, buf, Some(state_mutated)).await?;
1415                return Err(PgWireError::Pg(fields));
1416            }
1417            _ => {}
1418        }
1419    }
1420}
1421
1422/// Collect COPY OUT data: CopyOutResponse → CopyData* → CopyDone → CommandComplete → ReadyForQuery.
1423async fn collect_copy_out(
1424    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1425    buf: &mut BytesMut,
1426    state_mutated: &std::sync::atomic::AtomicBool,
1427) -> Result<PipelineResponse, PgWireError> {
1428    let mut data_chunks: Vec<RawRow> = Vec::new();
1429    let mut command_tag = String::new();
1430    loop {
1431        let msg = read_msg(stream, buf).await?;
1432        match msg {
1433            BackendMsg::CopyOutResponse { .. } => {}
1434            BackendMsg::CopyData { data } => {
1435                let body = bytes::Bytes::from(data);
1436                data_chunks.push(RawRow::from_full_body(body));
1437            }
1438            BackendMsg::CopyDone => {}
1439            BackendMsg::CommandComplete { tag } => command_tag = tag,
1440            BackendMsg::ReadyForQuery { status } => {
1441                note_rfq_status(status, state_mutated);
1442                return Ok(PipelineResponse::Rows {
1443                    fields: Vec::new(),
1444                    rows: data_chunks,
1445                    command_tag,
1446                });
1447            }
1448            BackendMsg::ErrorResponse { fields } => {
1449                drain_until_ready(stream, buf, Some(state_mutated)).await?;
1450                return Err(PgWireError::Pg(fields));
1451            }
1452            _ => {}
1453        }
1454    }
1455}
1456
1457/// Check if a PostgreSQL error indicates a stale/invalidated prepared statement.
1458/// Error codes: 26000 (invalid_sql_statement_name), 0A000 (feature_not_supported
1459/// — used when cached plan changes type).
1460fn is_stale_statement_error(err: &crate::protocol::types::PgError) -> bool {
1461    matches!(err.code.as_str(), "26000" | "0A000")
1462}
1463
1464fn parse_copy_count(tag: &str) -> u64 {
1465    // COPY tag format: "COPY 123"
1466    tag.strip_prefix("COPY ")
1467        .and_then(|s| s.parse::<u64>().ok())
1468        .unwrap_or(0)
1469}
1470
1471// Extension to WireConn to extract the underlying stream.
1472impl WireConn {
1473    pub(crate) fn into_stream(self) -> crate::tls::MaybeTlsStream {
1474        self.stream
1475    }
1476}
1477
1478#[cfg(test)]
1479mod tests {
1480    use super::*;
1481
1482    /// Channel-full branch: when the request channel has no spare capacity,
1483    /// `try_enqueue_rollback` returns `false` instead of blocking.
1484    #[tokio::test]
1485    async fn try_enqueue_rollback_returns_false_when_channel_full() {
1486        let (tx, _rx) = mpsc::channel::<PipelineRequest>(2);
1487        // Fill the channel by reusing the same helper. capacity=2 plus the
1488        // single buffered slot tokio reserves means we may need to push
1489        // until try_send fails; loop until we observe the false return.
1490        let mut filled = false;
1491        for _ in 0..16 {
1492            if !try_enqueue_rollback(&tx) {
1493                filled = true;
1494                break;
1495            }
1496        }
1497        assert!(
1498            filled,
1499            "expected try_enqueue_rollback to eventually return false on a full channel"
1500        );
1501        assert!(
1502            !try_enqueue_rollback(&tx),
1503            "subsequent calls on a full channel must keep returning false"
1504        );
1505    }
1506
1507    /// Channel-closed branch: dropping the receiver makes `try_send` fail
1508    /// with `Closed`, which `try_enqueue_rollback` reports as `false`.
1509    #[tokio::test]
1510    async fn try_enqueue_rollback_returns_false_when_channel_closed() {
1511        let (tx, rx) = mpsc::channel::<PipelineRequest>(8);
1512        drop(rx);
1513        assert!(
1514            !try_enqueue_rollback(&tx),
1515            "try_enqueue_rollback must return false when the receiver has been dropped"
1516        );
1517    }
1518
1519    /// Happy path: with a live receiver and free capacity, the helper
1520    /// reports success and the receiver observes a queued request whose
1521    /// payload starts with the simple-query opcode `'Q'`.
1522    #[tokio::test]
1523    async fn try_enqueue_rollback_returns_true_and_enqueues_query() {
1524        let (tx, mut rx) = mpsc::channel::<PipelineRequest>(2);
1525        assert!(try_enqueue_rollback(&tx));
1526        let req = rx.recv().await.expect("request should be received");
1527        assert_eq!(
1528            req.messages.first().copied(),
1529            Some(b'Q'),
1530            "queued request should be a simple Query message"
1531        );
1532        // Body should mention ROLLBACK (text follows length prefix and is
1533        // null-terminated; just substring-search to keep the test simple).
1534        assert!(
1535            req.messages.windows(8).any(|w| w == b"ROLLBACK"),
1536            "queued request should contain the ROLLBACK statement text"
1537        );
1538    }
1539}