mongodb/cursor/session.rs
1use std::{
2 marker::PhantomData,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use crate::bson::RawDocument;
8use futures_core::Stream;
9use futures_util::StreamExt;
10use serde::{de::DeserializeOwned, Deserialize};
11#[cfg(test)]
12use tokio::sync::oneshot;
13
14use super::{
15 common::{
16 kill_cursor,
17 CursorBuffer,
18 CursorInformation,
19 CursorState,
20 GenericCursor,
21 PinnedConnection,
22 },
23 stream_poll_next,
24 BatchValue,
25 CursorStream,
26};
27use crate::{
28 bson::Document,
29 change_stream::event::ResumeToken,
30 client::{options::ServerAddress, AsyncDropToken},
31 cmap::conn::PinnedConnectionHandle,
32 cursor::{common::ExplicitClientSessionHandle, CursorSpecification},
33 error::{Error, Result},
34 Client,
35 ClientSession,
36};
37
38/// A [`SessionCursor`] is a cursor that was created with a [`ClientSession`] that must be iterated
39/// using one. To iterate, use [`SessionCursor::next`] or retrieve a [`SessionCursorStream`] using
40/// [`SessionCursor::stream`]:
41///
42/// ```rust
43/// # use mongodb::{bson::{Document, doc}, Client, error::Result, ClientSession, SessionCursor};
44/// #
45/// # async fn do_stuff() -> Result<()> {
46/// # let client = Client::with_uri_str("mongodb://example.com").await?;
47/// # let mut session = client.start_session().await?;
48/// # let coll = client.database("foo").collection::<Document>("bar");
49/// #
50/// // iterate using next()
51/// let mut cursor = coll.find(doc! {}).session(&mut session).await?;
52/// while let Some(doc) = cursor.next(&mut session).await.transpose()? {
53/// println!("{}", doc)
54/// }
55///
56/// // iterate using `Stream`:
57/// use futures::stream::TryStreamExt;
58///
59/// let mut cursor = coll.find(doc! {}).session(&mut session).await?;
60/// let results: Vec<_> = cursor.stream(&mut session).try_collect().await?;
61/// #
62/// # Ok(())
63/// # }
64/// ```
65///
66/// If a [`SessionCursor`] is still open when it goes out of scope, it will automatically be closed
67/// via an asynchronous [killCursors](https://www.mongodb.com/docs/manual/reference/command/killCursors/) command executed
68/// from its [`Drop`](https://doc.rust-lang.org/std/ops/trait.Drop.html) implementation.
69#[derive(Debug)]
70pub struct SessionCursor<T> {
71 client: Client,
72 drop_token: AsyncDropToken,
73 info: CursorInformation,
74 state: Option<CursorState>,
75 drop_address: Option<ServerAddress>,
76 _phantom: PhantomData<T>,
77 #[cfg(test)]
78 kill_watcher: Option<oneshot::Sender<()>>,
79}
80
81impl<T> SessionCursor<T> {
82 pub(crate) fn new(
83 client: Client,
84 spec: CursorSpecification,
85 pinned: Option<PinnedConnectionHandle>,
86 ) -> Self {
87 let exhausted = spec.info.id == 0;
88
89 Self {
90 drop_token: client.register_async_drop(),
91 client,
92 info: spec.info,
93 drop_address: None,
94 _phantom: Default::default(),
95 #[cfg(test)]
96 kill_watcher: None,
97 state: CursorState {
98 buffer: CursorBuffer::new(spec.initial_buffer),
99 exhausted,
100 post_batch_resume_token: None,
101 pinned_connection: PinnedConnection::new(pinned),
102 }
103 .into(),
104 }
105 }
106}
107
108impl<T> SessionCursor<T>
109where
110 T: DeserializeOwned,
111{
112 /// Retrieves a [`SessionCursorStream`] to iterate this cursor. The session provided must be the
113 /// same session used to create the cursor.
114 ///
115 /// Note that the borrow checker will not allow the session to be reused in between iterations
116 /// of this stream. In order to do that, either use [`SessionCursor::next`] instead or drop
117 /// the stream before using the session.
118 ///
119 /// ```
120 /// # use mongodb::{Client, bson::{doc, Document}, error::Result};
121 /// # fn main() {
122 /// # async {
123 /// # let client = Client::with_uri_str("foo").await?;
124 /// # let coll = client.database("foo").collection::<Document>("bar");
125 /// # let other_coll = coll.clone();
126 /// # let mut session = client.start_session().await?;
127 /// #
128 /// use futures::stream::TryStreamExt;
129 ///
130 /// // iterate over the results
131 /// let mut cursor = coll.find(doc! { "x": 1 }).session(&mut session).await?;
132 /// while let Some(doc) = cursor.stream(&mut session).try_next().await? {
133 /// println!("{}", doc);
134 /// }
135 ///
136 /// // collect the results
137 /// let mut cursor1 = coll.find(doc! { "x": 1 }).session(&mut session).await?;
138 /// let v: Vec<Document> = cursor1.stream(&mut session).try_collect().await?;
139 ///
140 /// // use session between iterations
141 /// let mut cursor2 = coll.find(doc! { "x": 1 }).session(&mut session).await?;
142 /// loop {
143 /// let doc = match cursor2.stream(&mut session).try_next().await? {
144 /// Some(d) => d,
145 /// None => break,
146 /// };
147 /// other_coll.insert_one(doc).session(&mut session).await?;
148 /// }
149 /// # Ok::<(), mongodb::error::Error>(())
150 /// # };
151 /// # }
152 /// ```
153 pub fn stream<'session>(
154 &mut self,
155 session: &'session mut ClientSession,
156 ) -> SessionCursorStream<'_, 'session, T> {
157 self.make_stream(session)
158 }
159
160 /// Retrieve the next result from the cursor.
161 /// The session provided must be the same session used to create the cursor.
162 ///
163 /// Use this method when the session needs to be used again between iterations or when the added
164 /// functionality of `Stream` is not needed.
165 ///
166 /// ```
167 /// # use mongodb::{Client, bson::{doc, Document}};
168 /// # fn main() {
169 /// # async {
170 /// # let client = Client::with_uri_str("foo").await?;
171 /// # let coll = client.database("foo").collection::<Document>("bar");
172 /// # let other_coll = coll.clone();
173 /// # let mut session = client.start_session().await?;
174 /// let mut cursor = coll.find(doc! { "x": 1 }).session(&mut session).await?;
175 /// while let Some(doc) = cursor.next(&mut session).await.transpose()? {
176 /// other_coll.insert_one(doc).session(&mut session).await?;
177 /// }
178 /// # Ok::<(), mongodb::error::Error>(())
179 /// # };
180 /// # }
181 /// ```
182 pub async fn next(&mut self, session: &mut ClientSession) -> Option<Result<T>> {
183 self.stream(session).next().await
184 }
185}
186
187impl<T> SessionCursor<T> {
188 fn make_stream<'session>(
189 &mut self,
190 session: &'session mut ClientSession,
191 ) -> SessionCursorStream<'_, 'session, T> {
192 // Pass the state into this cursor handle for iteration.
193 // It will be returned in the handle's `Drop` implementation.
194 SessionCursorStream {
195 generic_cursor: ExplicitSessionCursor::with_explicit_session(
196 self.take_state(),
197 self.client.clone(),
198 self.info.clone(),
199 ExplicitClientSessionHandle(session),
200 ),
201 session_cursor: self,
202 }
203 }
204
205 fn take_state(&mut self) -> CursorState {
206 self.state.take().unwrap()
207 }
208
209 /// Move the cursor forward, potentially triggering requests to the database for more results
210 /// if the local buffer has been exhausted.
211 ///
212 /// This will keep requesting data from the server until either the cursor is exhausted
213 /// or batch with results in it has been received.
214 ///
215 /// The return value indicates whether new results were successfully returned (true) or if
216 /// the cursor has been closed (false).
217 ///
218 /// Note: [`SessionCursor::current`] and [`SessionCursor::deserialize_current`] must only be
219 /// called after [`SessionCursor::advance`] returned `Ok(true)`. It is an error to call
220 /// either of them without calling [`SessionCursor::advance`] first or after
221 /// [`SessionCursor::advance`] returns an error / false.
222 ///
223 /// ```
224 /// # use mongodb::{Client, bson::{doc, Document}, error::Result};
225 /// # async fn foo() -> Result<()> {
226 /// # let client = Client::with_uri_str("mongodb://localhost:27017").await?;
227 /// # let mut session = client.start_session().await?;
228 /// # let coll = client.database("stuff").collection::<Document>("stuff");
229 /// let mut cursor = coll.find(doc! {}).session(&mut session).await?;
230 /// while cursor.advance(&mut session).await? {
231 /// println!("{:?}", cursor.current());
232 /// }
233 /// # Ok(())
234 /// # }
235 /// ```
236 pub async fn advance(&mut self, session: &mut ClientSession) -> Result<bool> {
237 self.make_stream(session).generic_cursor.advance().await
238 }
239
240 #[cfg(test)]
241 pub(crate) async fn try_advance(&mut self, session: &mut ClientSession) -> Result<()> {
242 self.make_stream(session)
243 .generic_cursor
244 .try_advance()
245 .await
246 .map(|_| ())
247 }
248
249 /// Returns a reference to the current result in the cursor.
250 ///
251 /// # Panics
252 /// [`SessionCursor::advance`] must return `Ok(true)` before [`SessionCursor::current`] can be
253 /// invoked. Calling [`SessionCursor::current`] after [`SessionCursor::advance`] does not return
254 /// true or without calling [`SessionCursor::advance`] at all may result in a panic.
255 ///
256 /// ```
257 /// # use mongodb::{Client, bson::{Document, doc}, error::Result};
258 /// # async fn foo() -> Result<()> {
259 /// # let client = Client::with_uri_str("mongodb://localhost:27017").await?;
260 /// # let mut session = client.start_session().await?;
261 /// # let coll = client.database("stuff").collection::<Document>("stuff");
262 /// let mut cursor = coll.find(doc! {}).session(&mut session).await?;
263 /// while cursor.advance(&mut session).await? {
264 /// println!("{:?}", cursor.current());
265 /// }
266 /// # Ok(())
267 /// # }
268 /// ```
269 pub fn current(&self) -> &RawDocument {
270 self.state.as_ref().unwrap().buffer.current().unwrap()
271 }
272
273 /// Deserialize the current result to the generic type associated with this cursor.
274 ///
275 /// # Panics
276 /// [`SessionCursor::advance`] must return `Ok(true)` before
277 /// [`SessionCursor::deserialize_current`] can be invoked. Calling
278 /// [`SessionCursor::deserialize_current`] after [`SessionCursor::advance`] does not return
279 /// true or without calling [`SessionCursor::advance`] at all may result in a panic.
280 ///
281 /// ```
282 /// # use mongodb::{Client, error::Result, bson::doc};
283 /// # async fn foo() -> Result<()> {
284 /// # let client = Client::with_uri_str("mongodb://localhost:27017").await?;
285 /// # let mut session = client.start_session().await?;
286 /// # let db = client.database("foo");
287 /// use serde::Deserialize;
288 ///
289 /// #[derive(Debug, Deserialize)]
290 /// struct Cat<'a> {
291 /// #[serde(borrow)]
292 /// name: &'a str
293 /// }
294 ///
295 /// let coll = db.collection::<Cat>("cat");
296 /// let mut cursor = coll.find(doc! {}).session(&mut session).await?;
297 /// while cursor.advance(&mut session).await? {
298 /// println!("{:?}", cursor.deserialize_current()?);
299 /// }
300 /// # Ok(())
301 /// # }
302 /// ```
303 pub fn deserialize_current<'a>(&'a self) -> Result<T>
304 where
305 T: Deserialize<'a>,
306 {
307 crate::bson_compat::deserialize_from_slice(self.current().as_bytes()).map_err(Error::from)
308 }
309
310 /// Update the type streamed values will be parsed as.
311 pub fn with_type<'a, D>(mut self) -> SessionCursor<D>
312 where
313 D: Deserialize<'a>,
314 {
315 SessionCursor {
316 client: self.client.clone(),
317 drop_token: self.drop_token.take(),
318 info: self.info.clone(),
319 state: Some(self.take_state()),
320 drop_address: self.drop_address.take(),
321 _phantom: Default::default(),
322 #[cfg(test)]
323 kill_watcher: self.kill_watcher.take(),
324 }
325 }
326
327 pub(crate) fn address(&self) -> &ServerAddress {
328 &self.info.address
329 }
330
331 pub(crate) fn set_drop_address(&mut self, address: ServerAddress) {
332 self.drop_address = Some(address);
333 }
334
335 /// Some tests need to be able to observe the events generated by `killCommand` execution;
336 /// however, because that happens asynchronously on `drop`, the test runner can conclude before
337 /// the event is published. To fix that, tests can set a "kill watcher" on cursors - a
338 /// one-shot channel with a `()` value pushed after `killCommand` is run that the test can wait
339 /// on.
340 #[cfg(test)]
341 pub(crate) fn set_kill_watcher(&mut self, tx: oneshot::Sender<()>) {
342 assert!(
343 self.kill_watcher.is_none(),
344 "cursor already has a kill_watcher"
345 );
346 self.kill_watcher = Some(tx);
347 }
348}
349
350impl<T> SessionCursor<T> {
351 pub(crate) fn is_exhausted(&self) -> bool {
352 self.state.as_ref().is_none_or(|state| state.exhausted)
353 }
354
355 #[cfg(test)]
356 pub(crate) fn client(&self) -> &Client {
357 &self.client
358 }
359}
360
361impl<T> Drop for SessionCursor<T> {
362 fn drop(&mut self) {
363 if self.is_exhausted() {
364 return;
365 }
366
367 kill_cursor(
368 self.client.clone(),
369 &mut self.drop_token,
370 &self.info.ns,
371 self.info.id,
372 self.state.as_ref().unwrap().pinned_connection.replicate(),
373 self.drop_address.take(),
374 #[cfg(test)]
375 self.kill_watcher.take(),
376 );
377 }
378}
379
380/// A `GenericCursor` that borrows its session.
381/// This is to be used with cursors associated with explicit sessions borrowed from the user.
382type ExplicitSessionCursor<'session> =
383 GenericCursor<'session, ExplicitClientSessionHandle<'session>>;
384
385/// A type that implements [`Stream`](https://docs.rs/futures/latest/futures/stream/index.html) which can be used to
386/// stream the results of a [`SessionCursor`]. Returned from [`SessionCursor::stream`].
387///
388/// This updates the buffer of the parent [`SessionCursor`] when dropped. [`SessionCursor::next`] or
389/// any further streams created from [`SessionCursor::stream`] will pick up where this one left off.
390pub struct SessionCursorStream<'cursor, 'session, T = Document> {
391 session_cursor: &'cursor mut SessionCursor<T>,
392 generic_cursor: ExplicitSessionCursor<'session>,
393}
394
395impl<T> SessionCursorStream<'_, '_, T>
396where
397 T: DeserializeOwned,
398{
399 pub(crate) fn post_batch_resume_token(&self) -> Option<&ResumeToken> {
400 self.generic_cursor.post_batch_resume_token()
401 }
402
403 pub(crate) fn client(&self) -> &Client {
404 &self.session_cursor.client
405 }
406}
407
408impl<T> Stream for SessionCursorStream<'_, '_, T>
409where
410 T: DeserializeOwned,
411{
412 type Item = Result<T>;
413
414 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
415 stream_poll_next(&mut self.generic_cursor, cx)
416 }
417}
418
419impl<T> CursorStream for SessionCursorStream<'_, '_, T>
420where
421 T: DeserializeOwned,
422{
423 fn poll_next_in_batch(&mut self, cx: &mut Context<'_>) -> Poll<Result<BatchValue>> {
424 self.generic_cursor.poll_next_in_batch(cx)
425 }
426}
427
428impl<T> Drop for SessionCursorStream<'_, '_, T> {
429 fn drop(&mut self) {
430 // Update the parent cursor's state based on any iteration performed on this handle.
431 self.session_cursor.state = Some(self.generic_cursor.take_state());
432 }
433}