Skip to main content

oxibonsai_rag/
persistence.rs

1//! JSON persistence for the vector store and retriever.
2//!
3//! This module provides durable snapshots of an indexed corpus so callers
4//! can save an index to disk and reload it without re-embedding every
5//! document.  Two snapshot types are exposed:
6//!
7//! - [`IndexSnapshot`] captures a [`crate::vector_store::VectorStore`] — its
8//!   dimensionality, distance metric, and every stored entry.
9//! - [`RetrieverSnapshot`] wraps an [`IndexSnapshot`] together with the
10//!   [`Retriever`]'s document counter so that `add_document` continues to
11//!   produce monotonically increasing `doc_id`s after a round-trip.
12//!
13//! A monotonically-increasing [`SCHEMA_VERSION`] is stored in every
14//! snapshot.  [`VectorStore::load_json`] and [`Retriever::load`] refuse to
15//! deserialise unknown versions with [`RagError::Persistence`].
16//!
17//! # Embedder state
18//!
19//! The [`Retriever`] is generic over its [`Embedder`].  Because the trait
20//! does not require `Serialize` we cannot persist embedder internals
21//! generically — callers must reconstruct the embedder themselves and pass
22//! it to [`Retriever::load`].  The `tfidf_state` field is an optional
23//! escape hatch (`serde_json::Value`) that advanced users can populate by
24//! hand if they wish to round-trip a TF-IDF vocabulary alongside the index.
25//!
26//! # Example
27//!
28//! ```no_run
29//! use oxibonsai_rag::embedding::IdentityEmbedder;
30//! use oxibonsai_rag::pipeline::RagConfig;
31//! use oxibonsai_rag::retriever::{Retriever, RetrieverConfig};
32//!
33//! let embedder = IdentityEmbedder::new(32).expect("valid dim");
34//! let mut retriever = Retriever::new(embedder, RetrieverConfig::default());
35//! retriever
36//!     .add_document("some text", &RagConfig::default().chunk_config)
37//!     .expect("index");
38//!
39//! let path = std::env::temp_dir().join("rag_snapshot.json");
40//! retriever.save(&path).expect("save");
41//!
42//! let embedder = IdentityEmbedder::new(32).expect("valid dim");
43//! let restored = Retriever::load(embedder, &path).expect("load");
44//! assert_eq!(restored.chunk_count(), 1);
45//! ```
46
47use std::fs::File;
48use std::io::{BufReader, BufWriter};
49use std::path::Path;
50
51use serde::{Deserialize, Serialize};
52
53use crate::distance::Distance;
54use crate::embedding::Embedder;
55use crate::error::RagError;
56use crate::retriever::Retriever;
57use crate::vector_store::{VectorEntry, VectorStore};
58
59// ─────────────────────────────────────────────────────────────────────────────
60// Schema version
61// ─────────────────────────────────────────────────────────────────────────────
62
63/// Current on-disk snapshot schema version.
64///
65/// Bump this when the [`IndexSnapshot`] layout changes in a non-backwards-
66/// compatible way.  Loaders reject unknown values with
67/// [`RagError::Persistence`] so that a stale binary cannot silently
68/// misinterpret a newer file.
69pub const SCHEMA_VERSION: u32 = 1;
70
71// ─────────────────────────────────────────────────────────────────────────────
72// IndexSnapshot
73// ─────────────────────────────────────────────────────────────────────────────
74
75/// Serde-serialisable snapshot of a [`VectorStore`].
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct IndexSnapshot {
78    /// Schema version tag (see [`SCHEMA_VERSION`]).
79    pub schema_version: u32,
80    /// Embedding dimensionality.
81    pub dim: usize,
82    /// Distance metric the store was configured with.
83    #[serde(default)]
84    pub distance: Distance,
85    /// All stored entries, in insertion order.
86    pub entries: Vec<VectorEntry>,
87    /// Optional serialised TF-IDF state.  Advanced users may populate this
88    /// by hand (see module-level docs).  Kept as an opaque
89    /// [`serde_json::Value`] so that the store does not need to know the
90    /// embedder's internal layout.
91    #[serde(default, skip_serializing_if = "Option::is_none")]
92    pub tfidf_state: Option<serde_json::Value>,
93}
94
95impl IndexSnapshot {
96    /// Ensure the schema version matches the build-time constant.  Returns
97    /// [`RagError::Persistence`] for unknown versions.
98    pub fn check_version(&self) -> Result<(), RagError> {
99        if self.schema_version != SCHEMA_VERSION {
100            return Err(RagError::Persistence(format!(
101                "unsupported schema_version {} (expected {})",
102                self.schema_version, SCHEMA_VERSION
103            )));
104        }
105        Ok(())
106    }
107}
108
109// ─────────────────────────────────────────────────────────────────────────────
110// RetrieverSnapshot
111// ─────────────────────────────────────────────────────────────────────────────
112
113/// Serde-serialisable snapshot of a [`Retriever`].
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct RetrieverSnapshot {
116    /// Schema version tag (see [`SCHEMA_VERSION`]).
117    pub schema_version: u32,
118    /// Number of distinct documents indexed so far.
119    pub doc_count: usize,
120    /// The underlying vector-store snapshot.
121    pub store: IndexSnapshot,
122}
123
124impl RetrieverSnapshot {
125    fn check_version(&self) -> Result<(), RagError> {
126        if self.schema_version != SCHEMA_VERSION {
127            return Err(RagError::Persistence(format!(
128                "unsupported schema_version {} (expected {})",
129                self.schema_version, SCHEMA_VERSION
130            )));
131        }
132        self.store.check_version()
133    }
134}
135
136// ─────────────────────────────────────────────────────────────────────────────
137// VectorStore <-> IndexSnapshot conversions
138// ─────────────────────────────────────────────────────────────────────────────
139
140impl VectorStore {
141    /// Produce an [`IndexSnapshot`] capturing the current store contents.
142    pub fn to_snapshot(&self) -> IndexSnapshot {
143        IndexSnapshot {
144            schema_version: SCHEMA_VERSION,
145            dim: self.dim(),
146            distance: self.distance(),
147            entries: self.entries().to_vec(),
148            tfidf_state: None,
149        }
150    }
151
152    /// Build a [`VectorStore`] from a previously-produced snapshot.
153    ///
154    /// Returns [`RagError::Persistence`] if the schema version is unknown,
155    /// and [`RagError::DimensionMismatch`] if any stored entry has a
156    /// vector whose length disagrees with the snapshot's `dim`.
157    pub fn from_snapshot(snapshot: IndexSnapshot) -> Result<Self, RagError> {
158        snapshot.check_version()?;
159        for entry in &snapshot.entries {
160            if entry.vector.len() != snapshot.dim {
161                return Err(RagError::DimensionMismatch {
162                    expected: snapshot.dim,
163                    got: entry.vector.len(),
164                });
165            }
166        }
167        let mut store = VectorStore::new_with_distance(snapshot.dim, snapshot.distance);
168        store.set_entries(snapshot.entries);
169        Ok(store)
170    }
171
172    /// Serialise this store to `path` as pretty-printed JSON.
173    pub fn save_json(&self, path: impl AsRef<Path>) -> Result<(), RagError> {
174        let file = File::create(path.as_ref())?;
175        let writer = BufWriter::new(file);
176        serde_json::to_writer_pretty(writer, &self.to_snapshot())
177            .map_err(|e| RagError::Persistence(format!("serialize failed: {e}")))?;
178        Ok(())
179    }
180
181    /// Deserialise a store previously written by [`VectorStore::save_json`].
182    ///
183    /// Returns [`RagError::Persistence`] on malformed JSON or unknown
184    /// schema version, and [`RagError::DimensionMismatch`] if any stored
185    /// entry's vector length disagrees with the snapshot's `dim`.
186    pub fn load_json(path: impl AsRef<Path>) -> Result<Self, RagError> {
187        let file = File::open(path.as_ref())?;
188        let reader = BufReader::new(file);
189        let snapshot: IndexSnapshot = serde_json::from_reader(reader)
190            .map_err(|e| RagError::Persistence(format!("parse failed: {e}")))?;
191        Self::from_snapshot(snapshot)
192    }
193}
194
195// ─────────────────────────────────────────────────────────────────────────────
196// Retriever persistence
197// ─────────────────────────────────────────────────────────────────────────────
198
199impl<E: Embedder> Retriever<E> {
200    /// Serialise this retriever's index to `path` (pretty JSON).
201    ///
202    /// The embedder itself is *not* persisted — callers must provide an
203    /// equivalent embedder to [`Retriever::load`].
204    pub fn save(&self, path: impl AsRef<Path>) -> Result<(), RagError> {
205        let snapshot = RetrieverSnapshot {
206            schema_version: SCHEMA_VERSION,
207            doc_count: self.document_count(),
208            store: self.store().to_snapshot(),
209        };
210        let file = File::create(path.as_ref())?;
211        let writer = BufWriter::new(file);
212        serde_json::to_writer_pretty(writer, &snapshot)
213            .map_err(|e| RagError::Persistence(format!("serialize failed: {e}")))?;
214        Ok(())
215    }
216
217    /// Reconstruct a [`Retriever`] from a previously-saved snapshot.
218    ///
219    /// `embedder` must produce vectors of the same dimensionality as the
220    /// snapshot, otherwise [`RagError::DimensionMismatch`] is returned.
221    pub fn load(embedder: E, path: impl AsRef<Path>) -> Result<Self, RagError> {
222        let file = File::open(path.as_ref())?;
223        let reader = BufReader::new(file);
224        let snapshot: RetrieverSnapshot = serde_json::from_reader(reader)
225            .map_err(|e| RagError::Persistence(format!("parse failed: {e}")))?;
226        snapshot.check_version()?;
227
228        if embedder.embedding_dim() != snapshot.store.dim {
229            return Err(RagError::DimensionMismatch {
230                expected: snapshot.store.dim,
231                got: embedder.embedding_dim(),
232            });
233        }
234
235        let store = VectorStore::from_snapshot(snapshot.store)?;
236        Ok(Self::from_parts(
237            embedder,
238            store,
239            snapshot.doc_count,
240            crate::retriever::RetrieverConfig::default(),
241        ))
242    }
243}
244
245// ─────────────────────────────────────────────────────────────────────────────
246// Inline tests
247// ─────────────────────────────────────────────────────────────────────────────
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use crate::chunker::Chunk;
253
254    fn tmp_path(tag: &str) -> std::path::PathBuf {
255        let nanos = std::time::SystemTime::now()
256            .duration_since(std::time::UNIX_EPOCH)
257            .map(|d| d.as_nanos())
258            .unwrap_or(0);
259        let pid = std::process::id();
260        std::env::temp_dir().join(format!("oxibonsai_rag_persist_{tag}_{pid}_{nanos}.json"))
261    }
262
263    #[test]
264    fn roundtrip_preserves_entries() {
265        let mut store = VectorStore::new(3);
266        let chunk = Chunk::new("hello".into(), 0, 0, 0);
267        store.insert(vec![1.0, 0.0, 0.0], chunk).expect("insert");
268
269        let path = tmp_path("roundtrip");
270        store.save_json(&path).expect("save");
271        let loaded = VectorStore::load_json(&path).expect("load");
272        assert_eq!(loaded.len(), 1);
273        assert_eq!(loaded.dim(), 3);
274        std::fs::remove_file(&path).ok();
275    }
276
277    #[test]
278    fn unknown_version_rejected() {
279        let snapshot = IndexSnapshot {
280            schema_version: 9999,
281            dim: 1,
282            distance: Distance::Cosine,
283            entries: Vec::new(),
284            tfidf_state: None,
285        };
286        let result = VectorStore::from_snapshot(snapshot);
287        assert!(matches!(result, Err(RagError::Persistence(_))));
288    }
289}