hyperdb_api_core/client/async_connection.rs
1// Copyright (c) 2026, Salesforce, Inc. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Async low-level connection handling.
5//!
6//! This module provides [`AsyncRawConnection`], the async version of [`RawConnection`](super::connection::RawConnection).
7//! It uses tokio's async I/O traits for non-blocking network operations.
8
9use std::collections::HashMap;
10
11use bytes::BytesMut;
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tracing::{debug, info, warn};
14
15use crate::protocol::message::{backend::Message, frontend};
16
17use super::auth::{self, AuthState};
18use super::error::{Error, Result};
19
20/// An async raw connection to a Hyper server.
21///
22/// This is the async equivalent of [`RawConnection`](super::connection::RawConnection),
23/// using tokio's async I/O traits instead of std's sync I/O.
24///
25/// The connection is generic over the stream type `S`, allowing it to work
26/// with different transport mechanisms (`TcpStream`, `TlsStream`, etc.) as long as they
27/// implement `AsyncRead + AsyncWrite + Unpin`.
28#[derive(Debug)]
29pub struct AsyncRawConnection<S> {
30 /// The underlying async I/O stream.
31 stream: S,
32 /// Buffer for reading incoming messages from the server.
33 read_buf: BytesMut,
34 /// Buffer for writing outgoing messages to the server.
35 write_buf: BytesMut,
36 /// Backend process ID (for cancel requests).
37 process_id: i32,
38 /// Secret key for authenticating cancel requests.
39 secret_key: i32,
40 /// Server parameters received during startup.
41 server_params: HashMap<String, String>,
42 /// Set by `AsyncCopyInWriter::Drop` when a COPY session is abandoned.
43 /// The `CopyFail` message has been written to `write_buf` but not flushed.
44 /// The next async operation must flush and drain the server response
45 /// (`ErrorResponse` + `ReadyForQuery`) before proceeding.
46 pending_copy_cancel: bool,
47 /// Sticky flag mirroring
48 /// [`RawConnection`](super::connection::RawConnection)'s
49 /// `desynchronized` field. Set when a bounded drain exhausts its
50 /// budget or hits a mid-drain I/O error; never cleared. See
51 /// [`Self::is_healthy`] and [`Self::ensure_healthy`] for the
52 /// consumer-facing API.
53 desynchronized: bool,
54}
55
56impl<S> AsyncRawConnection<S>
57where
58 S: AsyncRead + AsyncWrite + Unpin,
59{
60 /// Creates a new async raw connection from a stream.
61 ///
62 /// Initializes read and write buffers with default capacity (64 KB each).
63 /// The connection is not yet authenticated - call `startup()` to begin
64 /// the connection handshake.
65 pub fn new(stream: S) -> Self {
66 AsyncRawConnection {
67 stream,
68 read_buf: BytesMut::with_capacity(64 * 1024),
69 write_buf: BytesMut::with_capacity(64 * 1024),
70 process_id: 0,
71 secret_key: 0,
72 server_params: HashMap::new(),
73 pending_copy_cancel: false,
74 desynchronized: false,
75 }
76 }
77
78 /// Returns `true` if this connection is still in a known-good state
79 /// and safe to use for new requests. See
80 /// [`super::connection::RawConnection::is_healthy`] for the full
81 /// semantics — this is the async mirror with identical behavior.
82 pub fn is_healthy(&self) -> bool {
83 !self.desynchronized
84 }
85
86 /// Marks this connection as desynchronized.
87 ///
88 /// Used by async result streams that are dropped mid-iteration: the
89 /// [`Drop`] impl cannot `await` to drain trailing `ErrorResponse +
90 /// ReadyForQuery` messages after sending a cancel, so it flags the
91 /// connection so the next operation short-circuits with a clear error
92 /// rather than hanging or misinterpreting stale server output.
93 pub fn mark_desynchronized(&mut self) {
94 self.desynchronized = true;
95 }
96
97 /// Async mirror of
98 /// [`super::connection::RawConnection::ensure_healthy`]. Called from
99 /// the entry point of every `pub async fn` that initiates a new
100 /// server request to short-circuit operations on a desynchronized
101 /// connection before any bytes hit the wire.
102 pub(crate) fn ensure_healthy(&self) -> Result<()> {
103 if self.desynchronized {
104 return Err(Error::new(
105 super::error::ErrorKind::Connection,
106 "connection is desynchronized from the server and cannot be reused; \
107 discard it and open a new one",
108 ));
109 }
110 Ok(())
111 }
112
113 /// Returns the process ID assigned by the server.
114 pub fn process_id(&self) -> i32 {
115 self.process_id
116 }
117
118 /// Returns the secret key for cancel requests.
119 pub fn secret_key(&self) -> i32 {
120 self.secret_key
121 }
122
123 /// Returns a reference to the underlying stream.
124 pub fn stream(&self) -> &S {
125 &self.stream
126 }
127
128 /// Returns a mutable reference to the underlying stream.
129 pub fn stream_mut(&mut self) -> &mut S {
130 &mut self.stream
131 }
132
133 /// Returns a server parameter value by name.
134 pub fn parameter_status(&self, name: &str) -> Option<&str> {
135 self.server_params
136 .get(name)
137 .map(std::string::String::as_str)
138 }
139
140 /// Queues a `CopyFail` message in the write buffer (synchronous).
141 ///
142 /// Called from `AsyncCopyInWriter::Drop` when a COPY session is abandoned
143 /// without `finish()` or `cancel()`. The `CopyFail` is written to the buffer
144 /// but NOT flushed (we can't do async I/O from `Drop`). The next async
145 /// operation will call [`drain_pending_copy_cancel`](Self::drain_pending_copy_cancel) to flush and drain
146 /// the server's `ErrorResponse` + `ReadyForQuery` before proceeding.
147 pub fn queue_copy_fail(&mut self, reason: &str) {
148 frontend::copy_fail(reason, &mut self.write_buf);
149 self.pending_copy_cancel = true;
150 }
151
152 /// Drains a pending COPY cancel that was queued by `queue_copy_fail()`.
153 ///
154 /// If `pending_copy_cancel` is set, this flushes the `CopyFail` message to
155 /// the server and reads messages until `ReadyForQuery`, restoring the
156 /// connection to a usable state. Called automatically at the start of
157 /// new operations (`simple_query`, `query_binary`, `start_copy_in*`).
158 ///
159 /// # Errors
160 ///
161 /// Returns [`Error`] (I/O) if flushing the queued `CopyFail` or
162 /// reading the server's drain responses fails. A successful drain
163 /// clears `pending_copy_cancel`.
164 pub async fn drain_pending_copy_cancel(&mut self) -> Result<()> {
165 if !self.pending_copy_cancel {
166 return Ok(());
167 }
168
169 // Flush the queued CopyFail message
170 self.flush().await?;
171
172 // Drain messages until the connection is back in ReadyForQuery state
173 loop {
174 let msg = self.read_message().await?;
175 match msg {
176 Message::ReadyForQuery(_) => {
177 self.pending_copy_cancel = false;
178 debug!(
179 target: "hyperdb_api_core::client",
180 "drained pending COPY cancel — connection restored"
181 );
182 return Ok(());
183 }
184 Message::ErrorResponse(_) => {
185 // Expected — server confirms the cancel
186 }
187 _ => {
188 // Ignore other messages (e.g., NoticeResponse)
189 }
190 }
191 }
192 }
193
194 /// Sends a startup message and performs initial handshake (async).
195 ///
196 /// # Errors
197 ///
198 /// - Returns [`Error`] (auth) when the server requests an
199 /// auth method and no password is supplied, when the offered
200 /// SASL mechanisms exclude SCRAM-SHA-256, or when SCRAM state
201 /// is missing at the SASL-continue / SASL-final step.
202 /// - Returns [`Error`] (server) when the server sends an `ErrorResponse`
203 /// during startup (unknown user, unknown database, etc.).
204 /// - Returns [`Error`] (protocol) if a message arrives out of
205 /// sequence.
206 /// - Returns [`Error`] (I/O) on transport read/write failure.
207 pub async fn startup(&mut self, params: &[(&str, &str)], password: Option<&str>) -> Result<()> {
208 // Send startup message
209 frontend::startup_message(params, &mut self.write_buf)?;
210 self.flush().await?;
211
212 // Handle authentication
213 let mut auth_state: Option<AuthState> = None;
214
215 loop {
216 let msg = self.read_message().await?;
217 match msg {
218 Message::AuthenticationOk => {
219 info!(target: "hyperdb_api", "connection-auth-success");
220 }
221 Message::AuthenticationCleartextPassword => {
222 debug!(target: "hyperdb_api", method = "cleartext", "connection-auth-method");
223 let password = password.ok_or_else(|| {
224 Error::authentication(
225 "server requested cleartext password but none provided",
226 )
227 })?;
228 frontend::password_message(password, &mut self.write_buf)?;
229 self.flush().await?;
230 }
231 Message::AuthenticationMd5Password(body) => {
232 debug!(target: "hyperdb_api", method = "MD5", "connection-auth-method");
233 let password = password.ok_or_else(|| {
234 Error::authentication("server requested MD5 password but none provided")
235 })?;
236 let user = params
237 .iter()
238 .find(|(k, _)| *k == "user")
239 .map_or("", |(_, v)| *v);
240
241 let md5_response = auth::compute_md5_password(user, password, &body.salt());
242 frontend::password_message(&md5_response, &mut self.write_buf)?;
243 self.flush().await?;
244 }
245 Message::AuthenticationSasl(body) => {
246 debug!(target: "hyperdb_api", method = "SCRAM-SHA-256", "connection-auth-method");
247 let password = password.ok_or_else(|| {
248 Error::authentication(
249 "server requested SASL authentication but no password provided",
250 )
251 })?;
252
253 let mechanisms: Vec<&str> = body.mechanisms().collect();
254 if !mechanisms.contains(&"SCRAM-SHA-256") {
255 return Err(Error::authentication(format!(
256 "server offered unsupported SASL mechanisms: {mechanisms:?}"
257 )));
258 }
259
260 let (state, client_first) = auth::scram_client_first(password)?;
261 auth_state = Some(state);
262
263 frontend::sasl_initial_response(
264 "SCRAM-SHA-256",
265 &client_first,
266 &mut self.write_buf,
267 )?;
268 self.flush().await?;
269 }
270 Message::AuthenticationSaslContinue(body) => {
271 let state = auth_state.take().ok_or_else(|| {
272 Error::authentication("received SASL continue without initial state")
273 })?;
274
275 let server_first = body.data();
276 let (new_state, client_final) = auth::scram_client_final(state, server_first)?;
277 auth_state = Some(new_state);
278
279 frontend::sasl_response(&client_final, &mut self.write_buf)?;
280 self.flush().await?;
281 }
282 Message::AuthenticationSaslFinal(body) => {
283 let state = auth_state.take().ok_or_else(|| {
284 Error::authentication("received SASL final without state")
285 })?;
286 auth::scram_verify_server(state, body.data())?;
287 }
288 Message::BackendKeyData(data) => {
289 self.process_id = data.process_id();
290 self.secret_key = data.secret_key();
291 }
292 Message::ParameterStatus(body) => {
293 if let (Ok(name), Ok(value)) = (body.name(), body.value()) {
294 self.server_params
295 .insert(name.to_string(), value.to_string());
296 }
297 }
298 Message::ReadyForQuery(_) => {
299 return Ok(());
300 }
301 Message::ErrorResponse(body) => {
302 return Err(self.consume_error(&body).await);
303 }
304 _ => {
305 return Err(Error::protocol("unexpected message during startup"));
306 }
307 }
308 }
309 }
310
311 /// Sends a simple query and returns all messages until `ReadyForQuery` (async).
312 ///
313 /// # Errors
314 ///
315 /// - Returns [`Error`] (connection) if the connection has been
316 /// marked unhealthy.
317 /// - Returns [`Error`] (server) when the server emits an
318 /// `ErrorResponse` (SQL error, constraint violation, etc.).
319 /// - Returns [`Error`] (I/O) / [`Error`] (closed) on transport
320 /// read/write failure.
321 /// - Propagates any error from
322 /// [`Self::drain_pending_copy_cancel`] when a queued `CopyFail`
323 /// needs to be flushed first.
324 pub async fn simple_query(&mut self, query: &str) -> Result<Vec<Message>> {
325 self.ensure_healthy()?;
326 self.drain_pending_copy_cancel().await?;
327 frontend::query(query, &mut self.write_buf)?;
328 self.flush().await?;
329
330 let mut messages = Vec::new();
331 loop {
332 let msg = self.read_message().await?;
333 match &msg {
334 Message::ReadyForQuery(_) => {
335 messages.push(msg);
336 return Ok(messages);
337 }
338 Message::ErrorResponse(body) => {
339 return Err(self.consume_error(body).await);
340 }
341 _ => {
342 messages.push(msg);
343 }
344 }
345 }
346 }
347
348 /// Sends a query using extended protocol with binary format results (async).
349 ///
350 /// # Errors
351 ///
352 /// Same failure modes as [`Self::simple_query`].
353 pub async fn query_binary(&mut self, query: &str) -> Result<Vec<Message>> {
354 self.ensure_healthy()?;
355 self.drain_pending_copy_cancel().await?;
356 const HYPER_BINARY_FORMAT: i16 = 2;
357
358 frontend::parse("", query, &[], &mut self.write_buf)?;
359 frontend::bind(
360 "",
361 "",
362 &[],
363 &[],
364 &[HYPER_BINARY_FORMAT],
365 &mut self.write_buf,
366 )?;
367 frontend::describe(b'P', "", &mut self.write_buf)?;
368 frontend::execute("", 0, &mut self.write_buf)?;
369 frontend::sync(&mut self.write_buf);
370
371 self.flush().await?;
372
373 let mut messages = Vec::new();
374 loop {
375 let msg = self.read_message().await?;
376 match &msg {
377 Message::ReadyForQuery(_) => {
378 messages.push(msg);
379 return Ok(messages);
380 }
381 Message::ErrorResponse(body) => {
382 return Err(self.consume_error(body).await);
383 }
384 _ => {
385 messages.push(msg);
386 }
387 }
388 }
389 }
390
391 /// Starts a binary query but leaves result consumption to the caller (async).
392 ///
393 /// # Errors
394 ///
395 /// - Returns [`Error`] (connection) if the connection is unhealthy.
396 /// - Returns [`Error`] (I/O) on transport write failure.
397 /// - Propagates any error from [`Self::drain_pending_copy_cancel`].
398 pub async fn start_query_binary(&mut self, query: &str) -> Result<()> {
399 self.ensure_healthy()?;
400 // Drain any CopyFail queued by `AsyncCopyInWriter::Drop` before
401 // writing the extended-query bytes. Without this, the flush at
402 // the end of this method would send [CopyFail | Parse | Bind |
403 // Describe | Execute | Sync] in a single buffer and the server
404 // would answer with CopyFail's ErrorResponse+ReadyForQuery
405 // interleaved with our query's responses — the read loop would
406 // then misattribute the COPY error to this query.
407 self.drain_pending_copy_cancel().await?;
408 const HYPER_BINARY_FORMAT: i16 = 2;
409
410 frontend::parse("", query, &[], &mut self.write_buf)?;
411 frontend::bind(
412 "",
413 "",
414 &[],
415 &[],
416 &[HYPER_BINARY_FORMAT],
417 &mut self.write_buf,
418 )?;
419 frontend::describe(b'P', "", &mut self.write_buf)?;
420 frontend::execute("", 0, &mut self.write_buf)?;
421 frontend::sync(&mut self.write_buf);
422
423 self.flush().await
424 }
425
426 /// Starts a simple query but leaves result consumption to the caller (async).
427 ///
428 /// # Errors
429 ///
430 /// Same failure modes as [`Self::start_query_binary`].
431 pub async fn start_simple_query(&mut self, query: &str) -> Result<()> {
432 self.ensure_healthy()?;
433 // See `start_query_binary` for why the pending-copy-cancel drain
434 // is required before writing any new query bytes.
435 self.drain_pending_copy_cancel().await?;
436 frontend::query(query, &mut self.write_buf)?;
437 self.flush().await
438 }
439
440 /// Starts an **execute** of a prepared statement but leaves result
441 /// consumption to the caller (async).
442 ///
443 /// Async mirror of
444 /// [`super::connection::RawConnection::start_execute_prepared`]. See
445 /// that method's docs for the split format-code rationale (params
446 /// use `1` = PG binary/BE, results use `2` = HyperBinary/LE).
447 ///
448 /// # Errors
449 ///
450 /// - Returns [`Error`] (connection) if the connection is unhealthy.
451 /// - Returns [`Error`] (I/O) on transport write failure.
452 /// - Propagates any error from [`Self::drain_pending_copy_cancel`].
453 pub async fn start_execute_prepared(
454 &mut self,
455 statement_name: &str,
456 params: &[Option<&[u8]>],
457 column_count: usize,
458 ) -> Result<()> {
459 self.ensure_healthy()?;
460 // Same rationale as `start_query_binary` for draining a pending
461 // CopyFail before writing new extended-query bytes.
462 self.drain_pending_copy_cancel().await?;
463
464 const PG_BINARY_FORMAT: i16 = 1;
465 const HYPER_BINARY_FORMAT: i16 = 2;
466 let param_formats: Vec<i16> = vec![PG_BINARY_FORMAT; params.len()];
467 let result_formats: Vec<i16> = vec![HYPER_BINARY_FORMAT; column_count];
468
469 frontend::bind(
470 "", // unnamed portal
471 statement_name,
472 ¶m_formats,
473 params,
474 &result_formats,
475 &mut self.write_buf,
476 )?;
477 frontend::execute("", 0, &mut self.write_buf)?;
478 frontend::sync(&mut self.write_buf);
479
480 self.flush().await
481 }
482
483 /// Reads a single message from the server (async).
484 ///
485 /// # Errors
486 ///
487 /// - Returns [`Error`] (I/O) if reading from the transport fails or
488 /// if [`Message::parse`] reports a malformed frame.
489 /// - Returns [`Error`] (closed) when the transport reaches EOF
490 /// (server closed the connection).
491 pub async fn read_message(&mut self) -> Result<Message> {
492 loop {
493 if let Some(msg) = Message::parse(&mut self.read_buf).map_err(Error::io)? {
494 return Ok(msg);
495 }
496
497 // Need more data — read directly into the spare capacity of
498 // `read_buf`, no temporary buffer or `extend_from_slice` memcpy.
499 // See the sync mirror in
500 // [`super::connection::RawConnection::read_message`] for the
501 // full rationale on the 64 KiB ceiling and Windows-loopback
502 // syscall amplification.
503 let prev_len = self.read_buf.len();
504 self.read_buf.resize(prev_len + 64 * 1024, 0);
505 let n = self.stream.read(&mut self.read_buf[prev_len..]).await?;
506 if n == 0 {
507 self.read_buf.truncate(prev_len);
508 warn!(target: "hyperdb_api", "connection-closed");
509 return Err(Error::closed());
510 }
511 self.read_buf.truncate(prev_len + n);
512 }
513 }
514
515 /// Async equivalent of
516 /// [`super::connection::RawConnection::drain_until_ready`]. Unbounded;
517 /// prefer [`drain_until_ready_bounded`](Self::drain_until_ready_bounded)
518 /// in destructors and other code paths where blocking indefinitely is
519 /// unacceptable. Drain errors are logged via `tracing::warn!` and then
520 /// swallowed.
521 pub async fn drain_until_ready(&mut self) {
522 let _ = self.drain_until_ready_bounded(usize::MAX).await;
523 }
524
525 /// Async equivalent of
526 /// [`super::connection::RawConnection::drain_until_ready_bounded`].
527 /// See that function's docs for the full semantics, including why we do
528 /// **not** send a `Sync` before draining (it would produce an extra
529 /// `ReadyForQuery` on the wire and corrupt the next query's response).
530 pub async fn drain_until_ready_bounded(&mut self, max_messages: usize) -> bool {
531 for i in 0..max_messages {
532 match self.read_message().await {
533 Ok(Message::ReadyForQuery(_)) => return true,
534 Ok(_) => {}
535 Err(e) => {
536 warn!(
537 target: "hyperdb_api_core::client",
538 error = %e,
539 messages_read = i,
540 "drain_until_ready: read error mid-drain (likely closed connection); \
541 connection marked desynchronized",
542 );
543 // Mirror of sync path: any mid-drain read error leaves
544 // the connection in unknown state. See
545 // `super::connection::RawConnection::drain_until_ready_bounded`
546 // for the full rationale.
547 self.desynchronized = true;
548 return false;
549 }
550 }
551 }
552 warn!(
553 target: "hyperdb_api_core::client",
554 max_messages,
555 "drain_until_ready_bounded: exhausted budget without seeing ReadyForQuery; \
556 connection marked desynchronized and should not be reused",
557 );
558 self.desynchronized = true;
559 false
560 }
561
562 /// Async equivalent of
563 /// [`super::connection::RawConnection::consume_error`]. Parse the error
564 /// body and drain the rest of the response in one call. Semantics are
565 /// identical to the sync version, including the
566 /// [`POST_ERROR_DRAIN_CAP`](super::connection::POST_ERROR_DRAIN_CAP)
567 /// safety valve — see that function's docs for the rationale. Unbounded
568 /// drain would be particularly dangerous here because a stalled read
569 /// on the underlying async stream would hang the caller's future
570 /// indefinitely with no observable symptom; the bounded drain turns
571 /// that into a loud `tracing::warn!` plus a connection marked for
572 /// reconnect on next use.
573 pub async fn consume_error(
574 &mut self,
575 body: &crate::protocol::message::backend::ErrorResponseBody,
576 ) -> Error {
577 let err = super::connection::parse_error_response(body);
578 let _ = self
579 .drain_until_ready_bounded(super::connection::POST_ERROR_DRAIN_CAP)
580 .await;
581 err
582 }
583
584 /// Flushes the write buffer to the server (async).
585 ///
586 /// # Errors
587 ///
588 /// Returns [`Error`] (I/O) if writing the buffered bytes or flushing
589 /// the underlying async transport fails.
590 pub async fn flush(&mut self) -> Result<()> {
591 if !self.write_buf.is_empty() {
592 self.stream.write_all(&self.write_buf).await?;
593 self.stream.flush().await?;
594 self.write_buf.clear();
595 }
596 Ok(())
597 }
598
599 /// Sends a terminate message and closes the connection (async).
600 ///
601 /// # Errors
602 ///
603 /// Returns [`Error`] (I/O) if writing the `Terminate` frame or
604 /// flushing the async transport fails.
605 pub async fn terminate(&mut self) -> Result<()> {
606 frontend::terminate(&mut self.write_buf);
607 self.flush().await
608 }
609
610 /// Returns a mutable reference to the write buffer.
611 pub fn write_buf(&mut self) -> &mut BytesMut {
612 &mut self.write_buf
613 }
614
615 /// Initiates a COPY IN operation with `HyperBinary` format (async).
616 ///
617 /// # Errors
618 ///
619 /// Same failure modes as [`Self::start_copy_in_with_format`].
620 pub async fn start_copy_in(&mut self, table_name: &str, columns: &[&str]) -> Result<()> {
621 self.start_copy_in_with_format(table_name, columns, "HYPERBINARY")
622 .await
623 }
624
625 /// Initiates a COPY IN operation with a specified format (async).
626 ///
627 /// # Errors
628 ///
629 /// - Returns [`Error`] (connection) if the connection has been
630 /// marked unhealthy.
631 /// - Returns [`Error`] (server) if the server rejects the generated
632 /// `COPY ... FROM STDIN` statement.
633 /// - Returns [`Error`] (I/O) on transport read/write failure.
634 /// - Propagates any error from [`Self::drain_pending_copy_cancel`].
635 pub async fn start_copy_in_with_format(
636 &mut self,
637 table_name: &str,
638 columns: &[&str],
639 format: &str,
640 ) -> Result<()> {
641 self.ensure_healthy()?;
642 self.drain_pending_copy_cancel().await?;
643 let column_list = if columns.is_empty() {
644 String::new()
645 } else {
646 format!(
647 " ({})",
648 columns
649 .iter()
650 .map(|c| format!("\"{}\"", c.replace('"', "\"\"")))
651 .collect::<Vec<_>>()
652 .join(", ")
653 )
654 };
655
656 let query = format!("COPY {table_name}{column_list} FROM STDIN WITH (FORMAT {format})");
657
658 frontend::query(&query, &mut self.write_buf)?;
659 self.flush().await?;
660
661 loop {
662 let msg = self.read_message().await?;
663 match msg {
664 Message::CopyInResponse(_) => {
665 return Ok(());
666 }
667 Message::ErrorResponse(body) => {
668 return Err(self.consume_error(&body).await);
669 }
670 _ => {}
671 }
672 }
673 }
674
675 /// Sends COPY data to the server (sync - just buffers).
676 ///
677 /// # Errors
678 ///
679 /// Currently infallible — frame construction is pure. The `Result`
680 /// return type is preserved for forward compatibility.
681 pub fn send_copy_data(&mut self, data: &[u8]) -> Result<()> {
682 frontend::copy_data(data, &mut self.write_buf);
683 Ok(())
684 }
685
686 /// Sends COPY data directly to the stream without internal buffering (async).
687 ///
688 /// This writes the `CopyData` message directly to the TCP stream, letting
689 /// the kernel's TCP stack handle buffering. Use `flush_stream()` periodically
690 /// to ensure data is sent.
691 ///
692 /// # Errors
693 ///
694 /// - Returns [`Error`] (protocol) if `data.len() + 4` exceeds
695 /// `u32::MAX` (PostgreSQL's per-message length cap).
696 /// - Returns [`Error`] (I/O) if flushing buffered bytes or writing
697 /// the header / payload to the async transport fails.
698 pub async fn send_copy_data_direct(&mut self, data: &[u8]) -> Result<()> {
699 // First flush any pending buffered data
700 if !self.write_buf.is_empty() {
701 self.stream.write_all(&self.write_buf).await?;
702 self.write_buf.clear();
703 }
704
705 // Write CopyData message header + data directly to stream
706 // Message format: 'd' (1 byte) + length (4 bytes BigEndian) + data
707 let msg_len = u32::try_from(4 + data.len())
708 .map_err(|_| Error::protocol("CopyData payload exceeds u32::MAX bytes"))?;
709 let len_be = msg_len.to_be_bytes();
710 let header = [b'd', len_be[0], len_be[1], len_be[2], len_be[3]];
711 self.stream.write_all(&header).await?;
712 self.stream.write_all(data).await?;
713 Ok(())
714 }
715
716 /// Flushes the TCP stream without clearing the write buffer (async).
717 ///
718 /// Use this with `send_copy_data_direct()` to periodically ensure
719 /// data is sent to the server.
720 ///
721 /// # Errors
722 ///
723 /// Returns [`Error`] (I/O) if flushing the underlying async transport
724 /// fails.
725 pub async fn flush_stream(&mut self) -> Result<()> {
726 self.stream.flush().await?;
727 Ok(())
728 }
729
730 /// Finishes a COPY IN operation successfully (async).
731 ///
732 /// # Errors
733 ///
734 /// - Returns [`Error`] (server) when the server emits an
735 /// `ErrorResponse` during finalization (for example, a
736 /// constraint violation from the accumulated rows).
737 /// - Returns [`Error`] (I/O) / [`Error`] (closed) on transport
738 /// read/write failure.
739 pub async fn finish_copy(&mut self) -> Result<u64> {
740 self.flush().await?;
741
742 frontend::copy_done(&mut self.write_buf);
743 self.flush().await?;
744
745 let mut row_count = 0u64;
746 loop {
747 let msg = self.read_message().await?;
748 match msg {
749 Message::CommandComplete(body) => {
750 if let Ok(tag) = body.tag() {
751 if let Some(count_str) = tag.strip_prefix("COPY ") {
752 if let Ok(count) = count_str.trim().parse() {
753 row_count = count;
754 }
755 }
756 }
757 }
758 Message::ReadyForQuery(_) => {
759 return Ok(row_count);
760 }
761 Message::ErrorResponse(body) => {
762 return Err(self.consume_error(&body).await);
763 }
764 _ => {}
765 }
766 }
767 }
768
769 /// Cancels a COPY IN operation (async).
770 ///
771 /// # Errors
772 ///
773 /// Returns [`Error`] (I/O) if flushing the buffer or writing the
774 /// `CopyFail` frame fails, or [`Error`] (closed) if the server
775 /// drops the connection before returning `ReadyForQuery`.
776 pub async fn cancel_copy(&mut self, reason: &str) -> Result<()> {
777 self.flush().await?;
778
779 frontend::copy_fail(reason, &mut self.write_buf);
780 self.flush().await?;
781
782 loop {
783 let msg = self.read_message().await?;
784 match msg {
785 Message::ReadyForQuery(_) => {
786 return Ok(());
787 }
788 Message::ErrorResponse(_) => {}
789 _ => {}
790 }
791 }
792 }
793
794 /// Executes a COPY ... TO STDOUT query and returns all output data (async).
795 ///
796 /// # Errors
797 ///
798 /// - Returns [`Error`] (connection) if the connection is unhealthy.
799 /// - Returns [`Error`] (server) when the server rejects the COPY TO
800 /// STDOUT statement via `ErrorResponse`.
801 /// - Returns [`Error`] (I/O) / [`Error`] (closed) on transport
802 /// read/write failure.
803 pub async fn copy_out(&mut self, query: &str) -> Result<Vec<u8>> {
804 self.ensure_healthy()?;
805 self.drain_pending_copy_cancel().await?;
806 frontend::query(query, &mut self.write_buf)?;
807 self.flush().await?;
808
809 let mut data = Vec::new();
810 let mut in_copy_out = false;
811
812 loop {
813 let msg = self.read_message().await?;
814 match msg {
815 Message::CopyOutResponse(_) => {
816 in_copy_out = true;
817 }
818 Message::CopyData(body) if in_copy_out => {
819 data.extend_from_slice(body.data());
820 }
821 Message::CopyDone => {
822 in_copy_out = false;
823 }
824 Message::CommandComplete(_) => {}
825 Message::ReadyForQuery(_) => {
826 return Ok(data);
827 }
828 Message::ErrorResponse(body) => {
829 return Err(self.consume_error(&body).await);
830 }
831 _ => {}
832 }
833 }
834 }
835
836 /// Prepares a statement using the extended query protocol (async).
837 ///
838 /// # Errors
839 ///
840 /// - Returns [`Error`] (connection) if the connection is unhealthy.
841 /// - Returns [`Error`] (server) if the server rejects the `Parse`
842 /// request (SQL syntax error, unknown type OIDs, etc.).
843 /// - Returns [`Error`] (I/O) on transport read/write failure.
844 pub async fn prepare(
845 &mut self,
846 name: &str,
847 query: &str,
848 param_types: &[crate::types::Oid],
849 ) -> Result<(Vec<crate::types::Oid>, Vec<super::statement::Column>)> {
850 use super::statement::{Column, ColumnFormat};
851
852 self.ensure_healthy()?;
853 self.drain_pending_copy_cancel().await?;
854
855 // Send Parse message
856 frontend::parse(name, query, param_types, &mut self.write_buf)?;
857
858 // Send Describe message for the statement
859 frontend::describe(b'S', name, &mut self.write_buf)?;
860
861 // Send Sync to get responses
862 frontend::sync(&mut self.write_buf);
863 self.flush().await?;
864
865 // Process responses
866 let mut parsed_params = Vec::new();
867 let mut parsed_columns = Vec::new();
868
869 loop {
870 let msg = self.read_message().await?;
871 match msg {
872 Message::ParseComplete => {}
873 Message::ParameterDescription(desc) => {
874 for oid in desc.parameters().filter_map(std::result::Result::ok) {
875 parsed_params.push(oid);
876 }
877 }
878 Message::RowDescription(desc) => {
879 for f in desc.fields().filter_map(std::result::Result::ok) {
880 parsed_columns.push(Column::new(
881 f.name().to_string(),
882 f.type_oid(),
883 f.type_modifier(),
884 ColumnFormat::from_code(f.format()),
885 ));
886 }
887 }
888 Message::NoData => {}
889 Message::ReadyForQuery(_) => {
890 break;
891 }
892 Message::ErrorResponse(body) => {
893 return Err(self.consume_error(&body).await);
894 }
895 _ => {}
896 }
897 }
898
899 Ok((parsed_params, parsed_columns))
900 }
901
902 /// Executes a prepared statement with parameters (async).
903 ///
904 /// # Errors
905 ///
906 /// - Returns [`Error`] (connection) if the connection is unhealthy.
907 /// - Returns [`Error`] (server) if `Bind` / `Execute` fails on the
908 /// server (parameter type mismatch, constraint violation, etc.).
909 /// - Returns [`Error`] (I/O) / [`Error`] (closed) on transport
910 /// read/write failure.
911 /// - Propagates row-construction errors from
912 /// `super::row::Row::new` if a `DataRow` cannot be decoded
913 /// against the reported `RowDescription`.
914 pub async fn execute_prepared(
915 &mut self,
916 statement_name: &str,
917 params: &[Option<&[u8]>],
918 column_count: usize,
919 ) -> Result<Vec<super::row::Row>> {
920 use super::statement::Column;
921 use std::sync::Arc;
922
923 self.ensure_healthy()?;
924 // Prepared-statement execution writes Bind/Execute/Sync into the
925 // buffer and flushes at the end; a pending CopyFail would be
926 // flushed together with our bind bytes and corrupt the response
927 // stream. See `start_query_binary` for the full argument.
928 self.drain_pending_copy_cancel().await?;
929 // Bind parameters (all in binary format)
930 let param_formats: Vec<i16> = vec![1; params.len()];
931 let result_formats: Vec<i16> = vec![1; column_count];
932
933 frontend::bind(
934 "",
935 statement_name,
936 ¶m_formats,
937 params,
938 &result_formats,
939 &mut self.write_buf,
940 )?;
941
942 frontend::execute("", 0, &mut self.write_buf)?;
943 frontend::sync(&mut self.write_buf);
944 self.flush().await?;
945
946 let mut rows = Vec::new();
947 let mut columns: Option<Arc<Vec<Column>>> = None;
948
949 loop {
950 let msg = self.read_message().await?;
951 match msg {
952 Message::BindComplete => {}
953 Message::RowDescription(desc) => {
954 let mut cols = Vec::new();
955 for f in desc.fields().filter_map(std::result::Result::ok) {
956 cols.push(Column::new(
957 f.name().to_string(),
958 f.type_oid(),
959 f.type_modifier(),
960 super::statement::ColumnFormat::from_code(f.format()),
961 ));
962 }
963 columns = Some(Arc::new(cols));
964 }
965 Message::DataRow(data) => {
966 if let Some(ref cols) = columns {
967 rows.push(super::row::Row::new(Arc::clone(cols), data)?);
968 }
969 }
970 Message::CommandComplete(_) => {}
971 Message::EmptyQueryResponse => {}
972 Message::ReadyForQuery(_) => {
973 break;
974 }
975 Message::ErrorResponse(body) => {
976 return Err(self.consume_error(&body).await);
977 }
978 _ => {}
979 }
980 }
981
982 Ok(rows)
983 }
984
985 /// Executes a prepared statement that doesn't return rows (async).
986 ///
987 /// # Errors
988 ///
989 /// Same failure modes as [`Self::execute_prepared`] (excluding
990 /// row-construction errors — this path never builds rows).
991 pub async fn execute_prepared_no_result(
992 &mut self,
993 statement_name: &str,
994 params: &[Option<&[u8]>],
995 ) -> Result<u64> {
996 self.ensure_healthy()?;
997 // See `execute_prepared` and `start_query_binary` for why we must
998 // drain any pending COPY cancel before writing new bytes.
999 self.drain_pending_copy_cancel().await?;
1000 let param_formats: Vec<i16> = vec![1; params.len()];
1001 let result_formats: Vec<i16> = vec![];
1002
1003 frontend::bind(
1004 "",
1005 statement_name,
1006 ¶m_formats,
1007 params,
1008 &result_formats,
1009 &mut self.write_buf,
1010 )?;
1011
1012 frontend::execute("", 0, &mut self.write_buf)?;
1013 frontend::sync(&mut self.write_buf);
1014 self.flush().await?;
1015
1016 let mut affected_rows = 0u64;
1017
1018 loop {
1019 let msg = self.read_message().await?;
1020 match msg {
1021 Message::BindComplete => {}
1022 Message::CommandComplete(body) => {
1023 if let Ok(tag) = body.tag() {
1024 // Parse formats like "INSERT 0 5", "UPDATE 10", "DELETE 3"
1025 let parts: Vec<&str> = tag.split_whitespace().collect();
1026 match parts.first() {
1027 Some(&"INSERT") => {
1028 if let Some(count) = parts.get(2) {
1029 affected_rows = count.parse().unwrap_or(0);
1030 }
1031 }
1032 Some(&"UPDATE" | &"DELETE" | &"SELECT" | &"COPY") => {
1033 if let Some(count) = parts.get(1) {
1034 affected_rows = count.parse().unwrap_or(0);
1035 }
1036 }
1037 _ => {}
1038 }
1039 }
1040 }
1041 Message::EmptyQueryResponse => {}
1042 Message::ReadyForQuery(_) => {
1043 break;
1044 }
1045 Message::ErrorResponse(body) => {
1046 return Err(self.consume_error(&body).await);
1047 }
1048 _ => {}
1049 }
1050 }
1051
1052 Ok(affected_rows)
1053 }
1054
1055 /// Closes a prepared statement (async).
1056 ///
1057 /// # Errors
1058 ///
1059 /// - Returns [`Error`] (connection) if the connection is unhealthy.
1060 /// - Returns [`Error`] (server) if the server reports an `ErrorResponse`
1061 /// during `Close`/`Sync`.
1062 /// - Returns [`Error`] (I/O) / [`Error`] (closed) on transport
1063 /// read/write failure.
1064 /// - Propagates any error from [`Self::drain_pending_copy_cancel`].
1065 pub async fn close_statement(&mut self, statement_name: &str) -> Result<()> {
1066 self.ensure_healthy()?;
1067 // Close + Sync get flushed together; a pending CopyFail would
1068 // share the flush and corrupt the response stream. See
1069 // `start_query_binary` for the full argument.
1070 self.drain_pending_copy_cancel().await?;
1071 frontend::close(b'S', statement_name, &mut self.write_buf)?;
1072 frontend::sync(&mut self.write_buf);
1073 self.flush().await?;
1074
1075 loop {
1076 let msg = self.read_message().await?;
1077 match msg {
1078 Message::CloseComplete => {}
1079 Message::ReadyForQuery(_) => {
1080 return Ok(());
1081 }
1082 Message::ErrorResponse(body) => {
1083 return Err(self.consume_error(&body).await);
1084 }
1085 _ => {}
1086 }
1087 }
1088 }
1089}