redis_rate/
lib.rs

1mod scripts;
2
3use std::time;
4
5#[cfg(feature = "local_accelerate")]
6use std::{
7    collections::HashMap,
8    sync::{LazyLock, RwLock},
9};
10
11use scripts::ALLOW_N_SCRIPT;
12
13#[cfg(feature = "local_accelerate")]
14static RESET_TIME_STORE: LazyLock<RwLock<HashMap<String, time::Instant>>> =
15    LazyLock::new(|| RwLock::new(HashMap::new()));
16
17const DEFAULT_LIMITER_KEY_PREFIX: &str = "redis_rate:";
18
19#[cfg(feature = "local_accelerate")]
20const DEFAULT_LIMITER_EVENT_CHANNEL: &str = "redis_rate_channel";
21#[cfg(feature = "local_accelerate")]
22const LIMITER_RESET_EVENT_PREFIX: &str = "reset:";
23
24/// Rate limit setting.
25#[derive(Debug, Clone)]
26pub struct Limit {
27    rate: usize,
28    burst: usize,
29    period_seconds: usize,
30}
31
32impl Limit {
33    /// Create a new `Limit` setting.
34    /// Code will panic if you try to create a limit with invalid values.
35    pub fn new(rate: usize, burst: usize, period_seconds: usize) -> Self {
36        if period_seconds == 0 {
37            panic!("period_seconds must be greater than 0");
38        }
39        if rate == 0 {
40            panic!("rate must be greater than 0");
41        }
42        if rate > burst {
43            panic!("rate must be less than or equal to burst");
44        }
45
46        Limit {
47            rate,
48            burst,
49            period_seconds,
50        }
51    }
52}
53
54/// Compile-time checked macro to create a new `Limit` instance.
55/// If you want to create dynamically configured limits, use `Limit::new` instead.
56#[macro_export]
57macro_rules! new_limit {
58    ($rate:expr, $burst:expr, $period_seconds:expr) => {{
59        const _: () = {
60            assert!($period_seconds > 0, "period_seconds must be greater than 0");
61            assert!($rate > 0, "rate must be greater than 0");
62            assert!($rate <= $burst, "rate must be less than or equal to burst");
63        };
64        $crate::Limit::new($rate, $burst, $period_seconds)
65    }};
66}
67
68/// Result of a limit check.
69#[derive(Debug, Clone)]
70pub struct LimitResult {
71    /// Whether the request is limited.
72    pub limited: bool,
73    /// Remaining requests that can be made within the limit.
74    pub remaining: usize,
75    /// Duration after which the request can be retried.
76    /// If the request is not limited, this will be `None`.
77    pub retry_after: Option<time::Duration>,
78    /// Duration after which the limit will be totally reset.
79    pub reset_after: time::Duration,
80}
81
82/// Rate limiter backed by Redis.
83#[derive(Debug, Clone)]
84pub struct Limiter {
85    client: redis::Client,
86    key_prefix: String,
87
88    #[cfg(feature = "local_accelerate")]
89    event_channel: String,
90}
91
92impl Limiter {
93    /// Create a new limiter with the given Redis client.
94    pub fn new(client: redis::Client) -> Self {
95        Limiter {
96            client,
97            key_prefix: DEFAULT_LIMITER_KEY_PREFIX.to_string(),
98
99            #[cfg(feature = "local_accelerate")]
100            event_channel: DEFAULT_LIMITER_EVENT_CHANNEL.to_string(),
101        }
102    }
103
104    /// Set the key prefix for the limiter's Redis keys.
105    pub fn set_key_prefix(mut self, key_prefix: &str) -> Self {
106        self.key_prefix = key_prefix.to_string();
107        self
108    }
109
110    /// Set the event channel name for the limiter.
111    /// This should be called before `start_event_sync`.
112    #[cfg(feature = "local_accelerate")]
113    pub fn set_event_channel(mut self, channel: &str) -> Self {
114        self.event_channel = channel.to_string();
115        self
116    }
117
118    /// Start a listening loop on the event channel.
119    /// When reset event is triggered on other instances, the limiter will reset the local cache for the key.
120    #[cfg(feature = "local_accelerate")]
121    pub fn start_event_sync(&self) -> Result<(), redis::RedisError> {
122        let mut con = self.client.get_connection()?;
123        let mut pubsub = con.as_pubsub();
124        pubsub.subscribe(&self.event_channel).unwrap();
125        loop {
126            let msg = pubsub.get_message()?.get_payload::<String>()?;
127            if msg.starts_with(LIMITER_RESET_EVENT_PREFIX) {
128                let payload = msg.split_at(LIMITER_RESET_EVENT_PREFIX.len()).1;
129                if let Ok(mut store) = RESET_TIME_STORE.try_write() {
130                    store.remove(payload);
131                }
132            }
133        }
134    }
135
136    /// Reset the limit for a key.
137    pub fn reset(&self, key: &str) -> Result<(), redis::RedisError> {
138        let key = format!("{}{}", self.key_prefix, key);
139        let mut con = self.client.get_connection()?;
140        redis::cmd("DEL").arg(&key).query::<()>(&mut con)?;
141
142        #[cfg(feature = "local_accelerate")]
143        {
144            let reset_notify = format!("{}{}", LIMITER_RESET_EVENT_PREFIX, key);
145            redis::cmd("PUBLISH")
146                .arg(self.event_channel.clone())
147                .arg(&reset_notify)
148                .query::<()>(&mut con)?;
149        }
150
151        Ok(())
152    }
153
154    /// Allow a request to be made within the limit.
155    pub fn allow(&self, key: &str, limit: &Limit) -> Result<LimitResult, redis::RedisError> {
156        self.allow_n(key, limit, 1)
157    }
158
159    /// Allow n requests to be made within the limit.
160    pub fn allow_n(
161        &self,
162        key: &str,
163        limit: &Limit,
164        n: usize,
165    ) -> Result<LimitResult, redis::RedisError> {
166        let key = format!("{}{}", self.key_prefix, key);
167
168        let emission_interval = limit.period_seconds as f64 / limit.rate as f64;
169        let tat_increment = emission_interval * n as f64;
170        let brust_offset = limit.burst as f64 * emission_interval;
171
172        #[cfg(feature = "local_accelerate")]
173        let now = time::Instant::now();
174        #[cfg(feature = "local_accelerate")]
175        if let Ok(store) = RESET_TIME_STORE.try_read() {
176            if let Some(reset_time) = store.get(&key) {
177                let reset_after = reset_time.duration_since(now).as_secs_f64();
178                let diff: f64 = reset_after + tat_increment - brust_offset;
179                if diff > 0.0 {
180                    return Ok(LimitResult {
181                        limited: true,
182                        remaining: f64::floor((brust_offset - reset_after) / emission_interval)
183                            as usize,
184                        retry_after: Some(time::Duration::from_secs_f64(diff.abs())),
185                        reset_after: reset_time.duration_since(now),
186                    });
187                }
188            }
189        }
190
191        let mut con = self.client.get_connection()?;
192        let result: redis::Value = ALLOW_N_SCRIPT
193            .key(&key)
194            .arg(emission_interval)
195            .arg(brust_offset)
196            .arg(tat_increment)
197            .arg(n)
198            .invoke(&mut con)?;
199
200        let (limited, remaining, retry_after_secs, reset_after_secs): (bool, usize, f64, f64) =
201            redis::from_redis_value(&result)?;
202        let retry_after = if retry_after_secs < 0.0 {
203            None
204        } else {
205            Some(time::Duration::from_secs_f64(retry_after_secs))
206        };
207        let reset_after = time::Duration::from_secs_f64(reset_after_secs);
208
209        #[cfg(feature = "local_accelerate")]
210        if let Ok(mut store) = RESET_TIME_STORE.try_write() {
211            store.insert(key, now + reset_after);
212        }
213
214        Ok(LimitResult {
215            limited,
216            remaining,
217            retry_after,
218            reset_after,
219        })
220    }
221}
222
223#[test]
224fn test_limiter() {
225    #[cfg(feature = "local_accelerate")]
226    use std::thread;
227
228    let limit = Limit::new(5, 5, 20);
229    let key = "test";
230    let limiter = Limiter::new(redis::Client::open("redis://127.0.0.1/").unwrap());
231    limiter.reset(key).unwrap();
232
233    #[cfg(feature = "local_accelerate")]
234    {
235        let limiter_clone = limiter.clone();
236        thread::spawn(move || {
237            limiter_clone.start_event_sync().unwrap();
238        });
239    }
240
241    let result = limiter.allow_n(key, &limit, 4).unwrap();
242    assert_eq!(result.limited, false);
243    assert_eq!(result.remaining, 1);
244    let result = limiter.allow_n(key, &limit, 3).unwrap();
245    assert_eq!(result.limited, true);
246    assert_eq!(result.remaining, 1);
247    let result = limiter.allow(&key, &limit).unwrap();
248    assert_eq!(result.limited, false);
249    assert_eq!(result.remaining, 0);
250
251    let result = limiter.allow_n(key, &limit, 5).unwrap();
252    assert_eq!(result.limited, true);
253    limiter.reset(key).unwrap();
254
255    #[cfg(feature = "local_accelerate")]
256    // Wait for the reset event to be processed in the other thread
257    thread::sleep(time::Duration::from_millis(100));
258
259    let result = limiter.allow_n(key, &limit, 5).unwrap();
260    assert_eq!(result.limited, false);
261}