use std::{
collections::VecDeque,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use derivative::Derivative;
use futures_core::{Future, Stream};
use serde::de::DeserializeOwned;
use crate::{
error::{Error, ErrorKind, Result},
operation,
options::ServerAddress,
results::GetMoreResult,
Client,
Namespace,
};
#[derive(Derivative)]
#[derivative(Debug)]
pub(super) struct GenericCursor<P, T>
where
P: GetMoreProvider<DocumentType = T>,
{
#[derivative(Debug = "ignore")]
provider: P,
client: Client,
info: CursorInformation,
buffer: VecDeque<T>,
exhausted: bool,
}
impl<P, T> GenericCursor<P, T>
where
P: GetMoreProvider<DocumentType = T>,
T: DeserializeOwned,
{
pub(super) fn new(client: Client, spec: CursorSpecification<T>, get_more_provider: P) -> Self {
let exhausted = spec.id() == 0;
Self {
exhausted,
client,
provider: get_more_provider,
buffer: spec.initial_buffer,
info: spec.info,
}
}
pub(super) fn take_buffer(&mut self) -> VecDeque<T> {
std::mem::take(&mut self.buffer)
}
pub(super) fn is_exhausted(&self) -> bool {
self.exhausted
}
pub(super) fn id(&self) -> i64 {
self.info.id
}
pub(super) fn namespace(&self) -> &Namespace {
&self.info.ns
}
fn start_get_more(&mut self) {
let info = self.info.clone();
let client = self.client.clone();
self.provider.start_execution(info, client);
}
}
impl<P, T> Stream for GenericCursor<P, T>
where
P: GetMoreProvider<DocumentType = T>,
T: DeserializeOwned + Unpin,
{
type Item = Result<T>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if let Some(future) = self.provider.executing_future() {
match Pin::new(future).poll(cx) {
Poll::Ready(get_more_result) => {
let exhausted = get_more_result.exhausted();
let (result, session) = get_more_result.into_parts();
self.exhausted = exhausted;
self.provider.clear_execution(session, exhausted);
self.buffer = result?.batch;
}
Poll::Pending => return Poll::Pending,
}
}
match self.buffer.pop_front() {
Some(doc) => {
if self.buffer.is_empty() && !self.exhausted {
self.start_get_more();
}
return Poll::Ready(Some(Ok(doc)));
}
None if !self.exhausted => {
self.start_get_more();
}
None => return Poll::Ready(None),
}
}
}
}
pub(super) trait GetMoreProvider: Unpin {
type DocumentType;
type ResultType: GetMoreProviderResult<DocumentType = Self::DocumentType>;
type GetMoreFuture: Future<Output = Self::ResultType> + Unpin;
fn executing_future(&mut self) -> Option<&mut Self::GetMoreFuture>;
fn clear_execution(
&mut self,
session: <Self::ResultType as GetMoreProviderResult>::Session,
exhausted: bool,
);
fn start_execution(&mut self, spec: CursorInformation, client: Client);
}
pub(super) trait GetMoreProviderResult {
type Session;
type DocumentType;
fn as_ref(&self) -> std::result::Result<&GetMoreResult<Self::DocumentType>, &Error>;
fn into_parts(self) -> (Result<GetMoreResult<Self::DocumentType>>, Self::Session);
fn exhausted(&self) -> bool {
match self.as_ref() {
Ok(res) => res.exhausted,
Err(e) => {
matches!(*e.kind, ErrorKind::Command(ref e) if e.code == 43 || e.code == 237)
}
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct CursorSpecification<T> {
pub(crate) info: CursorInformation,
pub(crate) initial_buffer: VecDeque<T>,
}
impl<T> CursorSpecification<T> {
pub(crate) fn new(
info: operation::CursorInfo<T>,
address: ServerAddress,
batch_size: impl Into<Option<u32>>,
max_time: impl Into<Option<Duration>>,
) -> Self {
Self {
info: CursorInformation {
ns: info.ns,
id: info.id,
address,
batch_size: batch_size.into(),
max_time: max_time.into(),
},
initial_buffer: info.first_batch,
}
}
pub(crate) fn id(&self) -> i64 {
self.info.id
}
#[cfg(test)]
pub(crate) fn address(&self) -> &ServerAddress {
&self.info.address
}
#[cfg(test)]
pub(crate) fn batch_size(&self) -> Option<u32> {
self.info.batch_size
}
#[cfg(test)]
pub(crate) fn max_time(&self) -> Option<Duration> {
self.info.max_time
}
}
#[derive(Clone, Debug)]
pub(crate) struct CursorInformation {
pub(crate) ns: Namespace,
pub(crate) address: ServerAddress,
pub(crate) id: i64,
pub(crate) batch_size: Option<u32>,
pub(crate) max_time: Option<Duration>,
}