Skip to main content

revolt_database/models/messages/ops/
mongodb.rs

1use bson::{to_bson, Document};
2use futures::try_join;
3use futures::StreamExt;
4use mongodb::options::FindOptions;
5use revolt_models::v0::MessageSort;
6use revolt_result::Result;
7use std::collections::{HashMap, HashSet};
8use std::time::SystemTime;
9use ulid::Ulid;
10
11use crate::{
12    AppendMessage, DocumentId, FieldsMessage, IntoDocumentPath, Message, MessageQuery,
13    MessageTimePeriod, MongoDb, PartialMessage,
14};
15
16use super::AbstractMessages;
17
18static COL: &str = "messages";
19
20#[async_trait]
21impl AbstractMessages for MongoDb {
22    /// Insert a new message into the database
23    async fn insert_message(&self, message: &Message) -> Result<()> {
24        query!(self, insert_one, COL, &message).map(|_| ())
25    }
26
27    /// Fetch a message by its id
28    async fn fetch_message(&self, id: &str) -> Result<Message> {
29        query!(self, find_one_by_id, COL, id)?.ok_or_else(|| create_error!(NotFound))
30    }
31
32    /// Fetch multiple messages by given query
33    async fn fetch_messages(&self, query: MessageQuery) -> Result<Vec<Message>> {
34        let mut filter = doc! {};
35
36        // 1. Apply message filters
37        if let Some(channel) = query.filter.channel {
38            filter.insert("channel", channel);
39        }
40
41        if let Some(author) = query.filter.author {
42            filter.insert("author", author);
43        }
44
45        let is_search_query = if let Some(query) = query.filter.query {
46            filter.insert(
47                "$text",
48                doc! {
49                    "$search": query
50                },
51            );
52
53            true
54        } else {
55            false
56        };
57
58        if let Some(pinned) = query.filter.pinned {
59            filter.insert("pinned", pinned);
60        };
61
62        // 2. Find query limit
63        let limit = query.limit.unwrap_or(50);
64
65        // 3. Apply message time period
66        match query.time_period {
67            MessageTimePeriod::Relative { nearby } => {
68                // 3.1. Prepare filters
69                let mut older_message_filter = filter.clone();
70                let mut newer_message_filter = filter;
71
72                older_message_filter.insert(
73                    "_id",
74                    doc! {
75                        "$lt": &nearby
76                    },
77                );
78
79                newer_message_filter.insert(
80                    "_id",
81                    doc! {
82                        "$gte": &nearby
83                    },
84                );
85
86                // 3.2. Execute in both directions
87                let (a, b) = try_join!(
88                    self.find_with_options::<_, Message>(
89                        COL,
90                        newer_message_filter,
91                        FindOptions::builder()
92                            .limit(limit / 2 + 1)
93                            .sort(doc! {
94                                "_id": 1_i32
95                            })
96                            .build(),
97                    ),
98                    self.find_with_options::<_, Message>(
99                        COL,
100                        older_message_filter,
101                        FindOptions::builder()
102                            .limit(limit / 2 + 1)
103                            .sort(doc! {
104                                "_id": -1_i32
105                            })
106                            .build(),
107                    )
108                )
109                .map_err(|_| create_database_error!("find", COL))?;
110
111                Ok([a, b].concat())
112            }
113            MessageTimePeriod::Absolute {
114                before,
115                after,
116                sort,
117            } => {
118                // 3.1. Apply message ID filter
119                if let Some(doc) = match (before, after) {
120                    (Some(before), Some(after)) => Some(doc! {
121                        "$lt": before,
122                        "$gt": after
123                    }),
124                    (Some(before), _) => Some(doc! {
125                        "$lt": before
126                    }),
127                    (_, Some(after)) => Some(doc! {
128                        "$gt": after
129                    }),
130                    _ => None,
131                } {
132                    filter.insert("_id", doc);
133                }
134
135                // 3.2. Execute with given message sort
136                self.find_with_options(
137                    COL,
138                    filter,
139                    FindOptions::builder()
140                        .limit(limit)
141                        .sort(match sort.unwrap_or(MessageSort::Latest) {
142                            // Sort by relevance, fallback to latest
143                            MessageSort::Relevance => {
144                                if is_search_query {
145                                    doc! {
146                                        "score": {
147                                            "$meta": "textScore"
148                                        }
149                                    }
150                                } else {
151                                    doc! {
152                                        "_id": -1_i32
153                                    }
154                                }
155                            }
156                            // Sort by latest first
157                            MessageSort::Latest => doc! {
158                                "_id": -1_i32
159                            },
160                            // Sort by oldest first
161                            MessageSort::Oldest => doc! {
162                                "_id": 1_i32
163                            },
164                        })
165                        .build(),
166                )
167                .await
168                .map_err(|_| create_database_error!("find", COL))
169            }
170        }
171    }
172
173    /// Fetch multiple messages by given IDs
174    async fn fetch_messages_by_id(&self, ids: &[String]) -> Result<Vec<Message>> {
175        self.find_with_options(
176            COL,
177            doc! {
178                "_id": {
179                    "$in": ids
180                }
181            },
182            None,
183        )
184        .await
185        .map_err(|_| create_database_error!("find", COL))
186    }
187
188    /// Update a given message with new information
189    async fn update_message(
190        &self,
191        id: &str,
192        message: &PartialMessage,
193        remove: Vec<FieldsMessage>,
194    ) -> Result<()> {
195        query!(
196            self,
197            update_one_by_id,
198            COL,
199            id,
200            message,
201            remove.iter().map(|x| x as &dyn IntoDocumentPath).collect(),
202            None
203        )
204        .map(|_| ())
205    }
206
207    /// Append information to a given message
208    async fn append_message(&self, id: &str, append: &AppendMessage) -> Result<()> {
209        let mut query = doc! {};
210
211        if let Some(embeds) = &append.embeds {
212            if !embeds.is_empty() {
213                query.insert(
214                    "$push",
215                    doc! {
216                        "embeds": {
217                            "$each": to_bson(embeds)
218                                .map_err(|_| create_database_error!("to_bson", "embeds"))?
219                        }
220                    },
221                );
222            }
223        }
224
225        if query.is_empty() {
226            return Ok(());
227        }
228
229        self.col::<Document>(COL)
230            .update_one(
231                doc! {
232                    "_id": id
233                },
234                query,
235            )
236            .await
237            .map(|_| ())
238            .map_err(|_| create_database_error!("update_one", COL))
239    }
240
241    /// Add a new reaction to a message
242    async fn add_reaction(&self, id: &str, emoji: &str, user: &str) -> Result<()> {
243        self.col::<Document>(COL)
244            .update_one(
245                doc! {
246                    "_id": id
247                },
248                doc! {
249                    "$addToSet": {
250                        format!("reactions.{emoji}"): user
251                    }
252                },
253            )
254            .await
255            .map(|_| ())
256            .map_err(|_| create_database_error!("update_one", COL))
257    }
258
259    /// Remove a reaction from a message
260    async fn remove_reaction(&self, id: &str, emoji: &str, user: &str) -> Result<()> {
261        self.col::<Document>(COL)
262            .update_one(
263                doc! {
264                    "_id": id
265                },
266                doc! {
267                    "$pull": {
268                        format!("reactions.{emoji}"): user
269                    }
270                },
271            )
272            .await
273            .map(|_| ())
274            .map_err(|_| create_database_error!("update_one", COL))
275    }
276
277    /// Remove reaction from a message
278    async fn clear_reaction(&self, id: &str, emoji: &str) -> Result<()> {
279        self.col::<Document>(COL)
280            .update_one(
281                doc! {
282                    "_id": id
283                },
284                doc! {
285                    "$unset": {
286                        format!("reactions.{emoji}"): 1
287                    }
288                },
289            )
290            .await
291            .map(|_| ())
292            .map_err(|_| create_database_error!("update_one", COL))
293    }
294
295    /// Delete a message from the database by its id
296    async fn delete_message(&self, id: &str) -> Result<()> {
297        query!(self, delete_one_by_id, COL, id).map(|_| ())
298    }
299
300    /// Delete messages from a channel by their ids and corresponding channel id
301    async fn delete_messages(&self, channel: &str, ids: &[String]) -> Result<()> {
302        self.col::<Document>(COL)
303            .delete_many(doc! {
304                "channel": channel,
305                "_id": {
306                    "$in": ids
307                }
308            })
309            .await
310            .map(|_| ())
311            .map_err(|_| create_database_error!("delete_many", COL))
312    }
313
314    /// Delete all messages from a specific author in a server from a certain ULID onwards
315    async fn delete_messages_by_author_since(
316        &self,
317        channels: &[String],
318        author: &str,
319        since: SystemTime,
320    ) -> Result<HashMap<String, Vec<String>>> {
321        let threshold_ulid = Ulid::from_datetime(since).to_string();
322
323        let filter = doc! {
324            "author": author,
325            "channel": { "$in": channels },
326            "_id": { "$gte": &threshold_ulid }
327        };
328
329        let pipeline = vec![
330            doc! { "$match": filter.clone() },
331            doc! {
332                "$project": {
333                    "channel": 1_i32,
334                    "message_id": "$_id",
335                    "attachment_ids": {
336                        "$map": {
337                            "input": { "$ifNull": ["$attachments", Vec::<bson::Bson>::new()] },
338                            "as": "a",
339                            "in": "$$a._id"
340                        }
341                    }
342                }
343            },
344            doc! {
345                "$group": {
346                    "_id": "$channel",
347                    "message_ids": { "$push": "$message_id" },
348                    "attachment_ids_nested": { "$push": "$attachment_ids" }
349                }
350            },
351            doc! {
352                "$project": {
353                    "message_ids": 1_i32,
354                    "attachment_ids": {
355                        "$reduce": {
356                            "input": "$attachment_ids_nested",
357                            "initialValue": Vec::<bson::Bson>::new(),
358                            "in": { "$setUnion": ["$$value", "$$this"] }
359                        }
360                    }
361                }
362            },
363        ];
364
365        #[derive(serde::Deserialize)]
366        struct AggregatedChannel {
367            #[serde(rename = "_id")]
368            channel: String,
369            message_ids: Vec<String>,
370            #[serde(default)]
371            attachment_ids: Vec<String>,
372        }
373
374        let mut cursor = self
375            .col::<Document>(COL)
376            .aggregate(pipeline)
377            .await
378            .map_err(|_| create_database_error!("aggregate", COL))?
379            .with_type::<AggregatedChannel>();
380
381        let mut deleted_messages: HashMap<String, Vec<String>> = HashMap::new();
382        let mut attachment_ids: HashSet<String> = HashSet::new();
383
384        while let Some(result) = cursor.next().await {
385            if let Ok(item) = result {
386                for id in item.attachment_ids {
387                    attachment_ids.insert(id);
388                }
389                deleted_messages.insert(item.channel, item.message_ids);
390            }
391        }
392
393        // Mark attachments as deleted before deleting messages
394        if !attachment_ids.is_empty() {
395            self.col::<Document>("attachments")
396                .update_many(
397                    doc! {
398                        "_id": {
399                            "$in": attachment_ids.into_iter().collect::<Vec<String>>()
400                        }
401                    },
402                    doc! {
403                        "$set": {
404                            "deleted": true
405                        }
406                    },
407                )
408                .await
409                .map_err(|_| create_database_error!("update_many", "attachments"))?;
410        }
411
412        self.col::<Document>(COL)
413            .delete_many(filter)
414            .await
415            .map_err(|_| create_database_error!("delete_many", COL))?;
416
417        Ok(deleted_messages)
418    }
419}
420
421impl IntoDocumentPath for FieldsMessage {
422    fn as_path(&self) -> Option<&'static str> {
423        Some(match self {
424            FieldsMessage::Pinned => "pinned",
425        })
426    }
427}
428
429impl MongoDb {
430    pub async fn delete_bulk_messages(&self, projection: Document) -> Result<()> {
431        let mut for_attachments = projection.clone();
432        for_attachments.insert(
433            "attachments",
434            doc! {
435                "$exists": 1_i32
436            },
437        );
438
439        // Check if there are any attachments we need to delete.
440        let message_ids_with_attachments = self
441            .find_with_options::<_, DocumentId>(
442                COL,
443                for_attachments,
444                FindOptions::builder()
445                    .projection(doc! { "_id": 1_i32 })
446                    .build(),
447            )
448            .await
449            .map_err(|_| create_database_error!("find_many", "attachments"))?
450            .into_iter()
451            .map(|x| x.id)
452            .collect::<Vec<String>>();
453
454        // If we found any, mark them as deleted.
455        if !message_ids_with_attachments.is_empty() {
456            self.col::<Document>("attachments")
457                .update_many(
458                    doc! {
459                        "message_id": {
460                            "$in": message_ids_with_attachments
461                        }
462                    },
463                    doc! {
464                        "$set": {
465                            "deleted": true
466                        }
467                    },
468                )
469                .await
470                .map_err(|_| create_database_error!("update_many", "attachments"))?;
471        }
472
473        // And then delete said messages.
474        self.col::<Document>(COL)
475            .delete_many(projection)
476            .await
477            .map(|_| ())
478            .map_err(|_| create_database_error!("delete_many", COL))
479    }
480}