use crate::{Database, Message, AMQP};
use deadqueue::limited::Queue;
use once_cell::sync::Lazy;
use revolt_config::capture_message;
use revolt_models::v0::PushNotification;
use std::{
collections::{HashMap, HashSet},
time::Duration,
};
use validator::HasLen;
use revolt_result::Result;
use super::DelayedTask;
use crate::Channel::TextChannel;
#[derive(Debug, Eq, PartialEq)]
pub enum AckEvent {
ProcessMessage {
messages: Vec<(Option<PushNotification>, Message, Vec<String>, bool)>,
},
AckMessage {
id: String,
},
}
struct Data {
channel: String,
user: Option<String>,
event: AckEvent,
}
#[derive(Debug)]
struct Task {
event: AckEvent,
}
static Q: Lazy<Queue<Data>> = Lazy::new(|| Queue::new(10_000));
pub async fn queue_ack(channel: String, user: String, event: AckEvent) {
Q.try_push(Data {
channel,
user: Some(user),
event,
})
.ok();
info!(
"Queue is using {} slots from {}. Queued type: ACK",
Q.len(),
Q.capacity()
);
}
pub async fn queue_message(channel: String, event: AckEvent) {
Q.try_push(Data {
channel,
user: None,
event,
})
.ok();
info!(
"Queue is using {} slots from {}. Queued type: MENTION",
Q.len(),
Q.capacity()
);
}
pub async fn handle_ack_event(
event: &AckEvent,
db: &Database,
amqp: &AMQP,
user: &Option<String>,
channel: &str,
) -> Result<()> {
match &event {
#[allow(clippy::disallowed_methods)] AckEvent::AckMessage { id } => {
let user = user.as_ref().unwrap();
let user: &str = user.as_str();
let unread = db.fetch_unread(user, channel).await?;
let updated = db.acknowledge_message(channel, user, id).await?;
if let (Some(before), Some(after)) = (unread, updated) {
let before_mentions = before.mentions.unwrap_or_default().len();
let after_mentions = after.mentions.unwrap_or_default().len();
let mentions_acked = before_mentions - after_mentions;
if mentions_acked > 0 {
if let Err(err) = amqp
.ack_notification_message(
user.to_string(),
channel.to_string(),
id.to_owned(),
)
.await
{
revolt_config::capture_error(&err);
}
};
}
}
AckEvent::ProcessMessage { messages } => {
let mut users: HashSet<&String> = HashSet::new();
info!(
"Processing {} messages from channel {}",
messages.len(),
messages[0].1.channel
);
messages.iter().for_each(|(_, _, recipents, _)| {
users.extend(recipents.iter());
});
info!("Found {} users to notify.", users.len());
for user in users {
let message_ids: Vec<String> = messages
.iter()
.filter_map(|(_, message, recipients, _)| {
if recipients.contains(user) {
Some(message.id.clone())
} else {
None
}
})
.collect();
if !message_ids.is_empty() {
db.add_mention_to_unread(channel, user, &message_ids)
.await?;
}
info!("Added {} mentions for user {}", message_ids.len(), &user);
}
let mut mass_mentions = vec![];
for (push, message, recipients, silenced) in messages {
if *silenced
|| push.is_none()
|| (recipients.is_empty() && !message.contains_mass_push_mention())
{
debug!(
"Rejecting push: silenced: {}, recipient count: {}, push exists: {:?}",
*silenced,
recipients.length(),
push.is_some()
);
continue;
}
debug!(
"Sending push event to AMQP; message {} for {} users",
push.as_ref().unwrap().message.id,
recipients.len()
);
if let Err(err) = amqp
.message_sent(recipients.clone(), push.clone().unwrap())
.await
{
revolt_config::capture_error(&err);
}
if message.contains_mass_push_mention() {
mass_mentions.push(push.clone().unwrap());
}
}
if !mass_mentions.is_empty() {
debug!(
"Sending mass mention push event to AMQP; channel {}",
&mass_mentions[0].message.channel
);
let channel = db
.fetch_channel(&mass_mentions[0].message.channel)
.await
.expect("Failed to fetch channel from db");
if let TextChannel { server, .. } = channel {
if let Err(err) = amqp.mass_mention_message_sent(server, mass_mentions).await {
revolt_config::capture_error(&err);
}
} else {
panic!("Unknown channel type when sending mass mention event");
}
}
}
};
Ok(())
}
pub async fn worker(db: Database, amqp: AMQP) {
let mut tasks = HashMap::<(Option<String>, String, u8), DelayedTask<Task>>::new();
let mut keys: Vec<(Option<String>, String, u8)> = vec![];
loop {
for (key, task) in &tasks {
if task.should_run() {
keys.push(key.clone());
}
}
for key in &keys {
if let Some(task) = tasks.remove(key) {
let Task { event } = task.data;
let (user, channel, _) = key;
if let Err(err) = handle_ack_event(&event, &db, &amqp, user, channel).await {
revolt_config::capture_error(&err);
error!("{err:?} for {event:?}. ({user:?}, {channel})");
} else {
info!("User {user:?} ack in {channel} with {event:?}");
}
}
}
keys.clear();
while let Some(Data {
channel,
user,
mut event,
}) = Q.try_pop()
{
info!("Took next ack from queue, now {} remaining", Q.len());
let key: (Option<String>, String, u8) = (
user,
channel,
match &event {
AckEvent::AckMessage { .. } => 0,
AckEvent::ProcessMessage { .. } => 1,
},
);
if let Some(task) = tasks.get_mut(&key) {
match &mut event {
AckEvent::ProcessMessage { messages: new_data } => {
if let AckEvent::ProcessMessage { messages: existing } =
&mut task.data.event
{
if let Some(new_event) = new_data.pop() {
if new_event.1.contains_mass_push_mention() {
existing.push(new_event);
task.run_immediately();
continue;
}
existing.push(new_event);
if (existing.length() as u16)
< revolt_config::config()
.await
.features
.advanced
.process_message_delay_limit
{
task.delay();
}
} else {
let err_msg = format!("Got zero-length message event: {event:?}");
capture_message(&err_msg, revolt_config::Level::Warning);
info!("{err_msg}")
}
} else {
panic!("Somehow got an ack message in the add mention arm");
}
}
AckEvent::AckMessage { .. } => {
task.data.event = event;
task.delay();
}
}
} else {
tasks.insert(key, DelayedTask::new(Task { event }));
}
}
async_std::task::sleep(Duration::from_secs(1)).await;
}
}