revolt_database/models/messages/ops/reference.rs
1use std::collections::HashMap;
2use futures::future::try_join_all;
3use indexmap::IndexSet;
4use revolt_result::Result;
5use std::time::SystemTime;
6use ulid::Ulid;
7use crate::{AppendMessage, FieldsMessage, Message, MessageQuery, PartialMessage, ReferenceDb};
8
9use super::AbstractMessages;
10
11#[async_trait]
12impl AbstractMessages for ReferenceDb {
13 /// Insert a new message into the database
14 async fn insert_message(&self, message: &Message) -> Result<()> {
15 let mut messages = self.messages.lock().await;
16 if messages.contains_key(&message.id) {
17 Err(create_database_error!("insert", "message"))
18 } else {
19 messages.insert(message.id.to_string(), message.clone());
20 Ok(())
21 }
22 }
23
24 /// Fetch a message by its id
25 async fn fetch_message(&self, id: &str) -> Result<Message> {
26 let messages = self.messages.lock().await;
27 messages
28 .get(id)
29 .cloned()
30 .ok_or_else(|| create_error!(NotFound))
31 }
32
33 /// Fetch multiple messages by given query
34 async fn fetch_messages(&self, query: MessageQuery) -> Result<Vec<Message>> {
35 let messages = self.messages.lock().await;
36 let matched_messages = messages
37 .values()
38 .filter(|message| {
39 if let Some(channel) = &query.filter.channel {
40 if &message.channel != channel {
41 return false;
42 }
43 }
44
45 if let Some(author) = &query.filter.author {
46 if &message.author != author {
47 return false;
48 }
49 }
50
51 if let Some(query) = &query.filter.query {
52 if let Some(content) = &message.content {
53 if !content.to_lowercase().contains(query) {
54 return false;
55 }
56 } else {
57 return false;
58 }
59 }
60
61 if let Some(pinned) = query.filter.pinned {
62 if message.pinned.unwrap_or_default() == pinned {
63 return false
64 }
65 }
66
67 true
68 })
69 .cloned()
70 .collect();
71
72 // FIXME: sorting, etc (will be required for tests)
73
74 Ok(matched_messages)
75
76 /*
77 // 2. Find query limit
78 let limit = query.limit.unwrap_or(50);
79
80 // 3. Apply message time period
81 match query.time_period {
82 MessageTimePeriod::Relative { nearby } => {
83 // 3.1. Prepare filters
84 let mut older_message_filter = filter.clone();
85 let mut newer_message_filter = filter;
86
87 older_message_filter.insert(
88 "_id",
89 doc! {
90 "$lt": &nearby
91 },
92 );
93
94 newer_message_filter.insert(
95 "_id",
96 doc! {
97 "$gte": &nearby
98 },
99 );
100
101 // 3.2. Execute in both directions
102 let (a, b) = try_join!(
103 self.find_with_options::<_, Message>(
104 COL,
105 newer_message_filter,
106 FindOptions::builder()
107 .limit(limit / 2 + 1)
108 .sort(doc! {
109 "_id": 1_i32
110 })
111 .build(),
112 ),
113 self.find_with_options::<_, Message>(
114 COL,
115 older_message_filter,
116 FindOptions::builder()
117 .limit(limit / 2)
118 .sort(doc! {
119 "_id": -1_i32
120 })
121 .build(),
122 )
123 )
124 .map_err(|_| create_database_error!("find", COL))?;
125
126 Ok([a, b].concat())
127 }
128 MessageTimePeriod::Absolute {
129 before,
130 after,
131 sort,
132 } => {
133 // 3.1. Apply message ID filter
134 if let Some(doc) = match (before, after) {
135 (Some(before), Some(after)) => Some(doc! {
136 "$lt": before,
137 "$gt": after
138 }),
139 (Some(before), _) => Some(doc! {
140 "$lt": before
141 }),
142 (_, Some(after)) => Some(doc! {
143 "$gt": after
144 }),
145 _ => None,
146 } {
147 filter.insert("_id", doc);
148 }
149
150 // 3.2. Execute with given message sort
151 self.find_with_options(
152 COL,
153 filter,
154 FindOptions::builder()
155 .limit(limit)
156 .sort(match sort.unwrap_or(MessageSort::Latest) {
157 // Sort by relevance, fallback to latest
158 MessageSort::Relevance => {
159 if is_search_query {
160 doc! {
161 "score": {
162 "$meta": "textScore"
163 }
164 }
165 } else {
166 doc! {
167 "_id": -1_i32
168 }
169 }
170 }
171 // Sort by latest first
172 MessageSort::Latest => doc! {
173 "_id": -1_i32
174 },
175 // Sort by oldest first
176 MessageSort::Oldest => doc! {
177 "_id": 1_i32
178 },
179 })
180 .build(),
181 )
182 .await
183 .map_err(|_| create_database_error!("find", COL))
184 }
185 }*/
186 }
187
188 /// Fetch multiple messages by given IDs
189 async fn fetch_messages_by_id(&self, ids: &[String]) -> Result<Vec<Message>> {
190 try_join_all(ids.iter().map(|id| self.fetch_message(id))).await
191 }
192
193 /// Update a given message with new information
194 async fn update_message(&self, id: &str, message: &PartialMessage, remove: Vec<FieldsMessage>) -> Result<()> {
195 let mut messages = self.messages.lock().await;
196 if let Some(message_data) = messages.get_mut(id) {
197 message_data.apply_options(message.to_owned());
198
199 for field in remove {
200 #[allow(clippy::disallowed_methods)]
201 message_data.remove_field(&field);
202 }
203 Ok(())
204 } else {
205 Err(create_error!(NotFound))
206 }
207 }
208
209 /// Append information to a given message
210 async fn append_message(&self, id: &str, append: &AppendMessage) -> Result<()> {
211 let mut messages = self.messages.lock().await;
212 if let Some(message_data) = messages.get_mut(id) {
213 if let Some(embeds) = &append.embeds {
214 if !embeds.is_empty() {
215 if let Some(embeds_data) = &mut message_data.embeds {
216 embeds_data.extend(embeds.clone());
217 } else {
218 message_data.embeds = Some(embeds.clone());
219 }
220 }
221 }
222
223 Ok(())
224 } else {
225 Err(create_error!(NotFound))
226 }
227 }
228
229 /// Add a new reaction to a message
230 async fn add_reaction(&self, id: &str, emoji: &str, user: &str) -> Result<()> {
231 let mut messages = self.messages.lock().await;
232 if let Some(message) = messages.get_mut(id) {
233 if let Some(users) = message.reactions.get_mut(emoji) {
234 users.insert(user.to_string());
235 } else {
236 message
237 .reactions
238 .insert(emoji.to_string(), IndexSet::from([user.to_string()]));
239 }
240
241 Ok(())
242 } else {
243 Err(create_error!(NotFound))
244 }
245 }
246
247 /// Remove a reaction from a message
248 async fn remove_reaction(&self, id: &str, emoji: &str, user: &str) -> Result<()> {
249 let mut messages = self.messages.lock().await;
250 if let Some(message) = messages.get_mut(id) {
251 if let Some(users) = message.reactions.get_mut(emoji) {
252 users.swap_remove(&user.to_string());
253 }
254
255 Ok(())
256 } else {
257 Err(create_error!(NotFound))
258 }
259 }
260
261 /// Remove reaction from a message
262 async fn clear_reaction(&self, id: &str, emoji: &str) -> Result<()> {
263 let mut messages = self.messages.lock().await;
264 if let Some(message) = messages.get_mut(id) {
265 message.reactions.swap_remove(emoji);
266 Ok(())
267 } else {
268 Err(create_error!(NotFound))
269 }
270 }
271
272 /// Delete a message from the database by its id
273 async fn delete_message(&self, id: &str) -> Result<()> {
274 let mut messages = self.messages.lock().await;
275 if messages.remove(id).is_some() {
276 Ok(())
277 } else {
278 Err(create_error!(NotFound))
279 }
280 }
281
282 /// Delete messages from a channel by their ids and corresponding channel id
283 async fn delete_messages(&self, channel: &str, ids: &[String]) -> Result<()> {
284 self.messages
285 .lock()
286 .await
287 .retain(|id, message| message.channel != channel && !ids.contains(id));
288
289 Ok(())
290 }
291
292 /// Delete all messages from a specific author in a list of channels from a certain ULID onwards
293 async fn delete_messages_by_author_since(
294 &self,
295 channels: &[String],
296 author: &str,
297 since: SystemTime
298 ) -> Result<HashMap<String, Vec<String>>> {
299 let threshold_ulid = Ulid::from_datetime(since).to_string();
300 let mut deleted_messages: HashMap<String, Vec<String>> = HashMap::new();
301 let mut attachment_ids: Vec<String> = Vec::new();
302
303 let messages = self.messages.lock().await;
304
305 // First pass: collect attachment IDs and message IDs to delete
306 for (id, message) in messages.iter() {
307 let should_delete = message.author == author
308 && channels.contains(&message.channel)
309 && id.as_str() >= threshold_ulid.as_str();
310
311 if should_delete {
312 // Collect attachment IDs
313 if let Some(attachments) = &message.attachments {
314 for attachment in attachments {
315 attachment_ids.push(attachment.id.clone());
316 }
317 }
318
319 deleted_messages
320 .entry(message.channel.clone())
321 .or_default()
322 .push(id.clone());
323 }
324 }
325 drop(messages);
326
327 // Mark attachments as deleted
328 if !attachment_ids.is_empty() {
329 let mut files = self.files.lock().await;
330 for attachment_id in attachment_ids {
331 if let Some(file) = files.get_mut(&attachment_id) {
332 file.deleted = Some(true);
333 }
334 }
335 }
336
337 // Delete the messages
338 self.messages
339 .lock()
340 .await
341 .retain(|id, message| {
342 let should_keep = !(message.author == author
343 && channels.contains(&message.channel)
344 && id.as_str() >= threshold_ulid.as_str());
345 should_keep
346 });
347
348 Ok(deleted_messages)
349 }
350}