rexis_rag/caching/
persistence.rs

1//! # Cache Persistence
2//!
3//! Persistence layer for cache data across restarts.
4
5use super::{
6    CacheStats, EmbeddingCacheEntry, PersistenceConfig, PersistenceFormat, QueryCacheEntry,
7    ResultCacheEntry, SemanticCacheEntry,
8};
9use crate::{RragError, RragResult};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs;
13use std::io::{Read, Write};
14use std::path::PathBuf;
15
16/// Cache persistence manager
17pub struct PersistenceManager {
18    /// Configuration
19    config: PersistenceConfig,
20
21    /// Storage path
22    storage_path: PathBuf,
23
24    /// Serializer based on format
25    serializer: Box<dyn CacheSerializer>,
26
27    /// Persistence statistics
28    stats: PersistenceStats,
29}
30
31/// Persistence statistics
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct PersistenceStats {
34    /// Total saves
35    pub save_count: u64,
36
37    /// Total loads
38    pub load_count: u64,
39
40    /// Failed saves
41    pub save_failures: u64,
42
43    /// Failed loads
44    pub load_failures: u64,
45
46    /// Total bytes written
47    pub bytes_written: u64,
48
49    /// Total bytes read
50    pub bytes_read: u64,
51
52    /// Last save time
53    pub last_save: Option<std::time::SystemTime>,
54
55    /// Last load time
56    pub last_load: Option<std::time::SystemTime>,
57}
58
59/// Cache serializer trait
60pub trait CacheSerializer: Send + Sync {
61    /// Serialize cache data
62    fn serialize_cache_data(&self, data: &PersistedCacheData) -> RragResult<Vec<u8>>;
63
64    /// Deserialize cache data
65    fn deserialize_cache_data(&self, data: &[u8]) -> RragResult<PersistedCacheData>;
66
67    /// Get format name
68    fn format_name(&self) -> &str;
69}
70
71/// Binary serializer using bincode
72pub struct BinarySerializer;
73
74impl CacheSerializer for BinarySerializer {
75    fn serialize_cache_data(&self, data: &PersistedCacheData) -> RragResult<Vec<u8>> {
76        bincode::serialize(data)
77            .map_err(|e| RragError::serialization_with_message("binary", e.to_string()))
78    }
79
80    fn deserialize_cache_data(&self, data: &[u8]) -> RragResult<PersistedCacheData> {
81        bincode::deserialize(data)
82            .map_err(|e| RragError::serialization_with_message("binary", e.to_string()))
83    }
84
85    fn format_name(&self) -> &str {
86        "binary"
87    }
88}
89
90/// JSON serializer
91pub struct JsonSerializer;
92
93impl CacheSerializer for JsonSerializer {
94    fn serialize_cache_data(&self, data: &PersistedCacheData) -> RragResult<Vec<u8>> {
95        serde_json::to_vec(data)
96            .map_err(|e| RragError::serialization_with_message("json", e.to_string()))
97    }
98
99    fn deserialize_cache_data(&self, data: &[u8]) -> RragResult<PersistedCacheData> {
100        serde_json::from_slice(data)
101            .map_err(|e| RragError::serialization_with_message("json", e.to_string()))
102    }
103
104    fn format_name(&self) -> &str {
105        "json"
106    }
107}
108
109/// MessagePack serializer
110pub struct MessagePackSerializer;
111
112impl CacheSerializer for MessagePackSerializer {
113    fn serialize_cache_data(&self, data: &PersistedCacheData) -> RragResult<Vec<u8>> {
114        rmp_serde::to_vec(data)
115            .map_err(|e| RragError::serialization_with_message("msgpack", e.to_string()))
116    }
117
118    fn deserialize_cache_data(&self, data: &[u8]) -> RragResult<PersistedCacheData> {
119        rmp_serde::from_slice(data)
120            .map_err(|e| RragError::serialization_with_message("msgpack", e.to_string()))
121    }
122
123    fn format_name(&self) -> &str {
124        "msgpack"
125    }
126}
127
128/// Persisted cache data structure
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct PersistedCacheData {
131    /// Version for compatibility
132    pub version: u32,
133
134    /// Timestamp of persistence
135    pub timestamp: std::time::SystemTime,
136
137    /// Query cache entries
138    pub query_cache: HashMap<String, QueryCacheEntry>,
139
140    /// Embedding cache entries
141    pub embedding_cache: HashMap<String, EmbeddingCacheEntry>,
142
143    /// Semantic cache entries
144    pub semantic_cache: HashMap<String, SemanticCacheEntry>,
145
146    /// Result cache entries
147    pub result_cache: HashMap<String, ResultCacheEntry>,
148
149    /// Cache statistics
150    pub stats: HashMap<String, CacheStats>,
151
152    /// Metadata
153    pub metadata: PersistenceMetadata,
154}
155
156/// Persistence metadata
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct PersistenceMetadata {
159    /// Application version
160    pub app_version: String,
161
162    /// Cache configuration hash
163    pub config_hash: String,
164
165    /// Total entries
166    pub total_entries: usize,
167
168    /// Compression enabled
169    pub compression_enabled: bool,
170
171    /// Custom metadata
172    pub custom: HashMap<String, String>,
173}
174
175impl PersistenceManager {
176    /// Create new persistence manager
177    pub fn new(config: PersistenceConfig) -> RragResult<Self> {
178        let storage_path = PathBuf::from(&config.storage_path);
179
180        // Create storage directory if it doesn't exist
181        if !storage_path.exists() {
182            fs::create_dir_all(&storage_path)
183                .map_err(|e| RragError::storage("create_cache_directory", e))?;
184        }
185
186        let serializer: Box<dyn CacheSerializer> = match config.format {
187            PersistenceFormat::Binary => Box::new(BinarySerializer),
188            PersistenceFormat::Json => Box::new(JsonSerializer),
189            PersistenceFormat::MessagePack => Box::new(MessagePackSerializer),
190        };
191
192        Ok(Self {
193            config,
194            storage_path,
195            serializer,
196            stats: PersistenceStats::default(),
197        })
198    }
199
200    /// Save cache data to disk
201    pub fn save(&mut self, data: &PersistedCacheData) -> RragResult<()> {
202        let start = std::time::Instant::now();
203
204        // Serialize data
205        let serialized = self.serializer.serialize_cache_data(data)?;
206
207        // Write to temporary file first
208        let temp_path = self.get_temp_path();
209        let mut file =
210            fs::File::create(&temp_path).map_err(|e| RragError::storage("create_temp_file", e))?;
211
212        file.write_all(&serialized)
213            .map_err(|e| RragError::storage("write_cache_data", e))?;
214
215        file.sync_all()
216            .map_err(|e| RragError::storage("sync_cache_file", e))?;
217
218        // Rename to final path (atomic on most systems)
219        let final_path = self.get_cache_path();
220        fs::rename(&temp_path, &final_path)
221            .map_err(|e| RragError::storage("rename_cache_file", e))?;
222
223        // Update stats
224        self.stats.save_count += 1;
225        self.stats.bytes_written += serialized.len() as u64;
226        self.stats.last_save = Some(std::time::SystemTime::now());
227
228        let duration = start.elapsed();
229        tracing::info!(
230            "Cache saved: {} entries, {} bytes, {:?}",
231            data.metadata.total_entries,
232            serialized.len(),
233            duration
234        );
235
236        Ok(())
237    }
238
239    /// Load cache data from disk
240    pub fn load(&mut self) -> RragResult<PersistedCacheData> {
241        let start = std::time::Instant::now();
242        let cache_path = self.get_cache_path();
243
244        if !cache_path.exists() {
245            return Err(RragError::memory("load_cache", "Cache file not found"));
246        }
247
248        // Read file
249        let mut file =
250            fs::File::open(&cache_path).map_err(|e| RragError::storage("open_cache_file", e))?;
251
252        let mut buffer = Vec::new();
253        file.read_to_end(&mut buffer)
254            .map_err(|e| RragError::storage("read_cache_file", e))?;
255
256        // Deserialize data
257        let data = self.serializer.deserialize_cache_data(&buffer)?;
258
259        // Validate version
260        if data.version != CACHE_VERSION {
261            return Err(RragError::validation(
262                "cache_version",
263                format!("version {}", CACHE_VERSION),
264                format!("version {}", data.version),
265            ));
266        }
267
268        // Update stats
269        self.stats.load_count += 1;
270        self.stats.bytes_read += buffer.len() as u64;
271        self.stats.last_load = Some(std::time::SystemTime::now());
272
273        let duration = start.elapsed();
274        tracing::info!(
275            "Cache loaded: {} entries, {} bytes, {:?}",
276            data.metadata.total_entries,
277            buffer.len(),
278            duration
279        );
280
281        Ok(data)
282    }
283
284    /// Save cache asynchronously
285    pub async fn save_async(&mut self, data: PersistedCacheData) -> RragResult<()> {
286        let serializer = self.create_serializer();
287        let path = self.get_cache_path();
288        let temp_path = self.get_temp_path();
289
290        // Spawn blocking task for IO
291        tokio::task::spawn_blocking(move || {
292            let serialized = serializer.serialize_cache_data(&data)?;
293
294            let mut file = fs::File::create(&temp_path)
295                .map_err(|e| RragError::storage("create_temp_file", e))?;
296
297            file.write_all(&serialized)
298                .map_err(|e| RragError::storage("write_cache_data", e))?;
299
300            file.sync_all()
301                .map_err(|e| RragError::storage("sync_cache_file", e))?;
302
303            fs::rename(&temp_path, &path)
304                .map_err(|e| RragError::storage("rename_cache_file", e))?;
305
306            Ok(())
307        })
308        .await
309        .map_err(|e| RragError::memory("async_save", e.to_string()))?
310    }
311
312    /// Create backup of current cache
313    pub fn backup(&self) -> RragResult<()> {
314        let cache_path = self.get_cache_path();
315        if !cache_path.exists() {
316            return Ok(());
317        }
318
319        let backup_path = self.get_backup_path();
320        fs::copy(&cache_path, &backup_path).map_err(|e| RragError::storage("create_backup", e))?;
321
322        tracing::info!("Cache backed up to {:?}", backup_path);
323        Ok(())
324    }
325
326    /// Restore from backup
327    pub fn restore(&self) -> RragResult<()> {
328        let backup_path = self.get_backup_path();
329        if !backup_path.exists() {
330            return Err(RragError::memory("restore_backup", "Backup file not found"));
331        }
332
333        let cache_path = self.get_cache_path();
334        fs::copy(&backup_path, &cache_path)
335            .map_err(|e| RragError::storage("restore_from_backup", e))?;
336
337        tracing::info!("Cache restored from backup");
338        Ok(())
339    }
340
341    /// Clean old cache files
342    pub fn cleanup(&self, keep_days: u32) -> RragResult<()> {
343        let cutoff =
344            std::time::SystemTime::now() - std::time::Duration::from_secs(keep_days as u64 * 86400);
345
346        let entries = fs::read_dir(&self.storage_path)
347            .map_err(|e| RragError::storage("read_cache_directory", e))?;
348
349        let mut removed = 0;
350        for entry in entries {
351            let entry = entry.map_err(|e| RragError::storage("read_directory_entry", e))?;
352            let metadata = entry
353                .metadata()
354                .map_err(|e| RragError::storage("read_file_metadata", e))?;
355
356            if let Ok(modified) = metadata.modified() {
357                if modified < cutoff {
358                    fs::remove_file(entry.path())
359                        .map_err(|e| RragError::storage("remove_old_cache", e))?;
360                    removed += 1;
361                }
362            }
363        }
364
365        tracing::info!("Cleaned up {} old cache files", removed);
366        Ok(())
367    }
368
369    /// Get cache file path
370    fn get_cache_path(&self) -> PathBuf {
371        self.storage_path.join("cache.dat")
372    }
373
374    /// Get temporary file path
375    fn get_temp_path(&self) -> PathBuf {
376        self.storage_path.join("cache.tmp")
377    }
378
379    /// Get backup file path
380    fn get_backup_path(&self) -> PathBuf {
381        self.storage_path.join("cache.bak")
382    }
383
384    /// Create serializer instance
385    fn create_serializer(&self) -> Box<dyn CacheSerializer> {
386        match self.config.format {
387            PersistenceFormat::Binary => Box::new(BinarySerializer),
388            PersistenceFormat::Json => Box::new(JsonSerializer),
389            PersistenceFormat::MessagePack => Box::new(MessagePackSerializer),
390        }
391    }
392}
393
394/// Cache version for compatibility checking
395const CACHE_VERSION: u32 = 1;
396
397impl Default for PersistenceStats {
398    fn default() -> Self {
399        Self {
400            save_count: 0,
401            load_count: 0,
402            save_failures: 0,
403            load_failures: 0,
404            bytes_written: 0,
405            bytes_read: 0,
406            last_save: None,
407            last_load: None,
408        }
409    }
410}
411
412impl Default for PersistedCacheData {
413    fn default() -> Self {
414        Self {
415            version: CACHE_VERSION,
416            timestamp: std::time::SystemTime::now(),
417            query_cache: HashMap::new(),
418            embedding_cache: HashMap::new(),
419            semantic_cache: HashMap::new(),
420            result_cache: HashMap::new(),
421            stats: HashMap::new(),
422            metadata: PersistenceMetadata::default(),
423        }
424    }
425}
426
427impl Default for PersistenceMetadata {
428    fn default() -> Self {
429        Self {
430            app_version: env!("CARGO_PKG_VERSION").to_string(),
431            config_hash: String::new(),
432            total_entries: 0,
433            compression_enabled: false,
434            custom: HashMap::new(),
435        }
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use tempfile::TempDir;
443
444    fn create_test_config(dir: &Path) -> PersistenceConfig {
445        PersistenceConfig {
446            enabled: true,
447            storage_path: dir.to_str().unwrap().to_string(),
448            auto_save_interval: std::time::Duration::from_secs(60),
449            format: PersistenceFormat::Binary,
450        }
451    }
452
453    #[test]
454    fn test_binary_serializer() {
455        let serializer = BinarySerializer;
456        let data = PersistedCacheData::default();
457
458        let serialized = serializer.serialize_cache_data(&data).unwrap();
459        let deserialized = serializer.deserialize_cache_data(&serialized).unwrap();
460
461        assert_eq!(data.version, deserialized.version);
462    }
463
464    #[test]
465    fn test_json_serializer() {
466        let serializer = JsonSerializer;
467        let data = PersistedCacheData::default();
468
469        let serialized = serializer.serialize_cache_data(&data).unwrap();
470        let deserialized = serializer.deserialize_cache_data(&serialized).unwrap();
471
472        assert_eq!(data.version, deserialized.version);
473    }
474
475    #[test]
476    fn test_save_and_load() {
477        let temp_dir = TempDir::new().unwrap();
478        let config = create_test_config(temp_dir.path());
479        let mut manager = PersistenceManager::new(config).unwrap();
480
481        let data = PersistedCacheData::default();
482        manager.save(&data).unwrap();
483
484        let loaded = manager.load().unwrap();
485        assert_eq!(loaded.version, data.version);
486    }
487
488    #[test]
489    fn test_backup_and_restore() {
490        let temp_dir = TempDir::new().unwrap();
491        let config = create_test_config(temp_dir.path());
492        let mut manager = PersistenceManager::new(config).unwrap();
493
494        let data = PersistedCacheData::default();
495        manager.save(&data).unwrap();
496
497        manager.backup().unwrap();
498
499        // Delete original
500        fs::remove_file(manager.get_cache_path()).unwrap();
501
502        // Restore from backup
503        manager.restore().unwrap();
504
505        let loaded = manager.load().unwrap();
506        assert_eq!(loaded.version, data.version);
507    }
508}