use std::{
collections::VecDeque,
pin::Pin,
task::{Context, Poll},
};
use futures_core::{future::BoxFuture, Stream};
use futures_util::StreamExt;
use serde::de::DeserializeOwned;
use super::common::{CursorInformation, GenericCursor, GetMoreProvider, GetMoreProviderResult};
use crate::{
bson::Document,
cursor::CursorSpecification,
error::{Error, Result},
operation::GetMore,
results::GetMoreResult,
Client,
ClientSession,
RUNTIME,
};
#[derive(Debug)]
pub struct SessionCursor<T>
where
T: DeserializeOwned + Unpin,
{
exhausted: bool,
client: Client,
info: CursorInformation,
buffer: VecDeque<T>,
}
impl<T> SessionCursor<T>
where
T: DeserializeOwned + Unpin + Send + Sync,
{
pub(crate) fn new(client: Client, spec: CursorSpecification<T>) -> Self {
let exhausted = spec.id() == 0;
Self {
exhausted,
client,
info: spec.info,
buffer: spec.initial_buffer,
}
}
pub fn stream<'session>(
&mut self,
session: &'session mut ClientSession,
) -> SessionCursorStream<'_, 'session, T> {
let get_more_provider = ExplicitSessionGetMoreProvider::new(session);
let spec = CursorSpecification {
info: self.info.clone(),
initial_buffer: std::mem::take(&mut self.buffer),
};
SessionCursorStream {
generic_cursor: ExplicitSessionCursor::new(
self.client.clone(),
spec,
get_more_provider,
),
session_cursor: self,
}
}
pub async fn next(&mut self, session: &mut ClientSession) -> Option<Result<T>> {
self.stream(session).next().await
}
}
impl<T> Drop for SessionCursor<T>
where
T: DeserializeOwned + Unpin,
{
fn drop(&mut self) {
if self.exhausted {
return;
}
let ns = &self.info.ns;
let coll = self
.client
.database(ns.db.as_str())
.collection::<Document>(ns.coll.as_str());
let cursor_id = self.info.id;
RUNTIME.execute(async move { coll.kill_cursor(cursor_id).await });
}
}
type ExplicitSessionCursor<'session, T> =
GenericCursor<ExplicitSessionGetMoreProvider<'session, T>, T>;
pub struct SessionCursorStream<'cursor, 'session, T = Document>
where
T: DeserializeOwned + Unpin + Send + Sync,
{
session_cursor: &'cursor mut SessionCursor<T>,
generic_cursor: ExplicitSessionCursor<'session, T>,
}
impl<'cursor, 'session, T> Stream for SessionCursorStream<'cursor, 'session, T>
where
T: DeserializeOwned + Unpin + Send + Sync,
{
type Item = Result<T>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.generic_cursor).poll_next(cx)
}
}
impl<'cursor, 'session, T> Drop for SessionCursorStream<'cursor, 'session, T>
where
T: DeserializeOwned + Unpin + Send + Sync,
{
fn drop(&mut self) {
self.session_cursor.buffer = self.generic_cursor.take_buffer();
self.session_cursor.exhausted = self.generic_cursor.is_exhausted();
}
}
enum ExplicitSessionGetMoreProvider<'session, T> {
Executing(BoxFuture<'session, ExecutionResult<'session, T>>),
Idle(MutableSessionReference<'session>),
}
impl<'session, T> ExplicitSessionGetMoreProvider<'session, T> {
fn new(session: &'session mut ClientSession) -> Self {
Self::Idle(MutableSessionReference { reference: session })
}
}
impl<'session, T: Send + Sync + DeserializeOwned> GetMoreProvider
for ExplicitSessionGetMoreProvider<'session, T>
{
type DocumentType = T;
type ResultType = ExecutionResult<'session, T>;
type GetMoreFuture = BoxFuture<'session, ExecutionResult<'session, T>>;
fn executing_future(&mut self) -> Option<&mut Self::GetMoreFuture> {
match self {
Self::Executing(future) => Some(future),
Self::Idle(_) => None,
}
}
fn clear_execution(&mut self, session: &'session mut ClientSession, _exhausted: bool) {
*self = Self::Idle(MutableSessionReference { reference: session })
}
fn start_execution(&mut self, info: CursorInformation, client: Client) {
take_mut::take(self, |self_| {
if let ExplicitSessionGetMoreProvider::Idle(session) = self_ {
let future = Box::pin(async move {
let get_more = GetMore::new(info);
let get_more_result = client
.execute_operation(get_more, Some(&mut *session.reference))
.await;
ExecutionResult {
get_more_result,
session: session.reference,
}
});
return ExplicitSessionGetMoreProvider::Executing(future);
}
self_
});
}
}
struct ExecutionResult<'session, T> {
get_more_result: Result<GetMoreResult<T>>,
session: &'session mut ClientSession,
}
impl<'session, T> GetMoreProviderResult for ExecutionResult<'session, T> {
type Session = &'session mut ClientSession;
type DocumentType = T;
fn as_ref(&self) -> std::result::Result<&GetMoreResult<T>, &Error> {
self.get_more_result.as_ref()
}
fn into_parts(self) -> (Result<GetMoreResult<T>>, Self::Session) {
(self.get_more_result, self.session)
}
}
struct MutableSessionReference<'a> {
reference: &'a mut ClientSession,
}