Skip to main content

oxirs_vec/
hnsw_persistence.rs

1//! HNSW Index Serialization / Persistence
2//!
3//! This module provides binary serialization and file-based persistence for
4//! Hierarchical Navigable Small World (HNSW) index snapshots.
5//!
6//! # Format
7//!
8//! The binary format is a custom little-endian encoding (no bincode):
9//!
10//! ```text
11//! Magic:     4 bytes  = b"HNSW"
12//! Version:   2 bytes  = u16 LE
13//! CRC32:     4 bytes  = u32 LE (checksum of all following bytes)
14//! meta_len:  4 bytes  = u32 LE (length of serialized HnswMeta)
15//! meta_data: N bytes
16//! layers:    zero or more layer records:
17//!   layer_level: 4 bytes = u32 LE
18//!   node_count:  4 bytes = u32 LE
19//!   nodes:       node_count × LayerNode records:
20//!     node_id:        8 bytes = u64 LE
21//!     neighbor_count: 4 bytes = u32 LE
22//!     neighbors:      neighbor_count × 8 bytes (u64 LE each)
23//! ```
24//!
25//! # Atomic file writes
26//!
27//! `save_to_file` writes to a `.tmp` file first, then renames into place —
28//! ensuring that the destination file is never partially written.
29
30use std::collections::HashMap;
31use std::io::{self, Write};
32use std::path::Path;
33
34// ─── CRC32 implementation ─────────────────────────────────────────────────────
35//
36// Simple table-driven CRC32 (ISO 3309 / IEEE 802.3 polynomial 0xEDB88320).
37
38const CRC32_POLY: u32 = 0xEDB8_8320;
39
40/// Build the 256-entry CRC32 look-up table.
41fn build_crc32_table() -> [u32; 256] {
42    let mut table = [0u32; 256];
43    for i in 0u32..256 {
44        let mut c = i;
45        for _ in 0..8 {
46            if c & 1 != 0 {
47                c = CRC32_POLY ^ (c >> 1);
48            } else {
49                c >>= 1;
50            }
51        }
52        table[i as usize] = c;
53    }
54    table
55}
56
57/// Compute the CRC32 checksum of a byte slice.
58pub fn crc32(data: &[u8]) -> u32 {
59    let table = build_crc32_table();
60    let mut crc: u32 = 0xFFFF_FFFF;
61    for &byte in data {
62        let idx = ((crc ^ byte as u32) & 0xFF) as usize;
63        crc = table[idx] ^ (crc >> 8);
64    }
65    crc ^ 0xFFFF_FFFF
66}
67
68// ─── Domain types ─────────────────────────────────────────────────────────────
69
70/// HNSW index metadata / configuration.
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct HnswMeta {
73    /// Maximum layer index (0-based; layer 0 is the base layer)
74    pub max_layer: usize,
75    /// Entry-point node ID (None if the index is empty)
76    pub entry_point: Option<u64>,
77    /// `ef_construction` parameter
78    pub ef_construction: usize,
79    /// `M` parameter (number of bidirectional links per node)
80    pub m: usize,
81    /// Total number of distinct nodes across all layers
82    pub node_count: usize,
83}
84
85/// A single node in one HNSW layer: its ID and its neighbour list.
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct LayerNode {
88    /// Unique node identifier
89    pub id: u64,
90    /// Neighbours (up to `M` per layer)
91    pub neighbors: Vec<u64>,
92}
93
94/// One layer of the HNSW graph.
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub struct HnswLayer {
97    /// Layer level (0 = base)
98    pub level: usize,
99    /// Nodes in this layer
100    pub nodes: Vec<LayerNode>,
101}
102
103/// A complete snapshot of an HNSW index.
104#[derive(Debug, Clone, PartialEq, Eq)]
105pub struct HnswSnapshot {
106    /// Index metadata
107    pub meta: HnswMeta,
108    /// Layer data (ordered by level)
109    pub layers: Vec<HnswLayer>,
110}
111
112// ─── Errors ───────────────────────────────────────────────────────────────────
113
114/// Error from HNSW persistence operations.
115#[derive(Debug)]
116pub enum PersistError {
117    /// I/O failure (wrapped)
118    Io(io::Error),
119    /// Magic bytes mismatch
120    BadMagic,
121    /// Unsupported format version
122    UnsupportedVersion(u16),
123    /// CRC32 checksum mismatch
124    ChecksumMismatch,
125    /// Data is truncated or malformed
126    Malformed(String),
127}
128
129impl std::fmt::Display for PersistError {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        match self {
132            Self::Io(e) => write!(f, "IO error: {e}"),
133            Self::BadMagic => write!(f, "bad magic bytes (expected 'HNSW')"),
134            Self::UnsupportedVersion(v) => write!(f, "unsupported version: {v}"),
135            Self::ChecksumMismatch => write!(f, "CRC32 checksum mismatch"),
136            Self::Malformed(s) => write!(f, "malformed data: {s}"),
137        }
138    }
139}
140
141impl std::error::Error for PersistError {
142    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
143        match self {
144            Self::Io(e) => Some(e),
145            _ => None,
146        }
147    }
148}
149
150impl From<io::Error> for PersistError {
151    fn from(e: io::Error) -> Self {
152        Self::Io(e)
153    }
154}
155
156// ─── Format constants ─────────────────────────────────────────────────────────
157
158const MAGIC: &[u8; 4] = b"HNSW";
159const FORMAT_VERSION: u16 = 1;
160
161// ─── Low-level read / write helpers ──────────────────────────────────────────
162
163fn write_u16_le(buf: &mut Vec<u8>, v: u16) {
164    buf.extend_from_slice(&v.to_le_bytes());
165}
166
167fn write_u32_le(buf: &mut Vec<u8>, v: u32) {
168    buf.extend_from_slice(&v.to_le_bytes());
169}
170
171fn write_u64_le(buf: &mut Vec<u8>, v: u64) {
172    buf.extend_from_slice(&v.to_le_bytes());
173}
174
175fn read_u16_le(data: &[u8], pos: &mut usize) -> Result<u16, PersistError> {
176    if *pos + 2 > data.len() {
177        return Err(PersistError::Malformed("unexpected end (u16)".into()));
178    }
179    let v = u16::from_le_bytes(data[*pos..*pos + 2].try_into().expect("slice of 2"));
180    *pos += 2;
181    Ok(v)
182}
183
184fn read_u32_le(data: &[u8], pos: &mut usize) -> Result<u32, PersistError> {
185    if *pos + 4 > data.len() {
186        return Err(PersistError::Malformed("unexpected end (u32)".into()));
187    }
188    let v = u32::from_le_bytes(data[*pos..*pos + 4].try_into().expect("slice of 4"));
189    *pos += 4;
190    Ok(v)
191}
192
193fn read_u64_le(data: &[u8], pos: &mut usize) -> Result<u64, PersistError> {
194    if *pos + 8 > data.len() {
195        return Err(PersistError::Malformed("unexpected end (u64)".into()));
196    }
197    let v = u64::from_le_bytes(data[*pos..*pos + 8].try_into().expect("slice of 8"));
198    *pos += 8;
199    Ok(v)
200}
201
202// ─── Meta serialization ────────────────────────────────────────────────────────
203
204fn serialize_meta(meta: &HnswMeta) -> Vec<u8> {
205    let mut buf = Vec::new();
206    write_u64_le(&mut buf, meta.max_layer as u64);
207    match meta.entry_point {
208        Some(ep) => {
209            write_u32_le(&mut buf, 1);
210            write_u64_le(&mut buf, ep);
211        }
212        None => {
213            write_u32_le(&mut buf, 0);
214            write_u64_le(&mut buf, 0);
215        }
216    }
217    write_u64_le(&mut buf, meta.ef_construction as u64);
218    write_u64_le(&mut buf, meta.m as u64);
219    write_u64_le(&mut buf, meta.node_count as u64);
220    buf
221}
222
223fn deserialize_meta(data: &[u8], pos: &mut usize) -> Result<HnswMeta, PersistError> {
224    let max_layer = read_u64_le(data, pos)? as usize;
225    let has_ep = read_u32_le(data, pos)?;
226    let ep_raw = read_u64_le(data, pos)?;
227    let entry_point = if has_ep == 1 { Some(ep_raw) } else { None };
228    let ef_construction = read_u64_le(data, pos)? as usize;
229    let m = read_u64_le(data, pos)? as usize;
230    let node_count = read_u64_le(data, pos)? as usize;
231    Ok(HnswMeta {
232        max_layer,
233        entry_point,
234        ef_construction,
235        m,
236        node_count,
237    })
238}
239
240// ─── HnswPersistence ─────────────────────────────────────────────────────────
241
242/// Provides serialization and file persistence for HNSW index snapshots.
243#[derive(Debug, Default, Clone)]
244pub struct HnswPersistence;
245
246impl HnswPersistence {
247    /// Create a new persistence instance.
248    pub fn new() -> Self {
249        Self
250    }
251
252    /// Serialize `meta` + `layers` into a byte vector.
253    ///
254    /// The format is described in the module documentation.
255    pub fn serialize(&self, meta: &HnswMeta, layers: &[HnswLayer]) -> Vec<u8> {
256        // Serialize meta
257        let meta_bytes = serialize_meta(meta);
258
259        // Serialize layers
260        let mut layer_bytes: Vec<u8> = Vec::new();
261        for layer in layers {
262            write_u32_le(&mut layer_bytes, layer.level as u32);
263            write_u32_le(&mut layer_bytes, layer.nodes.len() as u32);
264            for node in &layer.nodes {
265                write_u64_le(&mut layer_bytes, node.id);
266                write_u32_le(&mut layer_bytes, node.neighbors.len() as u32);
267                for &nb in &node.neighbors {
268                    write_u64_le(&mut layer_bytes, nb);
269                }
270            }
271        }
272
273        // Build the payload (meta + layers) for checksumming
274        let mut payload: Vec<u8> = Vec::new();
275        write_u32_le(&mut payload, meta_bytes.len() as u32);
276        payload.extend_from_slice(&meta_bytes);
277        payload.extend_from_slice(&layer_bytes);
278
279        let checksum = crc32(&payload);
280
281        // Build the final buffer
282        let mut buf: Vec<u8> = Vec::with_capacity(10 + payload.len());
283        buf.extend_from_slice(MAGIC);
284        write_u16_le(&mut buf, FORMAT_VERSION);
285        write_u32_le(&mut buf, checksum);
286        buf.extend_from_slice(&payload);
287
288        buf
289    }
290
291    /// Deserialize bytes into `(HnswMeta, Vec<HnswLayer>)`.
292    pub fn deserialize(&self, bytes: &[u8]) -> Result<(HnswMeta, Vec<HnswLayer>), PersistError> {
293        if bytes.len() < 10 {
294            return Err(PersistError::Malformed("too short".into()));
295        }
296
297        // Magic
298        if &bytes[..4] != MAGIC {
299            return Err(PersistError::BadMagic);
300        }
301        let mut pos = 4;
302
303        // Version
304        let version = read_u16_le(bytes, &mut pos)?;
305        if version != FORMAT_VERSION {
306            return Err(PersistError::UnsupportedVersion(version));
307        }
308
309        // Checksum
310        let stored_crc = read_u32_le(bytes, &mut pos)?;
311        let computed_crc = crc32(&bytes[pos..]);
312        if stored_crc != computed_crc {
313            return Err(PersistError::ChecksumMismatch);
314        }
315
316        // Meta
317        let meta_len = read_u32_le(bytes, &mut pos)? as usize;
318        if pos + meta_len > bytes.len() {
319            return Err(PersistError::Malformed("meta_len out of bounds".into()));
320        }
321        let meta = deserialize_meta(bytes, &mut pos)?;
322
323        // Layers
324        let mut layers = Vec::new();
325        while pos < bytes.len() {
326            let level = read_u32_le(bytes, &mut pos)? as usize;
327            let node_count = read_u32_le(bytes, &mut pos)?;
328            let mut nodes = Vec::with_capacity(node_count as usize);
329            for _ in 0..node_count {
330                let id = read_u64_le(bytes, &mut pos)?;
331                let neighbor_count = read_u32_le(bytes, &mut pos)?;
332                let mut neighbors = Vec::with_capacity(neighbor_count as usize);
333                for _ in 0..neighbor_count {
334                    neighbors.push(read_u64_le(bytes, &mut pos)?);
335                }
336                nodes.push(LayerNode { id, neighbors });
337            }
338            layers.push(HnswLayer { level, nodes });
339        }
340
341        Ok((meta, layers))
342    }
343
344    /// Write a snapshot to a file atomically (write → temp, rename).
345    pub fn save_to_file(
346        &self,
347        path: &Path,
348        meta: &HnswMeta,
349        layers: &[HnswLayer],
350    ) -> Result<(), PersistError> {
351        let bytes = self.serialize(meta, layers);
352
353        // Write to a temporary file alongside the destination
354        let mut tmp_path = path.to_path_buf();
355        let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("hnsw");
356        tmp_path.set_file_name(format!(".{file_name}.tmp"));
357
358        {
359            let mut file = std::fs::File::create(&tmp_path)?;
360            file.write_all(&bytes)?;
361            file.flush()?;
362        }
363
364        // Atomic rename
365        std::fs::rename(&tmp_path, path)?;
366
367        Ok(())
368    }
369
370    /// Read a snapshot from a file.
371    pub fn load_from_file(&self, path: &Path) -> Result<(HnswMeta, Vec<HnswLayer>), PersistError> {
372        let bytes = std::fs::read(path)?;
373        self.deserialize(&bytes)
374    }
375
376    /// Validate the CRC32 checksum of serialized bytes.
377    ///
378    /// Returns `true` if the checksum field matches the computed checksum
379    /// of the payload, `false` otherwise.
380    pub fn validate_checksum(&self, bytes: &[u8]) -> bool {
381        if bytes.len() < 10 {
382            return false;
383        }
384        if &bytes[..4] != MAGIC {
385            return false;
386        }
387        let stored_crc = u32::from_le_bytes(bytes[6..10].try_into().expect("4 bytes"));
388        let computed_crc = crc32(&bytes[10..]);
389        stored_crc == computed_crc
390    }
391
392    /// Merge a base snapshot with a delta snapshot.
393    ///
394    /// Merge strategy:
395    /// - `meta` is taken from `delta` (delta always has the newer configuration)
396    /// - For each layer in delta, its nodes override/supplement the base layer nodes
397    ///   (keyed by `node.id`)
398    /// - Layers present only in base are preserved; layers only in delta are added
399    pub fn merge_snapshots(&self, base: &HnswSnapshot, delta: &HnswSnapshot) -> HnswSnapshot {
400        // Index base layers by level
401        let mut layer_map: HashMap<usize, HashMap<u64, LayerNode>> = HashMap::new();
402        for layer in &base.layers {
403            let node_map: HashMap<u64, LayerNode> =
404                layer.nodes.iter().map(|n| (n.id, n.clone())).collect();
405            layer_map.insert(layer.level, node_map);
406        }
407
408        // Apply delta layers
409        for layer in &delta.layers {
410            let node_map = layer_map.entry(layer.level).or_default();
411            for node in &layer.nodes {
412                node_map.insert(node.id, node.clone());
413            }
414        }
415
416        // Reconstruct ordered layers
417        let mut levels: Vec<usize> = layer_map.keys().copied().collect();
418        levels.sort();
419
420        let merged_layers: Vec<HnswLayer> = levels
421            .into_iter()
422            .map(|level| {
423                let node_map = &layer_map[&level];
424                let mut nodes: Vec<LayerNode> = node_map.values().cloned().collect();
425                nodes.sort_by_key(|n| n.id);
426                HnswLayer { level, nodes }
427            })
428            .collect();
429
430        HnswSnapshot {
431            meta: delta.meta.clone(),
432            layers: merged_layers,
433        }
434    }
435}
436
437// ─── Tests ──────────────────────────────────────────────────────────────────
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use std::env::temp_dir;
443    use std::path::PathBuf;
444
445    fn sample_meta() -> HnswMeta {
446        HnswMeta {
447            max_layer: 3,
448            entry_point: Some(42),
449            ef_construction: 200,
450            m: 16,
451            node_count: 1000,
452        }
453    }
454
455    fn sample_layer(level: usize, node_ids: &[u64]) -> HnswLayer {
456        let nodes: Vec<LayerNode> = node_ids
457            .iter()
458            .map(|&id| LayerNode {
459                id,
460                neighbors: node_ids
461                    .iter()
462                    .copied()
463                    .filter(|&n| n != id)
464                    .take(4)
465                    .collect(),
466            })
467            .collect();
468        HnswLayer { level, nodes }
469    }
470
471    fn persist() -> HnswPersistence {
472        HnswPersistence::new()
473    }
474
475    fn tmp_path(name: &str) -> PathBuf {
476        temp_dir().join(format!("hnsw_test_{name}.bin"))
477    }
478
479    // ── CRC32 ─────────────────────────────────────────────────────────────────
480
481    #[test]
482    fn test_crc32_empty() {
483        let c = crc32(&[]);
484        // Standard CRC32 of empty input = 0x00000000
485        assert_eq!(c, 0x0000_0000);
486    }
487
488    #[test]
489    fn test_crc32_hello() {
490        // CRC32 of "hello" = 0x3610a686
491        let c = crc32(b"hello");
492        assert_eq!(c, 0x3610_a686);
493    }
494
495    #[test]
496    fn test_crc32_deterministic() {
497        let data = b"deterministic test data";
498        assert_eq!(crc32(data), crc32(data));
499    }
500
501    #[test]
502    fn test_crc32_different_data() {
503        assert_ne!(crc32(b"foo"), crc32(b"bar"));
504    }
505
506    // ── HnswMeta ─────────────────────────────────────────────────────────────
507
508    #[test]
509    fn test_meta_roundtrip() {
510        let meta = sample_meta();
511        let bytes = serialize_meta(&meta);
512        let mut pos = 0;
513        let decoded = deserialize_meta(&bytes, &mut pos).expect("ok");
514        assert_eq!(decoded, meta);
515    }
516
517    #[test]
518    fn test_meta_no_entry_point() {
519        let meta = HnswMeta {
520            max_layer: 0,
521            entry_point: None,
522            ef_construction: 100,
523            m: 8,
524            node_count: 0,
525        };
526        let bytes = serialize_meta(&meta);
527        let mut pos = 0;
528        let decoded = deserialize_meta(&bytes, &mut pos).expect("ok");
529        assert_eq!(decoded.entry_point, None);
530    }
531
532    // ── Serialize / Deserialize ───────────────────────────────────────────────
533
534    #[test]
535    fn test_serialize_deserialize_empty() {
536        let meta = sample_meta();
537        let layers: Vec<HnswLayer> = vec![];
538        let bytes = persist().serialize(&meta, &layers);
539        let (decoded_meta, decoded_layers) = persist().deserialize(&bytes).expect("ok");
540        assert_eq!(decoded_meta, meta);
541        assert!(decoded_layers.is_empty());
542    }
543
544    #[test]
545    fn test_serialize_deserialize_single_layer() {
546        let meta = sample_meta();
547        let layer = sample_layer(0, &[1, 2, 3, 4, 5]);
548        let bytes = persist().serialize(&meta, std::slice::from_ref(&layer));
549        let (_, layers) = persist().deserialize(&bytes).expect("ok");
550        assert_eq!(layers.len(), 1);
551        assert_eq!(layers[0].level, 0);
552        assert_eq!(layers[0].nodes.len(), 5);
553    }
554
555    #[test]
556    fn test_serialize_deserialize_multiple_layers() {
557        let meta = sample_meta();
558        let l0 = sample_layer(0, &[1, 2, 3]);
559        let l1 = sample_layer(1, &[1, 2]);
560        let l2 = sample_layer(2, &[1]);
561        let bytes = persist().serialize(&meta, &[l0, l1, l2]);
562        let (_, layers) = persist().deserialize(&bytes).expect("ok");
563        assert_eq!(layers.len(), 3);
564        assert_eq!(layers[0].level, 0);
565        assert_eq!(layers[1].level, 1);
566        assert_eq!(layers[2].level, 2);
567    }
568
569    #[test]
570    fn test_serialize_deserialize_neighbors() {
571        let meta = sample_meta();
572        let layer = HnswLayer {
573            level: 0,
574            nodes: vec![LayerNode {
575                id: 42,
576                neighbors: vec![1, 2, 3],
577            }],
578        };
579        let bytes = persist().serialize(&meta, &[layer]);
580        let (_, layers) = persist().deserialize(&bytes).expect("ok");
581        assert_eq!(layers[0].nodes[0].id, 42);
582        assert_eq!(layers[0].nodes[0].neighbors, vec![1, 2, 3]);
583    }
584
585    #[test]
586    fn test_serialize_magic() {
587        let meta = sample_meta();
588        let bytes = persist().serialize(&meta, &[]);
589        assert_eq!(&bytes[..4], b"HNSW");
590    }
591
592    #[test]
593    fn test_serialize_version() {
594        let meta = sample_meta();
595        let bytes = persist().serialize(&meta, &[]);
596        let version = u16::from_le_bytes(bytes[4..6].try_into().expect("2 bytes"));
597        assert_eq!(version, FORMAT_VERSION);
598    }
599
600    #[test]
601    fn test_deserialize_bad_magic() {
602        let mut bytes = persist().serialize(&sample_meta(), &[]);
603        bytes[0] = b'X';
604        let result = persist().deserialize(&bytes);
605        assert!(matches!(result, Err(PersistError::BadMagic)));
606    }
607
608    #[test]
609    fn test_deserialize_checksum_mismatch() {
610        let mut bytes = persist().serialize(&sample_meta(), &[]);
611        let last = bytes.len() - 1;
612        bytes[last] ^= 0xFF; // corrupt last byte
613        let result = persist().deserialize(&bytes);
614        assert!(matches!(result, Err(PersistError::ChecksumMismatch)));
615    }
616
617    #[test]
618    fn test_deserialize_too_short() {
619        let result = persist().deserialize(&[0u8; 5]);
620        assert!(matches!(result, Err(PersistError::Malformed(_))));
621    }
622
623    #[test]
624    fn test_deserialize_unsupported_version() {
625        let mut bytes = persist().serialize(&sample_meta(), &[]);
626        // Overwrite version field with 99
627        bytes[4] = 99;
628        bytes[5] = 0;
629        // Recompute checksum
630        let checksum = crc32(&bytes[10..]);
631        bytes[6..10].copy_from_slice(&checksum.to_le_bytes());
632        let result = persist().deserialize(&bytes);
633        assert!(matches!(result, Err(PersistError::UnsupportedVersion(99))));
634    }
635
636    // ── validate_checksum ─────────────────────────────────────────────────────
637
638    #[test]
639    fn test_validate_checksum_valid() {
640        let bytes = persist().serialize(&sample_meta(), &[]);
641        assert!(persist().validate_checksum(&bytes));
642    }
643
644    #[test]
645    fn test_validate_checksum_corrupted() {
646        let mut bytes = persist().serialize(&sample_meta(), &[]);
647        let last = bytes.len() - 1;
648        bytes[last] ^= 0x01;
649        assert!(!persist().validate_checksum(&bytes));
650    }
651
652    #[test]
653    fn test_validate_checksum_too_short() {
654        assert!(!persist().validate_checksum(&[0u8; 5]));
655    }
656
657    #[test]
658    fn test_validate_checksum_bad_magic() {
659        let mut bytes = persist().serialize(&sample_meta(), &[]);
660        bytes[0] = b'X';
661        assert!(!persist().validate_checksum(&bytes));
662    }
663
664    // ── File I/O ──────────────────────────────────────────────────────────────
665
666    #[test]
667    fn test_save_and_load() {
668        let path = tmp_path("save_load");
669        let meta = sample_meta();
670        let layer = sample_layer(0, &[10, 20, 30]);
671
672        persist()
673            .save_to_file(&path, &meta, std::slice::from_ref(&layer))
674            .expect("save ok");
675
676        let (loaded_meta, loaded_layers) = persist().load_from_file(&path).expect("load ok");
677
678        assert_eq!(loaded_meta, meta);
679        assert_eq!(loaded_layers.len(), 1);
680        assert_eq!(loaded_layers[0].nodes.len(), layer.nodes.len());
681
682        // Cleanup
683        let _ = std::fs::remove_file(&path);
684    }
685
686    #[test]
687    fn test_save_atomic_no_tmp_left() {
688        let path = tmp_path("atomic_no_tmp");
689        persist()
690            .save_to_file(&path, &sample_meta(), &[])
691            .expect("save ok");
692
693        // Temporary file should not exist
694        let mut tmp = path.clone();
695        let name = path.file_name().unwrap().to_str().unwrap();
696        tmp.set_file_name(format!(".{name}.tmp"));
697        assert!(!tmp.exists(), "temp file should have been cleaned up");
698
699        let _ = std::fs::remove_file(&path);
700    }
701
702    #[test]
703    fn test_load_nonexistent_file() {
704        let path = tmp_path("nonexistent_xyz_12345");
705        let result = persist().load_from_file(&path);
706        assert!(matches!(result, Err(PersistError::Io(_))));
707    }
708
709    // ── merge_snapshots ────────────────────────────────────────────────────────
710
711    #[test]
712    fn test_merge_empty_delta() {
713        let base = HnswSnapshot {
714            meta: sample_meta(),
715            layers: vec![sample_layer(0, &[1, 2, 3])],
716        };
717        let delta = HnswSnapshot {
718            meta: HnswMeta {
719                max_layer: 3,
720                entry_point: Some(1),
721                ef_construction: 200,
722                m: 16,
723                node_count: 4,
724            },
725            layers: vec![],
726        };
727        let merged = persist().merge_snapshots(&base, &delta);
728        // Meta from delta
729        assert_eq!(merged.meta.node_count, 4);
730        // Layer from base preserved
731        assert_eq!(merged.layers.len(), 1);
732        assert_eq!(merged.layers[0].nodes.len(), 3);
733    }
734
735    #[test]
736    fn test_merge_delta_adds_nodes() {
737        let base = HnswSnapshot {
738            meta: sample_meta(),
739            layers: vec![sample_layer(0, &[1, 2, 3])],
740        };
741        let delta = HnswSnapshot {
742            meta: HnswMeta {
743                node_count: 5,
744                ..sample_meta()
745            },
746            layers: vec![HnswLayer {
747                level: 0,
748                nodes: vec![
749                    LayerNode {
750                        id: 4,
751                        neighbors: vec![1, 2],
752                    },
753                    LayerNode {
754                        id: 5,
755                        neighbors: vec![3],
756                    },
757                ],
758            }],
759        };
760        let merged = persist().merge_snapshots(&base, &delta);
761        let l0 = &merged.layers[0];
762        let ids: Vec<u64> = l0.nodes.iter().map(|n| n.id).collect();
763        assert!(ids.contains(&1));
764        assert!(ids.contains(&4));
765        assert!(ids.contains(&5));
766        assert_eq!(ids.len(), 5);
767    }
768
769    #[test]
770    fn test_merge_delta_overrides_node() {
771        let base = HnswSnapshot {
772            meta: sample_meta(),
773            layers: vec![HnswLayer {
774                level: 0,
775                nodes: vec![LayerNode {
776                    id: 1,
777                    neighbors: vec![2, 3],
778                }],
779            }],
780        };
781        let delta = HnswSnapshot {
782            meta: sample_meta(),
783            layers: vec![HnswLayer {
784                level: 0,
785                nodes: vec![LayerNode {
786                    id: 1,
787                    neighbors: vec![4, 5, 6],
788                }], // new neighbors
789            }],
790        };
791        let merged = persist().merge_snapshots(&base, &delta);
792        let node1 = merged.layers[0]
793            .nodes
794            .iter()
795            .find(|n| n.id == 1)
796            .expect("node 1");
797        assert_eq!(node1.neighbors, vec![4, 5, 6]);
798    }
799
800    #[test]
801    fn test_merge_adds_new_layer() {
802        let base = HnswSnapshot {
803            meta: sample_meta(),
804            layers: vec![sample_layer(0, &[1, 2])],
805        };
806        let delta = HnswSnapshot {
807            meta: sample_meta(),
808            layers: vec![sample_layer(1, &[1])],
809        };
810        let merged = persist().merge_snapshots(&base, &delta);
811        assert_eq!(merged.layers.len(), 2);
812        let levels: Vec<usize> = merged.layers.iter().map(|l| l.level).collect();
813        assert!(levels.contains(&0));
814        assert!(levels.contains(&1));
815    }
816
817    #[test]
818    fn test_merge_meta_from_delta() {
819        let base = HnswSnapshot {
820            meta: HnswMeta {
821                max_layer: 0,
822                entry_point: Some(1),
823                ef_construction: 100,
824                m: 8,
825                node_count: 10,
826            },
827            layers: vec![],
828        };
829        let delta = HnswSnapshot {
830            meta: HnswMeta {
831                max_layer: 2,
832                entry_point: Some(99),
833                ef_construction: 200,
834                m: 16,
835                node_count: 500,
836            },
837            layers: vec![],
838        };
839        let merged = persist().merge_snapshots(&base, &delta);
840        assert_eq!(merged.meta.node_count, 500);
841        assert_eq!(merged.meta.entry_point, Some(99));
842    }
843
844    // ── Large-scale roundtrip ─────────────────────────────────────────────────
845
846    #[test]
847    fn test_large_roundtrip() {
848        let meta = HnswMeta {
849            max_layer: 4,
850            entry_point: Some(0),
851            ef_construction: 400,
852            m: 32,
853            node_count: 10_000,
854        };
855        let mut layers = Vec::new();
856        for level in 0..5 {
857            let node_count = 1000 >> level;
858            let ids: Vec<u64> = (0..node_count).map(|i| i as u64).collect();
859            layers.push(sample_layer(level, &ids));
860        }
861        let bytes = persist().serialize(&meta, &layers);
862        assert!(persist().validate_checksum(&bytes));
863        let (decoded_meta, decoded_layers) = persist().deserialize(&bytes).expect("ok");
864        assert_eq!(decoded_meta.m, 32);
865        assert_eq!(decoded_layers.len(), 5);
866    }
867}