actix_extensible_rate_limit/backend/
redis.rs

1use crate::backend::{Backend, Decision, SimpleBackend, SimpleInput, SimpleOutput};
2use actix_web::rt::time::Instant;
3use actix_web::{HttpResponse, ResponseError};
4use redis::aio::ConnectionManager;
5use redis::AsyncCommands;
6use std::borrow::Cow;
7use std::time::Duration;
8use thiserror::Error;
9
10const BITFIELD_ENCODING: &str = "u63";
11const BITFIELD_OFFSET: u8 = 0;
12
13#[derive(Debug, Error)]
14pub enum Error {
15    #[error("Redis error: {0}")]
16    Redis(
17        #[source]
18        #[from]
19        redis::RedisError,
20    ),
21    #[error("Unexpected negative TTL response for the rate limit key")]
22    NegativeTtl,
23}
24
25impl ResponseError for Error {
26    fn error_response(&self) -> HttpResponse {
27        HttpResponse::InternalServerError().finish()
28    }
29}
30
31/// A Fixed Window rate limiter [Backend] that uses stores data in Redis.
32#[derive(Clone)]
33pub struct RedisBackend {
34    connection: ConnectionManager,
35    key_prefix: Option<String>,
36}
37
38impl RedisBackend {
39    /// Create a RedisBackendBuilder.
40    ///
41    /// # Arguments
42    ///
43    /// * `pool`: [A Redis connection pool](https://github.com/importcjj/mobc-redis)
44    ///
45    /// # Examples
46    ///
47    /// ```no_run
48    /// # use actix_extensible_rate_limit::backend::redis::RedisBackend;
49    /// # use redis::aio::ConnectionManager;
50    /// # async fn example() {
51    /// let client = redis::Client::open("redis://127.0.0.1/").unwrap();
52    /// let manager = ConnectionManager::new(client).await.unwrap();
53    /// let backend = RedisBackend::builder(manager).build();
54    /// # };
55    /// ```
56    pub fn builder(connection: ConnectionManager) -> Builder {
57        Builder {
58            connection,
59            key_prefix: None,
60        }
61    }
62
63    fn make_key<'t>(&self, key: &'t str) -> Cow<'t, str> {
64        match &self.key_prefix {
65            None => Cow::Borrowed(key),
66            Some(prefix) => Cow::Owned(format!("{prefix}{key}")),
67        }
68    }
69}
70
71pub struct Builder {
72    connection: ConnectionManager,
73    key_prefix: Option<String>,
74}
75
76impl Builder {
77    /// Apply an optional prefix to all rate limit keys given to this backend.
78    ///
79    /// This may be useful when the Redis instance is being used for other purposes; the prefix is
80    /// used as a 'namespace' to avoid collision with other caches or keys inside Redis.
81    pub fn key_prefix(mut self, key_prefix: Option<&str>) -> Self {
82        self.key_prefix = key_prefix.map(ToOwned::to_owned);
83        self
84    }
85
86    pub fn build(self) -> RedisBackend {
87        RedisBackend {
88            connection: self.connection,
89            key_prefix: self.key_prefix,
90        }
91    }
92}
93
94impl Backend<SimpleInput> for RedisBackend {
95    type Output = SimpleOutput;
96    type RollbackToken = String;
97    type Error = Error;
98
99    async fn request(
100        &self,
101        input: SimpleInput,
102    ) -> Result<(Decision, Self::Output, Self::RollbackToken), Self::Error> {
103        let key = self.make_key(&input.key);
104
105        let mut pipe = redis::pipe();
106        pipe.atomic()
107            // Increment the rate limit count
108            .cmd("BITFIELD")
109            .arg(key.as_ref())
110            .arg("OVERFLOW")
111            .arg("SAT")
112            .arg("INCRBY")
113            .arg(BITFIELD_ENCODING)
114            .arg(BITFIELD_OFFSET)
115            .arg(1)
116            .arg("GET")
117            .arg(BITFIELD_ENCODING)
118            .arg(BITFIELD_OFFSET)
119            // Set the key to expire (only if it doesn't already have an expiry)
120            .cmd("EXPIRE")
121            .arg(key.as_ref())
122            .arg(input.interval.as_secs())
123            .arg("NX")
124            .ignore()
125            // Return time-to-live of key
126            .cmd("TTL")
127            .arg(key.as_ref());
128
129        let mut con = self.connection.clone();
130        let (counts, ttl): (Vec<u64>, i64) = pipe.query_async(&mut con).await?;
131        if ttl < 0 {
132            return Err(Error::NegativeTtl);
133        }
134        let count = *counts.first().expect("BITFIELD should return one value");
135
136        let allow = count <= input.max_requests;
137        let output = SimpleOutput {
138            limit: input.max_requests,
139            remaining: input.max_requests.saturating_sub(count),
140            reset: Instant::now() + Duration::from_secs(ttl as u64),
141        };
142        Ok((Decision::from_allowed(allow), output, input.key))
143    }
144
145    async fn rollback(&self, token: Self::RollbackToken) -> Result<(), Self::Error> {
146        let key = self.make_key(&token);
147
148        let mut con = self.connection.clone();
149
150        let mut pipe = redis::pipe();
151        pipe.atomic()
152            // Decrement the rate limit count
153            .cmd("BITFIELD")
154            .arg(key.as_ref())
155            .arg("OVERFLOW")
156            .arg("SAT")
157            .arg("INCRBY")
158            .arg(BITFIELD_ENCODING)
159            .arg(BITFIELD_OFFSET)
160            .arg(-1)
161            // Set the key to expire immediately, if it doesn't already have an expiry
162            .cmd("EXPIRE")
163            .arg(key.as_ref())
164            .arg(0)
165            .arg("NX")
166            .ignore();
167
168        pipe.query_async(&mut con).await?;
169
170        Ok(())
171    }
172}
173
174impl SimpleBackend for RedisBackend {
175    /// Note that the key prefix (if set) is automatically included, you do not need to prepend
176    /// it yourself.
177    async fn remove_key(&self, key: &str) -> Result<(), Self::Error> {
178        let key = self.make_key(key);
179        let mut con = self.connection.clone();
180        con.del(key.as_ref()).await?;
181        Ok(())
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::HeaderCompatibleOutput;
189    use redis::Cmd;
190
191    const MINUTE: Duration = Duration::from_secs(60);
192
193    // Each test must use non-overlapping keys (because the tests may be run concurrently)
194    // Each test should also reset its key on each run, so that it is in a clean state.
195    async fn make_backend(clear_test_key: &str) -> Builder {
196        let host = option_env!("REDIS_HOST").unwrap_or("127.0.0.1");
197        let port = option_env!("REDIS_PORT").unwrap_or("6379");
198        let client = redis::Client::open(format!("redis://{host}:{port}")).unwrap();
199        let mut manager = ConnectionManager::new(client).await.unwrap();
200        manager.del::<_, ()>(clear_test_key).await.unwrap();
201        RedisBackend::builder(manager)
202    }
203
204    #[actix_web::test]
205    async fn test_allow_deny() {
206        let backend = make_backend("test_allow_deny").await.build();
207        let input = SimpleInput {
208            interval: MINUTE,
209            max_requests: 5,
210            key: "test_allow_deny".to_string(),
211        };
212        let mut prev_seconds_until_reset = u64::MAX;
213        for i in (0..5).rev() {
214            // First 5 should be allowed
215            let (decision, output, _) = backend.request(input.clone()).await.unwrap();
216            // Remaining counts should be decreasing
217            assert_eq!(output.remaining, i);
218            // Limit should be the same
219            assert_eq!(output.limit, 5);
220            // Request should be allowed
221            assert!(decision.is_allowed());
222            // Check expiry time is going down each time (instead of being reset)
223            assert!(output.seconds_until_reset() < prev_seconds_until_reset);
224            // Sleep for a second
225            prev_seconds_until_reset = output.seconds_until_reset();
226            tokio::time::sleep(Duration::from_secs(1)).await;
227        }
228        // Sixth should be denied
229        let (decision, output, _) = backend.request(input.clone()).await.unwrap();
230        assert_eq!(output.remaining, 0);
231        assert_eq!(output.limit, 5);
232        assert!(decision.is_denied());
233    }
234
235    #[actix_web::test]
236    async fn test_reset() {
237        let backend = make_backend("test_reset").await.build();
238        let input = SimpleInput {
239            interval: Duration::from_secs(3),
240            max_requests: 1,
241            key: "test_reset".to_string(),
242        };
243        // Make first request, should be allowed
244        let (decision, _, _) = backend.request(input.clone()).await.unwrap();
245        assert!(decision.is_allowed());
246
247        // Request again immediately afterwards, should now be denied
248        let (decision, out, _) = backend.request(input.clone()).await.unwrap();
249        assert!(decision.is_denied());
250
251        // Sleep until reset, should now be allowed
252        tokio::time::sleep(Duration::from_secs(out.seconds_until_reset())).await;
253        let (decision, _, _) = backend.request(input).await.unwrap();
254        assert!(decision.is_allowed());
255    }
256
257    #[actix_web::test]
258    async fn test_output() {
259        let backend = make_backend("test_output").await.build();
260        let input = SimpleInput {
261            interval: MINUTE,
262            max_requests: 2,
263            key: "test_output".to_string(),
264        };
265        // First of 2 should be allowed.
266        let (decision, output, _) = backend.request(input.clone()).await.unwrap();
267        assert!(decision.is_allowed());
268        assert_eq!(output.remaining, 1);
269        assert_eq!(output.limit, 2);
270        assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60);
271
272        // Second of 2 should be allowed.
273        let (decision, output, _) = backend.request(input.clone()).await.unwrap();
274        assert!(decision.is_allowed());
275        assert_eq!(output.remaining, 0);
276        assert_eq!(output.limit, 2);
277        assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60);
278
279        // Should be denied
280        let (decision, output, _) = backend.request(input).await.unwrap();
281        assert!(decision.is_denied());
282        assert_eq!(output.remaining, 0);
283        assert_eq!(output.limit, 2);
284        assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60);
285    }
286
287    #[actix_web::test]
288    async fn test_rollback() {
289        let backend = make_backend("test_rollback").await.build();
290        let input = SimpleInput {
291            interval: MINUTE,
292            max_requests: 5,
293            key: "test_rollback".to_string(),
294        };
295        let (_, output, rollback) = backend.request(input.clone()).await.unwrap();
296        assert_eq!(output.remaining, 4);
297        backend.rollback(rollback).await.unwrap();
298        // Remaining requests should still be the same, since the previous call was excluded
299        let (_, output, _) = backend.request(input).await.unwrap();
300        assert_eq!(output.remaining, 4);
301        // Check ttl is not corrupted
302        assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60);
303    }
304
305    #[actix_web::test]
306    async fn test_rollback_key_gone() {
307        let key = "test_rollback_key_gone";
308        let backend = make_backend(key).await.build();
309        let mut con = backend.connection.clone();
310        // The rollback could happen after the key has already expired / gone
311        backend.rollback(key.to_string()).await.unwrap();
312        // In which case the count should remain at 0 (it must not become negative)
313        let mut cmd = Cmd::new();
314        cmd.arg("BITFIELD")
315            .arg(key)
316            .arg("GET")
317            .arg(BITFIELD_ENCODING)
318            .arg(BITFIELD_OFFSET);
319        let value: Vec<u64> = cmd.query_async(&mut con).await.unwrap();
320        assert_eq!(value[0], 0u64);
321    }
322
323    #[actix_web::test]
324    async fn test_remove_key() {
325        let backend = make_backend("test_remove_key").await.build();
326        let input = SimpleInput {
327            interval: MINUTE,
328            max_requests: 1,
329            key: "test_remove_key".to_string(),
330        };
331        let (decision, _, _) = backend.request(input.clone()).await.unwrap();
332        assert!(decision.is_allowed());
333        let (decision, _, _) = backend.request(input.clone()).await.unwrap();
334        assert!(decision.is_denied());
335        backend.remove_key("test_remove_key").await.unwrap();
336        // Counter should have been reset
337        let (decision, _, _) = backend.request(input).await.unwrap();
338        assert!(decision.is_allowed());
339    }
340
341    #[actix_web::test]
342    async fn test_key_prefix() {
343        let backend = make_backend("prefix:test_key_prefix")
344            .await
345            .key_prefix(Some("prefix:"))
346            .build();
347        let mut con = backend.connection.clone();
348        let input = SimpleInput {
349            interval: MINUTE,
350            max_requests: 5,
351            key: "test_key_prefix".to_string(),
352        };
353        backend.request(input.clone()).await.unwrap();
354        assert!(con
355            .exists::<_, bool>("prefix:test_key_prefix")
356            .await
357            .unwrap());
358
359        backend.remove_key("test_key_prefix").await.unwrap();
360        assert!(!con
361            .exists::<_, bool>("prefix:test_key_prefix")
362            .await
363            .unwrap());
364    }
365}