use std::time::{Duration, Instant};
use crate::error::Result;
use crate::schema::{ColumnHeader, QueryResult, Row};
use crate::stream::DynRowStream;
const COLLECT_PREALLOC_CAP: usize = 1024;
fn elapsed_ms_saturating(d: Duration) -> u64 {
u64::try_from(d.as_millis()).unwrap_or(u64::MAX)
}
pub struct QueryStream {
inner: Box<dyn DynRowStream>,
started: Instant,
rows_yielded: usize,
drained: bool,
}
impl QueryStream {
#[must_use]
pub fn new(inner: Box<dyn DynRowStream>) -> Self {
Self {
inner,
started: Instant::now(),
rows_yielded: 0,
drained: false,
}
}
#[must_use]
pub fn columns(&self) -> &[ColumnHeader] {
self.inner.columns()
}
#[must_use]
pub const fn rows_yielded(&self) -> usize {
self.rows_yielded
}
#[must_use]
pub fn elapsed(&self) -> Duration {
self.started.elapsed()
}
pub async fn next_row(&mut self) -> Option<Result<Row>> {
if self.drained {
return None;
}
match self.inner.next_row().await {
Ok(Some(row)) => {
self.rows_yielded += 1;
Some(Ok(row))
}
Ok(None) => {
self.drained = true;
None
}
Err(error) => {
self.drained = true;
Some(Err(error))
}
}
}
pub async fn collect_all(mut self) -> Result<QueryResult> {
let mut rows = Vec::new();
loop {
match self.next_row().await {
Some(Ok(row)) => rows.push(row),
Some(Err(error)) => {
let close_result = self.inner.close().await;
if let Err(close_err) = close_result {
tracing::warn!(
target: "narwhal::query_stream",
error = %close_err,
"close-after-error failed (possible cursor leak)",
);
}
return Err(error);
}
None => break,
}
}
let elapsed_ms = elapsed_ms_saturating(self.started.elapsed());
let columns = self.inner.columns().to_vec();
if let Err(close_err) = self.inner.close().await {
tracing::warn!(
target: "narwhal::query_stream",
error = %close_err,
"close after end-of-stream failed (possible cursor leak)",
);
}
Ok(QueryResult {
columns,
rows,
rows_affected: None,
elapsed_ms,
})
}
pub async fn collect_with_limit(mut self, limit: usize) -> Result<(QueryResult, bool)> {
if limit == 0 {
let truncated = !self.drained && self.peek_has_more().await?;
let elapsed_ms = elapsed_ms_saturating(self.started.elapsed());
let columns = self.inner.columns().to_vec();
if let Err(close_err) = self.inner.close().await {
tracing::warn!(
target: "narwhal::query_stream",
error = %close_err,
"close after zero-limit peek failed (possible cursor leak)",
);
}
return Ok((
QueryResult {
columns,
rows: Vec::new(),
rows_affected: None,
elapsed_ms,
},
truncated,
));
}
let mut rows = Vec::with_capacity(limit.min(COLLECT_PREALLOC_CAP));
let mut truncated = false;
while rows.len() < limit {
match self.next_row().await {
Some(Ok(row)) => rows.push(row),
Some(Err(error)) => {
if let Err(close_err) = self.inner.close().await {
tracing::warn!(
target: "narwhal::query_stream",
error = %close_err,
"close-after-error failed (possible cursor leak)",
);
}
return Err(error);
}
None => break,
}
}
if rows.len() == limit && !self.drained {
match self.peek_has_more().await {
Ok(more) => truncated = more,
Err(error) => {
if let Err(close_err) = self.inner.close().await {
tracing::warn!(
target: "narwhal::query_stream",
error = %close_err,
"close-after-error failed (possible cursor leak)",
);
}
return Err(error);
}
}
}
let elapsed_ms = elapsed_ms_saturating(self.started.elapsed());
let columns = self.inner.columns().to_vec();
if let Err(close_err) = self.inner.close().await {
tracing::warn!(
target: "narwhal::query_stream",
error = %close_err,
"close after limit drain failed (possible cursor leak)",
);
}
Ok((
QueryResult {
columns,
rows,
rows_affected: None,
elapsed_ms,
},
truncated,
))
}
async fn peek_has_more(&mut self) -> Result<bool> {
match self.inner.next_row().await {
Ok(Some(_discarded)) => Ok(true),
Ok(None) => {
self.drained = true;
Ok(false)
}
Err(error) => {
self.drained = true;
Err(error)
}
}
}
pub async fn close(self) -> Result<()> {
self.inner.close().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Error;
use crate::future::BoxFuture;
use crate::schema::Row;
use crate::stream::DynRowStream;
use crate::value::Value;
struct VecStream {
columns: Vec<ColumnHeader>,
rows: std::vec::IntoIter<Row>,
terminal: Option<Error>,
close_called: std::sync::Arc<std::sync::atomic::AtomicBool>,
}
impl VecStream {
fn new(
columns: Vec<ColumnHeader>,
rows: Vec<Row>,
terminal: Option<Error>,
) -> (Self, std::sync::Arc<std::sync::atomic::AtomicBool>) {
let close_called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let stream = Self {
columns,
rows: rows.into_iter(),
terminal,
close_called: std::sync::Arc::clone(&close_called),
};
(stream, close_called)
}
}
impl DynRowStream for VecStream {
fn columns(&self) -> &[ColumnHeader] {
&self.columns
}
fn next_row(&mut self) -> BoxFuture<'_, Result<Option<Row>>> {
Box::pin(async move {
if let Some(row) = self.rows.next() {
return Ok(Some(row));
}
if let Some(error) = self.terminal.take() {
return Err(error);
}
Ok(None)
})
}
fn close(self: Box<Self>) -> BoxFuture<'static, Result<()>> {
self.close_called
.store(true, std::sync::atomic::Ordering::SeqCst);
Box::pin(async { Ok(()) })
}
}
fn col(name: &str) -> ColumnHeader {
ColumnHeader {
name: name.to_owned(),
data_type: "TEXT".to_owned(),
}
}
fn row(values: &[&str]) -> Row {
Row(values
.iter()
.map(|s| Value::String((*s).to_owned()))
.collect())
}
#[tokio::test]
async fn next_row_yields_then_ends() {
let (s, closed) = VecStream::new(vec![col("a")], vec![row(&["1"]), row(&["2"])], None);
let mut qs = QueryStream::new(Box::new(s));
assert_eq!(qs.rows_yielded(), 0);
assert!(qs.next_row().await.unwrap().is_ok());
assert_eq!(qs.rows_yielded(), 1);
assert!(qs.next_row().await.unwrap().is_ok());
assert!(qs.next_row().await.is_none());
assert!(qs.next_row().await.is_none());
assert!(!closed.load(std::sync::atomic::Ordering::SeqCst));
qs.close().await.unwrap();
assert!(closed.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn collect_all_round_trips() {
let (s, closed) = VecStream::new(
vec![col("a"), col("b")],
vec![row(&["1", "x"]), row(&["2", "y"]), row(&["3", "z"])],
None,
);
let qs = QueryStream::new(Box::new(s));
let qr = qs.collect_all().await.unwrap();
assert_eq!(qr.columns.len(), 2);
assert_eq!(qr.rows.len(), 3);
assert!(qr.rows_affected.is_none());
assert!(closed.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn collect_all_propagates_terminal_error() {
let err = Error::Query("boom".into());
let (s, closed) = VecStream::new(vec![col("a")], vec![row(&["only-row"])], Some(err));
let qs = QueryStream::new(Box::new(s));
let result = qs.collect_all().await;
assert!(matches!(result, Err(Error::Query(_))));
assert!(closed.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn next_row_fuses_after_error() {
let err = Error::Query("boom".into());
let (s, _) = VecStream::new(vec![col("a")], vec![], Some(err));
let mut qs = QueryStream::new(Box::new(s));
assert!(matches!(qs.next_row().await, Some(Err(_))));
assert!(qs.next_row().await.is_none());
assert!(qs.next_row().await.is_none());
}
#[tokio::test]
async fn collect_with_limit_truncates() {
let (s, closed) = VecStream::new(
vec![col("a")],
(0..10).map(|i| row(&[&i.to_string()])).collect(),
None,
);
let qs = QueryStream::new(Box::new(s));
let (qr, truncated) = qs.collect_with_limit(3).await.unwrap();
assert_eq!(qr.rows.len(), 3);
assert!(truncated, "expected truncated=true when engine has more");
assert!(closed.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn collect_with_limit_not_truncated_when_exact_fit() {
let (s, closed) = VecStream::new(
vec![col("a")],
vec![row(&["1"]), row(&["2"]), row(&["3"])],
None,
);
let qs = QueryStream::new(Box::new(s));
let (qr, truncated) = qs.collect_with_limit(3).await.unwrap();
assert_eq!(qr.rows.len(), 3);
assert!(
!truncated,
"expected truncated=false when engine ends at limit"
);
assert!(closed.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn collect_with_limit_not_truncated_when_under() {
let (s, _) = VecStream::new(vec![col("a")], vec![row(&["1"])], None);
let qs = QueryStream::new(Box::new(s));
let (qr, truncated) = qs.collect_with_limit(10).await.unwrap();
assert_eq!(qr.rows.len(), 1);
assert!(!truncated);
}
#[tokio::test]
async fn collect_with_limit_zero_short_circuits_with_rows() {
let (s, closed) = VecStream::new(vec![col("a")], vec![row(&["1"]), row(&["2"])], None);
let qs = QueryStream::new(Box::new(s));
let (qr, truncated) = qs.collect_with_limit(0).await.unwrap();
assert!(qr.rows.is_empty());
assert!(truncated, "engine had rows; truncated must be true");
assert!(closed.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn collect_with_limit_zero_on_empty_stream() {
let (s, closed) = VecStream::new(vec![col("a")], vec![], None);
let qs = QueryStream::new(Box::new(s));
let (qr, truncated) = qs.collect_with_limit(0).await.unwrap();
assert!(qr.rows.is_empty());
assert!(!truncated, "empty stream is not truncated");
assert!(closed.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn collect_with_limit_truncated_yields_exactly_limit() {
let (s, _) = VecStream::new(
vec![col("a")],
(0..10).map(|i| row(&[&i.to_string()])).collect(),
None,
);
let qs = QueryStream::new(Box::new(s));
let (qr, truncated) = qs.collect_with_limit(3).await.unwrap();
assert_eq!(
qr.rows.len(),
3,
"limit cap is hard — no over-collection from the peek"
);
assert!(truncated);
}
#[tokio::test]
async fn columns_delegates_to_inner() {
let inner_cols = vec![col("a"), col("b"), col("c")];
let (s, _) = VecStream::new(inner_cols, vec![], None);
let qs = QueryStream::new(Box::new(s));
assert_eq!(qs.columns().len(), 3);
assert_eq!(qs.columns()[0].name, "a");
assert_eq!(qs.columns()[2].name, "c");
}
#[tokio::test]
async fn collect_all_materialises_columns_from_inner() {
let (s, _) = VecStream::new(
vec![col("alpha"), col("beta")],
vec![row(&["1", "x"])],
None,
);
let qs = QueryStream::new(Box::new(s));
let qr = qs.collect_all().await.unwrap();
assert_eq!(qr.columns.len(), 2);
assert_eq!(qr.columns[0].name, "alpha");
assert_eq!(qr.columns[1].name, "beta");
}
#[tokio::test]
async fn rows_yielded_tracks_correctly() {
let (s, _) = VecStream::new(
vec![col("a")],
vec![row(&["1"]), row(&["2"]), row(&["3"])],
None,
);
let mut qs = QueryStream::new(Box::new(s));
let _ = qs.next_row().await;
assert_eq!(qs.rows_yielded(), 1);
let _ = qs.next_row().await;
let _ = qs.next_row().await;
assert_eq!(qs.rows_yielded(), 3);
let _ = qs.next_row().await; assert_eq!(qs.rows_yielded(), 3);
}
#[tokio::test]
async fn drop_releases_without_close() {
let (s, closed) = VecStream::new(
vec![col("a")],
(0..1000).map(|i| row(&[&i.to_string()])).collect(),
None,
);
let mut qs = QueryStream::new(Box::new(s));
let _ = qs.next_row().await;
let _ = qs.next_row().await;
drop(qs);
assert!(!closed.load(std::sync::atomic::Ordering::SeqCst));
}
}