1use std::{marker::PhantomData, sync::Arc, collections::HashMap, time::Duration};
3
4use parking_lot::Mutex;
5use redis::AsyncCommands;
6use serde::{de::DeserializeOwned, Serialize};
7
8use crate::{RedisObjects, ErrorTypes, retry_call};
9
10const POP_SCRIPT: &str = r#"
11local result = redis.call('hget', ARGV[1], ARGV[2])
12if result then redis.call('hdel', ARGV[1], ARGV[2]) end
13return result
14"#;
15
16
17const CONDITIONAL_REMOVE_SCRIPT: &str = r#"
18local hash_name = KEYS[1]
19local key_in_hash = ARGV[1]
20local expected_value = ARGV[2]
21local result = redis.call('hget', hash_name, key_in_hash)
22if result == expected_value then
23 redis.call('hdel', hash_name, key_in_hash)
24 return 1
25end
26return 0
27"#;
28
29#[derive(Clone)]
44pub struct Hashmap<T> {
45 name: String,
46 store: Arc<RedisObjects>,
47 pop_script: redis::Script,
48conditional_remove_script: redis::Script,
50 ttl: Option<Duration>,
51 last_expire_time: Arc<Mutex<Option<std::time::Instant>>>,
52 _data: PhantomData<T>
53}
54
55impl<T: Serialize + DeserializeOwned> Hashmap<T> {
56 pub (crate) fn new(name: String, store: Arc<RedisObjects>, ttl: Option<Duration>) -> Self {
57 Self {
58 name,
59 store,
60 pop_script: redis::Script::new(POP_SCRIPT),
61 conditional_remove_script: redis::Script::new(CONDITIONAL_REMOVE_SCRIPT),
63 ttl,
64 last_expire_time: Arc::new(Mutex::new(None)),
65 _data: PhantomData,
66 }
67 }
68
69 async fn conditional_expire(&self) -> Result<(), ErrorTypes> {
71 if let Some(ttl) = self.ttl {
73 let call = {
74 let mut last_expire_time = self.last_expire_time.lock();
76
77 let call = match *last_expire_time {
80 Some(time) => {
81 time.elapsed() > (ttl / 4)
82 },
83 None => true };
85
86 if call {
87 *last_expire_time = Some(std::time::Instant::now());
90 }
91 call
92 };
93
94 if call {
95 let ttl = ttl.as_secs() as i64;
96 let _: () = retry_call!(self.store.pool, expire, &self.name, ttl)?;
97 }
98 }
99 Ok(())
100 }
101
102 pub async fn add(&self, key: &str, value: &T) -> Result<bool, ErrorTypes> {
110 let data = serde_json::to_vec(value)?;
111 let result = retry_call!(self.store.pool, hset_nx, &self.name, &key, &data)?;
112 self.conditional_expire().await?;
113 Ok(result)
114 }
115
116 pub async fn increment(&self, key: &str, increment: i64) -> Result<i64, ErrorTypes> {
118 let result = retry_call!(self.store.pool, hincr, &self.name, key, increment)?;
119 self.conditional_expire().await?;
120 Ok(result)
121 }
122
123 pub async fn exists(&self, key: &str) -> Result<bool, ErrorTypes> {
134 retry_call!(self.store.pool, hexists, &self.name, key)
135 }
136
137 pub async fn get(&self, key: &str) -> Result<Option<T>, ErrorTypes> {
139 let item: Option<Vec<u8>> = retry_call!(self.store.pool, hget, &self.name, key)?;
140 Ok(match item {
141 Some(data) => Some(serde_json::from_slice(&data)?),
142 None => None,
143 })
144 }
145
146 pub async fn get_raw(&self, key: &str) -> Result<Option<Vec<u8>>, ErrorTypes> {
148 Ok(retry_call!(self.store.pool, hget, &self.name, key)?)
149 }
150
151 pub async fn keys(&self) -> Result<Vec<String>, ErrorTypes> {
153 retry_call!(self.store.pool, hkeys, &self.name)
154 }
155
156 pub async fn length(&self) -> Result<u64, ErrorTypes> {
158 retry_call!(self.store.pool, hlen, &self.name)
159 }
160
161 pub async fn items(&self) -> Result<HashMap<String, T>, ErrorTypes> {
163 let items: Vec<(String, Vec<u8>)> = retry_call!(self.store.pool, hgetall, &self.name)?;
164 let mut out = HashMap::new();
165 for (key, data) in items {
166 out.insert(key, serde_json::from_slice(&data)?);
167 }
168 Ok(out)
169 }
170
171 pub async fn conditional_remove(&self, key: &str, value: &T) -> Result<bool, ErrorTypes> {
173 let data = serde_json::to_vec(value)?;
174 retry_call!(method, self.store.pool, self.conditional_remove_script.key(&self.name).arg(key).arg(&data), invoke_async)
175 }
176
177 pub async fn pop(&self, key: &str) -> Result<Option<T>, ErrorTypes> {
179 let item: Option<Vec<u8>> = retry_call!(method, self.store.pool, self.pop_script.arg(&self.name).arg(key), invoke_async)?;
180 Ok(match item {
181 Some(data) => Some(serde_json::from_slice(&data)?),
182 None => None,
183 })
184 }
185
186 pub async fn set(&self, key: &str, value: &T) -> Result<i64, ErrorTypes> {
188 let data = serde_json::to_vec(value)?;
189 let result = retry_call!(self.store.pool, hset, &self.name, key, &data)?;
190 self.conditional_expire().await?;
191 Ok(result)
192 }
193
194 pub async fn delete(&self) -> Result<(), ErrorTypes> {
203 retry_call!(self.store.pool, del, &self.name)
204 }
205
206}
207
208
209#[cfg(test)]
210mod test {
211 use crate::test::redis_connection;
212 use crate::ErrorTypes;
213 use std::time::Duration;
214
215 #[tokio::test]
216 async fn hash() -> Result<(), ErrorTypes> {
217 let redis = redis_connection().await;
218 let h = redis.hashmap("test-hashmap".to_string(), None);
219 h.delete().await?;
220
221 let value_string = "value".to_owned();
222 let new_value_string = "new-value".to_owned();
223
224 assert!(h.add("key", &value_string).await?);
225 assert!(!h.add("key", &value_string).await?);
226 assert!(h.exists("key").await?);
227 assert_eq!(h.get("key").await?.unwrap(), value_string);
228 assert_eq!(h.set("key", &new_value_string).await?, 0);
229 assert!(!h.add("key", &value_string).await?);
230 assert_eq!(h.keys().await?, ["key"]);
231 assert_eq!(h.length().await?, 1);
232 assert_eq!(h.items().await?, [("key".to_owned(), new_value_string.clone())].into_iter().collect());
233 assert_eq!(h.pop("key").await?.unwrap(), new_value_string);
234 assert_eq!(h.length().await?, 0);
235 assert!(h.add("key", &value_string).await?);
236 assert_eq!(h.increment("a", 1).await?, 1);
252 assert_eq!(h.increment("a", 1).await?, 2);
253 assert_eq!(h.increment("a", 10).await?, 12);
254 assert_eq!(h.increment("a", -22).await?, -10);
255 h.delete().await?;
256
257 Ok(())
267 }
268
269 #[tokio::test]
270 async fn expiring_hash() -> Result<(), ErrorTypes> {
271 let redis = redis_connection().await;
272 let eh = redis.hashmap("test-expiring-hashmap".to_string(), Duration::from_secs(1).into());
273 eh.delete().await?;
274 assert!(eh.add("key", &"value".to_owned()).await?);
275 assert_eq!(eh.length().await?, 1);
276 tokio::time::sleep(Duration::from_secs_f32(1.1)).await;
277 assert_eq!(eh.length().await?, 0);
278 Ok(())
279 }
280
281}