use crate::{backends::CacheBackend, error::Error, metrics::Metrics, Result};
use async_trait::async_trait;
use rocksdb::{Options, DB};
use serde::{Deserialize, Serialize};
use std::{
path::Path,
sync::Arc,
time::{Duration, SystemTime},
};
#[derive(Debug, Serialize, Deserialize)]
struct CacheEntry {
value: Vec<u8>,
expires_at: Option<SystemTime>,
}
#[derive(Debug)]
pub struct RocksDBBackend {
db: Arc<DB>,
metrics: Arc<Metrics>,
}
impl RocksDBBackend {
pub fn new<P: AsRef<Path>>(db_path: P) -> Result<Self> {
let mut options = Options::default();
options.create_if_missing(true);
let db = DB::open(&options, db_path)
.map_err(|e| Error::Backend(format!("Failed to open RocksDB: {}", e)))?;
Ok(Self {
db: Arc::new(db),
metrics: Arc::new(Metrics::new()),
})
}
fn is_expired(entry: &CacheEntry) -> bool {
if let Some(expires_at) = entry.expires_at {
SystemTime::now() > expires_at
} else {
false
}
}
}
#[async_trait]
impl CacheBackend for RocksDBBackend {
async fn get(&self, key: &String) -> Result<Option<Vec<u8>>> {
match self.db.get(key.as_bytes()) {
Ok(Some(bytes)) => match bincode::deserialize::<CacheEntry>(&bytes) {
Ok(entry) => {
if Self::is_expired(&entry) {
if let Err(e) = self.db.delete(key.as_bytes()) {
return Err(Error::Backend(format!(
"Failed to delete expired key: {}",
e
)));
}
self.metrics.record_miss();
Ok(None)
} else {
self.metrics.record_hit();
Ok(Some(entry.value))
}
}
Err(e) => {
self.metrics.record_miss();
Err(Error::Codec(format!(
"Failed to deserialize cache entry: {}",
e
)))
}
},
Ok(None) => {
self.metrics.record_miss();
Ok(None)
}
Err(e) => Err(Error::Backend(format!("RocksDB error: {}", e))),
}
}
async fn set(&self, key: String, value: Vec<u8>, ttl: Option<Duration>) -> Result<()> {
let expires_at = ttl.map(|duration| {
SystemTime::now()
.checked_add(duration)
.unwrap_or_else(|| SystemTime::now() + duration)
});
let entry = CacheEntry { value, expires_at };
let bytes = bincode::serialize(&entry)
.map_err(|e| Error::Codec(format!("Failed to serialize cache entry: {}", e)))?;
self.db
.put(key.as_bytes(), bytes)
.map_err(|e| Error::Backend(format!("Failed to store in RocksDB: {}", e)))?;
self.metrics.record_insertion();
Ok(())
}
async fn remove(&self, key: &String) -> Result<()> {
self.db
.delete(key.as_bytes())
.map_err(|e| Error::Backend(format!("Failed to remove from RocksDB: {}", e)))?;
Ok(())
}
async fn contains_key(&self, key: &String) -> Result<bool> {
match self.db.get(key.as_bytes()) {
Ok(Some(bytes)) => match bincode::deserialize::<CacheEntry>(&bytes) {
Ok(entry) => {
if Self::is_expired(&entry) {
Ok(false)
} else {
Ok(true)
}
}
Err(_) => Ok(false),
},
Ok(None) => Ok(false),
Err(e) => Err(Error::Backend(format!("RocksDB error: {}", e))),
}
}
async fn clear(&self) -> Result<()> {
let iter = self.db.iterator(rocksdb::IteratorMode::Start);
let keys: Vec<Vec<u8>> = iter.map(|item| item.unwrap().0.to_vec()).collect();
for key in keys {
if let Err(e) = self.db.delete(&key) {
return Err(Error::Backend(format!(
"Failed to delete key during clear: {}",
e
)));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use tempfile::tempdir;
use tokio::time::sleep;
#[tokio::test]
#[serial]
async fn test_get_set() {
let temp_dir = tempdir().unwrap();
let db_path = temp_dir.path();
let backend = RocksDBBackend::new(db_path).unwrap();
let key = "test_key".to_string();
let value = b"test_value".to_vec();
backend.set(key.clone(), value.clone(), None).await.unwrap();
let result = backend.get(&key).await.unwrap();
assert_eq!(result, Some(value));
assert!(backend.contains_key(&key).await.unwrap());
backend.remove(&key).await.unwrap();
assert_eq!(backend.get(&key).await.unwrap(), None);
assert!(!backend.contains_key(&key).await.unwrap());
}
#[tokio::test]
#[serial]
async fn test_ttl() {
let temp_dir = tempdir().unwrap();
let db_path = temp_dir.path();
let backend = RocksDBBackend::new(db_path).unwrap();
let key = "test_ttl".to_string();
let value = b"test_value".to_vec();
backend
.set(key.clone(), value, Some(Duration::from_millis(100)))
.await
.unwrap();
assert!(backend.get(&key).await.unwrap().is_some());
sleep(Duration::from_millis(150)).await;
assert!(backend.get(&key).await.unwrap().is_none());
}
#[tokio::test]
#[serial]
async fn test_clear() {
let temp_dir = tempdir().unwrap();
let db_path = temp_dir.path();
let backend = RocksDBBackend::new(db_path).unwrap();
let key1 = "test_key1".to_string();
let key2 = "test_key2".to_string();
let value = b"test_value".to_vec();
backend
.set(key1.clone(), value.clone(), None)
.await
.unwrap();
backend
.set(key2.clone(), value.clone(), None)
.await
.unwrap();
assert!(backend.contains_key(&key1).await.unwrap());
assert!(backend.contains_key(&key2).await.unwrap());
backend.clear().await.unwrap();
assert!(!backend.contains_key(&key1).await.unwrap());
assert!(!backend.contains_key(&key2).await.unwrap());
}
#[tokio::test]
#[serial]
async fn test_metrics() {
let temp_dir = tempdir().unwrap();
let db_path = temp_dir.path();
let backend = RocksDBBackend::new(db_path).unwrap();
let key = "test_metrics".to_string();
let value = b"test_value".to_vec();
assert_eq!(backend.metrics.hits(), 0);
assert_eq!(backend.metrics.misses(), 0);
assert!(backend.get(&key).await.unwrap().is_none());
assert_eq!(backend.metrics.misses(), 1);
backend.set(key.clone(), value, None).await.unwrap();
assert!(backend.get(&key).await.unwrap().is_some());
assert_eq!(backend.metrics.hits(), 1);
}
}