1use super::{
6 CacheStats, EmbeddingCacheEntry, PersistenceConfig, PersistenceFormat, QueryCacheEntry,
7 ResultCacheEntry, SemanticCacheEntry,
8};
9use crate::{RragError, RragResult};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs;
13use std::io::{Read, Write};
14use std::path::PathBuf;
15
16pub struct PersistenceManager {
18 config: PersistenceConfig,
20
21 storage_path: PathBuf,
23
24 serializer: Box<dyn CacheSerializer>,
26
27 stats: PersistenceStats,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct PersistenceStats {
34 pub save_count: u64,
36
37 pub load_count: u64,
39
40 pub save_failures: u64,
42
43 pub load_failures: u64,
45
46 pub bytes_written: u64,
48
49 pub bytes_read: u64,
51
52 pub last_save: Option<std::time::SystemTime>,
54
55 pub last_load: Option<std::time::SystemTime>,
57}
58
59pub trait CacheSerializer: Send + Sync {
61 fn serialize_cache_data(&self, data: &PersistedCacheData) -> RragResult<Vec<u8>>;
63
64 fn deserialize_cache_data(&self, data: &[u8]) -> RragResult<PersistedCacheData>;
66
67 fn format_name(&self) -> &str;
69}
70
71pub struct BinarySerializer;
73
74impl CacheSerializer for BinarySerializer {
75 fn serialize_cache_data(&self, data: &PersistedCacheData) -> RragResult<Vec<u8>> {
76 bincode::serialize(data)
77 .map_err(|e| RragError::serialization_with_message("binary", e.to_string()))
78 }
79
80 fn deserialize_cache_data(&self, data: &[u8]) -> RragResult<PersistedCacheData> {
81 bincode::deserialize(data)
82 .map_err(|e| RragError::serialization_with_message("binary", e.to_string()))
83 }
84
85 fn format_name(&self) -> &str {
86 "binary"
87 }
88}
89
90pub struct JsonSerializer;
92
93impl CacheSerializer for JsonSerializer {
94 fn serialize_cache_data(&self, data: &PersistedCacheData) -> RragResult<Vec<u8>> {
95 serde_json::to_vec(data)
96 .map_err(|e| RragError::serialization_with_message("json", e.to_string()))
97 }
98
99 fn deserialize_cache_data(&self, data: &[u8]) -> RragResult<PersistedCacheData> {
100 serde_json::from_slice(data)
101 .map_err(|e| RragError::serialization_with_message("json", e.to_string()))
102 }
103
104 fn format_name(&self) -> &str {
105 "json"
106 }
107}
108
109pub struct MessagePackSerializer;
111
112impl CacheSerializer for MessagePackSerializer {
113 fn serialize_cache_data(&self, data: &PersistedCacheData) -> RragResult<Vec<u8>> {
114 rmp_serde::to_vec(data)
115 .map_err(|e| RragError::serialization_with_message("msgpack", e.to_string()))
116 }
117
118 fn deserialize_cache_data(&self, data: &[u8]) -> RragResult<PersistedCacheData> {
119 rmp_serde::from_slice(data)
120 .map_err(|e| RragError::serialization_with_message("msgpack", e.to_string()))
121 }
122
123 fn format_name(&self) -> &str {
124 "msgpack"
125 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct PersistedCacheData {
131 pub version: u32,
133
134 pub timestamp: std::time::SystemTime,
136
137 pub query_cache: HashMap<String, QueryCacheEntry>,
139
140 pub embedding_cache: HashMap<String, EmbeddingCacheEntry>,
142
143 pub semantic_cache: HashMap<String, SemanticCacheEntry>,
145
146 pub result_cache: HashMap<String, ResultCacheEntry>,
148
149 pub stats: HashMap<String, CacheStats>,
151
152 pub metadata: PersistenceMetadata,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct PersistenceMetadata {
159 pub app_version: String,
161
162 pub config_hash: String,
164
165 pub total_entries: usize,
167
168 pub compression_enabled: bool,
170
171 pub custom: HashMap<String, String>,
173}
174
175impl PersistenceManager {
176 pub fn new(config: PersistenceConfig) -> RragResult<Self> {
178 let storage_path = PathBuf::from(&config.storage_path);
179
180 if !storage_path.exists() {
182 fs::create_dir_all(&storage_path)
183 .map_err(|e| RragError::storage("create_cache_directory", e))?;
184 }
185
186 let serializer: Box<dyn CacheSerializer> = match config.format {
187 PersistenceFormat::Binary => Box::new(BinarySerializer),
188 PersistenceFormat::Json => Box::new(JsonSerializer),
189 PersistenceFormat::MessagePack => Box::new(MessagePackSerializer),
190 };
191
192 Ok(Self {
193 config,
194 storage_path,
195 serializer,
196 stats: PersistenceStats::default(),
197 })
198 }
199
200 pub fn save(&mut self, data: &PersistedCacheData) -> RragResult<()> {
202 let start = std::time::Instant::now();
203
204 let serialized = self.serializer.serialize_cache_data(data)?;
206
207 let temp_path = self.get_temp_path();
209 let mut file =
210 fs::File::create(&temp_path).map_err(|e| RragError::storage("create_temp_file", e))?;
211
212 file.write_all(&serialized)
213 .map_err(|e| RragError::storage("write_cache_data", e))?;
214
215 file.sync_all()
216 .map_err(|e| RragError::storage("sync_cache_file", e))?;
217
218 let final_path = self.get_cache_path();
220 fs::rename(&temp_path, &final_path)
221 .map_err(|e| RragError::storage("rename_cache_file", e))?;
222
223 self.stats.save_count += 1;
225 self.stats.bytes_written += serialized.len() as u64;
226 self.stats.last_save = Some(std::time::SystemTime::now());
227
228 let duration = start.elapsed();
229 tracing::info!(
230 "Cache saved: {} entries, {} bytes, {:?}",
231 data.metadata.total_entries,
232 serialized.len(),
233 duration
234 );
235
236 Ok(())
237 }
238
239 pub fn load(&mut self) -> RragResult<PersistedCacheData> {
241 let start = std::time::Instant::now();
242 let cache_path = self.get_cache_path();
243
244 if !cache_path.exists() {
245 return Err(RragError::memory("load_cache", "Cache file not found"));
246 }
247
248 let mut file =
250 fs::File::open(&cache_path).map_err(|e| RragError::storage("open_cache_file", e))?;
251
252 let mut buffer = Vec::new();
253 file.read_to_end(&mut buffer)
254 .map_err(|e| RragError::storage("read_cache_file", e))?;
255
256 let data = self.serializer.deserialize_cache_data(&buffer)?;
258
259 if data.version != CACHE_VERSION {
261 return Err(RragError::validation(
262 "cache_version",
263 format!("version {}", CACHE_VERSION),
264 format!("version {}", data.version),
265 ));
266 }
267
268 self.stats.load_count += 1;
270 self.stats.bytes_read += buffer.len() as u64;
271 self.stats.last_load = Some(std::time::SystemTime::now());
272
273 let duration = start.elapsed();
274 tracing::info!(
275 "Cache loaded: {} entries, {} bytes, {:?}",
276 data.metadata.total_entries,
277 buffer.len(),
278 duration
279 );
280
281 Ok(data)
282 }
283
284 pub async fn save_async(&mut self, data: PersistedCacheData) -> RragResult<()> {
286 let serializer = self.create_serializer();
287 let path = self.get_cache_path();
288 let temp_path = self.get_temp_path();
289
290 tokio::task::spawn_blocking(move || {
292 let serialized = serializer.serialize_cache_data(&data)?;
293
294 let mut file = fs::File::create(&temp_path)
295 .map_err(|e| RragError::storage("create_temp_file", e))?;
296
297 file.write_all(&serialized)
298 .map_err(|e| RragError::storage("write_cache_data", e))?;
299
300 file.sync_all()
301 .map_err(|e| RragError::storage("sync_cache_file", e))?;
302
303 fs::rename(&temp_path, &path)
304 .map_err(|e| RragError::storage("rename_cache_file", e))?;
305
306 Ok(())
307 })
308 .await
309 .map_err(|e| RragError::memory("async_save", e.to_string()))?
310 }
311
312 pub fn backup(&self) -> RragResult<()> {
314 let cache_path = self.get_cache_path();
315 if !cache_path.exists() {
316 return Ok(());
317 }
318
319 let backup_path = self.get_backup_path();
320 fs::copy(&cache_path, &backup_path).map_err(|e| RragError::storage("create_backup", e))?;
321
322 tracing::info!("Cache backed up to {:?}", backup_path);
323 Ok(())
324 }
325
326 pub fn restore(&self) -> RragResult<()> {
328 let backup_path = self.get_backup_path();
329 if !backup_path.exists() {
330 return Err(RragError::memory("restore_backup", "Backup file not found"));
331 }
332
333 let cache_path = self.get_cache_path();
334 fs::copy(&backup_path, &cache_path)
335 .map_err(|e| RragError::storage("restore_from_backup", e))?;
336
337 tracing::info!("Cache restored from backup");
338 Ok(())
339 }
340
341 pub fn cleanup(&self, keep_days: u32) -> RragResult<()> {
343 let cutoff =
344 std::time::SystemTime::now() - std::time::Duration::from_secs(keep_days as u64 * 86400);
345
346 let entries = fs::read_dir(&self.storage_path)
347 .map_err(|e| RragError::storage("read_cache_directory", e))?;
348
349 let mut removed = 0;
350 for entry in entries {
351 let entry = entry.map_err(|e| RragError::storage("read_directory_entry", e))?;
352 let metadata = entry
353 .metadata()
354 .map_err(|e| RragError::storage("read_file_metadata", e))?;
355
356 if let Ok(modified) = metadata.modified() {
357 if modified < cutoff {
358 fs::remove_file(entry.path())
359 .map_err(|e| RragError::storage("remove_old_cache", e))?;
360 removed += 1;
361 }
362 }
363 }
364
365 tracing::info!("Cleaned up {} old cache files", removed);
366 Ok(())
367 }
368
369 fn get_cache_path(&self) -> PathBuf {
371 self.storage_path.join("cache.dat")
372 }
373
374 fn get_temp_path(&self) -> PathBuf {
376 self.storage_path.join("cache.tmp")
377 }
378
379 fn get_backup_path(&self) -> PathBuf {
381 self.storage_path.join("cache.bak")
382 }
383
384 fn create_serializer(&self) -> Box<dyn CacheSerializer> {
386 match self.config.format {
387 PersistenceFormat::Binary => Box::new(BinarySerializer),
388 PersistenceFormat::Json => Box::new(JsonSerializer),
389 PersistenceFormat::MessagePack => Box::new(MessagePackSerializer),
390 }
391 }
392}
393
394const CACHE_VERSION: u32 = 1;
396
397impl Default for PersistenceStats {
398 fn default() -> Self {
399 Self {
400 save_count: 0,
401 load_count: 0,
402 save_failures: 0,
403 load_failures: 0,
404 bytes_written: 0,
405 bytes_read: 0,
406 last_save: None,
407 last_load: None,
408 }
409 }
410}
411
412impl Default for PersistedCacheData {
413 fn default() -> Self {
414 Self {
415 version: CACHE_VERSION,
416 timestamp: std::time::SystemTime::now(),
417 query_cache: HashMap::new(),
418 embedding_cache: HashMap::new(),
419 semantic_cache: HashMap::new(),
420 result_cache: HashMap::new(),
421 stats: HashMap::new(),
422 metadata: PersistenceMetadata::default(),
423 }
424 }
425}
426
427impl Default for PersistenceMetadata {
428 fn default() -> Self {
429 Self {
430 app_version: env!("CARGO_PKG_VERSION").to_string(),
431 config_hash: String::new(),
432 total_entries: 0,
433 compression_enabled: false,
434 custom: HashMap::new(),
435 }
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use tempfile::TempDir;
443
444 fn create_test_config(dir: &Path) -> PersistenceConfig {
445 PersistenceConfig {
446 enabled: true,
447 storage_path: dir.to_str().unwrap().to_string(),
448 auto_save_interval: std::time::Duration::from_secs(60),
449 format: PersistenceFormat::Binary,
450 }
451 }
452
453 #[test]
454 fn test_binary_serializer() {
455 let serializer = BinarySerializer;
456 let data = PersistedCacheData::default();
457
458 let serialized = serializer.serialize_cache_data(&data).unwrap();
459 let deserialized = serializer.deserialize_cache_data(&serialized).unwrap();
460
461 assert_eq!(data.version, deserialized.version);
462 }
463
464 #[test]
465 fn test_json_serializer() {
466 let serializer = JsonSerializer;
467 let data = PersistedCacheData::default();
468
469 let serialized = serializer.serialize_cache_data(&data).unwrap();
470 let deserialized = serializer.deserialize_cache_data(&serialized).unwrap();
471
472 assert_eq!(data.version, deserialized.version);
473 }
474
475 #[test]
476 fn test_save_and_load() {
477 let temp_dir = TempDir::new().unwrap();
478 let config = create_test_config(temp_dir.path());
479 let mut manager = PersistenceManager::new(config).unwrap();
480
481 let data = PersistedCacheData::default();
482 manager.save(&data).unwrap();
483
484 let loaded = manager.load().unwrap();
485 assert_eq!(loaded.version, data.version);
486 }
487
488 #[test]
489 fn test_backup_and_restore() {
490 let temp_dir = TempDir::new().unwrap();
491 let config = create_test_config(temp_dir.path());
492 let mut manager = PersistenceManager::new(config).unwrap();
493
494 let data = PersistedCacheData::default();
495 manager.save(&data).unwrap();
496
497 manager.backup().unwrap();
498
499 fs::remove_file(manager.get_cache_path()).unwrap();
501
502 manager.restore().unwrap();
504
505 let loaded = manager.load().unwrap();
506 assert_eq!(loaded.version, data.version);
507 }
508}