use crate::error::SqlError;
use crate::guard::SizeGuards;
use crate::value::{ColumnInfo, Row};
use futures_util::stream::{Stream, StreamExt};
use std::pin::Pin;
pub const DEFAULT_CURSOR_CAPACITY: usize = 1024;
pub type BoxRowStream<'a> = Pin<Box<dyn Stream<Item = Result<Row, SqlError>> + Send + 'a>>;
pub struct RowCursor<'a> {
columns: Vec<ColumnInfo>,
rt: &'a tokio::runtime::Runtime,
stream: BoxRowStream<'a>,
guards: SizeGuards,
row_ordinal: u64,
exhausted: bool,
}
impl<'a> RowCursor<'a> {
pub(crate) fn new(
columns: Vec<ColumnInfo>,
rt: &'a tokio::runtime::Runtime,
stream: BoxRowStream<'a>,
guards: SizeGuards,
) -> Self {
Self {
columns,
rt,
stream,
guards,
row_ordinal: 0,
exhausted: false,
}
}
#[must_use]
pub fn columns(&self) -> &[ColumnInfo] {
&self.columns
}
pub fn next_batch(&mut self, n: usize) -> Result<Vec<Row>, SqlError> {
if n == 0 || self.exhausted {
return Ok(Vec::new());
}
let columns = &self.columns;
let guards = &self.guards;
let stream = &mut self.stream;
let start_ordinal = self.row_ordinal;
let cap = n.min(DEFAULT_CURSOR_CAPACITY);
let result: Result<(Vec<Row>, bool), SqlError> = self.rt.block_on(async move {
let mut out = Vec::with_capacity(cap);
let mut ordinal = start_ordinal;
for _ in 0..n {
match stream.next().await {
Some(Ok(row)) => {
guards.check_row(ordinal, &row, columns)?;
ordinal += 1;
out.push(row);
}
Some(Err(e)) => return Err(e),
None => return Ok((out, true)),
}
}
Ok((out, false))
});
match result {
Ok((out, reached_end)) => {
self.row_ordinal += out.len() as u64;
self.exhausted = reached_end;
Ok(out)
}
Err(e) => {
self.exhausted = true;
Err(e)
}
}
}
}
impl Iterator for RowCursor<'_> {
type Item = Result<Row, SqlError>;
fn next(&mut self) -> Option<Self::Item> {
if self.exhausted {
return None;
}
match self.next_batch(1) {
Ok(mut rows) => rows.pop().map(Ok),
Err(e) => Some(Err(e)),
}
}
}
pub(crate) fn channel_stream(
rx: tokio::sync::mpsc::Receiver<Result<Row, SqlError>>,
) -> BoxRowStream<'static> {
Box::pin(tokio_stream_from_channel(rx))
}
fn tokio_stream_from_channel(
mut rx: tokio::sync::mpsc::Receiver<Result<Row, SqlError>>,
) -> impl Stream<Item = Result<Row, SqlError>> + Send + 'static {
futures_util::stream::poll_fn(move |cx| rx.poll_recv(cx))
}