1use crate::entity::MongoEntity;
2use async_trait::async_trait;
3use futures::TryStreamExt;
4use mongodb::options::FindOptions;
5use mongodb::{
6 bson::{doc, oid::ObjectId, Document},
7 Collection,
8};
9
10use typed_builder::TypedBuilder;
11
12#[async_trait]
13pub trait MongoRepository<E> {
14 async fn find_by_id(&self, id: &str) -> Option<E>;
15
16 async fn insert(&self, entity: &E) -> String;
17
18 async fn insert_many(&self, entities: &[E]);
19
20 async fn update_by_id(&self, id: &str, update: Document);
21
22 async fn update_one_by(&self, filter: Document, update: Document);
23
24 async fn update_many(&self, filter: Document, update: Document) -> (u64, u64);
25
26 async fn delete_by_id(&self, id: &str) -> bool;
27
28 async fn find(&self, option: FindOption) -> Vec<E>;
29
30 async fn find_one_by(&self, filter: Document) -> Option<E>;
31
32 async fn count(&self, filter: Option<Document>) -> u64;
33
34 async fn exists(&self, filter: Document) -> bool;
35
36 async fn delete_by(&self, filter: Document) -> u64;
37
38 async fn aggregate(&self, stages: Vec<Document>) -> Vec<Document>;
39}
40
41impl<E: MongoEntity> MongoRepo<E> {
42 pub fn new(collection: Collection<Document>) -> Self {
43 Self {
44 collection,
45 entity: Default::default(),
46 }
47 }
48}
49#[derive(Debug)]
50pub struct MongoRepo<E: MongoEntity> {
51 collection: Collection<Document>,
52 entity: std::marker::PhantomData<E>,
53}
54
55#[derive(Debug, TypedBuilder, Clone)]
56pub struct FindOption {
57 #[builder(default, setter(strip_option))] pub filter: Option<Document>,
59
60 #[builder(default, setter(strip_option))] pub sort: Option<Document>,
62
63 #[builder(default, setter(strip_option))] pub limit: Option<i64>,
65
66 #[builder(default, setter(strip_option))] pub skip: Option<u64>,
68}
69
70impl FindOption {
71 pub fn all() -> Self {
72 FindOption {
73 filter: None,
74 sort: None,
75 limit: None,
76 skip: None,
77 }
78 }
79}
80
81impl Into<FindOptions> for FindOption {
82 fn into(self) -> FindOptions {
83 let mut builder = FindOptions::default();
84
85 builder.sort = self.sort;
86 builder.skip = self.skip;
87 builder.limit = self.limit;
88
89 builder
90 }
91}
92
93#[async_trait]
94impl<E> MongoRepository<E> for MongoRepo<E>
95where
96 E: MongoEntity,
97{
98 async fn find_by_id(&self, id: &str) -> Option<E> {
99 let object_id = ObjectId::parse_str(id);
100
101 let find = match object_id {
102 Ok(oid) => self.collection.find_one(doc! { "_id": oid }),
103 Err(_) => self.collection.find_one(doc! { "_id": id }),
104 };
105
106 let res = find.await.expect("MongoDB error");
107
108 res.map(E::from_document)
109 }
110
111 async fn insert(&self, entity: &E) -> String {
112 let id = self
113 .collection
114 .insert_one(entity.to_document())
115 .await
116 .unwrap()
117 .inserted_id;
118
119 if let Some(oid) = id.as_object_id() {
120 oid.to_hex()
121 } else {
122 id.as_str().unwrap().to_string()
123 }
124 }
125
126 async fn insert_many(&self, entities: &[E]) {
127 let documents: Vec<Document> = entities.iter().map(|e| e.to_document()).collect();
128 self.collection.insert_many(documents).await.unwrap();
129 }
130
131 async fn update_by_id(&self, id: &str, update: Document) {
132 let f = ObjectId::parse_str(id)
133 .map(|oid| doc! { "_id": oid })
134 .unwrap_or(doc! { "_id": id });
135 self.collection
136 .update_one(f, doc! { "$set": update })
137 .await
138 .unwrap();
139 }
140
141 async fn update_one_by(&self, filter: Document, update: Document) {
142 self.collection
143 .update_one(filter, doc! { "$set": update })
144 .await
145 .unwrap();
146 }
147
148 async fn update_many(&self, filter: Document, update: Document) -> (u64, u64) {
149 let result = self
150 .collection
151 .update_many(filter, doc! { "$set": update })
152 .await
153 .unwrap();
154 (result.matched_count, result.modified_count)
155 }
156
157 async fn delete_by_id(&self, id: &str) -> bool {
158 let object_id = ObjectId::parse_str(id).unwrap();
159 let result = self
160 .collection
161 .delete_one(doc! { "_id": object_id })
162 .await
163 .unwrap();
164 result.deleted_count > 0
165 }
166
167 async fn find(&self, option: FindOption) -> Vec<E> {
168 let options: FindOptions = option.clone().into();
169 let mut cursor = self
170 .collection
171 .find(option.filter.unwrap_or(doc! {}))
172 .with_options(options)
173 .await
174 .unwrap();
175 let mut results = Vec::new();
176 while let Some(doc) = cursor.try_next().await.unwrap() {
177 results.push(E::from_document(doc));
178 }
179 results
180 }
181
182 async fn find_one_by(&self, filter: Document) -> Option<E> {
183 self.collection
184 .find_one(filter)
185 .await
186 .unwrap()
187 .map(E::from_document)
188 }
189
190 async fn count(&self, filter: Option<Document>) -> u64 {
191 self.collection
192 .count_documents(filter.unwrap_or_default())
193 .await
194 .unwrap()
195 }
196
197 async fn exists(&self, filter: Document) -> bool {
198 self.collection.count_documents(filter).await.unwrap() > 0
199 }
200
201 async fn delete_by(&self, filter: Document) -> u64 {
202 let result = self.collection.delete_many(filter).await.unwrap();
203 result.deleted_count
204 }
205
206 async fn aggregate(&self, stages: Vec<Document>) -> Vec<Document> {
207 let cursor = self.collection.aggregate(stages).await.unwrap();
208 cursor.try_collect().await.unwrap()
209 }
210}