1use crate::{Database, Message, AMQP};
3
4use deadqueue::limited::Queue;
5use once_cell::sync::Lazy;
6use revolt_config::capture_message;
7use revolt_models::v0::PushNotification;
8use std::{
9 collections::{HashMap, HashSet},
10 time::Duration,
11};
12use validator::HasLen;
13
14use revolt_result::Result;
15
16use super::DelayedTask;
17use crate::Channel::{TextChannel, VoiceChannel};
18
19#[derive(Debug, Eq, PartialEq)]
21pub enum AckEvent {
22 ProcessMessage {
24 messages: Vec<(Option<PushNotification>, Message, Vec<String>, bool)>,
26 },
27
28 AckMessage {
30 id: String,
32 },
33}
34
35struct Data {
37 channel: String,
39 user: Option<String>,
41 event: AckEvent,
43}
44
45#[derive(Debug)]
46struct Task {
47 event: AckEvent,
48}
49
50static Q: Lazy<Queue<Data>> = Lazy::new(|| Queue::new(10_000));
51
52pub async fn queue_ack(channel: String, user: String, event: AckEvent) {
54 Q.try_push(Data {
55 channel,
56 user: Some(user),
57 event,
58 })
59 .ok();
60
61 info!(
62 "Queue is using {} slots from {}. Queued type: ACK",
63 Q.len(),
64 Q.capacity()
65 );
66}
67
68pub async fn queue_message(channel: String, event: AckEvent) {
70 Q.try_push(Data {
71 channel,
72 user: None,
73 event,
74 })
75 .ok();
76
77 info!(
78 "Queue is using {} slots from {}. Queued type: MENTION",
79 Q.len(),
80 Q.capacity()
81 );
82}
83
84pub async fn handle_ack_event(
85 event: &AckEvent,
86 db: &Database,
87 amqp: &AMQP,
88 user: &Option<String>,
89 channel: &str,
90) -> Result<()> {
91 match &event {
92 #[allow(clippy::disallowed_methods)] AckEvent::AckMessage { id } => {
94 let user = user.as_ref().unwrap();
95 let user: &str = user.as_str();
96
97 let unread = db.fetch_unread(user, channel).await?;
98 let updated = db.acknowledge_message(channel, user, id).await?;
99
100 if let (Some(before), Some(after)) = (unread, updated) {
101 let before_mentions = before.mentions.unwrap_or_default().len();
102 let after_mentions = after.mentions.unwrap_or_default().len();
103
104 let mentions_acked = before_mentions - after_mentions;
105
106 if mentions_acked > 0 {
107 if let Err(err) = amqp
108 .ack_message(user.to_string(), channel.to_string(), id.to_owned())
109 .await
110 {
111 revolt_config::capture_error(&err);
112 }
113 };
114 }
115 }
116 AckEvent::ProcessMessage { messages } => {
117 let mut users: HashSet<&String> = HashSet::new();
118 info!(
119 "Processing {} messages from channel {}",
120 messages.len(),
121 messages[0].1.channel
122 );
123
124 messages.iter().for_each(|(_, _, recipents, _)| {
126 users.extend(recipents.iter());
127 });
128
129 info!("Found {} users to notify.", users.len());
130
131 for user in users {
132 let message_ids: Vec<String> = messages
133 .iter()
134 .filter_map(|(_, message, recipients, _)| {
135 if recipients.contains(user) {
136 Some(message.id.clone())
137 } else {
138 None
139 }
140 })
141 .collect();
142
143 if !message_ids.is_empty() {
144 db.add_mention_to_unread(channel, user, &message_ids)
145 .await?;
146 }
147 info!("Added {} mentions for user {}", message_ids.len(), &user);
148 }
149
150 let mut mass_mentions = vec![];
151
152 for (push, message, recipients, silenced) in messages {
153 if *silenced
154 || push.is_none()
155 || (recipients.is_empty() && !message.contains_mass_push_mention())
156 {
157 debug!(
158 "Rejecting push: silenced: {}, recipient count: {}, push exists: {:?}",
159 *silenced,
160 recipients.length(),
161 push.is_some()
162 );
163 continue;
164 }
165
166 debug!(
167 "Sending push event to AMQP; message {} for {} users",
168 push.as_ref().unwrap().message.id,
169 recipients.len()
170 );
171 if let Err(err) = amqp
172 .message_sent(recipients.clone(), push.clone().unwrap())
173 .await
174 {
175 revolt_config::capture_error(&err);
176 }
177
178 if message.contains_mass_push_mention() {
179 mass_mentions.push(push.clone().unwrap());
180 }
181 }
182
183 if !mass_mentions.is_empty() {
184 debug!(
185 "Sending mass mention push event to AMQP; channel {}",
186 &mass_mentions[0].message.channel
187 );
188
189 let channel = db
190 .fetch_channel(&mass_mentions[0].message.channel)
191 .await
192 .expect("Failed to fetch channel from db");
193
194 match channel {
195 TextChannel { server, .. } | VoiceChannel { server, .. } => {
196 if let Err(err) =
197 amqp.mass_mention_message_sent(server, mass_mentions).await
198 {
199 revolt_config::capture_error(&err);
200 }
201 }
202 _ => {
203 panic!("Unknown channel type when sending mass mention event");
204 }
205 }
206 }
207 }
208 };
209
210 Ok(())
211}
212
213pub async fn worker(db: Database, amqp: AMQP) {
215 let mut tasks = HashMap::<(Option<String>, String, u8), DelayedTask<Task>>::new();
216 let mut keys: Vec<(Option<String>, String, u8)> = vec![];
217
218 loop {
219 for (key, task) in &tasks {
221 if task.should_run() {
222 keys.push(key.clone());
223 }
224 }
225
226 for key in &keys {
228 if let Some(task) = tasks.remove(key) {
229 let Task { event } = task.data;
230 let (user, channel, _) = key;
231
232 if let Err(err) = handle_ack_event(&event, &db, &amqp, user, channel).await {
233 revolt_config::capture_error(&err);
234 error!("{err:?} for {event:?}. ({user:?}, {channel})");
235 } else {
236 info!("User {user:?} ack in {channel} with {event:?}");
237 }
238 }
239 }
240
241 keys.clear();
243
244 while let Some(Data {
246 channel,
247 user,
248 mut event,
249 }) = Q.try_pop()
250 {
251 info!("Took next ack from queue, now {} remaining", Q.len());
252
253 let key: (Option<String>, String, u8) = (
254 user,
255 channel,
256 match &event {
257 AckEvent::AckMessage { .. } => 0,
258 AckEvent::ProcessMessage { .. } => 1,
259 },
260 );
261 if let Some(task) = tasks.get_mut(&key) {
262 match &mut event {
263 AckEvent::ProcessMessage { messages: new_data } => {
264 if let AckEvent::ProcessMessage { messages: existing } =
265 &mut task.data.event
266 {
267 if let Some(new_event) = new_data.pop() {
268 if new_event.1.contains_mass_push_mention() {
270 existing.push(new_event);
272 task.run_immediately();
273 continue;
274 }
275
276 existing.push(new_event);
277
278 if (existing.length() as u16)
280 < revolt_config::config()
281 .await
282 .features
283 .advanced
284 .process_message_delay_limit
285 {
286 task.delay();
287 }
288 } else {
289 let err_msg = format!("Got zero-length message event: {event:?}");
290 capture_message(&err_msg, revolt_config::Level::Warning);
291 info!("{err_msg}")
292 }
293 } else {
294 panic!("Somehow got an ack message in the add mention arm");
295 }
296 }
297 AckEvent::AckMessage { .. } => {
298 task.data.event = event;
300 task.delay();
301 }
302 }
303 } else {
304 tasks.insert(key, DelayedTask::new(Task { event }));
305 }
306 }
307
308 async_std::task::sleep(Duration::from_secs(1)).await;
310 }
311}