redis_objects/
hashmap.rs

1//! A hash map stored under a single redis key.
2use std::{marker::PhantomData, sync::Arc, collections::HashMap, time::Duration};
3
4use parking_lot::Mutex;
5use redis::AsyncCommands;
6use serde::{de::DeserializeOwned, Serialize};
7use tracing::instrument;
8
9use crate::{RedisObjects, ErrorTypes, retry_call};
10
11const POP_SCRIPT: &str = r#"
12local result = redis.call('hget', ARGV[1], ARGV[2])
13if result then redis.call('hdel', ARGV[1], ARGV[2]) end
14return result
15"#;
16
17
18const CONDITIONAL_REMOVE_SCRIPT: &str = r#"
19local hash_name = KEYS[1]
20local key_in_hash = ARGV[1]
21local expected_value = ARGV[2]
22local result = redis.call('hget', hash_name, key_in_hash)
23if result == expected_value then
24    redis.call('hdel', hash_name, key_in_hash)
25    return 1
26end
27return 0
28"#;
29
30// const LIMITED_ADD: &str = r#"
31// local set_name = KEYS[1]
32// local key = ARGV[1]
33// local value = ARGV[2]
34// local limit = tonumber(ARGV[3])
35
36// if redis.call('hlen', set_name) < limit then
37//     return redis.call('hsetnx', set_name, key, value)
38// end
39// return nil
40// "#;
41
42
43/// Hashmap opened by `RedisObjects::hashmap`
44#[derive(Clone)]
45pub struct Hashmap<T> {
46    name: String,
47    store: Arc<RedisObjects>,
48    pop_script: redis::Script,
49//     self._limited_add = self.c.register_script(_limited_add)
50    conditional_remove_script: redis::Script,
51    ttl: Option<Duration>,
52    last_expire_time: Arc<Mutex<Option<std::time::Instant>>>,
53    _data: PhantomData<T>
54}
55
56impl<T> std::fmt::Debug for Hashmap<T> {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("Hashmp").field("store", &self.store).field("name", &self.name).finish()
59    }
60}
61
62impl<T: Serialize + DeserializeOwned> Hashmap<T> {
63    pub (crate) fn new(name: String, store: Arc<RedisObjects>, ttl: Option<Duration>) -> Self {
64        Self {
65            name,
66            store,
67            pop_script: redis::Script::new(POP_SCRIPT),
68    //     self._limited_add = self.c.register_script(_limited_add)
69            conditional_remove_script: redis::Script::new(CONDITIONAL_REMOVE_SCRIPT),
70            ttl,
71            last_expire_time: Arc::new(Mutex::new(None)),
72            _data: PhantomData,
73        }
74    }
75
76    /// set the expiry in redis but only if we haven't called it recently
77    async fn conditional_expire(&self) -> Result<(), ErrorTypes> {
78        // load the ttl of this object has one set
79        if let Some(ttl) = self.ttl {
80            let call = {
81                // the last expire time is behind a mutex so that the queue object is threadsafe
82                let mut last_expire_time = self.last_expire_time.lock();
83
84                // figure out if its time to update the expiry, wait until we are 25% through the
85                // ttl to avoid resetting something only milliseconds old
86                let call = match *last_expire_time {
87                    Some(time) => {
88                        time.elapsed() > (ttl / 4)
89                    },
90                    None => true // always update the expiry if we haven't run it before on this object
91                };
92
93                if call {
94                    // update the time in the mutex then drop it so we aren't holding the lock 
95                    // while we make the call to the redis server
96                    *last_expire_time = Some(std::time::Instant::now());
97                }
98                call
99            };
100
101            if call {
102                let ttl = ttl.as_secs() as i64;
103                let _: () = retry_call!(self.store.pool, expire, &self.name, ttl)?;
104            }
105        }
106        Ok(())
107    }
108
109    // def __iter__(self):
110    //     return HashIterator(self)
111
112
113    /// Add the (key, value) pair to the hash for new keys.
114    /// If a key already exists this operation doesn't add it.
115    /// Returns true if key has been added to the table, False otherwise.
116    #[instrument(skip(value))]
117    pub async fn add(&self, key: &str, value: &T) -> Result<bool, ErrorTypes> {    
118        let data = serde_json::to_vec(value)?;
119        let result = retry_call!(self.store.pool, hset_nx, &self.name, &key, &data)?;
120        self.conditional_expire().await?;
121        Ok(result)
122    }
123
124    /// Increment a key within a hash by the given delta
125    #[instrument]
126    pub async fn increment(&self, key: &str, increment: i64) -> Result<i64, ErrorTypes> {
127        let result = retry_call!(self.store.pool, hincr, &self.name, key, increment)?;
128        self.conditional_expire().await?;
129        Ok(result)
130    }
131
132    // def limited_add(self, key, value, size_limit):
133    //     """Add a single value to the set, but only if that wouldn't make the set grow past a given size.
134
135    //     If the hash has hit the size limit returns None
136    //     Otherwise, returns the result of hsetnx (same as `add`)
137    //     """
138    //     self._conditional_expire()
139    //     return retry_call(self._limited_add, keys=[self.name], args=[key, json.dumps(value), size_limit])
140
141    /// Test if a given key is defind within this hash
142    #[instrument]
143    pub async fn exists(&self, key: &str) -> Result<bool, ErrorTypes> {
144        retry_call!(self.store.pool, hexists, &self.name, key)
145    }
146
147    /// Read the value stored at the given key
148    #[instrument]
149    pub async fn get(&self, key: &str) -> Result<Option<T>, ErrorTypes> {
150        let item: Option<Vec<u8>> = retry_call!(self.store.pool, hget, &self.name, key)?;
151        Ok(match item {
152            Some(data) => Some(serde_json::from_slice(&data)?),
153            None => None,
154        })
155    }
156
157    /// Read the value stored at the given key
158    #[instrument]
159    pub async fn get_raw(&self, key: &str) -> Result<Option<Vec<u8>>, ErrorTypes> {
160        Ok(retry_call!(self.store.pool, hget, &self.name, key)?)
161    }
162
163    /// Load all keys from the hash
164    #[instrument]
165    pub async fn keys(&self) -> Result<Vec<String>, ErrorTypes> {
166        retry_call!(self.store.pool, hkeys, &self.name)
167    }
168
169    /// Read the number of items in the hash
170    #[instrument]
171    pub async fn length(&self) -> Result<u64, ErrorTypes> {
172        retry_call!(self.store.pool, hlen, &self.name)
173    }
174
175    /// Download the entire hash into memory
176    #[instrument]
177    pub async fn items(&self) -> Result<HashMap<String, T>, ErrorTypes> {
178        let items: Vec<(String, Vec<u8>)> = retry_call!(self.store.pool, hgetall, &self.name)?;
179        let mut out = HashMap::new();
180        for (key, data) in items {
181            out.insert(key, serde_json::from_slice(&data)?);
182        }
183        Ok(out)
184    }
185
186    /// Remove an item, but only if its value is as given
187    #[instrument(skip(value))]
188    pub async fn conditional_remove(&self, key: &str, value: &T) -> Result<bool, ErrorTypes> {
189        let data = serde_json::to_vec(value)?;
190        retry_call!(method, self.store.pool, self.conditional_remove_script.key(&self.name).arg(key).arg(&data), invoke_async)
191    }
192
193    /// Remove and return the item in the hash if found
194    #[instrument]
195    pub async fn pop(&self, key: &str) -> Result<Option<T>, ErrorTypes> {
196        let item: Option<Vec<u8>>  = retry_call!(method, self.store.pool, self.pop_script.arg(&self.name).arg(key), invoke_async)?;
197        Ok(match item {
198            Some(data) => Some(serde_json::from_slice(&data)?),
199            None => None,
200        })
201    }
202
203    /// Unconditionally overwrite the value stored at a given key
204    #[instrument(skip(value))]
205    pub async fn set(&self, key: &str, value: &T) -> Result<i64, ErrorTypes> {
206        let data = serde_json::to_vec(value)?;
207        let result = retry_call!(self.store.pool, hset, &self.name, key, &data)?;
208        self.conditional_expire().await?;
209        Ok(result)
210    }
211
212    // def multi_set(self, data: dict[str, T]):
213    //     if any(isinstance(key, bytes) for key in data.keys()):
214    //         raise ValueError("Cannot use bytes for hashmap keys")
215    //     encoded = {key: json.dumps(value) for key, value in data.items()}
216    //     self._conditional_expire()
217    //     return retry_call(self.c.hset, self.name, mapping=encoded)
218
219    /// Clear the content of this hash
220    #[instrument]
221    pub async fn delete(&self) -> Result<(), ErrorTypes> {
222        retry_call!(self.store.pool, del, &self.name)
223    }
224
225}
226
227
228#[cfg(test)]
229mod test {
230    use crate::test::redis_connection;
231    use crate::ErrorTypes;
232    use std::time::Duration;
233
234    #[tokio::test]
235    async fn hash() -> Result<(), ErrorTypes> {
236        let redis = redis_connection().await;
237        let h = redis.hashmap("test-hashmap".to_string(), None);
238        h.delete().await?;
239
240        let value_string = "value".to_owned();
241        let new_value_string = "new-value".to_owned();
242
243        assert!(h.add("key", &value_string).await?);
244        assert!(!h.add("key", &value_string).await?);
245        assert!(h.exists("key").await?);
246        assert_eq!(h.get("key").await?.unwrap(), value_string);
247        assert_eq!(h.set("key", &new_value_string).await?, 0);
248        assert!(!h.add("key", &value_string).await?);
249        assert_eq!(h.keys().await?, ["key"]);
250        assert_eq!(h.length().await?, 1);
251        assert_eq!(h.items().await?, [("key".to_owned(), new_value_string.clone())].into_iter().collect());
252        assert_eq!(h.pop("key").await?.unwrap(), new_value_string);
253        assert_eq!(h.length().await?, 0);
254        assert!(h.add("key", &value_string).await?);
255        // assert h.conditional_remove("key", "value1") is False
256        // assert h.conditional_remove("key", "value") is True
257        // assert h.length(), 0
258
259        // // Make sure we can limit the size of a hash table
260        // assert h.limited_add("a", 1, 2) == 1
261        // assert h.limited_add("a", 1, 2) == 0
262        // assert h.length() == 1
263        // assert h.limited_add("b", 10, 2) == 1
264        // assert h.length() == 2
265        // assert h.limited_add("c", 1, 2) is None
266        // assert h.length() == 2
267        // assert h.pop("a")
268
269        // Can we increment integer values in the hash
270        assert_eq!(h.increment("a", 1).await?, 1);
271        assert_eq!(h.increment("a", 1).await?, 2);
272        assert_eq!(h.increment("a", 10).await?, 12);
273        assert_eq!(h.increment("a", -22).await?, -10);
274        h.delete().await?;
275
276        // // Load a bunch of items and test iteration
277        // data_before = [''.join(_x) for _x in itertools.product('abcde', repeat=5)]
278        // data_before = {_x: _x + _x for _x in data_before}
279        // h.multi_set(data_before)
280
281        // data_after = {}
282        // for key, value in h:
283        //     data_after[key] = value
284        // assert data_before == data_after
285        Ok(())
286    }
287
288    #[tokio::test] 
289    async fn expiring_hash() -> Result<(), ErrorTypes> {
290        let redis = redis_connection().await;
291        let eh = redis.hashmap("test-expiring-hashmap".to_string(), Duration::from_secs(1).into());
292        eh.delete().await?;
293        assert!(eh.add("key", &"value".to_owned()).await?);
294        assert_eq!(eh.length().await?, 1);
295        tokio::time::sleep(Duration::from_secs_f32(1.1)).await;
296        assert_eq!(eh.length().await?, 0);
297        Ok(())
298    }
299
300}