use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
vec,
};
use futures_core::ready;
use futures_util::Stream;
use serde_json::Value;
use sqlx_core::{ext::ustr::UStr, logger::QueryLogger, Either, HashMap};
use crate::{
column::ExaColumn,
connection::websocket::{
future::{
CloseResultSets, Execute, ExecuteBatch, ExecutePrepared, FetchChunk, WebSocketFuture,
},
ExaWebSocket,
},
error::ExaProtocolError,
query_result::ExaQueryResult,
responses::{DataChunk, MultiResults, QueryResult, ResultSet, ResultSetOutput, SingleResult},
row::ExaRow,
SqlxError, SqlxResult,
};
pub struct ResultStream<'ws> {
ws: &'ws mut ExaWebSocket,
logger: QueryLogger,
result_set_handles: Vec<u16>,
state: ResultStreamState,
had_err: bool,
}
impl<'ws> ResultStream<'ws> {
pub fn new<F>(ws: &'ws mut ExaWebSocket, logger: QueryLogger, future: F) -> Self
where
ResultStreamState: From<F>,
{
Self {
ws,
logger,
result_set_handles: Vec::new(),
state: future.into(),
had_err: false,
}
}
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<<Self as Stream>::Item>> {
loop {
match &mut self.state {
ResultStreamState::Execute(future) => {
let multi_stream = ready!(future.poll_unpin(cx, self.ws))?;
self.result_set_handles = multi_stream.handles();
self.state = ResultStreamState::Stream(multi_stream);
}
ResultStreamState::ExecuteBatch(future) => {
let multi_stream = ready!(future.poll_unpin(cx, self.ws))?;
self.result_set_handles = multi_stream.handles();
self.state = ResultStreamState::Stream(multi_stream);
}
ResultStreamState::ExecutePrepared(future) => {
let multi_stream = ready!(future.poll_unpin(cx, self.ws))?;
self.result_set_handles = multi_stream.handles();
self.state = ResultStreamState::Stream(multi_stream);
}
ResultStreamState::Stream(stream) => {
let Some(either) = ready!(stream.poll_next_unpin(cx, self.ws)).transpose()?
else {
return Poll::Ready(None);
};
match &either {
Either::Left(q) => self.logger.increase_rows_affected(q.rows_affected()),
Either::Right(_) => self.logger.increment_rows_returned(),
}
return Poll::Ready(Some(Ok(either)));
}
}
}
}
}
impl Stream for ResultStream<'_> {
type Item = SqlxResult<Either<ExaQueryResult, ExaRow>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.had_err {
return Poll::Ready(None);
}
let poll = this.poll(cx);
if let Poll::Ready(Some(Err(_))) = &poll {
this.had_err = true;
}
poll
}
}
impl Drop for ResultStream<'_> {
fn drop(&mut self) {
let handles = std::mem::take(&mut self.result_set_handles);
if !handles.is_empty() {
self.ws.pending_close = Some(CloseResultSets::new(handles));
}
}
}
pub enum ResultStreamState {
Execute(Execute),
ExecuteBatch(ExecuteBatch),
ExecutePrepared(ExecutePrepared),
Stream(MultiResultStream),
}
impl From<Execute> for ResultStreamState {
fn from(value: Execute) -> Self {
Self::Execute(value)
}
}
impl From<ExecuteBatch> for ResultStreamState {
fn from(value: ExecuteBatch) -> Self {
Self::ExecuteBatch(value)
}
}
impl From<ExecutePrepared> for ResultStreamState {
fn from(value: ExecutePrepared) -> Self {
Self::ExecutePrepared(value)
}
}
trait WebsocketStream: Unpin {
type Item;
fn poll_next_unpin(
&mut self,
cx: &mut Context<'_>,
ws: &mut ExaWebSocket,
) -> Poll<Option<Self::Item>>;
}
pub struct MultiResultStream {
next_results: vec::IntoIter<QueryResult>,
stream: QueryResultStream,
}
impl MultiResultStream {
pub fn new(first_result: QueryResult, next_results: vec::IntoIter<QueryResult>) -> Self {
let stream = QueryResultStream::new(first_result);
Self {
next_results,
stream,
}
}
fn handles(&self) -> Vec<u16> {
let first_handle = match &self.stream {
QueryResultStream::RowStream(row_stream) => match &row_stream.chunk_stream {
ChunkStream::Multi(multi_chunk_stream) => Some(multi_chunk_stream.handle),
ChunkStream::Single(_) => None,
},
QueryResultStream::RowCount(_) => None,
};
let results_handles_iter = self
.next_results
.as_slice()
.iter()
.filter_map(QueryResult::handle);
first_handle
.into_iter()
.chain(results_handles_iter)
.collect()
}
}
impl WebsocketStream for MultiResultStream {
type Item = SqlxResult<Either<ExaQueryResult, ExaRow>>;
fn poll_next_unpin(
&mut self,
cx: &mut Context<'_>,
ws: &mut ExaWebSocket,
) -> Poll<Option<Self::Item>> {
loop {
if let Some(res) = ready!(self.stream.poll_next_unpin(cx, ws)) {
return Poll::Ready(Some(res));
}
let Some(qr) = self.next_results.next() else {
return Poll::Ready(None);
};
self.stream = QueryResultStream::new(qr);
}
}
}
impl From<SingleResult> for MultiResultStream {
fn from(value: SingleResult) -> Self {
Self::new(value.into(), Vec::new().into_iter())
}
}
impl TryFrom<MultiResults> for MultiResultStream {
type Error = SqlxError;
fn try_from(value: MultiResults) -> Result<Self, Self::Error> {
let mut next_results = value.results.into_iter();
let Some(first_result) = next_results.next() else {
return Err(ExaProtocolError::NoResponse)?;
};
Ok(MultiResultStream::new(first_result, next_results))
}
}
pub enum QueryResultStream {
RowStream(RowStream),
RowCount(Option<ExaQueryResult>),
}
impl QueryResultStream {
fn new(query_result: QueryResult) -> Self {
match query_result {
QueryResult::ResultSet { result_set: rs } => Self::RowStream(RowStream::new(rs)),
QueryResult::RowCount { row_count } => {
let query_result = ExaQueryResult::new(row_count);
Self::RowCount(Some(query_result))
}
}
}
}
impl WebsocketStream for QueryResultStream {
type Item = SqlxResult<Either<ExaQueryResult, ExaRow>>;
fn poll_next_unpin(
&mut self,
cx: &mut Context<'_>,
ws: &mut ExaWebSocket,
) -> Poll<Option<Self::Item>> {
match self {
QueryResultStream::RowStream(rs) => rs
.poll_next_unpin(cx, ws)
.map(|o| o.map(|r| r.map(Either::Right))),
QueryResultStream::RowCount(qr) => Poll::Ready(qr.take().map(Either::Left).map(Ok)),
}
}
}
pub struct RowStream {
chunk_stream: ChunkStream,
chunk_iter: ChunkIter,
}
impl RowStream {
fn new(rs: ResultSet) -> Self {
let ResultSet {
total_rows_num,
total_rows_pos,
output,
columns,
} = rs;
let chunk_iter = ChunkIter::new(columns);
let chunk_stream = match output {
ResultSetOutput::Handle(handle) => {
ChunkStream::Multi(MultiChunkStream::new(handle, total_rows_num))
}
ResultSetOutput::Data(data) => {
let num_rows = total_rows_pos;
ChunkStream::Single(Some(DataChunk { num_rows, data }))
}
};
Self {
chunk_stream,
chunk_iter,
}
}
}
impl WebsocketStream for RowStream {
type Item = SqlxResult<ExaRow>;
fn poll_next_unpin(
&mut self,
cx: &mut Context<'_>,
ws: &mut ExaWebSocket,
) -> Poll<Option<Self::Item>> {
loop {
if let Some(row) = self.chunk_iter.next() {
return Poll::Ready(Some(Ok(row)));
}
match ready!(self.chunk_stream.poll_next_unpin(cx, ws)?) {
Some(chunk) => self.chunk_iter.renew(chunk),
None => return Poll::Ready(None),
}
}
}
}
enum ChunkStream {
Multi(MultiChunkStream),
Single(Option<DataChunk>),
}
impl WebsocketStream for ChunkStream {
type Item = SqlxResult<DataChunk>;
fn poll_next_unpin(
&mut self,
cx: &mut Context<'_>,
ws: &mut ExaWebSocket,
) -> Poll<Option<Self::Item>> {
match self {
Self::Multi(s) => s.poll_next_unpin(cx, ws),
Self::Single(chunk) => Poll::Ready(chunk.take().map(Ok)),
}
}
}
struct MultiChunkStream {
handle: u16,
total_rows_num: usize,
total_rows_pos: usize,
state: MultiChunkStreamState,
}
impl MultiChunkStream {
fn new(handle: u16, total_rows_num: usize) -> Self {
Self {
handle,
total_rows_num,
total_rows_pos: 0,
state: MultiChunkStreamState::Initial,
}
}
}
impl WebsocketStream for MultiChunkStream {
type Item = SqlxResult<DataChunk>;
fn poll_next_unpin(
&mut self,
cx: &mut Context<'_>,
ws: &mut ExaWebSocket,
) -> Poll<Option<Self::Item>> {
loop {
match &mut self.state {
MultiChunkStreamState::Initial => {
let num_bytes = ws.attributes.fetch_size();
let future = FetchChunk::new(self.handle, self.total_rows_pos, num_bytes);
self.state = MultiChunkStreamState::Polling(future);
}
MultiChunkStreamState::Polling(future) => {
if self.total_rows_pos >= self.total_rows_num {
self.state = MultiChunkStreamState::Finished;
continue;
}
let num_bytes = ws.attributes.fetch_size();
let chunk = ready!(future.poll_unpin(cx, ws))?;
self.total_rows_pos += chunk.num_rows;
let future = FetchChunk::new(self.handle, self.total_rows_pos, num_bytes);
self.state = MultiChunkStreamState::Polling(future);
return Poll::Ready(Some(Ok(chunk)));
}
MultiChunkStreamState::Finished => return Poll::Ready(None),
}
}
}
}
enum MultiChunkStreamState {
Initial,
Polling(FetchChunk),
Finished,
}
struct ChunkIter {
column_names: Arc<HashMap<UStr, usize>>,
columns: Arc<[ExaColumn]>,
chunk_rows_total: usize,
chunk_rows_pos: usize,
data: vec::IntoIter<Vec<Value>>,
}
impl ChunkIter {
fn new(columns: Arc<[ExaColumn]>) -> Self {
let column_names = columns
.iter()
.enumerate()
.map(|(i, c)| (c.name.clone(), i))
.collect();
Self {
column_names: Arc::new(column_names),
columns,
chunk_rows_total: 0,
chunk_rows_pos: 0,
data: Vec::new().into_iter(),
}
}
}
impl ChunkIter {
fn renew(&mut self, chunk: DataChunk) {
self.chunk_rows_pos = 0;
self.chunk_rows_total = chunk.num_rows;
self.data = chunk.data.into_iter();
}
}
impl Iterator for ChunkIter {
type Item = ExaRow;
fn next(&mut self) -> Option<Self::Item> {
debug_assert!(self.chunk_rows_pos <= self.chunk_rows_total);
let row = ExaRow::new(
self.data.next()?,
self.columns.clone(),
self.column_names.clone(),
);
self.chunk_rows_pos += 1;
Some(row)
}
}