Skip to main content

cargo_smith/database/
mongo.rs

1use mongodb::{Client, ClientSession, Collection, bson::{Bson, Document, doc}};
2use serde::{de::DeserializeOwned, Serialize};
3use futures::TryStreamExt;
4use crate::{AppError, common::{PaginatedResponse, PaginationQuery, extractors::pagination::SortDirection}};
5
6pub trait MongoPaginationExt {
7    fn to_filter(&self, searchable_fields: &[&str]) -> Document;
8    fn to_sort(&self) -> Document;
9}
10
11impl MongoPaginationExt for PaginationQuery {
12    fn to_filter(&self, searchable_fields: &[&str]) -> Document {
13        let mut filter = doc! {};
14        
15        // Text Search Logic
16        if let Some(ref query) = self.q {
17            if !query.is_empty() {
18                let regex = doc! { "$regex": query, "$options": "i" };
19                let or: Vec<Document> = searchable_fields.iter()
20                    .map(|&f| doc! { f: regex.clone() })
21                    .collect();
22                if !or.is_empty() { filter.insert("$or", or); }
23            }
24        }
25
26        // Dynamic Filters Logic
27        for (key, value) in &self.filters {
28            let bson_val = match value.to_lowercase().as_str() {
29                "true" => Bson::Boolean(true),
30                "false" => Bson::Boolean(false),
31                _ => {
32                    // Try to parse as Int/Float, fallback to String
33                    if let Ok(i) = value.parse::<i64>() { Bson::Int64(i) }
34                    else if let Ok(f) = value.parse::<f64>() { Bson::Double(f) }
35                    else { Bson::String(value.to_string()) }
36                }
37            };
38            filter.insert(key, bson_val);
39        }
40        filter
41    }
42
43    fn to_sort(&self) -> Document {
44        let field = self.sort_by.as_deref().unwrap_or("_id");
45        let order = self.sort_order.as_ref().unwrap_or(&SortDirection::Desc).to_int();
46        doc! { field: order }
47    }
48}
49
50pub struct MongoCollection<T>
51where
52    T: Serialize + DeserializeOwned + Unpin + Send + Sync,
53{
54    pub collection: Collection<T>,
55}
56
57impl<T> MongoCollection<T>
58where
59    T: Serialize + DeserializeOwned + Unpin + Send + Sync,
60{
61    pub async fn new(uri: &str, db_name: &str, collection_name: &str) -> Self {
62        let client = Client::with_uri_str(uri)
63            .await
64            .expect("Failed to connect to MongoDB");
65        let collection = client
66            .database(db_name)
67            .collection::<T>(collection_name);
68        Self { collection }
69    }
70
71    pub async fn insert(
72        &self, 
73        doc: &T, 
74        session: Option<&mut ClientSession>
75    ) -> Result<(), AppError> {
76        let mut action = self.collection.insert_one(doc);
77        
78        // If a session is provided, attach it to the action
79        if let Some(s) = session { action = action.session(s); }
80
81        action.await.map_err(AppError::Database)?;
82        Ok(())
83    }
84
85    pub fn parse_id(&self, id: &str) -> Result<mongodb::bson::oid::ObjectId, AppError> {
86        mongodb::bson::oid::ObjectId::parse_str(id)
87            .map_err(|_| AppError::BadRequest("Invalid ID format".into()))
88    }
89
90    pub async fn find_by_id(
91        &self, 
92        id: &str, 
93        session: Option<&mut ClientSession>
94    ) -> Result<T, AppError> {
95        let obj_id = self.parse_id(id)?;
96        let filter = doc! { "_id": obj_id };
97        
98        let mut action = self.collection.find_one(filter);
99        if let Some(s) = session { action = action.session(s); }
100
101        action.await
102            .map_err(AppError::Database)?
103            .ok_or_else(|| AppError::NotFound(format!("Entity {} not found", id)))
104    }
105
106    pub async fn find_one(&self, filter: Document) -> Result<T, AppError> {
107        self.collection
108            .find_one(filter)
109            .await
110            .map_err(AppError::Database)?
111            .ok_or_else(|| AppError::NotFound("Resource not found".into()))
112    }
113
114    pub async fn find_all(&self) -> Result<Vec<T>, AppError> {
115        self.collection
116            .find(doc! {})
117            .await
118            .map_err(AppError::Database)?
119            .try_collect()
120            .await
121            .map_err(AppError::Database)
122    }
123
124    pub async fn find_many(&self, filter: Document) -> Result<Vec<T>, AppError> {
125        self.collection
126            .find(filter)
127            .await
128            .map_err(AppError::Database)?
129            .try_collect()
130            .await
131            .map_err(AppError::Database)
132    }
133
134    pub async fn aggregate(&self, pipeline: Vec<Document>) -> Result<Vec<Document>, AppError> {
135        self.collection
136            .aggregate(pipeline)
137            .await
138            .map_err(AppError::Database)?
139            .try_collect()
140            .await
141            .map_err(AppError::Database)
142    }
143
144    pub async fn update_one(
145        &self, 
146        filter: Document, 
147        update: Document, 
148        session: Option<&mut ClientSession>
149    ) -> Result<bool, AppError> {
150        let mut action = self.collection.update_one(filter, update);
151        if let Some(s) = session { action = action.session(s); }
152
153        let result = action.await.map_err(AppError::Database)?;
154        Ok(result.matched_count > 0)
155    }
156
157    pub async fn update_many(&self, filter: Document, update: Document) -> Result<u64, AppError> {
158        let result = self.collection
159            .update_many(filter, update)
160            .await
161            .map_err(AppError::Database)?;
162        Ok(result.modified_count)
163    }
164
165    pub async fn delete_one(
166        &self, 
167        filter: Document, 
168        session: Option<&mut ClientSession>
169    ) -> Result<bool, AppError> {
170        let mut action = self.collection.delete_one(filter);
171        if let Some(s) = session { action = action.session(s); }
172
173        let result = action.await.map_err(AppError::Database)?;
174        Ok(result.deleted_count > 0)
175    }
176
177    pub async fn find_paginated(
178        &self,
179        query: PaginationQuery,
180        searchable_fields: Vec<&str>,
181    ) -> Result<PaginatedResponse<T>, AppError> {
182        let filter = query.to_filter(&searchable_fields);
183        let sort = query.to_sort();
184
185        let total = self.collection
186            .count_documents(filter.clone())
187            .await
188            .map_err(AppError::Database)?;
189
190        let mut cursor = self.collection
191            .find(filter)
192            .limit(query.get_limit())
193            .skip(query.skip())
194            .sort(sort)
195            .await
196            .map_err(AppError::Database)?;
197
198        let mut items = Vec::new();
199        while let Ok(Some(item)) = cursor.try_next().await {
200            items.push(item);
201        }
202
203        Ok(PaginatedResponse::new(items, total, &query))
204    }
205
206    pub fn raw(&self) -> &Collection<T> {
207        &self.collection
208    }
209
210}
211
212pub struct Db {
213    pub client: Client,
214    pub db_name: String,
215}
216
217impl Db {
218    pub async fn connect(uri: &str, db_name: &str) -> Self {
219        let client = Client::with_uri_str(uri)
220            .await
221            .expect("Failed to connect to MongoDB");
222        Self { client, db_name: db_name.to_string() }
223    }
224
225    /// Get a typed collection — call this once per handler or store in Data<>
226    pub fn collection<T>(&self, name: &str) -> MongoCollection<T>
227    where
228        T: Serialize + DeserializeOwned + Unpin + Send + Sync,
229    {
230        MongoCollection {
231            collection: self.client
232                .database(&self.db_name)
233                .collection::<T>(name),
234        }
235    }
236
237    pub async fn start_transaction(&self) -> Result<mongodb::ClientSession, AppError> {
238        let mut session = self.client
239            .start_session()
240            .await
241            .map_err(AppError::Database)?;
242            
243        session.start_transaction()
244            .await
245            .map_err(AppError::Database)?;
246            
247        Ok(session)
248    }
249
250    pub async fn commit_transaction(&self, mut session: mongodb::ClientSession) -> Result<(), AppError> {
251        session.commit_transaction()
252            .await
253            .map_err(AppError::Database)
254    }
255
256    pub async fn abort_transaction(&self, mut session: mongodb::ClientSession) -> Result<(), AppError> {
257        session.abort_transaction()
258            .await
259            .map_err(AppError::Database)
260    }
261}