use super::invalidation::EntityTag;
use super::key::{CacheKey, KeyPattern};
use std::future::Future;
use std::time::Duration;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum CacheError {
#[error("serialization error: {0}")]
Serialization(String),
#[error("deserialization error: {0}")]
Deserialization(String),
#[error("connection error: {0}")]
Connection(String),
#[error("operation timed out")]
Timeout,
#[error("key not found: {0}")]
NotFound(String),
#[error("backend error: {0}")]
Backend(String),
#[error("configuration error: {0}")]
Config(String),
}
pub type CacheResult<T> = Result<T, CacheError>;
#[derive(Debug, Clone)]
pub struct CacheEntry<T> {
pub value: T,
pub created_at: std::time::Instant,
pub ttl: Option<Duration>,
pub tags: Vec<EntityTag>,
pub size_bytes: Option<usize>,
}
impl<T> CacheEntry<T> {
pub fn new(value: T) -> Self {
Self {
value,
created_at: std::time::Instant::now(),
ttl: None,
tags: Vec::new(),
size_bytes: None,
}
}
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = Some(ttl);
self
}
pub fn with_tags(mut self, tags: Vec<EntityTag>) -> Self {
self.tags = tags;
self
}
pub fn with_size(mut self, size: usize) -> Self {
self.size_bytes = Some(size);
self
}
pub fn is_expired(&self) -> bool {
if let Some(ttl) = self.ttl {
self.created_at.elapsed() >= ttl
} else {
false
}
}
pub fn remaining_ttl(&self) -> Option<Duration> {
self.ttl
.map(|ttl| ttl.saturating_sub(self.created_at.elapsed()))
}
}
pub trait CacheBackend: Send + Sync + 'static {
fn get<T>(&self, key: &CacheKey) -> impl Future<Output = CacheResult<Option<T>>> + Send
where
T: serde::de::DeserializeOwned;
fn set<T>(
&self,
key: &CacheKey,
value: &T,
ttl: Option<Duration>,
) -> impl Future<Output = CacheResult<()>> + Send
where
T: serde::Serialize + Sync;
fn delete(&self, key: &CacheKey) -> impl Future<Output = CacheResult<bool>> + Send;
fn exists(&self, key: &CacheKey) -> impl Future<Output = CacheResult<bool>> + Send;
fn get_many<T>(
&self,
keys: &[CacheKey],
) -> impl Future<Output = CacheResult<Vec<Option<T>>>> + Send
where
T: serde::de::DeserializeOwned + Send,
{
async move {
let mut results = Vec::with_capacity(keys.len());
for key in keys {
results.push(self.get::<T>(key).await?);
}
Ok(results)
}
}
fn set_many<T>(
&self,
entries: &[(&CacheKey, &T)],
ttl: Option<Duration>,
) -> impl Future<Output = CacheResult<()>> + Send
where
T: serde::Serialize + Sync + Send,
{
async move {
for (key, value) in entries {
self.set(key, *value, ttl).await?;
}
Ok(())
}
}
fn delete_many(&self, keys: &[CacheKey]) -> impl Future<Output = CacheResult<u64>> + Send {
async move {
let mut count = 0u64;
for key in keys {
if self.delete(key).await? {
count += 1;
}
}
Ok(count)
}
}
fn invalidate_pattern(
&self,
pattern: &KeyPattern,
) -> impl Future<Output = CacheResult<u64>> + Send;
fn invalidate_tags(&self, tags: &[EntityTag]) -> impl Future<Output = CacheResult<u64>> + Send;
fn clear(&self) -> impl Future<Output = CacheResult<()>> + Send;
fn len(&self) -> impl Future<Output = CacheResult<usize>> + Send;
fn is_empty(&self) -> impl Future<Output = CacheResult<bool>> + Send {
async move { Ok(self.len().await? == 0) }
}
fn stats(&self) -> impl Future<Output = CacheResult<BackendStats>> + Send {
async move {
Ok(BackendStats {
entries: self.len().await?,
..Default::default()
})
}
}
}
#[derive(Debug, Clone, Default)]
pub struct BackendStats {
pub entries: usize,
pub memory_bytes: Option<usize>,
pub connections: Option<usize>,
pub info: Option<String>,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NoopCache;
impl CacheBackend for NoopCache {
async fn get<T>(&self, _key: &CacheKey) -> CacheResult<Option<T>>
where
T: serde::de::DeserializeOwned,
{
Ok(None)
}
async fn set<T>(&self, _key: &CacheKey, _value: &T, _ttl: Option<Duration>) -> CacheResult<()>
where
T: serde::Serialize + Sync,
{
Ok(())
}
async fn delete(&self, _key: &CacheKey) -> CacheResult<bool> {
Ok(false)
}
async fn exists(&self, _key: &CacheKey) -> CacheResult<bool> {
Ok(false)
}
async fn invalidate_pattern(&self, _pattern: &KeyPattern) -> CacheResult<u64> {
Ok(0)
}
async fn invalidate_tags(&self, _tags: &[EntityTag]) -> CacheResult<u64> {
Ok(0)
}
async fn clear(&self) -> CacheResult<()> {
Ok(())
}
async fn len(&self) -> CacheResult<usize> {
Ok(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_entry() {
let entry = CacheEntry::new("test value")
.with_ttl(Duration::from_secs(60))
.with_tags(vec![EntityTag::new("User")]);
assert!(!entry.is_expired());
assert!(entry.remaining_ttl().unwrap() > Duration::from_secs(59));
}
#[tokio::test]
async fn test_noop_cache() {
let cache = NoopCache;
cache
.set(&CacheKey::new("test", "key"), &"value", None)
.await
.unwrap();
let result: Option<String> = cache.get(&CacheKey::new("test", "key")).await.unwrap();
assert!(result.is_none());
}
}