redis_objects/
hashmap.rs

1//! A hash map stored under a single redis key.
2use std::{marker::PhantomData, sync::Arc, collections::HashMap, time::Duration};
3
4use parking_lot::Mutex;
5use redis::AsyncCommands;
6use serde::{de::DeserializeOwned, Serialize};
7
8use crate::{RedisObjects, ErrorTypes, retry_call};
9
10const POP_SCRIPT: &str = r#"
11local result = redis.call('hget', ARGV[1], ARGV[2])
12if result then redis.call('hdel', ARGV[1], ARGV[2]) end
13return result
14"#;
15
16
17const CONDITIONAL_REMOVE_SCRIPT: &str = r#"
18local hash_name = KEYS[1]
19local key_in_hash = ARGV[1]
20local expected_value = ARGV[2]
21local result = redis.call('hget', hash_name, key_in_hash)
22if result == expected_value then
23    redis.call('hdel', hash_name, key_in_hash)
24    return 1
25end
26return 0
27"#;
28
29// const LIMITED_ADD: &str = r#"
30// local set_name = KEYS[1]
31// local key = ARGV[1]
32// local value = ARGV[2]
33// local limit = tonumber(ARGV[3])
34
35// if redis.call('hlen', set_name) < limit then
36//     return redis.call('hsetnx', set_name, key, value)
37// end
38// return nil
39// "#;
40
41
42/// Hashmap opened by `RedisObjects::hashmap`
43#[derive(Clone)]
44pub struct Hashmap<T> {
45    name: String,
46    store: Arc<RedisObjects>,
47    pop_script: redis::Script,
48//     self._limited_add = self.c.register_script(_limited_add)
49    conditional_remove_script: redis::Script,
50    ttl: Option<Duration>,
51    last_expire_time: Arc<Mutex<Option<std::time::Instant>>>,
52    _data: PhantomData<T>
53}
54
55impl<T: Serialize + DeserializeOwned> Hashmap<T> {
56    pub (crate) fn new(name: String, store: Arc<RedisObjects>, ttl: Option<Duration>) -> Self {
57        Self {
58            name,
59            store,
60            pop_script: redis::Script::new(POP_SCRIPT),
61    //     self._limited_add = self.c.register_script(_limited_add)
62            conditional_remove_script: redis::Script::new(CONDITIONAL_REMOVE_SCRIPT),
63            ttl,
64            last_expire_time: Arc::new(Mutex::new(None)),
65            _data: PhantomData,
66        }
67    }
68
69    /// set the expiry in redis but only if we haven't called it recently
70    async fn conditional_expire(&self) -> Result<(), ErrorTypes> {
71        // load the ttl of this object has one set
72        if let Some(ttl) = self.ttl {
73            let call = {
74                // the last expire time is behind a mutex so that the queue object is threadsafe
75                let mut last_expire_time = self.last_expire_time.lock();
76
77                // figure out if its time to update the expiry, wait until we are 25% through the
78                // ttl to avoid resetting something only milliseconds old
79                let call = match *last_expire_time {
80                    Some(time) => {
81                        time.elapsed() > (ttl / 4)
82                    },
83                    None => true // always update the expiry if we haven't run it before on this object
84                };
85
86                if call {
87                    // update the time in the mutex then drop it so we aren't holding the lock 
88                    // while we make the call to the redis server
89                    *last_expire_time = Some(std::time::Instant::now());
90                }
91                call
92            };
93
94            if call {
95                let ttl = ttl.as_secs() as i64;
96                let _: () = retry_call!(self.store.pool, expire, &self.name, ttl)?;
97            }
98        }
99        Ok(())
100    }
101
102    // def __iter__(self):
103    //     return HashIterator(self)
104
105
106    /// Add the (key, value) pair to the hash for new keys.
107    /// If a key already exists this operation doesn't add it.
108    /// Returns true if key has been added to the table, False otherwise.
109    pub async fn add(&self, key: &str, value: &T) -> Result<bool, ErrorTypes> {    
110        let data = serde_json::to_vec(value)?;
111        let result = retry_call!(self.store.pool, hset_nx, &self.name, &key, &data)?;
112        self.conditional_expire().await?;
113        Ok(result)
114    }
115
116    /// Increment a key within a hash by the given delta
117    pub async fn increment(&self, key: &str, increment: i64) -> Result<i64, ErrorTypes> {
118        let result = retry_call!(self.store.pool, hincr, &self.name, key, increment)?;
119        self.conditional_expire().await?;
120        Ok(result)
121    }
122
123    // def limited_add(self, key, value, size_limit):
124    //     """Add a single value to the set, but only if that wouldn't make the set grow past a given size.
125
126    //     If the hash has hit the size limit returns None
127    //     Otherwise, returns the result of hsetnx (same as `add`)
128    //     """
129    //     self._conditional_expire()
130    //     return retry_call(self._limited_add, keys=[self.name], args=[key, json.dumps(value), size_limit])
131
132    /// Test if a given key is defind within this hash
133    pub async fn exists(&self, key: &str) -> Result<bool, ErrorTypes> {
134        retry_call!(self.store.pool, hexists, &self.name, key)
135    }
136
137    /// Read the value stored at the given key
138    pub async fn get(&self, key: &str) -> Result<Option<T>, ErrorTypes> {
139        let item: Option<Vec<u8>> = retry_call!(self.store.pool, hget, &self.name, key)?;
140        Ok(match item {
141            Some(data) => Some(serde_json::from_slice(&data)?),
142            None => None,
143        })
144    }
145
146    /// Read the value stored at the given key
147    pub async fn get_raw(&self, key: &str) -> Result<Option<Vec<u8>>, ErrorTypes> {
148        Ok(retry_call!(self.store.pool, hget, &self.name, key)?)
149    }
150
151    /// Load all keys from the hash
152    pub async fn keys(&self) -> Result<Vec<String>, ErrorTypes> {
153        retry_call!(self.store.pool, hkeys, &self.name)
154    }
155
156    /// Read the number of items in the hash
157    pub async fn length(&self) -> Result<u64, ErrorTypes> {
158        retry_call!(self.store.pool, hlen, &self.name)
159    }
160
161    /// Download the entire hash into memory
162    pub async fn items(&self) -> Result<HashMap<String, T>, ErrorTypes> {
163        let items: Vec<(String, Vec<u8>)> = retry_call!(self.store.pool, hgetall, &self.name)?;
164        let mut out = HashMap::new();
165        for (key, data) in items {
166            out.insert(key, serde_json::from_slice(&data)?);
167        }
168        Ok(out)
169    }
170
171    /// Remove an item, but only if its value is as given
172    pub async fn conditional_remove(&self, key: &str, value: &T) -> Result<bool, ErrorTypes> {
173        let data = serde_json::to_vec(value)?;
174        retry_call!(method, self.store.pool, self.conditional_remove_script.key(&self.name).arg(key).arg(&data), invoke_async)
175    }
176
177    /// Remove and return the item in the hash if found
178    pub async fn pop(&self, key: &str) -> Result<Option<T>, ErrorTypes> {
179        let item: Option<Vec<u8>>  = retry_call!(method, self.store.pool, self.pop_script.arg(&self.name).arg(key), invoke_async)?;
180        Ok(match item {
181            Some(data) => Some(serde_json::from_slice(&data)?),
182            None => None,
183        })
184    }
185
186    /// Unconditionally overwrite the value stored at a given key
187    pub async fn set(&self, key: &str, value: &T) -> Result<i64, ErrorTypes> {
188        let data = serde_json::to_vec(value)?;
189        let result = retry_call!(self.store.pool, hset, &self.name, key, &data)?;
190        self.conditional_expire().await?;
191        Ok(result)
192    }
193
194    // def multi_set(self, data: dict[str, T]):
195    //     if any(isinstance(key, bytes) for key in data.keys()):
196    //         raise ValueError("Cannot use bytes for hashmap keys")
197    //     encoded = {key: json.dumps(value) for key, value in data.items()}
198    //     self._conditional_expire()
199    //     return retry_call(self.c.hset, self.name, mapping=encoded)
200
201    /// Clear the content of this hash
202    pub async fn delete(&self) -> Result<(), ErrorTypes> {
203        retry_call!(self.store.pool, del, &self.name)
204    }
205
206}
207
208
209#[cfg(test)]
210mod test {
211    use crate::test::redis_connection;
212    use crate::ErrorTypes;
213    use std::time::Duration;
214
215    #[tokio::test]
216    async fn hash() -> Result<(), ErrorTypes> {
217        let redis = redis_connection().await;
218        let h = redis.hashmap("test-hashmap".to_string(), None);
219        h.delete().await?;
220
221        let value_string = "value".to_owned();
222        let new_value_string = "new-value".to_owned();
223
224        assert!(h.add("key", &value_string).await?);
225        assert!(!h.add("key", &value_string).await?);
226        assert!(h.exists("key").await?);
227        assert_eq!(h.get("key").await?.unwrap(), value_string);
228        assert_eq!(h.set("key", &new_value_string).await?, 0);
229        assert!(!h.add("key", &value_string).await?);
230        assert_eq!(h.keys().await?, ["key"]);
231        assert_eq!(h.length().await?, 1);
232        assert_eq!(h.items().await?, [("key".to_owned(), new_value_string.clone())].into_iter().collect());
233        assert_eq!(h.pop("key").await?.unwrap(), new_value_string);
234        assert_eq!(h.length().await?, 0);
235        assert!(h.add("key", &value_string).await?);
236        // assert h.conditional_remove("key", "value1") is False
237        // assert h.conditional_remove("key", "value") is True
238        // assert h.length(), 0
239
240        // // Make sure we can limit the size of a hash table
241        // assert h.limited_add("a", 1, 2) == 1
242        // assert h.limited_add("a", 1, 2) == 0
243        // assert h.length() == 1
244        // assert h.limited_add("b", 10, 2) == 1
245        // assert h.length() == 2
246        // assert h.limited_add("c", 1, 2) is None
247        // assert h.length() == 2
248        // assert h.pop("a")
249
250        // Can we increment integer values in the hash
251        assert_eq!(h.increment("a", 1).await?, 1);
252        assert_eq!(h.increment("a", 1).await?, 2);
253        assert_eq!(h.increment("a", 10).await?, 12);
254        assert_eq!(h.increment("a", -22).await?, -10);
255        h.delete().await?;
256
257        // // Load a bunch of items and test iteration
258        // data_before = [''.join(_x) for _x in itertools.product('abcde', repeat=5)]
259        // data_before = {_x: _x + _x for _x in data_before}
260        // h.multi_set(data_before)
261
262        // data_after = {}
263        // for key, value in h:
264        //     data_after[key] = value
265        // assert data_before == data_after
266        Ok(())
267    }
268
269    #[tokio::test] 
270    async fn expiring_hash() -> Result<(), ErrorTypes> {
271        let redis = redis_connection().await;
272        let eh = redis.hashmap("test-expiring-hashmap".to_string(), Duration::from_secs(1).into());
273        eh.delete().await?;
274        assert!(eh.add("key", &"value".to_owned()).await?);
275        assert_eq!(eh.length().await?, 1);
276        tokio::time::sleep(Duration::from_secs_f32(1.1)).await;
277        assert_eq!(eh.length().await?, 0);
278        Ok(())
279    }
280
281}