1use std::{marker::PhantomData, sync::Arc, collections::HashMap, time::Duration};
3
4use parking_lot::Mutex;
5use redis::AsyncCommands;
6use serde::{de::DeserializeOwned, Serialize};
7use tracing::instrument;
8
9use crate::{RedisObjects, ErrorTypes, retry_call};
10
11const POP_SCRIPT: &str = r#"
12local result = redis.call('hget', ARGV[1], ARGV[2])
13if result then redis.call('hdel', ARGV[1], ARGV[2]) end
14return result
15"#;
16
17
18const CONDITIONAL_REMOVE_SCRIPT: &str = r#"
19local hash_name = KEYS[1]
20local key_in_hash = ARGV[1]
21local expected_value = ARGV[2]
22local result = redis.call('hget', hash_name, key_in_hash)
23if result == expected_value then
24 redis.call('hdel', hash_name, key_in_hash)
25 return 1
26end
27return 0
28"#;
29
30#[derive(Clone)]
45pub struct Hashmap<T> {
46 name: String,
47 store: Arc<RedisObjects>,
48 pop_script: redis::Script,
49conditional_remove_script: redis::Script,
51 ttl: Option<Duration>,
52 last_expire_time: Arc<Mutex<Option<std::time::Instant>>>,
53 _data: PhantomData<T>
54}
55
56impl<T> std::fmt::Debug for Hashmap<T> {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("Hashmp").field("store", &self.store).field("name", &self.name).finish()
59 }
60}
61
62impl<T: Serialize + DeserializeOwned> Hashmap<T> {
63 pub (crate) fn new(name: String, store: Arc<RedisObjects>, ttl: Option<Duration>) -> Self {
64 Self {
65 name,
66 store,
67 pop_script: redis::Script::new(POP_SCRIPT),
68 conditional_remove_script: redis::Script::new(CONDITIONAL_REMOVE_SCRIPT),
70 ttl,
71 last_expire_time: Arc::new(Mutex::new(None)),
72 _data: PhantomData,
73 }
74 }
75
76 async fn conditional_expire(&self) -> Result<(), ErrorTypes> {
78 if let Some(ttl) = self.ttl {
80 let call = {
81 let mut last_expire_time = self.last_expire_time.lock();
83
84 let call = match *last_expire_time {
87 Some(time) => {
88 time.elapsed() > (ttl / 4)
89 },
90 None => true };
92
93 if call {
94 *last_expire_time = Some(std::time::Instant::now());
97 }
98 call
99 };
100
101 if call {
102 let ttl = ttl.as_secs() as i64;
103 let _: () = retry_call!(self.store.pool, expire, &self.name, ttl)?;
104 }
105 }
106 Ok(())
107 }
108
109 #[instrument(skip(value))]
117 pub async fn add(&self, key: &str, value: &T) -> Result<bool, ErrorTypes> {
118 let data = serde_json::to_vec(value)?;
119 let result = retry_call!(self.store.pool, hset_nx, &self.name, &key, &data)?;
120 self.conditional_expire().await?;
121 Ok(result)
122 }
123
124 #[instrument]
126 pub async fn increment(&self, key: &str, increment: i64) -> Result<i64, ErrorTypes> {
127 let result = retry_call!(self.store.pool, hincr, &self.name, key, increment)?;
128 self.conditional_expire().await?;
129 Ok(result)
130 }
131
132 #[instrument]
143 pub async fn exists(&self, key: &str) -> Result<bool, ErrorTypes> {
144 retry_call!(self.store.pool, hexists, &self.name, key)
145 }
146
147 #[instrument]
149 pub async fn get(&self, key: &str) -> Result<Option<T>, ErrorTypes> {
150 let item: Option<Vec<u8>> = retry_call!(self.store.pool, hget, &self.name, key)?;
151 Ok(match item {
152 Some(data) => Some(serde_json::from_slice(&data)?),
153 None => None,
154 })
155 }
156
157 #[instrument]
159 pub async fn get_raw(&self, key: &str) -> Result<Option<Vec<u8>>, ErrorTypes> {
160 Ok(retry_call!(self.store.pool, hget, &self.name, key)?)
161 }
162
163 #[instrument]
165 pub async fn keys(&self) -> Result<Vec<String>, ErrorTypes> {
166 retry_call!(self.store.pool, hkeys, &self.name)
167 }
168
169 #[instrument]
171 pub async fn length(&self) -> Result<u64, ErrorTypes> {
172 retry_call!(self.store.pool, hlen, &self.name)
173 }
174
175 #[instrument]
177 pub async fn items(&self) -> Result<HashMap<String, T>, ErrorTypes> {
178 let items: Vec<(String, Vec<u8>)> = retry_call!(self.store.pool, hgetall, &self.name)?;
179 let mut out = HashMap::new();
180 for (key, data) in items {
181 out.insert(key, serde_json::from_slice(&data)?);
182 }
183 Ok(out)
184 }
185
186 #[instrument(skip(value))]
188 pub async fn conditional_remove(&self, key: &str, value: &T) -> Result<bool, ErrorTypes> {
189 let data = serde_json::to_vec(value)?;
190 retry_call!(method, self.store.pool, self.conditional_remove_script.key(&self.name).arg(key).arg(&data), invoke_async)
191 }
192
193 #[instrument]
195 pub async fn pop(&self, key: &str) -> Result<Option<T>, ErrorTypes> {
196 let item: Option<Vec<u8>> = retry_call!(method, self.store.pool, self.pop_script.arg(&self.name).arg(key), invoke_async)?;
197 Ok(match item {
198 Some(data) => Some(serde_json::from_slice(&data)?),
199 None => None,
200 })
201 }
202
203 #[instrument(skip(value))]
205 pub async fn set(&self, key: &str, value: &T) -> Result<i64, ErrorTypes> {
206 let data = serde_json::to_vec(value)?;
207 let result = retry_call!(self.store.pool, hset, &self.name, key, &data)?;
208 self.conditional_expire().await?;
209 Ok(result)
210 }
211
212 #[instrument]
221 pub async fn delete(&self) -> Result<(), ErrorTypes> {
222 retry_call!(self.store.pool, del, &self.name)
223 }
224
225}
226
227
228#[cfg(test)]
229mod test {
230 use crate::test::redis_connection;
231 use crate::ErrorTypes;
232 use std::time::Duration;
233
234 #[tokio::test]
235 async fn hash() -> Result<(), ErrorTypes> {
236 let redis = redis_connection().await;
237 let h = redis.hashmap("test-hashmap".to_string(), None);
238 h.delete().await?;
239
240 let value_string = "value".to_owned();
241 let new_value_string = "new-value".to_owned();
242
243 assert!(h.add("key", &value_string).await?);
244 assert!(!h.add("key", &value_string).await?);
245 assert!(h.exists("key").await?);
246 assert_eq!(h.get("key").await?.unwrap(), value_string);
247 assert_eq!(h.set("key", &new_value_string).await?, 0);
248 assert!(!h.add("key", &value_string).await?);
249 assert_eq!(h.keys().await?, ["key"]);
250 assert_eq!(h.length().await?, 1);
251 assert_eq!(h.items().await?, [("key".to_owned(), new_value_string.clone())].into_iter().collect());
252 assert_eq!(h.pop("key").await?.unwrap(), new_value_string);
253 assert_eq!(h.length().await?, 0);
254 assert!(h.add("key", &value_string).await?);
255 assert_eq!(h.increment("a", 1).await?, 1);
271 assert_eq!(h.increment("a", 1).await?, 2);
272 assert_eq!(h.increment("a", 10).await?, 12);
273 assert_eq!(h.increment("a", -22).await?, -10);
274 h.delete().await?;
275
276 Ok(())
286 }
287
288 #[tokio::test]
289 async fn expiring_hash() -> Result<(), ErrorTypes> {
290 let redis = redis_connection().await;
291 let eh = redis.hashmap("test-expiring-hashmap".to_string(), Duration::from_secs(1).into());
292 eh.delete().await?;
293 assert!(eh.add("key", &"value".to_owned()).await?);
294 assert_eq!(eh.length().await?, 1);
295 tokio::time::sleep(Duration::from_secs_f32(1.1)).await;
296 assert_eq!(eh.length().await?, 0);
297 Ok(())
298 }
299
300}