1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
use crate::core::{Handler, HandlerResult};
use async_trait::async_trait;
use ratelimit_meter::{DirectRateLimiter, GCRA};
use std::{num::NonZeroU32, sync::Arc, time::Duration};
use tgbot::types::Update;
use tokio::sync::Mutex;

/// Limits all updates
pub struct DirectRateLimitHandler {
    limiter: Arc<Mutex<DirectRateLimiter<GCRA>>>,
}

impl DirectRateLimitHandler {
    /// Creates a new handler
    ///
    /// # Arguments
    ///
    /// - capacity - Number of updates
    /// - duration - Per time unit
    pub fn new(capacity: NonZeroU32, duration: Duration) -> Self {
        Self {
            limiter: Arc::new(Mutex::new(DirectRateLimiter::new(capacity, duration))),
        }
    }
}

#[async_trait]
impl<C> Handler<C> for DirectRateLimitHandler
where
    C: Send + Sync,
{
    type Input = Update;
    type Output = HandlerResult;

    async fn handle(&mut self, _context: &C, _update: Self::Input) -> Self::Output {
        if self.limiter.lock().await.check().is_ok() {
            HandlerResult::Continue
        } else {
            HandlerResult::Stop
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use nonzero_ext::nonzero;
    use tgbot::types::Update;

    #[tokio::test]
    async fn handler() {
        let update: Update = serde_json::from_value(serde_json::json!({
            "update_id": 1,
            "message": {
                "message_id": 1,
                "date": 1,
                "from": {"id": 1, "is_bot": false, "first_name": "test", "username": "username_user"},
                "chat": {"id": 1, "type": "supergroup", "title": "test", "username": "username_chat"},
                "text": "test"
            }
        }))
        .unwrap();
        let mut handler = DirectRateLimitHandler::new(nonzero!(1u32), Duration::from_secs(1000));
        let mut results = Vec::new();
        for _ in 0..10 {
            results.push(handler.handle(&(), update.clone()).await)
        }
        assert!(results.into_iter().any(|x| match x {
            HandlerResult::Stop => true,
            _ => false,
        }));
    }
}