1use futures::stream::StreamExt;
2use mongodb::bson::doc;
3use mongodb::bson::Document;
4use mongodb::options::{FindOneAndUpdateOptions, IndexOptions};
5use mongodb::{options::ClientOptions, Client, ClientSession, Database, IndexModel};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use crate::config::Config as MongoDbConfig;
10
11async fn collections(client: &Client, database: &str) -> mongodb::error::Result<Arc<[Arc<str>]>> {
12 Ok(client
13 .database(database)
14 .list_collection_names()
15 .await?
16 .into_iter()
17 .map(Arc::from)
18 .collect())
19}
20
21struct Inner {
22 db_name: Arc<str>,
23 admin_db_name: Arc<str>,
24 client: Client,
25 admin: Client,
26 is_sharded: bool,
27 collections: RwLock<Arc<[Arc<str>]>>,
28}
29
30#[derive(serde::Deserialize)]
31pub struct DbUser {
33 user: String,
37 db: String,
38}
39
40#[derive(serde::Deserialize)]
41pub struct DbUsers {
42 users: Vec<DbUser>,
43}
44
45#[derive(Clone)]
46pub struct DB {
47 inner: Arc<Inner>,
48}
49
50impl DB {
51 pub async fn new(app_name: &str, cfg: &MongoDbConfig) -> mongodb::error::Result<Self> {
52 tracing::info!("'{app_name}' -> connects to mongodb '{}'", cfg.database());
53 let mut client_options = ClientOptions::parse(cfg.root_address()).await?;
54 client_options.app_name = Some(app_name.to_string());
55 let admin = Client::with_options(client_options)?;
56 let collections = RwLock::new(collections(&admin, cfg.database()).await?);
57 if let (Some(username), Some(password)) = (cfg.username(), cfg.password()) {
58 let db_users = mongodb::bson::from_document::<DbUsers>(
59 admin
60 .database(cfg.database())
61 .run_command(doc! {
62 "usersInfo": [{
63 "db": cfg.database(),
64 "user": username,
65 }],
66 "showPrivileges": false,
67 "showCredentials": false,
68 })
69 .await?,
70 )
71 .ok();
72 if !db_users
73 .map(|u| {
74 u.users
75 .iter()
76 .any(|u: &DbUser| u.db == cfg.database() && u.user == username)
77 })
78 .unwrap_or(false)
79 {
80 tracing::info!(
81 "{app_name} -> create user {} for db {}",
82 username,
83 cfg.database()
84 );
85 admin
86 .database(cfg.database())
87 .run_command(doc! {
88 "createUser": username,
89 "pwd": password,
90 "roles": [
91 {
92 "role": "readWrite",
93 "db": cfg.database(),
94 }
95 ]
96 })
97 .await?;
98 }
99 }
100 let mut client_options = ClientOptions::parse(cfg.address()).await?;
101 client_options.app_name = Some(app_name.to_string());
102 let client = Client::with_options(client_options)?;
103 let is_sharded = cfg.sharded();
104 let db = Self {
105 inner: Arc::new(Inner {
106 db_name: Arc::from(cfg.database()),
107 admin_db_name: Arc::from(cfg.root_database()),
108 client,
109 admin,
110 is_sharded,
111 collections,
112 }),
113 };
114 db.setup(cfg).await?;
115 Ok(db)
116 }
117
118 pub fn is_sharded(&self) -> bool {
119 self.inner.is_sharded
120 }
121
122 pub async fn session(&self) -> mongodb::error::Result<ClientSession> {
123 self.inner.client.start_session().await
124 }
125
126 pub fn get(&self) -> Database {
127 self.inner.client.database(&self.inner.db_name)
128 }
129
130 pub fn get_admin(&self) -> Database {
131 self.inner.admin.database(&self.inner.admin_db_name)
132 }
133
134 pub fn db_name(&self) -> &str {
135 &self.inner.db_name
136 }
137
138 pub async fn setup(&self, cfg: &MongoDbConfig) -> mongodb::error::Result<()> {
139 if self.is_sharded() {
140 self.get_admin()
141 .run_command(doc! {
142 "enableSharding": cfg.database()
143 })
144 .await?;
145 }
146 for col in self.inner.collections.read().await.as_ref().iter() {
147 tracing::debug!("found collection: {}", col);
148 }
149 Ok(())
150 }
151
152 pub async fn collections(&self) -> Arc<[Arc<str>]> {
153 self.inner.collections.read().await.clone()
154 }
155
156 pub async fn update_collections(&self) -> mongodb::error::Result<()> {
157 *self.inner.collections.write().await =
158 collections(&self.inner.client, self.db_name()).await?;
159 Ok(())
160 }
161
162 pub async fn ensure_collection_with_sharding(
163 &self,
164 collections: &[String],
165 name: &str,
166 shard_key: &str,
167 ) -> mongodb::error::Result<()> {
168 if !collections.iter().any(|c| c == name) {
169 self.get().create_collection(name).await.ok();
170 self.get()
171 .collection::<()>(name)
172 .create_index(IndexModel::builder().keys(doc! { shard_key: 1 }).build())
173 .await?;
174 if self.is_sharded() {
175 self.get_admin()
176 .run_command(doc! {
177 "shardCollection": &format!("{}.{}", self.inner.db_name, name),
178 "key": { shard_key: "hashed" },
179 })
180 .await?;
181 }
182 }
183 Ok(())
184 }
185
186 pub async fn ensure_collection_with_indexes(
187 &self,
188 collections: &[String],
189 name: &str,
190 indexes: Vec<(Document, bool)>,
191 ) -> mongodb::error::Result<bool> {
192 if !collections.iter().any(|c| c == name) {
193 self.get().create_collection(name).await?;
194 for index in indexes {
195 self.get()
196 .collection::<()>(name)
197 .create_index(
198 IndexModel::builder()
199 .keys(index.0)
200 .options(IndexOptions::builder().unique(index.1).build())
201 .build(),
202 )
203 .await?;
204 }
205 return Ok(true);
206 }
207 Ok(false)
208 }
209
210 pub async fn cleanup(&self) -> mongodb::error::Result<()> {
211 for collection in self
212 .inner
213 .admin
214 .database(self.db_name())
215 .list_collection_names()
216 .await?
217 {
218 if &collection != "api_jwt_secrets" {
219 self.inner
220 .admin
221 .database(self.db_name())
222 .collection::<Document>(&collection)
223 .delete_many(doc! {})
224 .await?;
225 }
226 }
227 Ok(())
228 }
229}
230
231pub async fn parse_vec<T>(cursor: mongodb::Cursor<Document>) -> Vec<T>
232where
233 T: serde::de::DeserializeOwned,
234{
235 cursor
236 .filter_map(|v| async {
237 v.ok().and_then(|v| {
238 mongodb::bson::from_document::<T>(v)
239 .map_err(|e| {
240 tracing::error!("Error while parsing MongoDB document: {e:#?}");
241 e
242 })
243 .ok()
244 })
245 })
246 .collect()
247 .await
248}
249
250pub fn insert_always_opts() -> Option<FindOneAndUpdateOptions> {
251 let mut opts = FindOneAndUpdateOptions::default();
252 opts.upsert = Some(true);
253 Some(opts)
254}
255
256#[macro_export]
257macro_rules! db {
258 ($storage:ty) => {
259 impl AsRef<qm::mongodb::DB> for $storage {
260 fn as_ref(&self) -> &qm::mongodb::DB {
261 &self.inner.db
262 }
263 }
264 };
265}