1use super::fingerprint::{blake3_hash, DocumentFingerprint};
14use super::types::{Bm25Config, RrfConfig};
15use super::IndexedDocument;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fs;
19use std::io::{self, Write};
20use std::path::{Path, PathBuf};
21
22pub const INDEX_VERSION: &str = "1.1.0";
26
27const CACHE_SUBDIR: &str = "batuta/rag";
29
30const MANIFEST_FILE: &str = "manifest.json";
32
33const INDEX_FILE: &str = "index.json";
35
36const DOCUMENTS_FILE: &str = "documents.json";
38
39const FINGERPRINTS_FILE: &str = "fingerprints.json";
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct RagManifest {
47 pub version: String,
49 pub index_checksum: [u8; 32],
51 pub docs_checksum: [u8; 32],
53 pub sources: Vec<CorpusSource>,
55 pub indexed_at: u64,
57 pub batuta_version: String,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct CorpusSource {
64 pub id: String,
66 pub commit: Option<String>,
68 pub doc_count: usize,
70 pub chunk_count: usize,
72}
73
74#[derive(Debug, Clone, Default, Serialize, Deserialize)]
76pub struct PersistedIndex {
77 pub inverted_index: HashMap<String, HashMap<String, usize>>,
79 pub doc_lengths: HashMap<String, usize>,
81 pub bm25_config: Bm25Config,
83 pub rrf_config: RrfConfig,
85 pub avg_doc_length: f64,
87}
88
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
91pub struct PersistedDocuments {
92 pub documents: HashMap<String, IndexedDocument>,
94 pub fingerprints: HashMap<String, DocumentFingerprint>,
96 pub total_chunks: usize,
98 #[serde(default)]
100 pub chunk_contents: HashMap<String, String>,
101}
102
103#[derive(Debug, thiserror::Error)]
105pub enum PersistenceError {
106 #[error("I/O error: {0}")]
108 Io(#[from] io::Error),
109
110 #[error("JSON error: {0}")]
112 Json(#[from] serde_json::Error),
113
114 #[error("Checksum mismatch for {file}: expected {expected:x?}, got {actual:x?}")]
116 ChecksumMismatch { file: String, expected: [u8; 32], actual: [u8; 32] },
117
118 #[error("Version mismatch: index version {index_version}, expected {expected_version}")]
120 VersionMismatch { index_version: String, expected_version: String },
121
122 #[error("Cache directory not found")]
124 CacheDirNotFound,
125
126 #[error("No cached index found")]
128 NoCachedIndex,
129}
130
131#[derive(Debug)]
135pub struct RagPersistence {
136 cache_path: PathBuf,
138}
139
140impl RagPersistence {
141 pub fn new() -> Self {
145 Self { cache_path: Self::default_cache_path() }
146 }
147
148 pub fn with_path(path: PathBuf) -> Self {
150 Self { cache_path: path }
151 }
152
153 fn default_cache_path() -> PathBuf {
157 #[cfg(feature = "native")]
158 {
159 dirs::cache_dir().unwrap_or_else(|| PathBuf::from(".cache")).join(CACHE_SUBDIR)
160 }
161 #[cfg(not(feature = "native"))]
162 {
163 PathBuf::from(".cache").join(CACHE_SUBDIR)
164 }
165 }
166
167 pub fn cache_path(&self) -> &Path {
169 &self.cache_path
170 }
171
172 pub fn save(
184 &self,
185 index: &PersistedIndex,
186 docs: &PersistedDocuments,
187 sources: Vec<CorpusSource>,
188 ) -> Result<(), PersistenceError> {
189 fs::create_dir_all(&self.cache_path)?;
191
192 self.cleanup_tmp_files();
194
195 let index_json = serde_json::to_string_pretty(index)?;
197 let docs_json = serde_json::to_string_pretty(docs)?;
198
199 let fingerprints_json = serde_json::to_string_pretty(&docs.fingerprints)?;
201
202 let index_checksum = blake3_hash(index_json.as_bytes());
204 let docs_checksum = blake3_hash(docs_json.as_bytes());
205
206 let manifest = RagManifest {
208 version: INDEX_VERSION.to_string(),
209 index_checksum,
210 docs_checksum,
211 sources,
212 indexed_at: current_timestamp_ms(),
213 batuta_version: env!("CARGO_PKG_VERSION").to_string(),
214 };
215 let manifest_json = serde_json::to_string_pretty(&manifest)?;
216
217 self.prepare_write(INDEX_FILE, index_json.as_bytes())?;
219 self.prepare_write(DOCUMENTS_FILE, docs_json.as_bytes())?;
220 self.prepare_write(FINGERPRINTS_FILE, fingerprints_json.as_bytes())?;
221 self.prepare_write(MANIFEST_FILE, manifest_json.as_bytes())?;
222
223 self.commit_rename(INDEX_FILE)?;
227 self.commit_rename(DOCUMENTS_FILE)?;
228 self.commit_rename(FINGERPRINTS_FILE)?;
229 self.commit_rename(MANIFEST_FILE)?;
230
231 Ok(())
232 }
233
234 pub fn load(
243 &self,
244 ) -> Result<Option<(PersistedIndex, PersistedDocuments, RagManifest)>, PersistenceError> {
245 let manifest_path = self.cache_path.join(MANIFEST_FILE);
246
247 if !manifest_path.exists() {
249 return Ok(None);
250 }
251
252 let manifest_json = match fs::read_to_string(&manifest_path) {
254 Ok(s) => s,
255 Err(e) => {
256 eprintln!("Warning: failed to read RAG manifest, will rebuild: {e}");
257 return Ok(None);
258 }
259 };
260 let manifest: RagManifest = match serde_json::from_str(&manifest_json) {
261 Ok(m) => m,
262 Err(e) => {
263 eprintln!("Warning: corrupt RAG manifest JSON, will rebuild: {e}");
264 return Ok(None);
265 }
266 };
267
268 self.validate_version(&manifest)?;
270
271 let index_json = match fs::read_to_string(self.cache_path.join(INDEX_FILE)) {
273 Ok(s) => s,
274 Err(e) => {
275 eprintln!("Warning: failed to read RAG index file, will rebuild: {e}");
276 return Ok(None);
277 }
278 };
279 if let Err(e) = self.validate_checksum(&index_json, manifest.index_checksum, "index.json") {
280 eprintln!("Warning: {e}, will rebuild");
281 return Ok(None);
282 }
283 let index: PersistedIndex = match serde_json::from_str(&index_json) {
284 Ok(i) => i,
285 Err(e) => {
286 eprintln!("Warning: corrupt RAG index JSON, will rebuild: {e}");
287 return Ok(None);
288 }
289 };
290
291 let docs_json = match fs::read_to_string(self.cache_path.join(DOCUMENTS_FILE)) {
293 Ok(s) => s,
294 Err(e) => {
295 eprintln!("Warning: failed to read RAG documents file, will rebuild: {e}");
296 return Ok(None);
297 }
298 };
299 if let Err(e) = self.validate_checksum(&docs_json, manifest.docs_checksum, "documents.json")
300 {
301 eprintln!("Warning: {e}, will rebuild");
302 return Ok(None);
303 }
304 let docs: PersistedDocuments = match serde_json::from_str(&docs_json) {
305 Ok(d) => d,
306 Err(e) => {
307 eprintln!("Warning: corrupt RAG documents JSON, will rebuild: {e}");
308 return Ok(None);
309 }
310 };
311
312 Ok(Some((index, docs, manifest)))
313 }
314
315 pub fn load_fingerprints_only(
320 &self,
321 ) -> Result<Option<HashMap<String, DocumentFingerprint>>, PersistenceError> {
322 let fp_path = self.cache_path.join(FINGERPRINTS_FILE);
323
324 if fp_path.exists() {
325 let fp_json = match fs::read_to_string(&fp_path) {
326 Ok(s) => s,
327 Err(_) => return self.load_fingerprints_fallback(),
328 };
329 match serde_json::from_str(&fp_json) {
330 Ok(fps) => return Ok(Some(fps)),
331 Err(_) => return self.load_fingerprints_fallback(),
332 }
333 }
334
335 self.load_fingerprints_fallback()
337 }
338
339 fn load_fingerprints_fallback(
341 &self,
342 ) -> Result<Option<HashMap<String, DocumentFingerprint>>, PersistenceError> {
343 self.load().map(|opt| opt.map(|(_, docs, _)| docs.fingerprints))
344 }
345
346 pub fn save_fingerprints_only(
351 &self,
352 fingerprints: &HashMap<String, DocumentFingerprint>,
353 ) -> Result<(), PersistenceError> {
354 fs::create_dir_all(&self.cache_path)?;
355 let fingerprints_json = serde_json::to_string_pretty(fingerprints)?;
356 self.prepare_write(FINGERPRINTS_FILE, fingerprints_json.as_bytes())?;
357 self.commit_rename(FINGERPRINTS_FILE)?;
358 Ok(())
359 }
360
361 pub fn clear(&self) -> Result<(), PersistenceError> {
363 if self.cache_path.exists() {
364 let _ = fs::remove_file(self.cache_path.join(MANIFEST_FILE));
366 let _ = fs::remove_file(self.cache_path.join(INDEX_FILE));
367 let _ = fs::remove_file(self.cache_path.join(DOCUMENTS_FILE));
368 let _ = fs::remove_file(self.cache_path.join(FINGERPRINTS_FILE));
369
370 let _ = fs::remove_dir(&self.cache_path);
372 }
373 Ok(())
374 }
375
376 pub fn stats(&self) -> Result<Option<RagManifest>, PersistenceError> {
378 let manifest_path = self.cache_path.join(MANIFEST_FILE);
379
380 if !manifest_path.exists() {
381 return Ok(None);
382 }
383
384 let manifest_json = fs::read_to_string(&manifest_path)?;
385 let manifest: RagManifest = serde_json::from_str(&manifest_json)?;
386
387 Ok(Some(manifest))
388 }
389
390 fn prepare_write(&self, filename: &str, data: &[u8]) -> Result<(), io::Error> {
392 let tmp_path = self.cache_path.join(format!("{}.tmp", filename));
393
394 let mut file = fs::File::create(&tmp_path)?;
395 file.write_all(data)?;
396 file.sync_all()?;
397
398 Ok(())
399 }
400
401 fn commit_rename(&self, filename: &str) -> Result<(), io::Error> {
403 let tmp_path = self.cache_path.join(format!("{}.tmp", filename));
404 let final_path = self.cache_path.join(filename);
405
406 fs::rename(&tmp_path, &final_path)?;
407
408 Ok(())
409 }
410
411 fn cleanup_tmp_files(&self) {
413 for filename in &[MANIFEST_FILE, INDEX_FILE, DOCUMENTS_FILE] {
414 let tmp_path = self.cache_path.join(format!("{}.tmp", filename));
415 let _ = fs::remove_file(tmp_path);
416 }
417 }
418
419 fn validate_version(&self, manifest: &RagManifest) -> Result<(), PersistenceError> {
421 let index_parts: Vec<&str> = manifest.version.split('.').collect();
423 let expected_parts: Vec<&str> = INDEX_VERSION.split('.').collect();
424
425 if index_parts.first() != expected_parts.first() {
427 return Err(PersistenceError::VersionMismatch {
428 index_version: manifest.version.clone(),
429 expected_version: INDEX_VERSION.to_string(),
430 });
431 }
432
433 Ok(())
434 }
435
436 fn validate_checksum(
438 &self,
439 data: &str,
440 expected: [u8; 32],
441 filename: &str,
442 ) -> Result<(), PersistenceError> {
443 let actual = blake3_hash(data.as_bytes());
444
445 if actual != expected {
446 return Err(PersistenceError::ChecksumMismatch {
447 file: filename.to_string(),
448 expected,
449 actual,
450 });
451 }
452
453 Ok(())
454 }
455}
456
457impl Default for RagPersistence {
458 fn default() -> Self {
459 Self::new()
460 }
461}
462
463fn current_timestamp_ms() -> u64 {
465 std::time::SystemTime::now()
466 .duration_since(std::time::UNIX_EPOCH)
467 .map(|d| d.as_millis() as u64)
468 .unwrap_or(0)
469}
470
471#[cfg(test)]
472#[path = "persistence_tests.rs"]
473mod tests;