revolt_database/drivers/
mongodb.rs

1use std::collections::HashMap;
2use std::ops::Deref;
3
4use futures::StreamExt;
5use mongodb::bson::{doc, to_document, Document};
6use mongodb::error::Result;
7use mongodb::options::{FindOneOptions, FindOptions};
8use mongodb::results::{DeleteResult, InsertOneResult, UpdateResult};
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11
12database_derived!(
13    #[cfg(feature = "mongodb")]
14    /// MongoDB implementation
15    pub struct MongoDb(pub ::mongodb::Client, pub String);
16);
17
18impl Deref for MongoDb {
19    type Target = mongodb::Client;
20
21    fn deref(&self) -> &Self::Target {
22        &self.0
23    }
24}
25
26#[allow(dead_code)]
27impl MongoDb {
28    /// Get the Revolt database
29    pub fn db(&self) -> mongodb::Database {
30        self.database(&self.1)
31    }
32
33    /// Get a collection by its name
34    pub fn col<T: Send + Sync>(&self, collection: &str) -> mongodb::Collection<T> {
35        self.db().collection(collection)
36    }
37
38    /// Insert one document into a collection
39    pub async fn insert_one<T: Serialize + Send + Sync>(
40        &self,
41        collection: &'static str,
42        document: T,
43    ) -> Result<InsertOneResult> {
44        self.col::<T>(collection).insert_one(document).await
45    }
46
47    /// Count documents by projection
48    pub async fn count_documents(
49        &self,
50        collection: &'static str,
51        projection: Document,
52    ) -> Result<u64> {
53        self.col::<Document>(collection)
54            .count_documents(projection)
55            .await
56    }
57
58    /// Find multiple documents in a collection with options
59    pub async fn find_with_options<O, T: DeserializeOwned + Unpin + Send + Sync>(
60        &self,
61        collection: &'static str,
62        projection: Document,
63        options: O,
64    ) -> Result<Vec<T>>
65    where
66        O: Into<Option<FindOptions>>,
67    {
68        Ok(self
69            .col::<T>(collection)
70            .find(projection)
71            .with_options(options)
72            .await?
73            .filter_map(|s| async {
74                if cfg!(debug_assertions) {
75                    // Hard fail on invalid documents
76                    Some(s.unwrap())
77                } else {
78                    s.ok()
79                }
80            })
81            .collect::<Vec<T>>()
82            .await)
83    }
84
85    /// Find multiple documents in a collection
86    pub async fn find<T: DeserializeOwned + Unpin + Send + Sync>(
87        &self,
88        collection: &'static str,
89        projection: Document,
90    ) -> Result<Vec<T>> {
91        self.find_with_options(collection, projection, None).await
92    }
93
94    /// Find one document with options
95    pub async fn find_one_with_options<O, T: DeserializeOwned + Unpin + Send + Sync>(
96        &self,
97        collection: &'static str,
98        projection: Document,
99        options: O,
100    ) -> Result<Option<T>>
101    where
102        O: Into<Option<FindOneOptions>>,
103    {
104        self.col::<T>(collection)
105            .find_one(projection)
106            .with_options(options)
107            .await
108    }
109
110    /// Find one document
111    pub async fn find_one<T: DeserializeOwned + Unpin + Send + Sync>(
112        &self,
113        collection: &'static str,
114        projection: Document,
115    ) -> Result<Option<T>> {
116        self.find_one_with_options(collection, projection, None)
117            .await
118    }
119
120    /// Find one document by its ID
121    pub async fn find_one_by_id<T: DeserializeOwned + Unpin + Send + Sync>(
122        &self,
123        collection: &'static str,
124        id: &str,
125    ) -> Result<Option<T>> {
126        self.find_one(
127            collection,
128            doc! {
129                "_id": id
130            },
131        )
132        .await
133    }
134
135    /// Update one document given a projection, partial document, and list of paths to unset
136    pub async fn update_one<P, T: Serialize>(
137        &self,
138        collection: &'static str,
139        projection: Document,
140        partial: T,
141        remove: Vec<&dyn IntoDocumentPath>,
142        prefix: P,
143    ) -> Result<UpdateResult>
144    where
145        P: Into<Option<String>>,
146    {
147        let prefix = prefix.into();
148
149        let mut unset = doc! {};
150        for field in remove {
151            if let Some(path) = field.as_path() {
152                if let Some(prefix) = &prefix {
153                    unset.insert(prefix.to_owned() + path, 1_i32);
154                } else {
155                    unset.insert(path, 1_i32);
156                }
157            }
158        }
159
160        let query = doc! {
161            "$unset": unset,
162            "$set": if let Some(prefix) = &prefix {
163                to_document(&prefix_keys(&partial, prefix))
164            } else {
165                to_document(&partial)
166            }?
167        };
168
169        self.col::<Document>(collection)
170            .update_one(projection, query)
171            .await
172    }
173
174    /// Update one document given an ID, partial document, and list of paths to unset
175    pub async fn update_one_by_id<P, T: Serialize>(
176        &self,
177        collection: &'static str,
178        id: &str,
179        partial: T,
180        remove: Vec<&dyn IntoDocumentPath>,
181        prefix: P,
182    ) -> Result<UpdateResult>
183    where
184        P: Into<Option<String>>,
185    {
186        self.update_one(
187            collection,
188            doc! {
189                "_id": id
190            },
191            partial,
192            remove,
193            prefix,
194        )
195        .await
196    }
197
198    /// Delete one document by the given projection
199    pub async fn delete_one(
200        &self,
201        collection: &'static str,
202        projection: Document,
203    ) -> Result<DeleteResult> {
204        self.col::<Document>(collection)
205            .delete_one(projection)
206            .await
207    }
208
209    /// Delete one document by the given ID
210    pub async fn delete_one_by_id(
211        &self,
212        collection: &'static str,
213        id: &str,
214    ) -> Result<DeleteResult> {
215        self.delete_one(
216            collection,
217            doc! {
218                "_id": id
219            },
220        )
221        .await
222    }
223}
224
225/// Just a string ID struct
226#[derive(Deserialize)]
227pub struct DocumentId {
228    #[serde(rename = "_id")]
229    pub id: String,
230}
231
232pub trait IntoDocumentPath: Send + Sync {
233    /// Create JSON key path
234    fn as_path(&self) -> Option<&'static str>;
235}
236
237/// Prefix keys on an arbitrary object
238pub fn prefix_keys<T: Serialize>(t: &T, prefix: &str) -> HashMap<String, serde_json::Value> {
239    let v: String = serde_json::to_string(t).unwrap();
240    let v: HashMap<String, serde_json::Value> = serde_json::from_str(&v).unwrap();
241    v.into_iter()
242        .filter(|(_k, v)| !v.is_null())
243        .map(|(k, v)| (format!("{}{}", prefix.to_owned(), k), v))
244        .collect()
245}