mod common;
pub(crate) mod session;
use std::{
pin::Pin,
task::{Context, Poll},
};
use bson::RawDocument;
use futures_core::{future::BoxFuture, Stream};
use serde::{de::DeserializeOwned, Deserialize};
#[cfg(test)]
use tokio::sync::oneshot;
use crate::{
change_stream::event::ResumeToken,
client::options::ServerAddress,
cmap::conn::PinnedConnectionHandle,
error::{Error, Result},
operation::GetMore,
results::GetMoreResult,
Client,
ClientSession,
};
use common::{kill_cursor, GenericCursor, GetMoreProvider, GetMoreProviderResult};
pub(crate) use common::{
stream_poll_next,
BatchValue,
CursorInformation,
CursorSpecification,
CursorStream,
NextInBatchFuture,
PinnedConnection,
};
#[derive(Debug)]
pub struct Cursor<T> {
client: Client,
wrapped_cursor: Option<ImplicitSessionCursor<T>>,
drop_address: Option<ServerAddress>,
#[cfg(test)]
kill_watcher: Option<oneshot::Sender<()>>,
_phantom: std::marker::PhantomData<T>,
}
impl<T> Cursor<T> {
pub(crate) fn new(
client: Client,
spec: CursorSpecification,
session: Option<ClientSession>,
pin: Option<PinnedConnectionHandle>,
) -> Self {
let provider = ImplicitSessionGetMoreProvider::new(&spec, session);
Self {
client: client.clone(),
wrapped_cursor: Some(ImplicitSessionCursor::new(
client,
spec,
PinnedConnection::new(pin),
provider,
)),
drop_address: None,
#[cfg(test)]
kill_watcher: None,
_phantom: Default::default(),
}
}
pub(crate) fn post_batch_resume_token(&self) -> Option<&ResumeToken> {
self.wrapped_cursor
.as_ref()
.and_then(|c| c.post_batch_resume_token())
}
pub(crate) fn is_exhausted(&self) -> bool {
self.wrapped_cursor.as_ref().unwrap().is_exhausted()
}
pub(crate) fn client(&self) -> &Client {
&self.client
}
pub(crate) fn address(&self) -> &ServerAddress {
self.wrapped_cursor.as_ref().unwrap().address()
}
pub(crate) fn set_drop_address(&mut self, address: ServerAddress) {
self.drop_address = Some(address);
}
pub(crate) fn take_implicit_session(&mut self) -> Option<ClientSession> {
self.wrapped_cursor
.as_mut()
.and_then(|c| c.provider_mut().take_implicit_session())
}
pub async fn advance(&mut self) -> Result<bool> {
self.wrapped_cursor.as_mut().unwrap().advance().await
}
pub fn current(&self) -> &RawDocument {
self.wrapped_cursor.as_ref().unwrap().current().unwrap()
}
pub fn deserialize_current<'a>(&'a self) -> Result<T>
where
T: Deserialize<'a>,
{
bson::from_slice(self.current().as_bytes()).map_err(Error::from)
}
pub fn with_type<'a, D>(mut self) -> Cursor<D>
where
D: Deserialize<'a>,
{
Cursor {
client: self.client.clone(),
wrapped_cursor: self.wrapped_cursor.take().map(|c| c.with_type()),
drop_address: self.drop_address.take(),
#[cfg(test)]
kill_watcher: self.kill_watcher.take(),
_phantom: Default::default(),
}
}
#[cfg(test)]
pub(crate) fn set_kill_watcher(&mut self, tx: oneshot::Sender<()>) {
assert!(
self.kill_watcher.is_none(),
"cursor already has a kill_watcher"
);
self.kill_watcher = Some(tx);
}
}
impl<T> CursorStream for Cursor<T>
where
T: DeserializeOwned + Unpin + Send + Sync,
{
fn poll_next_in_batch(&mut self, cx: &mut Context<'_>) -> Poll<Result<BatchValue>> {
self.wrapped_cursor.as_mut().unwrap().poll_next_in_batch(cx)
}
}
impl<T> Stream for Cursor<T>
where
T: DeserializeOwned + Unpin + Send + Sync,
{
type Item = Result<T>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(self.wrapped_cursor.as_mut().unwrap()).poll_next(cx)
}
}
impl<T> Drop for Cursor<T> {
fn drop(&mut self) {
let wrapped_cursor = match &self.wrapped_cursor {
None => return,
Some(c) => c,
};
if wrapped_cursor.is_exhausted() {
return;
}
kill_cursor(
self.client.clone(),
wrapped_cursor.namespace(),
wrapped_cursor.id(),
wrapped_cursor.pinned_connection().replicate(),
self.drop_address.take(),
#[cfg(test)]
self.kill_watcher.take(),
);
}
}
type ImplicitSessionCursor<T> = GenericCursor<ImplicitSessionGetMoreProvider, T>;
struct ImplicitSessionGetMoreResult {
get_more_result: Result<GetMoreResult>,
session: Option<Box<ClientSession>>,
}
impl GetMoreProviderResult for ImplicitSessionGetMoreResult {
type Session = Option<Box<ClientSession>>;
fn as_ref(&self) -> std::result::Result<&GetMoreResult, &Error> {
self.get_more_result.as_ref()
}
fn into_parts(self) -> (Result<GetMoreResult>, Self::Session) {
(self.get_more_result, self.session)
}
}
enum ImplicitSessionGetMoreProvider {
Executing(BoxFuture<'static, ImplicitSessionGetMoreResult>),
Idle(Option<Box<ClientSession>>),
Done,
}
impl ImplicitSessionGetMoreProvider {
fn new(spec: &CursorSpecification, session: Option<ClientSession>) -> Self {
let session = session.map(Box::new);
if spec.id() == 0 {
Self::Done
} else {
Self::Idle(session)
}
}
fn take_implicit_session(&mut self) -> Option<ClientSession> {
match self {
ImplicitSessionGetMoreProvider::Idle(session) => session.take().map(|s| *s),
_ => None,
}
}
}
impl GetMoreProvider for ImplicitSessionGetMoreProvider {
type ResultType = ImplicitSessionGetMoreResult;
type GetMoreFuture = BoxFuture<'static, ImplicitSessionGetMoreResult>;
fn executing_future(&mut self) -> Option<&mut Self::GetMoreFuture> {
match self {
Self::Executing(ref mut future) => Some(future),
Self::Idle { .. } | Self::Done => None,
}
}
fn clear_execution(&mut self, session: Option<Box<ClientSession>>, exhausted: bool) {
if exhausted {
*self = Self::Done;
} else {
*self = Self::Idle(session);
}
}
fn start_execution(
&mut self,
info: CursorInformation,
client: Client,
pinned_connection: Option<&PinnedConnectionHandle>,
) {
take_mut::take(self, |self_| match self_ {
Self::Idle(mut session) => {
let pinned_connection = pinned_connection.map(|c| c.replicate());
let future = Box::pin(async move {
let get_more = GetMore::new(info, pinned_connection.as_ref());
let get_more_result = client
.execute_operation(get_more, session.as_mut().map(|b| b.as_mut()))
.await;
ImplicitSessionGetMoreResult {
get_more_result,
session,
}
});
Self::Executing(future)
}
Self::Executing(_) | Self::Done => self_,
})
}
fn execute(
&mut self,
info: CursorInformation,
client: Client,
pinned_connection: PinnedConnection,
) -> BoxFuture<'_, Result<GetMoreResult>> {
match self {
Self::Idle(ref mut session) => Box::pin(async move {
let get_more = GetMore::new(info, pinned_connection.handle());
let get_more_result = client
.execute_operation(get_more, session.as_mut().map(|b| b.as_mut()))
.await;
get_more_result
}),
Self::Executing(_fut) => Box::pin(async {
Err(Error::internal(
"streaming the cursor was cancelled while a request was in progress and must \
be continued before iterating manually",
))
}),
Self::Done => {
Box::pin(async { Err(Error::internal("cursor iterated after already exhausted")) })
}
}
}
}