use std::sync::Arc;
use crate::protocol::message::backend::Message;
use tokio::sync::MutexGuard;
use tracing::warn;
use super::async_connection::AsyncRawConnection;
use super::async_stream::AsyncStream;
use super::cancel::Cancellable;
use super::connection::parse_error_response;
use super::error::Result;
use super::row::StreamRow;
use super::statement::Column;
pub struct AsyncPreparedQueryStream<'a> {
conn: Option<MutexGuard<'a, AsyncRawConnection<AsyncStream>>>,
canceller: &'a dyn Cancellable,
finished: bool,
chunk_size: usize,
columns: Arc<Vec<Column>>,
}
impl std::fmt::Debug for AsyncPreparedQueryStream<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncPreparedQueryStream")
.field("finished", &self.finished)
.field("chunk_size", &self.chunk_size)
.field("column_count", &self.columns.len())
.finish_non_exhaustive()
}
}
impl<'a> AsyncPreparedQueryStream<'a> {
pub(crate) fn new(
conn: MutexGuard<'a, AsyncRawConnection<AsyncStream>>,
canceller: &'a dyn Cancellable,
chunk_size: usize,
columns: Arc<Vec<Column>>,
) -> Self {
Self {
conn: Some(conn),
canceller,
finished: false,
chunk_size: chunk_size.max(1),
columns,
}
}
#[must_use]
pub fn schema(&self) -> &[Column] {
&self.columns
}
pub async fn next_chunk(&mut self) -> Result<Option<Vec<StreamRow>>> {
if self.finished {
return Ok(None);
}
let Some(conn) = self.conn.as_mut() else {
return Ok(None);
};
let mut rows = Vec::with_capacity(self.chunk_size);
while rows.len() < self.chunk_size {
let msg = conn.read_message().await?;
match msg {
Message::BindComplete => {}
Message::DataRow(data) => {
rows.push(StreamRow::new(data));
if rows.len() >= self.chunk_size {
return Ok(Some(rows));
}
}
Message::CommandComplete(_) | Message::EmptyQueryResponse => {}
Message::ReadyForQuery(_) => {
self.finished = true;
self.conn = None;
return if rows.is_empty() {
Ok(None)
} else {
Ok(Some(rows))
};
}
Message::ErrorResponse(body) => {
self.finished = true;
let err = match self.conn {
Some(ref mut c) => c.consume_error(&body).await,
None => parse_error_response(&body),
};
self.conn = None;
return Err(err);
}
_ => {}
}
}
Ok(Some(rows))
}
}
impl Drop for AsyncPreparedQueryStream<'_> {
fn drop(&mut self) {
if self.finished {
return;
}
self.canceller.cancel();
if let Some(ref mut conn) = self.conn {
conn.mark_desynchronized();
warn!(
target: "hyperdb_api_core::client",
"AsyncPreparedQueryStream dropped before completion; \
connection marked desynchronized — discard and reconnect",
);
}
}
}