use std::{marker::PhantomData, sync::Arc, collections::HashMap, time::Duration};
use parking_lot::Mutex;
use redis::AsyncCommands;
use serde::{de::DeserializeOwned, Serialize};
use crate::{RedisObjects, ErrorTypes, retry_call};
const POP_SCRIPT: &str = r#"
local result = redis.call('hget', ARGV[1], ARGV[2])
if result then redis.call('hdel', ARGV[1], ARGV[2]) end
return result
"#;
const CONDITIONAL_REMOVE_SCRIPT: &str = r#"
local hash_name = KEYS[1]
local key_in_hash = ARGV[1]
local expected_value = ARGV[2]
local result = redis.call('hget', hash_name, key_in_hash)
if result == expected_value then
redis.call('hdel', hash_name, key_in_hash)
return 1
end
return 0
"#;
#[derive(Clone)]
pub struct Hashmap<T> {
name: String,
store: Arc<RedisObjects>,
pop_script: redis::Script,
conditional_remove_script: redis::Script,
ttl: Option<Duration>,
last_expire_time: Arc<Mutex<Option<std::time::Instant>>>,
_data: PhantomData<T>
}
impl<T: Serialize + DeserializeOwned> Hashmap<T> {
pub (crate) fn new(name: String, store: Arc<RedisObjects>, ttl: Option<Duration>) -> Self {
Self {
name,
store,
pop_script: redis::Script::new(POP_SCRIPT),
conditional_remove_script: redis::Script::new(CONDITIONAL_REMOVE_SCRIPT),
ttl,
last_expire_time: Arc::new(Mutex::new(None)),
_data: PhantomData,
}
}
async fn conditional_expire(&self) -> Result<(), ErrorTypes> {
if let Some(ttl) = self.ttl {
let call = {
let mut last_expire_time = self.last_expire_time.lock();
let call = match *last_expire_time {
Some(time) => {
time.elapsed() > (ttl / 4)
},
None => true };
if call {
*last_expire_time = Some(std::time::Instant::now());
}
call
};
if call {
let ttl = ttl.as_secs() as i64;
let _: () = retry_call!(self.store.pool, expire, &self.name, ttl)?;
}
}
Ok(())
}
pub async fn add(&self, key: &str, value: &T) -> Result<bool, ErrorTypes> {
let data = serde_json::to_vec(value)?;
let result = retry_call!(self.store.pool, hset_nx, &self.name, &key, &data)?;
self.conditional_expire().await?;
Ok(result)
}
pub async fn increment(&self, key: &str, increment: i64) -> Result<i64, ErrorTypes> {
let result = retry_call!(self.store.pool, hincr, &self.name, key, increment)?;
self.conditional_expire().await?;
Ok(result)
}
pub async fn exists(&self, key: &str) -> Result<bool, ErrorTypes> {
retry_call!(self.store.pool, hexists, &self.name, key)
}
pub async fn get(&self, key: &str) -> Result<Option<T>, ErrorTypes> {
let item: Option<Vec<u8>> = retry_call!(self.store.pool, hget, &self.name, key)?;
Ok(match item {
Some(data) => Some(serde_json::from_slice(&data)?),
None => None,
})
}
pub async fn get_raw(&self, key: &str) -> Result<Option<Vec<u8>>, ErrorTypes> {
Ok(retry_call!(self.store.pool, hget, &self.name, key)?)
}
pub async fn keys(&self) -> Result<Vec<String>, ErrorTypes> {
retry_call!(self.store.pool, hkeys, &self.name)
}
pub async fn length(&self) -> Result<u64, ErrorTypes> {
retry_call!(self.store.pool, hlen, &self.name)
}
pub async fn items(&self) -> Result<HashMap<String, T>, ErrorTypes> {
let items: Vec<(String, Vec<u8>)> = retry_call!(self.store.pool, hgetall, &self.name)?;
let mut out = HashMap::new();
for (key, data) in items {
out.insert(key, serde_json::from_slice(&data)?);
}
Ok(out)
}
pub async fn conditional_remove(&self, key: &str, value: &T) -> Result<bool, ErrorTypes> {
let data = serde_json::to_vec(value)?;
retry_call!(method, self.store.pool, self.conditional_remove_script.key(&self.name).arg(key).arg(&data), invoke_async)
}
pub async fn pop(&self, key: &str) -> Result<Option<T>, ErrorTypes> {
let item: Option<Vec<u8>> = retry_call!(method, self.store.pool, self.pop_script.arg(&self.name).arg(key), invoke_async)?;
Ok(match item {
Some(data) => Some(serde_json::from_slice(&data)?),
None => None,
})
}
pub async fn set(&self, key: &str, value: &T) -> Result<i64, ErrorTypes> {
let data = serde_json::to_vec(value)?;
let result = retry_call!(self.store.pool, hset, &self.name, key, &data)?;
self.conditional_expire().await?;
Ok(result)
}
pub async fn delete(&self) -> Result<(), ErrorTypes> {
retry_call!(self.store.pool, del, &self.name)
}
}
#[cfg(test)]
mod test {
use crate::test::redis_connection;
use crate::ErrorTypes;
use std::time::Duration;
#[tokio::test]
async fn hash() -> Result<(), ErrorTypes> {
let redis = redis_connection().await;
let h = redis.hashmap("test-hashmap".to_string(), None);
h.delete().await?;
let value_string = "value".to_owned();
let new_value_string = "new-value".to_owned();
assert!(h.add("key", &value_string).await?);
assert!(!h.add("key", &value_string).await?);
assert!(h.exists("key").await?);
assert_eq!(h.get("key").await?.unwrap(), value_string);
assert_eq!(h.set("key", &new_value_string).await?, 0);
assert!(!h.add("key", &value_string).await?);
assert_eq!(h.keys().await?, ["key"]);
assert_eq!(h.length().await?, 1);
assert_eq!(h.items().await?, [("key".to_owned(), new_value_string.clone())].into_iter().collect());
assert_eq!(h.pop("key").await?.unwrap(), new_value_string);
assert_eq!(h.length().await?, 0);
assert!(h.add("key", &value_string).await?);
assert_eq!(h.increment("a", 1).await?, 1);
assert_eq!(h.increment("a", 1).await?, 2);
assert_eq!(h.increment("a", 10).await?, 12);
assert_eq!(h.increment("a", -22).await?, -10);
h.delete().await?;
Ok(())
}
#[tokio::test]
async fn expiring_hash() -> Result<(), ErrorTypes> {
let redis = redis_connection().await;
let eh = redis.hashmap("test-expiring-hashmap".to_string(), Duration::from_secs(1).into());
eh.delete().await?;
assert!(eh.add("key", &"value".to_owned()).await?);
assert_eq!(eh.length().await?, 1);
tokio::time::sleep(Duration::from_secs_f32(1.1)).await;
assert_eq!(eh.length().await?, 0);
Ok(())
}
}