mongo_orm/
repo.rs

1use crate::entity::MongoEntity;
2use async_trait::async_trait;
3use futures::TryStreamExt;
4use mongodb::options::FindOptions;
5use mongodb::{
6    bson::{doc, oid::ObjectId, Document},
7    Collection,
8};
9
10use typed_builder::TypedBuilder;
11
12#[async_trait]
13pub trait MongoRepository<E> {
14    async fn find_by_id(&self, id: &str) -> Option<E>;
15
16    async fn insert(&self, entity: &E) -> String;
17
18    async fn insert_many(&self, entities: &[E]);
19
20    async fn update_by_id(&self, id: &str, update: Document);
21
22    async fn update_one_by(&self, filter: Document, update: Document);
23
24    async fn update_many(&self, filter: Document, update: Document) -> (u64, u64);
25
26    async fn delete_by_id(&self, id: &str) -> bool;
27
28    async fn find(&self, option: FindOption) -> Vec<E>;
29
30    async fn find_one_by(&self, filter: Document) -> Option<E>;
31
32    async fn count(&self, filter: Option<Document>) -> u64;
33
34    async fn exists(&self, filter: Document) -> bool;
35
36    async fn delete_by(&self, filter: Document) -> u64;
37
38    async fn aggregate(&self, stages: Vec<Document>) -> Vec<Document>;
39}
40
41impl<E: MongoEntity> MongoRepo<E> {
42    pub fn new(collection: Collection<Document>) -> Self {
43        Self {
44            collection,
45            entity: Default::default(),
46        }
47    }
48}
49#[derive(Debug)]
50pub struct MongoRepo<E: MongoEntity> {
51    collection: Collection<Document>,
52    entity: std::marker::PhantomData<E>,
53}
54
55#[derive(Debug, TypedBuilder, Clone)]
56pub struct FindOption {
57    #[builder(default, setter(strip_option))] // Optional field, allows None or Some(Document)
58    pub filter: Option<Document>,
59
60    #[builder(default, setter(strip_option))] // Optional field
61    pub sort: Option<Document>,
62
63    #[builder(default, setter(strip_option))] // Optional field
64    pub limit: Option<i64>,
65
66    #[builder(default, setter(strip_option))] // Optional field
67    pub skip: Option<u64>,
68}
69
70impl FindOption {
71    pub fn all() -> Self {
72        FindOption {
73            filter: None,
74            sort: None,
75            limit: None,
76            skip: None,
77        }
78    }
79}
80
81impl Into<FindOptions> for FindOption {
82    fn into(self) -> FindOptions {
83        let mut builder = FindOptions::default();
84
85        builder.sort = self.sort;
86        builder.skip = self.skip;
87        builder.limit = self.limit;
88
89        builder
90    }
91}
92
93#[async_trait]
94impl<E> MongoRepository<E> for MongoRepo<E>
95where
96    E: MongoEntity,
97{
98    async fn find_by_id(&self, id: &str) -> Option<E> {
99        let object_id = ObjectId::parse_str(id);
100
101        let find = match object_id {
102            Ok(oid) => self.collection.find_one(doc! { "_id": oid }),
103            Err(_) => self.collection.find_one(doc! { "_id": id }),
104        };
105
106        let res = find.await.expect("MongoDB error");
107
108        res.map(E::from_document)
109    }
110
111    async fn insert(&self, entity: &E) -> String {
112        let id = self
113            .collection
114            .insert_one(entity.to_document())
115            .await
116            .unwrap()
117            .inserted_id;
118
119        if let Some(oid) = id.as_object_id() {
120            oid.to_hex()
121        } else {
122            id.as_str().unwrap().to_string()
123        }
124    }
125
126    async fn insert_many(&self, entities: &[E]) {
127        let documents: Vec<Document> = entities.iter().map(|e| e.to_document()).collect();
128        self.collection.insert_many(documents).await.unwrap();
129    }
130
131    async fn update_by_id(&self, id: &str, update: Document) {
132        let f = ObjectId::parse_str(id)
133            .map(|oid| doc! { "_id": oid })
134            .unwrap_or(doc! { "_id": id });
135        self.collection
136            .update_one(f, doc! { "$set": update })
137            .await
138            .unwrap();
139    }
140
141    async fn update_one_by(&self, filter: Document, update: Document) {
142        self.collection
143            .update_one(filter, doc! { "$set": update })
144            .await
145            .unwrap();
146    }
147
148    async fn update_many(&self, filter: Document, update: Document) -> (u64, u64) {
149        let result = self
150            .collection
151            .update_many(filter, doc! { "$set": update })
152            .await
153            .unwrap();
154        (result.matched_count, result.modified_count)
155    }
156
157    async fn delete_by_id(&self, id: &str) -> bool {
158        let object_id = ObjectId::parse_str(id).unwrap();
159        let result = self
160            .collection
161            .delete_one(doc! { "_id": object_id })
162            .await
163            .unwrap();
164        result.deleted_count > 0
165    }
166
167    async fn find(&self, option: FindOption) -> Vec<E> {
168        let options: FindOptions = option.clone().into();
169        let mut cursor = self
170            .collection
171            .find(option.filter.unwrap_or(doc! {}))
172            .with_options(options)
173            .await
174            .unwrap();
175        let mut results = Vec::new();
176        while let Some(doc) = cursor.try_next().await.unwrap() {
177            results.push(E::from_document(doc));
178        }
179        results
180    }
181
182    async fn find_one_by(&self, filter: Document) -> Option<E> {
183        self.collection
184            .find_one(filter)
185            .await
186            .unwrap()
187            .map(E::from_document)
188    }
189
190    async fn count(&self, filter: Option<Document>) -> u64 {
191        self.collection
192            .count_documents(filter.unwrap_or_default())
193            .await
194            .unwrap()
195    }
196
197    async fn exists(&self, filter: Document) -> bool {
198        self.collection.count_documents(filter).await.unwrap() > 0
199    }
200
201    async fn delete_by(&self, filter: Document) -> u64 {
202        let result = self.collection.delete_many(filter).await.unwrap();
203        result.deleted_count
204    }
205
206    async fn aggregate(&self, stages: Vec<Document>) -> Vec<Document> {
207        let cursor = self.collection.aggregate(stages).await.unwrap();
208        cursor.try_collect().await.unwrap()
209    }
210}