actix_ratelimit/stores/
redis.rs

1//! Redis store for rate limiting
2use actix::prelude::*;
3use backoff::backoff::Backoff;
4use backoff::ExponentialBackoff;
5use log::*;
6use redis_rs::{self as redis, aio::MultiplexedConnection};
7use std::time::Duration;
8
9use crate::errors::ARError;
10use crate::{ActorMessage, ActorResponse};
11
12struct GetAddr;
13impl Message for GetAddr {
14    type Result = Result<MultiplexedConnection, ARError>;
15}
16
17/// Type used to connect to a running redis instance
18pub struct RedisStore {
19    addr: String,
20    backoff: ExponentialBackoff,
21    client: Option<MultiplexedConnection>,
22}
23
24impl RedisStore {
25    /// Accepts a valid connection string to connect to redis
26    ///
27    /// # Example
28    /// ```rust
29    /// use actix_ratelimit::RedisStore;
30    ///
31    /// #[actix_rt::main]
32    /// async fn main() -> std::io::Result<()>{
33    ///     let store = RedisStore::connect("redis://127.0.0.1");
34    ///     Ok(())
35    /// }
36    /// ```
37    pub fn connect<S: Into<String>>(addr: S) -> Addr<Self> {
38        let addr = addr.into();
39        let mut backoff = ExponentialBackoff::default();
40        backoff.max_elapsed_time = None;
41        Supervisor::start(|_| RedisStore {
42            addr,
43            backoff,
44            client: None,
45        })
46    }
47}
48
49impl Actor for RedisStore {
50    type Context = Context<Self>;
51
52    fn started(&mut self, ctx: &mut Context<Self>) {
53        info!("Started main redis store");
54        let addr = self.addr.clone();
55        async move {
56            let client = redis::Client::open(addr.as_ref()).unwrap();
57            client.get_multiplexed_async_connection().await
58        }
59        .into_actor(self)
60        .map(|con, act, context| {
61            match con {
62                Ok(c) => {
63                    act.client = Some(c.0);
64                    let fut = c.1;
65                    fut.into_actor(act).spawn(context);
66                }
67                Err(e) => {
68                    error!("Error connecting to redis: {}", &e);
69                    if let Some(timeout) = act.backoff.next_backoff() {
70                        context.run_later(timeout, |_, ctx| ctx.stop());
71                    }
72                }
73            };
74            info!("Connected to redis server");
75            act.backoff.reset();
76        })
77        .wait(ctx);
78    }
79}
80
81impl Supervised for RedisStore {
82    fn restarting(&mut self, _: &mut Self::Context) {
83        debug!("restarting redis store");
84        self.client.take();
85    }
86}
87
88impl Handler<GetAddr> for RedisStore {
89    type Result = Result<MultiplexedConnection, ARError>;
90    fn handle(&mut self, _: GetAddr, ctx: &mut Self::Context) -> Self::Result {
91        if let Some(con) = &self.client {
92            Ok(con.clone())
93        } else {
94            // No connection exists
95            if let Some(backoff) = self.backoff.next_backoff() {
96                ctx.run_later(backoff, |_, ctx| ctx.stop());
97            };
98            Err(ARError::NotConnected)
99        }
100    }
101}
102
103/// Actor for redis store
104pub struct RedisStoreActor {
105    addr: Addr<RedisStore>,
106    backoff: ExponentialBackoff,
107    inner: Option<MultiplexedConnection>,
108}
109
110impl Actor for RedisStoreActor {
111    type Context = Context<Self>;
112
113    fn started(&mut self, ctx: &mut Context<Self>) {
114        let addr = self.addr.clone();
115        async move { addr.send(GetAddr).await }
116            .into_actor(self)
117            .map(|res, act, context| match res {
118                Ok(c) => {
119                    if let Ok(conn) = c {
120                        act.inner = Some(conn);
121                    } else {
122                        error!("could not get redis store address");
123                        if let Some(timeout) = act.backoff.next_backoff() {
124                            context.run_later(timeout, |_, ctx| ctx.stop());
125                        }
126                    }
127                }
128                Err(_) => {
129                    error!("mailboxerror: could not get redis store address");
130                    if let Some(timeout) = act.backoff.next_backoff() {
131                        context.run_later(timeout, |_, ctx| ctx.stop());
132                    }
133                }
134            })
135            .wait(ctx);
136    }
137}
138
139impl From<Addr<RedisStore>> for RedisStoreActor {
140    fn from(addr: Addr<RedisStore>) -> Self {
141        let mut backoff = ExponentialBackoff::default();
142        backoff.max_interval = Duration::from_secs(3);
143        RedisStoreActor {
144            addr,
145            backoff,
146            inner: None,
147        }
148    }
149}
150
151impl RedisStoreActor {
152    /// Starts the redis actor and returns it's address
153    pub fn start(self) -> Addr<Self> {
154        debug!("started redis actor");
155        Supervisor::start(|_| self)
156    }
157}
158
159impl Supervised for RedisStoreActor {
160    fn restarting(&mut self, _: &mut Self::Context) {
161        debug!("restarting redis actor!");
162        self.inner.take();
163    }
164}
165
166impl Handler<ActorMessage> for RedisStoreActor {
167    type Result = ActorResponse;
168    fn handle(&mut self, msg: ActorMessage, ctx: &mut Self::Context) -> Self::Result {
169        let connection = self.inner.clone();
170        if let Some(mut con) = connection {
171            match msg {
172                ActorMessage::Set { key, value, expiry } => {
173                    ActorResponse::Set(Box::pin(async move {
174                        let mut cmd = redis::Cmd::new();
175                        cmd.arg("SET")
176                            .arg(key)
177                            .arg(value)
178                            .arg("EX")
179                            .arg(expiry.as_secs());
180                        let result = cmd.query_async::<MultiplexedConnection, ()>(&mut con).await;
181                        match result {
182                            Ok(_) => Ok(()),
183                            Err(e) => Err(ARError::ReadWriteError(format!("{:?}", &e))),
184                        }
185                    }))
186                }
187                ActorMessage::Update { key, value } => {
188                    ActorResponse::Update(Box::pin(async move {
189                        let mut cmd = redis::Cmd::new();
190                        cmd.arg("DECRBY").arg(key).arg(value);
191                        let result = cmd
192                            .query_async::<MultiplexedConnection, usize>(&mut con)
193                            .await;
194                        match result {
195                            Ok(c) => Ok(c),
196                            Err(e) => Err(ARError::ReadWriteError(format!("{:?}", &e))),
197                        }
198                    }))
199                }
200                ActorMessage::Get(key) => ActorResponse::Get(Box::pin(async move {
201                    let mut cmd = redis::Cmd::new();
202                    cmd.arg("GET").arg(key);
203                    let result = cmd
204                        .query_async::<MultiplexedConnection, Option<usize>>(&mut con)
205                        .await;
206
207                    match result {
208                        Ok(c) => Ok(c),
209                        Err(e) => Err(ARError::ReadWriteError(format!("{:?}", &e))),
210                    }
211                })),
212                ActorMessage::Expire(key) => ActorResponse::Expire(Box::pin(async move {
213                    let mut cmd = redis::Cmd::new();
214                    cmd.arg("TTL").arg(key);
215                    let result = cmd
216                        .query_async::<MultiplexedConnection, isize>(&mut con)
217                        .await;
218                    match result {
219                        Ok(c) => {
220                            if c > 0 {
221                                Ok(Duration::new(c as u64, 0))
222                            } else {
223                                Err(ARError::ReadWriteError("redis error: key does not exists or does not has a associated ttl.".to_string()))
224                            }
225                        }
226                        Err(e) => Err(ARError::ReadWriteError(format!("{:?}", &e))),
227                    }
228                })),
229                ActorMessage::Remove(key) => ActorResponse::Remove(Box::pin(async move {
230                    let mut cmd = redis::Cmd::new();
231                    cmd.arg("DEL").arg(key);
232                    let result = cmd
233                        .query_async::<MultiplexedConnection, usize>(&mut con)
234                        .await;
235                    match result {
236                        Ok(c) => Ok(c),
237                        Err(e) => Err(ARError::ReadWriteError(format!("{:?}", &e))),
238                    }
239                })),
240            }
241        } else {
242            ctx.stop();
243            ActorResponse::Set(Box::pin(async move { Err(ARError::Disconnected) }))
244        }
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    fn init() {
253        let _ = env_logger::builder().is_test(true).try_init();
254    }
255
256    #[actix_rt::test]
257    async fn test_set() {
258        init();
259        let store = RedisStore::connect("redis://127.0.0.1/");
260        let addr = RedisStoreActor::from(store.clone()).start();
261        let res = addr
262            .send(ActorMessage::Set {
263                key: "hello".to_string(),
264                value: 30usize,
265                expiry: Duration::from_secs(5),
266            })
267            .await;
268        let res = res.expect("Failed to send msg");
269        match res {
270            ActorResponse::Set(c) => match c.await {
271                Ok(()) => {}
272                Err(e) => panic!("Shouldn't happen: {}", &e),
273            },
274            _ => panic!("Shouldn't happen!"),
275        }
276    }
277
278    #[actix_rt::test]
279    async fn test_get() {
280        init();
281        let store = RedisStore::connect("redis://127.0.0.1/");
282        let addr = RedisStoreActor::from(store.clone()).start();
283        let expiry = Duration::from_secs(5);
284        let res = addr
285            .send(ActorMessage::Set {
286                key: "hello".to_string(),
287                value: 30usize,
288                expiry: expiry,
289            })
290            .await;
291        let res = res.expect("Failed to send msg");
292        match res {
293            ActorResponse::Set(c) => match c.await {
294                Ok(()) => {}
295                Err(e) => panic!("Shouldn't happen {}", &e),
296            },
297            _ => panic!("Shouldn't happen!"),
298        }
299        let res2 = addr.send(ActorMessage::Get("hello".to_string())).await;
300        let res2 = res2.expect("Failed to send msg");
301        match res2 {
302            ActorResponse::Get(c) => match c.await {
303                Ok(d) => {
304                    let d = d.unwrap();
305                    assert_eq!(d, 30usize);
306                }
307                Err(e) => panic!("Shouldn't happen {}", &e),
308            },
309            _ => panic!("Shouldn't happen!"),
310        };
311    }
312
313    #[actix_rt::test]
314    async fn test_expiry() {
315        init();
316        let store = RedisStore::connect("redis://127.0.0.1/");
317        let addr = RedisStoreActor::from(store.clone()).start();
318        let expiry = Duration::from_secs(3);
319        let res = addr
320            .send(ActorMessage::Set {
321                key: "hello_test".to_string(),
322                value: 30usize,
323                expiry: expiry,
324            })
325            .await;
326        let res = res.expect("Failed to send msg");
327        match res {
328            ActorResponse::Set(c) => match c.await {
329                Ok(()) => {}
330                Err(e) => panic!("Shouldn't happen {}", &e),
331            },
332            _ => panic!("Shouldn't happen!"),
333        }
334        assert_eq!(addr.connected(), true);
335
336        let res3 = addr
337            .send(ActorMessage::Expire("hello_test".to_string()))
338            .await;
339        let res3 = res3.expect("Failed to send msg");
340        match res3 {
341            ActorResponse::Expire(c) => match c.await {
342                Ok(dur) => {
343                    let now = Duration::from_secs(3);
344                    if dur > now {
345                        panic!("Shouldn't happen: {}, {}", &dur.as_secs(), &now.as_secs())
346                    }
347                }
348                Err(e) => {
349                    panic!("Shouldn't happen: {}", &e);
350                }
351            },
352            _ => panic!("Shouldn't happen!"),
353        };
354    }
355}