chik_sdk_client/
rate_limiter.rs

1use std::{
2    collections::HashMap,
3    time::{SystemTime, UNIX_EPOCH},
4};
5
6use chik_protocol::{Message, ProtocolMessageTypes};
7
8use crate::RateLimits;
9
10#[derive(Debug, Clone)]
11pub struct RateLimiter {
12    incoming: bool,
13    reset_seconds: u64,
14    period: u64,
15    message_counts: HashMap<ProtocolMessageTypes, f64>,
16    message_cumulative_sizes: HashMap<ProtocolMessageTypes, f64>,
17    limit_factor: f64,
18    non_tx_count: f64,
19    non_tx_size: f64,
20    rate_limits: RateLimits,
21}
22
23impl RateLimiter {
24    pub fn new(
25        incoming: bool,
26        reset_seconds: u64,
27        limit_factor: f64,
28        rate_limits: RateLimits,
29    ) -> Self {
30        Self {
31            incoming,
32            reset_seconds,
33            period: time() / reset_seconds,
34            message_counts: HashMap::new(),
35            message_cumulative_sizes: HashMap::new(),
36            limit_factor,
37            non_tx_count: 0.0,
38            non_tx_size: 0.0,
39            rate_limits,
40        }
41    }
42
43    pub fn handle_message(&mut self, message: &Message) -> bool {
44        let size: u32 = message.data.len().try_into().expect("Message too large");
45        let size = f64::from(size);
46        let period = time() / self.reset_seconds;
47
48        if self.period != period {
49            self.period = period;
50            self.message_counts.clear();
51            self.message_cumulative_sizes.clear();
52            self.non_tx_count = 0.0;
53            self.non_tx_size = 0.0;
54        }
55
56        let new_message_count = self.message_counts.get(&message.msg_type).unwrap_or(&0.0) + 1.0;
57        let new_cumulative_size = self
58            .message_cumulative_sizes
59            .get(&message.msg_type)
60            .unwrap_or(&0.0)
61            + size;
62        let mut new_non_tx_count = self.non_tx_count;
63        let mut new_non_tx_size = self.non_tx_size;
64
65        let passed = 'checker: {
66            let mut limits = self.rate_limits.default_settings;
67
68            if let Some(tx_limits) = self.rate_limits.tx.get(&message.msg_type) {
69                limits = *tx_limits;
70            } else if let Some(other_limits) = self.rate_limits.other.get(&message.msg_type) {
71                limits = *other_limits;
72
73                new_non_tx_count += 1.0;
74                new_non_tx_size += size;
75
76                if new_non_tx_count > self.rate_limits.non_tx_frequency * self.limit_factor {
77                    break 'checker false;
78                }
79
80                if new_non_tx_size > self.rate_limits.non_tx_max_total_size * self.limit_factor {
81                    break 'checker false;
82                }
83            }
84
85            let max_total_size = limits
86                .max_total_size
87                .unwrap_or(limits.frequency * limits.max_size);
88
89            if new_message_count > limits.frequency * self.limit_factor {
90                break 'checker false;
91            }
92
93            if size > limits.max_size {
94                break 'checker false;
95            }
96
97            if new_cumulative_size > max_total_size * self.limit_factor {
98                break 'checker false;
99            }
100
101            true
102        };
103
104        if self.incoming || passed {
105            *self.message_counts.entry(message.msg_type).or_default() = new_message_count;
106            *self
107                .message_cumulative_sizes
108                .entry(message.msg_type)
109                .or_default() = new_cumulative_size;
110            self.non_tx_count = new_non_tx_count;
111            self.non_tx_size = new_non_tx_size;
112        }
113
114        passed
115    }
116}
117
118fn time() -> u64 {
119    SystemTime::now()
120        .duration_since(UNIX_EPOCH)
121        .expect("Time went backwards")
122        .as_secs()
123}