use std::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use bson::RawDocument;
use futures_core::Stream;
use futures_util::StreamExt;
use serde::{de::DeserializeOwned, Deserialize};
#[cfg(test)]
use tokio::sync::oneshot;
use super::{
common::{
kill_cursor,
CursorBuffer,
CursorInformation,
CursorState,
GenericCursor,
PinnedConnection,
},
stream_poll_next,
BatchValue,
CursorStream,
};
use crate::{
bson::Document,
change_stream::event::ResumeToken,
client::{options::ServerAddress, AsyncDropToken},
cmap::conn::PinnedConnectionHandle,
cursor::{common::ExplicitClientSessionHandle, CursorSpecification},
error::{Error, Result},
Client,
ClientSession,
};
#[derive(Debug)]
pub struct SessionCursor<T> {
client: Client,
drop_token: AsyncDropToken,
info: CursorInformation,
state: Option<CursorState>,
drop_address: Option<ServerAddress>,
_phantom: PhantomData<T>,
#[cfg(test)]
kill_watcher: Option<oneshot::Sender<()>>,
}
impl<T> SessionCursor<T> {
pub(crate) fn new(
client: Client,
spec: CursorSpecification,
pinned: Option<PinnedConnectionHandle>,
) -> Self {
let exhausted = spec.info.id == 0;
Self {
drop_token: client.register_async_drop(),
client,
info: spec.info,
drop_address: None,
_phantom: Default::default(),
#[cfg(test)]
kill_watcher: None,
state: CursorState {
buffer: CursorBuffer::new(spec.initial_buffer),
exhausted,
post_batch_resume_token: None,
pinned_connection: PinnedConnection::new(pinned),
}
.into(),
}
}
}
impl<T> SessionCursor<T>
where
T: DeserializeOwned,
{
pub fn stream<'session>(
&mut self,
session: &'session mut ClientSession,
) -> SessionCursorStream<'_, 'session, T> {
self.make_stream(session)
}
pub async fn next(&mut self, session: &mut ClientSession) -> Option<Result<T>> {
self.stream(session).next().await
}
}
impl<T> SessionCursor<T> {
fn make_stream<'session>(
&mut self,
session: &'session mut ClientSession,
) -> SessionCursorStream<'_, 'session, T> {
SessionCursorStream {
generic_cursor: ExplicitSessionCursor::with_explicit_session(
self.take_state(),
self.client.clone(),
self.info.clone(),
ExplicitClientSessionHandle(session),
),
session_cursor: self,
}
}
fn take_state(&mut self) -> CursorState {
self.state.take().unwrap()
}
pub async fn advance(&mut self, session: &mut ClientSession) -> Result<bool> {
self.make_stream(session).generic_cursor.advance().await
}
#[cfg(test)]
pub(crate) async fn try_advance(&mut self, session: &mut ClientSession) -> Result<()> {
self.make_stream(session)
.generic_cursor
.try_advance()
.await
.map(|_| ())
}
pub fn current(&self) -> &RawDocument {
self.state.as_ref().unwrap().buffer.current().unwrap()
}
pub fn deserialize_current<'a>(&'a self) -> Result<T>
where
T: Deserialize<'a>,
{
bson::from_slice(self.current().as_bytes()).map_err(Error::from)
}
pub fn with_type<'a, D>(mut self) -> SessionCursor<D>
where
D: Deserialize<'a>,
{
SessionCursor {
client: self.client.clone(),
drop_token: self.drop_token.take(),
info: self.info.clone(),
state: Some(self.take_state()),
drop_address: self.drop_address.take(),
_phantom: Default::default(),
#[cfg(test)]
kill_watcher: self.kill_watcher.take(),
}
}
pub(crate) fn address(&self) -> &ServerAddress {
&self.info.address
}
pub(crate) fn set_drop_address(&mut self, address: ServerAddress) {
self.drop_address = Some(address);
}
#[cfg(test)]
pub(crate) fn set_kill_watcher(&mut self, tx: oneshot::Sender<()>) {
assert!(
self.kill_watcher.is_none(),
"cursor already has a kill_watcher"
);
self.kill_watcher = Some(tx);
}
}
impl<T> SessionCursor<T> {
pub(crate) fn is_exhausted(&self) -> bool {
self.state.as_ref().map_or(true, |state| state.exhausted)
}
#[cfg(test)]
pub(crate) fn client(&self) -> &Client {
&self.client
}
}
impl<T> Drop for SessionCursor<T> {
fn drop(&mut self) {
if self.is_exhausted() {
return;
}
kill_cursor(
self.client.clone(),
&mut self.drop_token,
&self.info.ns,
self.info.id,
self.state.as_ref().unwrap().pinned_connection.replicate(),
self.drop_address.take(),
#[cfg(test)]
self.kill_watcher.take(),
);
}
}
type ExplicitSessionCursor<'session> =
GenericCursor<'session, ExplicitClientSessionHandle<'session>>;
pub struct SessionCursorStream<'cursor, 'session, T = Document> {
session_cursor: &'cursor mut SessionCursor<T>,
generic_cursor: ExplicitSessionCursor<'session>,
}
impl<'cursor, 'session, T> SessionCursorStream<'cursor, 'session, T>
where
T: DeserializeOwned,
{
pub(crate) fn post_batch_resume_token(&self) -> Option<&ResumeToken> {
self.generic_cursor.post_batch_resume_token()
}
pub(crate) fn client(&self) -> &Client {
&self.session_cursor.client
}
}
impl<'cursor, 'session, T> Stream for SessionCursorStream<'cursor, 'session, T>
where
T: DeserializeOwned,
{
type Item = Result<T>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
stream_poll_next(&mut self.generic_cursor, cx)
}
}
impl<'cursor, 'session, T> CursorStream for SessionCursorStream<'cursor, 'session, T>
where
T: DeserializeOwned,
{
fn poll_next_in_batch(&mut self, cx: &mut Context<'_>) -> Poll<Result<BatchValue>> {
self.generic_cursor.poll_next_in_batch(cx)
}
}
impl<'cursor, 'session, T> Drop for SessionCursorStream<'cursor, 'session, T> {
fn drop(&mut self) {
self.session_cursor.state = Some(self.generic_cursor.take_state());
}
}