actix_extensible_rate_limit/backend/
redis.rs1use crate::backend::{Backend, Decision, SimpleBackend, SimpleInput, SimpleOutput};
2use actix_web::rt::time::Instant;
3use actix_web::{HttpResponse, ResponseError};
4use redis::aio::ConnectionManager;
5use redis::AsyncCommands;
6use std::borrow::Cow;
7use std::time::Duration;
8use thiserror::Error;
9
10const BITFIELD_ENCODING: &str = "u63";
11const BITFIELD_OFFSET: u8 = 0;
12
13#[derive(Debug, Error)]
14pub enum Error {
15 #[error("Redis error: {0}")]
16 Redis(
17 #[source]
18 #[from]
19 redis::RedisError,
20 ),
21 #[error("Unexpected negative TTL response for the rate limit key")]
22 NegativeTtl,
23}
24
25impl ResponseError for Error {
26 fn error_response(&self) -> HttpResponse {
27 HttpResponse::InternalServerError().finish()
28 }
29}
30
31#[derive(Clone)]
33pub struct RedisBackend {
34 connection: ConnectionManager,
35 key_prefix: Option<String>,
36}
37
38impl RedisBackend {
39 pub fn builder(connection: ConnectionManager) -> Builder {
57 Builder {
58 connection,
59 key_prefix: None,
60 }
61 }
62
63 fn make_key<'t>(&self, key: &'t str) -> Cow<'t, str> {
64 match &self.key_prefix {
65 None => Cow::Borrowed(key),
66 Some(prefix) => Cow::Owned(format!("{prefix}{key}")),
67 }
68 }
69}
70
71pub struct Builder {
72 connection: ConnectionManager,
73 key_prefix: Option<String>,
74}
75
76impl Builder {
77 pub fn key_prefix(mut self, key_prefix: Option<&str>) -> Self {
82 self.key_prefix = key_prefix.map(ToOwned::to_owned);
83 self
84 }
85
86 pub fn build(self) -> RedisBackend {
87 RedisBackend {
88 connection: self.connection,
89 key_prefix: self.key_prefix,
90 }
91 }
92}
93
94impl Backend<SimpleInput> for RedisBackend {
95 type Output = SimpleOutput;
96 type RollbackToken = String;
97 type Error = Error;
98
99 async fn request(
100 &self,
101 input: SimpleInput,
102 ) -> Result<(Decision, Self::Output, Self::RollbackToken), Self::Error> {
103 let key = self.make_key(&input.key);
104
105 let mut pipe = redis::pipe();
106 pipe.atomic()
107 .cmd("BITFIELD")
109 .arg(key.as_ref())
110 .arg("OVERFLOW")
111 .arg("SAT")
112 .arg("INCRBY")
113 .arg(BITFIELD_ENCODING)
114 .arg(BITFIELD_OFFSET)
115 .arg(1)
116 .arg("GET")
117 .arg(BITFIELD_ENCODING)
118 .arg(BITFIELD_OFFSET)
119 .cmd("EXPIRE")
121 .arg(key.as_ref())
122 .arg(input.interval.as_secs())
123 .arg("NX")
124 .ignore()
125 .cmd("TTL")
127 .arg(key.as_ref());
128
129 let mut con = self.connection.clone();
130 let (counts, ttl): (Vec<u64>, i64) = pipe.query_async(&mut con).await?;
131 if ttl < 0 {
132 return Err(Error::NegativeTtl);
133 }
134 let count = *counts.first().expect("BITFIELD should return one value");
135
136 let allow = count <= input.max_requests;
137 let output = SimpleOutput {
138 limit: input.max_requests,
139 remaining: input.max_requests.saturating_sub(count),
140 reset: Instant::now() + Duration::from_secs(ttl as u64),
141 };
142 Ok((Decision::from_allowed(allow), output, input.key))
143 }
144
145 async fn rollback(&self, token: Self::RollbackToken) -> Result<(), Self::Error> {
146 let key = self.make_key(&token);
147
148 let mut con = self.connection.clone();
149
150 let mut pipe = redis::pipe();
151 pipe.atomic()
152 .cmd("BITFIELD")
154 .arg(key.as_ref())
155 .arg("OVERFLOW")
156 .arg("SAT")
157 .arg("INCRBY")
158 .arg(BITFIELD_ENCODING)
159 .arg(BITFIELD_OFFSET)
160 .arg(-1)
161 .cmd("EXPIRE")
163 .arg(key.as_ref())
164 .arg(0)
165 .arg("NX")
166 .ignore();
167
168 pipe.query_async(&mut con).await?;
169
170 Ok(())
171 }
172}
173
174impl SimpleBackend for RedisBackend {
175 async fn remove_key(&self, key: &str) -> Result<(), Self::Error> {
178 let key = self.make_key(key);
179 let mut con = self.connection.clone();
180 con.del(key.as_ref()).await?;
181 Ok(())
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use crate::HeaderCompatibleOutput;
189 use redis::Cmd;
190
191 const MINUTE: Duration = Duration::from_secs(60);
192
193 async fn make_backend(clear_test_key: &str) -> Builder {
196 let host = option_env!("REDIS_HOST").unwrap_or("127.0.0.1");
197 let port = option_env!("REDIS_PORT").unwrap_or("6379");
198 let client = redis::Client::open(format!("redis://{host}:{port}")).unwrap();
199 let mut manager = ConnectionManager::new(client).await.unwrap();
200 manager.del::<_, ()>(clear_test_key).await.unwrap();
201 RedisBackend::builder(manager)
202 }
203
204 #[actix_web::test]
205 async fn test_allow_deny() {
206 let backend = make_backend("test_allow_deny").await.build();
207 let input = SimpleInput {
208 interval: MINUTE,
209 max_requests: 5,
210 key: "test_allow_deny".to_string(),
211 };
212 let mut prev_seconds_until_reset = u64::MAX;
213 for i in (0..5).rev() {
214 let (decision, output, _) = backend.request(input.clone()).await.unwrap();
216 assert_eq!(output.remaining, i);
218 assert_eq!(output.limit, 5);
220 assert!(decision.is_allowed());
222 assert!(output.seconds_until_reset() < prev_seconds_until_reset);
224 prev_seconds_until_reset = output.seconds_until_reset();
226 tokio::time::sleep(Duration::from_secs(1)).await;
227 }
228 let (decision, output, _) = backend.request(input.clone()).await.unwrap();
230 assert_eq!(output.remaining, 0);
231 assert_eq!(output.limit, 5);
232 assert!(decision.is_denied());
233 }
234
235 #[actix_web::test]
236 async fn test_reset() {
237 let backend = make_backend("test_reset").await.build();
238 let input = SimpleInput {
239 interval: Duration::from_secs(3),
240 max_requests: 1,
241 key: "test_reset".to_string(),
242 };
243 let (decision, _, _) = backend.request(input.clone()).await.unwrap();
245 assert!(decision.is_allowed());
246
247 let (decision, out, _) = backend.request(input.clone()).await.unwrap();
249 assert!(decision.is_denied());
250
251 tokio::time::sleep(Duration::from_secs(out.seconds_until_reset())).await;
253 let (decision, _, _) = backend.request(input).await.unwrap();
254 assert!(decision.is_allowed());
255 }
256
257 #[actix_web::test]
258 async fn test_output() {
259 let backend = make_backend("test_output").await.build();
260 let input = SimpleInput {
261 interval: MINUTE,
262 max_requests: 2,
263 key: "test_output".to_string(),
264 };
265 let (decision, output, _) = backend.request(input.clone()).await.unwrap();
267 assert!(decision.is_allowed());
268 assert_eq!(output.remaining, 1);
269 assert_eq!(output.limit, 2);
270 assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60);
271
272 let (decision, output, _) = backend.request(input.clone()).await.unwrap();
274 assert!(decision.is_allowed());
275 assert_eq!(output.remaining, 0);
276 assert_eq!(output.limit, 2);
277 assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60);
278
279 let (decision, output, _) = backend.request(input).await.unwrap();
281 assert!(decision.is_denied());
282 assert_eq!(output.remaining, 0);
283 assert_eq!(output.limit, 2);
284 assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60);
285 }
286
287 #[actix_web::test]
288 async fn test_rollback() {
289 let backend = make_backend("test_rollback").await.build();
290 let input = SimpleInput {
291 interval: MINUTE,
292 max_requests: 5,
293 key: "test_rollback".to_string(),
294 };
295 let (_, output, rollback) = backend.request(input.clone()).await.unwrap();
296 assert_eq!(output.remaining, 4);
297 backend.rollback(rollback).await.unwrap();
298 let (_, output, _) = backend.request(input).await.unwrap();
300 assert_eq!(output.remaining, 4);
301 assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60);
303 }
304
305 #[actix_web::test]
306 async fn test_rollback_key_gone() {
307 let key = "test_rollback_key_gone";
308 let backend = make_backend(key).await.build();
309 let mut con = backend.connection.clone();
310 backend.rollback(key.to_string()).await.unwrap();
312 let mut cmd = Cmd::new();
314 cmd.arg("BITFIELD")
315 .arg(key)
316 .arg("GET")
317 .arg(BITFIELD_ENCODING)
318 .arg(BITFIELD_OFFSET);
319 let value: Vec<u64> = cmd.query_async(&mut con).await.unwrap();
320 assert_eq!(value[0], 0u64);
321 }
322
323 #[actix_web::test]
324 async fn test_remove_key() {
325 let backend = make_backend("test_remove_key").await.build();
326 let input = SimpleInput {
327 interval: MINUTE,
328 max_requests: 1,
329 key: "test_remove_key".to_string(),
330 };
331 let (decision, _, _) = backend.request(input.clone()).await.unwrap();
332 assert!(decision.is_allowed());
333 let (decision, _, _) = backend.request(input.clone()).await.unwrap();
334 assert!(decision.is_denied());
335 backend.remove_key("test_remove_key").await.unwrap();
336 let (decision, _, _) = backend.request(input).await.unwrap();
338 assert!(decision.is_allowed());
339 }
340
341 #[actix_web::test]
342 async fn test_key_prefix() {
343 let backend = make_backend("prefix:test_key_prefix")
344 .await
345 .key_prefix(Some("prefix:"))
346 .build();
347 let mut con = backend.connection.clone();
348 let input = SimpleInput {
349 interval: MINUTE,
350 max_requests: 5,
351 key: "test_key_prefix".to_string(),
352 };
353 backend.request(input.clone()).await.unwrap();
354 assert!(con
355 .exists::<_, bool>("prefix:test_key_prefix")
356 .await
357 .unwrap());
358
359 backend.remove_key("test_key_prefix").await.unwrap();
360 assert!(!con
361 .exists::<_, bool>("prefix:test_key_prefix")
362 .await
363 .unwrap());
364 }
365}