1use std::{
43 pin::Pin,
44 task::{ready, Context, Poll},
45};
46
47use crate::{
48 bson::{RawArray, RawDocument},
49 cursor::common::CursorSpecification,
50 operation::GetMore,
51};
52use futures_core::{future::BoxFuture, Future, Stream};
53#[cfg(test)]
54use tokio::sync::oneshot;
55
56use crate::{
57 bson::RawDocumentBuf,
58 change_stream::event::ResumeToken,
59 client::{options::ServerAddress, AsyncDropToken},
60 cmap::conn::PinnedConnectionHandle,
61 cursor::common::{kill_cursor, PinnedConnection},
62 error::{Error, ErrorKind, Result},
63 Client,
64 ClientSession,
65};
66
67use super::common::CursorInformation;
68
69const CURSOR: &str = "cursor";
70const FIRST_BATCH: &str = "firstBatch";
71const NEXT_BATCH: &str = "nextBatch";
72
73#[derive(Clone, Debug)]
78pub struct RawBatch {
79 reply: RawDocumentBuf,
80}
81
82impl RawBatch {
83 pub(crate) fn new(reply: RawDocumentBuf) -> Self {
84 Self { reply }
85 }
86
87 pub fn doc_slices(&self) -> Result<&RawArray> {
92 let root = self.reply.as_ref();
93 let cursor = root
94 .get_document(CURSOR)
95 .map_err(|_| Error::invalid_response("missing cursor subdocument"))?;
96
97 let docs = cursor
98 .get(FIRST_BATCH)?
99 .or_else(|| cursor.get(NEXT_BATCH).ok().flatten())
100 .ok_or_else(|| {
101 Error::invalid_response(format!("missing {FIRST_BATCH}/{NEXT_BATCH}"))
102 })?;
103
104 docs.as_array()
105 .ok_or_else(|| Error::invalid_response(format!("invalid {FIRST_BATCH}/{NEXT_BATCH}")))
106 }
107
108 pub fn as_raw_document(&self) -> &RawDocument {
113 self.reply.as_ref()
114 }
115}
116
117pub struct RawBatchCursor {
119 client: Client,
120 drop_token: AsyncDropToken,
121 info: CursorInformation,
122 state: RawBatchCursorState,
123 drop_address: Option<ServerAddress>,
124 #[cfg(test)]
125 kill_watcher: Option<oneshot::Sender<()>>,
126}
127
128#[allow(dead_code, unreachable_code, clippy::diverging_sub_expression)]
129const _: fn() = || {
130 fn assert_unpin<T: Unpin>(_t: T) {}
131
132 let _rb: RawBatchCursor = todo!();
133 assert_unpin(_rb);
134};
135
136struct RawBatchCursorState {
137 exhausted: bool,
138 pinned_connection: PinnedConnection,
139 post_batch_resume_token: Option<ResumeToken>,
140 provider: GetMoreRawProvider<'static, ImplicitClientSessionHandle>,
141 buffered_reply: Option<RawDocumentBuf>,
142}
143
144impl crate::cursor::NewCursor for RawBatchCursor {
145 fn generic_new(
146 client: Client,
147 spec: CursorSpecification,
148 implicit_session: Option<ClientSession>,
149 pinned: Option<PinnedConnectionHandle>,
150 ) -> Result<Self> {
151 Ok(Self::new(client, spec, implicit_session, pinned))
152 }
153}
154
155impl RawBatchCursor {
156 fn new(
157 client: Client,
158 spec: CursorSpecification,
159 session: Option<ClientSession>,
160 pin: Option<PinnedConnectionHandle>,
161 ) -> Self {
162 let exhausted = spec.info.id == 0;
163 Self {
164 client: client.clone(),
165 drop_token: client.register_async_drop(),
166 info: spec.info,
167 drop_address: None,
168 #[cfg(test)]
169 kill_watcher: None,
170 state: RawBatchCursorState {
171 exhausted,
172 pinned_connection: PinnedConnection::new(pin),
173 post_batch_resume_token: spec.post_batch_resume_token,
174 provider: if exhausted {
175 GetMoreRawProvider::Done
176 } else {
177 GetMoreRawProvider::Idle(Box::new(ImplicitClientSessionHandle(session)))
178 },
179 buffered_reply: Some(spec.initial_reply),
180 },
181 }
182 }
183
184 pub(crate) fn is_exhausted(&self) -> bool {
185 self.state.exhausted
186 }
187
188 pub(crate) fn has_next(&self) -> bool {
189 if !self.is_exhausted() {
190 return true;
191 }
192 let Some(batch) = self
193 .state
194 .buffered_reply
195 .as_ref()
196 .and_then(|reply| reply.get_document(CURSOR).ok())
197 .and_then(|cursor| {
198 cursor
199 .get_array(FIRST_BATCH)
200 .or_else(|_| cursor.get_array(NEXT_BATCH))
201 .ok()
202 })
203 else {
204 return false;
205 };
206 !batch.is_empty()
207 }
208
209 pub(crate) fn post_batch_resume_token(&self) -> Option<&ResumeToken> {
210 self.state.post_batch_resume_token.as_ref()
211 }
212
213 pub(crate) fn address(&self) -> &ServerAddress {
214 &self.info.address
215 }
216
217 pub(crate) fn set_drop_address(&mut self, address: ServerAddress) {
218 self.drop_address = Some(address);
219 }
220
221 pub(crate) fn client(&self) -> &Client {
222 &self.client
223 }
224
225 fn mark_exhausted(&mut self) {
226 self.state.exhausted = true;
227 self.state.pinned_connection = PinnedConnection::Unpinned;
228 }
229
230 #[cfg(test)]
231 pub(crate) fn set_kill_watcher(&mut self, tx: oneshot::Sender<()>) {
232 assert!(
233 self.kill_watcher.is_none(),
234 "cursor already has a kill_watcher"
235 );
236 self.kill_watcher = Some(tx);
237 }
238
239 pub(crate) fn take_implicit_session(&mut self) -> Option<ClientSession> {
241 self.state.provider.take_implicit_session()
242 }
243}
244
245impl Stream for RawBatchCursor {
246 type Item = Result<RawBatch>;
247
248 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
249 loop {
250 if let Some(future) = self.state.provider.executing_future() {
252 let get_more_out = ready!(Pin::new(future).poll(cx));
253 match get_more_out.result {
254 Ok(out) => {
255 self.state.buffered_reply = Some(out.raw_reply);
256 self.state.post_batch_resume_token = out.post_batch_resume_token;
257 if out.exhausted {
258 self.mark_exhausted();
259 }
260 if out.id != 0 {
261 self.info.id = out.id;
262 }
263 self.info.ns = out.ns;
264 }
265 Err(e) => {
266 if matches!(*e.kind, ErrorKind::Command(ref ce) if ce.code == 43 || ce.code == 237)
267 {
268 self.mark_exhausted();
269 }
270 if e.is_network_error() {
271 self.state.pinned_connection.invalidate();
274 }
275 let exhausted_now = self.state.exhausted;
276 self.state
277 .provider
278 .clear_execution(get_more_out.session, exhausted_now);
279 return Poll::Ready(Some(Err(e)));
280 }
281 }
282 let exhausted_now = self.state.exhausted;
283 self.state
284 .provider
285 .clear_execution(get_more_out.session, exhausted_now);
286 }
287
288 if let Some(reply) = self.state.buffered_reply.take() {
290 return Poll::Ready(Some(Ok(RawBatch::new(reply))));
291 }
292
293 if !self.state.exhausted
295 && !matches!(self.state.pinned_connection, PinnedConnection::Invalid(_))
296 {
297 let info = self.info.clone();
298 let client = self.client.clone();
299 let state = &mut self.state;
300 state
301 .provider
302 .start_execution(info, client, state.pinned_connection.handle());
303 continue;
304 }
305
306 return Poll::Ready(None);
308 }
309 }
310}
311
312impl Drop for RawBatchCursor {
313 fn drop(&mut self) {
314 if self.is_exhausted() {
315 return;
316 }
317 kill_cursor(
318 self.client.clone(),
319 &mut self.drop_token,
320 &self.info.ns,
321 self.info.id,
322 self.state.pinned_connection.replicate(),
323 self.drop_address.take(),
324 #[cfg(test)]
325 self.kill_watcher.take(),
326 );
327 }
328}
329
330#[derive(Debug)]
332pub struct SessionRawBatchCursor {
333 client: Client,
334 drop_token: AsyncDropToken,
335 info: CursorInformation,
336 exhausted: bool,
337 pinned_connection: PinnedConnection,
338 post_batch_resume_token: Option<ResumeToken>,
339 buffered_reply: Option<RawDocumentBuf>,
340 drop_address: Option<ServerAddress>,
341 #[cfg(test)]
342 kill_watcher: Option<oneshot::Sender<()>>,
343}
344
345impl super::NewCursor for SessionRawBatchCursor {
346 fn generic_new(
347 client: Client,
348 spec: CursorSpecification,
349 _implicit_session: Option<ClientSession>,
350 pinned: Option<PinnedConnectionHandle>,
351 ) -> Result<Self> {
352 Ok(Self::new(client, spec, pinned))
353 }
354}
355
356impl SessionRawBatchCursor {
357 fn new(
358 client: Client,
359 spec: CursorSpecification,
360 pinned: Option<PinnedConnectionHandle>,
361 ) -> Self {
362 let exhausted = spec.info.id == 0;
363 Self {
364 drop_token: client.register_async_drop(),
365 client,
366 info: spec.info,
367 exhausted,
368 pinned_connection: PinnedConnection::new(pinned),
369 post_batch_resume_token: spec.post_batch_resume_token,
370 buffered_reply: Some(spec.initial_reply),
371 drop_address: None,
372 #[cfg(test)]
373 kill_watcher: None,
374 }
375 }
376
377 pub fn stream<'session>(
380 &mut self,
381 session: &'session mut ClientSession,
382 ) -> SessionRawBatchCursorStream<'_, 'session> {
383 SessionRawBatchCursorStream {
384 parent: self,
385 provider: GetMoreRawProvider::Idle(Box::new(ExplicitClientSessionHandle(session))),
386 }
387 }
388
389 pub(crate) fn address(&self) -> &ServerAddress {
390 &self.info.address
391 }
392
393 pub(crate) fn set_drop_address(&mut self, address: ServerAddress) {
394 self.drop_address = Some(address);
395 }
396
397 fn mark_exhausted(&mut self) {
398 self.exhausted = true;
399 self.pinned_connection = PinnedConnection::Unpinned;
400 }
401
402 pub(crate) fn is_exhausted(&self) -> bool {
403 self.exhausted
404 }
405
406 pub(crate) fn post_batch_resume_token(&self) -> Option<&ResumeToken> {
407 self.post_batch_resume_token.as_ref()
408 }
409
410 #[cfg(test)]
411 pub(crate) fn set_kill_watcher(&mut self, tx: oneshot::Sender<()>) {
412 assert!(
413 self.kill_watcher.is_none(),
414 "cursor already has a kill_watcher"
415 );
416 self.kill_watcher = Some(tx);
417 }
418
419 pub(crate) fn client(&self) -> &Client {
420 &self.client
421 }
422}
423
424impl Drop for SessionRawBatchCursor {
425 fn drop(&mut self) {
426 if self.is_exhausted() {
427 return;
428 }
429 kill_cursor(
430 self.client.clone(),
431 &mut self.drop_token,
432 &self.info.ns,
433 self.info.id,
434 self.pinned_connection.replicate(),
435 self.drop_address.take(),
436 #[cfg(test)]
437 self.kill_watcher.take(),
438 );
439 }
440}
441
442pub struct SessionRawBatchCursorStream<'cursor, 'session> {
445 parent: &'cursor mut SessionRawBatchCursor,
446 provider: GetMoreRawProvider<'session, ExplicitClientSessionHandle<'session>>,
447}
448
449impl Stream for SessionRawBatchCursorStream<'_, '_> {
450 type Item = Result<RawBatch>;
451
452 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
453 loop {
454 if let Some(future) = self.provider.executing_future() {
456 let get_more_out = ready!(Pin::new(future).poll(cx));
457 match get_more_out.result {
458 Ok(out) => {
459 if out.exhausted {
460 self.parent.mark_exhausted();
461 }
462 if out.id != 0 {
463 self.parent.info.id = out.id;
464 }
465 self.parent.info.ns = out.ns;
466 self.parent.post_batch_resume_token = out.post_batch_resume_token;
467 self.parent.buffered_reply = Some(out.raw_reply);
469 }
470 Err(e) => {
471 if matches!(*e.kind, ErrorKind::Command(ref ce) if ce.code == 43 || ce.code == 237)
472 {
473 self.parent.mark_exhausted();
474 }
475 if e.is_network_error() {
476 self.parent.pinned_connection.invalidate();
479 }
480 let exhausted_now = self.parent.exhausted;
481 self.provider
482 .clear_execution(get_more_out.session, exhausted_now);
483 return Poll::Ready(Some(Err(e)));
484 }
485 }
486 let exhausted_now = self.parent.exhausted;
487 self.provider
488 .clear_execution(get_more_out.session, exhausted_now);
489 }
490
491 if let Some(reply) = self.parent.buffered_reply.take() {
493 return Poll::Ready(Some(Ok(RawBatch::new(reply))));
494 }
495
496 if !self.parent.exhausted
498 && !matches!(self.parent.pinned_connection, PinnedConnection::Invalid(_))
499 {
500 let info = self.parent.info.clone();
501 let client = self.parent.client.clone();
502 let pinned_owned = self
503 .parent
504 .pinned_connection
505 .handle()
506 .map(|c| c.replicate());
507 let pinned_ref = pinned_owned.as_ref();
508 self.provider.start_execution(info, client, pinned_ref);
509 continue;
510 }
511
512 return Poll::Ready(None);
514 }
515 }
516}
517
518#[derive(Debug)]
519struct GetMoreRawResultAndSession<S> {
520 result: Result<crate::results::GetMoreResult>,
521 session: S,
522}
523
524enum GetMoreRawProvider<'s, S> {
525 Executing(BoxFuture<'s, GetMoreRawResultAndSession<S>>),
526 Idle(Box<S>),
527 Done,
528}
529
530impl GetMoreRawProvider<'static, ImplicitClientSessionHandle> {
531 fn take_implicit_session(&mut self) -> Option<ClientSession> {
534 match self {
535 Self::Idle(session) => session.take_implicit_session(),
536 Self::Executing(..) | Self::Done => None,
537 }
538 }
539}
540
541impl<'s, S: ClientSessionHandle<'s>> GetMoreRawProvider<'s, S> {
542 fn executing_future(&mut self) -> Option<&mut BoxFuture<'s, GetMoreRawResultAndSession<S>>> {
543 if let Self::Executing(future) = self {
544 Some(future)
545 } else {
546 None
547 }
548 }
549
550 fn clear_execution(&mut self, session: S, exhausted: bool) {
551 if exhausted && session.is_implicit() {
552 *self = Self::Done
553 } else {
554 *self = Self::Idle(Box::new(session))
555 }
556 }
557
558 fn start_execution(
559 &mut self,
560 info: CursorInformation,
561 client: Client,
562 pinned_connection: Option<&PinnedConnectionHandle>,
563 ) {
564 take_mut::take(self, |this| {
565 if let Self::Idle(mut session) = this {
566 let pinned = pinned_connection.map(|c| c.replicate());
567 let fut = Box::pin(async move {
568 let get_more = GetMore::new(info, pinned.as_ref());
569 let res = client
570 .execute_operation(get_more, session.borrow_mut())
571 .await;
572 GetMoreRawResultAndSession {
573 result: res,
574 session: *session,
575 }
576 });
577 Self::Executing(fut)
578 } else {
579 this
580 }
581 })
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588 use crate::bson::{doc, Document};
589
590 #[test]
591 fn raw_batch_into_docs_works() {
592 let reply_doc: Document = doc! {
593 "ok": 1,
594 "cursor": {
595 "id": 0_i64,
596 "ns": "db.coll",
597 "firstBatch": [
598 { "x": 1 },
599 { "x": 2 }
600 ]
601 }
602 };
603 let mut bytes = Vec::new();
604 reply_doc.to_writer(&mut bytes).unwrap();
605 let raw = RawDocumentBuf::from_bytes(bytes).unwrap();
606
607 let batch = RawBatch::new(raw);
608 let docs: Vec<_> = batch.doc_slices().unwrap().into_iter().collect();
609 assert_eq!(docs.len(), 2);
610 }
611}
612
613#[derive(Debug)]
614pub(super) struct ImplicitClientSessionHandle(pub(super) Option<ClientSession>);
615
616impl ImplicitClientSessionHandle {
617 fn take_implicit_session(&mut self) -> Option<ClientSession> {
618 self.0.take()
619 }
620}
621
622impl ClientSessionHandle<'_> for ImplicitClientSessionHandle {
623 fn is_implicit(&self) -> bool {
624 true
625 }
626
627 fn borrow_mut(&mut self) -> Option<&mut ClientSession> {
628 self.0.as_mut()
629 }
630}
631
632pub(super) struct ExplicitClientSessionHandle<'a>(pub(super) &'a mut ClientSession);
633
634impl<'a> ClientSessionHandle<'a> for ExplicitClientSessionHandle<'a> {
635 fn is_implicit(&self) -> bool {
636 false
637 }
638
639 fn borrow_mut(&mut self) -> Option<&mut ClientSession> {
640 Some(self.0)
641 }
642}
643
644pub(super) trait ClientSessionHandle<'a>: Send + 'a {
645 fn is_implicit(&self) -> bool;
646
647 fn borrow_mut(&mut self) -> Option<&mut ClientSession>;
648}