use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::{Duration, Instant};
use bytes::Bytes;
use moka::Expiry;
use moka::sync::Cache;
use crate::Cache as CacheTrait;
#[derive(Clone, Debug, Eq)]
struct CacheKey(Bytes);
impl CacheKey {
#[inline]
fn new(key: &[u8]) -> Self {
Self(Bytes::copy_from_slice(key))
}
#[inline]
fn as_slice(&self) -> &[u8] {
&self.0
}
}
impl PartialEq for CacheKey {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Hash for CacheKey {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
#[derive(PartialEq, Eq)]
struct BorrowedKey<'a>(&'a [u8]);
impl Hash for BorrowedKey<'_> {
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl equivalent::Equivalent<CacheKey> for BorrowedKey<'_> {
#[inline]
fn equivalent(&self, key: &CacheKey) -> bool {
self.0 == key.as_slice()
}
}
#[derive(Clone, Debug)]
struct CacheEntry {
value: Bytes,
expires_at: Instant,
}
struct PerEntryExpiry;
impl Expiry<CacheKey, CacheEntry> for PerEntryExpiry {
fn expire_after_create(
&self,
_key: &CacheKey,
value: &CacheEntry,
_current_time: Instant,
) -> Option<Duration> {
let now = Instant::now();
if value.expires_at > now {
Some(value.expires_at.duration_since(now))
} else {
Some(Duration::ZERO)
}
}
fn expire_after_read(
&self,
_key: &CacheKey,
_value: &CacheEntry,
_current_time: Instant,
current_duration: Option<Duration>,
_last_modified_at: Instant,
) -> Option<Duration> {
current_duration
}
fn expire_after_update(
&self,
_key: &CacheKey,
value: &CacheEntry,
_current_time: Instant,
_current_duration: Option<Duration>,
) -> Option<Duration> {
let now = Instant::now();
if value.expires_at > now {
Some(value.expires_at.duration_since(now))
} else {
Some(Duration::ZERO)
}
}
}
#[derive(Clone, Debug)]
pub struct LocalCache {
inner: Arc<Cache<CacheKey, CacheEntry>>,
}
impl LocalCache {
pub fn new(capacity: u64, default_ttl: Duration) -> Self {
let cache = Cache::builder()
.max_capacity(capacity.max(1))
.time_to_live(default_ttl)
.expire_after(PerEntryExpiry)
.build();
Self {
inner: Arc::new(cache),
}
}
#[inline]
pub fn contains_sync(&self, key: &[u8]) -> bool {
self.inner.contains_key(&BorrowedKey(key))
}
#[inline]
pub fn get_sync(&self, key: &[u8]) -> Option<Bytes> {
self.inner.get(&BorrowedKey(key)).map(|entry| entry.value)
}
}
impl CacheTrait for LocalCache {
fn set_nx_px(
&self,
key: &[u8],
value: &[u8],
ttl: Duration,
) -> impl Future<Output = anyhow::Result<bool>> + Send {
let cache_key = CacheKey::new(key);
let entry = CacheEntry {
value: Bytes::copy_from_slice(value),
expires_at: Instant::now() + ttl,
};
let inner = Arc::clone(&self.inner);
async move {
let result = inner.entry(cache_key).or_insert(entry);
Ok(result.is_fresh())
}
}
fn set(
&self,
key: &[u8],
value: &[u8],
ttl: Duration,
) -> impl Future<Output = anyhow::Result<()>> + Send {
let cache_key = CacheKey::new(key);
let entry = CacheEntry {
value: Bytes::copy_from_slice(value),
expires_at: Instant::now() + ttl,
};
let inner = Arc::clone(&self.inner);
async move {
inner.insert(cache_key, entry);
Ok(())
}
}
fn get(&self, key: &[u8]) -> impl Future<Output = anyhow::Result<Option<Vec<u8>>>> + Send {
let result = self.inner.get(&BorrowedKey(key)).map(|entry| entry.value);
async move { Ok(result.map(|bytes| bytes.to_vec())) }
}
fn del(&self, key: &[u8]) -> impl Future<Output = anyhow::Result<()>> + Send {
self.inner.invalidate(&BorrowedKey(key));
async move { Ok(()) }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_set_and_get() {
let cache = LocalCache::new(100, Duration::from_secs(60));
cache
.set(b"key1", b"value1", Duration::from_secs(60))
.await
.unwrap();
let result = cache.get(b"key1").await.unwrap();
assert_eq!(result, Some(b"value1".to_vec()));
}
#[tokio::test]
async fn test_get_nonexistent() {
let cache = LocalCache::new(100, Duration::from_secs(60));
let result = cache.get(b"nonexistent").await.unwrap();
assert_eq!(result, None);
}
#[tokio::test]
async fn test_set_nx_px_new_key() {
let cache = LocalCache::new(100, Duration::from_secs(60));
let was_set = cache
.set_nx_px(b"key1", b"value1", Duration::from_secs(60))
.await
.unwrap();
assert!(was_set, "Expected key to be set (new key)");
let result = cache.get(b"key1").await.unwrap();
assert_eq!(result, Some(b"value1".to_vec()));
}
#[tokio::test]
async fn test_set_nx_px_existing_key() {
let cache = LocalCache::new(100, Duration::from_secs(60));
let was_set1 = cache
.set_nx_px(b"key1", b"value1", Duration::from_secs(60))
.await
.unwrap();
assert!(was_set1);
let was_set2 = cache
.set_nx_px(b"key1", b"value2", Duration::from_secs(60))
.await
.unwrap();
assert!(!was_set2, "Expected key NOT to be set (key exists)");
let result = cache.get(b"key1").await.unwrap();
assert_eq!(result, Some(b"value1".to_vec()));
}
#[tokio::test]
async fn test_del() {
let cache = LocalCache::new(100, Duration::from_secs(60));
cache
.set(b"key1", b"value1", Duration::from_secs(60))
.await
.unwrap();
cache.del(b"key1").await.unwrap();
let result = cache.get(b"key1").await.unwrap();
assert_eq!(result, None);
}
#[tokio::test]
async fn test_contains_sync() {
let cache = LocalCache::new(100, Duration::from_secs(60));
assert!(!cache.contains_sync(b"key1"));
cache
.set(b"key1", b"value1", Duration::from_secs(60))
.await
.unwrap();
assert!(cache.contains_sync(b"key1"));
}
#[tokio::test]
async fn test_get_sync() {
let cache = LocalCache::new(100, Duration::from_secs(60));
assert!(cache.get_sync(b"key1").is_none());
cache
.set(b"key1", b"value1", Duration::from_secs(60))
.await
.unwrap();
let result = cache.get_sync(b"key1");
assert_eq!(result, Some(Bytes::from_static(b"value1")));
}
#[tokio::test]
async fn test_per_entry_ttl_respected() {
let cache = LocalCache::new(100, Duration::from_secs(60));
cache
.set(b"short_ttl", b"value", Duration::from_millis(50))
.await
.unwrap();
let result = cache.get(b"short_ttl").await.unwrap();
assert_eq!(result, Some(b"value".to_vec()));
tokio::time::sleep(Duration::from_millis(100)).await;
let result = cache.get(b"short_ttl").await.unwrap();
assert_eq!(result, None, "Entry should have expired after TTL");
}
#[tokio::test]
async fn test_different_ttls_for_different_keys() {
let cache = LocalCache::new(100, Duration::from_secs(60));
cache
.set(b"short", b"value1", Duration::from_millis(50))
.await
.unwrap();
cache
.set(b"long", b"value2", Duration::from_secs(10))
.await
.unwrap();
assert!(cache.get(b"short").await.unwrap().is_some());
assert!(cache.get(b"long").await.unwrap().is_some());
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(
cache.get(b"short").await.unwrap().is_none(),
"Short TTL entry should have expired"
);
assert!(
cache.get(b"long").await.unwrap().is_some(),
"Long TTL entry should still exist"
);
}
#[tokio::test]
async fn test_set_nx_px_ttl_respected() {
let cache = LocalCache::new(100, Duration::from_secs(60));
let was_set = cache
.set_nx_px(b"key", b"value", Duration::from_millis(50))
.await
.unwrap();
assert!(was_set);
tokio::time::sleep(Duration::from_millis(100)).await;
let result = cache.get(b"key").await.unwrap();
assert_eq!(result, None, "Entry should have expired after TTL");
let was_set_again = cache
.set_nx_px(b"key", b"new_value", Duration::from_secs(60))
.await
.unwrap();
assert!(
was_set_again,
"Should be able to set after previous entry expired"
);
let result = cache.get(b"key").await.unwrap();
assert_eq!(result, Some(b"new_value".to_vec()));
}
}