1use crate::{Database, Message, AMQP};
3
4use deadqueue::limited::Queue;
5use once_cell::sync::Lazy;
6use revolt_config::capture_message;
7use revolt_models::v0::PushNotification;
8use std::{
9 collections::{HashMap, HashSet},
10 time::Duration,
11};
12use validator::HasLen;
13
14use revolt_result::Result;
15
16use super::DelayedTask;
17use crate::Channel::TextChannel;
18
19#[derive(Debug, Eq, PartialEq)]
21pub enum AckEvent {
22 ProcessMessage {
24 messages: Vec<(Option<PushNotification>, Message, Vec<String>, bool)>,
26 },
27
28 AckMessage {
30 id: String,
32 },
33}
34
35struct Data {
37 channel: String,
39 user: Option<String>,
41 event: AckEvent,
43}
44
45#[derive(Debug)]
46struct Task {
47 event: AckEvent,
48}
49
50static Q: Lazy<Queue<Data>> = Lazy::new(|| Queue::new(10_000));
51
52pub async fn queue_ack(channel: String, user: String, event: AckEvent) {
54 Q.try_push(Data {
55 channel,
56 user: Some(user),
57 event,
58 })
59 .ok();
60
61 info!(
62 "Queue is using {} slots from {}. Queued type: ACK",
63 Q.len(),
64 Q.capacity()
65 );
66}
67
68pub async fn queue_message(channel: String, event: AckEvent) {
70 Q.try_push(Data {
71 channel,
72 user: None,
73 event,
74 })
75 .ok();
76
77 info!(
78 "Queue is using {} slots from {}. Queued type: MENTION",
79 Q.len(),
80 Q.capacity()
81 );
82}
83
84pub async fn handle_ack_event(
85 event: &AckEvent,
86 db: &Database,
87 amqp: &AMQP,
88 user: &Option<String>,
89 channel: &str,
90) -> Result<()> {
91 match &event {
92 #[allow(clippy::disallowed_methods)] AckEvent::AckMessage { id } => {
94 let user = user.as_ref().unwrap();
95 let user: &str = user.as_str();
96
97 let unread = db.fetch_unread(user, channel).await?;
98 let updated = db.acknowledge_message(channel, user, id).await?;
99
100 if let (Some(before), Some(after)) = (unread, updated) {
101 let before_mentions = before.mentions.unwrap_or_default().len();
102 let after_mentions = after.mentions.unwrap_or_default().len();
103
104 let mentions_acked = before_mentions - after_mentions;
105
106 if mentions_acked > 0 {
107 if let Err(err) = amqp
108 .ack_notification_message(
109 user.to_string(),
110 channel.to_string(),
111 id.to_owned(),
112 )
113 .await
114 {
115 revolt_config::capture_error(&err);
116 }
117 };
118 }
119 }
120 AckEvent::ProcessMessage { messages } => {
121 let mut users: HashSet<&String> = HashSet::new();
122 info!(
123 "Processing {} messages from channel {}",
124 messages.len(),
125 messages[0].1.channel
126 );
127
128 messages.iter().for_each(|(_, _, recipents, _)| {
130 users.extend(recipents.iter());
131 });
132
133 info!("Found {} users to notify.", users.len());
134
135 for user in users {
136 let message_ids: Vec<String> = messages
137 .iter()
138 .filter_map(|(_, message, recipients, _)| {
139 if recipients.contains(user) {
140 Some(message.id.clone())
141 } else {
142 None
143 }
144 })
145 .collect();
146
147 if !message_ids.is_empty() {
148 db.add_mention_to_unread(channel, user, &message_ids)
149 .await?;
150 }
151 info!("Added {} mentions for user {}", message_ids.len(), &user);
152 }
153
154 let mut mass_mentions = vec![];
155
156 for (push, message, recipients, silenced) in messages {
157 if *silenced
158 || push.is_none()
159 || (recipients.is_empty() && !message.contains_mass_push_mention())
160 {
161 debug!(
162 "Rejecting push: silenced: {}, recipient count: {}, push exists: {:?}",
163 *silenced,
164 recipients.length(),
165 push.is_some()
166 );
167 continue;
168 }
169
170 debug!(
171 "Sending push event to AMQP; message {} for {} users",
172 push.as_ref().unwrap().message.id,
173 recipients.len()
174 );
175 if let Err(err) = amqp
176 .message_sent(recipients.clone(), push.clone().unwrap())
177 .await
178 {
179 revolt_config::capture_error(&err);
180 }
181
182 if message.contains_mass_push_mention() {
183 mass_mentions.push(push.clone().unwrap());
184 }
185 }
186
187 if !mass_mentions.is_empty() {
188 debug!(
189 "Sending mass mention push event to AMQP; channel {}",
190 &mass_mentions[0].message.channel
191 );
192
193 let channel = db
194 .fetch_channel(&mass_mentions[0].message.channel)
195 .await
196 .expect("Failed to fetch channel from db");
197
198 if let TextChannel { server, .. } = channel {
199 if let Err(err) = amqp.mass_mention_message_sent(server, mass_mentions).await {
200 revolt_config::capture_error(&err);
201 }
202 } else {
203 panic!("Unknown channel type when sending mass mention event");
204 }
205 }
206 }
207 };
208
209 Ok(())
210}
211
212pub async fn worker(db: Database, amqp: AMQP) {
214 let mut tasks = HashMap::<(Option<String>, String, u8), DelayedTask<Task>>::new();
215 let mut keys: Vec<(Option<String>, String, u8)> = vec![];
216
217 loop {
218 for (key, task) in &tasks {
220 if task.should_run() {
221 keys.push(key.clone());
222 }
223 }
224
225 for key in &keys {
227 if let Some(task) = tasks.remove(key) {
228 let Task { event } = task.data;
229 let (user, channel, _) = key;
230
231 if let Err(err) = handle_ack_event(&event, &db, &amqp, user, channel).await {
232 revolt_config::capture_error(&err);
233 error!("{err:?} for {event:?}. ({user:?}, {channel})");
234 } else {
235 info!("User {user:?} ack in {channel} with {event:?}");
236 }
237 }
238 }
239
240 keys.clear();
242
243 while let Some(Data {
245 channel,
246 user,
247 mut event,
248 }) = Q.try_pop()
249 {
250 info!("Took next ack from queue, now {} remaining", Q.len());
251
252 let key: (Option<String>, String, u8) = (
253 user,
254 channel,
255 match &event {
256 AckEvent::AckMessage { .. } => 0,
257 AckEvent::ProcessMessage { .. } => 1,
258 },
259 );
260 if let Some(task) = tasks.get_mut(&key) {
261 match &mut event {
262 AckEvent::ProcessMessage { messages: new_data } => {
263 if let AckEvent::ProcessMessage { messages: existing } =
264 &mut task.data.event
265 {
266 if let Some(new_event) = new_data.pop() {
267 if new_event.1.contains_mass_push_mention() {
269 existing.push(new_event);
271 task.run_immediately();
272 continue;
273 }
274
275 existing.push(new_event);
276
277 if (existing.length() as u16)
279 < revolt_config::config()
280 .await
281 .features
282 .advanced
283 .process_message_delay_limit
284 {
285 task.delay();
286 }
287 } else {
288 let err_msg = format!("Got zero-length message event: {event:?}");
289 capture_message(&err_msg, revolt_config::Level::Warning);
290 info!("{err_msg}")
291 }
292 } else {
293 panic!("Somehow got an ack message in the add mention arm");
294 }
295 }
296 AckEvent::AckMessage { .. } => {
297 task.data.event = event;
299 task.delay();
300 }
301 }
302 } else {
303 tasks.insert(key, DelayedTask::new(Task { event }));
304 }
305 }
306
307 async_std::task::sleep(Duration::from_secs(1)).await;
309 }
310}