use std::{collections::VecDeque, task::Poll};
use derive_where::derive_where;
use futures_core::Stream as AsyncStream;
use futures_util::{stream::StreamExt, FutureExt};
use serde::{de::DeserializeOwned, Deserialize};
use crate::{
bson::{RawDocument, RawDocumentBuf},
error::{Error, Result},
BoxFuture,
};
use super::raw_batch::RawBatch;
#[derive_where(Debug)]
pub(super) struct Stream<'a, Raw, T> {
state: StreamState<'a, Raw>,
_phantom: std::marker::PhantomData<fn() -> T>,
}
impl<'a, Raw, T> Stream<'a, Raw, T> {
pub(super) fn new(raw: Raw) -> Self {
Self::from_cursor(BatchBuffer::new(raw))
}
pub(super) fn from_cursor(cs: BatchBuffer<Raw>) -> Self {
Self {
state: StreamState::Idle(cs),
_phantom: std::marker::PhantomData,
}
}
pub(super) fn buffer(&self) -> &BatchBuffer<Raw> {
match &self.state {
StreamState::Idle(state) => state,
_ => panic!("state access while streaming"),
}
}
pub(super) fn buffer_mut(&mut self) -> &mut BatchBuffer<Raw> {
match &mut self.state {
StreamState::Idle(state) => state,
_ => panic!("state access while streaming"),
}
}
pub(super) fn take_buffer(&mut self) -> BatchBuffer<Raw> {
match std::mem::replace(&mut self.state, StreamState::Polling) {
StreamState::Idle(state) => state,
_ => panic!("state access while streaming"),
}
}
pub(super) fn with_type<D>(self) -> Stream<'a, Raw, D> {
Stream {
state: self.state,
_phantom: std::marker::PhantomData,
}
}
}
#[derive_where(Debug)]
enum StreamState<'a, Raw> {
Idle(BatchBuffer<Raw>),
Polling,
Advance(#[derive_where(skip)] BoxFuture<'a, AdvanceDone<Raw>>),
}
#[derive_where(Debug)]
struct AdvanceDone<Raw> {
buffer: BatchBuffer<Raw>,
out: Result<bool>,
}
#[derive_where(Debug)]
pub(super) struct BatchBuffer<Raw> {
#[derive_where(skip)]
pub(super) raw: Raw,
batch: VecDeque<RawDocumentBuf>,
}
impl<Raw> BatchBuffer<Raw> {
pub(super) fn new(raw: Raw) -> Self {
Self {
raw,
batch: VecDeque::new(),
}
}
pub(super) fn current(&self) -> &RawDocument {
self.batch.front().unwrap()
}
pub(super) fn deserialize_current<'a, V>(&'a self) -> Result<V>
where
V: Deserialize<'a>,
{
crate::bson_compat::deserialize_from_slice(self.current().as_bytes()).map_err(Error::from)
}
pub(super) fn map<G>(self, f: impl FnOnce(Raw) -> G) -> BatchBuffer<G> {
BatchBuffer {
raw: f(self.raw),
batch: self.batch,
}
}
pub(crate) fn batch(&self) -> &VecDeque<RawDocumentBuf> {
&self.batch
}
}
impl<Raw: AsyncStream<Item = Result<RawBatch>> + Unpin> BatchBuffer<Raw> {
pub(super) async fn advance(&mut self) -> Result<bool> {
loop {
match self.advance_internal().await? {
AdvanceResult::Advanced => return Ok(true),
AdvanceResult::Exhausted => return Ok(false),
AdvanceResult::Waiting => continue,
}
}
}
pub(super) async fn try_advance(&mut self) -> Result<bool> {
self.advance_internal()
.await
.map(|ar| matches!(ar, AdvanceResult::Advanced))
}
async fn advance_internal(&mut self) -> Result<AdvanceResult> {
self.batch.pop_front();
if !self.batch.is_empty() {
return Ok(AdvanceResult::Advanced);
}
let Some(raw_batch) = self.raw.next().await else {
return Ok(AdvanceResult::Exhausted);
};
let raw_batch = raw_batch?;
for item in raw_batch.doc_slices()? {
self.batch.push_back(
item?
.as_document()
.ok_or_else(|| Error::invalid_response("invalid cursor batch item"))?
.to_owned(),
);
}
return Ok(if self.batch.is_empty() {
AdvanceResult::Waiting
} else {
AdvanceResult::Advanced
});
}
}
#[derive(Debug)]
enum AdvanceResult {
Advanced,
Exhausted,
Waiting,
}
impl<'a, Raw: 'a + AsyncStream<Item = Result<RawBatch>> + Send + Unpin, T: DeserializeOwned>
AsyncStream for Stream<'a, Raw, T>
{
type Item = Result<T>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
loop {
match std::mem::replace(&mut self.state, StreamState::Polling) {
StreamState::Idle(mut buffer) => {
self.state = StreamState::Advance(
async move {
let out = buffer.advance().await;
AdvanceDone { buffer, out }
}
.boxed(),
);
continue;
}
StreamState::Advance(mut fut) => {
return match fut.poll_unpin(cx) {
Poll::Pending => {
self.state = StreamState::Advance(fut);
Poll::Pending
}
Poll::Ready(ar) => {
let out = match ar.out {
Err(e) => Some(Err(e)),
Ok(false) => None,
Ok(true) => Some(ar.buffer.deserialize_current()),
};
self.state = StreamState::Idle(ar.buffer);
return Poll::Ready(out);
}
}
}
StreamState::Polling => {
return Poll::Ready(Some(Err(Error::internal(
"attempt to poll cursor already in polling state",
))))
}
}
}
}
}