qm_mongodb/
db.rs

1use futures::stream::StreamExt;
2use mongodb::bson::doc;
3use mongodb::bson::Document;
4use mongodb::options::{FindOneAndUpdateOptions, IndexOptions};
5use mongodb::{options::ClientOptions, Client, ClientSession, Database, IndexModel};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use crate::config::Config as MongoDbConfig;
10
11async fn collections(client: &Client, database: &str) -> mongodb::error::Result<Arc<[Arc<str>]>> {
12    Ok(client
13        .database(database)
14        .list_collection_names()
15        .await?
16        .into_iter()
17        .map(Arc::from)
18        .collect())
19}
20
21struct Inner {
22    db_name: Arc<str>,
23    admin_db_name: Arc<str>,
24    client: Client,
25    admin: Client,
26    is_sharded: bool,
27    collections: RwLock<Arc<[Arc<str>]>>,
28}
29
30#[derive(serde::Deserialize)]
31// #[serde(rename_all = "camelCase")]
32pub struct DbUser {
33    // #[serde(rename = "_id")]
34    // id: String,
35    // user_id: Uuid,
36    user: String,
37    db: String,
38}
39
40#[derive(serde::Deserialize)]
41pub struct DbUsers {
42    users: Vec<DbUser>,
43}
44
45#[derive(Clone)]
46pub struct DB {
47    inner: Arc<Inner>,
48}
49
50impl DB {
51    pub async fn new(app_name: &str, cfg: &MongoDbConfig) -> mongodb::error::Result<Self> {
52        tracing::info!("'{app_name}' -> connects to mongodb '{}'", cfg.database());
53        let mut client_options = ClientOptions::parse(cfg.root_address()).await?;
54        client_options.app_name = Some(app_name.to_string());
55        let admin = Client::with_options(client_options)?;
56        let collections = RwLock::new(collections(&admin, cfg.database()).await?);
57        if let (Some(username), Some(password)) = (cfg.username(), cfg.password()) {
58            let db_users = mongodb::bson::from_document::<DbUsers>(
59                admin
60                    .database(cfg.database())
61                    .run_command(doc! {
62                        "usersInfo": [{
63                            "db": cfg.database(),
64                            "user": username,
65                        }],
66                        "showPrivileges": false,
67                        "showCredentials": false,
68                    })
69                    .await?,
70            )
71            .ok();
72            if !db_users
73                .map(|u| {
74                    u.users
75                        .iter()
76                        .any(|u: &DbUser| u.db == cfg.database() && u.user == username)
77                })
78                .unwrap_or(false)
79            {
80                tracing::info!(
81                    "{app_name} -> create user {} for db {}",
82                    username,
83                    cfg.database()
84                );
85                admin
86                    .database(cfg.database())
87                    .run_command(doc! {
88                        "createUser": username,
89                        "pwd": password,
90                        "roles": [
91                            {
92                                "role": "readWrite",
93                                "db": cfg.database(),
94                            }
95                        ]
96                    })
97                    .await?;
98            }
99        }
100        let mut client_options = ClientOptions::parse(cfg.address()).await?;
101        client_options.app_name = Some(app_name.to_string());
102        let client = Client::with_options(client_options)?;
103        let is_sharded = cfg.sharded();
104        let db = Self {
105            inner: Arc::new(Inner {
106                db_name: Arc::from(cfg.database()),
107                admin_db_name: Arc::from(cfg.root_database()),
108                client,
109                admin,
110                is_sharded,
111                collections,
112            }),
113        };
114        db.setup(cfg).await?;
115        Ok(db)
116    }
117
118    pub fn is_sharded(&self) -> bool {
119        self.inner.is_sharded
120    }
121
122    pub async fn session(&self) -> mongodb::error::Result<ClientSession> {
123        self.inner.client.start_session().await
124    }
125
126    pub fn get(&self) -> Database {
127        self.inner.client.database(&self.inner.db_name)
128    }
129
130    pub fn get_admin(&self) -> Database {
131        self.inner.admin.database(&self.inner.admin_db_name)
132    }
133
134    pub fn db_name(&self) -> &str {
135        &self.inner.db_name
136    }
137
138    pub async fn setup(&self, cfg: &MongoDbConfig) -> mongodb::error::Result<()> {
139        if self.is_sharded() {
140            self.get_admin()
141                .run_command(doc! {
142                    "enableSharding": cfg.database()
143                })
144                .await?;
145        }
146        for col in self.inner.collections.read().await.as_ref().iter() {
147            tracing::debug!("found collection: {}", col);
148        }
149        Ok(())
150    }
151
152    pub async fn collections(&self) -> Arc<[Arc<str>]> {
153        self.inner.collections.read().await.clone()
154    }
155
156    pub async fn update_collections(&self) -> mongodb::error::Result<()> {
157        *self.inner.collections.write().await =
158            collections(&self.inner.client, self.db_name()).await?;
159        Ok(())
160    }
161
162    pub async fn ensure_collection_with_sharding(
163        &self,
164        collections: &[String],
165        name: &str,
166        shard_key: &str,
167    ) -> mongodb::error::Result<()> {
168        if !collections.iter().any(|c| c == name) {
169            self.get().create_collection(name).await.ok();
170            self.get()
171                .collection::<()>(name)
172                .create_index(IndexModel::builder().keys(doc! { shard_key: 1 }).build())
173                .await?;
174            if self.is_sharded() {
175                self.get_admin()
176                    .run_command(doc! {
177                        "shardCollection": &format!("{}.{}", self.inner.db_name, name),
178                        "key": { shard_key: "hashed" },
179                    })
180                    .await?;
181            }
182        }
183        Ok(())
184    }
185
186    pub async fn ensure_collection_with_indexes(
187        &self,
188        collections: &[String],
189        name: &str,
190        indexes: Vec<(Document, bool)>,
191    ) -> mongodb::error::Result<bool> {
192        if !collections.iter().any(|c| c == name) {
193            self.get().create_collection(name).await?;
194            for index in indexes {
195                self.get()
196                    .collection::<()>(name)
197                    .create_index(
198                        IndexModel::builder()
199                            .keys(index.0)
200                            .options(IndexOptions::builder().unique(index.1).build())
201                            .build(),
202                    )
203                    .await?;
204            }
205            return Ok(true);
206        }
207        Ok(false)
208    }
209
210    pub async fn cleanup(&self) -> mongodb::error::Result<()> {
211        for collection in self
212            .inner
213            .admin
214            .database(self.db_name())
215            .list_collection_names()
216            .await?
217        {
218            if &collection != "api_jwt_secrets" {
219                self.inner
220                    .admin
221                    .database(self.db_name())
222                    .collection::<Document>(&collection)
223                    .delete_many(doc! {})
224                    .await?;
225            }
226        }
227        Ok(())
228    }
229}
230
231pub async fn parse_vec<T>(cursor: mongodb::Cursor<Document>) -> Vec<T>
232where
233    T: serde::de::DeserializeOwned,
234{
235    cursor
236        .filter_map(|v| async {
237            v.ok().and_then(|v| {
238                mongodb::bson::from_document::<T>(v)
239                    .map_err(|e| {
240                        tracing::error!("Error while parsing MongoDB document: {e:#?}");
241                        e
242                    })
243                    .ok()
244            })
245        })
246        .collect()
247        .await
248}
249
250pub fn insert_always_opts() -> Option<FindOneAndUpdateOptions> {
251    let mut opts = FindOneAndUpdateOptions::default();
252    opts.upsert = Some(true);
253    Some(opts)
254}
255
256#[macro_export]
257macro_rules! db {
258    ($storage:ty) => {
259        impl AsRef<qm::mongodb::DB> for $storage {
260            fn as_ref(&self) -> &qm::mongodb::DB {
261                &self.inner.db
262            }
263        }
264    };
265}