actix_ratelimit/stores/
memory.rs

1//! In memory store for rate limiting
2use actix::prelude::*;
3use dashmap::DashMap;
4use futures::future::{self};
5use log::*;
6use std::sync::Arc;
7use std::time::{Duration, SystemTime, UNIX_EPOCH};
8
9use crate::errors::ARError;
10use crate::{ActorMessage, ActorResponse};
11
12/// Type used to create a concurrent hashmap store
13#[derive(Clone)]
14pub struct MemoryStore {
15    inner: Arc<DashMap<String, (usize, Duration)>>,
16}
17
18impl MemoryStore {
19    /// Create a new hashmap
20    ///
21    /// # Example
22    /// ```rust
23    /// use actix_ratelimit::MemoryStore;
24    ///
25    /// let store = MemoryStore::new();
26    /// ```
27    pub fn new() -> Self {
28        debug!("Creating new MemoryStore");
29        MemoryStore {
30            inner: Arc::new(DashMap::<String, (usize, Duration)>::new()),
31        }
32    }
33
34    /// Create a new hashmap with the provided capacity
35    pub fn with_capacity(capacity: usize) -> Self {
36        debug!("Creating new MemoryStore");
37        MemoryStore {
38            inner: Arc::new(DashMap::<String, (usize, Duration)>::with_capacity(
39                capacity,
40            )),
41        }
42    }
43}
44
45/// Actor for memory store
46pub struct MemoryStoreActor {
47    inner: Arc<DashMap<String, (usize, Duration)>>,
48}
49
50impl From<MemoryStore> for MemoryStoreActor {
51    fn from(store: MemoryStore) -> Self {
52        MemoryStoreActor { inner: store.inner }
53    }
54}
55
56impl MemoryStoreActor {
57    /// Starts the memory actor and returns it's address
58    pub fn start(self) -> Addr<Self> {
59        debug!("Started memory store");
60        Supervisor::start(|_| self)
61    }
62}
63
64impl Actor for MemoryStoreActor {
65    type Context = Context<Self>;
66}
67
68impl Supervised for MemoryStoreActor {
69    fn restarting(&mut self, _: &mut Self::Context) {
70        debug!("Restarting memory store");
71    }
72}
73
74impl Handler<ActorMessage> for MemoryStoreActor {
75    type Result = ActorResponse;
76    fn handle(&mut self, msg: ActorMessage, ctx: &mut Self::Context) -> Self::Result {
77        match msg {
78            ActorMessage::Set { key, value, expiry } => {
79                debug!("Inserting key {} with expiry {}", &key, &expiry.as_secs());
80                let future_key = String::from(&key);
81                let now = SystemTime::now();
82                let now = now.duration_since(UNIX_EPOCH).unwrap();
83                self.inner.insert(key, (value, now + expiry));
84                ctx.notify_later(ActorMessage::Remove(future_key), expiry);
85                ActorResponse::Set(Box::pin(future::ready(Ok(()))))
86            }
87            ActorMessage::Update { key, value } => match self.inner.get_mut(&key) {
88                Some(mut c) => {
89                    let val_mut: &mut (usize, Duration) = c.value_mut();
90                    if val_mut.0 > value {
91                        val_mut.0 -= value;
92                    } else {
93                        val_mut.0 = 0;
94                    }
95                    let new_val = val_mut.0;
96                    ActorResponse::Update(Box::pin(future::ready(Ok(new_val))))
97                }
98                None => {
99                    return ActorResponse::Update(Box::pin(future::ready(Err(
100                        ARError::ReadWriteError("memory store: read failed!".to_string()),
101                    ))))
102                }
103            },
104            ActorMessage::Get(key) => {
105                if self.inner.contains_key(&key) {
106                    let val = match self.inner.get(&key) {
107                        Some(c) => c,
108                        None => {
109                            return ActorResponse::Get(Box::pin(future::ready(Err(
110                                ARError::ReadWriteError("memory store: read failed!".to_string()),
111                            ))))
112                        }
113                    };
114                    let val = val.value().0;
115                    ActorResponse::Get(Box::pin(future::ready(Ok(Some(val)))))
116                } else {
117                    ActorResponse::Get(Box::pin(future::ready(Ok(None))))
118                }
119            }
120            ActorMessage::Expire(key) => {
121                let c = match self.inner.get(&key) {
122                    Some(d) => d,
123                    None => {
124                        return ActorResponse::Expire(Box::pin(future::ready(Err(
125                            ARError::ReadWriteError("memory store: read failed!".to_string()),
126                        ))))
127                    }
128                };
129                let dur = c.value().1;
130                let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
131                let res = dur.checked_sub(now).unwrap_or_else(|| Duration::new(0, 0));
132                ActorResponse::Expire(Box::pin(future::ready(Ok(res))))
133            }
134            ActorMessage::Remove(key) => {
135                debug!("Removing key: {}", &key);
136                let val = match self.inner.remove::<String>(&key) {
137                    Some(c) => c,
138                    None => {
139                        return ActorResponse::Remove(Box::pin(future::ready(Err(
140                            ARError::ReadWriteError("memory store: remove failed!".to_string()),
141                        ))))
142                    }
143                };
144                let val = val.1;
145                ActorResponse::Remove(Box::pin(future::ready(Ok(val.0))))
146            }
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[actix_rt::test]
156    async fn test_set() {
157        let store = MemoryStore::new();
158        let addr = MemoryStoreActor::from(store.clone()).start();
159        let res = addr
160            .send(ActorMessage::Set {
161                key: "hello".to_string(),
162                value: 30usize,
163                expiry: Duration::from_secs(5),
164            })
165            .await;
166        let res = res.expect("Failed to send msg");
167        match res {
168            ActorResponse::Set(c) => match c.await {
169                Ok(()) => {}
170                Err(e) => panic!("Shouldn't happen {}", &e),
171            },
172            _ => panic!("Shouldn't happen!"),
173        }
174    }
175
176    #[actix_rt::test]
177    async fn test_get() {
178        let store = MemoryStore::new();
179        let addr = MemoryStoreActor::from(store.clone()).start();
180        let expiry = Duration::from_secs(5);
181        let res = addr
182            .send(ActorMessage::Set {
183                key: "hello".to_string(),
184                value: 30usize,
185                expiry: expiry,
186            })
187            .await;
188        let res = res.expect("Failed to send msg");
189        match res {
190            ActorResponse::Set(c) => match c.await {
191                Ok(()) => {}
192                Err(e) => panic!("Shouldn't happen {}", &e),
193            },
194            _ => panic!("Shouldn't happen!"),
195        }
196        let res2 = addr.send(ActorMessage::Get("hello".to_string())).await;
197        let res2 = res2.expect("Failed to send msg");
198        match res2 {
199            ActorResponse::Get(c) => match c.await {
200                Ok(d) => {
201                    let d = d.unwrap();
202                    assert_eq!(d, 30usize);
203                }
204                Err(e) => panic!("Shouldn't happen {}", &e),
205            },
206            _ => panic!("Shouldn't happen!"),
207        };
208    }
209
210    #[actix_rt::test]
211    async fn test_expiry() {
212        let store = MemoryStore::new();
213        let addr = MemoryStoreActor::from(store.clone()).start();
214        let expiry = Duration::from_secs(3);
215        let res = addr
216            .send(ActorMessage::Set {
217                key: "hello".to_string(),
218                value: 30usize,
219                expiry: expiry,
220            })
221            .await;
222        let res = res.expect("Failed to send msg");
223        match res {
224            ActorResponse::Set(c) => match c.await {
225                Ok(()) => {}
226                Err(e) => panic!("Shouldn't happen {}", &e),
227            },
228            _ => panic!("Shouldn't happen!"),
229        }
230        assert_eq!(addr.connected(), true);
231
232        let res3 = addr.send(ActorMessage::Expire("hello".to_string())).await;
233        let res3 = res3.expect("Failed to send msg");
234        match res3 {
235            ActorResponse::Expire(c) => match c.await {
236                Ok(dur) => {
237                    let now = Duration::from_secs(3);
238                    if dur > now {
239                        panic!("Expiry is invalid!");
240                    } else if dur > now + Duration::from_secs(4) {
241                        panic!("Expiry is invalid!");
242                    }
243                }
244                Err(e) => {
245                    panic!("Shouldn't happen: {}", &e);
246                }
247            },
248            _ => panic!("Shouldn't happen!"),
249        };
250    }
251}