pub mod error;
mod options;
use bson::{doc, oid::ObjectId, Bson, Document};
use error::CursorError;
use log::warn;
use mongodb::{options::FindOptions, Collection};
use options::CursorOptions;
use serde::{Deserialize, Serialize};
use std::io::Cursor;
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct PageInfo {
pub has_next_page: bool,
pub has_previous_page: bool,
pub start_cursor: Option<String>,
pub next_cursor: Option<String>,
}
#[cfg(feature = "graphql")]
#[juniper::object]
impl PageInfo {
fn has_next_page(&self) -> bool {
self.has_next_page
}
fn has_previous_page(&self) -> bool {
self.has_previous_page
}
fn start_cursor(&self) -> Option<String> {
self.start_cursor.to_owned()
}
fn next_cursor(&self) -> Option<String> {
self.next_cursor.to_owned()
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Edge {
pub cursor: String,
}
#[cfg(feature = "graphql")]
#[juniper::object]
impl Edge {
fn cursor(&self) -> String {
self.cursor.to_owned()
}
}
#[cfg(feature = "graphql")]
impl From<&Edge> for Edge {
fn from(edge: &Edge) -> Edge {
Edge {
cursor: edge.cursor.clone(),
}
}
}
#[derive(Debug)]
pub struct FindResult<T> {
pub page_info: PageInfo,
pub edges: Vec<Edge>,
pub total_count: i64,
pub items: Vec<T>,
}
#[derive(Clone, Debug, PartialEq)]
pub enum CursorDirections {
Previous,
Next,
}
#[derive(Debug)]
pub struct PaginatedCursor {
has_cursor: bool,
cursor_doc: Document,
direction: CursorDirections,
options: CursorOptions,
}
impl<'a> PaginatedCursor {
pub fn new(
options: Option<FindOptions>,
cursor: Option<String>,
direction: Option<CursorDirections>,
) -> Self {
PaginatedCursor {
has_cursor: cursor.is_some(),
cursor_doc: if let Some(b64) = cursor {
map_from_base64(b64).expect("Unable to parse cursor")
} else {
Document::new()
},
direction: if let Some(d) = direction {
d
} else {
CursorDirections::Next
},
options: CursorOptions::from(options),
}
}
pub fn estimated_document_count(&self, collection: &Collection) -> Result<i64, CursorError> {
let count_options = self.options.clone();
let total_count: i64 = collection.estimated_document_count(count_options).unwrap();
Ok(total_count)
}
pub fn count_documents(
&self,
collection: &Collection,
query: Option<&Document>,
) -> Result<i64, CursorError> {
let mut count_options = self.options.clone();
count_options.limit = None;
count_options.skip = None;
let count_query = if let Some(q) = query {
q.clone()
} else {
Document::new()
};
let total_count: i64 = collection
.count_documents(count_query, count_options)
.unwrap();
Ok(total_count)
}
pub fn find<T>(
&self,
collection: &Collection,
filter: Option<&Document>,
) -> Result<FindResult<T>, CursorError>
where
T: Deserialize<'a>,
{
let total_count: i64 = self.count_documents(collection, filter).unwrap();
let mut items: Vec<T> = vec![];
let mut edges: Vec<Edge> = vec![];
let mut has_next_page = false;
let mut has_previous_page = false;
let mut has_skip = false;
let mut start_cursor: Option<String> = None;
let mut next_cursor: Option<String> = None;
if total_count > 0 {
let query_doc = self.get_query(filter)?;
let mut options = self.options.clone();
let skip_value: i64 = if let Some(s) = options.skip { s } else { 0 };
if self.has_cursor || skip_value == 0 {
options.skip = None;
} else {
has_skip = true;
}
let is_previous_query = self.has_cursor && self.direction == CursorDirections::Previous;
if is_previous_query {
if let Some(sort) = options.sort {
let keys: Vec<&String> = sort.keys().collect();
let mut new_sort = Document::new();
for key in keys {
let bson_value = sort.get(key).unwrap();
match bson_value {
Bson::I32(value) => {
new_sort.insert(key, Bson::I32(-value));
}
Bson::I64(value) => {
new_sort.insert(key, Bson::I64(-value));
}
_ => {}
};
}
options.sort = Some(new_sort);
}
}
let cursor = collection.find(query_doc, options).unwrap();
for result in cursor {
match result {
Ok(doc) => {
let item = bson::from_bson(bson::Bson::Document(doc.clone())).unwrap();
edges.push(Edge {
cursor: self.create_from_doc(&doc),
});
items.push(item);
}
Err(error) => {
warn!("Error to find doc: {}", error);
}
}
}
let has_more: bool;
if has_skip {
has_more = (items.len() as i64 + skip_value) < total_count;
has_previous_page = true;
has_next_page = has_more;
} else {
has_more = items.len() > (self.options.limit.unwrap() - 1) as usize;
has_previous_page = (self.has_cursor && self.direction == CursorDirections::Next)
|| (is_previous_query && has_more);
has_next_page = (self.direction == CursorDirections::Next && has_more)
|| (is_previous_query && self.has_cursor);
}
if is_previous_query {
items.reverse();
edges.reverse();
}
if has_more && !is_previous_query {
items.pop();
edges.pop();
} else if has_more {
items.remove(0);
edges.remove(0);
}
if !items.is_empty() && edges.len() == items.len() {
start_cursor = Some(edges[0].cursor.to_owned());
next_cursor = Some(edges[items.len() - 1].cursor.to_owned());
}
}
let page_info = PageInfo {
has_next_page,
has_previous_page,
start_cursor,
next_cursor,
};
Ok(FindResult {
page_info,
total_count,
edges,
items,
})
}
fn get_value_from_doc(&self, key: &str, doc: Bson) -> Option<(String, Bson)> {
let parts: Vec<&str> = key.splitn(2, ".").collect();
match doc {
Bson::Document(d) => {
let some_value = d.get(parts[0]);
match some_value {
Some(value) =>
match value {
Bson::Document(d) => {
self.get_value_from_doc(parts[1], Bson::Document(d.clone()))
}
_ => Some((parts[0].to_string(), value.clone())),
},
None => None
}
}
_ => Some((parts[0].to_string(), doc)),
}
}
fn create_from_doc(&self, doc: &Document) -> String {
let mut only_sort_keys = Document::new();
if let Some(sort) = &self.options.sort {
for key in sort.keys() {
if let Some((_, value)) = self.get_value_from_doc(key, Bson::Document(doc.clone())) {
only_sort_keys.insert(key, value);
}
}
let mut buf = Vec::new();
bson::encode_document(&mut buf, &only_sort_keys).unwrap();
base64::encode(&buf)
} else {
"".to_owned()
}
}
fn get_query(&self, query: Option<&Document>) -> Result<Document, CursorError> {
let mut query_doc = match query {
Some(doc) => doc.clone(),
None => Document::new(),
};
if self.cursor_doc.is_empty() {
return Ok(query_doc);
} else if let Some(sort) = &self.options.sort {
if sort.len() > 1 {
let keys: Vec<&String> = sort.keys().collect();
let mut queries: Vec<Document> = Vec::new();
#[allow(clippy::needless_range_loop)]
for i in 0..keys.len() {
let mut query = query_doc.clone();
#[allow(clippy::needless_range_loop)]
for j in 0..i {
let value = self.cursor_doc.get(keys[j]).unwrap_or(&Bson::Null);
query.insert(keys[j], value.clone());
}
let value = self.cursor_doc.get(keys[i]).unwrap_or(&Bson::Null);
let direction = self.get_direction_from_key(&sort, keys[i]);
query.insert(keys[i], doc! { direction: value.clone() });
queries.push(query);
}
if queries.len() > 1 {
query_doc = doc! { "$or": [] };
let or_array = query_doc.get_array_mut("$or").map_err(|_| CursorError::Unknown("Unable to process".into()))?;
for d in queries.iter() {
or_array.push(Bson::Document(d.clone()));
}
} else {
query_doc = queries[0].clone();
}
} else {
let object_id = self.cursor_doc.get("_id").unwrap().clone();
let direction = self.get_direction_from_key(&sort, "_id");
query_doc.insert("_id", doc! { direction: object_id });
}
}
Ok(query_doc)
}
fn get_direction_from_key(&self, sort: &Document, key: &str) -> &'static str {
let value = sort.get(key).unwrap().as_i32().unwrap();
match self.direction {
CursorDirections::Next => {
if value >= 0 {
"$gt"
} else {
"$lt"
}
}
CursorDirections::Previous => {
if value >= 0 {
"$lt"
} else {
"$gt"
}
}
}
}
}
fn map_from_base64(base64_string: String) -> Result<Document, CursorError> {
let decoded = base64::decode(&base64_string)?;
let cursor_doc = bson::decode_document(&mut Cursor::new(&decoded)).unwrap();
Ok(cursor_doc)
}
pub fn get_object_id(id: &str) -> Result<ObjectId, CursorError> {
let object_id = match ObjectId::with_string(id) {
Ok(object_id) => object_id,
Err(_e) => return Err(CursorError::InvalidId(id.to_string())),
};
Ok(object_id)
}