chik_sdk_client/
rate_limiter.rs1use std::{
2 collections::HashMap,
3 time::{SystemTime, UNIX_EPOCH},
4};
5
6use chik_protocol::{Message, ProtocolMessageTypes};
7
8use crate::RateLimits;
9
10#[derive(Debug, Clone)]
11pub struct RateLimiter {
12 incoming: bool,
13 reset_seconds: u64,
14 period: u64,
15 message_counts: HashMap<ProtocolMessageTypes, f64>,
16 message_cumulative_sizes: HashMap<ProtocolMessageTypes, f64>,
17 limit_factor: f64,
18 non_tx_count: f64,
19 non_tx_size: f64,
20 rate_limits: RateLimits,
21}
22
23impl RateLimiter {
24 pub fn new(
25 incoming: bool,
26 reset_seconds: u64,
27 limit_factor: f64,
28 rate_limits: RateLimits,
29 ) -> Self {
30 Self {
31 incoming,
32 reset_seconds,
33 period: time() / reset_seconds,
34 message_counts: HashMap::new(),
35 message_cumulative_sizes: HashMap::new(),
36 limit_factor,
37 non_tx_count: 0.0,
38 non_tx_size: 0.0,
39 rate_limits,
40 }
41 }
42
43 pub fn handle_message(&mut self, message: &Message) -> bool {
44 let size: u32 = message.data.len().try_into().expect("Message too large");
45 let size = f64::from(size);
46 let period = time() / self.reset_seconds;
47
48 if self.period != period {
49 self.period = period;
50 self.message_counts.clear();
51 self.message_cumulative_sizes.clear();
52 self.non_tx_count = 0.0;
53 self.non_tx_size = 0.0;
54 }
55
56 let new_message_count = self.message_counts.get(&message.msg_type).unwrap_or(&0.0) + 1.0;
57 let new_cumulative_size = self
58 .message_cumulative_sizes
59 .get(&message.msg_type)
60 .unwrap_or(&0.0)
61 + size;
62 let mut new_non_tx_count = self.non_tx_count;
63 let mut new_non_tx_size = self.non_tx_size;
64
65 let passed = 'checker: {
66 let mut limits = self.rate_limits.default_settings;
67
68 if let Some(tx_limits) = self.rate_limits.tx.get(&message.msg_type) {
69 limits = *tx_limits;
70 } else if let Some(other_limits) = self.rate_limits.other.get(&message.msg_type) {
71 limits = *other_limits;
72
73 new_non_tx_count += 1.0;
74 new_non_tx_size += size;
75
76 if new_non_tx_count > self.rate_limits.non_tx_frequency * self.limit_factor {
77 break 'checker false;
78 }
79
80 if new_non_tx_size > self.rate_limits.non_tx_max_total_size * self.limit_factor {
81 break 'checker false;
82 }
83 }
84
85 let max_total_size = limits
86 .max_total_size
87 .unwrap_or(limits.frequency * limits.max_size);
88
89 if new_message_count > limits.frequency * self.limit_factor {
90 break 'checker false;
91 }
92
93 if size > limits.max_size {
94 break 'checker false;
95 }
96
97 if new_cumulative_size > max_total_size * self.limit_factor {
98 break 'checker false;
99 }
100
101 true
102 };
103
104 if self.incoming || passed {
105 *self.message_counts.entry(message.msg_type).or_default() = new_message_count;
106 *self
107 .message_cumulative_sizes
108 .entry(message.msg_type)
109 .or_default() = new_cumulative_size;
110 self.non_tx_count = new_non_tx_count;
111 self.non_tx_size = new_non_tx_size;
112 }
113
114 passed
115 }
116}
117
118fn time() -> u64 {
119 SystemTime::now()
120 .duration_since(UNIX_EPOCH)
121 .expect("Time went backwards")
122 .as_secs()
123}