mongodb_cursor_pagination/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3#![warn(
4    clippy::cast_lossless,
5    clippy::cast_possible_truncation,
6    clippy::cast_possible_wrap,
7    clippy::cast_precision_loss,
8    clippy::cast_sign_loss,
9    clippy::checked_conversions,
10    clippy::implicit_saturating_sub,
11    clippy::integer_arithmetic,
12    clippy::mod_module_files,
13    clippy::panic,
14    clippy::panic_in_result_fn,
15    clippy::unwrap_used,
16    missing_docs,
17    rust_2018_idioms,
18    unused_lifetimes,
19    unused_qualifications
20)]
21
22//! ### Usage:
23//! The usage is a bit different than the node version. See the examples for more details and a working example.
24//! ```rust
25//! use mongodb::{options::FindOptions, Client};
26//! use mongodb_cursor_pagination::{CursorDirections, FindResult, PaginatedCursor};
27//! use bson::doc;
28//! use serde::Deserialize;
29//!
30//! // Note that your data structure must derive Deserialize
31//! #[derive(Debug, Deserialize, PartialEq, Clone)]
32//! pub struct MyFruit {
33//!     name: String,
34//!     how_many: i32,
35//! }
36//! #  impl MyFruit {
37//! #     #[must_use]
38//! #     pub fn new(name: impl Into<String>, how_many: i32) -> Self {
39//! #         Self {
40//! #             name: name.into(),
41//! #             how_many,
42//! #         }
43//! #     }
44//! # }
45//!
46//! #[tokio::main]
47//! async fn main() {
48//!     let client = Client::with_uri_str("mongodb://localhost:27017/")
49//!         .await
50//!         .expect("Failed to initialize client.");
51//!     let db = client.database("mongodb_cursor_pagination");
52//!   #  db.collection::<MyFruit>("myfruits")
53//!   #      .drop(None)
54//!   #      .await
55//!   #      .expect("Failed to drop table");
56//!
57//!     let docs = vec![
58//!         doc! { "name": "Apple", "how_many": 5 },
59//!         doc! { "name": "Orange", "how_many": 3 },
60//!         doc! { "name": "Blueberry", "how_many": 25 },
61//!         doc! { "name": "Bananas", "how_many": 8 },
62//!         doc! { "name": "Grapes", "how_many": 12 },
63//!     ];
64//!
65//!     db.collection("myfruits")
66//!         .insert_many(docs, None)
67//!         .await
68//!         .expect("Unable to insert data");
69//!
70//!     // query page 1, 2 at a time
71//!     let options = FindOptions::builder()
72//!             .limit(2)
73//!             .sort(doc! { "name": 1 })
74//!             .build();
75//!
76//!     let mut find_results: FindResult<MyFruit> = PaginatedCursor::new(Some(options.clone()), None, None)
77//!         .find(&db.collection("myfruits"), None)
78//!         .await
79//!         .expect("Unable to find data");
80//!   #  assert_eq!(
81//!   #     find_results.items,
82//!   #     vec![MyFruit::new("Apple", 5), MyFruit::new("Bananas", 8),]
83//!   # );
84//!     println!("First page: {:?}", find_results);
85//!
86//!     // get the second page
87//!     let mut cursor = find_results.page_info.next_cursor;
88//!     find_results = PaginatedCursor::new(Some(options), cursor, Some(CursorDirections::Next))
89//!         .find(&db.collection("myfruits"), None)
90//!         .await
91//!         .expect("Unable to find data");
92//!   #  assert_eq!(
93//!   #    find_results.items,
94//!   #     vec![MyFruit::new("Blueberry", 25), MyFruit::new("Grapes", 12),]
95//!   # );
96//!     println!("Second page: {:?}", find_results);
97//! }
98//! ```
99//!
100//! ### Response
101//! The response `FindResult<T>` contains page info, cursors and edges (cursors for all of the items in the response).
102//! ```rust
103//! pub struct PageInfo {
104//!     pub has_next_page: bool,
105//!     pub has_previous_page: bool,
106//!     pub start_cursor: Option<String>,
107//!     pub next_cursor: Option<String>,
108//! }
109//!
110//! pub struct Edge {
111//!     pub cursor: String,
112//! }
113//!
114//! pub struct FindResult<T> {
115//!     pub page_info: PageInfo,
116//!     pub edges: Vec<Edge>,
117//!     pub total_count: i64,
118//!     pub items: Vec<T>,
119//! }
120//! ```
121//!
122//! ## Features
123//! It has support for graphql (using [juniper](https://github.com/graphql-rust/juniper)) if you enable the `graphql` flag. You can use it by just including the `PageInfo` into your code.
124//!
125//! ```ignore
126//! use mongodb_cursor_pagination::{PageInfo, Edge};
127//!
128//! #[derive(Serialize, Deserialize)]
129//! struct MyDataConnection {
130//!     page_info: PageInfo,
131//!     edges: Vec<Edge>,
132//!     data: Vec<MyData>,
133//!     total_count: i64,
134//! }
135//!
136//! [juniper::object]
137//! impl MyDataConnection {
138//!     fn page_info(&self) -> &PageInfo {
139//!         self.page_info
140//!     }
141//!
142//!     fn edges(&self) -> &Vec<Edge> {
143//!         &self.edges
144//!     }
145//! }
146//! ```
147
148pub mod error;
149mod options;
150
151use crate::options::CursorOptions;
152use base64::engine::general_purpose::STANDARD;
153use base64::Engine;
154use bson::{doc, oid::ObjectId, Bson, Document};
155use error::CursorError;
156use futures_util::stream::StreamExt;
157use log::warn;
158use mongodb::options::{CountOptions, EstimatedDocumentCountOptions};
159use mongodb::{options::FindOptions, Collection};
160use serde::de::DeserializeOwned;
161use serde::{Deserialize, Serialize};
162use std::ops::Neg;
163
164/// Provides details about if there are more pages and the cursor to the start of the list and end
165#[derive(Clone, Debug, Deserialize, Serialize, Default)]
166pub struct PageInfo {
167    pub has_next_page: bool,
168    pub has_previous_page: bool,
169    pub start_cursor: Option<String>,
170    pub next_cursor: Option<String>,
171}
172
173#[cfg(feature = "graphql")]
174#[juniper::object]
175impl PageInfo {
176    fn has_next_page(&self) -> bool {
177        self.has_next_page
178    }
179
180    fn has_previous_page(&self) -> bool {
181        self.has_previous_page
182    }
183
184    fn start_cursor(&self) -> Option<String> {
185        self.start_cursor.to_owned()
186    }
187
188    fn next_cursor(&self) -> Option<String> {
189        self.next_cursor.to_owned()
190    }
191}
192
193/// Edges are the cursors on all of the items in the return
194#[derive(Clone, Debug, Deserialize, Serialize)]
195pub struct Edge {
196    pub cursor: String,
197}
198
199#[cfg(feature = "graphql")]
200#[juniper::object]
201impl Edge {
202    fn cursor(&self) -> String {
203        self.cursor.to_owned()
204    }
205}
206// FIX: there's probably a better way to do this...but for now
207#[cfg(feature = "graphql")]
208impl From<&Edge> for Edge {
209    fn from(edge: &Edge) -> Edge {
210        Edge {
211            cursor: edge.cursor.clone(),
212        }
213    }
214}
215
216/// The result of a find method with the items, edges, pagination info, and total count of objects
217#[derive(Debug, Default)]
218pub struct FindResult<T> {
219    pub page_info: PageInfo,
220    pub edges: Vec<Edge>,
221    pub total_count: u64,
222    pub items: Vec<T>,
223}
224
225/// The direction of the list, ie. you are sending a cursor for the next or previous items. Defaults to Next
226#[derive(Clone, Debug, PartialEq, Eq)]
227pub enum CursorDirections {
228    Previous,
229    Next,
230}
231
232/// The main entry point for finding documents
233#[derive(Debug)]
234pub struct PaginatedCursor {
235    has_cursor: bool,
236    cursor_doc: Document,
237    direction: CursorDirections,
238    options: CursorOptions,
239}
240
241impl PaginatedCursor {
242    /// Updates or creates all of the find options to help with pagination and returns a `PaginatedCursor` object.
243    ///
244    /// # Arguments
245    /// * `options` - Optional find options that you would like to perform any searches with
246    /// * `cursor` - An optional existing cursor in base64. This would have come from a previous `FindResult<T>`
247    /// * `direction` - Determines whether the cursor supplied is for a previous page or the next page. Defaults to Next
248    ///
249    #[must_use]
250    pub fn new(
251        options: Option<FindOptions>,
252        cursor: Option<String>,
253        direction: Option<CursorDirections>,
254    ) -> Self {
255        Self {
256            // parse base64 for keys
257            has_cursor: cursor.is_some(),
258            cursor_doc: cursor.map_or_else(Document::new, |b64| {
259                map_from_base64(b64).expect("Unable to parse cursor")
260            }),
261            direction: direction.unwrap_or(CursorDirections::Next),
262            options: CursorOptions::from(options.unwrap_or_default()),
263        }
264    }
265
266    /// Estimates the number of documents in the collection using collection metadata.
267    pub async fn estimated_document_count<T>(
268        &self,
269        collection: &Collection<T>,
270    ) -> Result<u64, CursorError> {
271        let total_count = collection
272            .estimated_document_count(Some(EstimatedDocumentCountOptions::from(
273                self.options.clone(),
274            )))
275            .await
276            .unwrap();
277        Ok(total_count)
278    }
279
280    /// Gets the number of documents matching filter.
281    /// Note that using [`PaginatedCursor::estimated_document_count`](#method.estimated_document_count)
282    /// is recommended instead of this method is most cases.
283    pub async fn count_documents<T>(
284        &self,
285        collection: &Collection<T>,
286        query: Option<&Document>,
287    ) -> Result<u64, CursorError> {
288        let mut count_options = self.options.clone();
289        count_options.limit = None;
290        count_options.skip = None;
291        let count_query = query.map_or_else(Document::new, Clone::clone);
292        let total_count = collection
293            .count_documents(count_query, Some(CountOptions::from(count_options)))
294            .await
295            .unwrap();
296        Ok(total_count)
297    }
298
299    /// Finds the documents in the `collection` matching `filter`.
300    pub async fn find<T>(
301        &self,
302        collection: &Collection<Document>,
303        filter: Option<&Document>,
304    ) -> Result<FindResult<T>, CursorError>
305    where
306        T: DeserializeOwned + Sync + Send + Unpin + Clone,
307    {
308        // first count the docs
309        let total_count = self.count_documents(collection, filter).await.unwrap();
310
311        // setup defaults
312        let mut items: Vec<T> = vec![];
313        let mut edges: Vec<Edge> = vec![];
314        let mut has_next_page = false;
315        let mut has_previous_page = false;
316        let mut has_skip = false;
317        let mut start_cursor: Option<String> = None;
318        let mut next_cursor: Option<String> = None;
319
320        // return if we if have no docs
321        if total_count == 0 {
322            return Ok(FindResult {
323                page_info: PageInfo::default(),
324                edges: vec![],
325                total_count: 0,
326                items: vec![],
327            });
328        }
329
330        // build the cursor
331        let query_doc = self.get_query(filter.cloned());
332        let mut options = self.options.clone();
333        let skip_value = options.skip.unwrap_or(0);
334        if self.has_cursor || skip_value == 0 {
335            options.skip = None;
336        } else {
337            has_skip = true;
338        }
339        // let has_previous
340        let is_previous_query = self.has_cursor && self.direction == CursorDirections::Previous;
341        // if it's a previous query we need to reverse the sort we were doing
342        if is_previous_query {
343            if let Some(sort) = options.sort.as_mut() {
344                sort.iter_mut().for_each(|(_key, value)| {
345                    if let Bson::Int32(num) = value {
346                        *value = Bson::Int32(num.neg());
347                    }
348                    if let Bson::Int64(num) = value {
349                        *value = Bson::Int64(num.neg());
350                    }
351                });
352            }
353        }
354        let mut cursor = collection
355            .find(query_doc, Some(options.into()))
356            .await
357            .unwrap();
358        while let Some(result) = cursor.next().await {
359            match result {
360                Ok(doc) => {
361                    let item = bson::from_bson(Bson::Document(doc.clone())).unwrap();
362                    edges.push(Edge {
363                        cursor: self.create_from_doc(&doc),
364                    });
365                    items.push(item);
366                }
367                Err(error) => {
368                    warn!("Error to find doc: {}", error);
369                }
370            }
371        }
372        let has_more: bool;
373        if has_skip {
374            has_more = (items.len() as u64).saturating_add(skip_value) < total_count;
375            has_previous_page = true;
376            has_next_page = has_more;
377        } else {
378            has_more = items.len() as i64 > self.options.limit.unwrap().saturating_sub(1);
379            has_previous_page = (self.has_cursor && self.direction == CursorDirections::Next)
380                || (is_previous_query && has_more);
381            has_next_page = (self.direction == CursorDirections::Next && has_more)
382                || (is_previous_query && self.has_cursor);
383        }
384
385        // reorder if we are going backwards
386        if is_previous_query {
387            items.reverse();
388            edges.reverse();
389        }
390        // remove the extra item to check if we have more
391        if has_more && !is_previous_query {
392            items.pop();
393            edges.pop();
394        } else if has_more {
395            items.remove(0);
396            edges.remove(0);
397        }
398
399        // create the next cursor
400        if !items.is_empty() && edges.len() == items.len() {
401            start_cursor = Some(edges[0].cursor.clone());
402            next_cursor = Some(edges[items.len().saturating_sub(1)].cursor.clone());
403        }
404
405        let page_info = PageInfo {
406            has_next_page,
407            has_previous_page,
408            start_cursor,
409            next_cursor,
410        };
411        Ok(FindResult {
412            page_info,
413            edges,
414            total_count,
415            items,
416        })
417    }
418
419    fn get_value_from_doc(&self, key: &str, doc: Bson) -> Option<(String, Bson)> {
420        let parts: Vec<&str> = key.splitn(2, '.').collect();
421        match doc {
422            Bson::Document(d) => d.get(parts[0]).and_then(|value| match value {
423                Bson::Document(d) => self.get_value_from_doc(parts[1], Bson::Document(d.clone())),
424                _ => Some((parts[0].to_string(), value.clone())),
425            }),
426            _ => Some((parts[0].to_string(), doc)),
427        }
428    }
429
430    fn create_from_doc(&self, doc: &Document) -> String {
431        let mut only_sort_keys = Document::new();
432        self.options.sort.as_ref().map_or_else(String::new, |sort| {
433            for key in sort.keys() {
434                if let Some((_, value)) = self.get_value_from_doc(key, Bson::Document(doc.clone()))
435                {
436                    only_sort_keys.insert(key, value);
437                }
438            }
439            let buf = bson::to_vec(&only_sort_keys).unwrap();
440            STANDARD.encode(buf)
441        })
442    }
443
444    /*
445    $or: [{
446        launchDate: { $lt: nextLaunchDate }
447    }, {
448        // If the launchDate is an exact match, we need a tiebreaker, so we use the _id field from the cursor.
449        launchDate: nextLaunchDate,
450    _id: { $lt: nextId }
451    }]
452    */
453    fn get_query(&self, query: Option<Document>) -> Document {
454        // now create the filter
455        let mut query_doc = query.unwrap_or_default();
456
457        // Don't do anything if no cursor is provided
458        if self.cursor_doc.is_empty() {
459            return query_doc;
460        }
461        let Some(sort) = &self.options.sort else {
462            return query_doc;
463        };
464
465        // this is the simplest form, it's just a sort by _id
466        if sort.len() <= 1 {
467            let object_id = self.cursor_doc.get("_id").unwrap().clone();
468            let direction = self.get_direction_from_key(sort, "_id");
469            query_doc.insert("_id", doc! { direction: object_id });
470            return query_doc;
471        }
472
473        let mut queries: Vec<Document> = Vec::new();
474        let mut previous_conditions: Vec<(String, Bson)> = Vec::new();
475
476        // Add each sort condition with it's direction and all previous condition with fixed values
477        for key in sort.keys() {
478            let mut query = query_doc.clone();
479            query.extend(previous_conditions.clone().into_iter()); // Add previous conditions
480
481            let value = self.cursor_doc.get(key).unwrap_or(&Bson::Null);
482            let direction = self.get_direction_from_key(sort, key);
483            query.insert(key, doc! { direction: value.clone() });
484            previous_conditions.push((key.clone(), value.clone())); // Add self without direction to previous conditions
485
486            queries.push(query);
487        }
488
489        query_doc = if queries.len() > 1 {
490            doc! { "$or": queries.iter().as_ref() }
491        } else {
492            queries.pop().unwrap_or_default()
493        };
494        query_doc
495    }
496
497    fn get_direction_from_key(&self, sort: &Document, key: &str) -> &'static str {
498        let value = sort.get(key).and_then(Bson::as_i32).unwrap_or(0);
499        match self.direction {
500            CursorDirections::Next => {
501                if value >= 0 {
502                    "$gt"
503                } else {
504                    "$lt"
505                }
506            }
507            CursorDirections::Previous => {
508                if value >= 0 {
509                    "$lt"
510                } else {
511                    "$gt"
512                }
513            }
514        }
515    }
516}
517
518fn map_from_base64(base64_string: String) -> Result<Document, CursorError> {
519    // change from base64
520    let decoded = STANDARD.decode(base64_string)?;
521    // decode from bson
522    let cursor_doc = bson::from_slice(decoded.as_slice()).unwrap();
523    Ok(cursor_doc)
524}
525
526/// Converts an id into a `MongoDb` `ObjectId`
527pub fn get_object_id(id: &str) -> Result<ObjectId, CursorError> {
528    let object_id = match ObjectId::parse_str(id) {
529        Ok(object_id) => object_id,
530        Err(_e) => return Err(CursorError::InvalidId(id.to_string())),
531    };
532    Ok(object_id)
533}