Skip to main content

revolt_database/tasks/
ack.rs

1// Queue Type: Debounced
2use crate::{Database, Message, AMQP};
3
4use deadqueue::limited::Queue;
5use once_cell::sync::Lazy;
6use revolt_config::capture_message;
7use revolt_models::v0::PushNotification;
8use std::{
9    collections::{HashMap, HashSet},
10    time::Duration,
11};
12use validator::HasLen;
13
14use revolt_result::Result;
15
16use super::DelayedTask;
17use crate::Channel::TextChannel;
18
19/// Enumeration of possible events
20#[derive(Debug, Eq, PartialEq)]
21pub enum AckEvent {
22    /// Add mentions for a channel
23    ProcessMessage {
24        /// push notification, message, recipients, push silenced
25        messages: Vec<(Option<PushNotification>, Message, Vec<String>, bool)>,
26    },
27
28    /// Acknowledge message in a channel for a user
29    AckMessage {
30        /// Message ID
31        id: String,
32    },
33}
34
35/// Task information
36struct Data {
37    /// Channel to ack
38    channel: String,
39    /// User to ack for
40    user: Option<String>,
41    /// Event
42    event: AckEvent,
43}
44
45#[derive(Debug)]
46struct Task {
47    event: AckEvent,
48}
49
50static Q: Lazy<Queue<Data>> = Lazy::new(|| Queue::new(10_000));
51
52/// Queue a new task for a worker
53pub async fn queue_ack(channel: String, user: String, event: AckEvent) {
54    Q.try_push(Data {
55        channel,
56        user: Some(user),
57        event,
58    })
59    .ok();
60
61    info!(
62        "Queue is using {} slots from {}. Queued type: ACK",
63        Q.len(),
64        Q.capacity()
65    );
66}
67
68/// Do not add more than one message per event.
69pub async fn queue_message(channel: String, event: AckEvent) {
70    Q.try_push(Data {
71        channel,
72        user: None,
73        event,
74    })
75    .ok();
76
77    info!(
78        "Queue is using {} slots from {}. Queued type: MENTION",
79        Q.len(),
80        Q.capacity()
81    );
82}
83
84pub async fn handle_ack_event(
85    event: &AckEvent,
86    db: &Database,
87    amqp: &AMQP,
88    user: &Option<String>,
89    channel: &str,
90) -> Result<()> {
91    match &event {
92        #[allow(clippy::disallowed_methods)] // event is sent by higher level function
93        AckEvent::AckMessage { id } => {
94            let user = user.as_ref().unwrap();
95            let user: &str = user.as_str();
96
97            let unread = db.fetch_unread(user, channel).await?;
98            let updated = db.acknowledge_message(channel, user, id).await?;
99
100            if let (Some(before), Some(after)) = (unread, updated) {
101                let before_mentions = before.mentions.unwrap_or_default().len();
102                let after_mentions = after.mentions.unwrap_or_default().len();
103
104                let mentions_acked = before_mentions - after_mentions;
105
106                if mentions_acked > 0 {
107                    if let Err(err) = amqp
108                        .ack_notification_message(
109                            user.to_string(),
110                            channel.to_string(),
111                            id.to_owned(),
112                        )
113                        .await
114                    {
115                        revolt_config::capture_error(&err);
116                    }
117                };
118            }
119        }
120        AckEvent::ProcessMessage { messages } => {
121            let mut users: HashSet<&String> = HashSet::new();
122            info!(
123                "Processing {} messages from channel {}",
124                messages.len(),
125                messages[0].1.channel
126            );
127
128            // find all the users we'll be notifying
129            messages.iter().for_each(|(_, _, recipents, _)| {
130                users.extend(recipents.iter());
131            });
132
133            info!("Found {} users to notify.", users.len());
134
135            for user in users {
136                let message_ids: Vec<String> = messages
137                    .iter()
138                    .filter_map(|(_, message, recipients, _)| {
139                        if recipients.contains(user) {
140                            Some(message.id.clone())
141                        } else {
142                            None
143                        }
144                    })
145                    .collect();
146
147                if !message_ids.is_empty() {
148                    db.add_mention_to_unread(channel, user, &message_ids)
149                        .await?;
150                }
151                info!("Added {} mentions for user {}", message_ids.len(), &user);
152            }
153
154            let mut mass_mentions = vec![];
155
156            for (push, message, recipients, silenced) in messages {
157                if *silenced
158                    || push.is_none()
159                    || (recipients.is_empty() && !message.contains_mass_push_mention())
160                {
161                    debug!(
162                        "Rejecting push: silenced: {}, recipient count: {}, push exists: {:?}",
163                        *silenced,
164                        recipients.length(),
165                        push.is_some()
166                    );
167                    continue;
168                }
169
170                debug!(
171                    "Sending push event to AMQP; message {} for {} users",
172                    push.as_ref().unwrap().message.id,
173                    recipients.len()
174                );
175                if let Err(err) = amqp
176                    .message_sent(recipients.clone(), push.clone().unwrap())
177                    .await
178                {
179                    revolt_config::capture_error(&err);
180                }
181
182                if message.contains_mass_push_mention() {
183                    mass_mentions.push(push.clone().unwrap());
184                }
185            }
186
187            if !mass_mentions.is_empty() {
188                debug!(
189                    "Sending mass mention push event to AMQP; channel {}",
190                    &mass_mentions[0].message.channel
191                );
192
193                let channel = db
194                    .fetch_channel(&mass_mentions[0].message.channel)
195                    .await
196                    .expect("Failed to fetch channel from db");
197
198                if let TextChannel { server, .. } = channel {
199                    if let Err(err) = amqp.mass_mention_message_sent(server, mass_mentions).await {
200                        revolt_config::capture_error(&err);
201                    }
202                } else {
203                    panic!("Unknown channel type when sending mass mention event");
204                }
205            }
206        }
207    };
208
209    Ok(())
210}
211
212/// Start a new worker
213pub async fn worker(db: Database, amqp: AMQP) {
214    let mut tasks = HashMap::<(Option<String>, String, u8), DelayedTask<Task>>::new();
215    let mut keys: Vec<(Option<String>, String, u8)> = vec![];
216
217    loop {
218        // Find due tasks.
219        for (key, task) in &tasks {
220            if task.should_run() {
221                keys.push(key.clone());
222            }
223        }
224
225        // Commit any due tasks to the database.
226        for key in &keys {
227            if let Some(task) = tasks.remove(key) {
228                let Task { event } = task.data;
229                let (user, channel, _) = key;
230
231                if let Err(err) = handle_ack_event(&event, &db, &amqp, user, channel).await {
232                    revolt_config::capture_error(&err);
233                    error!("{err:?} for {event:?}. ({user:?}, {channel})");
234                } else {
235                    info!("User {user:?} ack in {channel} with {event:?}");
236                }
237            }
238        }
239
240        // Clear keys
241        keys.clear();
242
243        // Queue incoming tasks.
244        while let Some(Data {
245            channel,
246            user,
247            mut event,
248        }) = Q.try_pop()
249        {
250            info!("Took next ack from queue, now {} remaining", Q.len());
251
252            let key: (Option<String>, String, u8) = (
253                user,
254                channel,
255                match &event {
256                    AckEvent::AckMessage { .. } => 0,
257                    AckEvent::ProcessMessage { .. } => 1,
258                },
259            );
260            if let Some(task) = tasks.get_mut(&key) {
261                match &mut event {
262                    AckEvent::ProcessMessage { messages: new_data } => {
263                        if let AckEvent::ProcessMessage { messages: existing } =
264                            &mut task.data.event
265                        {
266                            if let Some(new_event) = new_data.pop() {
267                                // if the message contains a mass mention, do not delay it any further.
268                                if new_event.1.contains_mass_push_mention() {
269                                    // add the new message to the list of messages to be processed.
270                                    existing.push(new_event);
271                                    task.run_immediately();
272                                    continue;
273                                }
274
275                                existing.push(new_event);
276
277                                // put a cap on the amount of messages that can be queued, for particularly active channels
278                                if (existing.length() as u16)
279                                    < revolt_config::config()
280                                        .await
281                                        .features
282                                        .advanced
283                                        .process_message_delay_limit
284                                {
285                                    task.delay();
286                                }
287                            } else {
288                                let err_msg = format!("Got zero-length message event: {event:?}");
289                                capture_message(&err_msg, revolt_config::Level::Warning);
290                                info!("{err_msg}")
291                            }
292                        } else {
293                            panic!("Somehow got an ack message in the add mention arm");
294                        }
295                    }
296                    AckEvent::AckMessage { .. } => {
297                        // replace the last acked message with the new acked message
298                        task.data.event = event;
299                        task.delay();
300                    }
301                }
302            } else {
303                tasks.insert(key, DelayedTask::new(Task { event }));
304            }
305        }
306
307        // Sleep for an arbitrary amount of time.
308        async_std::task::sleep(Duration::from_secs(1)).await;
309    }
310}