1use std::path::Path;
7use std::sync::Arc;
8
9use rocksdb::{Options, DB};
10use serde::{Deserialize, Serialize};
11use tracing::{debug, warn};
12
13use common::{DakeraError, Result, Vector, VectorId};
14
15#[derive(Debug, Clone)]
17pub struct DiskCacheConfig {
18 pub path: String,
20 pub max_size_bytes: u64,
22 pub compression: bool,
24 pub write_buffer_size: usize,
26 pub max_write_buffer_number: i32,
28}
29
30impl Default for DiskCacheConfig {
31 fn default() -> Self {
32 Self {
33 path: "./cache".to_string(),
34 max_size_bytes: 10 * 1024 * 1024 * 1024, compression: true,
36 write_buffer_size: 64 * 1024 * 1024, max_write_buffer_number: 3,
38 }
39 }
40}
41
42#[derive(Debug, Serialize, Deserialize)]
44struct CacheEntry {
45 vector: Vector,
46 access_count: u64,
47 created_at: u64,
48}
49
50pub struct DiskCache {
52 db: Arc<DB>,
53 #[allow(dead_code)]
54 config: DiskCacheConfig,
55}
56
57impl DiskCache {
58 pub fn new(config: DiskCacheConfig) -> Result<Self> {
60 let mut opts = Options::default();
61 opts.create_if_missing(true);
62 opts.set_write_buffer_size(config.write_buffer_size);
63 opts.set_max_write_buffer_number(config.max_write_buffer_number);
64
65 if config.compression {
66 opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
67 }
68
69 opts.set_level_compaction_dynamic_level_bytes(true);
71 opts.set_max_background_jobs(4);
72
73 let db = DB::open(&opts, &config.path)
74 .map_err(|e| DakeraError::Storage(format!("Failed to open RocksDB: {}", e)))?;
75
76 debug!(path = %config.path, "Disk cache initialized");
77
78 Ok(Self {
79 db: Arc::new(db),
80 config,
81 })
82 }
83
84 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
86 let config = DiskCacheConfig {
87 path: path.as_ref().to_string_lossy().to_string(),
88 ..Default::default()
89 };
90 Self::new(config)
91 }
92
93 fn make_key(namespace: &str, id: &VectorId) -> Vec<u8> {
95 format!("{}:{}", namespace, id).into_bytes()
96 }
97
98 fn namespace_prefix(namespace: &str) -> Vec<u8> {
100 format!("{}:", namespace).into_bytes()
101 }
102
103 pub fn put(&self, namespace: &str, vector: &Vector) -> Result<()> {
105 let key = Self::make_key(namespace, &vector.id);
106 let entry = CacheEntry {
107 vector: vector.clone(),
108 access_count: 1,
109 created_at: std::time::SystemTime::now()
110 .duration_since(std::time::UNIX_EPOCH)
111 .unwrap_or_default()
112 .as_secs(),
113 };
114
115 let value = serde_json::to_vec(&entry)
116 .map_err(|e| DakeraError::Storage(format!("Failed to serialize cache entry: {}", e)))?;
117
118 self.db
119 .put(&key, &value)
120 .map_err(|e| DakeraError::Storage(format!("Failed to write to disk cache: {}", e)))?;
121
122 debug!(namespace = %namespace, id = %vector.id, "Cached vector to disk");
123 Ok(())
124 }
125
126 pub fn put_batch(&self, namespace: &str, vectors: &[Vector]) -> Result<usize> {
128 let mut batch = rocksdb::WriteBatch::default();
129 let now = std::time::SystemTime::now()
130 .duration_since(std::time::UNIX_EPOCH)
131 .unwrap_or_default()
132 .as_secs();
133
134 for vector in vectors {
135 let key = Self::make_key(namespace, &vector.id);
136 let entry = CacheEntry {
137 vector: vector.clone(),
138 access_count: 1,
139 created_at: now,
140 };
141
142 let value = serde_json::to_vec(&entry).map_err(|e| {
143 DakeraError::Storage(format!("Failed to serialize cache entry: {}", e))
144 })?;
145
146 batch.put(&key, &value);
147 }
148
149 let count = vectors.len();
150 self.db.write(batch).map_err(|e| {
151 DakeraError::Storage(format!("Failed to write batch to disk cache: {}", e))
152 })?;
153
154 debug!(namespace = %namespace, count = count, "Batch cached vectors to disk");
155 Ok(count)
156 }
157
158 pub fn get(&self, namespace: &str, id: &VectorId) -> Result<Option<Vector>> {
160 let key = Self::make_key(namespace, id);
161
162 match self.db.get(&key) {
163 Ok(Some(value)) => {
164 let entry: CacheEntry = serde_json::from_slice(&value).map_err(|e| {
165 DakeraError::Storage(format!("Failed to deserialize cache entry: {}", e))
166 })?;
167 Ok(Some(entry.vector))
168 }
169 Ok(None) => Ok(None),
170 Err(e) => {
171 warn!(error = %e, "Failed to read from disk cache");
172 Ok(None)
173 }
174 }
175 }
176
177 pub fn get_batch(&self, namespace: &str, ids: &[VectorId]) -> Result<Vec<Vector>> {
179 let keys: Vec<Vec<u8>> = ids.iter().map(|id| Self::make_key(namespace, id)).collect();
180
181 let results = self.db.multi_get(&keys);
182 let mut vectors = Vec::with_capacity(ids.len());
183
184 for result in results {
185 if let Ok(Some(value)) = result {
186 if let Ok(entry) = serde_json::from_slice::<CacheEntry>(&value) {
187 vectors.push(entry.vector);
188 }
189 }
190 }
191
192 Ok(vectors)
193 }
194
195 pub fn get_all(&self, namespace: &str) -> Result<Vec<Vector>> {
197 let prefix = Self::namespace_prefix(namespace);
198 let mut vectors = Vec::new();
199
200 let iter = self.db.prefix_iterator(&prefix);
201 for item in iter {
202 match item {
203 Ok((key, value)) => {
204 if !key.starts_with(&prefix) {
206 break;
207 }
208
209 if let Ok(entry) = serde_json::from_slice::<CacheEntry>(&value) {
210 vectors.push(entry.vector);
211 }
212 }
213 Err(e) => {
214 warn!(error = %e, "Error iterating disk cache");
215 break;
216 }
217 }
218 }
219
220 Ok(vectors)
221 }
222
223 pub fn delete(&self, namespace: &str, id: &VectorId) -> Result<bool> {
225 let key = Self::make_key(namespace, id);
226
227 let existed = self
229 .db
230 .get(&key)
231 .map_err(|e| DakeraError::Storage(format!("Failed to check disk cache: {}", e)))?
232 .is_some();
233
234 if existed {
235 self.db.delete(&key).map_err(|e| {
236 DakeraError::Storage(format!("Failed to delete from disk cache: {}", e))
237 })?;
238 }
239
240 Ok(existed)
241 }
242
243 pub fn delete_batch(&self, namespace: &str, ids: &[VectorId]) -> Result<usize> {
245 let mut batch = rocksdb::WriteBatch::default();
246 let mut count = 0;
247
248 for id in ids {
249 let key = Self::make_key(namespace, id);
250 if self.db.get(&key).ok().flatten().is_some() {
251 batch.delete(&key);
252 count += 1;
253 }
254 }
255
256 self.db.write(batch).map_err(|e| {
257 DakeraError::Storage(format!("Failed to delete batch from disk cache: {}", e))
258 })?;
259
260 Ok(count)
261 }
262
263 pub fn clear_namespace(&self, namespace: &str) -> Result<usize> {
265 let prefix = Self::namespace_prefix(namespace);
266 let mut batch = rocksdb::WriteBatch::default();
267 let mut count = 0;
268
269 let iter = self.db.prefix_iterator(&prefix);
270 for item in iter {
271 match item {
272 Ok((key, _)) => {
273 if !key.starts_with(&prefix) {
274 break;
275 }
276 batch.delete(&key);
277 count += 1;
278 }
279 Err(_) => break,
280 }
281 }
282
283 if count > 0 {
284 self.db.write(batch).map_err(|e| {
285 DakeraError::Storage(format!("Failed to clear namespace from disk cache: {}", e))
286 })?;
287 }
288
289 debug!(namespace = %namespace, count = count, "Cleared namespace from disk cache");
290 Ok(count)
291 }
292
293 pub fn approximate_size(&self) -> u64 {
295 self.db
296 .property_int_value("rocksdb.estimate-live-data-size")
297 .ok()
298 .flatten()
299 .unwrap_or(0)
300 }
301
302 pub fn stats(&self) -> DiskCacheStats {
304 DiskCacheStats {
305 approximate_size_bytes: self.approximate_size(),
306 approximate_num_keys: self
307 .db
308 .property_int_value("rocksdb.estimate-num-keys")
309 .ok()
310 .flatten()
311 .unwrap_or(0),
312 }
313 }
314
315 pub fn flush(&self) -> Result<()> {
317 self.db
318 .flush()
319 .map_err(|e| DakeraError::Storage(format!("Failed to flush disk cache: {}", e)))
320 }
321}
322
323#[derive(Debug, Clone)]
325pub struct DiskCacheStats {
326 pub approximate_size_bytes: u64,
327 pub approximate_num_keys: u64,
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use tempfile::TempDir;
334
335 fn create_test_cache() -> (DiskCache, TempDir) {
336 let temp_dir = TempDir::new().unwrap();
337 let config = DiskCacheConfig {
338 path: temp_dir.path().to_string_lossy().to_string(),
339 ..Default::default()
340 };
341 let cache = DiskCache::new(config).unwrap();
342 (cache, temp_dir)
343 }
344
345 fn test_vector(id: &str) -> Vector {
346 Vector {
347 id: id.to_string(),
348 values: vec![1.0, 2.0, 3.0],
349 metadata: None,
350 ttl_seconds: None,
351 expires_at: None,
352 }
353 }
354
355 #[test]
356 fn test_put_and_get() {
357 let (cache, _dir) = create_test_cache();
358 let namespace = "test";
359 let vector = test_vector("v1");
360
361 cache.put(namespace, &vector).unwrap();
362 let result = cache.get(namespace, &"v1".to_string()).unwrap();
363
364 assert!(result.is_some());
365 let retrieved = result.unwrap();
366 assert_eq!(retrieved.id, "v1");
367 assert_eq!(retrieved.values, vec![1.0, 2.0, 3.0]);
368 }
369
370 #[test]
371 fn test_get_nonexistent() {
372 let (cache, _dir) = create_test_cache();
373 let result = cache.get("test", &"nonexistent".to_string()).unwrap();
374 assert!(result.is_none());
375 }
376
377 #[test]
378 fn test_batch_operations() {
379 let (cache, _dir) = create_test_cache();
380 let namespace = "test";
381 let vectors = vec![test_vector("v1"), test_vector("v2"), test_vector("v3")];
382
383 let count = cache.put_batch(namespace, &vectors).unwrap();
384 assert_eq!(count, 3);
385
386 let ids: Vec<String> = vec!["v1".to_string(), "v2".to_string(), "v3".to_string()];
387 let retrieved = cache.get_batch(namespace, &ids).unwrap();
388 assert_eq!(retrieved.len(), 3);
389 }
390
391 #[test]
392 fn test_get_all() {
393 let (cache, _dir) = create_test_cache();
394 let namespace = "test";
395 let vectors = vec![test_vector("v1"), test_vector("v2")];
396
397 cache.put_batch(namespace, &vectors).unwrap();
398 let all = cache.get_all(namespace).unwrap();
399 assert_eq!(all.len(), 2);
400 }
401
402 #[test]
403 fn test_delete() {
404 let (cache, _dir) = create_test_cache();
405 let namespace = "test";
406 let vector = test_vector("v1");
407
408 cache.put(namespace, &vector).unwrap();
409 assert!(cache.get(namespace, &"v1".to_string()).unwrap().is_some());
410
411 let deleted = cache.delete(namespace, &"v1".to_string()).unwrap();
412 assert!(deleted);
413
414 assert!(cache.get(namespace, &"v1".to_string()).unwrap().is_none());
415 }
416
417 #[test]
418 fn test_delete_batch() {
419 let (cache, _dir) = create_test_cache();
420 let namespace = "test";
421 let vectors = vec![test_vector("v1"), test_vector("v2"), test_vector("v3")];
422
423 cache.put_batch(namespace, &vectors).unwrap();
424
425 let ids = vec!["v1".to_string(), "v2".to_string()];
426 let deleted = cache.delete_batch(namespace, &ids).unwrap();
427 assert_eq!(deleted, 2);
428
429 assert!(cache.get(namespace, &"v1".to_string()).unwrap().is_none());
430 assert!(cache.get(namespace, &"v2".to_string()).unwrap().is_none());
431 assert!(cache.get(namespace, &"v3".to_string()).unwrap().is_some());
432 }
433
434 #[test]
435 fn test_clear_namespace() {
436 let (cache, _dir) = create_test_cache();
437 let vectors = vec![test_vector("v1"), test_vector("v2")];
438
439 cache.put_batch("ns1", &vectors).unwrap();
440 cache.put_batch("ns2", &vectors).unwrap();
441
442 let cleared = cache.clear_namespace("ns1").unwrap();
443 assert_eq!(cleared, 2);
444
445 assert!(cache.get_all("ns1").unwrap().is_empty());
446 assert_eq!(cache.get_all("ns2").unwrap().len(), 2);
447 }
448
449 #[test]
450 fn test_namespace_isolation() {
451 let (cache, _dir) = create_test_cache();
452 let vector = test_vector("v1");
453
454 cache.put("ns1", &vector).unwrap();
455 cache.put("ns2", &vector).unwrap();
456
457 assert!(cache.get("ns1", &"v1".to_string()).unwrap().is_some());
458 assert!(cache.get("ns2", &"v1".to_string()).unwrap().is_some());
459
460 cache.delete("ns1", &"v1".to_string()).unwrap();
461
462 assert!(cache.get("ns1", &"v1".to_string()).unwrap().is_none());
463 assert!(cache.get("ns2", &"v1".to_string()).unwrap().is_some());
464 }
465
466 #[test]
467 fn test_stats() {
468 let (cache, _dir) = create_test_cache();
469 let vectors = vec![test_vector("v1"), test_vector("v2")];
470
471 cache.put_batch("test", &vectors).unwrap();
472 cache.flush().unwrap();
473
474 let stats = cache.stats();
475 let _ = stats.approximate_num_keys;
477 }
478}