1use std::path::Path;
7use std::sync::Arc;
8
9use rocksdb::{BlockBasedOptions, Cache, Options, DB};
10use serde::{Deserialize, Serialize};
11use tracing::{debug, warn};
12
13use common::{DakeraError, Result, Vector, VectorId};
14
15#[derive(Debug, Clone)]
17pub struct DiskCacheConfig {
18 pub path: String,
20 pub max_size_bytes: u64,
22 pub compression: bool,
24 pub write_buffer_size: usize,
26 pub max_write_buffer_number: i32,
28}
29
30impl Default for DiskCacheConfig {
31 fn default() -> Self {
32 Self {
33 path: "./cache".to_string(),
34 max_size_bytes: 10 * 1024 * 1024 * 1024, compression: true,
36 write_buffer_size: 64 * 1024 * 1024, max_write_buffer_number: 3,
38 }
39 }
40}
41
42#[derive(Debug, Serialize, Deserialize)]
44struct CacheEntry {
45 vector: Vector,
46 access_count: u64,
47 created_at: u64,
48}
49
50pub struct DiskCache {
52 db: Arc<DB>,
53 #[allow(dead_code)]
54 config: DiskCacheConfig,
55}
56
57impl DiskCache {
58 pub fn new(config: DiskCacheConfig) -> Result<Self> {
60 let mut opts = Options::default();
61 opts.create_if_missing(true);
62 opts.set_write_buffer_size(config.write_buffer_size);
63 opts.set_max_write_buffer_number(config.max_write_buffer_number);
64
65 if config.compression {
66 opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
67 }
68
69 let block_cache_mb: usize = std::env::var("DAKERA_ROCKSDB_BLOCK_CACHE_MB")
71 .ok()
72 .and_then(|v| v.parse().ok())
73 .unwrap_or(64);
74 let cache = Cache::new_lru_cache(block_cache_mb * 1024 * 1024);
75 let mut block_opts = BlockBasedOptions::default();
76 block_opts.set_block_cache(&cache);
77 block_opts.set_optimize_filters_for_memory(true);
78 opts.set_block_based_table_factory(&block_opts);
79
80 opts.set_level_compaction_dynamic_level_bytes(true);
82 opts.set_max_background_jobs(4);
83
84 let db = DB::open(&opts, &config.path)
85 .map_err(|e| DakeraError::Storage(format!("Failed to open RocksDB: {}", e)))?;
86
87 debug!(path = %config.path, "Disk cache initialized");
88
89 Ok(Self {
90 db: Arc::new(db),
91 config,
92 })
93 }
94
95 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
97 let config = DiskCacheConfig {
98 path: path.as_ref().to_string_lossy().to_string(),
99 ..Default::default()
100 };
101 Self::new(config)
102 }
103
104 fn make_key(namespace: &str, id: &VectorId) -> Vec<u8> {
106 format!("{}:{}", namespace, id).into_bytes()
107 }
108
109 fn namespace_prefix(namespace: &str) -> Vec<u8> {
111 format!("{}:", namespace).into_bytes()
112 }
113
114 pub fn put(&self, namespace: &str, vector: &Vector) -> Result<()> {
116 let key = Self::make_key(namespace, &vector.id);
117 let entry = CacheEntry {
118 vector: vector.clone(),
119 access_count: 1,
120 created_at: std::time::SystemTime::now()
121 .duration_since(std::time::UNIX_EPOCH)
122 .unwrap_or_default()
123 .as_secs(),
124 };
125
126 let value = serde_json::to_vec(&entry)
127 .map_err(|e| DakeraError::Storage(format!("Failed to serialize cache entry: {}", e)))?;
128
129 self.db
130 .put(&key, &value)
131 .map_err(|e| DakeraError::Storage(format!("Failed to write to disk cache: {}", e)))?;
132
133 debug!(namespace = %namespace, id = %vector.id, "Cached vector to disk");
134 Ok(())
135 }
136
137 pub fn put_batch(&self, namespace: &str, vectors: &[Vector]) -> Result<usize> {
139 let mut batch = rocksdb::WriteBatch::default();
140 let now = std::time::SystemTime::now()
141 .duration_since(std::time::UNIX_EPOCH)
142 .unwrap_or_default()
143 .as_secs();
144
145 for vector in vectors {
146 let key = Self::make_key(namespace, &vector.id);
147 let entry = CacheEntry {
148 vector: vector.clone(),
149 access_count: 1,
150 created_at: now,
151 };
152
153 let value = serde_json::to_vec(&entry).map_err(|e| {
154 DakeraError::Storage(format!("Failed to serialize cache entry: {}", e))
155 })?;
156
157 batch.put(&key, &value);
158 }
159
160 let count = vectors.len();
161 self.db.write(batch).map_err(|e| {
162 DakeraError::Storage(format!("Failed to write batch to disk cache: {}", e))
163 })?;
164
165 debug!(namespace = %namespace, count = count, "Batch cached vectors to disk");
166 Ok(count)
167 }
168
169 pub fn get(&self, namespace: &str, id: &VectorId) -> Result<Option<Vector>> {
171 let key = Self::make_key(namespace, id);
172
173 match self.db.get(&key) {
174 Ok(Some(value)) => {
175 let entry: CacheEntry = serde_json::from_slice(&value).map_err(|e| {
176 DakeraError::Storage(format!("Failed to deserialize cache entry: {}", e))
177 })?;
178 Ok(Some(entry.vector))
179 }
180 Ok(None) => Ok(None),
181 Err(e) => {
182 warn!(error = %e, "Failed to read from disk cache");
183 Ok(None)
184 }
185 }
186 }
187
188 pub fn get_batch(&self, namespace: &str, ids: &[VectorId]) -> Result<Vec<Vector>> {
190 let keys: Vec<Vec<u8>> = ids.iter().map(|id| Self::make_key(namespace, id)).collect();
191
192 let results = self.db.multi_get(&keys);
193 let mut vectors = Vec::with_capacity(ids.len());
194
195 for result in results {
196 if let Ok(Some(value)) = result {
197 if let Ok(entry) = serde_json::from_slice::<CacheEntry>(&value) {
198 vectors.push(entry.vector);
199 }
200 }
201 }
202
203 Ok(vectors)
204 }
205
206 pub fn get_all(&self, namespace: &str) -> Result<Vec<Vector>> {
208 let prefix = Self::namespace_prefix(namespace);
209 let mut vectors = Vec::new();
210
211 let iter = self.db.prefix_iterator(&prefix);
212 for item in iter {
213 match item {
214 Ok((key, value)) => {
215 if !key.starts_with(&prefix) {
217 break;
218 }
219
220 if let Ok(entry) = serde_json::from_slice::<CacheEntry>(&value) {
221 vectors.push(entry.vector);
222 }
223 }
224 Err(e) => {
225 warn!(error = %e, "Error iterating disk cache");
226 break;
227 }
228 }
229 }
230
231 Ok(vectors)
232 }
233
234 pub fn delete(&self, namespace: &str, id: &VectorId) -> Result<bool> {
236 let key = Self::make_key(namespace, id);
237
238 let existed = self
240 .db
241 .get(&key)
242 .map_err(|e| DakeraError::Storage(format!("Failed to check disk cache: {}", e)))?
243 .is_some();
244
245 if existed {
246 self.db.delete(&key).map_err(|e| {
247 DakeraError::Storage(format!("Failed to delete from disk cache: {}", e))
248 })?;
249 }
250
251 Ok(existed)
252 }
253
254 pub fn delete_batch(&self, namespace: &str, ids: &[VectorId]) -> Result<usize> {
256 let mut batch = rocksdb::WriteBatch::default();
257 let mut count = 0;
258
259 for id in ids {
260 let key = Self::make_key(namespace, id);
261 if self.db.get(&key).ok().flatten().is_some() {
262 batch.delete(&key);
263 count += 1;
264 }
265 }
266
267 self.db.write(batch).map_err(|e| {
268 DakeraError::Storage(format!("Failed to delete batch from disk cache: {}", e))
269 })?;
270
271 Ok(count)
272 }
273
274 pub fn clear_namespace(&self, namespace: &str) -> Result<usize> {
276 let prefix = Self::namespace_prefix(namespace);
277 let mut batch = rocksdb::WriteBatch::default();
278 let mut count = 0;
279
280 let iter = self.db.prefix_iterator(&prefix);
281 for item in iter {
282 match item {
283 Ok((key, _)) => {
284 if !key.starts_with(&prefix) {
285 break;
286 }
287 batch.delete(&key);
288 count += 1;
289 }
290 Err(_) => break,
291 }
292 }
293
294 if count > 0 {
295 self.db.write(batch).map_err(|e| {
296 DakeraError::Storage(format!("Failed to clear namespace from disk cache: {}", e))
297 })?;
298 }
299
300 debug!(namespace = %namespace, count = count, "Cleared namespace from disk cache");
301 Ok(count)
302 }
303
304 pub fn approximate_size(&self) -> u64 {
306 self.db
307 .property_int_value("rocksdb.estimate-live-data-size")
308 .ok()
309 .flatten()
310 .unwrap_or(0)
311 }
312
313 pub fn stats(&self) -> DiskCacheStats {
315 DiskCacheStats {
316 approximate_size_bytes: self.approximate_size(),
317 approximate_num_keys: self
318 .db
319 .property_int_value("rocksdb.estimate-num-keys")
320 .ok()
321 .flatten()
322 .unwrap_or(0),
323 }
324 }
325
326 pub fn flush(&self) -> Result<()> {
328 self.db
329 .flush()
330 .map_err(|e| DakeraError::Storage(format!("Failed to flush disk cache: {}", e)))
331 }
332}
333
334#[derive(Debug, Clone)]
336pub struct DiskCacheStats {
337 pub approximate_size_bytes: u64,
338 pub approximate_num_keys: u64,
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use tempfile::TempDir;
345
346 fn create_test_cache() -> (DiskCache, TempDir) {
347 let temp_dir = TempDir::new().unwrap();
348 let config = DiskCacheConfig {
349 path: temp_dir.path().to_string_lossy().to_string(),
350 ..Default::default()
351 };
352 let cache = DiskCache::new(config).unwrap();
353 (cache, temp_dir)
354 }
355
356 fn test_vector(id: &str) -> Vector {
357 Vector {
358 id: id.to_string(),
359 values: vec![1.0, 2.0, 3.0],
360 metadata: None,
361 ttl_seconds: None,
362 expires_at: None,
363 }
364 }
365
366 #[test]
367 fn test_put_and_get() {
368 let (cache, _dir) = create_test_cache();
369 let namespace = "test";
370 let vector = test_vector("v1");
371
372 cache.put(namespace, &vector).unwrap();
373 let result = cache.get(namespace, &"v1".to_string()).unwrap();
374
375 assert!(result.is_some());
376 let retrieved = result.unwrap();
377 assert_eq!(retrieved.id, "v1");
378 assert_eq!(retrieved.values, vec![1.0, 2.0, 3.0]);
379 }
380
381 #[test]
382 fn test_get_nonexistent() {
383 let (cache, _dir) = create_test_cache();
384 let result = cache.get("test", &"nonexistent".to_string()).unwrap();
385 assert!(result.is_none());
386 }
387
388 #[test]
389 fn test_batch_operations() {
390 let (cache, _dir) = create_test_cache();
391 let namespace = "test";
392 let vectors = vec![test_vector("v1"), test_vector("v2"), test_vector("v3")];
393
394 let count = cache.put_batch(namespace, &vectors).unwrap();
395 assert_eq!(count, 3);
396
397 let ids: Vec<String> = vec!["v1".to_string(), "v2".to_string(), "v3".to_string()];
398 let retrieved = cache.get_batch(namespace, &ids).unwrap();
399 assert_eq!(retrieved.len(), 3);
400 }
401
402 #[test]
403 fn test_get_all() {
404 let (cache, _dir) = create_test_cache();
405 let namespace = "test";
406 let vectors = vec![test_vector("v1"), test_vector("v2")];
407
408 cache.put_batch(namespace, &vectors).unwrap();
409 let all = cache.get_all(namespace).unwrap();
410 assert_eq!(all.len(), 2);
411 }
412
413 #[test]
414 fn test_delete() {
415 let (cache, _dir) = create_test_cache();
416 let namespace = "test";
417 let vector = test_vector("v1");
418
419 cache.put(namespace, &vector).unwrap();
420 assert!(cache.get(namespace, &"v1".to_string()).unwrap().is_some());
421
422 let deleted = cache.delete(namespace, &"v1".to_string()).unwrap();
423 assert!(deleted);
424
425 assert!(cache.get(namespace, &"v1".to_string()).unwrap().is_none());
426 }
427
428 #[test]
429 fn test_delete_batch() {
430 let (cache, _dir) = create_test_cache();
431 let namespace = "test";
432 let vectors = vec![test_vector("v1"), test_vector("v2"), test_vector("v3")];
433
434 cache.put_batch(namespace, &vectors).unwrap();
435
436 let ids = vec!["v1".to_string(), "v2".to_string()];
437 let deleted = cache.delete_batch(namespace, &ids).unwrap();
438 assert_eq!(deleted, 2);
439
440 assert!(cache.get(namespace, &"v1".to_string()).unwrap().is_none());
441 assert!(cache.get(namespace, &"v2".to_string()).unwrap().is_none());
442 assert!(cache.get(namespace, &"v3".to_string()).unwrap().is_some());
443 }
444
445 #[test]
446 fn test_clear_namespace() {
447 let (cache, _dir) = create_test_cache();
448 let vectors = vec![test_vector("v1"), test_vector("v2")];
449
450 cache.put_batch("ns1", &vectors).unwrap();
451 cache.put_batch("ns2", &vectors).unwrap();
452
453 let cleared = cache.clear_namespace("ns1").unwrap();
454 assert_eq!(cleared, 2);
455
456 assert!(cache.get_all("ns1").unwrap().is_empty());
457 assert_eq!(cache.get_all("ns2").unwrap().len(), 2);
458 }
459
460 #[test]
461 fn test_namespace_isolation() {
462 let (cache, _dir) = create_test_cache();
463 let vector = test_vector("v1");
464
465 cache.put("ns1", &vector).unwrap();
466 cache.put("ns2", &vector).unwrap();
467
468 assert!(cache.get("ns1", &"v1".to_string()).unwrap().is_some());
469 assert!(cache.get("ns2", &"v1".to_string()).unwrap().is_some());
470
471 cache.delete("ns1", &"v1".to_string()).unwrap();
472
473 assert!(cache.get("ns1", &"v1".to_string()).unwrap().is_none());
474 assert!(cache.get("ns2", &"v1".to_string()).unwrap().is_some());
475 }
476
477 #[test]
478 fn test_stats() {
479 let (cache, _dir) = create_test_cache();
480 let vectors = vec![test_vector("v1"), test_vector("v2")];
481
482 cache.put_batch("test", &vectors).unwrap();
483 cache.flush().unwrap();
484
485 let stats = cache.stats();
486 let _ = stats.approximate_num_keys;
488 }
489}