use core::{fmt::Display, time::Duration};
use std::{hash::Hash, time::Instant};
use moka::future::Cache as LocalCache;
use redis::{AsyncCommands, RedisError, ToRedisArgs};
use serde::{de::DeserializeOwned, Serialize};
use siphasher::sip128::{Hasher128, SipHasher24};
use snafu::Snafu;
type RedisConnection = redis::aio::ConnectionManager;
pub struct Cache<K, V> {
local: LocalCache<K, LocalEntry<V>>,
redis: Option<RedisConfig>,
}
struct RedisConfig {
redis: RedisConnection,
prefix: String,
ttl: Duration,
hash_key: bool,
}
#[derive(Debug, Snafu)]
pub enum CacheError {
#[snafu(display("Redis error: {}", source), context(false))]
Redis { source: RedisError },
#[snafu(display("Serde error: {}", source), context(false))]
Serde { source: bincode::Error },
}
impl<K, V> Cache<K, V>
where
K: Display + Hash + Eq + Send + Sync + 'static,
V: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
{
pub fn new(ttl: Duration) -> Self {
Self {
local: LocalCache::builder().time_to_live(ttl).build(),
redis: None,
}
}
pub fn with_redis(
self,
redis: RedisConnection,
prefix: impl Into<String>,
ttl: Duration,
hash_key: bool,
) -> Self {
Self {
redis: Some(RedisConfig {
redis,
prefix: prefix.into(),
ttl,
hash_key,
}),
..self
}
}
pub fn longest_ttl(&self) -> Duration {
let local_ttl = self
.local
.policy()
.time_to_live()
.expect("local always has a ttl");
if let Some(redis) = &self.redis {
redis.ttl.max(local_ttl)
} else {
local_ttl
}
}
pub async fn get(&self, key: &K) -> Result<Option<V>, CacheError> {
if let Some(entry) = self
.local
.get(key)
.await
.filter(|entry| entry.still_valid())
{
Ok(Some(entry.value))
} else if let Some(RedisConfig {
redis,
prefix,
hash_key,
..
}) = &self.redis
{
let v: Option<Vec<u8>> = redis
.clone()
.get(RedisCacheKey {
prefix,
key,
hash_key: *hash_key,
})
.await?;
if let Some(v) = v {
let v = bincode::deserialize(&v)?;
Ok(Some(v))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
pub async fn insert(&self, key: K, value: V) -> Result<(), CacheError> {
if let Some(RedisConfig {
redis,
prefix,
ttl,
hash_key,
}) = &self.redis
{
redis
.clone()
.set_ex(
RedisCacheKey {
prefix,
key: &key,
hash_key: *hash_key,
},
bincode::serialize(&value)?,
ttl.as_secs(),
)
.await?;
}
self.local
.insert(
key,
LocalEntry {
value,
expires_at: None,
},
)
.await;
Ok(())
}
pub async fn insert_with_ttl(&self, key: K, value: V, ttl: Duration) -> Result<(), CacheError> {
if ttl >= self.longest_ttl() {
return self.insert(key, value).await;
}
if let Some(RedisConfig {
redis,
prefix,
hash_key,
..
}) = &self.redis
{
redis
.clone()
.set_ex(
RedisCacheKey {
prefix,
key: &key,
hash_key: *hash_key,
},
bincode::serialize(&value)?,
ttl.as_secs(),
)
.await?;
}
self.local
.insert(
key,
LocalEntry {
value,
expires_at: Some(Instant::now() + ttl),
},
)
.await;
Ok(())
}
}
struct RedisCacheKey<'a, K> {
hash_key: bool,
prefix: &'a str,
key: &'a K,
}
impl<K: Display + Hash> Display for RedisCacheKey<'_, K> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.hash_key {
let mut h = SipHasher24::new_with_keys(!0x113, 0x311);
self.key.hash(&mut h);
let hash = h.finish128().as_u128();
write!(f, "opentalk-cache:{}:{:x}", self.prefix, hash)
} else {
write!(f, "opentalk-cache:{}:{}", self.prefix, self.key)
}
}
}
impl<D: Display + Hash> ToRedisArgs for RedisCacheKey<'_, D> {
fn write_redis_args<W>(&self, out: &mut W)
where
W: ?Sized + redis::RedisWrite,
{
out.write_arg_fmt(self)
}
}
#[derive(Debug, Clone, Copy)]
struct LocalEntry<V> {
value: V,
expires_at: Option<Instant>,
}
impl<V> LocalEntry<V> {
fn still_valid(&self) -> bool {
if let Some(exp) = self.expires_at {
exp.saturating_duration_since(Instant::now()) > Duration::ZERO
} else {
true
}
}
}