Skip to main content

oxirs_vec/
hnsw_persistence.rs

1//! HNSW Index Serialization / Persistence
2//!
3//! This module provides binary serialization and file-based persistence for
4//! Hierarchical Navigable Small World (HNSW) index snapshots.
5//!
6//! # Format
7//!
8//! The binary format is a custom little-endian encoding (no bincode):
9//!
10//! ```text
11//! Magic:     4 bytes  = b"HNSW"
12//! Version:   2 bytes  = u16 LE
13//! CRC32:     4 bytes  = u32 LE (checksum of all following bytes)
14//! meta_len:  4 bytes  = u32 LE (length of serialized HnswMeta)
15//! meta_data: N bytes
16//! layers:    zero or more layer records:
17//!   layer_level: 4 bytes = u32 LE
18//!   node_count:  4 bytes = u32 LE
19//!   nodes:       node_count × LayerNode records:
20//!     node_id:        8 bytes = u64 LE
21//!     neighbor_count: 4 bytes = u32 LE
22//!     neighbors:      neighbor_count × 8 bytes (u64 LE each)
23//! ```
24//!
25//! # Atomic file writes
26//!
27//! `save_to_file` writes to a `.tmp` file first, then renames into place —
28//! ensuring that the destination file is never partially written.
29
30use std::collections::HashMap;
31use std::io::{self, Write};
32use std::path::Path;
33
34// ─── CRC32 implementation ─────────────────────────────────────────────────────
35//
36// Simple table-driven CRC32 (ISO 3309 / IEEE 802.3 polynomial 0xEDB88320).
37
38const CRC32_POLY: u32 = 0xEDB8_8320;
39
40/// Build the 256-entry CRC32 look-up table.
41fn build_crc32_table() -> [u32; 256] {
42    let mut table = [0u32; 256];
43    for i in 0u32..256 {
44        let mut c = i;
45        for _ in 0..8 {
46            if c & 1 != 0 {
47                c = CRC32_POLY ^ (c >> 1);
48            } else {
49                c >>= 1;
50            }
51        }
52        table[i as usize] = c;
53    }
54    table
55}
56
57/// Compute the CRC32 checksum of a byte slice.
58pub fn crc32(data: &[u8]) -> u32 {
59    let table = build_crc32_table();
60    let mut crc: u32 = 0xFFFF_FFFF;
61    for &byte in data {
62        let idx = ((crc ^ byte as u32) & 0xFF) as usize;
63        crc = table[idx] ^ (crc >> 8);
64    }
65    crc ^ 0xFFFF_FFFF
66}
67
68// ─── Domain types ─────────────────────────────────────────────────────────────
69
70/// HNSW index metadata / configuration.
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct HnswMeta {
73    /// Maximum layer index (0-based; layer 0 is the base layer)
74    pub max_layer: usize,
75    /// Entry-point node ID (None if the index is empty)
76    pub entry_point: Option<u64>,
77    /// `ef_construction` parameter
78    pub ef_construction: usize,
79    /// `M` parameter (number of bidirectional links per node)
80    pub m: usize,
81    /// Total number of distinct nodes across all layers
82    pub node_count: usize,
83}
84
85/// A single node in one HNSW layer: its ID and its neighbour list.
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct LayerNode {
88    /// Unique node identifier
89    pub id: u64,
90    /// Neighbours (up to `M` per layer)
91    pub neighbors: Vec<u64>,
92}
93
94/// One layer of the HNSW graph.
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub struct HnswLayer {
97    /// Layer level (0 = base)
98    pub level: usize,
99    /// Nodes in this layer
100    pub nodes: Vec<LayerNode>,
101}
102
103/// A complete snapshot of an HNSW index.
104#[derive(Debug, Clone, PartialEq, Eq)]
105pub struct HnswSnapshot {
106    /// Index metadata
107    pub meta: HnswMeta,
108    /// Layer data (ordered by level)
109    pub layers: Vec<HnswLayer>,
110}
111
112// ─── Errors ───────────────────────────────────────────────────────────────────
113
114/// Error from HNSW persistence operations.
115#[derive(Debug)]
116pub enum PersistError {
117    /// I/O failure (wrapped)
118    Io(io::Error),
119    /// Magic bytes mismatch
120    BadMagic,
121    /// Unsupported format version
122    UnsupportedVersion(u16),
123    /// CRC32 checksum mismatch
124    ChecksumMismatch,
125    /// Data is truncated or malformed
126    Malformed(String),
127}
128
129impl std::fmt::Display for PersistError {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        match self {
132            Self::Io(e) => write!(f, "IO error: {e}"),
133            Self::BadMagic => write!(f, "bad magic bytes (expected 'HNSW')"),
134            Self::UnsupportedVersion(v) => write!(f, "unsupported version: {v}"),
135            Self::ChecksumMismatch => write!(f, "CRC32 checksum mismatch"),
136            Self::Malformed(s) => write!(f, "malformed data: {s}"),
137        }
138    }
139}
140
141impl std::error::Error for PersistError {
142    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
143        match self {
144            Self::Io(e) => Some(e),
145            _ => None,
146        }
147    }
148}
149
150impl From<io::Error> for PersistError {
151    fn from(e: io::Error) -> Self {
152        Self::Io(e)
153    }
154}
155
156// ─── Format constants ─────────────────────────────────────────────────────────
157
158const MAGIC: &[u8; 4] = b"HNSW";
159const FORMAT_VERSION: u16 = 1;
160
161// ─── Low-level read / write helpers ──────────────────────────────────────────
162
163fn write_u16_le(buf: &mut Vec<u8>, v: u16) {
164    buf.extend_from_slice(&v.to_le_bytes());
165}
166
167fn write_u32_le(buf: &mut Vec<u8>, v: u32) {
168    buf.extend_from_slice(&v.to_le_bytes());
169}
170
171fn write_u64_le(buf: &mut Vec<u8>, v: u64) {
172    buf.extend_from_slice(&v.to_le_bytes());
173}
174
175fn read_u16_le(data: &[u8], pos: &mut usize) -> Result<u16, PersistError> {
176    if *pos + 2 > data.len() {
177        return Err(PersistError::Malformed("unexpected end (u16)".into()));
178    }
179    let v = u16::from_le_bytes(data[*pos..*pos + 2].try_into().expect("slice of 2"));
180    *pos += 2;
181    Ok(v)
182}
183
184fn read_u32_le(data: &[u8], pos: &mut usize) -> Result<u32, PersistError> {
185    if *pos + 4 > data.len() {
186        return Err(PersistError::Malformed("unexpected end (u32)".into()));
187    }
188    let v = u32::from_le_bytes(data[*pos..*pos + 4].try_into().expect("slice of 4"));
189    *pos += 4;
190    Ok(v)
191}
192
193fn read_u64_le(data: &[u8], pos: &mut usize) -> Result<u64, PersistError> {
194    if *pos + 8 > data.len() {
195        return Err(PersistError::Malformed("unexpected end (u64)".into()));
196    }
197    let v = u64::from_le_bytes(data[*pos..*pos + 8].try_into().expect("slice of 8"));
198    *pos += 8;
199    Ok(v)
200}
201
202// ─── Meta serialization ────────────────────────────────────────────────────────
203
204fn serialize_meta(meta: &HnswMeta) -> Vec<u8> {
205    let mut buf = Vec::new();
206    write_u64_le(&mut buf, meta.max_layer as u64);
207    match meta.entry_point {
208        Some(ep) => {
209            write_u32_le(&mut buf, 1);
210            write_u64_le(&mut buf, ep);
211        }
212        None => {
213            write_u32_le(&mut buf, 0);
214            write_u64_le(&mut buf, 0);
215        }
216    }
217    write_u64_le(&mut buf, meta.ef_construction as u64);
218    write_u64_le(&mut buf, meta.m as u64);
219    write_u64_le(&mut buf, meta.node_count as u64);
220    buf
221}
222
223fn deserialize_meta(data: &[u8], pos: &mut usize) -> Result<HnswMeta, PersistError> {
224    let max_layer = read_u64_le(data, pos)? as usize;
225    let has_ep = read_u32_le(data, pos)?;
226    let ep_raw = read_u64_le(data, pos)?;
227    let entry_point = if has_ep == 1 { Some(ep_raw) } else { None };
228    let ef_construction = read_u64_le(data, pos)? as usize;
229    let m = read_u64_le(data, pos)? as usize;
230    let node_count = read_u64_le(data, pos)? as usize;
231    Ok(HnswMeta {
232        max_layer,
233        entry_point,
234        ef_construction,
235        m,
236        node_count,
237    })
238}
239
240// ─── HnswPersistence ─────────────────────────────────────────────────────────
241
242/// Provides serialization and file persistence for HNSW index snapshots.
243#[derive(Debug, Default, Clone)]
244pub struct HnswPersistence;
245
246impl HnswPersistence {
247    /// Create a new persistence instance.
248    pub fn new() -> Self {
249        Self
250    }
251
252    /// Serialize `meta` + `layers` into a byte vector.
253    ///
254    /// The format is described in the module documentation.
255    pub fn serialize(&self, meta: &HnswMeta, layers: &[HnswLayer]) -> Vec<u8> {
256        // Serialize meta
257        let meta_bytes = serialize_meta(meta);
258
259        // Serialize layers
260        let mut layer_bytes: Vec<u8> = Vec::new();
261        for layer in layers {
262            write_u32_le(&mut layer_bytes, layer.level as u32);
263            write_u32_le(&mut layer_bytes, layer.nodes.len() as u32);
264            for node in &layer.nodes {
265                write_u64_le(&mut layer_bytes, node.id);
266                write_u32_le(&mut layer_bytes, node.neighbors.len() as u32);
267                for &nb in &node.neighbors {
268                    write_u64_le(&mut layer_bytes, nb);
269                }
270            }
271        }
272
273        // Build the payload (meta + layers) for checksumming
274        let mut payload: Vec<u8> = Vec::new();
275        write_u32_le(&mut payload, meta_bytes.len() as u32);
276        payload.extend_from_slice(&meta_bytes);
277        payload.extend_from_slice(&layer_bytes);
278
279        let checksum = crc32(&payload);
280
281        // Build the final buffer
282        let mut buf: Vec<u8> = Vec::with_capacity(10 + payload.len());
283        buf.extend_from_slice(MAGIC);
284        write_u16_le(&mut buf, FORMAT_VERSION);
285        write_u32_le(&mut buf, checksum);
286        buf.extend_from_slice(&payload);
287
288        buf
289    }
290
291    /// Deserialize bytes into `(HnswMeta, Vec<HnswLayer>)`.
292    pub fn deserialize(&self, bytes: &[u8]) -> Result<(HnswMeta, Vec<HnswLayer>), PersistError> {
293        if bytes.len() < 10 {
294            return Err(PersistError::Malformed("too short".into()));
295        }
296
297        // Magic
298        if &bytes[..4] != MAGIC {
299            return Err(PersistError::BadMagic);
300        }
301        let mut pos = 4;
302
303        // Version
304        let version = read_u16_le(bytes, &mut pos)?;
305        if version != FORMAT_VERSION {
306            return Err(PersistError::UnsupportedVersion(version));
307        }
308
309        // Checksum
310        let stored_crc = read_u32_le(bytes, &mut pos)?;
311        let computed_crc = crc32(&bytes[pos..]);
312        if stored_crc != computed_crc {
313            return Err(PersistError::ChecksumMismatch);
314        }
315
316        // Meta
317        let meta_len = read_u32_le(bytes, &mut pos)? as usize;
318        if pos + meta_len > bytes.len() {
319            return Err(PersistError::Malformed("meta_len out of bounds".into()));
320        }
321        let meta = deserialize_meta(bytes, &mut pos)?;
322
323        // Layers
324        let mut layers = Vec::new();
325        while pos < bytes.len() {
326            let level = read_u32_le(bytes, &mut pos)? as usize;
327            let node_count = read_u32_le(bytes, &mut pos)?;
328            let mut nodes = Vec::with_capacity(node_count as usize);
329            for _ in 0..node_count {
330                let id = read_u64_le(bytes, &mut pos)?;
331                let neighbor_count = read_u32_le(bytes, &mut pos)?;
332                let mut neighbors = Vec::with_capacity(neighbor_count as usize);
333                for _ in 0..neighbor_count {
334                    neighbors.push(read_u64_le(bytes, &mut pos)?);
335                }
336                nodes.push(LayerNode { id, neighbors });
337            }
338            layers.push(HnswLayer { level, nodes });
339        }
340
341        Ok((meta, layers))
342    }
343
344    /// Write a snapshot to a file atomically (write → temp, rename).
345    pub fn save_to_file(
346        &self,
347        path: &Path,
348        meta: &HnswMeta,
349        layers: &[HnswLayer],
350    ) -> Result<(), PersistError> {
351        let bytes = self.serialize(meta, layers);
352
353        // Write to a temporary file alongside the destination
354        let mut tmp_path = path.to_path_buf();
355        let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("hnsw");
356        tmp_path.set_file_name(format!(".{file_name}.tmp"));
357
358        {
359            let mut file = std::fs::File::create(&tmp_path)?;
360            file.write_all(&bytes)?;
361            file.flush()?;
362        }
363
364        // Atomic rename
365        std::fs::rename(&tmp_path, path)?;
366
367        Ok(())
368    }
369
370    /// Read a snapshot from a file.
371    pub fn load_from_file(&self, path: &Path) -> Result<(HnswMeta, Vec<HnswLayer>), PersistError> {
372        let bytes = std::fs::read(path)?;
373        self.deserialize(&bytes)
374    }
375
376    /// Validate the CRC32 checksum of serialized bytes.
377    ///
378    /// Returns `true` if the checksum field matches the computed checksum
379    /// of the payload, `false` otherwise.
380    pub fn validate_checksum(&self, bytes: &[u8]) -> bool {
381        if bytes.len() < 10 {
382            return false;
383        }
384        if &bytes[..4] != MAGIC {
385            return false;
386        }
387        let stored_crc = u32::from_le_bytes(bytes[6..10].try_into().expect("4 bytes"));
388        let computed_crc = crc32(&bytes[10..]);
389        stored_crc == computed_crc
390    }
391
392    /// Merge a base snapshot with a delta snapshot.
393    ///
394    /// Merge strategy:
395    /// - `meta` is taken from `delta` (delta always has the newer configuration)
396    /// - For each layer in delta, its nodes override/supplement the base layer nodes
397    ///   (keyed by `node.id`)
398    /// - Layers present only in base are preserved; layers only in delta are added
399    pub fn merge_snapshots(&self, base: &HnswSnapshot, delta: &HnswSnapshot) -> HnswSnapshot {
400        // Index base layers by level
401        let mut layer_map: HashMap<usize, HashMap<u64, LayerNode>> = HashMap::new();
402        for layer in &base.layers {
403            let node_map: HashMap<u64, LayerNode> =
404                layer.nodes.iter().map(|n| (n.id, n.clone())).collect();
405            layer_map.insert(layer.level, node_map);
406        }
407
408        // Apply delta layers
409        for layer in &delta.layers {
410            let node_map = layer_map.entry(layer.level).or_default();
411            for node in &layer.nodes {
412                node_map.insert(node.id, node.clone());
413            }
414        }
415
416        // Reconstruct ordered layers
417        let mut levels: Vec<usize> = layer_map.keys().copied().collect();
418        levels.sort();
419
420        let merged_layers: Vec<HnswLayer> = levels
421            .into_iter()
422            .map(|level| {
423                let node_map = &layer_map[&level];
424                let mut nodes: Vec<LayerNode> = node_map.values().cloned().collect();
425                nodes.sort_by_key(|n| n.id);
426                HnswLayer { level, nodes }
427            })
428            .collect();
429
430        HnswSnapshot {
431            meta: delta.meta.clone(),
432            layers: merged_layers,
433        }
434    }
435}
436
437// ─── Tests ──────────────────────────────────────────────────────────────────
438
439#[cfg(test)]
440mod tests {
441    type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
442    use super::*;
443    use std::env::temp_dir;
444    use std::path::PathBuf;
445
446    fn sample_meta() -> HnswMeta {
447        HnswMeta {
448            max_layer: 3,
449            entry_point: Some(42),
450            ef_construction: 200,
451            m: 16,
452            node_count: 1000,
453        }
454    }
455
456    fn sample_layer(level: usize, node_ids: &[u64]) -> HnswLayer {
457        let nodes: Vec<LayerNode> = node_ids
458            .iter()
459            .map(|&id| LayerNode {
460                id,
461                neighbors: node_ids
462                    .iter()
463                    .copied()
464                    .filter(|&n| n != id)
465                    .take(4)
466                    .collect(),
467            })
468            .collect();
469        HnswLayer { level, nodes }
470    }
471
472    fn persist() -> HnswPersistence {
473        HnswPersistence::new()
474    }
475
476    fn tmp_path(name: &str) -> PathBuf {
477        temp_dir().join(format!("hnsw_test_{name}.bin"))
478    }
479
480    // ── CRC32 ─────────────────────────────────────────────────────────────────
481
482    #[test]
483    fn test_crc32_empty() {
484        let c = crc32(&[]);
485        // Standard CRC32 of empty input = 0x00000000
486        assert_eq!(c, 0x0000_0000);
487    }
488
489    #[test]
490    fn test_crc32_hello() {
491        // CRC32 of "hello" = 0x3610a686
492        let c = crc32(b"hello");
493        assert_eq!(c, 0x3610_a686);
494    }
495
496    #[test]
497    fn test_crc32_deterministic() {
498        let data = b"deterministic test data";
499        assert_eq!(crc32(data), crc32(data));
500    }
501
502    #[test]
503    fn test_crc32_different_data() {
504        assert_ne!(crc32(b"foo"), crc32(b"bar"));
505    }
506
507    // ── HnswMeta ─────────────────────────────────────────────────────────────
508
509    #[test]
510    fn test_meta_roundtrip() {
511        let meta = sample_meta();
512        let bytes = serialize_meta(&meta);
513        let mut pos = 0;
514        let decoded = deserialize_meta(&bytes, &mut pos).expect("ok");
515        assert_eq!(decoded, meta);
516    }
517
518    #[test]
519    fn test_meta_no_entry_point() {
520        let meta = HnswMeta {
521            max_layer: 0,
522            entry_point: None,
523            ef_construction: 100,
524            m: 8,
525            node_count: 0,
526        };
527        let bytes = serialize_meta(&meta);
528        let mut pos = 0;
529        let decoded = deserialize_meta(&bytes, &mut pos).expect("ok");
530        assert_eq!(decoded.entry_point, None);
531    }
532
533    // ── Serialize / Deserialize ───────────────────────────────────────────────
534
535    #[test]
536    fn test_serialize_deserialize_empty() {
537        let meta = sample_meta();
538        let layers: Vec<HnswLayer> = vec![];
539        let bytes = persist().serialize(&meta, &layers);
540        let (decoded_meta, decoded_layers) = persist().deserialize(&bytes).expect("ok");
541        assert_eq!(decoded_meta, meta);
542        assert!(decoded_layers.is_empty());
543    }
544
545    #[test]
546    fn test_serialize_deserialize_single_layer() {
547        let meta = sample_meta();
548        let layer = sample_layer(0, &[1, 2, 3, 4, 5]);
549        let bytes = persist().serialize(&meta, std::slice::from_ref(&layer));
550        let (_, layers) = persist().deserialize(&bytes).expect("ok");
551        assert_eq!(layers.len(), 1);
552        assert_eq!(layers[0].level, 0);
553        assert_eq!(layers[0].nodes.len(), 5);
554    }
555
556    #[test]
557    fn test_serialize_deserialize_multiple_layers() {
558        let meta = sample_meta();
559        let l0 = sample_layer(0, &[1, 2, 3]);
560        let l1 = sample_layer(1, &[1, 2]);
561        let l2 = sample_layer(2, &[1]);
562        let bytes = persist().serialize(&meta, &[l0, l1, l2]);
563        let (_, layers) = persist().deserialize(&bytes).expect("ok");
564        assert_eq!(layers.len(), 3);
565        assert_eq!(layers[0].level, 0);
566        assert_eq!(layers[1].level, 1);
567        assert_eq!(layers[2].level, 2);
568    }
569
570    #[test]
571    fn test_serialize_deserialize_neighbors() {
572        let meta = sample_meta();
573        let layer = HnswLayer {
574            level: 0,
575            nodes: vec![LayerNode {
576                id: 42,
577                neighbors: vec![1, 2, 3],
578            }],
579        };
580        let bytes = persist().serialize(&meta, &[layer]);
581        let (_, layers) = persist().deserialize(&bytes).expect("ok");
582        assert_eq!(layers[0].nodes[0].id, 42);
583        assert_eq!(layers[0].nodes[0].neighbors, vec![1, 2, 3]);
584    }
585
586    #[test]
587    fn test_serialize_magic() {
588        let meta = sample_meta();
589        let bytes = persist().serialize(&meta, &[]);
590        assert_eq!(&bytes[..4], b"HNSW");
591    }
592
593    #[test]
594    fn test_serialize_version() {
595        let meta = sample_meta();
596        let bytes = persist().serialize(&meta, &[]);
597        let version = u16::from_le_bytes(bytes[4..6].try_into().expect("2 bytes"));
598        assert_eq!(version, FORMAT_VERSION);
599    }
600
601    #[test]
602    fn test_deserialize_bad_magic() {
603        let mut bytes = persist().serialize(&sample_meta(), &[]);
604        bytes[0] = b'X';
605        let result = persist().deserialize(&bytes);
606        assert!(matches!(result, Err(PersistError::BadMagic)));
607    }
608
609    #[test]
610    fn test_deserialize_checksum_mismatch() {
611        let mut bytes = persist().serialize(&sample_meta(), &[]);
612        let last = bytes.len() - 1;
613        bytes[last] ^= 0xFF; // corrupt last byte
614        let result = persist().deserialize(&bytes);
615        assert!(matches!(result, Err(PersistError::ChecksumMismatch)));
616    }
617
618    #[test]
619    fn test_deserialize_too_short() {
620        let result = persist().deserialize(&[0u8; 5]);
621        assert!(matches!(result, Err(PersistError::Malformed(_))));
622    }
623
624    #[test]
625    fn test_deserialize_unsupported_version() {
626        let mut bytes = persist().serialize(&sample_meta(), &[]);
627        // Overwrite version field with 99
628        bytes[4] = 99;
629        bytes[5] = 0;
630        // Recompute checksum
631        let checksum = crc32(&bytes[10..]);
632        bytes[6..10].copy_from_slice(&checksum.to_le_bytes());
633        let result = persist().deserialize(&bytes);
634        assert!(matches!(result, Err(PersistError::UnsupportedVersion(99))));
635    }
636
637    // ── validate_checksum ─────────────────────────────────────────────────────
638
639    #[test]
640    fn test_validate_checksum_valid() {
641        let bytes = persist().serialize(&sample_meta(), &[]);
642        assert!(persist().validate_checksum(&bytes));
643    }
644
645    #[test]
646    fn test_validate_checksum_corrupted() {
647        let mut bytes = persist().serialize(&sample_meta(), &[]);
648        let last = bytes.len() - 1;
649        bytes[last] ^= 0x01;
650        assert!(!persist().validate_checksum(&bytes));
651    }
652
653    #[test]
654    fn test_validate_checksum_too_short() {
655        assert!(!persist().validate_checksum(&[0u8; 5]));
656    }
657
658    #[test]
659    fn test_validate_checksum_bad_magic() {
660        let mut bytes = persist().serialize(&sample_meta(), &[]);
661        bytes[0] = b'X';
662        assert!(!persist().validate_checksum(&bytes));
663    }
664
665    // ── File I/O ──────────────────────────────────────────────────────────────
666
667    #[test]
668    fn test_save_and_load() {
669        let path = tmp_path("save_load");
670        let meta = sample_meta();
671        let layer = sample_layer(0, &[10, 20, 30]);
672
673        persist()
674            .save_to_file(&path, &meta, std::slice::from_ref(&layer))
675            .expect("save ok");
676
677        let (loaded_meta, loaded_layers) = persist().load_from_file(&path).expect("load ok");
678
679        assert_eq!(loaded_meta, meta);
680        assert_eq!(loaded_layers.len(), 1);
681        assert_eq!(loaded_layers[0].nodes.len(), layer.nodes.len());
682
683        // Cleanup
684        let _ = std::fs::remove_file(&path);
685    }
686
687    #[test]
688    fn test_save_atomic_no_tmp_left() -> Result<()> {
689        let path = tmp_path("atomic_no_tmp");
690        persist()
691            .save_to_file(&path, &sample_meta(), &[])
692            .expect("save ok");
693
694        // Temporary file should not exist
695        let mut tmp = path.clone();
696        let name = path
697            .file_name()
698            .expect("path has a file name")
699            .to_str()
700            .expect("file name is valid UTF-8");
701        tmp.set_file_name(format!(".{name}.tmp"));
702        assert!(!tmp.exists(), "temp file should have been cleaned up");
703
704        let _ = std::fs::remove_file(&path);
705        Ok(())
706    }
707
708    #[test]
709    fn test_load_nonexistent_file() {
710        let path = tmp_path("nonexistent_xyz_12345");
711        let result = persist().load_from_file(&path);
712        assert!(matches!(result, Err(PersistError::Io(_))));
713    }
714
715    // ── merge_snapshots ────────────────────────────────────────────────────────
716
717    #[test]
718    fn test_merge_empty_delta() {
719        let base = HnswSnapshot {
720            meta: sample_meta(),
721            layers: vec![sample_layer(0, &[1, 2, 3])],
722        };
723        let delta = HnswSnapshot {
724            meta: HnswMeta {
725                max_layer: 3,
726                entry_point: Some(1),
727                ef_construction: 200,
728                m: 16,
729                node_count: 4,
730            },
731            layers: vec![],
732        };
733        let merged = persist().merge_snapshots(&base, &delta);
734        // Meta from delta
735        assert_eq!(merged.meta.node_count, 4);
736        // Layer from base preserved
737        assert_eq!(merged.layers.len(), 1);
738        assert_eq!(merged.layers[0].nodes.len(), 3);
739    }
740
741    #[test]
742    fn test_merge_delta_adds_nodes() {
743        let base = HnswSnapshot {
744            meta: sample_meta(),
745            layers: vec![sample_layer(0, &[1, 2, 3])],
746        };
747        let delta = HnswSnapshot {
748            meta: HnswMeta {
749                node_count: 5,
750                ..sample_meta()
751            },
752            layers: vec![HnswLayer {
753                level: 0,
754                nodes: vec![
755                    LayerNode {
756                        id: 4,
757                        neighbors: vec![1, 2],
758                    },
759                    LayerNode {
760                        id: 5,
761                        neighbors: vec![3],
762                    },
763                ],
764            }],
765        };
766        let merged = persist().merge_snapshots(&base, &delta);
767        let l0 = &merged.layers[0];
768        let ids: Vec<u64> = l0.nodes.iter().map(|n| n.id).collect();
769        assert!(ids.contains(&1));
770        assert!(ids.contains(&4));
771        assert!(ids.contains(&5));
772        assert_eq!(ids.len(), 5);
773    }
774
775    #[test]
776    fn test_merge_delta_overrides_node() {
777        let base = HnswSnapshot {
778            meta: sample_meta(),
779            layers: vec![HnswLayer {
780                level: 0,
781                nodes: vec![LayerNode {
782                    id: 1,
783                    neighbors: vec![2, 3],
784                }],
785            }],
786        };
787        let delta = HnswSnapshot {
788            meta: sample_meta(),
789            layers: vec![HnswLayer {
790                level: 0,
791                nodes: vec![LayerNode {
792                    id: 1,
793                    neighbors: vec![4, 5, 6],
794                }], // new neighbors
795            }],
796        };
797        let merged = persist().merge_snapshots(&base, &delta);
798        let node1 = merged.layers[0]
799            .nodes
800            .iter()
801            .find(|n| n.id == 1)
802            .expect("node 1");
803        assert_eq!(node1.neighbors, vec![4, 5, 6]);
804    }
805
806    #[test]
807    fn test_merge_adds_new_layer() {
808        let base = HnswSnapshot {
809            meta: sample_meta(),
810            layers: vec![sample_layer(0, &[1, 2])],
811        };
812        let delta = HnswSnapshot {
813            meta: sample_meta(),
814            layers: vec![sample_layer(1, &[1])],
815        };
816        let merged = persist().merge_snapshots(&base, &delta);
817        assert_eq!(merged.layers.len(), 2);
818        let levels: Vec<usize> = merged.layers.iter().map(|l| l.level).collect();
819        assert!(levels.contains(&0));
820        assert!(levels.contains(&1));
821    }
822
823    #[test]
824    fn test_merge_meta_from_delta() {
825        let base = HnswSnapshot {
826            meta: HnswMeta {
827                max_layer: 0,
828                entry_point: Some(1),
829                ef_construction: 100,
830                m: 8,
831                node_count: 10,
832            },
833            layers: vec![],
834        };
835        let delta = HnswSnapshot {
836            meta: HnswMeta {
837                max_layer: 2,
838                entry_point: Some(99),
839                ef_construction: 200,
840                m: 16,
841                node_count: 500,
842            },
843            layers: vec![],
844        };
845        let merged = persist().merge_snapshots(&base, &delta);
846        assert_eq!(merged.meta.node_count, 500);
847        assert_eq!(merged.meta.entry_point, Some(99));
848    }
849
850    // ── Large-scale roundtrip ─────────────────────────────────────────────────
851
852    #[test]
853    fn test_large_roundtrip() {
854        let meta = HnswMeta {
855            max_layer: 4,
856            entry_point: Some(0),
857            ef_construction: 400,
858            m: 32,
859            node_count: 10_000,
860        };
861        let mut layers = Vec::new();
862        for level in 0..5 {
863            let node_count = 1000 >> level;
864            let ids: Vec<u64> = (0..node_count).map(|i| i as u64).collect();
865            layers.push(sample_layer(level, &ids));
866        }
867        let bytes = persist().serialize(&meta, &layers);
868        assert!(persist().validate_checksum(&bytes));
869        let (decoded_meta, decoded_layers) = persist().deserialize(&bytes).expect("ok");
870        assert_eq!(decoded_meta.m, 32);
871        assert_eq!(decoded_layers.len(), 5);
872    }
873}