revolt_database/tasks/
ack.rs

1// Queue Type: Debounced
2use crate::{Database, Message, AMQP};
3
4use deadqueue::limited::Queue;
5use once_cell::sync::Lazy;
6use revolt_config::capture_message;
7use revolt_models::v0::PushNotification;
8use std::{
9    collections::{HashMap, HashSet},
10    time::Duration,
11};
12use validator::HasLen;
13
14use revolt_result::Result;
15
16use super::DelayedTask;
17use crate::Channel::{TextChannel, VoiceChannel};
18
19/// Enumeration of possible events
20#[derive(Debug, Eq, PartialEq)]
21pub enum AckEvent {
22    /// Add mentions for a channel
23    ProcessMessage {
24        /// push notification, message, recipients, push silenced
25        messages: Vec<(Option<PushNotification>, Message, Vec<String>, bool)>,
26    },
27
28    /// Acknowledge message in a channel for a user
29    AckMessage {
30        /// Message ID
31        id: String,
32    },
33}
34
35/// Task information
36struct Data {
37    /// Channel to ack
38    channel: String,
39    /// User to ack for
40    user: Option<String>,
41    /// Event
42    event: AckEvent,
43}
44
45#[derive(Debug)]
46struct Task {
47    event: AckEvent,
48}
49
50static Q: Lazy<Queue<Data>> = Lazy::new(|| Queue::new(10_000));
51
52/// Queue a new task for a worker
53pub async fn queue_ack(channel: String, user: String, event: AckEvent) {
54    Q.try_push(Data {
55        channel,
56        user: Some(user),
57        event,
58    })
59    .ok();
60
61    info!(
62        "Queue is using {} slots from {}. Queued type: ACK",
63        Q.len(),
64        Q.capacity()
65    );
66}
67
68/// Do not add more than one message per event.
69pub async fn queue_message(channel: String, event: AckEvent) {
70    Q.try_push(Data {
71        channel,
72        user: None,
73        event,
74    })
75    .ok();
76
77    info!(
78        "Queue is using {} slots from {}. Queued type: MENTION",
79        Q.len(),
80        Q.capacity()
81    );
82}
83
84pub async fn handle_ack_event(
85    event: &AckEvent,
86    db: &Database,
87    amqp: &AMQP,
88    user: &Option<String>,
89    channel: &str,
90) -> Result<()> {
91    match &event {
92        #[allow(clippy::disallowed_methods)] // event is sent by higher level function
93        AckEvent::AckMessage { id } => {
94            let user = user.as_ref().unwrap();
95            let user: &str = user.as_str();
96
97            let unread = db.fetch_unread(user, channel).await?;
98            let updated = db.acknowledge_message(channel, user, id).await?;
99
100            if let (Some(before), Some(after)) = (unread, updated) {
101                let before_mentions = before.mentions.unwrap_or_default().len();
102                let after_mentions = after.mentions.unwrap_or_default().len();
103
104                let mentions_acked = before_mentions - after_mentions;
105
106                if mentions_acked > 0 {
107                    if let Err(err) = amqp
108                        .ack_message(user.to_string(), channel.to_string(), id.to_owned())
109                        .await
110                    {
111                        revolt_config::capture_error(&err);
112                    }
113                };
114            }
115        }
116        AckEvent::ProcessMessage { messages } => {
117            let mut users: HashSet<&String> = HashSet::new();
118            info!(
119                "Processing {} messages from channel {}",
120                messages.len(),
121                messages[0].1.channel
122            );
123
124            // find all the users we'll be notifying
125            messages.iter().for_each(|(_, _, recipents, _)| {
126                users.extend(recipents.iter());
127            });
128
129            info!("Found {} users to notify.", users.len());
130
131            for user in users {
132                let message_ids: Vec<String> = messages
133                    .iter()
134                    .filter_map(|(_, message, recipients, _)| {
135                        if recipients.contains(user) {
136                            Some(message.id.clone())
137                        } else {
138                            None
139                        }
140                    })
141                    .collect();
142
143                if !message_ids.is_empty() {
144                    db.add_mention_to_unread(channel, user, &message_ids)
145                        .await?;
146                }
147                info!("Added {} mentions for user {}", message_ids.len(), &user);
148            }
149
150            let mut mass_mentions = vec![];
151
152            for (push, message, recipients, silenced) in messages {
153                if *silenced
154                    || push.is_none()
155                    || (recipients.is_empty() && !message.contains_mass_push_mention())
156                {
157                    debug!(
158                        "Rejecting push: silenced: {}, recipient count: {}, push exists: {:?}",
159                        *silenced,
160                        recipients.length(),
161                        push.is_some()
162                    );
163                    continue;
164                }
165
166                debug!(
167                    "Sending push event to AMQP; message {} for {} users",
168                    push.as_ref().unwrap().message.id,
169                    recipients.len()
170                );
171                if let Err(err) = amqp
172                    .message_sent(recipients.clone(), push.clone().unwrap())
173                    .await
174                {
175                    revolt_config::capture_error(&err);
176                }
177
178                if message.contains_mass_push_mention() {
179                    mass_mentions.push(push.clone().unwrap());
180                }
181            }
182
183            if !mass_mentions.is_empty() {
184                debug!(
185                    "Sending mass mention push event to AMQP; channel {}",
186                    &mass_mentions[0].message.channel
187                );
188
189                let channel = db
190                    .fetch_channel(&mass_mentions[0].message.channel)
191                    .await
192                    .expect("Failed to fetch channel from db");
193
194                match channel {
195                    TextChannel { server, .. } | VoiceChannel { server, .. } => {
196                        if let Err(err) =
197                            amqp.mass_mention_message_sent(server, mass_mentions).await
198                        {
199                            revolt_config::capture_error(&err);
200                        }
201                    }
202                    _ => {
203                        panic!("Unknown channel type when sending mass mention event");
204                    }
205                }
206            }
207        }
208    };
209
210    Ok(())
211}
212
213/// Start a new worker
214pub async fn worker(db: Database, amqp: AMQP) {
215    let mut tasks = HashMap::<(Option<String>, String, u8), DelayedTask<Task>>::new();
216    let mut keys: Vec<(Option<String>, String, u8)> = vec![];
217
218    loop {
219        // Find due tasks.
220        for (key, task) in &tasks {
221            if task.should_run() {
222                keys.push(key.clone());
223            }
224        }
225
226        // Commit any due tasks to the database.
227        for key in &keys {
228            if let Some(task) = tasks.remove(key) {
229                let Task { event } = task.data;
230                let (user, channel, _) = key;
231
232                if let Err(err) = handle_ack_event(&event, &db, &amqp, user, channel).await {
233                    revolt_config::capture_error(&err);
234                    error!("{err:?} for {event:?}. ({user:?}, {channel})");
235                } else {
236                    info!("User {user:?} ack in {channel} with {event:?}");
237                }
238            }
239        }
240
241        // Clear keys
242        keys.clear();
243
244        // Queue incoming tasks.
245        while let Some(Data {
246            channel,
247            user,
248            mut event,
249        }) = Q.try_pop()
250        {
251            info!("Took next ack from queue, now {} remaining", Q.len());
252
253            let key: (Option<String>, String, u8) = (
254                user,
255                channel,
256                match &event {
257                    AckEvent::AckMessage { .. } => 0,
258                    AckEvent::ProcessMessage { .. } => 1,
259                },
260            );
261            if let Some(task) = tasks.get_mut(&key) {
262                match &mut event {
263                    AckEvent::ProcessMessage { messages: new_data } => {
264                        if let AckEvent::ProcessMessage { messages: existing } =
265                            &mut task.data.event
266                        {
267                            if let Some(new_event) = new_data.pop() {
268                                // if the message contains a mass mention, do not delay it any further.
269                                if new_event.1.contains_mass_push_mention() {
270                                    // add the new message to the list of messages to be processed.
271                                    existing.push(new_event);
272                                    task.run_immediately();
273                                    continue;
274                                }
275
276                                existing.push(new_event);
277
278                                // put a cap on the amount of messages that can be queued, for particularly active channels
279                                if (existing.length() as u16)
280                                    < revolt_config::config()
281                                        .await
282                                        .features
283                                        .advanced
284                                        .process_message_delay_limit
285                                {
286                                    task.delay();
287                                }
288                            } else {
289                                let err_msg = format!("Got zero-length message event: {event:?}");
290                                capture_message(&err_msg, revolt_config::Level::Warning);
291                                info!("{err_msg}")
292                            }
293                        } else {
294                            panic!("Somehow got an ack message in the add mention arm");
295                        }
296                    }
297                    AckEvent::AckMessage { .. } => {
298                        // replace the last acked message with the new acked message
299                        task.data.event = event;
300                        task.delay();
301                    }
302                }
303            } else {
304                tasks.insert(key, DelayedTask::new(Task { event }));
305            }
306        }
307
308        // Sleep for an arbitrary amount of time.
309        async_std::task::sleep(Duration::from_secs(1)).await;
310    }
311}