1use crate::{Database, Message, AMQP};
3
4use deadqueue::limited::Queue;
5use once_cell::sync::Lazy;
6use revolt_config::capture_message;
7use revolt_models::v0::PushNotification;
8use std::{
9 collections::{HashMap, HashSet},
10 time::Duration,
11};
12use validator::HasLen;
13
14use revolt_result::Result;
15
16use super::DelayedTask;
17use crate::Channel::TextChannel;
18
19#[derive(Debug, Eq, PartialEq)]
21pub enum AckEvent {
22 ProcessMessage {
24 messages: Vec<(Option<PushNotification>, Message, Vec<String>, bool)>,
26 },
27
28 AckMessage {
30 id: String,
32 },
33}
34
35struct Data {
37 channel: String,
39 user: Option<String>,
41 event: AckEvent,
43}
44
45#[derive(Debug)]
46struct Task {
47 event: AckEvent,
48}
49
50static Q: Lazy<Queue<Data>> = Lazy::new(|| Queue::new(10_000));
51
52pub async fn queue_ack(channel: String, user: String, event: AckEvent) {
54 Q.try_push(Data {
55 channel,
56 user: Some(user),
57 event,
58 })
59 .ok();
60
61 info!(
62 "Queue is using {} slots from {}. Queued type: ACK",
63 Q.len(),
64 Q.capacity()
65 );
66}
67
68pub async fn queue_message(channel: String, event: AckEvent) {
70 Q.try_push(Data {
71 channel,
72 user: None,
73 event,
74 })
75 .ok();
76
77 info!(
78 "Queue is using {} slots from {}. Queued type: MENTION",
79 Q.len(),
80 Q.capacity()
81 );
82}
83
84pub async fn handle_ack_event(
85 event: &AckEvent,
86 db: &Database,
87 amqp: &AMQP,
88 user: &Option<String>,
89 channel: &str,
90) -> Result<()> {
91 match &event {
92 #[allow(clippy::disallowed_methods)] AckEvent::AckMessage { id } => {
94 let user = user.as_ref().unwrap();
95 let user: &str = user.as_str();
96
97 let unread = db.fetch_unread(user, channel).await?;
98 let updated = db.acknowledge_message(channel, user, id).await?;
99
100 if let (Some(before), Some(after)) = (unread, updated) {
101 let before_mentions = before.mentions.unwrap_or_default().len();
102 let after_mentions = after.mentions.unwrap_or_default().len();
103
104 let mentions_acked = before_mentions - after_mentions;
105
106 if mentions_acked > 0 {
107 if let Err(err) = amqp
108 .ack_message(user.to_string(), channel.to_string(), id.to_owned())
109 .await
110 {
111 revolt_config::capture_error(&err);
112 }
113 };
114 }
115 }
116 AckEvent::ProcessMessage { messages } => {
117 let mut users: HashSet<&String> = HashSet::new();
118 info!(
119 "Processing {} messages from channel {}",
120 messages.len(),
121 messages[0].1.channel
122 );
123
124 messages.iter().for_each(|(_, _, recipents, _)| {
126 users.extend(recipents.iter());
127 });
128
129 info!("Found {} users to notify.", users.len());
130
131 for user in users {
132 let message_ids: Vec<String> = messages
133 .iter()
134 .filter_map(|(_, message, recipients, _)| {
135 if recipients.contains(user) {
136 Some(message.id.clone())
137 } else {
138 None
139 }
140 })
141 .collect();
142
143 if !message_ids.is_empty() {
144 db.add_mention_to_unread(channel, user, &message_ids)
145 .await?;
146 }
147 info!("Added {} mentions for user {}", message_ids.len(), &user);
148 }
149
150 let mut mass_mentions = vec![];
151
152 for (push, message, recipients, silenced) in messages {
153 if *silenced
154 || push.is_none()
155 || (recipients.is_empty() && !message.contains_mass_push_mention())
156 {
157 debug!(
158 "Rejecting push: silenced: {}, recipient count: {}, push exists: {:?}",
159 *silenced,
160 recipients.length(),
161 push.is_some()
162 );
163 continue;
164 }
165
166 debug!(
167 "Sending push event to AMQP; message {} for {} users",
168 push.as_ref().unwrap().message.id,
169 recipients.len()
170 );
171 if let Err(err) = amqp
172 .message_sent(recipients.clone(), push.clone().unwrap())
173 .await
174 {
175 revolt_config::capture_error(&err);
176 }
177
178 if message.contains_mass_push_mention() {
179 mass_mentions.push(push.clone().unwrap());
180 }
181 }
182
183 if !mass_mentions.is_empty() {
184 debug!(
185 "Sending mass mention push event to AMQP; channel {}",
186 &mass_mentions[0].message.channel
187 );
188
189 let channel = db
190 .fetch_channel(&mass_mentions[0].message.channel)
191 .await
192 .expect("Failed to fetch channel from db");
193
194 if let TextChannel { server, .. } = channel {
195 if let Err(err) =
196 amqp.mass_mention_message_sent(server, mass_mentions).await
197 {
198 revolt_config::capture_error(&err);
199 }
200 } else {
201 panic!("Unknown channel type when sending mass mention event");
202 }
203 }
204 }
205 };
206
207 Ok(())
208}
209
210pub async fn worker(db: Database, amqp: AMQP) {
212 let mut tasks = HashMap::<(Option<String>, String, u8), DelayedTask<Task>>::new();
213 let mut keys: Vec<(Option<String>, String, u8)> = vec![];
214
215 loop {
216 for (key, task) in &tasks {
218 if task.should_run() {
219 keys.push(key.clone());
220 }
221 }
222
223 for key in &keys {
225 if let Some(task) = tasks.remove(key) {
226 let Task { event } = task.data;
227 let (user, channel, _) = key;
228
229 if let Err(err) = handle_ack_event(&event, &db, &amqp, user, channel).await {
230 revolt_config::capture_error(&err);
231 error!("{err:?} for {event:?}. ({user:?}, {channel})");
232 } else {
233 info!("User {user:?} ack in {channel} with {event:?}");
234 }
235 }
236 }
237
238 keys.clear();
240
241 while let Some(Data {
243 channel,
244 user,
245 mut event,
246 }) = Q.try_pop()
247 {
248 info!("Took next ack from queue, now {} remaining", Q.len());
249
250 let key: (Option<String>, String, u8) = (
251 user,
252 channel,
253 match &event {
254 AckEvent::AckMessage { .. } => 0,
255 AckEvent::ProcessMessage { .. } => 1,
256 },
257 );
258 if let Some(task) = tasks.get_mut(&key) {
259 match &mut event {
260 AckEvent::ProcessMessage { messages: new_data } => {
261 if let AckEvent::ProcessMessage { messages: existing } =
262 &mut task.data.event
263 {
264 if let Some(new_event) = new_data.pop() {
265 if new_event.1.contains_mass_push_mention() {
267 existing.push(new_event);
269 task.run_immediately();
270 continue;
271 }
272
273 existing.push(new_event);
274
275 if (existing.length() as u16)
277 < revolt_config::config()
278 .await
279 .features
280 .advanced
281 .process_message_delay_limit
282 {
283 task.delay();
284 }
285 } else {
286 let err_msg = format!("Got zero-length message event: {event:?}");
287 capture_message(&err_msg, revolt_config::Level::Warning);
288 info!("{err_msg}")
289 }
290 } else {
291 panic!("Somehow got an ack message in the add mention arm");
292 }
293 }
294 AckEvent::AckMessage { .. } => {
295 task.data.event = event;
297 task.delay();
298 }
299 }
300 } else {
301 tasks.insert(key, DelayedTask::new(Task { event }));
302 }
303 }
304
305 async_std::task::sleep(Duration::from_secs(1)).await;
307 }
308}