1use std::sync::{Arc, Mutex};
2use std::time::{Duration, Instant};
3
4#[derive(Debug, Default)]
5pub(crate) struct RateLimiterState {
6 max_calls: Option<u64>,
7 max_per_second: Option<u64>,
8 hits: Vec<Instant>,
9 notified: bool,
10}
11
12#[derive(Debug, Clone)]
13pub struct RateLimiterHandle {
14 state: Arc<Mutex<RateLimiterState>>,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub(crate) enum RateLimitStatus {
19 Allowed,
20 Blocked { notify: bool },
21}
22
23impl RateLimiterHandle {
24 pub(crate) fn new(state: Arc<Mutex<RateLimiterState>>) -> Self {
25 Self { state }
26 }
27
28 pub fn max(self, max_calls: i64) -> Self {
30 {
31 let mut state = self.state.lock().unwrap();
32 state.max_calls = normalize_limit(max_calls);
33 if state.max_calls.is_none() {
34 state.notified = false;
35 }
36 }
37 self
38 }
39
40 pub fn per_second(self, max_per_second: i64) -> Self {
42 {
43 let mut state = self.state.lock().unwrap();
44 state.max_per_second = normalize_limit(max_per_second);
45 if state.max_per_second.is_none() {
46 state.notified = false;
47 }
48 }
49 self
50 }
51
52 pub fn clear(self) -> Self {
54 {
55 let mut state = self.state.lock().unwrap();
56 state.clear();
57 }
58 self
59 }
60
61 pub(crate) fn status(&self) -> RateLimitStatus {
62 let mut state = self.state.lock().unwrap();
63 let max_reached = state.is_max_reached();
64 let per_second_reached = state.is_max_per_second_reached();
65
66 if max_reached || per_second_reached {
67 let notify = !state.notified;
68 if notify {
69 state.notified = true;
70 }
71 return RateLimitStatus::Blocked { notify };
72 }
73
74 RateLimitStatus::Allowed
75 }
76
77 pub(crate) fn hit(&self) {
78 let mut state = self.state.lock().unwrap();
79 state.hit();
80 }
81}
82
83impl RateLimiterState {
84 fn hit(&mut self) {
85 self.hits.push(Instant::now());
86 }
87
88 fn clear(&mut self) {
89 self.max_calls = None;
90 self.max_per_second = None;
91 self.hits.clear();
92 self.notified = false;
93 }
94
95 fn is_max_reached(&mut self) -> bool {
96 let Some(max_calls) = self.max_calls else {
97 return false;
98 };
99
100 let reached = self.hits.len() as u64 >= max_calls;
101 if !reached {
102 self.notified = false;
103 }
104 reached
105 }
106
107 fn is_max_per_second_reached(&mut self) -> bool {
108 let Some(max_per_second) = self.max_per_second else {
109 return false;
110 };
111
112 let reached = self.count_last_second() >= max_per_second;
113 if !reached {
114 self.notified = false;
115 }
116 reached
117 }
118
119 fn count_last_second(&self) -> u64 {
120 let now = Instant::now();
121 self.hits
122 .iter()
123 .filter(|hit| now.duration_since(**hit) <= Duration::from_secs(1))
124 .count() as u64
125 }
126}
127
128fn normalize_limit(value: i64) -> Option<u64> {
129 if value <= 0 {
130 None
131 } else {
132 Some(value as u64)
133 }
134}