actix_ratelimit/stores/
memcached.rs

1//! Memcached store for rate limiting
2use crate::errors::ARError;
3use crate::{ActorMessage, ActorResponse};
4use actix::prelude::*;
5use backoff::backoff::Backoff;
6use backoff::ExponentialBackoff;
7use log::*;
8use r2d2_memcache::r2d2::Pool;
9use r2d2_memcache::MemcacheConnectionManager;
10use std::convert::TryInto;
11use std::time::{Duration, SystemTime, UNIX_EPOCH};
12
13struct GetAddr;
14impl Message for GetAddr {
15    type Result = Result<Pool<MemcacheConnectionManager>, ARError>;
16}
17
18/// Type used to connect to a running memecached store
19pub struct MemcacheStore {
20    addr: String,
21    backoff: ExponentialBackoff,
22    client: Option<Pool<MemcacheConnectionManager>>,
23}
24
25impl MemcacheStore {
26    /// Accepts a valid connection string to connect to memcache
27    ///
28    /// # Example
29    /// ```rust
30    /// use actix_ratelimit::MemcacheStore;
31    /// #[actix_rt::main]
32    /// async fn main() -> std::io::Result<()>{
33    ///     let store = MemcacheStore::connect("memcache://127.0.0.1:11211");
34    ///     Ok(())
35    /// }
36    /// ```
37    pub fn connect<S: Into<String>>(addr: S) -> Addr<Self> {
38        let addr = addr.into();
39        let mut backoff = ExponentialBackoff::default();
40        backoff.max_elapsed_time = None;
41        let manager = MemcacheConnectionManager::new(addr.clone());
42        let pool = Pool::builder().max_size(15).build(manager).unwrap();
43        Supervisor::start(|_| MemcacheStore {
44            addr,
45            backoff,
46            client: Some(pool),
47        })
48    }
49}
50
51impl Actor for MemcacheStore {
52    type Context = Context<Self>;
53
54    fn started(&mut self, ctx: &mut Context<Self>) {
55        info!("Started memcached store");
56        let addr = self.addr.clone();
57        let manager = MemcacheConnectionManager::new(addr);
58        let pool = Pool::builder().max_size(15).build(manager);
59        async move { pool }
60            .into_actor(self)
61            .map(|con, act, context| {
62                match con {
63                    Ok(c) => {
64                        act.client = Some(c);
65                    }
66                    Err(e) => {
67                        error!("Error connecting to memcached: {}", &e);
68                        if let Some(timeout) = act.backoff.next_backoff() {
69                            context.run_later(timeout, |_, ctx| ctx.stop());
70                        }
71                    }
72                };
73                info!("Connected to memcached server");
74                act.backoff.reset();
75            })
76            .wait(ctx);
77    }
78}
79
80impl Supervised for MemcacheStore {
81    fn restarting(&mut self, _: &mut Self::Context) {
82        debug!("restarting memcache store");
83        self.client.take();
84    }
85}
86
87impl Handler<GetAddr> for MemcacheStore {
88    type Result = Result<Pool<MemcacheConnectionManager>, ARError>;
89    fn handle(&mut self, _: GetAddr, ctx: &mut Self::Context) -> Self::Result {
90        if let Some(con) = &self.client {
91            Ok(con.clone())
92        } else {
93            if let Some(backoff) = self.backoff.next_backoff() {
94                ctx.run_later(backoff, |_, ctx| ctx.stop());
95            };
96            Err(ARError::NotConnected)
97        }
98    }
99}
100
101/// Actor for MemcacheStore
102pub struct MemcacheStoreActor {
103    addr: Addr<MemcacheStore>,
104    backoff: ExponentialBackoff,
105    inner: Option<Pool<MemcacheConnectionManager>>,
106}
107
108impl Actor for MemcacheStoreActor {
109    type Context = Context<Self>;
110
111    fn started(&mut self, ctx: &mut Context<Self>) {
112        let addr = self.addr.clone();
113        async move { addr.send(GetAddr).await }
114            .into_actor(self)
115            .map(|res, act, context| match res {
116                Ok(c) => {
117                    if let Ok(pool) = c {
118                        act.inner = Some(pool)
119                    } else {
120                        error!("could not get memecache store address");
121                        if let Some(timeout) = act.backoff.next_backoff() {
122                            context.run_later(timeout, |_, ctx| ctx.stop());
123                        }
124                    }
125                }
126                Err(_) => {
127                    error!("mailboxerror: could not get memcached store address");
128                    if let Some(timeout) = act.backoff.next_backoff() {
129                        context.run_later(timeout, |_, ctx| ctx.stop());
130                    }
131                }
132            })
133            .wait(ctx);
134    }
135}
136
137impl From<Addr<MemcacheStore>> for MemcacheStoreActor {
138    fn from(addr: Addr<MemcacheStore>) -> Self {
139        let mut backoff = ExponentialBackoff::default();
140        backoff.max_interval = Duration::from_secs(3);
141        MemcacheStoreActor {
142            addr,
143            backoff,
144            inner: None,
145        }
146    }
147}
148
149impl MemcacheStoreActor {
150    /// Starts the memcached store actor and returns it's address
151    pub fn start(self) -> Addr<Self> {
152        debug!("Started memcache actor");
153        Supervisor::start(|_| self)
154    }
155}
156
157impl Supervised for MemcacheStoreActor {
158    fn restarting(&mut self, _: &mut Self::Context) {
159        debug!("restarting memcache actor");
160        self.inner.take();
161    }
162}
163
164impl Handler<ActorMessage> for MemcacheStoreActor {
165    type Result = ActorResponse;
166    fn handle(&mut self, msg: ActorMessage, ctx: &mut Self::Context) -> Self::Result {
167        let pool = self.inner.clone();
168        if let Some(p) = pool {
169            if let Ok(client) = p.get() {
170                match msg {
171                    ActorMessage::Set { key, value, expiry } => {
172                        ActorResponse::Set(Box::pin(async move {
173                            let ex_key = format!("{}:expire", key);
174                            let now = SystemTime::now();
175                            let now = now.duration_since(UNIX_EPOCH).unwrap();
176                            let result = client.set(
177                                &key,
178                                value as u64,
179                                expiry.as_secs().try_into().unwrap(),
180                            );
181                            let val = now + expiry;
182                            let val: u64 = val.as_secs().try_into().unwrap();
183                            client
184                                .set(&ex_key, val, expiry.as_secs().try_into().unwrap())
185                                .unwrap();
186                            match result {
187                                Ok(_) => Ok(()),
188                                Err(e) => Err(ARError::ReadWriteError(format!("{:?}", &e))),
189                            }
190                        }))
191                    }
192                    ActorMessage::Update { key, value } => {
193                        ActorResponse::Update(Box::pin(async move {
194                            let result = client.decrement(&key, value as u64);
195                            match result {
196                                Ok(c) => Ok(c as usize),
197                                Err(e) => Err(ARError::ReadWriteError(format!("{:?}", &e))),
198                            }
199                        }))
200                    }
201                    ActorMessage::Get(key) => ActorResponse::Get(Box::pin(async move {
202                        let result: Result<Option<u64>, _> = client.get(&key);
203                        match result {
204                            Ok(c) => match c {
205                                Some(v) => Ok(Some(v as usize)),
206                                None => Ok(None),
207                            }
208                            Err(e) => Err(ARError::ReadWriteError(format!("{:?}", &e))),
209                        }
210                    })),
211                    ActorMessage::Expire(key) => ActorResponse::Expire(Box::pin(async move {
212                        let result: Result<Option<u64>, _> =
213                            client.get(&format!("{}:expire", &key));
214                        match result {
215                            Ok(c) => {
216                                if let Some(d) = c {
217                                    let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
218                                    let now = now.as_secs().try_into().unwrap();
219                                    let res = d.checked_sub(now).unwrap_or_else(|| 0);
220                                    Ok(Duration::from_secs(res))
221                                } else {
222                                    Err(ARError::ReadWriteError(
223                                        "error: expiration data not found".to_owned(),
224                                    ))
225                                }
226                            }
227                            Err(e) => Err(ARError::ReadWriteError(format!("{:?}", &e))),
228                        }
229                    })),
230                    ActorMessage::Remove(key) => ActorResponse::Remove(Box::pin(async move {
231                        let result = client.delete(&key);
232                        let _ = client.delete(&format!("{}:expire", &key));
233                        match result {
234                            Ok(_) => Ok(1),
235                            Err(e) => Err(ARError::ReadWriteError(format!("{:?}", &e))),
236                        }
237                    })),
238                }
239            } else {
240                ctx.stop();
241                ActorResponse::Set(Box::pin(async move { Err(ARError::Disconnected) }))
242            }
243        } else {
244            ctx.stop();
245            ActorResponse::Set(Box::pin(async move { Err(ARError::Disconnected) }))
246        }
247    }
248}
249
250
251
252
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    fn init() {
259        let _ = env_logger::builder().is_test(true).try_init();
260    }
261
262    #[actix_rt::test]
263    async fn test_set() {
264        init();
265        let store = MemcacheStore::connect("memcache://127.0.0.1:11211");
266        let addr = MemcacheStoreActor::from(store.clone()).start();
267        let res = addr
268            .send(ActorMessage::Set {
269                key: "hello".to_string(),
270                value: 30usize,
271                expiry: Duration::from_secs(5),
272            })
273            .await;
274        let res = res.expect("Failed to send msg");
275        match res {
276            ActorResponse::Set(c) => match c.await {
277                Ok(()) => {}
278                Err(e) => panic!("Shouldn't happen: {}", &e),
279            },
280            _ => panic!("Shouldn't happen!"),
281        }
282    }
283
284    #[actix_rt::test]
285    async fn test_get() {
286        init();
287        let store = MemcacheStore::connect("memcache://127.0.0.1:11211");
288        let addr = MemcacheStoreActor::from(store.clone()).start();
289        let expiry = Duration::from_secs(5);
290        let res = addr
291            .send(ActorMessage::Set {
292                key: "hello".to_string(),
293                value: 30usize,
294                expiry: expiry,
295            })
296            .await;
297        let res = res.expect("Failed to send msg");
298        match res {
299            ActorResponse::Set(c) => match c.await {
300                Ok(()) => {}
301                Err(e) => panic!("Shouldn't happen {}", &e),
302            },
303            _ => panic!("Shouldn't happen!"),
304        }
305        let res2 = addr.send(ActorMessage::Get("hello".to_string())).await;
306        let res2 = res2.expect("Failed to send msg");
307        match res2 {
308            ActorResponse::Get(c) => match c.await {
309                Ok(d) => {
310                    let d = d.unwrap();
311                    assert_eq!(d, 30usize);
312                }
313                Err(e) => panic!("Shouldn't happen {}", &e),
314            },
315            _ => panic!("Shouldn't happen!"),
316        };
317    }
318    
319    #[actix_rt::test]
320    async fn test_expiry() {
321        init();
322        let store = MemcacheStore::connect("memcache://127.0.0.1:11211");
323        let addr = MemcacheStoreActor::from(store.clone()).start();
324        let expiry = Duration::from_secs(3);
325        let res = addr
326            .send(ActorMessage::Set {
327                key: "hello_test".to_string(),
328                value: 30usize,
329                expiry: expiry,
330            })
331            .await;
332        let res = res.expect("Failed to send msg");
333        match res {
334            ActorResponse::Set(c) => match c.await {
335                Ok(()) => {}
336                Err(e) => panic!("Shouldn't happen {}", &e),
337            },
338            _ => panic!("Shouldn't happen!"),
339        }
340        assert_eq!(addr.connected(), true);
341
342        let res3 = addr
343            .send(ActorMessage::Expire("hello_test".to_string()))
344            .await;
345        let res3 = res3.expect("Failed to send msg");
346        match res3 {
347            ActorResponse::Expire(c) => match c.await {
348                Ok(dur) => {
349                    let now = Duration::from_secs(3);
350                    if dur > now {
351                        panic!("Shouldn't happen: {}, {}", &dur.as_secs(), &now.as_secs())
352                    }
353                }
354                Err(e) => {
355                    panic!("Shouldn't happen: {}", &e);
356                }
357            },
358            _ => panic!("Shouldn't happen!"),
359        };
360
361    }
362}