Skip to main content

revolt_database/amqp/
amqp.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3
4use crate::events::rabbit::*;
5use crate::User;
6use lapin::{
7    options::BasicPublishOptions,
8    protocol::basic::AMQPProperties,
9    types::{AMQPValue, FieldTable},
10    Channel, Connection, ConnectionProperties, Error as AMQPError,
11};
12use revolt_models::v0::PushNotification;
13use revolt_presence::filter_online;
14use revolt_result::Result;
15
16use serde_json::to_string;
17
18#[derive(Clone)]
19pub struct AMQP {
20    friend_request_accepted: Arc<Channel>,
21    friend_request_received: Arc<Channel>,
22    generic_message: Arc<Channel>,
23    message_sent: Arc<Channel>,
24    mass_mention_message_sent: Arc<Channel>,
25    ack_notification_message: Arc<Channel>,
26    dm_call_updated: Arc<Channel>,
27    process_ack: Arc<Channel>,
28    #[allow(unused)]
29    connection: Arc<Connection>,
30}
31
32impl AMQP {
33    pub async fn new(connection: Arc<Connection>) -> Self {
34        Self {
35            friend_request_accepted: Self::create_channel(&connection).await,
36            friend_request_received: Self::create_channel(&connection).await,
37            generic_message: Self::create_channel(&connection).await,
38            message_sent: Self::create_channel(&connection).await,
39            mass_mention_message_sent: Self::create_channel(&connection).await,
40            ack_notification_message: Self::create_channel(&connection).await,
41            dm_call_updated: Self::create_channel(&connection).await,
42            process_ack: Self::create_channel(&connection).await,
43            connection,
44        }
45    }
46
47    pub async fn new_auto() -> Self {
48        let config = revolt_config::config().await;
49
50        let connection = Arc::new(
51            Connection::connect(
52                &format!(
53                    "amqp://{}:{}@{}:{}",
54                    &config.rabbit.username,
55                    &config.rabbit.password,
56                    &config.rabbit.host,
57                    &config.rabbit.port,
58                ),
59                ConnectionProperties::default(),
60            )
61            .await
62            .expect("Failed to connect to RabbitMQ"),
63        );
64
65        Self::new(connection).await
66    }
67
68    async fn create_channel(connection: &Connection) -> Arc<Channel> {
69        Arc::new(
70            connection
71                .create_channel()
72                .await
73                .expect("Failed to create channel"),
74        )
75    }
76
77    pub async fn friend_request_accepted(
78        &self,
79        accepted_request_user: &User,
80        sent_request_user: &User,
81    ) -> Result<(), AMQPError> {
82        let config = revolt_config::config().await;
83        let payload = FRAcceptedPayload {
84            accepted_user: accepted_request_user.to_owned(),
85            user: sent_request_user.id.clone(),
86        };
87        let payload = to_string(&payload).unwrap();
88
89        debug!(
90            "Sending friend request accept payload on channel {}: {}",
91            config.pushd.get_fr_accepted_routing_key(),
92            payload
93        );
94
95        self.friend_request_accepted
96            .basic_publish(
97                config.pushd.exchange.clone().into(),
98                config.pushd.get_fr_accepted_routing_key().into(),
99                BasicPublishOptions::default(),
100                payload.as_bytes(),
101                AMQPProperties::default()
102                    .with_content_type("application/json".into())
103                    .with_delivery_mode(2),
104            )
105            .await?;
106
107        Ok(())
108    }
109
110    pub async fn friend_request_received(
111        &self,
112        received_request_user: &User,
113        sent_request_user: &User,
114    ) -> Result<(), AMQPError> {
115        let config = revolt_config::config().await;
116        let payload = FRReceivedPayload {
117            from_user: sent_request_user.to_owned(),
118            user: received_request_user.id.clone(),
119        };
120        let payload = to_string(&payload).unwrap();
121
122        debug!(
123            "Sending friend request received payload on channel {}: {}",
124            config.pushd.get_fr_received_routing_key(),
125            payload
126        );
127
128        self.friend_request_received
129            .basic_publish(
130                config.pushd.exchange.clone().into(),
131                config.pushd.get_fr_received_routing_key().into(),
132                BasicPublishOptions::default(),
133                payload.as_bytes(),
134                AMQPProperties::default()
135                    .with_content_type("application/json".into())
136                    .with_delivery_mode(2),
137            )
138            .await?;
139
140        Ok(())
141    }
142
143    pub async fn generic_message(
144        &self,
145        user: &User,
146        title: String,
147        body: String,
148        icon: Option<String>,
149    ) -> Result<(), AMQPError> {
150        let config = revolt_config::config().await;
151        let payload = GenericPayload {
152            title,
153            body,
154            icon,
155            user: user.to_owned(),
156        };
157        let payload = to_string(&payload).unwrap();
158
159        debug!(
160            "Sending generic payload on channel {}: {}",
161            config.pushd.get_generic_routing_key(),
162            payload
163        );
164
165        self.generic_message
166            .basic_publish(
167                config.pushd.exchange.clone().into(),
168                config.pushd.get_generic_routing_key().into(),
169                BasicPublishOptions::default(),
170                payload.as_bytes(),
171                AMQPProperties::default()
172                    .with_content_type("application/json".into())
173                    .with_delivery_mode(2),
174            )
175            .await?;
176
177        Ok(())
178    }
179
180    pub async fn message_sent(
181        &self,
182        recipients: Vec<String>,
183        payload: PushNotification,
184    ) -> Result<(), AMQPError> {
185        if recipients.is_empty() {
186            return Ok(());
187        }
188
189        let config = revolt_config::config().await;
190
191        let online_ids = filter_online(&recipients).await;
192        let recipients = (&recipients.into_iter().collect::<HashSet<String>>() - &online_ids)
193            .into_iter()
194            .collect::<Vec<String>>();
195
196        let payload = MessageSentPayload {
197            notification: payload,
198            users: recipients,
199        };
200        let payload = to_string(&payload).unwrap();
201
202        debug!(
203            "Sending message payload on channel {}: {}",
204            config.pushd.get_message_routing_key(),
205            payload
206        );
207
208        self.message_sent
209            .basic_publish(
210                config.pushd.exchange.clone().into(),
211                config.pushd.get_message_routing_key().into(),
212                BasicPublishOptions::default(),
213                payload.as_bytes(),
214                AMQPProperties::default()
215                    .with_content_type("application/json".into())
216                    .with_delivery_mode(2),
217            )
218            .await?;
219
220        Ok(())
221    }
222
223    pub async fn mass_mention_message_sent(
224        &self,
225        server_id: String,
226        payload: Vec<PushNotification>,
227    ) -> Result<(), AMQPError> {
228        let config = revolt_config::config().await;
229
230        let payload = MassMessageSentPayload {
231            notifications: payload,
232            server_id,
233        };
234        let payload = to_string(&payload).unwrap();
235
236        let routing_key = config.pushd.get_mass_mention_routing_key();
237
238        debug!(
239            "Sending mass mention payload on channel {}: {}",
240            routing_key, payload
241        );
242
243        self.mass_mention_message_sent
244            .basic_publish(
245                config.pushd.exchange.clone().into(),
246                routing_key.into(),
247                BasicPublishOptions::default(),
248                payload.as_bytes(),
249                AMQPProperties::default()
250                    .with_content_type("application/json".into())
251                    .with_delivery_mode(2),
252            )
253            .await?;
254
255        Ok(())
256    }
257
258    /// # Sends an ack to pushd to update badges on iPhones.
259    /// Not to be confused with the process_ack function, which handles sending all acks to crond for processing.
260    pub async fn ack_notification_message(
261        &self,
262        user_id: String,
263        channel_id: String,
264        message_id: String,
265    ) -> Result<(), AMQPError> {
266        let config = revolt_config::config().await;
267
268        let payload = AckPayload {
269            user_id: user_id.clone(),
270            channel_id: channel_id.clone(),
271            message_id,
272        };
273        let payload = to_string(&payload).unwrap();
274
275        info!(
276            "Sending ack payload on channel {}: {}",
277            config.pushd.ack_queue, payload
278        );
279
280        let mut headers = FieldTable::default();
281        headers.insert(
282            "x-deduplication-header".into(),
283            AMQPValue::LongString(format!("{}-{}", &user_id, &channel_id).into()),
284        );
285
286        self.ack_notification_message
287            .basic_publish(
288                config.pushd.exchange.clone().into(),
289                config.pushd.ack_queue.into(),
290                BasicPublishOptions::default(),
291                payload.as_bytes(),
292                AMQPProperties::default()
293                    .with_content_type("application/json".into())
294                    .with_delivery_mode(2),
295            )
296            .await?;
297
298        Ok(())
299    }
300
301    /// # DM Call Update
302    /// Used to send an update about a DM call, eg. start or end of a call.
303    /// Recipients can be used to narrow the scope of recipients, otherwise all recipients will be notified.
304    /// `ended` refers to the ringing period, not necessarily the call itself.
305    pub async fn dm_call_updated(
306        &self,
307        initiator_id: &str,
308        channel_id: &str,
309        started_at: Option<&str>,
310        ended: bool,
311        recipients: Option<Vec<String>>,
312    ) -> Result<(), AMQPError> {
313        let config = revolt_config::config().await;
314
315        let payload = InternalDmCallPayload {
316            payload: DmCallPayload {
317                initiator_id: initiator_id.to_string(),
318                channel_id: channel_id.to_string(),
319                started_at: started_at.map(|f| f.to_string()),
320                ended,
321            },
322            recipients,
323        };
324        let payload = to_string(&payload).unwrap();
325
326        debug!(
327            "Sending dm call update payload on channel {}: {}",
328            config.pushd.get_dm_call_routing_key(),
329            payload
330        );
331
332        self.dm_call_updated
333            .basic_publish(
334                config.pushd.exchange.clone().into(),
335                config.pushd.get_dm_call_routing_key().into(),
336                BasicPublishOptions::default(),
337                payload.as_bytes(),
338                AMQPProperties::default()
339                    .with_content_type("application/json".into())
340                    .with_delivery_mode(2),
341            )
342            .await?;
343
344        Ok(())
345    }
346
347    /// # Send an ack to crond for processing
348    pub async fn process_ack(
349        &self,
350        user_id: &str,
351        channel_id: Option<&str>,
352        server_id: Option<&str>,
353    ) -> Result<(), AMQPError> {
354        let config = revolt_config::config().await;
355
356        let payload = AckEventPayload {
357            user_id: user_id.to_string(),
358            channel_id: channel_id.map(|value| value.to_string()),
359            server_id: server_id.map(|value| value.to_string()),
360        };
361        let payload = to_string(&payload).unwrap();
362
363        info!(
364            "Sending ack processor event on exchange {}, channel {}: {}",
365            config.rabbit.default_exchange, config.rabbit.queues.acks, payload
366        );
367
368        self.process_ack
369            .basic_publish(
370                config.rabbit.default_exchange.clone().into(),
371                config.rabbit.queues.acks.into(),
372                BasicPublishOptions::default(),
373                payload.as_bytes(),
374                AMQPProperties::default()
375                    .with_content_type("application/json".into())
376                    .with_delivery_mode(2),
377            )
378            .await?;
379
380        Ok(())
381    }
382}