use std::path::Path;
use std::sync::Arc;
use rocksdb::{Options, DB};
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};
use common::{DakeraError, Result, Vector, VectorId};
#[derive(Debug, Clone)]
pub struct DiskCacheConfig {
pub path: String,
pub max_size_bytes: u64,
pub compression: bool,
pub write_buffer_size: usize,
pub max_write_buffer_number: i32,
}
impl Default for DiskCacheConfig {
fn default() -> Self {
Self {
path: "./cache".to_string(),
max_size_bytes: 10 * 1024 * 1024 * 1024, compression: true,
write_buffer_size: 64 * 1024 * 1024, max_write_buffer_number: 3,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct CacheEntry {
vector: Vector,
access_count: u64,
created_at: u64,
}
pub struct DiskCache {
db: Arc<DB>,
#[allow(dead_code)]
config: DiskCacheConfig,
}
impl DiskCache {
pub fn new(config: DiskCacheConfig) -> Result<Self> {
let mut opts = Options::default();
opts.create_if_missing(true);
opts.set_write_buffer_size(config.write_buffer_size);
opts.set_max_write_buffer_number(config.max_write_buffer_number);
if config.compression {
opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
}
opts.set_level_compaction_dynamic_level_bytes(true);
opts.set_max_background_jobs(4);
let db = DB::open(&opts, &config.path)
.map_err(|e| DakeraError::Storage(format!("Failed to open RocksDB: {}", e)))?;
debug!(path = %config.path, "Disk cache initialized");
Ok(Self {
db: Arc::new(db),
config,
})
}
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let config = DiskCacheConfig {
path: path.as_ref().to_string_lossy().to_string(),
..Default::default()
};
Self::new(config)
}
fn make_key(namespace: &str, id: &VectorId) -> Vec<u8> {
format!("{}:{}", namespace, id).into_bytes()
}
fn namespace_prefix(namespace: &str) -> Vec<u8> {
format!("{}:", namespace).into_bytes()
}
pub fn put(&self, namespace: &str, vector: &Vector) -> Result<()> {
let key = Self::make_key(namespace, &vector.id);
let entry = CacheEntry {
vector: vector.clone(),
access_count: 1,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
let value = serde_json::to_vec(&entry)
.map_err(|e| DakeraError::Storage(format!("Failed to serialize cache entry: {}", e)))?;
self.db
.put(&key, &value)
.map_err(|e| DakeraError::Storage(format!("Failed to write to disk cache: {}", e)))?;
debug!(namespace = %namespace, id = %vector.id, "Cached vector to disk");
Ok(())
}
pub fn put_batch(&self, namespace: &str, vectors: &[Vector]) -> Result<usize> {
let mut batch = rocksdb::WriteBatch::default();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
for vector in vectors {
let key = Self::make_key(namespace, &vector.id);
let entry = CacheEntry {
vector: vector.clone(),
access_count: 1,
created_at: now,
};
let value = serde_json::to_vec(&entry).map_err(|e| {
DakeraError::Storage(format!("Failed to serialize cache entry: {}", e))
})?;
batch.put(&key, &value);
}
let count = vectors.len();
self.db.write(batch).map_err(|e| {
DakeraError::Storage(format!("Failed to write batch to disk cache: {}", e))
})?;
debug!(namespace = %namespace, count = count, "Batch cached vectors to disk");
Ok(count)
}
pub fn get(&self, namespace: &str, id: &VectorId) -> Result<Option<Vector>> {
let key = Self::make_key(namespace, id);
match self.db.get(&key) {
Ok(Some(value)) => {
let entry: CacheEntry = serde_json::from_slice(&value).map_err(|e| {
DakeraError::Storage(format!("Failed to deserialize cache entry: {}", e))
})?;
Ok(Some(entry.vector))
}
Ok(None) => Ok(None),
Err(e) => {
warn!(error = %e, "Failed to read from disk cache");
Ok(None)
}
}
}
pub fn get_batch(&self, namespace: &str, ids: &[VectorId]) -> Result<Vec<Vector>> {
let keys: Vec<Vec<u8>> = ids.iter().map(|id| Self::make_key(namespace, id)).collect();
let results = self.db.multi_get(&keys);
let mut vectors = Vec::with_capacity(ids.len());
for result in results {
if let Ok(Some(value)) = result {
if let Ok(entry) = serde_json::from_slice::<CacheEntry>(&value) {
vectors.push(entry.vector);
}
}
}
Ok(vectors)
}
pub fn get_all(&self, namespace: &str) -> Result<Vec<Vector>> {
let prefix = Self::namespace_prefix(namespace);
let mut vectors = Vec::new();
let iter = self.db.prefix_iterator(&prefix);
for item in iter {
match item {
Ok((key, value)) => {
if !key.starts_with(&prefix) {
break;
}
if let Ok(entry) = serde_json::from_slice::<CacheEntry>(&value) {
vectors.push(entry.vector);
}
}
Err(e) => {
warn!(error = %e, "Error iterating disk cache");
break;
}
}
}
Ok(vectors)
}
pub fn delete(&self, namespace: &str, id: &VectorId) -> Result<bool> {
let key = Self::make_key(namespace, id);
let existed = self
.db
.get(&key)
.map_err(|e| DakeraError::Storage(format!("Failed to check disk cache: {}", e)))?
.is_some();
if existed {
self.db.delete(&key).map_err(|e| {
DakeraError::Storage(format!("Failed to delete from disk cache: {}", e))
})?;
}
Ok(existed)
}
pub fn delete_batch(&self, namespace: &str, ids: &[VectorId]) -> Result<usize> {
let mut batch = rocksdb::WriteBatch::default();
let mut count = 0;
for id in ids {
let key = Self::make_key(namespace, id);
if self.db.get(&key).ok().flatten().is_some() {
batch.delete(&key);
count += 1;
}
}
self.db.write(batch).map_err(|e| {
DakeraError::Storage(format!("Failed to delete batch from disk cache: {}", e))
})?;
Ok(count)
}
pub fn clear_namespace(&self, namespace: &str) -> Result<usize> {
let prefix = Self::namespace_prefix(namespace);
let mut batch = rocksdb::WriteBatch::default();
let mut count = 0;
let iter = self.db.prefix_iterator(&prefix);
for item in iter {
match item {
Ok((key, _)) => {
if !key.starts_with(&prefix) {
break;
}
batch.delete(&key);
count += 1;
}
Err(_) => break,
}
}
if count > 0 {
self.db.write(batch).map_err(|e| {
DakeraError::Storage(format!("Failed to clear namespace from disk cache: {}", e))
})?;
}
debug!(namespace = %namespace, count = count, "Cleared namespace from disk cache");
Ok(count)
}
pub fn approximate_size(&self) -> u64 {
self.db
.property_int_value("rocksdb.estimate-live-data-size")
.ok()
.flatten()
.unwrap_or(0)
}
pub fn stats(&self) -> DiskCacheStats {
DiskCacheStats {
approximate_size_bytes: self.approximate_size(),
approximate_num_keys: self
.db
.property_int_value("rocksdb.estimate-num-keys")
.ok()
.flatten()
.unwrap_or(0),
}
}
pub fn flush(&self) -> Result<()> {
self.db
.flush()
.map_err(|e| DakeraError::Storage(format!("Failed to flush disk cache: {}", e)))
}
}
#[derive(Debug, Clone)]
pub struct DiskCacheStats {
pub approximate_size_bytes: u64,
pub approximate_num_keys: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_cache() -> (DiskCache, TempDir) {
let temp_dir = TempDir::new().unwrap();
let config = DiskCacheConfig {
path: temp_dir.path().to_string_lossy().to_string(),
..Default::default()
};
let cache = DiskCache::new(config).unwrap();
(cache, temp_dir)
}
fn test_vector(id: &str) -> Vector {
Vector {
id: id.to_string(),
values: vec![1.0, 2.0, 3.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
}
}
#[test]
fn test_put_and_get() {
let (cache, _dir) = create_test_cache();
let namespace = "test";
let vector = test_vector("v1");
cache.put(namespace, &vector).unwrap();
let result = cache.get(namespace, &"v1".to_string()).unwrap();
assert!(result.is_some());
let retrieved = result.unwrap();
assert_eq!(retrieved.id, "v1");
assert_eq!(retrieved.values, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_get_nonexistent() {
let (cache, _dir) = create_test_cache();
let result = cache.get("test", &"nonexistent".to_string()).unwrap();
assert!(result.is_none());
}
#[test]
fn test_batch_operations() {
let (cache, _dir) = create_test_cache();
let namespace = "test";
let vectors = vec![test_vector("v1"), test_vector("v2"), test_vector("v3")];
let count = cache.put_batch(namespace, &vectors).unwrap();
assert_eq!(count, 3);
let ids: Vec<String> = vec!["v1".to_string(), "v2".to_string(), "v3".to_string()];
let retrieved = cache.get_batch(namespace, &ids).unwrap();
assert_eq!(retrieved.len(), 3);
}
#[test]
fn test_get_all() {
let (cache, _dir) = create_test_cache();
let namespace = "test";
let vectors = vec![test_vector("v1"), test_vector("v2")];
cache.put_batch(namespace, &vectors).unwrap();
let all = cache.get_all(namespace).unwrap();
assert_eq!(all.len(), 2);
}
#[test]
fn test_delete() {
let (cache, _dir) = create_test_cache();
let namespace = "test";
let vector = test_vector("v1");
cache.put(namespace, &vector).unwrap();
assert!(cache.get(namespace, &"v1".to_string()).unwrap().is_some());
let deleted = cache.delete(namespace, &"v1".to_string()).unwrap();
assert!(deleted);
assert!(cache.get(namespace, &"v1".to_string()).unwrap().is_none());
}
#[test]
fn test_delete_batch() {
let (cache, _dir) = create_test_cache();
let namespace = "test";
let vectors = vec![test_vector("v1"), test_vector("v2"), test_vector("v3")];
cache.put_batch(namespace, &vectors).unwrap();
let ids = vec!["v1".to_string(), "v2".to_string()];
let deleted = cache.delete_batch(namespace, &ids).unwrap();
assert_eq!(deleted, 2);
assert!(cache.get(namespace, &"v1".to_string()).unwrap().is_none());
assert!(cache.get(namespace, &"v2".to_string()).unwrap().is_none());
assert!(cache.get(namespace, &"v3".to_string()).unwrap().is_some());
}
#[test]
fn test_clear_namespace() {
let (cache, _dir) = create_test_cache();
let vectors = vec![test_vector("v1"), test_vector("v2")];
cache.put_batch("ns1", &vectors).unwrap();
cache.put_batch("ns2", &vectors).unwrap();
let cleared = cache.clear_namespace("ns1").unwrap();
assert_eq!(cleared, 2);
assert!(cache.get_all("ns1").unwrap().is_empty());
assert_eq!(cache.get_all("ns2").unwrap().len(), 2);
}
#[test]
fn test_namespace_isolation() {
let (cache, _dir) = create_test_cache();
let vector = test_vector("v1");
cache.put("ns1", &vector).unwrap();
cache.put("ns2", &vector).unwrap();
assert!(cache.get("ns1", &"v1".to_string()).unwrap().is_some());
assert!(cache.get("ns2", &"v1".to_string()).unwrap().is_some());
cache.delete("ns1", &"v1".to_string()).unwrap();
assert!(cache.get("ns1", &"v1".to_string()).unwrap().is_none());
assert!(cache.get("ns2", &"v1".to_string()).unwrap().is_some());
}
#[test]
fn test_stats() {
let (cache, _dir) = create_test_cache();
let vectors = vec![test_vector("v1"), test_vector("v2")];
cache.put_batch("test", &vectors).unwrap();
cache.flush().unwrap();
let stats = cache.stats();
let _ = stats.approximate_num_keys;
}
}