1mod scripts;
2
3use std::time;
4
5#[cfg(feature = "local_accelerate")]
6use std::{
7 collections::HashMap,
8 sync::{LazyLock, RwLock},
9};
10
11use scripts::ALLOW_N_SCRIPT;
12
13#[cfg(feature = "local_accelerate")]
14static RESET_TIME_STORE: LazyLock<RwLock<HashMap<String, time::Instant>>> =
15 LazyLock::new(|| RwLock::new(HashMap::new()));
16
17const DEFAULT_LIMITER_KEY_PREFIX: &str = "redis_rate:";
18
19#[cfg(feature = "local_accelerate")]
20const DEFAULT_LIMITER_EVENT_CHANNEL: &str = "redis_rate_channel";
21#[cfg(feature = "local_accelerate")]
22const LIMITER_RESET_EVENT_PREFIX: &str = "reset:";
23
24#[derive(Debug, Clone)]
26pub struct Limit {
27 rate: usize,
28 burst: usize,
29 period_seconds: usize,
30}
31
32impl Limit {
33 pub fn new(rate: usize, burst: usize, period_seconds: usize) -> Self {
36 if period_seconds == 0 {
37 panic!("period_seconds must be greater than 0");
38 }
39 if rate == 0 {
40 panic!("rate must be greater than 0");
41 }
42 if rate > burst {
43 panic!("rate must be less than or equal to burst");
44 }
45
46 Limit {
47 rate,
48 burst,
49 period_seconds,
50 }
51 }
52}
53
54#[macro_export]
57macro_rules! new_limit {
58 ($rate:expr, $burst:expr, $period_seconds:expr) => {{
59 const _: () = {
60 assert!($period_seconds > 0, "period_seconds must be greater than 0");
61 assert!($rate > 0, "rate must be greater than 0");
62 assert!($rate <= $burst, "rate must be less than or equal to burst");
63 };
64 $crate::Limit::new($rate, $burst, $period_seconds)
65 }};
66}
67
68#[derive(Debug, Clone)]
70pub struct LimitResult {
71 pub limited: bool,
73 pub remaining: usize,
75 pub retry_after: Option<time::Duration>,
78 pub reset_after: time::Duration,
80}
81
82#[derive(Debug, Clone)]
84pub struct Limiter {
85 client: redis::Client,
86 key_prefix: String,
87
88 #[cfg(feature = "local_accelerate")]
89 event_channel: String,
90}
91
92impl Limiter {
93 pub fn new(client: redis::Client) -> Self {
95 Limiter {
96 client,
97 key_prefix: DEFAULT_LIMITER_KEY_PREFIX.to_string(),
98
99 #[cfg(feature = "local_accelerate")]
100 event_channel: DEFAULT_LIMITER_EVENT_CHANNEL.to_string(),
101 }
102 }
103
104 pub fn set_key_prefix(mut self, key_prefix: &str) -> Self {
106 self.key_prefix = key_prefix.to_string();
107 self
108 }
109
110 #[cfg(feature = "local_accelerate")]
113 pub fn set_event_channel(mut self, channel: &str) -> Self {
114 self.event_channel = channel.to_string();
115 self
116 }
117
118 #[cfg(feature = "local_accelerate")]
121 pub fn start_event_sync(&self) -> Result<(), redis::RedisError> {
122 let mut con = self.client.get_connection()?;
123 let mut pubsub = con.as_pubsub();
124 pubsub.subscribe(&self.event_channel).unwrap();
125 loop {
126 let msg = pubsub.get_message()?.get_payload::<String>()?;
127 if msg.starts_with(LIMITER_RESET_EVENT_PREFIX) {
128 let payload = msg.split_at(LIMITER_RESET_EVENT_PREFIX.len()).1;
129 if let Ok(mut store) = RESET_TIME_STORE.try_write() {
130 store.remove(payload);
131 }
132 }
133 }
134 }
135
136 pub fn reset(&self, key: &str) -> Result<(), redis::RedisError> {
138 let key = format!("{}{}", self.key_prefix, key);
139 let mut con = self.client.get_connection()?;
140 redis::cmd("DEL").arg(&key).query::<()>(&mut con)?;
141
142 #[cfg(feature = "local_accelerate")]
143 {
144 let reset_notify = format!("{}{}", LIMITER_RESET_EVENT_PREFIX, key);
145 redis::cmd("PUBLISH")
146 .arg(self.event_channel.clone())
147 .arg(&reset_notify)
148 .query::<()>(&mut con)?;
149 }
150
151 Ok(())
152 }
153
154 pub fn allow(&self, key: &str, limit: &Limit) -> Result<LimitResult, redis::RedisError> {
156 self.allow_n(key, limit, 1)
157 }
158
159 pub fn allow_n(
161 &self,
162 key: &str,
163 limit: &Limit,
164 n: usize,
165 ) -> Result<LimitResult, redis::RedisError> {
166 let key = format!("{}{}", self.key_prefix, key);
167
168 let emission_interval = limit.period_seconds as f64 / limit.rate as f64;
169 let tat_increment = emission_interval * n as f64;
170 let brust_offset = limit.burst as f64 * emission_interval;
171
172 #[cfg(feature = "local_accelerate")]
173 let now = time::Instant::now();
174 #[cfg(feature = "local_accelerate")]
175 if let Ok(store) = RESET_TIME_STORE.try_read() {
176 if let Some(reset_time) = store.get(&key) {
177 let reset_after = reset_time.duration_since(now).as_secs_f64();
178 let diff: f64 = reset_after + tat_increment - brust_offset;
179 if diff > 0.0 {
180 return Ok(LimitResult {
181 limited: true,
182 remaining: f64::floor((brust_offset - reset_after) / emission_interval)
183 as usize,
184 retry_after: Some(time::Duration::from_secs_f64(diff.abs())),
185 reset_after: reset_time.duration_since(now),
186 });
187 }
188 }
189 }
190
191 let mut con = self.client.get_connection()?;
192 let result: redis::Value = ALLOW_N_SCRIPT
193 .key(&key)
194 .arg(emission_interval)
195 .arg(brust_offset)
196 .arg(tat_increment)
197 .arg(n)
198 .invoke(&mut con)?;
199
200 let (limited, remaining, retry_after_secs, reset_after_secs): (bool, usize, f64, f64) =
201 redis::from_redis_value(&result)?;
202 let retry_after = if retry_after_secs < 0.0 {
203 None
204 } else {
205 Some(time::Duration::from_secs_f64(retry_after_secs))
206 };
207 let reset_after = time::Duration::from_secs_f64(reset_after_secs);
208
209 #[cfg(feature = "local_accelerate")]
210 if let Ok(mut store) = RESET_TIME_STORE.try_write() {
211 store.insert(key, now + reset_after);
212 }
213
214 Ok(LimitResult {
215 limited,
216 remaining,
217 retry_after,
218 reset_after,
219 })
220 }
221}
222
223#[test]
224fn test_limiter() {
225 #[cfg(feature = "local_accelerate")]
226 use std::thread;
227
228 let limit = Limit::new(5, 5, 20);
229 let key = "test";
230 let limiter = Limiter::new(redis::Client::open("redis://127.0.0.1/").unwrap());
231 limiter.reset(key).unwrap();
232
233 #[cfg(feature = "local_accelerate")]
234 {
235 let limiter_clone = limiter.clone();
236 thread::spawn(move || {
237 limiter_clone.start_event_sync().unwrap();
238 });
239 }
240
241 let result = limiter.allow_n(key, &limit, 4).unwrap();
242 assert_eq!(result.limited, false);
243 assert_eq!(result.remaining, 1);
244 let result = limiter.allow_n(key, &limit, 3).unwrap();
245 assert_eq!(result.limited, true);
246 assert_eq!(result.remaining, 1);
247 let result = limiter.allow(&key, &limit).unwrap();
248 assert_eq!(result.limited, false);
249 assert_eq!(result.remaining, 0);
250
251 let result = limiter.allow_n(key, &limit, 5).unwrap();
252 assert_eq!(result.limited, true);
253 limiter.reset(key).unwrap();
254
255 #[cfg(feature = "local_accelerate")]
256 thread::sleep(time::Duration::from_millis(100));
258
259 let result = limiter.allow_n(key, &limit, 5).unwrap();
260 assert_eq!(result.limited, false);
261}