Skip to main content

hyperdb_api_core/client/
async_connection.rs

1// Copyright (c) 2026, Salesforce, Inc. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Async low-level connection handling.
5//!
6//! This module provides [`AsyncRawConnection`], the async version of [`RawConnection`](super::connection::RawConnection).
7//! It uses tokio's async I/O traits for non-blocking network operations.
8
9use std::collections::HashMap;
10
11use bytes::BytesMut;
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tracing::{debug, info, warn};
14
15use crate::protocol::message::{backend::Message, frontend};
16
17use super::auth::{self, AuthState};
18use super::error::{Error, Result};
19
20/// An async raw connection to a Hyper server.
21///
22/// This is the async equivalent of [`RawConnection`](super::connection::RawConnection),
23/// using tokio's async I/O traits instead of std's sync I/O.
24///
25/// The connection is generic over the stream type `S`, allowing it to work
26/// with different transport mechanisms (`TcpStream`, `TlsStream`, etc.) as long as they
27/// implement `AsyncRead + AsyncWrite + Unpin`.
28#[derive(Debug)]
29pub struct AsyncRawConnection<S> {
30    /// The underlying async I/O stream.
31    stream: S,
32    /// Buffer for reading incoming messages from the server.
33    read_buf: BytesMut,
34    /// Buffer for writing outgoing messages to the server.
35    write_buf: BytesMut,
36    /// Backend process ID (for cancel requests).
37    process_id: i32,
38    /// Secret key for authenticating cancel requests.
39    secret_key: i32,
40    /// Server parameters received during startup.
41    server_params: HashMap<String, String>,
42    /// Set by `AsyncCopyInWriter::Drop` when a COPY session is abandoned.
43    /// The `CopyFail` message has been written to `write_buf` but not flushed.
44    /// The next async operation must flush and drain the server response
45    /// (`ErrorResponse` + `ReadyForQuery`) before proceeding.
46    pending_copy_cancel: bool,
47    /// Sticky flag mirroring
48    /// [`RawConnection`](super::connection::RawConnection)'s
49    /// `desynchronized` field. Set when a bounded drain exhausts its
50    /// budget or hits a mid-drain I/O error; never cleared. See
51    /// [`Self::is_healthy`] and [`Self::ensure_healthy`] for the
52    /// consumer-facing API.
53    desynchronized: bool,
54}
55
56impl<S> AsyncRawConnection<S>
57where
58    S: AsyncRead + AsyncWrite + Unpin,
59{
60    /// Creates a new async raw connection from a stream.
61    ///
62    /// Initializes read and write buffers with default capacity (64 KB each).
63    /// The connection is not yet authenticated - call `startup()` to begin
64    /// the connection handshake.
65    pub fn new(stream: S) -> Self {
66        AsyncRawConnection {
67            stream,
68            read_buf: BytesMut::with_capacity(64 * 1024),
69            write_buf: BytesMut::with_capacity(64 * 1024),
70            process_id: 0,
71            secret_key: 0,
72            server_params: HashMap::new(),
73            pending_copy_cancel: false,
74            desynchronized: false,
75        }
76    }
77
78    /// Returns `true` if this connection is still in a known-good state
79    /// and safe to use for new requests. See
80    /// [`super::connection::RawConnection::is_healthy`] for the full
81    /// semantics — this is the async mirror with identical behavior.
82    pub fn is_healthy(&self) -> bool {
83        !self.desynchronized
84    }
85
86    /// Marks this connection as desynchronized.
87    ///
88    /// Used by async result streams that are dropped mid-iteration: the
89    /// [`Drop`] impl cannot `await` to drain trailing `ErrorResponse +
90    /// ReadyForQuery` messages after sending a cancel, so it flags the
91    /// connection so the next operation short-circuits with a clear error
92    /// rather than hanging or misinterpreting stale server output.
93    pub fn mark_desynchronized(&mut self) {
94        self.desynchronized = true;
95    }
96
97    /// Async mirror of
98    /// [`super::connection::RawConnection::ensure_healthy`]. Called from
99    /// the entry point of every `pub async fn` that initiates a new
100    /// server request to short-circuit operations on a desynchronized
101    /// connection before any bytes hit the wire.
102    pub(crate) fn ensure_healthy(&self) -> Result<()> {
103        if self.desynchronized {
104            return Err(Error::new(
105                super::error::ErrorKind::Connection,
106                "connection is desynchronized from the server and cannot be reused; \
107                 discard it and open a new one",
108            ));
109        }
110        Ok(())
111    }
112
113    /// Returns the process ID assigned by the server.
114    pub fn process_id(&self) -> i32 {
115        self.process_id
116    }
117
118    /// Returns the secret key for cancel requests.
119    pub fn secret_key(&self) -> i32 {
120        self.secret_key
121    }
122
123    /// Returns a reference to the underlying stream.
124    pub fn stream(&self) -> &S {
125        &self.stream
126    }
127
128    /// Returns a mutable reference to the underlying stream.
129    pub fn stream_mut(&mut self) -> &mut S {
130        &mut self.stream
131    }
132
133    /// Returns a server parameter value by name.
134    pub fn parameter_status(&self, name: &str) -> Option<&str> {
135        self.server_params
136            .get(name)
137            .map(std::string::String::as_str)
138    }
139
140    /// Queues a `CopyFail` message in the write buffer (synchronous).
141    ///
142    /// Called from `AsyncCopyInWriter::Drop` when a COPY session is abandoned
143    /// without `finish()` or `cancel()`. The `CopyFail` is written to the buffer
144    /// but NOT flushed (we can't do async I/O from `Drop`). The next async
145    /// operation will call [`drain_pending_copy_cancel`](Self::drain_pending_copy_cancel) to flush and drain
146    /// the server's `ErrorResponse` + `ReadyForQuery` before proceeding.
147    pub fn queue_copy_fail(&mut self, reason: &str) {
148        frontend::copy_fail(reason, &mut self.write_buf);
149        self.pending_copy_cancel = true;
150    }
151
152    /// Drains a pending COPY cancel that was queued by `queue_copy_fail()`.
153    ///
154    /// If `pending_copy_cancel` is set, this flushes the `CopyFail` message to
155    /// the server and reads messages until `ReadyForQuery`, restoring the
156    /// connection to a usable state. Called automatically at the start of
157    /// new operations (`simple_query`, `query_binary`, `start_copy_in*`).
158    ///
159    /// # Errors
160    ///
161    /// Returns [`Error`] (I/O) if flushing the queued `CopyFail` or
162    /// reading the server's drain responses fails. A successful drain
163    /// clears `pending_copy_cancel`.
164    pub async fn drain_pending_copy_cancel(&mut self) -> Result<()> {
165        if !self.pending_copy_cancel {
166            return Ok(());
167        }
168
169        // Flush the queued CopyFail message
170        self.flush().await?;
171
172        // Drain messages until the connection is back in ReadyForQuery state
173        loop {
174            let msg = self.read_message().await?;
175            match msg {
176                Message::ReadyForQuery(_) => {
177                    self.pending_copy_cancel = false;
178                    debug!(
179                        target: "hyperdb_api_core::client",
180                        "drained pending COPY cancel — connection restored"
181                    );
182                    return Ok(());
183                }
184                Message::ErrorResponse(_) => {
185                    // Expected — server confirms the cancel
186                }
187                _ => {
188                    // Ignore other messages (e.g., NoticeResponse)
189                }
190            }
191        }
192    }
193
194    /// Sends a startup message and performs initial handshake (async).
195    ///
196    /// # Errors
197    ///
198    /// - Returns [`Error`] (auth) when the server requests an
199    ///   auth method and no password is supplied, when the offered
200    ///   SASL mechanisms exclude SCRAM-SHA-256, or when SCRAM state
201    ///   is missing at the SASL-continue / SASL-final step.
202    /// - Returns [`Error`] (server) when the server sends an `ErrorResponse`
203    ///   during startup (unknown user, unknown database, etc.).
204    /// - Returns [`Error`] (protocol) if a message arrives out of
205    ///   sequence.
206    /// - Returns [`Error`] (I/O) on transport read/write failure.
207    pub async fn startup(&mut self, params: &[(&str, &str)], password: Option<&str>) -> Result<()> {
208        // Send startup message
209        frontend::startup_message(params, &mut self.write_buf)?;
210        self.flush().await?;
211
212        // Handle authentication
213        let mut auth_state: Option<AuthState> = None;
214
215        loop {
216            let msg = self.read_message().await?;
217            match msg {
218                Message::AuthenticationOk => {
219                    info!(target: "hyperdb_api", "connection-auth-success");
220                }
221                Message::AuthenticationCleartextPassword => {
222                    debug!(target: "hyperdb_api", method = "cleartext", "connection-auth-method");
223                    let password = password.ok_or_else(|| {
224                        Error::authentication(
225                            "server requested cleartext password but none provided",
226                        )
227                    })?;
228                    frontend::password_message(password, &mut self.write_buf)?;
229                    self.flush().await?;
230                }
231                Message::AuthenticationMd5Password(body) => {
232                    debug!(target: "hyperdb_api", method = "MD5", "connection-auth-method");
233                    let password = password.ok_or_else(|| {
234                        Error::authentication("server requested MD5 password but none provided")
235                    })?;
236                    let user = params
237                        .iter()
238                        .find(|(k, _)| *k == "user")
239                        .map_or("", |(_, v)| *v);
240
241                    let md5_response = auth::compute_md5_password(user, password, &body.salt());
242                    frontend::password_message(&md5_response, &mut self.write_buf)?;
243                    self.flush().await?;
244                }
245                Message::AuthenticationSasl(body) => {
246                    debug!(target: "hyperdb_api", method = "SCRAM-SHA-256", "connection-auth-method");
247                    let password = password.ok_or_else(|| {
248                        Error::authentication(
249                            "server requested SASL authentication but no password provided",
250                        )
251                    })?;
252
253                    let mechanisms: Vec<&str> = body.mechanisms().collect();
254                    if !mechanisms.contains(&"SCRAM-SHA-256") {
255                        return Err(Error::authentication(format!(
256                            "server offered unsupported SASL mechanisms: {mechanisms:?}"
257                        )));
258                    }
259
260                    let (state, client_first) = auth::scram_client_first(password)?;
261                    auth_state = Some(state);
262
263                    frontend::sasl_initial_response(
264                        "SCRAM-SHA-256",
265                        &client_first,
266                        &mut self.write_buf,
267                    )?;
268                    self.flush().await?;
269                }
270                Message::AuthenticationSaslContinue(body) => {
271                    let state = auth_state.take().ok_or_else(|| {
272                        Error::authentication("received SASL continue without initial state")
273                    })?;
274
275                    let server_first = body.data();
276                    let (new_state, client_final) = auth::scram_client_final(state, server_first)?;
277                    auth_state = Some(new_state);
278
279                    frontend::sasl_response(&client_final, &mut self.write_buf)?;
280                    self.flush().await?;
281                }
282                Message::AuthenticationSaslFinal(body) => {
283                    let state = auth_state.take().ok_or_else(|| {
284                        Error::authentication("received SASL final without state")
285                    })?;
286                    auth::scram_verify_server(state, body.data())?;
287                }
288                Message::BackendKeyData(data) => {
289                    self.process_id = data.process_id();
290                    self.secret_key = data.secret_key();
291                }
292                Message::ParameterStatus(body) => {
293                    if let (Ok(name), Ok(value)) = (body.name(), body.value()) {
294                        self.server_params
295                            .insert(name.to_string(), value.to_string());
296                    }
297                }
298                Message::ReadyForQuery(_) => {
299                    return Ok(());
300                }
301                Message::ErrorResponse(body) => {
302                    return Err(self.consume_error(&body).await);
303                }
304                _ => {
305                    return Err(Error::protocol("unexpected message during startup"));
306                }
307            }
308        }
309    }
310
311    /// Sends a simple query and returns all messages until `ReadyForQuery` (async).
312    ///
313    /// # Errors
314    ///
315    /// - Returns [`Error`] (connection) if the connection has been
316    ///   marked unhealthy.
317    /// - Returns [`Error`] (server) when the server emits an
318    ///   `ErrorResponse` (SQL error, constraint violation, etc.).
319    /// - Returns [`Error`] (I/O) / [`Error`] (closed) on transport
320    ///   read/write failure.
321    /// - Propagates any error from
322    ///   [`Self::drain_pending_copy_cancel`] when a queued `CopyFail`
323    ///   needs to be flushed first.
324    pub async fn simple_query(&mut self, query: &str) -> Result<Vec<Message>> {
325        self.ensure_healthy()?;
326        self.drain_pending_copy_cancel().await?;
327        frontend::query(query, &mut self.write_buf)?;
328        self.flush().await?;
329
330        let mut messages = Vec::new();
331        loop {
332            let msg = self.read_message().await?;
333            match &msg {
334                Message::ReadyForQuery(_) => {
335                    messages.push(msg);
336                    return Ok(messages);
337                }
338                Message::ErrorResponse(body) => {
339                    return Err(self.consume_error(body).await);
340                }
341                _ => {
342                    messages.push(msg);
343                }
344            }
345        }
346    }
347
348    /// Sends a query using extended protocol with binary format results (async).
349    ///
350    /// # Errors
351    ///
352    /// Same failure modes as [`Self::simple_query`].
353    pub async fn query_binary(&mut self, query: &str) -> Result<Vec<Message>> {
354        self.ensure_healthy()?;
355        self.drain_pending_copy_cancel().await?;
356        const HYPER_BINARY_FORMAT: i16 = 2;
357
358        frontend::parse("", query, &[], &mut self.write_buf)?;
359        frontend::bind(
360            "",
361            "",
362            &[],
363            &[],
364            &[HYPER_BINARY_FORMAT],
365            &mut self.write_buf,
366        )?;
367        frontend::describe(b'P', "", &mut self.write_buf)?;
368        frontend::execute("", 0, &mut self.write_buf)?;
369        frontend::sync(&mut self.write_buf);
370
371        self.flush().await?;
372
373        let mut messages = Vec::new();
374        loop {
375            let msg = self.read_message().await?;
376            match &msg {
377                Message::ReadyForQuery(_) => {
378                    messages.push(msg);
379                    return Ok(messages);
380                }
381                Message::ErrorResponse(body) => {
382                    return Err(self.consume_error(body).await);
383                }
384                _ => {
385                    messages.push(msg);
386                }
387            }
388        }
389    }
390
391    /// Starts a binary query but leaves result consumption to the caller (async).
392    ///
393    /// # Errors
394    ///
395    /// - Returns [`Error`] (connection) if the connection is unhealthy.
396    /// - Returns [`Error`] (I/O) on transport write failure.
397    /// - Propagates any error from [`Self::drain_pending_copy_cancel`].
398    pub async fn start_query_binary(&mut self, query: &str) -> Result<()> {
399        self.ensure_healthy()?;
400        // Drain any CopyFail queued by `AsyncCopyInWriter::Drop` before
401        // writing the extended-query bytes. Without this, the flush at
402        // the end of this method would send [CopyFail | Parse | Bind |
403        // Describe | Execute | Sync] in a single buffer and the server
404        // would answer with CopyFail's ErrorResponse+ReadyForQuery
405        // interleaved with our query's responses — the read loop would
406        // then misattribute the COPY error to this query.
407        self.drain_pending_copy_cancel().await?;
408        const HYPER_BINARY_FORMAT: i16 = 2;
409
410        frontend::parse("", query, &[], &mut self.write_buf)?;
411        frontend::bind(
412            "",
413            "",
414            &[],
415            &[],
416            &[HYPER_BINARY_FORMAT],
417            &mut self.write_buf,
418        )?;
419        frontend::describe(b'P', "", &mut self.write_buf)?;
420        frontend::execute("", 0, &mut self.write_buf)?;
421        frontend::sync(&mut self.write_buf);
422
423        self.flush().await
424    }
425
426    /// Starts a simple query but leaves result consumption to the caller (async).
427    ///
428    /// # Errors
429    ///
430    /// Same failure modes as [`Self::start_query_binary`].
431    pub async fn start_simple_query(&mut self, query: &str) -> Result<()> {
432        self.ensure_healthy()?;
433        // See `start_query_binary` for why the pending-copy-cancel drain
434        // is required before writing any new query bytes.
435        self.drain_pending_copy_cancel().await?;
436        frontend::query(query, &mut self.write_buf)?;
437        self.flush().await
438    }
439
440    /// Starts an **execute** of a prepared statement but leaves result
441    /// consumption to the caller (async).
442    ///
443    /// Async mirror of
444    /// [`super::connection::RawConnection::start_execute_prepared`]. See
445    /// that method's docs for the split format-code rationale (params
446    /// use `1` = PG binary/BE, results use `2` = HyperBinary/LE).
447    ///
448    /// # Errors
449    ///
450    /// - Returns [`Error`] (connection) if the connection is unhealthy.
451    /// - Returns [`Error`] (I/O) on transport write failure.
452    /// - Propagates any error from [`Self::drain_pending_copy_cancel`].
453    pub async fn start_execute_prepared(
454        &mut self,
455        statement_name: &str,
456        params: &[Option<&[u8]>],
457        column_count: usize,
458    ) -> Result<()> {
459        self.ensure_healthy()?;
460        // Same rationale as `start_query_binary` for draining a pending
461        // CopyFail before writing new extended-query bytes.
462        self.drain_pending_copy_cancel().await?;
463
464        const PG_BINARY_FORMAT: i16 = 1;
465        const HYPER_BINARY_FORMAT: i16 = 2;
466        let param_formats: Vec<i16> = vec![PG_BINARY_FORMAT; params.len()];
467        let result_formats: Vec<i16> = vec![HYPER_BINARY_FORMAT; column_count];
468
469        frontend::bind(
470            "", // unnamed portal
471            statement_name,
472            &param_formats,
473            params,
474            &result_formats,
475            &mut self.write_buf,
476        )?;
477        frontend::execute("", 0, &mut self.write_buf)?;
478        frontend::sync(&mut self.write_buf);
479
480        self.flush().await
481    }
482
483    /// Reads a single message from the server (async).
484    ///
485    /// # Errors
486    ///
487    /// - Returns [`Error`] (I/O) if reading from the transport fails or
488    ///   if [`Message::parse`] reports a malformed frame.
489    /// - Returns [`Error`] (closed) when the transport reaches EOF
490    ///   (server closed the connection).
491    pub async fn read_message(&mut self) -> Result<Message> {
492        loop {
493            if let Some(msg) = Message::parse(&mut self.read_buf).map_err(Error::io)? {
494                return Ok(msg);
495            }
496
497            // Need more data — read directly into the spare capacity of
498            // `read_buf`, no temporary buffer or `extend_from_slice` memcpy.
499            // See the sync mirror in
500            // [`super::connection::RawConnection::read_message`] for the
501            // full rationale on the 64 KiB ceiling and Windows-loopback
502            // syscall amplification.
503            let prev_len = self.read_buf.len();
504            self.read_buf.resize(prev_len + 64 * 1024, 0);
505            let n = self.stream.read(&mut self.read_buf[prev_len..]).await?;
506            if n == 0 {
507                self.read_buf.truncate(prev_len);
508                warn!(target: "hyperdb_api", "connection-closed");
509                return Err(Error::closed());
510            }
511            self.read_buf.truncate(prev_len + n);
512        }
513    }
514
515    /// Async equivalent of
516    /// [`super::connection::RawConnection::drain_until_ready`]. Unbounded;
517    /// prefer [`drain_until_ready_bounded`](Self::drain_until_ready_bounded)
518    /// in destructors and other code paths where blocking indefinitely is
519    /// unacceptable. Drain errors are logged via `tracing::warn!` and then
520    /// swallowed.
521    pub async fn drain_until_ready(&mut self) {
522        let _ = self.drain_until_ready_bounded(usize::MAX).await;
523    }
524
525    /// Async equivalent of
526    /// [`super::connection::RawConnection::drain_until_ready_bounded`].
527    /// See that function's docs for the full semantics, including why we do
528    /// **not** send a `Sync` before draining (it would produce an extra
529    /// `ReadyForQuery` on the wire and corrupt the next query's response).
530    pub async fn drain_until_ready_bounded(&mut self, max_messages: usize) -> bool {
531        for i in 0..max_messages {
532            match self.read_message().await {
533                Ok(Message::ReadyForQuery(_)) => return true,
534                Ok(_) => {}
535                Err(e) => {
536                    warn!(
537                        target: "hyperdb_api_core::client",
538                        error = %e,
539                        messages_read = i,
540                        "drain_until_ready: read error mid-drain (likely closed connection); \
541                         connection marked desynchronized",
542                    );
543                    // Mirror of sync path: any mid-drain read error leaves
544                    // the connection in unknown state. See
545                    // `super::connection::RawConnection::drain_until_ready_bounded`
546                    // for the full rationale.
547                    self.desynchronized = true;
548                    return false;
549                }
550            }
551        }
552        warn!(
553            target: "hyperdb_api_core::client",
554            max_messages,
555            "drain_until_ready_bounded: exhausted budget without seeing ReadyForQuery; \
556             connection marked desynchronized and should not be reused",
557        );
558        self.desynchronized = true;
559        false
560    }
561
562    /// Async equivalent of
563    /// [`super::connection::RawConnection::consume_error`]. Parse the error
564    /// body and drain the rest of the response in one call. Semantics are
565    /// identical to the sync version, including the
566    /// [`POST_ERROR_DRAIN_CAP`](super::connection::POST_ERROR_DRAIN_CAP)
567    /// safety valve — see that function's docs for the rationale. Unbounded
568    /// drain would be particularly dangerous here because a stalled read
569    /// on the underlying async stream would hang the caller's future
570    /// indefinitely with no observable symptom; the bounded drain turns
571    /// that into a loud `tracing::warn!` plus a connection marked for
572    /// reconnect on next use.
573    pub async fn consume_error(
574        &mut self,
575        body: &crate::protocol::message::backend::ErrorResponseBody,
576    ) -> Error {
577        let err = super::connection::parse_error_response(body);
578        let _ = self
579            .drain_until_ready_bounded(super::connection::POST_ERROR_DRAIN_CAP)
580            .await;
581        err
582    }
583
584    /// Flushes the write buffer to the server (async).
585    ///
586    /// # Errors
587    ///
588    /// Returns [`Error`] (I/O) if writing the buffered bytes or flushing
589    /// the underlying async transport fails.
590    pub async fn flush(&mut self) -> Result<()> {
591        if !self.write_buf.is_empty() {
592            self.stream.write_all(&self.write_buf).await?;
593            self.stream.flush().await?;
594            self.write_buf.clear();
595        }
596        Ok(())
597    }
598
599    /// Sends a terminate message and closes the connection (async).
600    ///
601    /// # Errors
602    ///
603    /// Returns [`Error`] (I/O) if writing the `Terminate` frame or
604    /// flushing the async transport fails.
605    pub async fn terminate(&mut self) -> Result<()> {
606        frontend::terminate(&mut self.write_buf);
607        self.flush().await
608    }
609
610    /// Returns a mutable reference to the write buffer.
611    pub fn write_buf(&mut self) -> &mut BytesMut {
612        &mut self.write_buf
613    }
614
615    /// Initiates a COPY IN operation with `HyperBinary` format (async).
616    ///
617    /// # Errors
618    ///
619    /// Same failure modes as [`Self::start_copy_in_with_format`].
620    pub async fn start_copy_in(&mut self, table_name: &str, columns: &[&str]) -> Result<()> {
621        self.start_copy_in_with_format(table_name, columns, "HYPERBINARY")
622            .await
623    }
624
625    /// Initiates a COPY IN operation with a specified format (async).
626    ///
627    /// # Errors
628    ///
629    /// - Returns [`Error`] (connection) if the connection has been
630    ///   marked unhealthy.
631    /// - Returns [`Error`] (server) if the server rejects the generated
632    ///   `COPY ... FROM STDIN` statement.
633    /// - Returns [`Error`] (I/O) on transport read/write failure.
634    /// - Propagates any error from [`Self::drain_pending_copy_cancel`].
635    pub async fn start_copy_in_with_format(
636        &mut self,
637        table_name: &str,
638        columns: &[&str],
639        format: &str,
640    ) -> Result<()> {
641        self.ensure_healthy()?;
642        self.drain_pending_copy_cancel().await?;
643        let column_list = if columns.is_empty() {
644            String::new()
645        } else {
646            format!(
647                " ({})",
648                columns
649                    .iter()
650                    .map(|c| format!("\"{}\"", c.replace('"', "\"\"")))
651                    .collect::<Vec<_>>()
652                    .join(", ")
653            )
654        };
655
656        let query = format!("COPY {table_name}{column_list} FROM STDIN WITH (FORMAT {format})");
657
658        frontend::query(&query, &mut self.write_buf)?;
659        self.flush().await?;
660
661        loop {
662            let msg = self.read_message().await?;
663            match msg {
664                Message::CopyInResponse(_) => {
665                    return Ok(());
666                }
667                Message::ErrorResponse(body) => {
668                    return Err(self.consume_error(&body).await);
669                }
670                _ => {}
671            }
672        }
673    }
674
675    /// Sends COPY data to the server (sync - just buffers).
676    ///
677    /// # Errors
678    ///
679    /// Currently infallible — frame construction is pure. The `Result`
680    /// return type is preserved for forward compatibility.
681    pub fn send_copy_data(&mut self, data: &[u8]) -> Result<()> {
682        frontend::copy_data(data, &mut self.write_buf);
683        Ok(())
684    }
685
686    /// Sends COPY data directly to the stream without internal buffering (async).
687    ///
688    /// This writes the `CopyData` message directly to the TCP stream, letting
689    /// the kernel's TCP stack handle buffering. Use `flush_stream()` periodically
690    /// to ensure data is sent.
691    ///
692    /// # Errors
693    ///
694    /// - Returns [`Error`] (protocol) if `data.len() + 4` exceeds
695    ///   `u32::MAX` (PostgreSQL's per-message length cap).
696    /// - Returns [`Error`] (I/O) if flushing buffered bytes or writing
697    ///   the header / payload to the async transport fails.
698    pub async fn send_copy_data_direct(&mut self, data: &[u8]) -> Result<()> {
699        // First flush any pending buffered data
700        if !self.write_buf.is_empty() {
701            self.stream.write_all(&self.write_buf).await?;
702            self.write_buf.clear();
703        }
704
705        // Write CopyData message header + data directly to stream
706        // Message format: 'd' (1 byte) + length (4 bytes BigEndian) + data
707        let msg_len = u32::try_from(4 + data.len())
708            .map_err(|_| Error::protocol("CopyData payload exceeds u32::MAX bytes"))?;
709        let len_be = msg_len.to_be_bytes();
710        let header = [b'd', len_be[0], len_be[1], len_be[2], len_be[3]];
711        self.stream.write_all(&header).await?;
712        self.stream.write_all(data).await?;
713        Ok(())
714    }
715
716    /// Flushes the TCP stream without clearing the write buffer (async).
717    ///
718    /// Use this with `send_copy_data_direct()` to periodically ensure
719    /// data is sent to the server.
720    ///
721    /// # Errors
722    ///
723    /// Returns [`Error`] (I/O) if flushing the underlying async transport
724    /// fails.
725    pub async fn flush_stream(&mut self) -> Result<()> {
726        self.stream.flush().await?;
727        Ok(())
728    }
729
730    /// Finishes a COPY IN operation successfully (async).
731    ///
732    /// # Errors
733    ///
734    /// - Returns [`Error`] (server) when the server emits an
735    ///   `ErrorResponse` during finalization (for example, a
736    ///   constraint violation from the accumulated rows).
737    /// - Returns [`Error`] (I/O) / [`Error`] (closed) on transport
738    ///   read/write failure.
739    pub async fn finish_copy(&mut self) -> Result<u64> {
740        self.flush().await?;
741
742        frontend::copy_done(&mut self.write_buf);
743        self.flush().await?;
744
745        let mut row_count = 0u64;
746        loop {
747            let msg = self.read_message().await?;
748            match msg {
749                Message::CommandComplete(body) => {
750                    if let Ok(tag) = body.tag() {
751                        if let Some(count_str) = tag.strip_prefix("COPY ") {
752                            if let Ok(count) = count_str.trim().parse() {
753                                row_count = count;
754                            }
755                        }
756                    }
757                }
758                Message::ReadyForQuery(_) => {
759                    return Ok(row_count);
760                }
761                Message::ErrorResponse(body) => {
762                    return Err(self.consume_error(&body).await);
763                }
764                _ => {}
765            }
766        }
767    }
768
769    /// Cancels a COPY IN operation (async).
770    ///
771    /// # Errors
772    ///
773    /// Returns [`Error`] (I/O) if flushing the buffer or writing the
774    /// `CopyFail` frame fails, or [`Error`] (closed) if the server
775    /// drops the connection before returning `ReadyForQuery`.
776    pub async fn cancel_copy(&mut self, reason: &str) -> Result<()> {
777        self.flush().await?;
778
779        frontend::copy_fail(reason, &mut self.write_buf);
780        self.flush().await?;
781
782        loop {
783            let msg = self.read_message().await?;
784            match msg {
785                Message::ReadyForQuery(_) => {
786                    return Ok(());
787                }
788                Message::ErrorResponse(_) => {}
789                _ => {}
790            }
791        }
792    }
793
794    /// Executes a COPY ... TO STDOUT query and returns all output data (async).
795    ///
796    /// # Errors
797    ///
798    /// - Returns [`Error`] (connection) if the connection is unhealthy.
799    /// - Returns [`Error`] (server) when the server rejects the COPY TO
800    ///   STDOUT statement via `ErrorResponse`.
801    /// - Returns [`Error`] (I/O) / [`Error`] (closed) on transport
802    ///   read/write failure.
803    pub async fn copy_out(&mut self, query: &str) -> Result<Vec<u8>> {
804        self.ensure_healthy()?;
805        self.drain_pending_copy_cancel().await?;
806        frontend::query(query, &mut self.write_buf)?;
807        self.flush().await?;
808
809        let mut data = Vec::new();
810        let mut in_copy_out = false;
811
812        loop {
813            let msg = self.read_message().await?;
814            match msg {
815                Message::CopyOutResponse(_) => {
816                    in_copy_out = true;
817                }
818                Message::CopyData(body) if in_copy_out => {
819                    data.extend_from_slice(body.data());
820                }
821                Message::CopyDone => {
822                    in_copy_out = false;
823                }
824                Message::CommandComplete(_) => {}
825                Message::ReadyForQuery(_) => {
826                    return Ok(data);
827                }
828                Message::ErrorResponse(body) => {
829                    return Err(self.consume_error(&body).await);
830                }
831                _ => {}
832            }
833        }
834    }
835
836    /// Prepares a statement using the extended query protocol (async).
837    ///
838    /// # Errors
839    ///
840    /// - Returns [`Error`] (connection) if the connection is unhealthy.
841    /// - Returns [`Error`] (server) if the server rejects the `Parse`
842    ///   request (SQL syntax error, unknown type OIDs, etc.).
843    /// - Returns [`Error`] (I/O) on transport read/write failure.
844    pub async fn prepare(
845        &mut self,
846        name: &str,
847        query: &str,
848        param_types: &[crate::types::Oid],
849    ) -> Result<(Vec<crate::types::Oid>, Vec<super::statement::Column>)> {
850        use super::statement::{Column, ColumnFormat};
851
852        self.ensure_healthy()?;
853        self.drain_pending_copy_cancel().await?;
854
855        // Send Parse message
856        frontend::parse(name, query, param_types, &mut self.write_buf)?;
857
858        // Send Describe message for the statement
859        frontend::describe(b'S', name, &mut self.write_buf)?;
860
861        // Send Sync to get responses
862        frontend::sync(&mut self.write_buf);
863        self.flush().await?;
864
865        // Process responses
866        let mut parsed_params = Vec::new();
867        let mut parsed_columns = Vec::new();
868
869        loop {
870            let msg = self.read_message().await?;
871            match msg {
872                Message::ParseComplete => {}
873                Message::ParameterDescription(desc) => {
874                    for oid in desc.parameters().filter_map(std::result::Result::ok) {
875                        parsed_params.push(oid);
876                    }
877                }
878                Message::RowDescription(desc) => {
879                    for f in desc.fields().filter_map(std::result::Result::ok) {
880                        parsed_columns.push(Column::new(
881                            f.name().to_string(),
882                            f.type_oid(),
883                            f.type_modifier(),
884                            ColumnFormat::from_code(f.format()),
885                        ));
886                    }
887                }
888                Message::NoData => {}
889                Message::ReadyForQuery(_) => {
890                    break;
891                }
892                Message::ErrorResponse(body) => {
893                    return Err(self.consume_error(&body).await);
894                }
895                _ => {}
896            }
897        }
898
899        Ok((parsed_params, parsed_columns))
900    }
901
902    /// Executes a prepared statement with parameters (async).
903    ///
904    /// # Errors
905    ///
906    /// - Returns [`Error`] (connection) if the connection is unhealthy.
907    /// - Returns [`Error`] (server) if `Bind` / `Execute` fails on the
908    ///   server (parameter type mismatch, constraint violation, etc.).
909    /// - Returns [`Error`] (I/O) / [`Error`] (closed) on transport
910    ///   read/write failure.
911    /// - Propagates row-construction errors from
912    ///   `super::row::Row::new` if a `DataRow` cannot be decoded
913    ///   against the reported `RowDescription`.
914    pub async fn execute_prepared(
915        &mut self,
916        statement_name: &str,
917        params: &[Option<&[u8]>],
918        column_count: usize,
919    ) -> Result<Vec<super::row::Row>> {
920        use super::statement::Column;
921        use std::sync::Arc;
922
923        self.ensure_healthy()?;
924        // Prepared-statement execution writes Bind/Execute/Sync into the
925        // buffer and flushes at the end; a pending CopyFail would be
926        // flushed together with our bind bytes and corrupt the response
927        // stream. See `start_query_binary` for the full argument.
928        self.drain_pending_copy_cancel().await?;
929        // Bind parameters (all in binary format)
930        let param_formats: Vec<i16> = vec![1; params.len()];
931        let result_formats: Vec<i16> = vec![1; column_count];
932
933        frontend::bind(
934            "",
935            statement_name,
936            &param_formats,
937            params,
938            &result_formats,
939            &mut self.write_buf,
940        )?;
941
942        frontend::execute("", 0, &mut self.write_buf)?;
943        frontend::sync(&mut self.write_buf);
944        self.flush().await?;
945
946        let mut rows = Vec::new();
947        let mut columns: Option<Arc<Vec<Column>>> = None;
948
949        loop {
950            let msg = self.read_message().await?;
951            match msg {
952                Message::BindComplete => {}
953                Message::RowDescription(desc) => {
954                    let mut cols = Vec::new();
955                    for f in desc.fields().filter_map(std::result::Result::ok) {
956                        cols.push(Column::new(
957                            f.name().to_string(),
958                            f.type_oid(),
959                            f.type_modifier(),
960                            super::statement::ColumnFormat::from_code(f.format()),
961                        ));
962                    }
963                    columns = Some(Arc::new(cols));
964                }
965                Message::DataRow(data) => {
966                    if let Some(ref cols) = columns {
967                        rows.push(super::row::Row::new(Arc::clone(cols), data)?);
968                    }
969                }
970                Message::CommandComplete(_) => {}
971                Message::EmptyQueryResponse => {}
972                Message::ReadyForQuery(_) => {
973                    break;
974                }
975                Message::ErrorResponse(body) => {
976                    return Err(self.consume_error(&body).await);
977                }
978                _ => {}
979            }
980        }
981
982        Ok(rows)
983    }
984
985    /// Executes a prepared statement that doesn't return rows (async).
986    ///
987    /// # Errors
988    ///
989    /// Same failure modes as [`Self::execute_prepared`] (excluding
990    /// row-construction errors — this path never builds rows).
991    pub async fn execute_prepared_no_result(
992        &mut self,
993        statement_name: &str,
994        params: &[Option<&[u8]>],
995    ) -> Result<u64> {
996        self.ensure_healthy()?;
997        // See `execute_prepared` and `start_query_binary` for why we must
998        // drain any pending COPY cancel before writing new bytes.
999        self.drain_pending_copy_cancel().await?;
1000        let param_formats: Vec<i16> = vec![1; params.len()];
1001        let result_formats: Vec<i16> = vec![];
1002
1003        frontend::bind(
1004            "",
1005            statement_name,
1006            &param_formats,
1007            params,
1008            &result_formats,
1009            &mut self.write_buf,
1010        )?;
1011
1012        frontend::execute("", 0, &mut self.write_buf)?;
1013        frontend::sync(&mut self.write_buf);
1014        self.flush().await?;
1015
1016        let mut affected_rows = 0u64;
1017
1018        loop {
1019            let msg = self.read_message().await?;
1020            match msg {
1021                Message::BindComplete => {}
1022                Message::CommandComplete(body) => {
1023                    if let Ok(tag) = body.tag() {
1024                        // Parse formats like "INSERT 0 5", "UPDATE 10", "DELETE 3"
1025                        let parts: Vec<&str> = tag.split_whitespace().collect();
1026                        match parts.first() {
1027                            Some(&"INSERT") => {
1028                                if let Some(count) = parts.get(2) {
1029                                    affected_rows = count.parse().unwrap_or(0);
1030                                }
1031                            }
1032                            Some(&"UPDATE" | &"DELETE" | &"SELECT" | &"COPY") => {
1033                                if let Some(count) = parts.get(1) {
1034                                    affected_rows = count.parse().unwrap_or(0);
1035                                }
1036                            }
1037                            _ => {}
1038                        }
1039                    }
1040                }
1041                Message::EmptyQueryResponse => {}
1042                Message::ReadyForQuery(_) => {
1043                    break;
1044                }
1045                Message::ErrorResponse(body) => {
1046                    return Err(self.consume_error(&body).await);
1047                }
1048                _ => {}
1049            }
1050        }
1051
1052        Ok(affected_rows)
1053    }
1054
1055    /// Closes a prepared statement (async).
1056    ///
1057    /// # Errors
1058    ///
1059    /// - Returns [`Error`] (connection) if the connection is unhealthy.
1060    /// - Returns [`Error`] (server) if the server reports an `ErrorResponse`
1061    ///   during `Close`/`Sync`.
1062    /// - Returns [`Error`] (I/O) / [`Error`] (closed) on transport
1063    ///   read/write failure.
1064    /// - Propagates any error from [`Self::drain_pending_copy_cancel`].
1065    pub async fn close_statement(&mut self, statement_name: &str) -> Result<()> {
1066        self.ensure_healthy()?;
1067        // Close + Sync get flushed together; a pending CopyFail would
1068        // share the flush and corrupt the response stream. See
1069        // `start_query_binary` for the full argument.
1070        self.drain_pending_copy_cancel().await?;
1071        frontend::close(b'S', statement_name, &mut self.write_buf)?;
1072        frontend::sync(&mut self.write_buf);
1073        self.flush().await?;
1074
1075        loop {
1076            let msg = self.read_message().await?;
1077            match msg {
1078                Message::CloseComplete => {}
1079                Message::ReadyForQuery(_) => {
1080                    return Ok(());
1081                }
1082                Message::ErrorResponse(body) => {
1083                    return Err(self.consume_error(&body).await);
1084                }
1085                _ => {}
1086            }
1087        }
1088    }
1089}