use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::Path;
use serde::{Deserialize, Serialize};
use crate::distance::Distance;
use crate::embedding::Embedder;
use crate::error::RagError;
use crate::retriever::Retriever;
use crate::vector_store::{VectorEntry, VectorStore};
pub const SCHEMA_VERSION: u32 = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexSnapshot {
pub schema_version: u32,
pub dim: usize,
#[serde(default)]
pub distance: Distance,
pub entries: Vec<VectorEntry>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tfidf_state: Option<serde_json::Value>,
}
impl IndexSnapshot {
pub fn check_version(&self) -> Result<(), RagError> {
if self.schema_version != SCHEMA_VERSION {
return Err(RagError::Persistence(format!(
"unsupported schema_version {} (expected {})",
self.schema_version, SCHEMA_VERSION
)));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrieverSnapshot {
pub schema_version: u32,
pub doc_count: usize,
pub store: IndexSnapshot,
}
impl RetrieverSnapshot {
fn check_version(&self) -> Result<(), RagError> {
if self.schema_version != SCHEMA_VERSION {
return Err(RagError::Persistence(format!(
"unsupported schema_version {} (expected {})",
self.schema_version, SCHEMA_VERSION
)));
}
self.store.check_version()
}
}
impl VectorStore {
pub fn to_snapshot(&self) -> IndexSnapshot {
IndexSnapshot {
schema_version: SCHEMA_VERSION,
dim: self.dim(),
distance: self.distance(),
entries: self.entries().to_vec(),
tfidf_state: None,
}
}
pub fn from_snapshot(snapshot: IndexSnapshot) -> Result<Self, RagError> {
snapshot.check_version()?;
for entry in &snapshot.entries {
if entry.vector.len() != snapshot.dim {
return Err(RagError::DimensionMismatch {
expected: snapshot.dim,
got: entry.vector.len(),
});
}
}
let mut store = VectorStore::new_with_distance(snapshot.dim, snapshot.distance);
store.set_entries(snapshot.entries);
Ok(store)
}
pub fn save_json(&self, path: impl AsRef<Path>) -> Result<(), RagError> {
let file = File::create(path.as_ref())?;
let writer = BufWriter::new(file);
serde_json::to_writer_pretty(writer, &self.to_snapshot())
.map_err(|e| RagError::Persistence(format!("serialize failed: {e}")))?;
Ok(())
}
pub fn load_json(path: impl AsRef<Path>) -> Result<Self, RagError> {
let file = File::open(path.as_ref())?;
let reader = BufReader::new(file);
let snapshot: IndexSnapshot = serde_json::from_reader(reader)
.map_err(|e| RagError::Persistence(format!("parse failed: {e}")))?;
Self::from_snapshot(snapshot)
}
}
impl<E: Embedder> Retriever<E> {
pub fn save(&self, path: impl AsRef<Path>) -> Result<(), RagError> {
let snapshot = RetrieverSnapshot {
schema_version: SCHEMA_VERSION,
doc_count: self.document_count(),
store: self.store().to_snapshot(),
};
let file = File::create(path.as_ref())?;
let writer = BufWriter::new(file);
serde_json::to_writer_pretty(writer, &snapshot)
.map_err(|e| RagError::Persistence(format!("serialize failed: {e}")))?;
Ok(())
}
pub fn load(embedder: E, path: impl AsRef<Path>) -> Result<Self, RagError> {
let file = File::open(path.as_ref())?;
let reader = BufReader::new(file);
let snapshot: RetrieverSnapshot = serde_json::from_reader(reader)
.map_err(|e| RagError::Persistence(format!("parse failed: {e}")))?;
snapshot.check_version()?;
if embedder.embedding_dim() != snapshot.store.dim {
return Err(RagError::DimensionMismatch {
expected: snapshot.store.dim,
got: embedder.embedding_dim(),
});
}
let store = VectorStore::from_snapshot(snapshot.store)?;
Ok(Self::from_parts(
embedder,
store,
snapshot.doc_count,
crate::retriever::RetrieverConfig::default(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::chunker::Chunk;
fn tmp_path(tag: &str) -> std::path::PathBuf {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let pid = std::process::id();
std::env::temp_dir().join(format!("oxibonsai_rag_persist_{tag}_{pid}_{nanos}.json"))
}
#[test]
fn roundtrip_preserves_entries() {
let mut store = VectorStore::new(3);
let chunk = Chunk::new("hello".into(), 0, 0, 0);
store.insert(vec![1.0, 0.0, 0.0], chunk).expect("insert");
let path = tmp_path("roundtrip");
store.save_json(&path).expect("save");
let loaded = VectorStore::load_json(&path).expect("load");
assert_eq!(loaded.len(), 1);
assert_eq!(loaded.dim(), 3);
std::fs::remove_file(&path).ok();
}
#[test]
fn unknown_version_rejected() {
let snapshot = IndexSnapshot {
schema_version: 9999,
dim: 1,
distance: Distance::Cosine,
entries: Vec::new(),
tfidf_state: None,
};
let result = VectorStore::from_snapshot(snapshot);
assert!(matches!(result, Err(RagError::Persistence(_))));
}
}