use std::{
pin::Pin,
task::{ready, Context, Poll},
};
use crate::{
bson::{RawArray, RawDocument},
cursor::common::CursorSpecification,
operation::GetMore,
};
use futures_core::{future::BoxFuture, Future, Stream};
#[cfg(test)]
use tokio::sync::oneshot;
use crate::{
bson::RawDocumentBuf,
change_stream::event::ResumeToken,
client::{options::ServerAddress, AsyncDropToken},
cmap::conn::PinnedConnectionHandle,
cursor::common::{kill_cursor, PinnedConnection},
error::{Error, ErrorKind, Result},
Client,
ClientSession,
};
use super::common::CursorInformation;
const CURSOR: &str = "cursor";
const FIRST_BATCH: &str = "firstBatch";
const NEXT_BATCH: &str = "nextBatch";
#[derive(Clone, Debug)]
pub struct RawBatch {
reply: RawDocumentBuf,
}
impl RawBatch {
pub(crate) fn new(reply: RawDocumentBuf) -> Self {
Self { reply }
}
pub fn doc_slices(&self) -> Result<&RawArray> {
let root = self.reply.as_ref();
let cursor = root
.get_document(CURSOR)
.map_err(|_| Error::invalid_response("missing cursor subdocument"))?;
let docs = cursor
.get(FIRST_BATCH)?
.or_else(|| cursor.get(NEXT_BATCH).ok().flatten())
.ok_or_else(|| {
Error::invalid_response(format!("missing {FIRST_BATCH}/{NEXT_BATCH}"))
})?;
docs.as_array()
.ok_or_else(|| Error::invalid_response(format!("invalid {FIRST_BATCH}/{NEXT_BATCH}")))
}
pub fn as_raw_document(&self) -> &RawDocument {
self.reply.as_ref()
}
}
pub struct RawBatchCursor {
client: Client,
drop_token: AsyncDropToken,
info: CursorInformation,
state: RawBatchCursorState,
drop_address: Option<ServerAddress>,
#[cfg(test)]
kill_watcher: Option<oneshot::Sender<()>>,
}
#[allow(dead_code, unreachable_code, clippy::diverging_sub_expression)]
const _: fn() = || {
fn assert_unpin<T: Unpin>(_t: T) {}
let _rb: RawBatchCursor = todo!();
assert_unpin(_rb);
};
struct RawBatchCursorState {
exhausted: bool,
pinned_connection: PinnedConnection,
post_batch_resume_token: Option<ResumeToken>,
provider: GetMoreRawProvider<'static, ImplicitClientSessionHandle>,
buffered_reply: Option<RawDocumentBuf>,
}
impl crate::cursor::NewCursor for RawBatchCursor {
fn generic_new(
client: Client,
spec: CursorSpecification,
implicit_session: Option<ClientSession>,
pinned: Option<PinnedConnectionHandle>,
) -> Result<Self> {
Ok(Self::new(client, spec, implicit_session, pinned))
}
}
impl RawBatchCursor {
fn new(
client: Client,
spec: CursorSpecification,
session: Option<ClientSession>,
pin: Option<PinnedConnectionHandle>,
) -> Self {
let exhausted = spec.info.id == 0;
Self {
client: client.clone(),
drop_token: client.register_async_drop(),
info: spec.info,
drop_address: None,
#[cfg(test)]
kill_watcher: None,
state: RawBatchCursorState {
exhausted,
pinned_connection: PinnedConnection::new(pin),
post_batch_resume_token: spec.post_batch_resume_token,
provider: if exhausted {
GetMoreRawProvider::Done
} else {
GetMoreRawProvider::Idle(Box::new(ImplicitClientSessionHandle(session)))
},
buffered_reply: Some(spec.initial_reply),
},
}
}
pub(crate) fn is_exhausted(&self) -> bool {
self.state.exhausted
}
pub(crate) fn has_next(&self) -> bool {
if !self.is_exhausted() {
return true;
}
let Some(batch) = self
.state
.buffered_reply
.as_ref()
.and_then(|reply| reply.get_document(CURSOR).ok())
.and_then(|cursor| {
cursor
.get_array(FIRST_BATCH)
.or_else(|_| cursor.get_array(NEXT_BATCH))
.ok()
})
else {
return false;
};
!batch.is_empty()
}
pub(crate) fn post_batch_resume_token(&self) -> Option<&ResumeToken> {
self.state.post_batch_resume_token.as_ref()
}
pub(crate) fn address(&self) -> &ServerAddress {
&self.info.address
}
pub(crate) fn set_drop_address(&mut self, address: ServerAddress) {
self.drop_address = Some(address);
}
pub(crate) fn client(&self) -> &Client {
&self.client
}
fn mark_exhausted(&mut self) {
self.state.exhausted = true;
self.state.pinned_connection = PinnedConnection::Unpinned;
}
#[cfg(test)]
pub(crate) fn set_kill_watcher(&mut self, tx: oneshot::Sender<()>) {
assert!(
self.kill_watcher.is_none(),
"cursor already has a kill_watcher"
);
self.kill_watcher = Some(tx);
}
pub(crate) fn take_implicit_session(&mut self) -> Option<ClientSession> {
self.state.provider.take_implicit_session()
}
}
impl Stream for RawBatchCursor {
type Item = Result<RawBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if let Some(future) = self.state.provider.executing_future() {
let get_more_out = ready!(Pin::new(future).poll(cx));
match get_more_out.result {
Ok(out) => {
self.state.buffered_reply = Some(out.raw_reply);
self.state.post_batch_resume_token = out.post_batch_resume_token;
if out.exhausted {
self.mark_exhausted();
}
if out.id != 0 {
self.info.id = out.id;
}
self.info.ns = out.ns;
}
Err(e) => {
if matches!(*e.kind, ErrorKind::Command(ref ce) if ce.code == 43 || ce.code == 237)
{
self.mark_exhausted();
}
if e.is_network_error() {
self.state.pinned_connection.invalidate();
}
let exhausted_now = self.state.exhausted;
self.state
.provider
.clear_execution(get_more_out.session, exhausted_now);
return Poll::Ready(Some(Err(e)));
}
}
let exhausted_now = self.state.exhausted;
self.state
.provider
.clear_execution(get_more_out.session, exhausted_now);
}
if let Some(reply) = self.state.buffered_reply.take() {
return Poll::Ready(Some(Ok(RawBatch::new(reply))));
}
if !self.state.exhausted
&& !matches!(self.state.pinned_connection, PinnedConnection::Invalid(_))
{
let info = self.info.clone();
let client = self.client.clone();
let state = &mut self.state;
state
.provider
.start_execution(info, client, state.pinned_connection.handle());
continue;
}
return Poll::Ready(None);
}
}
}
impl Drop for RawBatchCursor {
fn drop(&mut self) {
if self.is_exhausted() {
return;
}
kill_cursor(
self.client.clone(),
&mut self.drop_token,
&self.info.ns,
self.info.id,
self.state.pinned_connection.replicate(),
self.drop_address.take(),
#[cfg(test)]
self.kill_watcher.take(),
);
}
}
#[derive(Debug)]
pub struct SessionRawBatchCursor {
client: Client,
drop_token: AsyncDropToken,
info: CursorInformation,
exhausted: bool,
pinned_connection: PinnedConnection,
post_batch_resume_token: Option<ResumeToken>,
buffered_reply: Option<RawDocumentBuf>,
drop_address: Option<ServerAddress>,
#[cfg(test)]
kill_watcher: Option<oneshot::Sender<()>>,
}
impl super::NewCursor for SessionRawBatchCursor {
fn generic_new(
client: Client,
spec: CursorSpecification,
_implicit_session: Option<ClientSession>,
pinned: Option<PinnedConnectionHandle>,
) -> Result<Self> {
Ok(Self::new(client, spec, pinned))
}
}
impl SessionRawBatchCursor {
fn new(
client: Client,
spec: CursorSpecification,
pinned: Option<PinnedConnectionHandle>,
) -> Self {
let exhausted = spec.info.id == 0;
Self {
drop_token: client.register_async_drop(),
client,
info: spec.info,
exhausted,
pinned_connection: PinnedConnection::new(pinned),
post_batch_resume_token: spec.post_batch_resume_token,
buffered_reply: Some(spec.initial_reply),
drop_address: None,
#[cfg(test)]
kill_watcher: None,
}
}
pub fn stream<'session>(
&mut self,
session: &'session mut ClientSession,
) -> SessionRawBatchCursorStream<'_, 'session> {
SessionRawBatchCursorStream {
parent: self,
provider: GetMoreRawProvider::Idle(Box::new(ExplicitClientSessionHandle(session))),
}
}
pub(crate) fn address(&self) -> &ServerAddress {
&self.info.address
}
pub(crate) fn set_drop_address(&mut self, address: ServerAddress) {
self.drop_address = Some(address);
}
fn mark_exhausted(&mut self) {
self.exhausted = true;
self.pinned_connection = PinnedConnection::Unpinned;
}
pub(crate) fn is_exhausted(&self) -> bool {
self.exhausted
}
pub(crate) fn post_batch_resume_token(&self) -> Option<&ResumeToken> {
self.post_batch_resume_token.as_ref()
}
#[cfg(test)]
pub(crate) fn set_kill_watcher(&mut self, tx: oneshot::Sender<()>) {
assert!(
self.kill_watcher.is_none(),
"cursor already has a kill_watcher"
);
self.kill_watcher = Some(tx);
}
pub(crate) fn client(&self) -> &Client {
&self.client
}
}
impl Drop for SessionRawBatchCursor {
fn drop(&mut self) {
if self.is_exhausted() {
return;
}
kill_cursor(
self.client.clone(),
&mut self.drop_token,
&self.info.ns,
self.info.id,
self.pinned_connection.replicate(),
self.drop_address.take(),
#[cfg(test)]
self.kill_watcher.take(),
);
}
}
pub struct SessionRawBatchCursorStream<'cursor, 'session> {
parent: &'cursor mut SessionRawBatchCursor,
provider: GetMoreRawProvider<'session, ExplicitClientSessionHandle<'session>>,
}
impl Stream for SessionRawBatchCursorStream<'_, '_> {
type Item = Result<RawBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if let Some(future) = self.provider.executing_future() {
let get_more_out = ready!(Pin::new(future).poll(cx));
match get_more_out.result {
Ok(out) => {
if out.exhausted {
self.parent.mark_exhausted();
}
if out.id != 0 {
self.parent.info.id = out.id;
}
self.parent.info.ns = out.ns;
self.parent.post_batch_resume_token = out.post_batch_resume_token;
self.parent.buffered_reply = Some(out.raw_reply);
}
Err(e) => {
if matches!(*e.kind, ErrorKind::Command(ref ce) if ce.code == 43 || ce.code == 237)
{
self.parent.mark_exhausted();
}
if e.is_network_error() {
self.parent.pinned_connection.invalidate();
}
let exhausted_now = self.parent.exhausted;
self.provider
.clear_execution(get_more_out.session, exhausted_now);
return Poll::Ready(Some(Err(e)));
}
}
let exhausted_now = self.parent.exhausted;
self.provider
.clear_execution(get_more_out.session, exhausted_now);
}
if let Some(reply) = self.parent.buffered_reply.take() {
return Poll::Ready(Some(Ok(RawBatch::new(reply))));
}
if !self.parent.exhausted
&& !matches!(self.parent.pinned_connection, PinnedConnection::Invalid(_))
{
let info = self.parent.info.clone();
let client = self.parent.client.clone();
let pinned_owned = self
.parent
.pinned_connection
.handle()
.map(|c| c.replicate());
let pinned_ref = pinned_owned.as_ref();
self.provider.start_execution(info, client, pinned_ref);
continue;
}
return Poll::Ready(None);
}
}
}
#[derive(Debug)]
struct GetMoreRawResultAndSession<S> {
result: Result<crate::results::GetMoreResult>,
session: S,
}
enum GetMoreRawProvider<'s, S> {
Executing(BoxFuture<'s, GetMoreRawResultAndSession<S>>),
Idle(Box<S>),
Done,
}
impl GetMoreRawProvider<'static, ImplicitClientSessionHandle> {
fn take_implicit_session(&mut self) -> Option<ClientSession> {
match self {
Self::Idle(session) => session.take_implicit_session(),
Self::Executing(..) | Self::Done => None,
}
}
}
impl<'s, S: ClientSessionHandle<'s>> GetMoreRawProvider<'s, S> {
fn executing_future(&mut self) -> Option<&mut BoxFuture<'s, GetMoreRawResultAndSession<S>>> {
if let Self::Executing(future) = self {
Some(future)
} else {
None
}
}
fn clear_execution(&mut self, session: S, exhausted: bool) {
if exhausted && session.is_implicit() {
*self = Self::Done
} else {
*self = Self::Idle(Box::new(session))
}
}
fn start_execution(
&mut self,
info: CursorInformation,
client: Client,
pinned_connection: Option<&PinnedConnectionHandle>,
) {
take_mut::take(self, |this| {
if let Self::Idle(mut session) = this {
let pinned = pinned_connection.map(|c| c.replicate());
let fut = Box::pin(async move {
let get_more = GetMore::new(info, pinned.as_ref());
let res = client
.execute_operation(get_more, session.borrow_mut())
.await;
GetMoreRawResultAndSession {
result: res,
session: *session,
}
});
Self::Executing(fut)
} else {
this
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bson::{doc, Document};
#[test]
fn raw_batch_into_docs_works() {
let reply_doc: Document = doc! {
"ok": 1,
"cursor": {
"id": 0_i64,
"ns": "db.coll",
"firstBatch": [
{ "x": 1 },
{ "x": 2 }
]
}
};
let mut bytes = Vec::new();
reply_doc.to_writer(&mut bytes).unwrap();
let raw = RawDocumentBuf::from_bytes(bytes).unwrap();
let batch = RawBatch::new(raw);
let docs: Vec<_> = batch.doc_slices().unwrap().into_iter().collect();
assert_eq!(docs.len(), 2);
}
}
#[derive(Debug)]
pub(super) struct ImplicitClientSessionHandle(pub(super) Option<ClientSession>);
impl ImplicitClientSessionHandle {
fn take_implicit_session(&mut self) -> Option<ClientSession> {
self.0.take()
}
}
impl ClientSessionHandle<'_> for ImplicitClientSessionHandle {
fn is_implicit(&self) -> bool {
true
}
fn borrow_mut(&mut self) -> Option<&mut ClientSession> {
self.0.as_mut()
}
}
pub(super) struct ExplicitClientSessionHandle<'a>(pub(super) &'a mut ClientSession);
impl<'a> ClientSessionHandle<'a> for ExplicitClientSessionHandle<'a> {
fn is_implicit(&self) -> bool {
false
}
fn borrow_mut(&mut self) -> Option<&mut ClientSession> {
Some(self.0)
}
}
pub(super) trait ClientSessionHandle<'a>: Send + 'a {
fn is_implicit(&self) -> bool;
fn borrow_mut(&mut self) -> Option<&mut ClientSession>;
}