Skip to main content

revolt_database/models/messages/ops/
reference.rs

1use std::collections::HashMap;
2use futures::future::try_join_all;
3use indexmap::IndexSet;
4use revolt_result::Result;
5use std::time::SystemTime;
6use ulid::Ulid;
7use crate::{AppendMessage, FieldsMessage, Message, MessageQuery, PartialMessage, ReferenceDb};
8
9use super::AbstractMessages;
10
11#[async_trait]
12impl AbstractMessages for ReferenceDb {
13    /// Insert a new message into the database
14    async fn insert_message(&self, message: &Message) -> Result<()> {
15        let mut messages = self.messages.lock().await;
16        if messages.contains_key(&message.id) {
17            Err(create_database_error!("insert", "message"))
18        } else {
19            messages.insert(message.id.to_string(), message.clone());
20            Ok(())
21        }
22    }
23
24    /// Fetch a message by its id
25    async fn fetch_message(&self, id: &str) -> Result<Message> {
26        let messages = self.messages.lock().await;
27        messages
28            .get(id)
29            .cloned()
30            .ok_or_else(|| create_error!(NotFound))
31    }
32
33    /// Fetch multiple messages by given query
34    async fn fetch_messages(&self, query: MessageQuery) -> Result<Vec<Message>> {
35        let messages = self.messages.lock().await;
36        let matched_messages = messages
37            .values()
38            .filter(|message| {
39                if let Some(channel) = &query.filter.channel {
40                    if &message.channel != channel {
41                        return false;
42                    }
43                }
44
45                if let Some(author) = &query.filter.author {
46                    if &message.author != author {
47                        return false;
48                    }
49                }
50
51                if let Some(query) = &query.filter.query {
52                    if let Some(content) = &message.content {
53                        if !content.to_lowercase().contains(query) {
54                            return false;
55                        }
56                    } else {
57                        return false;
58                    }
59                }
60
61                if let Some(pinned) = query.filter.pinned {
62                    if message.pinned.unwrap_or_default() == pinned {
63                        return false
64                    }
65                }
66
67                true
68            })
69            .cloned()
70            .collect();
71
72        // FIXME: sorting, etc (will be required for tests)
73
74        Ok(matched_messages)
75
76        /*
77        // 2. Find query limit
78        let limit = query.limit.unwrap_or(50);
79
80        // 3. Apply message time period
81        match query.time_period {
82            MessageTimePeriod::Relative { nearby } => {
83                // 3.1. Prepare filters
84                let mut older_message_filter = filter.clone();
85                let mut newer_message_filter = filter;
86
87                older_message_filter.insert(
88                    "_id",
89                    doc! {
90                        "$lt": &nearby
91                    },
92                );
93
94                newer_message_filter.insert(
95                    "_id",
96                    doc! {
97                        "$gte": &nearby
98                    },
99                );
100
101                // 3.2. Execute in both directions
102                let (a, b) = try_join!(
103                    self.find_with_options::<_, Message>(
104                        COL,
105                        newer_message_filter,
106                        FindOptions::builder()
107                            .limit(limit / 2 + 1)
108                            .sort(doc! {
109                                "_id": 1_i32
110                            })
111                            .build(),
112                    ),
113                    self.find_with_options::<_, Message>(
114                        COL,
115                        older_message_filter,
116                        FindOptions::builder()
117                            .limit(limit / 2)
118                            .sort(doc! {
119                                "_id": -1_i32
120                            })
121                            .build(),
122                    )
123                )
124                .map_err(|_| create_database_error!("find", COL))?;
125
126                Ok([a, b].concat())
127            }
128            MessageTimePeriod::Absolute {
129                before,
130                after,
131                sort,
132            } => {
133                // 3.1. Apply message ID filter
134                if let Some(doc) = match (before, after) {
135                    (Some(before), Some(after)) => Some(doc! {
136                        "$lt": before,
137                        "$gt": after
138                    }),
139                    (Some(before), _) => Some(doc! {
140                        "$lt": before
141                    }),
142                    (_, Some(after)) => Some(doc! {
143                        "$gt": after
144                    }),
145                    _ => None,
146                } {
147                    filter.insert("_id", doc);
148                }
149
150                // 3.2. Execute with given message sort
151                self.find_with_options(
152                    COL,
153                    filter,
154                    FindOptions::builder()
155                        .limit(limit)
156                        .sort(match sort.unwrap_or(MessageSort::Latest) {
157                            // Sort by relevance, fallback to latest
158                            MessageSort::Relevance => {
159                                if is_search_query {
160                                    doc! {
161                                        "score": {
162                                            "$meta": "textScore"
163                                        }
164                                    }
165                                } else {
166                                    doc! {
167                                        "_id": -1_i32
168                                    }
169                                }
170                            }
171                            // Sort by latest first
172                            MessageSort::Latest => doc! {
173                                "_id": -1_i32
174                            },
175                            // Sort by oldest first
176                            MessageSort::Oldest => doc! {
177                                "_id": 1_i32
178                            },
179                        })
180                        .build(),
181                )
182                .await
183                .map_err(|_| create_database_error!("find", COL))
184            }
185        }*/
186    }
187
188    /// Fetch multiple messages by given IDs
189    async fn fetch_messages_by_id(&self, ids: &[String]) -> Result<Vec<Message>> {
190        try_join_all(ids.iter().map(|id| self.fetch_message(id))).await
191    }
192
193    /// Update a given message with new information
194    async fn update_message(&self, id: &str, message: &PartialMessage, remove: Vec<FieldsMessage>) -> Result<()> {
195        let mut messages = self.messages.lock().await;
196        if let Some(message_data) = messages.get_mut(id) {
197            message_data.apply_options(message.to_owned());
198
199            for field in remove {
200                #[allow(clippy::disallowed_methods)]
201                message_data.remove_field(&field);
202            }
203            Ok(())
204        } else {
205            Err(create_error!(NotFound))
206        }
207    }
208
209    /// Append information to a given message
210    async fn append_message(&self, id: &str, append: &AppendMessage) -> Result<()> {
211        let mut messages = self.messages.lock().await;
212        if let Some(message_data) = messages.get_mut(id) {
213            if let Some(embeds) = &append.embeds {
214                if !embeds.is_empty() {
215                    if let Some(embeds_data) = &mut message_data.embeds {
216                        embeds_data.extend(embeds.clone());
217                    } else {
218                        message_data.embeds = Some(embeds.clone());
219                    }
220                }
221            }
222
223            Ok(())
224        } else {
225            Err(create_error!(NotFound))
226        }
227    }
228
229    /// Add a new reaction to a message
230    async fn add_reaction(&self, id: &str, emoji: &str, user: &str) -> Result<()> {
231        let mut messages = self.messages.lock().await;
232        if let Some(message) = messages.get_mut(id) {
233            if let Some(users) = message.reactions.get_mut(emoji) {
234                users.insert(user.to_string());
235            } else {
236                message
237                    .reactions
238                    .insert(emoji.to_string(), IndexSet::from([user.to_string()]));
239            }
240
241            Ok(())
242        } else {
243            Err(create_error!(NotFound))
244        }
245    }
246
247    /// Remove a reaction from a message
248    async fn remove_reaction(&self, id: &str, emoji: &str, user: &str) -> Result<()> {
249        let mut messages = self.messages.lock().await;
250        if let Some(message) = messages.get_mut(id) {
251            if let Some(users) = message.reactions.get_mut(emoji) {
252                users.swap_remove(&user.to_string());
253            }
254
255            Ok(())
256        } else {
257            Err(create_error!(NotFound))
258        }
259    }
260
261    /// Remove reaction from a message
262    async fn clear_reaction(&self, id: &str, emoji: &str) -> Result<()> {
263        let mut messages = self.messages.lock().await;
264        if let Some(message) = messages.get_mut(id) {
265            message.reactions.swap_remove(emoji);
266            Ok(())
267        } else {
268            Err(create_error!(NotFound))
269        }
270    }
271
272    /// Delete a message from the database by its id
273    async fn delete_message(&self, id: &str) -> Result<()> {
274        let mut messages = self.messages.lock().await;
275        if messages.remove(id).is_some() {
276            Ok(())
277        } else {
278            Err(create_error!(NotFound))
279        }
280    }
281
282    /// Delete messages from a channel by their ids and corresponding channel id
283    async fn delete_messages(&self, channel: &str, ids: &[String]) -> Result<()> {
284        self.messages
285            .lock()
286            .await
287            .retain(|id, message| message.channel != channel && !ids.contains(id));
288
289        Ok(())
290    }
291
292    /// Delete all messages from a specific author in a list of channels from a certain ULID onwards
293    async fn delete_messages_by_author_since(
294        &self,
295        channels: &[String],
296        author: &str,
297        since: SystemTime
298    ) -> Result<HashMap<String, Vec<String>>> {
299        let threshold_ulid = Ulid::from_datetime(since).to_string();
300        let mut deleted_messages: HashMap<String, Vec<String>> = HashMap::new();
301        let mut attachment_ids: Vec<String> = Vec::new();
302
303        let messages = self.messages.lock().await;
304
305        // First pass: collect attachment IDs and message IDs to delete
306        for (id, message) in messages.iter() {
307            let should_delete = message.author == author
308                && channels.contains(&message.channel)
309                && id.as_str() >= threshold_ulid.as_str();
310
311            if should_delete {
312                // Collect attachment IDs
313                if let Some(attachments) = &message.attachments {
314                    for attachment in attachments {
315                        attachment_ids.push(attachment.id.clone());
316                    }
317                }
318
319                deleted_messages
320                    .entry(message.channel.clone())
321                    .or_default()
322                    .push(id.clone());
323            }
324        }
325        drop(messages);
326
327        // Mark attachments as deleted
328        if !attachment_ids.is_empty() {
329            let mut files = self.files.lock().await;
330            for attachment_id in attachment_ids {
331                if let Some(file) = files.get_mut(&attachment_id) {
332                    file.deleted = Some(true);
333                }
334            }
335        }
336
337        // Delete the messages
338        self.messages
339            .lock()
340            .await
341            .retain(|id, message| {
342                let should_keep = !(message.author == author
343                    && channels.contains(&message.channel)
344                    && id.as_str() >= threshold_ulid.as_str());
345                should_keep
346            });
347
348        Ok(deleted_messages)
349    }
350}