1use std::marker::PhantomData;
4
5use bson::{Document, doc};
6use futures::TryStreamExt;
7use mongodb::Collection;
8use prax_query::QueryResult;
9use prax_query::filter::FilterValue;
10use prax_query::traits::{BoxFuture, Model, QueryEngine};
11use tracing::debug;
12
13use crate::client::MongoClient;
14use crate::error::MongoError;
15use crate::types::filter_value_to_bson;
16
17#[derive(Clone)]
22pub struct MongoEngine {
23 client: MongoClient,
24}
25
26impl MongoEngine {
27 pub fn new(client: MongoClient) -> Self {
29 Self { client }
30 }
31
32 pub fn client(&self) -> &MongoClient {
34 &self.client
35 }
36
37 pub fn collection<T>(&self) -> Collection<T>
39 where
40 T: Model + Send + Sync,
41 {
42 let collection_name = format!("{}s", T::MODEL_NAME.to_lowercase());
44 self.client.collection(&collection_name)
45 }
46
47 pub fn collection_by_name<T>(&self, name: &str) -> Collection<T>
49 where
50 T: Send + Sync,
51 {
52 self.client.collection(name)
53 }
54
55 fn build_filter(sql: &str, params: &[FilterValue]) -> MongoResult<Document> {
57 if sql.starts_with('{') {
60 let filter: Document = serde_json::from_str(sql)
62 .map_err(|e| MongoError::query(format!("invalid filter JSON: {}", e)))?;
63 Ok(filter)
64 } else if sql.is_empty() {
65 Ok(doc! {})
67 } else {
68 let mut filter = Document::new();
71
72 for part in sql.split(" AND ") {
74 let part = part.trim();
75 if let Some(eq_pos) = part.find('=') {
76 let field = part[..eq_pos].trim();
77 let value_placeholder = part[eq_pos + 1..].trim();
78
79 if let Some(stripped) = value_placeholder.strip_prefix('$') {
81 if let Ok(param_idx) = stripped.parse::<usize>() {
82 if param_idx > 0 && param_idx <= params.len() {
83 let bson_value = filter_value_to_bson(¶ms[param_idx - 1])?;
84 filter.insert(field, bson_value);
85 }
86 }
87 } else {
88 filter.insert(field, value_placeholder);
90 }
91 }
92 }
93
94 Ok(filter)
95 }
96 }
97}
98
99use crate::error::MongoResult;
100
101impl QueryEngine for MongoEngine {
102 fn query_many<T: Model + Send + 'static>(
103 &self,
104 sql: &str,
105 params: Vec<FilterValue>,
106 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
107 let sql = sql.to_string();
108 Box::pin(async move {
109 debug!(filter = %sql, "Executing query_many");
110
111 let filter = Self::build_filter(&sql, ¶ms)
112 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
113
114 let collection = self
115 .client
116 .collection_doc(&format!("{}s", T::MODEL_NAME.to_lowercase()));
117
118 let cursor = collection
119 .find(filter, None)
120 .await
121 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
122
123 let docs: Vec<Document> = cursor
124 .try_collect()
125 .await
126 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
127
128 let _ = docs;
131 Ok(Vec::new())
132 })
133 }
134
135 fn query_one<T: Model + Send + 'static>(
136 &self,
137 sql: &str,
138 params: Vec<FilterValue>,
139 ) -> BoxFuture<'_, QueryResult<T>> {
140 let sql = sql.to_string();
141 Box::pin(async move {
142 debug!(filter = %sql, "Executing query_one");
143
144 let filter = Self::build_filter(&sql, ¶ms)
145 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
146
147 let collection = self
148 .client
149 .collection_doc(&format!("{}s", T::MODEL_NAME.to_lowercase()));
150
151 let _doc = collection
152 .find_one(filter, None)
153 .await
154 .map_err(|e| prax_query::QueryError::database(e.to_string()))?
155 .ok_or_else(|| prax_query::QueryError::not_found(T::MODEL_NAME))?;
156
157 Err(prax_query::QueryError::internal(
159 "deserialization not yet implemented".to_string(),
160 ))
161 })
162 }
163
164 fn query_optional<T: Model + Send + 'static>(
165 &self,
166 sql: &str,
167 params: Vec<FilterValue>,
168 ) -> BoxFuture<'_, QueryResult<Option<T>>> {
169 let sql = sql.to_string();
170 Box::pin(async move {
171 debug!(filter = %sql, "Executing query_optional");
172
173 let filter = Self::build_filter(&sql, ¶ms)
174 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
175
176 let collection = self
177 .client
178 .collection_doc(&format!("{}s", T::MODEL_NAME.to_lowercase()));
179
180 let doc = collection
181 .find_one(filter, None)
182 .await
183 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
184
185 match doc {
186 Some(_doc) => {
187 Err(prax_query::QueryError::internal(
189 "deserialization not yet implemented".to_string(),
190 ))
191 }
192 None => Ok(None),
193 }
194 })
195 }
196
197 fn execute_insert<T: Model + Send + 'static>(
198 &self,
199 sql: &str,
200 params: Vec<FilterValue>,
201 ) -> BoxFuture<'_, QueryResult<T>> {
202 let sql = sql.to_string();
203 Box::pin(async move {
204 debug!(data = %sql, "Executing insert");
205
206 let doc: Document = if sql.starts_with('{') {
208 serde_json::from_str(&sql)
209 .map_err(|e| prax_query::QueryError::database(e.to_string()))?
210 } else {
211 let mut doc = Document::new();
213 for (i, param) in params.iter().enumerate() {
214 let bson_value = filter_value_to_bson(param)
215 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
216 doc.insert(format!("field{}", i), bson_value);
217 }
218 doc
219 };
220
221 let collection = self
222 .client
223 .collection_doc(&format!("{}s", T::MODEL_NAME.to_lowercase()));
224
225 let _result = collection
226 .insert_one(doc, None)
227 .await
228 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
229
230 Err(prax_query::QueryError::internal(
232 "insert returning not yet implemented".to_string(),
233 ))
234 })
235 }
236
237 fn execute_update<T: Model + Send + 'static>(
238 &self,
239 sql: &str,
240 _params: Vec<FilterValue>,
241 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
242 let sql = sql.to_string();
243 Box::pin(async move {
244 debug!(data = %sql, "Executing update");
245
246 let collection = self
248 .client
249 .collection_doc(&format!("{}s", T::MODEL_NAME.to_lowercase()));
250
251 let filter = doc! {};
253 let update = doc! { "$set": {} };
254
255 let _result = collection
256 .update_many(filter, update, None)
257 .await
258 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
259
260 Ok(Vec::new())
261 })
262 }
263
264 fn execute_delete(
265 &self,
266 sql: &str,
267 params: Vec<FilterValue>,
268 ) -> BoxFuture<'_, QueryResult<u64>> {
269 let sql = sql.to_string();
270 Box::pin(async move {
271 debug!(filter = %sql, "Executing delete");
272
273 let filter = Self::build_filter(&sql, ¶ms)
274 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
275
276 let collection = self.client.collection_doc("documents");
278
279 let result = collection
280 .delete_many(filter, None)
281 .await
282 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
283
284 Ok(result.deleted_count)
285 })
286 }
287
288 fn execute_raw(&self, sql: &str, _params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
289 let sql = sql.to_string();
290 Box::pin(async move {
291 debug!(command = %sql, "Executing raw command");
292
293 let command: Document = serde_json::from_str(&sql)
295 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
296
297 let _result = self
298 .client
299 .run_command(command)
300 .await
301 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
302
303 Ok(1)
304 })
305 }
306
307 fn count(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
308 let sql = sql.to_string();
309 Box::pin(async move {
310 debug!(filter = %sql, "Executing count");
311
312 let filter = Self::build_filter(&sql, ¶ms)
313 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
314
315 let collection = self.client.collection_doc("documents");
317
318 let count = collection
319 .count_documents(filter, None)
320 .await
321 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
322
323 Ok(count)
324 })
325 }
326}
327
328pub struct MongoQueryBuilder<T: Model> {
330 engine: MongoEngine,
331 _marker: PhantomData<T>,
332}
333
334impl<T: Model> MongoQueryBuilder<T> {
335 pub fn new(engine: MongoEngine) -> Self {
337 Self {
338 engine,
339 _marker: PhantomData,
340 }
341 }
342
343 pub fn engine(&self) -> &MongoEngine {
345 &self.engine
346 }
347
348 pub fn collection(&self) -> Collection<T>
350 where
351 T: Send + Sync,
352 {
353 self.engine.collection::<T>()
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_build_filter_json() {
363 let filter = MongoEngine::build_filter(r#"{"name": "Alice"}"#, &[]).unwrap();
364 assert_eq!(filter.get_str("name").unwrap(), "Alice");
365 }
366
367 #[test]
368 fn test_build_filter_empty() {
369 let filter = MongoEngine::build_filter("", &[]).unwrap();
370 assert!(filter.is_empty());
371 }
372}