revolt_database/models/messages/ops/
reference.rs

1use futures::future::try_join_all;
2use indexmap::IndexSet;
3use revolt_result::Result;
4
5use crate::{AppendMessage, FieldsMessage, Message, MessageQuery, PartialMessage, ReferenceDb};
6
7use super::AbstractMessages;
8
9#[async_trait]
10impl AbstractMessages for ReferenceDb {
11    /// Insert a new message into the database
12    async fn insert_message(&self, message: &Message) -> Result<()> {
13        let mut messages = self.messages.lock().await;
14        if messages.contains_key(&message.id) {
15            Err(create_database_error!("insert", "message"))
16        } else {
17            messages.insert(message.id.to_string(), message.clone());
18            Ok(())
19        }
20    }
21
22    /// Fetch a message by its id
23    async fn fetch_message(&self, id: &str) -> Result<Message> {
24        let messages = self.messages.lock().await;
25        messages
26            .get(id)
27            .cloned()
28            .ok_or_else(|| create_error!(NotFound))
29    }
30
31    /// Fetch multiple messages by given query
32    async fn fetch_messages(&self, query: MessageQuery) -> Result<Vec<Message>> {
33        let messages = self.messages.lock().await;
34        let matched_messages = messages
35            .values()
36            .filter(|message| {
37                if let Some(channel) = &query.filter.channel {
38                    if &message.channel != channel {
39                        return false;
40                    }
41                }
42
43                if let Some(author) = &query.filter.author {
44                    if &message.author != author {
45                        return false;
46                    }
47                }
48
49                if let Some(query) = &query.filter.query {
50                    if let Some(content) = &message.content {
51                        if !content.to_lowercase().contains(query) {
52                            return false;
53                        }
54                    } else {
55                        return false;
56                    }
57                }
58
59                if let Some(pinned) = query.filter.pinned {
60                    if message.pinned.unwrap_or_default() == pinned {
61                        return false
62                    }
63                }
64
65                true
66            })
67            .cloned()
68            .collect();
69
70        // FIXME: sorting, etc (will be required for tests)
71
72        Ok(matched_messages)
73
74        /*
75        // 2. Find query limit
76        let limit = query.limit.unwrap_or(50);
77
78        // 3. Apply message time period
79        match query.time_period {
80            MessageTimePeriod::Relative { nearby } => {
81                // 3.1. Prepare filters
82                let mut older_message_filter = filter.clone();
83                let mut newer_message_filter = filter;
84
85                older_message_filter.insert(
86                    "_id",
87                    doc! {
88                        "$lt": &nearby
89                    },
90                );
91
92                newer_message_filter.insert(
93                    "_id",
94                    doc! {
95                        "$gte": &nearby
96                    },
97                );
98
99                // 3.2. Execute in both directions
100                let (a, b) = try_join!(
101                    self.find_with_options::<_, Message>(
102                        COL,
103                        newer_message_filter,
104                        FindOptions::builder()
105                            .limit(limit / 2 + 1)
106                            .sort(doc! {
107                                "_id": 1_i32
108                            })
109                            .build(),
110                    ),
111                    self.find_with_options::<_, Message>(
112                        COL,
113                        older_message_filter,
114                        FindOptions::builder()
115                            .limit(limit / 2)
116                            .sort(doc! {
117                                "_id": -1_i32
118                            })
119                            .build(),
120                    )
121                )
122                .map_err(|_| create_database_error!("find", COL))?;
123
124                Ok([a, b].concat())
125            }
126            MessageTimePeriod::Absolute {
127                before,
128                after,
129                sort,
130            } => {
131                // 3.1. Apply message ID filter
132                if let Some(doc) = match (before, after) {
133                    (Some(before), Some(after)) => Some(doc! {
134                        "$lt": before,
135                        "$gt": after
136                    }),
137                    (Some(before), _) => Some(doc! {
138                        "$lt": before
139                    }),
140                    (_, Some(after)) => Some(doc! {
141                        "$gt": after
142                    }),
143                    _ => None,
144                } {
145                    filter.insert("_id", doc);
146                }
147
148                // 3.2. Execute with given message sort
149                self.find_with_options(
150                    COL,
151                    filter,
152                    FindOptions::builder()
153                        .limit(limit)
154                        .sort(match sort.unwrap_or(MessageSort::Latest) {
155                            // Sort by relevance, fallback to latest
156                            MessageSort::Relevance => {
157                                if is_search_query {
158                                    doc! {
159                                        "score": {
160                                            "$meta": "textScore"
161                                        }
162                                    }
163                                } else {
164                                    doc! {
165                                        "_id": -1_i32
166                                    }
167                                }
168                            }
169                            // Sort by latest first
170                            MessageSort::Latest => doc! {
171                                "_id": -1_i32
172                            },
173                            // Sort by oldest first
174                            MessageSort::Oldest => doc! {
175                                "_id": 1_i32
176                            },
177                        })
178                        .build(),
179                )
180                .await
181                .map_err(|_| create_database_error!("find", COL))
182            }
183        }*/
184    }
185
186    /// Fetch multiple messages by given IDs
187    async fn fetch_messages_by_id(&self, ids: &[String]) -> Result<Vec<Message>> {
188        try_join_all(ids.iter().map(|id| self.fetch_message(id))).await
189    }
190
191    /// Update a given message with new information
192    async fn update_message(&self, id: &str, message: &PartialMessage, remove: Vec<FieldsMessage>) -> Result<()> {
193        let mut messages = self.messages.lock().await;
194        if let Some(message_data) = messages.get_mut(id) {
195            message_data.apply_options(message.to_owned());
196
197            for field in remove {
198                #[allow(clippy::disallowed_methods)]
199                message_data.remove_field(&field);
200            }
201            Ok(())
202        } else {
203            Err(create_error!(NotFound))
204        }
205    }
206
207    /// Append information to a given message
208    async fn append_message(&self, id: &str, append: &AppendMessage) -> Result<()> {
209        let mut messages = self.messages.lock().await;
210        if let Some(message_data) = messages.get_mut(id) {
211            if let Some(embeds) = &append.embeds {
212                if !embeds.is_empty() {
213                    if let Some(embeds_data) = &mut message_data.embeds {
214                        embeds_data.extend(embeds.clone());
215                    } else {
216                        message_data.embeds = Some(embeds.clone());
217                    }
218                }
219            }
220
221            Ok(())
222        } else {
223            Err(create_error!(NotFound))
224        }
225    }
226
227    /// Add a new reaction to a message
228    async fn add_reaction(&self, id: &str, emoji: &str, user: &str) -> Result<()> {
229        let mut messages = self.messages.lock().await;
230        if let Some(message) = messages.get_mut(id) {
231            if let Some(users) = message.reactions.get_mut(emoji) {
232                users.insert(user.to_string());
233            } else {
234                message
235                    .reactions
236                    .insert(emoji.to_string(), IndexSet::from([user.to_string()]));
237            }
238
239            Ok(())
240        } else {
241            Err(create_error!(NotFound))
242        }
243    }
244
245    /// Remove a reaction from a message
246    async fn remove_reaction(&self, id: &str, emoji: &str, user: &str) -> Result<()> {
247        let mut messages = self.messages.lock().await;
248        if let Some(message) = messages.get_mut(id) {
249            if let Some(users) = message.reactions.get_mut(emoji) {
250                users.remove(&user.to_string());
251            }
252
253            Ok(())
254        } else {
255            Err(create_error!(NotFound))
256        }
257    }
258
259    /// Remove reaction from a message
260    async fn clear_reaction(&self, id: &str, emoji: &str) -> Result<()> {
261        let mut messages = self.messages.lock().await;
262        if let Some(message) = messages.get_mut(id) {
263            message.reactions.remove(emoji);
264            Ok(())
265        } else {
266            Err(create_error!(NotFound))
267        }
268    }
269
270    /// Delete a message from the database by its id
271    async fn delete_message(&self, id: &str) -> Result<()> {
272        let mut messages = self.messages.lock().await;
273        if messages.remove(id).is_some() {
274            Ok(())
275        } else {
276            Err(create_error!(NotFound))
277        }
278    }
279
280    /// Delete messages from a channel by their ids and corresponding channel id
281    async fn delete_messages(&self, channel: &str, ids: &[String]) -> Result<()> {
282        self.messages
283            .lock()
284            .await
285            .retain(|id, message| message.channel != channel && !ids.contains(id));
286
287        Ok(())
288    }
289}