Skip to main content

pg_wired/
async_conn.rs

1//! Async split sender/receiver connection.
2//! Inspired by hsqlx's PgWire.Async architecture.
3//!
4//! A single TCP connection is shared by many concurrent handler tasks.
5//! The writer task coalesces messages from multiple requests into one write().
6//! The reader task parses responses and dispatches them to waiting handlers via FIFO.
7
8use std::collections::VecDeque;
9use std::sync::Arc;
10
11use bytes::BytesMut;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::sync::{mpsc, oneshot, Mutex};
14
15use crate::connection::WireConn;
16use crate::error::PgWireError;
17use crate::protocol::backend;
18use crate::protocol::frontend;
19use crate::protocol::types::{BackendMsg, FormatCode, FrontendMsg, RawRow};
20
21// ---------------------------------------------------------------------------
22// Request types
23// ---------------------------------------------------------------------------
24
25/// A request to execute on the connection. Internal plumbing between the
26/// public `submit` / `submit_batch` API and the writer task.
27pub(crate) struct PipelineRequest {
28    pub(crate) messages: BytesMut,
29    pub(crate) collector: ResponseCollector,
30    pub(crate) response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
31}
32
33/// How to collect response messages for a request.
34#[allow(dead_code)]
35#[non_exhaustive]
36pub enum ResponseCollector {
37    /// Collect DataRows until ReadyForQuery (for SELECT queries).
38    Rows,
39    /// Just drain until ReadyForQuery (for setup commands like BEGIN, SET ROLE).
40    Drain,
41    /// Stream rows one at a time via channels. Sends header first, then individual rows.
42    Stream {
43        /// One-shot channel for the row description (sent once before any rows).
44        header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
45        /// Bounded channel for individual rows; closed on completion or error.
46        row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
47    },
48    /// COPY IN: after receiving CopyInResponse, send the provided data then CopyDone.
49    CopyIn {
50        /// The data to send after CopyInResponse.
51        data: Vec<u8>,
52    },
53    /// COPY OUT: collect CopyData messages until CopyDone.
54    CopyOut,
55}
56
57/// Response from a pipeline request.
58#[non_exhaustive]
59pub enum PipelineResponse {
60    /// A query that produced a row set (`SELECT`, `RETURNING`, etc.).
61    Rows {
62        /// Column metadata from RowDescription (empty if no RowDescription received).
63        fields: Vec<crate::protocol::types::FieldDescription>,
64        /// Row data.
65        rows: Vec<RawRow>,
66        /// CommandComplete tag (e.g. "SELECT 3", "INSERT 0 1").
67        command_tag: String,
68    },
69    /// A statement that produced no row set (e.g., `BEGIN`, `SET ROLE`,
70    /// non-RETURNING DML).
71    Done,
72}
73
74/// Metadata sent at the start of a streaming response.
75#[derive(Debug, Clone)]
76pub struct StreamHeader {
77    /// Column descriptions (name, OID, format) for the streamed result set.
78    pub fields: Vec<crate::protocol::types::FieldDescription>,
79}
80
81/// A single streamed row.
82pub type StreamedRow = RawRow;
83
84// ---------------------------------------------------------------------------
85// Async connection
86// ---------------------------------------------------------------------------
87
88/// A shared async connection that multiplexes requests from many tasks.
89pub struct AsyncConn {
90    request_tx: mpsc::Sender<PipelineRequest>,
91    stmt_cache: std::sync::Mutex<std::collections::HashMap<String, (String, u64)>>,
92    stmt_counter: std::sync::atomic::AtomicU64,
93    alive: Arc<std::sync::atomic::AtomicBool>,
94    backend_pid: i32,
95    backend_secret: i32,
96    addr: String,
97    /// Channel for async notifications received during query execution.
98    /// Notifications are NOT silently dropped, they're forwarded here.
99    #[allow(dead_code)]
100    notification_tx: mpsc::Sender<crate::protocol::types::BackendMsg>,
101    notification_rx: std::sync::Mutex<Option<mpsc::Receiver<crate::protocol::types::BackendMsg>>>,
102    /// True if any operation since the last `take_state_mutated()` may have
103    /// left the session in a non-default state (open transaction, SET
104    /// without LOCAL, advisory lock, temp table, prepared cursor, etc.).
105    ///
106    /// Set explicitly by callers issuing such operations
107    /// (`mark_state_mutated`), and automatically by the reader task whenever
108    /// ReadyForQuery reports a non-idle transaction status. Callers that
109    /// only run self-contained Bind/Execute/Sync queries leave this `false`,
110    /// allowing pools to skip an expensive DISCARD ALL on return.
111    state_mutated: Arc<std::sync::atomic::AtomicBool>,
112    /// Cumulative count of asynchronous notifications dropped because the
113    /// notification channel was full or no application code was draining it.
114    /// Surfaced via [`AsyncConn::dropped_notifications`] so callers can detect
115    /// missed `LISTEN` events.
116    dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
117}
118
119impl std::fmt::Debug for AsyncConn {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        f.debug_struct("AsyncConn")
122            .field("addr", &self.addr)
123            .field("backend_pid", &self.backend_pid)
124            .field("alive", &self.is_alive())
125            .finish()
126    }
127}
128
129impl AsyncConn {
130    /// Check if the connection is still alive (writer/reader tasks running).
131    pub fn is_alive(&self) -> bool {
132        self.alive.load(std::sync::atomic::Ordering::Relaxed)
133    }
134
135    /// Backend process ID assigned by the server.
136    pub fn backend_pid(&self) -> i32 {
137        self.backend_pid
138    }
139
140    /// Server address this connection is talking to.
141    pub fn addr(&self) -> &str {
142        &self.addr
143    }
144
145    /// Produce a cancel token for the running session on this connection.
146    pub fn cancel_token(&self) -> crate::cancel::CancelToken {
147        crate::cancel::CancelToken::new(self.addr.clone(), self.backend_pid, self.backend_secret)
148    }
149
150    /// Mark the connection as having mutated session state since the last
151    /// reset. Pools call `take_state_mutated()` on return to decide whether
152    /// to issue `DISCARD ALL`. Callers issuing `BEGIN`, `SET` (without
153    /// `LOCAL`), advisory locks, temp tables, etc., should call this before
154    /// submitting.
155    pub fn mark_state_mutated(&self) {
156        self.state_mutated
157            .store(true, std::sync::atomic::Ordering::Release);
158    }
159
160    /// Atomically read and clear the state-mutated flag. Returns the
161    /// previous value: `true` means the caller should issue a reset.
162    pub fn take_state_mutated(&self) -> bool {
163        self.state_mutated
164            .swap(false, std::sync::atomic::Ordering::AcqRel)
165    }
166
167    /// Read the state-mutated flag without clearing it.
168    pub fn is_state_mutated(&self) -> bool {
169        self.state_mutated
170            .load(std::sync::atomic::Ordering::Acquire)
171    }
172}
173
174struct PendingResponse {
175    collector: ResponseCollector,
176    response_tx: oneshot::Sender<Result<PipelineResponse, PgWireError>>,
177}
178
179impl AsyncConn {
180    /// Create a new async connection from a raw WireConn.
181    /// Spawns writer and reader tasks.
182    pub fn new(conn: WireConn) -> Self {
183        let backend_pid = conn.pid;
184        let backend_secret = conn.secret;
185        // Extract peer address before consuming the stream.
186        let addr = conn
187            .stream
188            .peer_addr()
189            .map(|a| a.to_string())
190            .unwrap_or_default();
191
192        let (notification_tx, notification_rx) = mpsc::channel(4096);
193        let (request_tx, request_rx) = mpsc::channel::<PipelineRequest>(256);
194        let pending: Arc<Mutex<VecDeque<PendingResponse>>> = Arc::new(Mutex::new(VecDeque::new()));
195        let pending_notify = Arc::new(tokio::sync::Notify::new());
196        let alive = Arc::new(std::sync::atomic::AtomicBool::new(true));
197        let state_mutated = Arc::new(std::sync::atomic::AtomicBool::new(false));
198        let dropped_notifications = Arc::new(std::sync::atomic::AtomicU64::new(0));
199
200        let (stream_read, stream_write) = tokio::io::split(conn.into_stream());
201
202        // Spawn writer task — sets alive=false on exit.
203        {
204            let pending = Arc::clone(&pending);
205            let pending_notify = Arc::clone(&pending_notify);
206            let alive = Arc::clone(&alive);
207            tokio::spawn(async move {
208                writer_task(request_rx, stream_write, pending, pending_notify).await;
209                alive.store(false, std::sync::atomic::Ordering::Relaxed);
210                tracing::warn!("pg-wired writer task exited");
211            });
212        }
213
214        // Spawn reader task — sets alive=false on exit.
215        {
216            let pending = Arc::clone(&pending);
217            let pending_notify = Arc::clone(&pending_notify);
218            let alive_clone = Arc::clone(&alive);
219            let state_mutated = Arc::clone(&state_mutated);
220            let ntf_tx = notification_tx.clone();
221            let dropped = Arc::clone(&dropped_notifications);
222            tokio::spawn(async move {
223                reader_task(
224                    stream_read,
225                    pending,
226                    pending_notify,
227                    ntf_tx,
228                    state_mutated,
229                    dropped,
230                )
231                .await;
232                alive_clone.store(false, std::sync::atomic::Ordering::Relaxed);
233                tracing::warn!("pg-wired reader task exited");
234            });
235        }
236
237        Self {
238            request_tx,
239            stmt_cache: std::sync::Mutex::new(std::collections::HashMap::new()),
240            stmt_counter: std::sync::atomic::AtomicU64::new(0),
241            alive,
242            backend_pid,
243            backend_secret,
244            addr,
245            notification_tx,
246            notification_rx: std::sync::Mutex::new(Some(notification_rx)),
247            state_mutated,
248            dropped_notifications,
249        }
250    }
251
252    /// Cumulative number of `NotificationResponse` messages this connection
253    /// has discarded since it was created.
254    ///
255    /// Notifications are dropped when (a) the application has not called
256    /// [`AsyncConn::take_notification_receiver`] yet, or (b) the receiver is
257    /// not draining fast enough and the bounded channel fills up. Compare
258    /// successive readings to detect missed `LISTEN` events.
259    pub fn dropped_notifications(&self) -> u64 {
260        self.dropped_notifications
261            .load(std::sync::atomic::Ordering::Relaxed)
262    }
263
264    /// Take the notification receiver. Call once to get a channel that
265    /// receives `NotificationResponse` messages that arrive during queries.
266    pub fn take_notification_receiver(
267        &self,
268    ) -> Option<mpsc::Receiver<crate::protocol::types::BackendMsg>> {
269        self.notification_rx
270            .lock()
271            .ok()
272            .and_then(|mut guard| guard.take())
273    }
274
275    /// Look up or allocate a statement name.
276    /// Uses an LRU-style eviction: when the cache is full, the oldest entry
277    /// (by insertion order / counter) is removed and a Close message is queued
278    /// to free the server-side prepared statement.
279    pub fn lookup_or_alloc(&self, sql: &str) -> (Vec<u8>, bool) {
280        let mut cache = match self.stmt_cache.lock() {
281            Ok(c) => c,
282            Err(poisoned) => poisoned.into_inner(),
283        };
284        if let Some((name, _)) = cache.get(sql) {
285            return (name.as_bytes().to_vec(), false);
286        }
287        // LRU eviction: remove the entry with the lowest counter value
288        // and send a Close message to free the server-side prepared statement.
289        if cache.len() >= 256 {
290            if let Some((oldest_key, oldest_name)) = cache
291                .iter()
292                .min_by_key(|(_, (_, counter))| *counter)
293                .map(|(k, (name, _))| (k.clone(), name.clone()))
294            {
295                cache.remove(&oldest_key);
296                // Queue a Close + Sync to free the server-side statement.
297                // Fire-and-forget: if the channel is full or closed, skip it.
298                let mut close_buf = BytesMut::with_capacity(32);
299                frontend::encode_message(
300                    &FrontendMsg::Close {
301                        kind: b'S',
302                        name: oldest_name.as_bytes(),
303                    },
304                    &mut close_buf,
305                );
306                frontend::encode_message(&FrontendMsg::Sync, &mut close_buf);
307                let (tx, _rx) = oneshot::channel();
308                let _ = self.request_tx.try_send(PipelineRequest {
309                    messages: close_buf,
310                    collector: ResponseCollector::Drain,
311                    response_tx: tx,
312                });
313            }
314        }
315        let n = self
316            .stmt_counter
317            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
318        let name = format!("s{n}");
319        cache.insert(sql.to_string(), (name.clone(), n));
320        (name.into_bytes(), true)
321    }
322
323    /// Execute COPY FROM STDIN: sends the COPY command, then data in chunks, then CopyDone.
324    /// Returns the number of rows copied (from CommandComplete tag).
325    ///
326    /// Data is sent in chunks of up to 1MB to avoid buffering the entire payload
327    /// in a single BytesMut. For small payloads (< 1MB), this is a single write.
328    pub async fn copy_in(&self, copy_sql: &str, data: &[u8]) -> Result<u64, PgWireError> {
329        use crate::protocol::types::FrontendMsg;
330        const CHUNK_SIZE: usize = 1024 * 1024; // 1MB chunks
331
332        // Build the message buffer: Query + chunked CopyData + CopyDone.
333        let mut buf = BytesMut::with_capacity(copy_sql.len() + data.len().min(CHUNK_SIZE) + 64);
334        frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
335
336        // Send data in chunks to avoid a single huge allocation.
337        for chunk in data.chunks(CHUNK_SIZE) {
338            frontend::encode_message(&FrontendMsg::CopyData(chunk), &mut buf);
339        }
340        // Empty data is valid (0 rows copied).
341        frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
342
343        let resp = self
344            .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
345            .await?;
346        match resp {
347            PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
348            PipelineResponse::Done => Ok(0),
349        }
350    }
351
352    /// Execute COPY FROM STDIN with streaming: sends the COPY command, then
353    /// reads data from an async reader in chunks, avoiding buffering the entire
354    /// payload in memory.
355    ///
356    /// ```no_run
357    /// # async fn _doctest() -> Result<(), Box<dyn std::error::Error>> {
358    /// # let conn: pg_wired::AsyncConn = unimplemented!();
359    /// use tokio::fs::File;
360    /// let file = File::open("data.csv").await?;
361    /// let _count = conn.copy_in_stream("COPY users FROM STDIN WITH (FORMAT csv)", file).await?;
362    /// # Ok(()) }
363    /// ```
364    pub async fn copy_in_stream<R: tokio::io::AsyncRead + Unpin>(
365        &self,
366        copy_sql: &str,
367        mut reader: R,
368    ) -> Result<u64, PgWireError> {
369        use tokio::io::AsyncReadExt;
370        const CHUNK_SIZE: usize = 1024 * 1024; // 1MB chunks
371
372        // Send the COPY command.
373        let mut buf = BytesMut::with_capacity(copy_sql.len() + 16);
374        frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
375
376        // Read and send data in chunks.
377        let mut chunk = vec![0u8; CHUNK_SIZE];
378        loop {
379            let n = reader.read(&mut chunk).await?;
380            if n == 0 {
381                break;
382            }
383            frontend::encode_message(&FrontendMsg::CopyData(&chunk[..n]), &mut buf);
384        }
385        frontend::encode_message(&FrontendMsg::CopyDone, &mut buf);
386
387        let resp = self
388            .submit(buf, ResponseCollector::CopyIn { data: Vec::new() })
389            .await?;
390        match resp {
391            PipelineResponse::Rows { command_tag, .. } => Ok(parse_copy_count(&command_tag)),
392            PipelineResponse::Done => Ok(0),
393        }
394    }
395
396    /// Execute COPY TO STDOUT: sends the COPY command, collects all CopyData.
397    pub async fn copy_out(&self, copy_sql: &str) -> Result<Vec<u8>, PgWireError> {
398        use crate::protocol::types::FrontendMsg;
399        let mut buf = BytesMut::new();
400        frontend::encode_message(&FrontendMsg::Query(copy_sql.as_bytes()), &mut buf);
401
402        let resp = self.submit(buf, ResponseCollector::CopyOut).await?;
403        match resp {
404            PipelineResponse::Rows { rows, .. } => {
405                // For CopyOut, we reuse the Rows variant but each `RawRow` carries
406                // one cell which is the raw COPY data chunk (see `collect_copy_out`).
407                let mut result = Vec::new();
408                for row in rows {
409                    for data in row.iter().flatten() {
410                        result.extend_from_slice(data);
411                    }
412                }
413                Ok(result)
414            }
415            PipelineResponse::Done => Ok(Vec::new()),
416        }
417    }
418
419    /// Evict a SQL statement from the cache, forcing re-parse on next use.
420    /// Used for prepared statement invalidation after schema changes.
421    pub fn invalidate_statement(&self, sql: &str) {
422        let mut cache = match self.stmt_cache.lock() {
423            Ok(c) => c,
424            Err(poisoned) => poisoned.into_inner(),
425        };
426        cache.remove(sql);
427    }
428
429    /// Clear the entire statement cache. Must be called after `DISCARD ALL`
430    /// which destroys server-side prepared statements.
431    pub fn clear_statement_cache(&self) {
432        let mut cache = match self.stmt_cache.lock() {
433            Ok(c) => c,
434            Err(poisoned) => poisoned.into_inner(),
435        };
436        cache.clear();
437    }
438
439    /// Execute a pipelined transaction with automatic statement caching.
440    pub async fn exec_transaction(
441        &self,
442        setup_sql: &str,
443        query_sql: &str,
444        params: &[Option<&[u8]>],
445        param_oids: &[u32],
446    ) -> Result<Vec<RawRow>, PgWireError> {
447        let (stmt_name, needs_parse) = self.lookup_or_alloc(query_sql);
448        self.pipeline_transaction(
449            setup_sql,
450            query_sql,
451            params,
452            param_oids,
453            &stmt_name,
454            needs_parse,
455        )
456        .await
457    }
458
459    /// Execute a parameterized query with automatic statement caching.
460    /// If a cached statement is invalidated by a schema change (PG error 26000
461    /// or 0A000), automatically evicts the cache entry, re-parses, and retries once.
462    pub async fn exec_query(
463        &self,
464        sql: &str,
465        params: &[Option<&[u8]>],
466        param_oids: &[u32],
467    ) -> Result<Vec<RawRow>, PgWireError> {
468        let (stmt_name, needs_parse) = self.lookup_or_alloc(sql);
469        match self
470            .query(sql, params, param_oids, &stmt_name, needs_parse)
471            .await
472        {
473            Ok(rows) => Ok(rows),
474            Err(PgWireError::Pg(ref pg_err))
475                if !needs_parse && is_stale_statement_error(pg_err) =>
476            {
477                tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
478                self.invalidate_statement(sql);
479                let (stmt_name, _) = self.lookup_or_alloc(sql);
480                self.query(sql, params, param_oids, &stmt_name, true).await
481            }
482            Err(e) => Err(e),
483        }
484    }
485
486    /// Maximum time to wait for a response from the reader task.
487    /// Prevents hanging forever if the reader/writer task dies mid-request.
488    const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
489
490    /// Submit a request to the connection. Returns a future that resolves
491    /// when the response is available. Times out after 5 minutes to prevent
492    /// hanging forever if the reader/writer task dies.
493    pub async fn submit(
494        &self,
495        messages: BytesMut,
496        collector: ResponseCollector,
497    ) -> Result<PipelineResponse, PgWireError> {
498        let (response_tx, response_rx) = oneshot::channel();
499        let req = PipelineRequest {
500            messages,
501            collector,
502            response_tx,
503        };
504        self.request_tx
505            .send(req)
506            .await
507            .map_err(|_| PgWireError::ConnectionClosed)?;
508        match tokio::time::timeout(Self::REQUEST_TIMEOUT, response_rx).await {
509            Ok(Ok(result)) => result,
510            Ok(Err(_)) => Err(PgWireError::ConnectionClosed),
511            Err(_elapsed) => {
512                tracing::error!(
513                    "request timed out after {:?} — reader/writer task may be dead",
514                    Self::REQUEST_TIMEOUT
515                );
516                Err(PgWireError::ConnectionClosed)
517            }
518        }
519    }
520
521    /// Submit a batch of requests in FIFO order. All requests are queued
522    /// before any response is awaited, so the writer task sees them together
523    /// and coalesces them into a single write() syscall. The server then
524    /// pipelines the N responses back-to-back, giving one network round-trip
525    /// for all N queries.
526    ///
527    /// Returns one `Result<PipelineResponse, PgWireError>` per input item,
528    /// in the same order. The outer `Result` fails only if queueing fails
529    /// (channel closed). Each inner `Result` reflects the per-query outcome.
530    pub async fn submit_batch(
531        &self,
532        items: Vec<(BytesMut, ResponseCollector)>,
533    ) -> Result<Vec<Result<PipelineResponse, PgWireError>>, PgWireError> {
534        let mut receivers = Vec::with_capacity(items.len());
535        for (messages, collector) in items {
536            let (response_tx, response_rx) = oneshot::channel();
537            self.request_tx
538                .send(PipelineRequest {
539                    messages,
540                    collector,
541                    response_tx,
542                })
543                .await
544                .map_err(|_| PgWireError::ConnectionClosed)?;
545            receivers.push(response_rx);
546        }
547        let mut results = Vec::with_capacity(receivers.len());
548        for rx in receivers {
549            match tokio::time::timeout(Self::REQUEST_TIMEOUT, rx).await {
550                Ok(Ok(r)) => results.push(r),
551                Ok(Err(_)) => results.push(Err(PgWireError::ConnectionClosed)),
552                Err(_) => {
553                    tracing::error!(
554                        "submit_batch request timed out after {:?}",
555                        Self::REQUEST_TIMEOUT
556                    );
557                    results.push(Err(PgWireError::ConnectionClosed));
558                }
559            }
560        }
561        Ok(results)
562    }
563
564    /// Send a Terminate message to the server and wait for the writer/reader
565    /// tasks to exit. After this returns, the connection is unusable; further
566    /// calls fail with `ConnectionClosed`. Idempotent: calling `close` on an
567    /// already-closed connection is a no-op and returns `Ok`.
568    pub async fn close(&self) -> Result<(), PgWireError> {
569        if !self.is_alive() {
570            return Ok(());
571        }
572        let mut buf = BytesMut::with_capacity(5);
573        frontend::encode_message(&FrontendMsg::Terminate, &mut buf);
574        // Submit Terminate through the writer so ordering is preserved wrt
575        // any in-flight requests ahead of us. The server replies with nothing
576        // and closes the socket, so we expect `ConnectionClosed` back from
577        // the drain collector — treat that as a successful close.
578        match self.submit(buf, ResponseCollector::Drain).await {
579            Ok(_) | Err(PgWireError::ConnectionClosed) => Ok(()),
580            Err(PgWireError::Io(e)) if e.kind() == std::io::ErrorKind::BrokenPipe => Ok(()),
581            Err(e) => Err(e),
582        }
583    }
584
585    /// Submit a streaming request. Returns the column header and an mpsc receiver
586    /// that yields rows one at a time.
587    pub async fn submit_stream(
588        &self,
589        messages: BytesMut,
590        row_buffer: usize,
591    ) -> Result<
592        (
593            StreamHeader,
594            mpsc::Receiver<Result<StreamedRow, PgWireError>>,
595        ),
596        PgWireError,
597    > {
598        let (header_tx, header_rx) = oneshot::channel();
599        let (row_tx, row_rx) = mpsc::channel(row_buffer);
600        let (response_tx, _response_rx) = oneshot::channel();
601        let req = PipelineRequest {
602            messages,
603            collector: ResponseCollector::Stream { header_tx, row_tx },
604            response_tx,
605        };
606        self.request_tx
607            .send(req)
608            .await
609            .map_err(|_| PgWireError::ConnectionClosed)?;
610        let header = header_rx
611            .await
612            .map_err(|_| PgWireError::ConnectionClosed)??;
613        Ok((header, row_rx))
614    }
615
616    /// Execute a pipelined transaction:
617    /// setup (simple query) + data query (extended protocol) + COMMIT (simple query)
618    /// All coalesced into one TCP write. Binary-safe parameterized data query.
619    pub async fn pipeline_transaction(
620        &self,
621        setup_sql: &str,
622        query_sql: &str,
623        params: &[Option<&[u8]>],
624        param_oids: &[u32],
625        stmt_name: &[u8],
626        needs_parse: bool,
627    ) -> Result<Vec<RawRow>, PgWireError> {
628        let mut buf = BytesMut::with_capacity(1024);
629
630        // 1. Simple query for setup (BEGIN + SET ROLE + set_config).
631        frontend::encode_message(&FrontendMsg::Query(setup_sql.as_bytes()), &mut buf);
632
633        // Submit setup as Drain — we don't care about its response data.
634        let setup_msgs = buf.split();
635
636        // 2. Extended query for data.
637        let text_fmts: Vec<FormatCode> = vec![FormatCode::Text; params.len().max(1)];
638        let result_fmts = [FormatCode::Text];
639
640        if needs_parse {
641            frontend::encode_message(
642                &FrontendMsg::Parse {
643                    name: stmt_name,
644                    sql: query_sql.as_bytes(),
645                    param_oids,
646                },
647                &mut buf,
648            );
649        }
650
651        frontend::encode_message(
652            &FrontendMsg::Bind {
653                portal: b"",
654                statement: stmt_name,
655                param_formats: &text_fmts[..params.len()],
656                params,
657                result_formats: &result_fmts,
658            },
659            &mut buf,
660        );
661
662        frontend::encode_message(
663            &FrontendMsg::Execute {
664                portal: b"",
665                max_rows: 0,
666            },
667            &mut buf,
668        );
669
670        frontend::encode_message(&FrontendMsg::Sync, &mut buf);
671
672        let data_msgs = buf.split();
673
674        // 3. Simple query for COMMIT — in its own buffer so each request
675        // carries exactly the bytes that produce its ReadyForQuery response.
676        let mut commit_buf = BytesMut::with_capacity(32);
677        frontend::encode_message(&FrontendMsg::Query(b"COMMIT"), &mut commit_buf);
678
679        // Submit all three as separate requests with different collectors.
680        // They'll be coalesced by the writer into one write() syscall.
681        let (setup_tx, setup_rx) = oneshot::channel();
682        let (data_tx, data_rx) = oneshot::channel();
683        let (commit_tx, commit_rx) = oneshot::channel();
684
685        // Send all three requests to the writer channel.
686        // The writer drains the channel and writes them all at once.
687        self.request_tx
688            .send(PipelineRequest {
689                messages: setup_msgs,
690                collector: ResponseCollector::Drain,
691                response_tx: setup_tx,
692            })
693            .await
694            .map_err(|_| PgWireError::ConnectionClosed)?;
695
696        self.request_tx
697            .send(PipelineRequest {
698                messages: data_msgs,
699                collector: ResponseCollector::Rows,
700                response_tx: data_tx,
701            })
702            .await
703            .map_err(|_| PgWireError::ConnectionClosed)?;
704
705        self.request_tx
706            .send(PipelineRequest {
707                messages: commit_buf,
708                collector: ResponseCollector::Drain,
709                response_tx: commit_tx,
710            })
711            .await
712            .map_err(|_| PgWireError::ConnectionClosed)?;
713
714        // Wait for all responses.
715        setup_rx
716            .await
717            .map_err(|_| PgWireError::ConnectionClosed)??;
718
719        let data_resp = data_rx.await.map_err(|_| PgWireError::ConnectionClosed)??;
720
721        commit_rx
722            .await
723            .map_err(|_| PgWireError::ConnectionClosed)??;
724
725        match data_resp {
726            PipelineResponse::Rows { rows, .. } => Ok(rows),
727            PipelineResponse::Done => Ok(Vec::new()),
728        }
729    }
730
731    /// Execute a simple parameterized query (no transaction).
732    pub async fn query(
733        &self,
734        sql: &str,
735        params: &[Option<&[u8]>],
736        param_oids: &[u32],
737        stmt_name: &[u8],
738        needs_parse: bool,
739    ) -> Result<Vec<RawRow>, PgWireError> {
740        self.query_with_formats(sql, params, param_oids, &[], &[], stmt_name, needs_parse)
741            .await
742    }
743
744    /// Execute a parameterized query with explicit per-param and per-result
745    /// format codes (text = 0, binary = 1).
746    ///
747    /// `param_formats` is interpreted per PostgreSQL wire protocol rules:
748    /// - empty: all params are text
749    /// - length 1: the single code applies to every param
750    /// - length N (== params.len()): one code per param
751    ///
752    /// Same rules apply to `result_formats` for output columns (empty → all
753    /// text; single code → applies to all columns; per-column list otherwise).
754    #[allow(clippy::too_many_arguments)]
755    pub async fn query_with_formats(
756        &self,
757        sql: &str,
758        params: &[Option<&[u8]>],
759        param_oids: &[u32],
760        param_formats: &[FormatCode],
761        result_formats: &[FormatCode],
762        stmt_name: &[u8],
763        needs_parse: bool,
764    ) -> Result<Vec<RawRow>, PgWireError> {
765        let mut buf = BytesMut::with_capacity(512);
766
767        // Default to all-text if caller passes empty slices.
768        let text_param_fmts: Vec<FormatCode>;
769        let param_fmts_slice: &[FormatCode] = if param_formats.is_empty() {
770            text_param_fmts = vec![FormatCode::Text; params.len().max(1)];
771            &text_param_fmts[..params.len()]
772        } else {
773            param_formats
774        };
775        let default_result_fmts = [FormatCode::Text];
776        let result_fmts_slice: &[FormatCode] = if result_formats.is_empty() {
777            &default_result_fmts
778        } else {
779            result_formats
780        };
781
782        if needs_parse {
783            frontend::encode_message(
784                &FrontendMsg::Parse {
785                    name: stmt_name,
786                    sql: sql.as_bytes(),
787                    param_oids,
788                },
789                &mut buf,
790            );
791        }
792
793        frontend::encode_message(
794            &FrontendMsg::Bind {
795                portal: b"",
796                statement: stmt_name,
797                param_formats: param_fmts_slice,
798                params,
799                result_formats: result_fmts_slice,
800            },
801            &mut buf,
802        );
803
804        frontend::encode_message(
805            &FrontendMsg::Execute {
806                portal: b"",
807                max_rows: 0,
808            },
809            &mut buf,
810        );
811
812        frontend::encode_message(&FrontendMsg::Sync, &mut buf);
813
814        let resp = self.submit(buf, ResponseCollector::Rows).await?;
815        match resp {
816            PipelineResponse::Rows { rows, .. } => Ok(rows),
817            PipelineResponse::Done => Ok(Vec::new()),
818        }
819    }
820
821    /// Variant of `exec_query` with per-param and per-result format codes.
822    /// See `query_with_formats` for format code semantics.
823    pub async fn exec_query_with_formats(
824        &self,
825        sql: &str,
826        params: &[Option<&[u8]>],
827        param_oids: &[u32],
828        param_formats: &[FormatCode],
829        result_formats: &[FormatCode],
830    ) -> Result<Vec<RawRow>, PgWireError> {
831        let (stmt_name, needs_parse) = self.lookup_or_alloc(sql);
832        match self
833            .query_with_formats(
834                sql,
835                params,
836                param_oids,
837                param_formats,
838                result_formats,
839                &stmt_name,
840                needs_parse,
841            )
842            .await
843        {
844            Ok(rows) => Ok(rows),
845            Err(PgWireError::Pg(ref pg_err))
846                if !needs_parse && is_stale_statement_error(pg_err) =>
847            {
848                tracing::debug!(sql = sql, "prepared statement invalidated — re-parsing");
849                self.invalidate_statement(sql);
850                let (stmt_name, _) = self.lookup_or_alloc(sql);
851                self.query_with_formats(
852                    sql,
853                    params,
854                    param_oids,
855                    param_formats,
856                    result_formats,
857                    &stmt_name,
858                    true,
859                )
860                .await
861            }
862            Err(e) => Err(e),
863        }
864    }
865}
866
867// ---------------------------------------------------------------------------
868// Writer task
869// ---------------------------------------------------------------------------
870
871async fn writer_task(
872    mut rx: mpsc::Receiver<PipelineRequest>,
873    mut stream: tokio::io::WriteHalf<crate::tls::MaybeTlsStream>,
874    pending: Arc<Mutex<VecDeque<PendingResponse>>>,
875    pending_notify: Arc<tokio::sync::Notify>,
876) {
877    let mut write_buf = BytesMut::with_capacity(8192);
878
879    loop {
880        // Wait for the first request.
881        let first = match rx.recv().await {
882            Some(req) => req,
883            None => {
884                // Channel closed — drain any pending responses with ConnectionClosed.
885                drain_pending_on_exit(&pending).await;
886                return;
887            }
888        };
889
890        // Drain any additional queued requests (batch coalescing).
891        write_buf.clear();
892        write_buf.extend_from_slice(&first.messages);
893
894        let mut batch: Vec<PendingResponse> = vec![PendingResponse {
895            collector: first.collector,
896            response_tx: first.response_tx,
897        }];
898
899        // Non-blocking drain of all queued requests.
900        while let Ok(req) = rx.try_recv() {
901            write_buf.extend_from_slice(&req.messages);
902            batch.push(PendingResponse {
903                collector: req.collector,
904                response_tx: req.response_tx,
905            });
906        }
907
908        // ONE write() syscall for all coalesced messages.
909        // Write BEFORE enqueuing pending responses — if the write fails,
910        // we send errors to callers instead of leaving them hanging.
911        let write_result = stream.write_all(&write_buf).await;
912        let write_err = match write_result {
913            Ok(_) => stream.flush().await.err(),
914            Err(e) => Some(e),
915        };
916
917        if let Some(e) = write_err {
918            tracing::error!("Writer error: {e}");
919            let msg = e.to_string();
920            for p in batch {
921                let _ = p.response_tx.send(Err(PgWireError::Io(std::io::Error::new(
922                    std::io::ErrorKind::BrokenPipe,
923                    msg.clone(),
924                ))));
925            }
926            // Drain any already-pending responses so the reader doesn't hang.
927            drain_pending_on_exit(&pending).await;
928            return;
929        }
930
931        // Write succeeded — enqueue pending responses for the reader.
932        {
933            let mut pq = pending.lock().await;
934            for p in batch {
935                pq.push_back(p);
936            }
937        }
938        // Wake the reader task to process the newly enqueued responses.
939        pending_notify.notify_one();
940    }
941}
942
943/// On writer exit, drain all pending responses with ConnectionClosed errors
944/// so callers don't wait for the 5-minute timeout.
945async fn drain_pending_on_exit(pending: &Arc<Mutex<VecDeque<PendingResponse>>>) {
946    let mut pq = pending.lock().await;
947    while let Some(pr) = pq.pop_front() {
948        let _ = pr.response_tx.send(Err(PgWireError::ConnectionClosed));
949    }
950}
951
952// ---------------------------------------------------------------------------
953// Reader task
954// ---------------------------------------------------------------------------
955
956async fn reader_task(
957    mut stream: tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
958    pending: Arc<Mutex<VecDeque<PendingResponse>>>,
959    pending_notify: Arc<tokio::sync::Notify>,
960    notification_tx: mpsc::Sender<BackendMsg>,
961    state_mutated: Arc<std::sync::atomic::AtomicBool>,
962    dropped_notifications: Arc<std::sync::atomic::AtomicU64>,
963) {
964    let mut recv_buf = BytesMut::with_capacity(32 * 1024);
965
966    loop {
967        // Wait for a pending response to become available.
968        let pr = loop {
969            {
970                let mut pq = pending.lock().await;
971                if let Some(pr) = pq.pop_front() {
972                    break pr;
973                }
974            }
975            // No pending — wait for the writer to signal.
976            pending_notify.notified().await;
977        };
978
979        // Collect the response based on the collector type.
980        let result = match pr.collector {
981            ResponseCollector::Rows => {
982                collect_rows(
983                    &mut stream,
984                    &mut recv_buf,
985                    &notification_tx,
986                    &state_mutated,
987                    &dropped_notifications,
988                )
989                .await
990            }
991            ResponseCollector::Drain => {
992                drain_until_ready(&mut stream, &mut recv_buf, Some(&state_mutated))
993                    .await
994                    .map(|_| PipelineResponse::Done)
995            }
996            ResponseCollector::Stream { header_tx, row_tx } => {
997                stream_rows(
998                    &mut stream,
999                    &mut recv_buf,
1000                    header_tx,
1001                    row_tx,
1002                    &notification_tx,
1003                    &state_mutated,
1004                    &dropped_notifications,
1005                )
1006                .await;
1007                Ok(PipelineResponse::Done)
1008            }
1009            ResponseCollector::CopyIn { .. } => {
1010                collect_copy_in_response(&mut stream, &mut recv_buf, &state_mutated).await
1011            }
1012            ResponseCollector::CopyOut => {
1013                collect_copy_out(&mut stream, &mut recv_buf, &state_mutated).await
1014            }
1015        };
1016
1017        // Send the response back to the caller.
1018        let _ = pr.response_tx.send(result);
1019    }
1020}
1021
1022async fn read_msg(
1023    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1024    buf: &mut BytesMut,
1025) -> Result<BackendMsg, PgWireError> {
1026    loop {
1027        if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1028            return Ok(msg);
1029        }
1030        let n = stream.read_buf(buf).await?;
1031        if n == 0 {
1032            // EOF — try to parse any remaining data in the buffer before giving up.
1033            // This handles the case where the last message arrived just before the
1034            // connection closed and is already fully buffered.
1035            if let Some(msg) = backend::parse_message(buf).map_err(PgWireError::Protocol)? {
1036                return Ok(msg);
1037            }
1038            return Err(PgWireError::ConnectionClosed);
1039        }
1040    }
1041}
1042
1043/// If the ReadyForQuery status byte is anything other than `I` (idle),
1044/// flag the connection as state-mutated. `T` (in transaction) and `E`
1045/// (failed transaction) both leave session state that needs DISCARD ALL.
1046fn note_rfq_status(status: u8, state_mutated: &std::sync::atomic::AtomicBool) {
1047    if status != b'I' {
1048        state_mutated.store(true, std::sync::atomic::Ordering::Release);
1049    }
1050}
1051
1052async fn collect_rows(
1053    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1054    buf: &mut BytesMut,
1055    notification_tx: &mpsc::Sender<BackendMsg>,
1056    state_mutated: &std::sync::atomic::AtomicBool,
1057    dropped_notifications: &std::sync::atomic::AtomicU64,
1058) -> Result<PipelineResponse, PgWireError> {
1059    let mut rows = Vec::new();
1060    let mut fields = Vec::new();
1061    let mut command_tag = String::new();
1062    loop {
1063        let msg = read_msg(stream, buf).await?;
1064        match msg {
1065            BackendMsg::DataRow(row) => rows.push(row),
1066            BackendMsg::RowDescription { fields: f } => fields = f,
1067            BackendMsg::CommandComplete { tag } => command_tag = tag,
1068            BackendMsg::ReadyForQuery { status } => {
1069                note_rfq_status(status, state_mutated);
1070                return Ok(PipelineResponse::Rows {
1071                    fields,
1072                    rows,
1073                    command_tag,
1074                });
1075            }
1076            BackendMsg::ErrorResponse { fields } => {
1077                drain_until_ready(stream, buf, Some(state_mutated)).await?;
1078                return Err(PgWireError::Pg(fields));
1079            }
1080            msg @ BackendMsg::NotificationResponse { .. } => {
1081                // Forward notification instead of dropping.
1082                #[allow(clippy::collapsible_match)]
1083                if notification_tx.try_send(msg).is_err() {
1084                    dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1085                    tracing::warn!("notification channel full, dropping notification");
1086                }
1087            }
1088            BackendMsg::ParseComplete
1089            | BackendMsg::BindComplete
1090            | BackendMsg::NoData
1091            | BackendMsg::NoticeResponse { .. }
1092            | BackendMsg::EmptyQueryResponse => {}
1093            _ => {}
1094        }
1095    }
1096}
1097
1098async fn drain_until_ready(
1099    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1100    buf: &mut BytesMut,
1101    state_mutated: Option<&std::sync::atomic::AtomicBool>,
1102) -> Result<(), PgWireError> {
1103    loop {
1104        let msg = read_msg(stream, buf).await?;
1105        if let BackendMsg::ReadyForQuery { status } = msg {
1106            if let Some(sm) = state_mutated {
1107                note_rfq_status(status, sm);
1108            }
1109            return Ok(());
1110        }
1111        if let BackendMsg::ErrorResponse { ref fields } = msg {
1112            tracing::warn!("Error in drain: {}: {}", fields.code, fields.message);
1113        }
1114    }
1115}
1116
1117/// Stream rows one at a time, sending header first, then individual rows.
1118async fn stream_rows(
1119    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1120    buf: &mut BytesMut,
1121    header_tx: oneshot::Sender<Result<StreamHeader, PgWireError>>,
1122    row_tx: mpsc::Sender<Result<StreamedRow, PgWireError>>,
1123    notification_tx: &mpsc::Sender<BackendMsg>,
1124    state_mutated: &std::sync::atomic::AtomicBool,
1125    dropped_notifications: &std::sync::atomic::AtomicU64,
1126) {
1127    let mut header_tx = Some(header_tx);
1128    let mut fields = Vec::new();
1129    loop {
1130        let msg = match read_msg(stream, buf).await {
1131            Ok(msg) => msg,
1132            Err(e) => {
1133                if let Some(htx) = header_tx.take() {
1134                    let _ = htx.send(Err(e));
1135                } else {
1136                    let _ = row_tx.send(Err(e)).await;
1137                }
1138                return;
1139            }
1140        };
1141        match msg {
1142            BackendMsg::RowDescription { fields: f } => {
1143                fields = f;
1144            }
1145            BackendMsg::DataRow(row) => {
1146                if let Some(htx) = header_tx.take() {
1147                    let _ = htx.send(Ok(StreamHeader {
1148                        fields: fields.clone(),
1149                    }));
1150                }
1151                if row_tx.send(Ok(row)).await.is_err() {
1152                    let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1153                    return;
1154                }
1155            }
1156            BackendMsg::CommandComplete { .. } => {
1157                if let Some(htx) = header_tx.take() {
1158                    let _ = htx.send(Ok(StreamHeader {
1159                        fields: std::mem::take(&mut fields),
1160                    }));
1161                }
1162            }
1163            BackendMsg::ReadyForQuery { status } => {
1164                note_rfq_status(status, state_mutated);
1165                if let Some(htx) = header_tx.take() {
1166                    let _ = htx.send(Ok(StreamHeader {
1167                        fields: std::mem::take(&mut fields),
1168                    }));
1169                }
1170                return;
1171            }
1172            BackendMsg::ErrorResponse { fields: err } => {
1173                if let Some(htx) = header_tx.take() {
1174                    let _ = htx.send(Err(PgWireError::Pg(err)));
1175                } else {
1176                    let _ = row_tx.send(Err(PgWireError::Pg(err))).await;
1177                }
1178                let _ = drain_until_ready(stream, buf, Some(state_mutated)).await;
1179                return;
1180            }
1181            msg @ BackendMsg::NotificationResponse { .. } => {
1182                #[allow(clippy::collapsible_match)]
1183                if notification_tx.try_send(msg).is_err() {
1184                    dropped_notifications.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1185                    tracing::warn!("notification channel full, dropping notification");
1186                }
1187            }
1188            BackendMsg::ParseComplete
1189            | BackendMsg::BindComplete
1190            | BackendMsg::NoData
1191            | BackendMsg::PortalSuspended
1192            | BackendMsg::NoticeResponse { .. }
1193            | BackendMsg::EmptyQueryResponse => {}
1194            _ => {}
1195        }
1196    }
1197}
1198
1199/// Handle COPY IN response: skip CopyInResponse, wait for CommandComplete + ReadyForQuery.
1200/// The actual CopyData + CopyDone were pre-buffered in the write, so PG processes them.
1201async fn collect_copy_in_response(
1202    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1203    buf: &mut BytesMut,
1204    state_mutated: &std::sync::atomic::AtomicBool,
1205) -> Result<PipelineResponse, PgWireError> {
1206    let mut command_tag = String::new();
1207    loop {
1208        let msg = read_msg(stream, buf).await?;
1209        match msg {
1210            BackendMsg::CopyInResponse { .. } => {}
1211            BackendMsg::CommandComplete { tag } => command_tag = tag,
1212            BackendMsg::ReadyForQuery { status } => {
1213                note_rfq_status(status, state_mutated);
1214                return Ok(PipelineResponse::Rows {
1215                    fields: Vec::new(),
1216                    rows: Vec::new(),
1217                    command_tag,
1218                });
1219            }
1220            BackendMsg::ErrorResponse { fields } => {
1221                drain_until_ready(stream, buf, Some(state_mutated)).await?;
1222                return Err(PgWireError::Pg(fields));
1223            }
1224            _ => {}
1225        }
1226    }
1227}
1228
1229/// Collect COPY OUT data: CopyOutResponse → CopyData* → CopyDone → CommandComplete → ReadyForQuery.
1230async fn collect_copy_out(
1231    stream: &mut tokio::io::ReadHalf<crate::tls::MaybeTlsStream>,
1232    buf: &mut BytesMut,
1233    state_mutated: &std::sync::atomic::AtomicBool,
1234) -> Result<PipelineResponse, PgWireError> {
1235    let mut data_chunks: Vec<RawRow> = Vec::new();
1236    let mut command_tag = String::new();
1237    loop {
1238        let msg = read_msg(stream, buf).await?;
1239        match msg {
1240            BackendMsg::CopyOutResponse { .. } => {}
1241            BackendMsg::CopyData { data } => {
1242                let body = bytes::Bytes::from(data);
1243                data_chunks.push(RawRow::from_full_body(body));
1244            }
1245            BackendMsg::CopyDone => {}
1246            BackendMsg::CommandComplete { tag } => command_tag = tag,
1247            BackendMsg::ReadyForQuery { status } => {
1248                note_rfq_status(status, state_mutated);
1249                return Ok(PipelineResponse::Rows {
1250                    fields: Vec::new(),
1251                    rows: data_chunks,
1252                    command_tag,
1253                });
1254            }
1255            BackendMsg::ErrorResponse { fields } => {
1256                drain_until_ready(stream, buf, Some(state_mutated)).await?;
1257                return Err(PgWireError::Pg(fields));
1258            }
1259            _ => {}
1260        }
1261    }
1262}
1263
1264/// Check if a PostgreSQL error indicates a stale/invalidated prepared statement.
1265/// Error codes: 26000 (invalid_sql_statement_name), 0A000 (feature_not_supported
1266/// — used when cached plan changes type).
1267fn is_stale_statement_error(err: &crate::protocol::types::PgError) -> bool {
1268    matches!(err.code.as_str(), "26000" | "0A000")
1269}
1270
1271fn parse_copy_count(tag: &str) -> u64 {
1272    // COPY tag format: "COPY 123"
1273    tag.strip_prefix("COPY ")
1274        .and_then(|s| s.parse::<u64>().ok())
1275        .unwrap_or(0)
1276}
1277
1278// Extension to WireConn to extract the underlying stream.
1279impl WireConn {
1280    pub(crate) fn into_stream(self) -> crate::tls::MaybeTlsStream {
1281        self.stream
1282    }
1283}