use super::fingerprint::{blake3_hash, DocumentFingerprint};
use super::types::{Bm25Config, RrfConfig};
use super::IndexedDocument;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
pub const INDEX_VERSION: &str = "1.1.0";
const CACHE_SUBDIR: &str = "batuta/rag";
const MANIFEST_FILE: &str = "manifest.json";
const INDEX_FILE: &str = "index.json";
const DOCUMENTS_FILE: &str = "documents.json";
const FINGERPRINTS_FILE: &str = "fingerprints.json";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagManifest {
pub version: String,
pub index_checksum: [u8; 32],
pub docs_checksum: [u8; 32],
pub sources: Vec<CorpusSource>,
pub indexed_at: u64,
pub batuta_version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorpusSource {
pub id: String,
pub commit: Option<String>,
pub doc_count: usize,
pub chunk_count: usize,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PersistedIndex {
pub inverted_index: HashMap<String, HashMap<String, usize>>,
pub doc_lengths: HashMap<String, usize>,
pub bm25_config: Bm25Config,
pub rrf_config: RrfConfig,
pub avg_doc_length: f64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PersistedDocuments {
pub documents: HashMap<String, IndexedDocument>,
pub fingerprints: HashMap<String, DocumentFingerprint>,
pub total_chunks: usize,
#[serde(default)]
pub chunk_contents: HashMap<String, String>,
}
#[derive(Debug, thiserror::Error)]
pub enum PersistenceError {
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Checksum mismatch for {file}: expected {expected:x?}, got {actual:x?}")]
ChecksumMismatch { file: String, expected: [u8; 32], actual: [u8; 32] },
#[error("Version mismatch: index version {index_version}, expected {expected_version}")]
VersionMismatch { index_version: String, expected_version: String },
#[error("Cache directory not found")]
CacheDirNotFound,
#[error("No cached index found")]
NoCachedIndex,
}
#[derive(Debug)]
pub struct RagPersistence {
cache_path: PathBuf,
}
impl RagPersistence {
pub fn new() -> Self {
Self { cache_path: Self::default_cache_path() }
}
pub fn with_path(path: PathBuf) -> Self {
Self { cache_path: path }
}
fn default_cache_path() -> PathBuf {
#[cfg(feature = "native")]
{
dirs::cache_dir().unwrap_or_else(|| PathBuf::from(".cache")).join(CACHE_SUBDIR)
}
#[cfg(not(feature = "native"))]
{
PathBuf::from(".cache").join(CACHE_SUBDIR)
}
}
pub fn cache_path(&self) -> &Path {
&self.cache_path
}
pub fn save(
&self,
index: &PersistedIndex,
docs: &PersistedDocuments,
sources: Vec<CorpusSource>,
) -> Result<(), PersistenceError> {
fs::create_dir_all(&self.cache_path)?;
self.cleanup_tmp_files();
let index_json = serde_json::to_string_pretty(index)?;
let docs_json = serde_json::to_string_pretty(docs)?;
let fingerprints_json = serde_json::to_string_pretty(&docs.fingerprints)?;
let index_checksum = blake3_hash(index_json.as_bytes());
let docs_checksum = blake3_hash(docs_json.as_bytes());
let manifest = RagManifest {
version: INDEX_VERSION.to_string(),
index_checksum,
docs_checksum,
sources,
indexed_at: current_timestamp_ms(),
batuta_version: env!("CARGO_PKG_VERSION").to_string(),
};
let manifest_json = serde_json::to_string_pretty(&manifest)?;
self.prepare_write(INDEX_FILE, index_json.as_bytes())?;
self.prepare_write(DOCUMENTS_FILE, docs_json.as_bytes())?;
self.prepare_write(FINGERPRINTS_FILE, fingerprints_json.as_bytes())?;
self.prepare_write(MANIFEST_FILE, manifest_json.as_bytes())?;
self.commit_rename(INDEX_FILE)?;
self.commit_rename(DOCUMENTS_FILE)?;
self.commit_rename(FINGERPRINTS_FILE)?;
self.commit_rename(MANIFEST_FILE)?;
Ok(())
}
pub fn load(
&self,
) -> Result<Option<(PersistedIndex, PersistedDocuments, RagManifest)>, PersistenceError> {
let manifest_path = self.cache_path.join(MANIFEST_FILE);
if !manifest_path.exists() {
return Ok(None);
}
let manifest_json = match fs::read_to_string(&manifest_path) {
Ok(s) => s,
Err(e) => {
eprintln!("Warning: failed to read RAG manifest, will rebuild: {e}");
return Ok(None);
}
};
let manifest: RagManifest = match serde_json::from_str(&manifest_json) {
Ok(m) => m,
Err(e) => {
eprintln!("Warning: corrupt RAG manifest JSON, will rebuild: {e}");
return Ok(None);
}
};
self.validate_version(&manifest)?;
let index_json = match fs::read_to_string(self.cache_path.join(INDEX_FILE)) {
Ok(s) => s,
Err(e) => {
eprintln!("Warning: failed to read RAG index file, will rebuild: {e}");
return Ok(None);
}
};
if let Err(e) = self.validate_checksum(&index_json, manifest.index_checksum, "index.json") {
eprintln!("Warning: {e}, will rebuild");
return Ok(None);
}
let index: PersistedIndex = match serde_json::from_str(&index_json) {
Ok(i) => i,
Err(e) => {
eprintln!("Warning: corrupt RAG index JSON, will rebuild: {e}");
return Ok(None);
}
};
let docs_json = match fs::read_to_string(self.cache_path.join(DOCUMENTS_FILE)) {
Ok(s) => s,
Err(e) => {
eprintln!("Warning: failed to read RAG documents file, will rebuild: {e}");
return Ok(None);
}
};
if let Err(e) = self.validate_checksum(&docs_json, manifest.docs_checksum, "documents.json")
{
eprintln!("Warning: {e}, will rebuild");
return Ok(None);
}
let docs: PersistedDocuments = match serde_json::from_str(&docs_json) {
Ok(d) => d,
Err(e) => {
eprintln!("Warning: corrupt RAG documents JSON, will rebuild: {e}");
return Ok(None);
}
};
Ok(Some((index, docs, manifest)))
}
pub fn load_fingerprints_only(
&self,
) -> Result<Option<HashMap<String, DocumentFingerprint>>, PersistenceError> {
let fp_path = self.cache_path.join(FINGERPRINTS_FILE);
if fp_path.exists() {
let fp_json = match fs::read_to_string(&fp_path) {
Ok(s) => s,
Err(_) => return self.load_fingerprints_fallback(),
};
match serde_json::from_str(&fp_json) {
Ok(fps) => return Ok(Some(fps)),
Err(_) => return self.load_fingerprints_fallback(),
}
}
self.load_fingerprints_fallback()
}
fn load_fingerprints_fallback(
&self,
) -> Result<Option<HashMap<String, DocumentFingerprint>>, PersistenceError> {
self.load().map(|opt| opt.map(|(_, docs, _)| docs.fingerprints))
}
pub fn save_fingerprints_only(
&self,
fingerprints: &HashMap<String, DocumentFingerprint>,
) -> Result<(), PersistenceError> {
fs::create_dir_all(&self.cache_path)?;
let fingerprints_json = serde_json::to_string_pretty(fingerprints)?;
self.prepare_write(FINGERPRINTS_FILE, fingerprints_json.as_bytes())?;
self.commit_rename(FINGERPRINTS_FILE)?;
Ok(())
}
pub fn clear(&self) -> Result<(), PersistenceError> {
if self.cache_path.exists() {
let _ = fs::remove_file(self.cache_path.join(MANIFEST_FILE));
let _ = fs::remove_file(self.cache_path.join(INDEX_FILE));
let _ = fs::remove_file(self.cache_path.join(DOCUMENTS_FILE));
let _ = fs::remove_file(self.cache_path.join(FINGERPRINTS_FILE));
let _ = fs::remove_dir(&self.cache_path);
}
Ok(())
}
pub fn stats(&self) -> Result<Option<RagManifest>, PersistenceError> {
let manifest_path = self.cache_path.join(MANIFEST_FILE);
if !manifest_path.exists() {
return Ok(None);
}
let manifest_json = fs::read_to_string(&manifest_path)?;
let manifest: RagManifest = serde_json::from_str(&manifest_json)?;
Ok(Some(manifest))
}
fn prepare_write(&self, filename: &str, data: &[u8]) -> Result<(), io::Error> {
let tmp_path = self.cache_path.join(format!("{}.tmp", filename));
let mut file = fs::File::create(&tmp_path)?;
file.write_all(data)?;
file.sync_all()?;
Ok(())
}
fn commit_rename(&self, filename: &str) -> Result<(), io::Error> {
let tmp_path = self.cache_path.join(format!("{}.tmp", filename));
let final_path = self.cache_path.join(filename);
fs::rename(&tmp_path, &final_path)?;
Ok(())
}
fn cleanup_tmp_files(&self) {
for filename in &[MANIFEST_FILE, INDEX_FILE, DOCUMENTS_FILE] {
let tmp_path = self.cache_path.join(format!("{}.tmp", filename));
let _ = fs::remove_file(tmp_path);
}
}
fn validate_version(&self, manifest: &RagManifest) -> Result<(), PersistenceError> {
let index_parts: Vec<&str> = manifest.version.split('.').collect();
let expected_parts: Vec<&str> = INDEX_VERSION.split('.').collect();
if index_parts.first() != expected_parts.first() {
return Err(PersistenceError::VersionMismatch {
index_version: manifest.version.clone(),
expected_version: INDEX_VERSION.to_string(),
});
}
Ok(())
}
fn validate_checksum(
&self,
data: &str,
expected: [u8; 32],
filename: &str,
) -> Result<(), PersistenceError> {
let actual = blake3_hash(data.as_bytes());
if actual != expected {
return Err(PersistenceError::ChecksumMismatch {
file: filename.to_string(),
expected,
actual,
});
}
Ok(())
}
}
impl Default for RagPersistence {
fn default() -> Self {
Self::new()
}
}
fn current_timestamp_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
#[path = "persistence_tests.rs"]
mod tests;