revolt_database/models/ratelimit_events/ops/
reference.rs

1use std::cmp::Ordering;
2use std::time::Duration;
3use std::time::SystemTime;
4
5use super::AbstractRatelimitEvents;
6use crate::RatelimitEvent;
7use crate::RatelimitEventType;
8use crate::ReferenceDb;
9use revolt_result::Result;
10use ulid::Ulid;
11
12#[async_trait]
13impl AbstractRatelimitEvents for ReferenceDb {
14    /// Insert a new ratelimit event
15    async fn insert_ratelimit_event(&self, event: &RatelimitEvent) -> Result<()> {
16        let mut ratelimit_events = self.ratelimit_events.lock().await;
17        if ratelimit_events.contains_key(&event.id) {
18            Err(create_database_error!("insert", "message"))
19        } else {
20            ratelimit_events.insert(event.id.to_string(), event.clone());
21            Ok(())
22        }
23    }
24
25    /// Count number of events in given duration and check if we've hit the limit
26    async fn has_ratelimited(
27        &self,
28        target_id: &str,
29        event_type: RatelimitEventType,
30        period: Duration,
31        count: usize,
32    ) -> Result<bool> {
33        let ratelimit_events = self.ratelimit_events.lock().await;
34        let gte_cmp_id = Ulid::from_datetime(SystemTime::now() - period).to_string();
35
36        Ok(ratelimit_events
37            .iter()
38            .filter(|(id, event)| {
39                id.cmp(&&gte_cmp_id) == Ordering::Greater
40                    && event.target_id == target_id
41                    && event.event_type == event_type
42            })
43            .count()
44            >= count)
45    }
46}