oxibonsai_rag/persistence.rs
1//! JSON persistence for the vector store and retriever.
2//!
3//! This module provides durable snapshots of an indexed corpus so callers
4//! can save an index to disk and reload it without re-embedding every
5//! document. Two snapshot types are exposed:
6//!
7//! - [`IndexSnapshot`] captures a [`crate::vector_store::VectorStore`] — its
8//! dimensionality, distance metric, and every stored entry.
9//! - [`RetrieverSnapshot`] wraps an [`IndexSnapshot`] together with the
10//! [`Retriever`]'s document counter so that `add_document` continues to
11//! produce monotonically increasing `doc_id`s after a round-trip.
12//!
13//! A monotonically-increasing [`SCHEMA_VERSION`] is stored in every
14//! snapshot. [`VectorStore::load_json`] and [`Retriever::load`] refuse to
15//! deserialise unknown versions with [`RagError::Persistence`].
16//!
17//! # Embedder state
18//!
19//! The [`Retriever`] is generic over its [`Embedder`]. Because the trait
20//! does not require `Serialize` we cannot persist embedder internals
21//! generically — callers must reconstruct the embedder themselves and pass
22//! it to [`Retriever::load`]. The `tfidf_state` field is an optional
23//! escape hatch (`serde_json::Value`) that advanced users can populate by
24//! hand if they wish to round-trip a TF-IDF vocabulary alongside the index.
25//!
26//! # Example
27//!
28//! ```no_run
29//! use oxibonsai_rag::embedding::IdentityEmbedder;
30//! use oxibonsai_rag::pipeline::RagConfig;
31//! use oxibonsai_rag::retriever::{Retriever, RetrieverConfig};
32//!
33//! let embedder = IdentityEmbedder::new(32).expect("valid dim");
34//! let mut retriever = Retriever::new(embedder, RetrieverConfig::default());
35//! retriever
36//! .add_document("some text", &RagConfig::default().chunk_config)
37//! .expect("index");
38//!
39//! let path = std::env::temp_dir().join("rag_snapshot.json");
40//! retriever.save(&path).expect("save");
41//!
42//! let embedder = IdentityEmbedder::new(32).expect("valid dim");
43//! let restored = Retriever::load(embedder, &path).expect("load");
44//! assert_eq!(restored.chunk_count(), 1);
45//! ```
46
47use std::fs::File;
48use std::io::{BufReader, BufWriter};
49use std::path::Path;
50
51use serde::{Deserialize, Serialize};
52
53use crate::distance::Distance;
54use crate::embedding::Embedder;
55use crate::error::RagError;
56use crate::retriever::Retriever;
57use crate::vector_store::{VectorEntry, VectorStore};
58
59// ─────────────────────────────────────────────────────────────────────────────
60// Schema version
61// ─────────────────────────────────────────────────────────────────────────────
62
63/// Current on-disk snapshot schema version.
64///
65/// Bump this when the [`IndexSnapshot`] layout changes in a non-backwards-
66/// compatible way. Loaders reject unknown values with
67/// [`RagError::Persistence`] so that a stale binary cannot silently
68/// misinterpret a newer file.
69pub const SCHEMA_VERSION: u32 = 1;
70
71// ─────────────────────────────────────────────────────────────────────────────
72// IndexSnapshot
73// ─────────────────────────────────────────────────────────────────────────────
74
75/// Serde-serialisable snapshot of a [`VectorStore`].
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct IndexSnapshot {
78 /// Schema version tag (see [`SCHEMA_VERSION`]).
79 pub schema_version: u32,
80 /// Embedding dimensionality.
81 pub dim: usize,
82 /// Distance metric the store was configured with.
83 #[serde(default)]
84 pub distance: Distance,
85 /// All stored entries, in insertion order.
86 pub entries: Vec<VectorEntry>,
87 /// Optional serialised TF-IDF state. Advanced users may populate this
88 /// by hand (see module-level docs). Kept as an opaque
89 /// [`serde_json::Value`] so that the store does not need to know the
90 /// embedder's internal layout.
91 #[serde(default, skip_serializing_if = "Option::is_none")]
92 pub tfidf_state: Option<serde_json::Value>,
93}
94
95impl IndexSnapshot {
96 /// Ensure the schema version matches the build-time constant. Returns
97 /// [`RagError::Persistence`] for unknown versions.
98 pub fn check_version(&self) -> Result<(), RagError> {
99 if self.schema_version != SCHEMA_VERSION {
100 return Err(RagError::Persistence(format!(
101 "unsupported schema_version {} (expected {})",
102 self.schema_version, SCHEMA_VERSION
103 )));
104 }
105 Ok(())
106 }
107}
108
109// ─────────────────────────────────────────────────────────────────────────────
110// RetrieverSnapshot
111// ─────────────────────────────────────────────────────────────────────────────
112
113/// Serde-serialisable snapshot of a [`Retriever`].
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct RetrieverSnapshot {
116 /// Schema version tag (see [`SCHEMA_VERSION`]).
117 pub schema_version: u32,
118 /// Number of distinct documents indexed so far.
119 pub doc_count: usize,
120 /// The underlying vector-store snapshot.
121 pub store: IndexSnapshot,
122}
123
124impl RetrieverSnapshot {
125 fn check_version(&self) -> Result<(), RagError> {
126 if self.schema_version != SCHEMA_VERSION {
127 return Err(RagError::Persistence(format!(
128 "unsupported schema_version {} (expected {})",
129 self.schema_version, SCHEMA_VERSION
130 )));
131 }
132 self.store.check_version()
133 }
134}
135
136// ─────────────────────────────────────────────────────────────────────────────
137// VectorStore <-> IndexSnapshot conversions
138// ─────────────────────────────────────────────────────────────────────────────
139
140impl VectorStore {
141 /// Produce an [`IndexSnapshot`] capturing the current store contents.
142 pub fn to_snapshot(&self) -> IndexSnapshot {
143 IndexSnapshot {
144 schema_version: SCHEMA_VERSION,
145 dim: self.dim(),
146 distance: self.distance(),
147 entries: self.entries().to_vec(),
148 tfidf_state: None,
149 }
150 }
151
152 /// Build a [`VectorStore`] from a previously-produced snapshot.
153 ///
154 /// Returns [`RagError::Persistence`] if the schema version is unknown,
155 /// and [`RagError::DimensionMismatch`] if any stored entry has a
156 /// vector whose length disagrees with the snapshot's `dim`.
157 pub fn from_snapshot(snapshot: IndexSnapshot) -> Result<Self, RagError> {
158 snapshot.check_version()?;
159 for entry in &snapshot.entries {
160 if entry.vector.len() != snapshot.dim {
161 return Err(RagError::DimensionMismatch {
162 expected: snapshot.dim,
163 got: entry.vector.len(),
164 });
165 }
166 }
167 let mut store = VectorStore::new_with_distance(snapshot.dim, snapshot.distance);
168 store.set_entries(snapshot.entries);
169 Ok(store)
170 }
171
172 /// Serialise this store to `path` as pretty-printed JSON.
173 pub fn save_json(&self, path: impl AsRef<Path>) -> Result<(), RagError> {
174 let file = File::create(path.as_ref())?;
175 let writer = BufWriter::new(file);
176 serde_json::to_writer_pretty(writer, &self.to_snapshot())
177 .map_err(|e| RagError::Persistence(format!("serialize failed: {e}")))?;
178 Ok(())
179 }
180
181 /// Deserialise a store previously written by [`VectorStore::save_json`].
182 ///
183 /// Returns [`RagError::Persistence`] on malformed JSON or unknown
184 /// schema version, and [`RagError::DimensionMismatch`] if any stored
185 /// entry's vector length disagrees with the snapshot's `dim`.
186 pub fn load_json(path: impl AsRef<Path>) -> Result<Self, RagError> {
187 let file = File::open(path.as_ref())?;
188 let reader = BufReader::new(file);
189 let snapshot: IndexSnapshot = serde_json::from_reader(reader)
190 .map_err(|e| RagError::Persistence(format!("parse failed: {e}")))?;
191 Self::from_snapshot(snapshot)
192 }
193}
194
195// ─────────────────────────────────────────────────────────────────────────────
196// Retriever persistence
197// ─────────────────────────────────────────────────────────────────────────────
198
199impl<E: Embedder> Retriever<E> {
200 /// Serialise this retriever's index to `path` (pretty JSON).
201 ///
202 /// The embedder itself is *not* persisted — callers must provide an
203 /// equivalent embedder to [`Retriever::load`].
204 pub fn save(&self, path: impl AsRef<Path>) -> Result<(), RagError> {
205 let snapshot = RetrieverSnapshot {
206 schema_version: SCHEMA_VERSION,
207 doc_count: self.document_count(),
208 store: self.store().to_snapshot(),
209 };
210 let file = File::create(path.as_ref())?;
211 let writer = BufWriter::new(file);
212 serde_json::to_writer_pretty(writer, &snapshot)
213 .map_err(|e| RagError::Persistence(format!("serialize failed: {e}")))?;
214 Ok(())
215 }
216
217 /// Reconstruct a [`Retriever`] from a previously-saved snapshot.
218 ///
219 /// `embedder` must produce vectors of the same dimensionality as the
220 /// snapshot, otherwise [`RagError::DimensionMismatch`] is returned.
221 pub fn load(embedder: E, path: impl AsRef<Path>) -> Result<Self, RagError> {
222 let file = File::open(path.as_ref())?;
223 let reader = BufReader::new(file);
224 let snapshot: RetrieverSnapshot = serde_json::from_reader(reader)
225 .map_err(|e| RagError::Persistence(format!("parse failed: {e}")))?;
226 snapshot.check_version()?;
227
228 if embedder.embedding_dim() != snapshot.store.dim {
229 return Err(RagError::DimensionMismatch {
230 expected: snapshot.store.dim,
231 got: embedder.embedding_dim(),
232 });
233 }
234
235 let store = VectorStore::from_snapshot(snapshot.store)?;
236 Ok(Self::from_parts(
237 embedder,
238 store,
239 snapshot.doc_count,
240 crate::retriever::RetrieverConfig::default(),
241 ))
242 }
243}
244
245// ─────────────────────────────────────────────────────────────────────────────
246// Inline tests
247// ─────────────────────────────────────────────────────────────────────────────
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use crate::chunker::Chunk;
253
254 fn tmp_path(tag: &str) -> std::path::PathBuf {
255 let nanos = std::time::SystemTime::now()
256 .duration_since(std::time::UNIX_EPOCH)
257 .map(|d| d.as_nanos())
258 .unwrap_or(0);
259 let pid = std::process::id();
260 std::env::temp_dir().join(format!("oxibonsai_rag_persist_{tag}_{pid}_{nanos}.json"))
261 }
262
263 #[test]
264 fn roundtrip_preserves_entries() {
265 let mut store = VectorStore::new(3);
266 let chunk = Chunk::new("hello".into(), 0, 0, 0);
267 store.insert(vec![1.0, 0.0, 0.0], chunk).expect("insert");
268
269 let path = tmp_path("roundtrip");
270 store.save_json(&path).expect("save");
271 let loaded = VectorStore::load_json(&path).expect("load");
272 assert_eq!(loaded.len(), 1);
273 assert_eq!(loaded.dim(), 3);
274 std::fs::remove_file(&path).ok();
275 }
276
277 #[test]
278 fn unknown_version_rejected() {
279 let snapshot = IndexSnapshot {
280 schema_version: 9999,
281 dim: 1,
282 distance: Distance::Cosine,
283 entries: Vec::new(),
284 tfidf_state: None,
285 };
286 let result = VectorStore::from_snapshot(snapshot);
287 assert!(matches!(result, Err(RagError::Persistence(_))));
288 }
289}