Skip to main content

oxirs_vec/persistence/
snapshot.rs

1//! HNSW Index Snapshot - Save and Restore without rebuilding
2//!
3//! Binary format (big-endian throughout):
4//!   [4 bytes]  magic "HNSW"
5//!   [4 bytes]  format version (u32)
6//!   [8 bytes]  num_nodes (u64)
7//!   [8 bytes]  num_layers (u64)
8//!   [8 bytes]  dimension (u64)
9//!   [8 bytes]  ef_construction (u64)
10//!   [8 bytes]  m (u64)
11//!   [8 bytes]  m_l0 (u64)
12//!   [1 byte]   has_entry_point (u8: 0 or 1)
13//!   [8 bytes]  entry_point (u64, only present if has_entry_point == 1)
14//!   For each node:
15//!     \[8 bytes\]  uri_len (u64)
16//!     \[uri_len\]  uri bytes (UTF-8)
17//!     [8 bytes]  vector_len (u64) -- number of f32 elements
18//!     [4*n bytes] vector data (f32 little-endian)
19//!     [8 bytes]  num_connection_layers (u64)
20//!     For each layer:
21//!       [8 bytes]  num_connections (u64)
22//!       For each connection:
23//!         [8 bytes]  connected_node_id (u64)
24
25use crate::hnsw::{HnswConfig, HnswIndex, Node};
26use crate::Vector;
27use crate::VectorError;
28use std::collections::HashSet;
29use std::io::{Read, Write};
30use std::path::Path;
31
32/// Magic bytes identifying an HNSW snapshot file
33const SNAPSHOT_MAGIC: &[u8; 4] = b"HNSW";
34
35/// Snapshot format version
36const SNAPSHOT_VERSION: u32 = 1;
37
38/// Header decoded from a snapshot
39#[derive(Debug, Clone)]
40pub struct SnapshotHeader {
41    /// File magic — always b"HNSW"
42    pub magic: [u8; 4],
43    /// Format version
44    pub version: u32,
45    /// Number of nodes stored
46    pub num_nodes: usize,
47    /// Number of hierarchy layers
48    pub num_layers: usize,
49    /// Vector dimension
50    pub dimension: usize,
51    /// ef_construction parameter at snapshot time
52    pub ef_construction: usize,
53    /// M parameter at snapshot time
54    pub m: usize,
55    /// M_l0 parameter at snapshot time
56    pub m_l0: usize,
57    /// Entry point node id (None when index is empty)
58    pub entry_point: Option<usize>,
59}
60
61/// Snapshot I/O for an [`HnswIndex`].
62///
63/// All multi-byte integers are stored as little-endian `u64` / `u32`.
64/// Floating-point values are stored as little-endian `f32`.
65pub struct IndexSnapshot;
66
67impl IndexSnapshot {
68    // ──────────────────────────────────────────────────────────────────────────
69    // Public API
70    // ──────────────────────────────────────────────────────────────────────────
71
72    /// Serialize `index` into `writer`.
73    ///
74    /// Returns the total number of bytes written.
75    pub fn save<W: Write>(index: &HnswIndex, writer: &mut W) -> Result<usize, VectorError> {
76        let mut written = 0usize;
77
78        // ── magic ──────────────────────────────────────────────────────────────
79        writer
80            .write_all(SNAPSHOT_MAGIC)
81            .map_err(VectorError::IoError)?;
82        written += 4;
83
84        // ── version ────────────────────────────────────────────────────────────
85        Self::write_u32(writer, SNAPSHOT_VERSION).map_err(VectorError::IoError)?;
86        written += 4;
87
88        let nodes = index.nodes();
89        let config = index.config();
90
91        // Derive the maximum layer count from the stored nodes
92        let num_layers = nodes.iter().map(|n| n.connections.len()).max().unwrap_or(0);
93
94        let dimension = nodes.first().map(|n| n.vector_data_f32.len()).unwrap_or(0);
95
96        // ── header scalars ─────────────────────────────────────────────────────
97        Self::write_u64(writer, nodes.len() as u64).map_err(VectorError::IoError)?;
98        written += 8;
99        Self::write_u64(writer, num_layers as u64).map_err(VectorError::IoError)?;
100        written += 8;
101        Self::write_u64(writer, dimension as u64).map_err(VectorError::IoError)?;
102        written += 8;
103        Self::write_u64(writer, config.ef_construction as u64).map_err(VectorError::IoError)?;
104        written += 8;
105        Self::write_u64(writer, config.m as u64).map_err(VectorError::IoError)?;
106        written += 8;
107        Self::write_u64(writer, config.m_l0 as u64).map_err(VectorError::IoError)?;
108        written += 8;
109
110        // ── entry point ────────────────────────────────────────────────────────
111        match index.entry_point() {
112            None => {
113                Self::write_u8(writer, 0).map_err(VectorError::IoError)?;
114                written += 1;
115            }
116            Some(ep) => {
117                Self::write_u8(writer, 1).map_err(VectorError::IoError)?;
118                written += 1;
119                Self::write_u64(writer, ep as u64).map_err(VectorError::IoError)?;
120                written += 8;
121            }
122        }
123
124        // ── nodes ──────────────────────────────────────────────────────────────
125        for node in nodes {
126            written += Self::write_node(writer, node).map_err(VectorError::IoError)?;
127        }
128
129        writer.flush().map_err(VectorError::IoError)?;
130        Ok(written)
131    }
132
133    /// Deserialize an [`HnswIndex`] from `reader`.
134    pub fn load<R: Read>(reader: &mut R) -> Result<HnswIndex, VectorError> {
135        // ── magic ──────────────────────────────────────────────────────────────
136        let mut magic = [0u8; 4];
137        reader
138            .read_exact(&mut magic)
139            .map_err(VectorError::IoError)?;
140        if &magic != SNAPSHOT_MAGIC {
141            return Err(VectorError::InvalidData(format!(
142                "Invalid snapshot magic: expected {:?}, got {:?}",
143                SNAPSHOT_MAGIC, magic
144            )));
145        }
146
147        // ── version ────────────────────────────────────────────────────────────
148        let version = Self::read_u32(reader).map_err(VectorError::IoError)?;
149        if version != SNAPSHOT_VERSION {
150            return Err(VectorError::InvalidData(format!(
151                "Unsupported snapshot version: {}",
152                version
153            )));
154        }
155
156        // ── header scalars ─────────────────────────────────────────────────────
157        let num_nodes = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
158        let _num_layers = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
159        let _dimension = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
160        let ef_construction = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
161        let m = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
162        let m_l0 = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
163
164        // ── entry point ────────────────────────────────────────────────────────
165        let has_entry = Self::read_u8(reader).map_err(VectorError::IoError)?;
166        let entry_point = if has_entry == 1 {
167            Some(Self::read_u64(reader).map_err(VectorError::IoError)? as usize)
168        } else {
169            None
170        };
171
172        // ── reconstruct config ─────────────────────────────────────────────────
173        let config = HnswConfig {
174            m,
175            m_l0,
176            ef_construction,
177            ..HnswConfig::default()
178        };
179
180        // ── nodes ──────────────────────────────────────────────────────────────
181        let mut nodes: Vec<Node> = Vec::with_capacity(num_nodes);
182        let mut uri_to_id: std::collections::HashMap<String, usize> =
183            std::collections::HashMap::with_capacity(num_nodes);
184
185        for idx in 0..num_nodes {
186            let node = Self::read_node(reader)?;
187            uri_to_id.insert(node.uri.clone(), idx);
188            nodes.push(node);
189        }
190
191        // ── assemble index ─────────────────────────────────────────────────────
192        let mut index = HnswIndex::new_cpu_only(config);
193        // Replace internal state via the provided accessors
194        *index.nodes_mut() = nodes;
195        *index.uri_to_id_mut() = uri_to_id;
196        index.set_entry_point(entry_point);
197
198        Ok(index)
199    }
200
201    /// Persist `index` to a file at `path`.
202    ///
203    /// The file is created (or truncated) atomically via a temporary sibling file.
204    pub fn save_to_file(index: &HnswIndex, path: &Path) -> Result<usize, VectorError> {
205        // Write to a temporary file first, then rename for atomicity
206        let tmp_path = path.with_extension("hnsw.tmp");
207        let file = std::fs::File::create(&tmp_path).map_err(VectorError::IoError)?;
208        let mut writer = std::io::BufWriter::new(file);
209
210        let written = Self::save(index, &mut writer)?;
211        drop(writer);
212
213        std::fs::rename(&tmp_path, path).map_err(VectorError::IoError)?;
214        Ok(written)
215    }
216
217    /// Load an index from a file at `path`.
218    pub fn load_from_file(path: &Path) -> Result<HnswIndex, VectorError> {
219        let file = std::fs::File::open(path).map_err(VectorError::IoError)?;
220        let mut reader = std::io::BufReader::new(file);
221        Self::load(&mut reader)
222    }
223
224    // ──────────────────────────────────────────────────────────────────────────
225    // Private node I/O
226    // ──────────────────────────────────────────────────────────────────────────
227
228    fn write_node<W: Write>(writer: &mut W, node: &Node) -> std::io::Result<usize> {
229        let mut written = 0usize;
230
231        // uri
232        let uri_bytes = node.uri.as_bytes();
233        Self::write_u64(writer, uri_bytes.len() as u64)?;
234        written += 8;
235        writer.write_all(uri_bytes)?;
236        written += uri_bytes.len();
237
238        // vector data (f32 array)
239        Self::write_u64(writer, node.vector_data_f32.len() as u64)?;
240        written += 8;
241        for &v in &node.vector_data_f32 {
242            Self::write_f32(writer, v)?;
243            written += 4;
244        }
245
246        // connections per layer
247        Self::write_u64(writer, node.connections.len() as u64)?;
248        written += 8;
249        for layer_connections in &node.connections {
250            Self::write_u64(writer, layer_connections.len() as u64)?;
251            written += 8;
252            for &neighbor_id in layer_connections {
253                Self::write_u64(writer, neighbor_id as u64)?;
254                written += 8;
255            }
256        }
257
258        Ok(written)
259    }
260
261    fn read_node<R: Read>(reader: &mut R) -> Result<Node, VectorError> {
262        // uri
263        let uri_len = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
264        let mut uri_bytes = vec![0u8; uri_len];
265        reader
266            .read_exact(&mut uri_bytes)
267            .map_err(VectorError::IoError)?;
268        let uri = String::from_utf8(uri_bytes)
269            .map_err(|e| VectorError::InvalidData(format!("Invalid UTF-8 in URI: {}", e)))?;
270
271        // vector data
272        let vec_len = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
273        let mut vector_data_f32 = Vec::with_capacity(vec_len);
274        for _ in 0..vec_len {
275            let v = Self::read_f32(reader).map_err(VectorError::IoError)?;
276            vector_data_f32.push(v);
277        }
278
279        // Reconstruct Vector from f32 data
280        let vector = Vector::new(vector_data_f32.clone());
281
282        // connections
283        let num_layers = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
284        let mut connections: Vec<HashSet<usize>> = Vec::with_capacity(num_layers);
285        for _ in 0..num_layers {
286            let num_conn = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
287            let mut layer_set = HashSet::with_capacity(num_conn);
288            for _ in 0..num_conn {
289                let neighbor = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
290                layer_set.insert(neighbor);
291            }
292            connections.push(layer_set);
293        }
294
295        let max_level = num_layers.saturating_sub(1);
296        let mut node = Node::new(uri, vector, max_level);
297        node.connections = connections;
298        node.vector_data_f32 = vector_data_f32;
299
300        Ok(node)
301    }
302
303    // ──────────────────────────────────────────────────────────────────────────
304    // Low-level I/O helpers (no external serialization crates)
305    // ──────────────────────────────────────────────────────────────────────────
306
307    fn write_u64<W: Write>(w: &mut W, v: u64) -> std::io::Result<()> {
308        w.write_all(&v.to_le_bytes())
309    }
310
311    fn read_u64<R: Read>(r: &mut R) -> std::io::Result<u64> {
312        let mut buf = [0u8; 8];
313        r.read_exact(&mut buf)?;
314        Ok(u64::from_le_bytes(buf))
315    }
316
317    fn write_u32<W: Write>(w: &mut W, v: u32) -> std::io::Result<()> {
318        w.write_all(&v.to_le_bytes())
319    }
320
321    fn read_u32<R: Read>(r: &mut R) -> std::io::Result<u32> {
322        let mut buf = [0u8; 4];
323        r.read_exact(&mut buf)?;
324        Ok(u32::from_le_bytes(buf))
325    }
326
327    fn write_u8<W: Write>(w: &mut W, v: u8) -> std::io::Result<()> {
328        w.write_all(&[v])
329    }
330
331    fn read_u8<R: Read>(r: &mut R) -> std::io::Result<u8> {
332        let mut buf = [0u8; 1];
333        r.read_exact(&mut buf)?;
334        Ok(buf[0])
335    }
336
337    fn write_f32<W: Write>(w: &mut W, v: f32) -> std::io::Result<()> {
338        w.write_all(&v.to_le_bytes())
339    }
340
341    fn read_f32<R: Read>(r: &mut R) -> std::io::Result<f32> {
342        let mut buf = [0u8; 4];
343        r.read_exact(&mut buf)?;
344        Ok(f32::from_le_bytes(buf))
345    }
346}
347
348// ────────────────────────────────────────────────────────────────────────────
349// Tests
350// ────────────────────────────────────────────────────────────────────────────
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::hnsw::{HnswConfig, HnswIndex};
356    use crate::VectorIndex;
357
358    fn make_index_with_vectors(n: usize, dim: usize) -> HnswIndex {
359        let config = HnswConfig::default();
360        let mut index = HnswIndex::new_cpu_only(config);
361
362        for i in 0..n {
363            let data: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 / 1000.0).collect();
364            let uri = format!("http://example.org/v{}", i);
365            let vec = Vector::new(data);
366            index.insert(uri, vec).expect("insert failed");
367        }
368
369        index
370    }
371
372    #[test]
373    fn test_save_and_load_empty_index() {
374        let index = HnswIndex::new_cpu_only(HnswConfig::default());
375        let mut buf = Vec::new();
376        let bytes = IndexSnapshot::save(&index, &mut buf).expect("save failed");
377        assert!(bytes > 0);
378
379        let loaded = IndexSnapshot::load(&mut buf.as_slice()).expect("load failed");
380        assert_eq!(loaded.len(), 0);
381        assert_eq!(loaded.entry_point(), None);
382    }
383
384    #[test]
385    fn test_save_and_load_roundtrip() {
386        let original = make_index_with_vectors(20, 8);
387        assert_eq!(original.len(), 20);
388
389        let mut buf = Vec::new();
390        IndexSnapshot::save(&original, &mut buf).expect("save failed");
391
392        let restored = IndexSnapshot::load(&mut buf.as_slice()).expect("load failed");
393
394        // Node count preserved
395        assert_eq!(restored.len(), original.len());
396
397        // URI mapping preserved
398        for uri in original.uri_to_id().keys() {
399            assert!(
400                restored.uri_to_id().contains_key(uri),
401                "URI {} missing after restore",
402                uri
403            );
404        }
405
406        // Entry point preserved
407        assert_eq!(original.entry_point(), restored.entry_point());
408    }
409
410    #[test]
411    fn test_save_and_load_vectors_preserved() {
412        let original = make_index_with_vectors(10, 4);
413
414        let mut buf = Vec::new();
415        IndexSnapshot::save(&original, &mut buf).expect("save failed");
416        let restored = IndexSnapshot::load(&mut buf.as_slice()).expect("load failed");
417
418        // Check each node's vector data matches
419        for (orig_node, rest_node) in original.nodes().iter().zip(restored.nodes().iter()) {
420            assert_eq!(orig_node.uri, rest_node.uri);
421            assert_eq!(
422                orig_node.vector_data_f32.len(),
423                rest_node.vector_data_f32.len()
424            );
425            for (a, b) in orig_node
426                .vector_data_f32
427                .iter()
428                .zip(rest_node.vector_data_f32.iter())
429            {
430                assert!((a - b).abs() < 1e-6, "Vector data mismatch: {} vs {}", a, b);
431            }
432        }
433    }
434
435    #[test]
436    fn test_save_and_load_connections_preserved() {
437        let original = make_index_with_vectors(30, 8);
438
439        let mut buf = Vec::new();
440        IndexSnapshot::save(&original, &mut buf).expect("save failed");
441        let restored = IndexSnapshot::load(&mut buf.as_slice()).expect("load failed");
442
443        // Verify connection structure is preserved for each node
444        for (i, (orig, rest)) in original
445            .nodes()
446            .iter()
447            .zip(restored.nodes().iter())
448            .enumerate()
449        {
450            assert_eq!(
451                orig.connections.len(),
452                rest.connections.len(),
453                "Node {} layer count mismatch",
454                i
455            );
456            for (layer, (oc, rc)) in orig
457                .connections
458                .iter()
459                .zip(rest.connections.iter())
460                .enumerate()
461            {
462                assert_eq!(oc, rc, "Node {} layer {} connections mismatch", i, layer);
463            }
464        }
465    }
466
467    #[test]
468    fn test_file_save_and_load() {
469        let original = make_index_with_vectors(15, 6);
470
471        let dir = std::env::temp_dir();
472        let path = dir.join("oxirs_snapshot_test.hnsw");
473
474        let bytes = IndexSnapshot::save_to_file(&original, &path).expect("save_to_file failed");
475        assert!(bytes > 0);
476        assert!(path.exists());
477
478        let restored = IndexSnapshot::load_from_file(&path).expect("load_from_file failed");
479        assert_eq!(restored.len(), original.len());
480
481        // Cleanup
482        let _ = std::fs::remove_file(&path);
483    }
484
485    #[test]
486    fn test_corrupt_magic_rejected() {
487        let mut buf = vec![0u8; 100];
488        buf[0] = b'X'; // corrupt magic
489        let result = IndexSnapshot::load(&mut buf.as_slice());
490        assert!(result.is_err());
491    }
492
493    #[test]
494    fn test_config_restored() {
495        let config = HnswConfig {
496            m: 8,
497            m_l0: 16,
498            ef_construction: 50,
499            ..Default::default()
500        };
501
502        let mut index = HnswIndex::new_cpu_only(config);
503        let vec_a = Vector::new(vec![1.0, 0.0, 0.0, 0.0]);
504        index
505            .insert("http://example.org/a".to_string(), vec_a)
506            .expect("insert");
507
508        let mut buf = Vec::new();
509        IndexSnapshot::save(&index, &mut buf).expect("save");
510        let restored = IndexSnapshot::load(&mut buf.as_slice()).expect("load");
511
512        assert_eq!(restored.config().m, 8);
513        assert_eq!(restored.config().m_l0, 16);
514        assert_eq!(restored.config().ef_construction, 50);
515    }
516}