use crate::entity::MongoEntity;
use async_trait::async_trait;
use futures::TryStreamExt;
use mongodb::options::FindOptions;
use mongodb::{
bson::{doc, oid::ObjectId, Document},
Collection,
};
use typed_builder::TypedBuilder;
#[async_trait]
pub trait MongoRepository<E> {
async fn find_by_id(&self, id: &str) -> Option<E>;
async fn insert(&self, entity: &E) -> String;
async fn insert_many(&self, entities: &[E]);
async fn update_by_id(&self, id: &str, update: Document);
async fn update_one_by(&self, filter: Document, update: Document);
async fn update_many(&self, filter: Document, update: Document) -> (u64, u64);
async fn delete_by_id(&self, id: &str) -> bool;
async fn find(&self, option: FindOption) -> Vec<E>;
async fn find_one_by(&self, filter: Document) -> Option<E>;
async fn count(&self, filter: Option<Document>) -> u64;
async fn exists(&self, filter: Document) -> bool;
async fn delete_by(&self, filter: Document) -> u64;
async fn aggregate(&self, stages: Vec<Document>) -> Vec<Document>;
}
impl<E: MongoEntity> MongoRepo<E> {
pub fn new(collection: Collection<Document>) -> Self {
Self {
collection,
entity: Default::default(),
}
}
}
#[derive(Debug)]
pub struct MongoRepo<E: MongoEntity> {
collection: Collection<Document>,
entity: std::marker::PhantomData<E>,
}
#[derive(Debug, TypedBuilder, Clone)]
pub struct FindOption {
#[builder(default, setter(strip_option))] pub filter: Option<Document>,
#[builder(default, setter(strip_option))] pub sort: Option<Document>,
#[builder(default, setter(strip_option))] pub limit: Option<i64>,
#[builder(default, setter(strip_option))] pub skip: Option<u64>,
}
impl FindOption {
pub fn all() -> Self {
FindOption {
filter: None,
sort: None,
limit: None,
skip: None,
}
}
}
impl Into<FindOptions> for FindOption {
fn into(self) -> FindOptions {
let mut builder = FindOptions::default();
builder.sort = self.sort;
builder.skip = self.skip;
builder.limit = self.limit;
builder
}
}
#[async_trait]
impl<E> MongoRepository<E> for MongoRepo<E>
where
E: MongoEntity,
{
async fn find_by_id(&self, id: &str) -> Option<E> {
let object_id = ObjectId::parse_str(id);
let find = match object_id {
Ok(oid) => self.collection.find_one(doc! { "_id": oid }),
Err(_) => self.collection.find_one(doc! { "_id": id }),
};
find.await.unwrap().map(E::from_document)
}
async fn insert(&self, entity: &E) -> String {
let id = self
.collection
.insert_one(entity.to_document())
.await
.unwrap()
.inserted_id;
if let Some(oid) = id.as_object_id() {
oid.to_hex()
} else {
id.to_string()
}
}
async fn insert_many(&self, entities: &[E]) {
let documents: Vec<Document> = entities.iter().map(|e| e.to_document()).collect();
self.collection.insert_many(documents).await.unwrap();
}
async fn update_by_id(&self, id: &str, update: Document) {
let f = ObjectId::parse_str(id)
.map(|oid| doc! { "_id": oid })
.unwrap_or(doc! { "_id": id });
self.collection
.update_one(f, doc! { "$set": update })
.await
.unwrap();
}
async fn update_one_by(&self, filter: Document, update: Document) {
self.collection
.update_one(filter, doc! { "$set": update })
.await
.unwrap();
}
async fn update_many(&self, filter: Document, update: Document) -> (u64, u64) {
let result = self
.collection
.update_many(filter, doc! { "$set": update })
.await
.unwrap();
(result.matched_count, result.modified_count)
}
async fn delete_by_id(&self, id: &str) -> bool {
let object_id = ObjectId::parse_str(id).unwrap();
let result = self
.collection
.delete_one(doc! { "_id": object_id })
.await
.unwrap();
result.deleted_count > 0
}
async fn find(&self, option: FindOption) -> Vec<E> {
let options: FindOptions = option.clone().into();
let mut cursor = self
.collection
.find(option.filter.unwrap_or(doc! {}))
.with_options(options)
.await
.unwrap();
let mut results = Vec::new();
while let Some(doc) = cursor.try_next().await.unwrap() {
results.push(E::from_document(doc));
}
results
}
async fn find_one_by(&self, filter: Document) -> Option<E> {
self.collection
.find_one(filter)
.await
.unwrap()
.map(E::from_document)
}
async fn count(&self, filter: Option<Document>) -> u64 {
self.collection
.count_documents(filter.unwrap_or_default())
.await
.unwrap()
}
async fn exists(&self, filter: Document) -> bool {
self.collection.count_documents(filter).await.unwrap() > 0
}
async fn delete_by(&self, filter: Document) -> u64 {
let result = self.collection.delete_many(filter).await.unwrap();
result.deleted_count
}
async fn aggregate(&self, stages: Vec<Document>) -> Vec<Document> {
let cursor = self.collection.aggregate(stages).await.unwrap();
cursor.try_collect().await.unwrap()
}
}