use tokio::sync::MutexGuard;
use tracing::warn;
use crate::protocol::message::backend::Message;
use super::async_connection::AsyncRawConnection;
use super::async_stream::AsyncStream;
use super::cancel::Cancellable;
use super::connection::parse_error_response;
use super::error::Result;
use super::row::StreamRow;
use super::statement::{Column, ColumnFormat};
pub struct AsyncQueryStream<'a> {
conn: Option<MutexGuard<'a, AsyncRawConnection<AsyncStream>>>,
canceller: &'a dyn Cancellable,
finished: bool,
chunk_size: usize,
schema: Option<Vec<Column>>,
schema_read: bool,
}
impl std::fmt::Debug for AsyncQueryStream<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncQueryStream")
.field("finished", &self.finished)
.field("chunk_size", &self.chunk_size)
.field("schema_read", &self.schema_read)
.finish_non_exhaustive()
}
}
impl<'a> AsyncQueryStream<'a> {
pub(crate) fn new(
conn: MutexGuard<'a, AsyncRawConnection<AsyncStream>>,
canceller: &'a dyn Cancellable,
chunk_size: usize,
) -> Self {
Self {
conn: Some(conn),
canceller,
finished: false,
chunk_size: chunk_size.max(1),
schema: None,
schema_read: false,
}
}
#[must_use]
pub fn schema(&self) -> Option<&[Column]> {
self.schema.as_deref()
}
pub async fn next_chunk(&mut self) -> Result<Option<Vec<StreamRow>>> {
if self.finished {
return Ok(None);
}
let Some(conn) = self.conn.as_mut() else {
return Ok(None);
};
let mut rows = Vec::with_capacity(self.chunk_size);
while rows.len() < self.chunk_size {
let msg = conn.read_message().await?;
match msg {
Message::RowDescription(desc) if !self.schema_read => {
let mut cols = Vec::new();
for f in desc.fields().filter_map(std::result::Result::ok) {
cols.push(Column::new(
f.name().to_string(),
f.type_oid(),
f.type_modifier(),
ColumnFormat::from_code(f.format()),
));
}
self.schema = Some(cols);
self.schema_read = true;
}
Message::DataRow(data) => {
rows.push(StreamRow::new(data));
if rows.len() >= self.chunk_size {
return Ok(Some(rows));
}
}
Message::ReadyForQuery(_) => {
self.finished = true;
self.conn = None;
return if rows.is_empty() {
Ok(None)
} else {
Ok(Some(rows))
};
}
Message::ErrorResponse(body) => {
self.finished = true;
let err = match self.conn {
Some(ref mut c) => c.consume_error(&body).await,
None => parse_error_response(&body),
};
self.conn = None;
return Err(err);
}
_ => {}
}
}
Ok(Some(rows))
}
}
impl Drop for AsyncQueryStream<'_> {
fn drop(&mut self) {
if self.finished {
return;
}
self.canceller.cancel();
if let Some(ref mut conn) = self.conn {
conn.mark_desynchronized();
warn!(
target: "hyperdb_api_core::client",
"AsyncQueryStream dropped before completion; \
connection marked desynchronized — discard and reconnect",
);
}
}
}