manifoldb_vector/index/
persistence.rs

1//! HNSW index persistence.
2//!
3//! This module handles saving and loading HNSW indexes to/from storage.
4//!
5//! ## Storage Layout
6//!
7//! The HNSW index uses the following key prefixes:
8//! - `0x20` - Index metadata (entry point, max layer, config)
9//! - `0x21` - Node data (embedding, max_layer)
10//! - `0x22` - Node connections (per-layer neighbor lists)
11
12// Allow unwrap on try_into for fixed-size slice conversions which are guaranteed to succeed
13#![allow(clippy::unwrap_used)]
14
15use manifoldb_core::EntityId;
16use manifoldb_storage::{Cursor, StorageEngine, Transaction};
17
18use crate::distance::DistanceMetric;
19use crate::error::VectorError;
20use crate::types::Embedding;
21
22use super::config::HnswConfig;
23use super::graph::{HnswGraph, HnswNode};
24
25/// Key prefix for HNSW index metadata.
26pub const PREFIX_HNSW_META: u8 = 0x20;
27
28/// Key prefix for HNSW node data.
29pub const PREFIX_HNSW_NODE: u8 = 0x21;
30
31/// Key prefix for HNSW node connections.
32pub const PREFIX_HNSW_CONNECTIONS: u8 = 0x22;
33
34/// Table name for HNSW index data.
35pub fn table_name(index_name: &str) -> String {
36    format!("hnsw_{index_name}")
37}
38
39/// Encode the metadata key for an index.
40fn encode_meta_key() -> Vec<u8> {
41    vec![PREFIX_HNSW_META]
42}
43
44/// Encode a node key.
45fn encode_node_key(entity_id: EntityId) -> Vec<u8> {
46    let mut key = Vec::with_capacity(9);
47    key.push(PREFIX_HNSW_NODE);
48    key.extend_from_slice(&entity_id.as_u64().to_be_bytes());
49    key
50}
51
52/// Encode a connections key for a node at a specific layer.
53fn encode_connections_key(entity_id: EntityId, layer: usize) -> Vec<u8> {
54    let mut key = Vec::with_capacity(13);
55    key.push(PREFIX_HNSW_CONNECTIONS);
56    key.extend_from_slice(&entity_id.as_u64().to_be_bytes());
57    key.extend_from_slice(&(layer as u32).to_be_bytes());
58    key
59}
60
61/// Decode a node key.
62fn decode_node_key(key: &[u8]) -> Option<EntityId> {
63    if key.len() != 9 || key[0] != PREFIX_HNSW_NODE {
64        return None;
65    }
66    let bytes: [u8; 8] = key[1..9].try_into().ok()?;
67    Some(EntityId::new(u64::from_be_bytes(bytes)))
68}
69
70/// Index metadata stored in the database.
71#[derive(Debug, Clone)]
72pub struct IndexMetadata {
73    /// The dimension of embeddings.
74    pub dimension: usize,
75    /// The distance metric.
76    pub distance_metric: DistanceMetric,
77    /// The entry point entity ID, if any.
78    pub entry_point: Option<EntityId>,
79    /// The maximum layer in the graph.
80    pub max_layer: usize,
81    /// The M parameter.
82    pub m: usize,
83    /// The M_max0 parameter.
84    pub m_max0: usize,
85    /// The ef_construction parameter.
86    pub ef_construction: usize,
87    /// The ef_search parameter.
88    pub ef_search: usize,
89    /// The ml parameter (stored as bits).
90    pub ml_bits: u64,
91    /// Number of PQ segments (0 = disabled).
92    pub pq_segments: usize,
93    /// Number of PQ centroids per segment.
94    pub pq_centroids: usize,
95}
96
97impl IndexMetadata {
98    /// Serialize metadata to bytes.
99    pub fn to_bytes(&self) -> Vec<u8> {
100        let mut bytes = Vec::with_capacity(72);
101
102        // Version byte (2 = with PQ support)
103        bytes.push(2);
104
105        // Dimension (4 bytes)
106        bytes.extend_from_slice(&(self.dimension as u32).to_be_bytes());
107
108        // Distance metric (1 byte)
109        bytes.push(match self.distance_metric {
110            DistanceMetric::Euclidean => 0,
111            DistanceMetric::Cosine => 1,
112            DistanceMetric::DotProduct => 2,
113            DistanceMetric::Manhattan => 3,
114            DistanceMetric::Chebyshev => 4,
115        });
116
117        // Entry point (1 byte flag + 8 bytes if present)
118        if let Some(ep) = self.entry_point {
119            bytes.push(1);
120            bytes.extend_from_slice(&ep.as_u64().to_be_bytes());
121        } else {
122            bytes.push(0);
123        }
124
125        // Max layer (4 bytes)
126        bytes.extend_from_slice(&(self.max_layer as u32).to_be_bytes());
127
128        // Config parameters (4 bytes each)
129        bytes.extend_from_slice(&(self.m as u32).to_be_bytes());
130        bytes.extend_from_slice(&(self.m_max0 as u32).to_be_bytes());
131        bytes.extend_from_slice(&(self.ef_construction as u32).to_be_bytes());
132        bytes.extend_from_slice(&(self.ef_search as u32).to_be_bytes());
133
134        // ml as bits (8 bytes)
135        bytes.extend_from_slice(&self.ml_bits.to_be_bytes());
136
137        // PQ parameters (4 bytes each)
138        bytes.extend_from_slice(&(self.pq_segments as u32).to_be_bytes());
139        bytes.extend_from_slice(&(self.pq_centroids as u32).to_be_bytes());
140
141        bytes
142    }
143
144    /// Deserialize metadata from bytes.
145    pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
146        if bytes.is_empty() {
147            return Err(VectorError::Encoding("empty metadata".into()));
148        }
149
150        let version = bytes[0];
151        if version != 1 && version != 2 {
152            return Err(VectorError::Encoding(format!("unsupported metadata version: {version}")));
153        }
154
155        let mut pos = 1;
156
157        let read_u32 = |bytes: &[u8], pos: &mut usize| -> Result<u32, VectorError> {
158            if *pos + 4 > bytes.len() {
159                return Err(VectorError::Encoding("truncated metadata".into()));
160            }
161            let val = u32::from_be_bytes(bytes[*pos..*pos + 4].try_into().unwrap());
162            *pos += 4;
163            Ok(val)
164        };
165
166        let read_u64 = |bytes: &[u8], pos: &mut usize| -> Result<u64, VectorError> {
167            if *pos + 8 > bytes.len() {
168                return Err(VectorError::Encoding("truncated metadata".into()));
169            }
170            let val = u64::from_be_bytes(bytes[*pos..*pos + 8].try_into().unwrap());
171            *pos += 8;
172            Ok(val)
173        };
174
175        let dimension = read_u32(bytes, &mut pos)? as usize;
176
177        if pos >= bytes.len() {
178            return Err(VectorError::Encoding("truncated metadata".into()));
179        }
180        let distance_metric = match bytes[pos] {
181            0 => DistanceMetric::Euclidean,
182            1 => DistanceMetric::Cosine,
183            2 => DistanceMetric::DotProduct,
184            3 => DistanceMetric::Manhattan,
185            4 => DistanceMetric::Chebyshev,
186            b => return Err(VectorError::Encoding(format!("unknown distance metric: {b}"))),
187        };
188        pos += 1;
189
190        if pos >= bytes.len() {
191            return Err(VectorError::Encoding("truncated metadata".into()));
192        }
193        let has_entry_point = bytes[pos] == 1;
194        pos += 1;
195
196        let entry_point =
197            if has_entry_point { Some(EntityId::new(read_u64(bytes, &mut pos)?)) } else { None };
198
199        let max_layer = read_u32(bytes, &mut pos)? as usize;
200        let m = read_u32(bytes, &mut pos)? as usize;
201        let m_max0 = read_u32(bytes, &mut pos)? as usize;
202        let ef_construction = read_u32(bytes, &mut pos)? as usize;
203        let ef_search = read_u32(bytes, &mut pos)? as usize;
204        let ml_bits = read_u64(bytes, &mut pos)?;
205
206        // PQ parameters (version 2+)
207        let (pq_segments, pq_centroids) = if version >= 2 {
208            let segments = read_u32(bytes, &mut pos)? as usize;
209            let centroids = read_u32(bytes, &mut pos)? as usize;
210            (segments, centroids)
211        } else {
212            (0, 256) // Defaults for version 1
213        };
214
215        Ok(Self {
216            dimension,
217            distance_metric,
218            entry_point,
219            max_layer,
220            m,
221            m_max0,
222            ef_construction,
223            ef_search,
224            ml_bits,
225            pq_segments,
226            pq_centroids,
227        })
228    }
229}
230
231/// Node data stored in the database (without connections).
232#[derive(Debug, Clone)]
233pub struct NodeData {
234    /// The embedding vector.
235    pub embedding: Embedding,
236    /// The maximum layer this node appears in.
237    pub max_layer: usize,
238}
239
240impl NodeData {
241    /// Serialize node data to bytes.
242    pub fn to_bytes(&self) -> Vec<u8> {
243        let embedding_bytes = self.embedding.to_bytes();
244        let mut bytes = Vec::with_capacity(5 + embedding_bytes.len());
245
246        // Version byte
247        bytes.push(1);
248
249        // Max layer (4 bytes)
250        bytes.extend_from_slice(&(self.max_layer as u32).to_be_bytes());
251
252        // Embedding bytes
253        bytes.extend_from_slice(&embedding_bytes);
254
255        bytes
256    }
257
258    /// Deserialize node data from bytes.
259    pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
260        if bytes.len() < 5 {
261            return Err(VectorError::Encoding("truncated node data".into()));
262        }
263
264        let version = bytes[0];
265        if version != 1 {
266            return Err(VectorError::Encoding(format!("unsupported node data version: {version}")));
267        }
268
269        let max_layer = u32::from_be_bytes(bytes[1..5].try_into().unwrap()) as usize;
270        let embedding = Embedding::from_bytes(&bytes[5..])?;
271
272        Ok(Self { embedding, max_layer })
273    }
274}
275
276/// Serialize a list of neighbor IDs.
277fn serialize_connections(neighbors: &[EntityId]) -> Vec<u8> {
278    let mut bytes = Vec::with_capacity(4 + neighbors.len() * 8);
279
280    // Number of neighbors (4 bytes)
281    bytes.extend_from_slice(&(neighbors.len() as u32).to_be_bytes());
282
283    // Each neighbor ID (8 bytes each)
284    for &id in neighbors {
285        bytes.extend_from_slice(&id.as_u64().to_be_bytes());
286    }
287
288    bytes
289}
290
291/// Deserialize a list of neighbor IDs.
292fn deserialize_connections(bytes: &[u8]) -> Result<Vec<EntityId>, VectorError> {
293    if bytes.len() < 4 {
294        return Err(VectorError::Encoding("truncated connections data".into()));
295    }
296
297    let count = u32::from_be_bytes(bytes[0..4].try_into().unwrap()) as usize;
298    let expected_len = 4 + count * 8;
299
300    if bytes.len() < expected_len {
301        return Err(VectorError::Encoding("truncated connections data".into()));
302    }
303
304    let mut neighbors = Vec::with_capacity(count);
305    for i in 0..count {
306        let start = 4 + i * 8;
307        let id = u64::from_be_bytes(bytes[start..start + 8].try_into().unwrap());
308        neighbors.push(EntityId::new(id));
309    }
310
311    Ok(neighbors)
312}
313
314/// Save index metadata to storage.
315pub fn save_metadata<E: StorageEngine>(
316    engine: &E,
317    table: &str,
318    metadata: &IndexMetadata,
319) -> Result<(), VectorError> {
320    let mut tx = engine.begin_write()?;
321    tx.put(table, &encode_meta_key(), &metadata.to_bytes())?;
322    tx.commit()?;
323    Ok(())
324}
325
326/// Load index metadata from storage.
327pub fn load_metadata<E: StorageEngine>(
328    engine: &E,
329    table: &str,
330) -> Result<Option<IndexMetadata>, VectorError> {
331    let tx = engine.begin_read()?;
332    match tx.get(table, &encode_meta_key())? {
333        Some(bytes) => Ok(Some(IndexMetadata::from_bytes(&bytes)?)),
334        None => Ok(None),
335    }
336}
337
338/// Save a single node to storage.
339pub fn save_node<E: StorageEngine>(
340    engine: &E,
341    table: &str,
342    node: &HnswNode,
343) -> Result<(), VectorError> {
344    let mut tx = engine.begin_write()?;
345
346    // Save node data
347    let node_data = NodeData { embedding: node.embedding.clone(), max_layer: node.max_layer };
348    tx.put(table, &encode_node_key(node.entity_id), &node_data.to_bytes())?;
349
350    // Save connections for each layer
351    for (layer, neighbors) in node.connections.iter().enumerate() {
352        let key = encode_connections_key(node.entity_id, layer);
353        tx.put(table, &key, &serialize_connections(neighbors))?;
354    }
355
356    tx.commit()?;
357    Ok(())
358}
359
360/// Load a single node from storage.
361pub fn load_node<E: StorageEngine>(
362    engine: &E,
363    table: &str,
364    entity_id: EntityId,
365) -> Result<Option<HnswNode>, VectorError> {
366    let tx = engine.begin_read()?;
367
368    // Load node data
369    let node_data = match tx.get(table, &encode_node_key(entity_id))? {
370        Some(bytes) => NodeData::from_bytes(&bytes)?,
371        None => return Ok(None),
372    };
373
374    // Load connections for each layer
375    let mut connections = Vec::with_capacity(node_data.max_layer + 1);
376    for layer in 0..=node_data.max_layer {
377        let key = encode_connections_key(entity_id, layer);
378        let neighbors = match tx.get(table, &key)? {
379            Some(bytes) => deserialize_connections(&bytes)?,
380            None => Vec::new(),
381        };
382        connections.push(neighbors);
383    }
384
385    Ok(Some(HnswNode {
386        entity_id,
387        embedding: node_data.embedding,
388        max_layer: node_data.max_layer,
389        connections,
390    }))
391}
392
393/// Delete a node from storage.
394pub fn delete_node<E: StorageEngine>(
395    engine: &E,
396    table: &str,
397    entity_id: EntityId,
398    max_layer: usize,
399) -> Result<bool, VectorError> {
400    let mut tx = engine.begin_write()?;
401
402    // Delete node data
403    let existed = tx.delete(table, &encode_node_key(entity_id))?;
404
405    // Delete connections for each layer
406    for layer in 0..=max_layer {
407        let key = encode_connections_key(entity_id, layer);
408        tx.delete(table, &key)?;
409    }
410
411    tx.commit()?;
412    Ok(existed)
413}
414
415/// Update the connections for a node at a specific layer.
416pub fn update_connections<E: StorageEngine>(
417    engine: &E,
418    table: &str,
419    entity_id: EntityId,
420    layer: usize,
421    neighbors: &[EntityId],
422) -> Result<(), VectorError> {
423    let mut tx = engine.begin_write()?;
424    let key = encode_connections_key(entity_id, layer);
425    tx.put(table, &key, &serialize_connections(neighbors))?;
426    tx.commit()?;
427    Ok(())
428}
429
430/// Load the entire graph from storage.
431pub fn load_graph<E: StorageEngine>(
432    engine: &E,
433    table: &str,
434    metadata: &IndexMetadata,
435) -> Result<HnswGraph, VectorError> {
436    let mut graph = HnswGraph::new(metadata.dimension, metadata.distance_metric);
437    graph.entry_point = metadata.entry_point;
438    graph.max_layer = metadata.max_layer;
439
440    let tx = engine.begin_read()?;
441
442    // Scan for all node keys
443    let node_prefix = [PREFIX_HNSW_NODE];
444    let node_end = [PREFIX_HNSW_NODE + 1];
445
446    let mut cursor = tx.range(
447        table,
448        std::ops::Bound::Included(&node_prefix[..]),
449        std::ops::Bound::Excluded(&node_end[..]),
450    )?;
451
452    let mut entity_ids = Vec::new();
453    while let Some((key, _)) = cursor.next()? {
454        if let Some(entity_id) = decode_node_key(&key) {
455            entity_ids.push(entity_id);
456        }
457    }
458    drop(cursor);
459    drop(tx);
460
461    // Load each node
462    for entity_id in entity_ids {
463        if let Some(node) = load_node(engine, table, entity_id)? {
464            graph.nodes.insert(entity_id, node);
465        }
466    }
467
468    Ok(graph)
469}
470
471/// Save the entire graph to storage.
472pub fn save_graph<E: StorageEngine>(
473    engine: &E,
474    table: &str,
475    graph: &HnswGraph,
476    config: &HnswConfig,
477) -> Result<(), VectorError> {
478    // Save metadata
479    let metadata = IndexMetadata {
480        dimension: graph.dimension,
481        distance_metric: graph.distance_metric,
482        entry_point: graph.entry_point,
483        max_layer: graph.max_layer,
484        m: config.m,
485        m_max0: config.m_max0,
486        ef_construction: config.ef_construction,
487        ef_search: config.ef_search,
488        ml_bits: config.ml.to_bits(),
489        pq_segments: config.pq_segments,
490        pq_centroids: config.pq_centroids,
491    };
492    save_metadata(engine, table, &metadata)?;
493
494    // Save all nodes
495    for node in graph.nodes.values() {
496        save_node(engine, table, node)?;
497    }
498
499    Ok(())
500}
501
502// =============================================================================
503// Transaction-aware persistence functions
504// =============================================================================
505// These functions accept an existing transaction reference, allowing HNSW
506// operations to be batched within a larger transaction (e.g., during DML).
507
508/// Save index metadata within an existing transaction.
509pub fn save_metadata_tx<T: Transaction>(
510    tx: &mut T,
511    table: &str,
512    metadata: &IndexMetadata,
513) -> Result<(), VectorError> {
514    tx.put(table, &encode_meta_key(), &metadata.to_bytes())?;
515    Ok(())
516}
517
518/// Load index metadata within an existing transaction.
519pub fn load_metadata_tx<T: Transaction>(
520    tx: &T,
521    table: &str,
522) -> Result<Option<IndexMetadata>, VectorError> {
523    match tx.get(table, &encode_meta_key())? {
524        Some(bytes) => Ok(Some(IndexMetadata::from_bytes(&bytes)?)),
525        None => Ok(None),
526    }
527}
528
529/// Save a single node within an existing transaction.
530pub fn save_node_tx<T: Transaction>(
531    tx: &mut T,
532    table: &str,
533    node: &HnswNode,
534) -> Result<(), VectorError> {
535    // Save node data
536    let node_data = NodeData { embedding: node.embedding.clone(), max_layer: node.max_layer };
537    tx.put(table, &encode_node_key(node.entity_id), &node_data.to_bytes())?;
538
539    // Save connections for each layer
540    for (layer, neighbors) in node.connections.iter().enumerate() {
541        let key = encode_connections_key(node.entity_id, layer);
542        tx.put(table, &key, &serialize_connections(neighbors))?;
543    }
544
545    Ok(())
546}
547
548/// Load a single node within an existing transaction.
549pub fn load_node_tx<T: Transaction>(
550    tx: &T,
551    table: &str,
552    entity_id: EntityId,
553) -> Result<Option<HnswNode>, VectorError> {
554    // Load node data
555    let node_data = match tx.get(table, &encode_node_key(entity_id))? {
556        Some(bytes) => NodeData::from_bytes(&bytes)?,
557        None => return Ok(None),
558    };
559
560    // Load connections for each layer
561    let mut connections = Vec::with_capacity(node_data.max_layer + 1);
562    for layer in 0..=node_data.max_layer {
563        let key = encode_connections_key(entity_id, layer);
564        let neighbors = match tx.get(table, &key)? {
565            Some(bytes) => deserialize_connections(&bytes)?,
566            None => Vec::new(),
567        };
568        connections.push(neighbors);
569    }
570
571    Ok(Some(HnswNode {
572        entity_id,
573        embedding: node_data.embedding,
574        max_layer: node_data.max_layer,
575        connections,
576    }))
577}
578
579/// Delete a node within an existing transaction.
580pub fn delete_node_tx<T: Transaction>(
581    tx: &mut T,
582    table: &str,
583    entity_id: EntityId,
584    max_layer: usize,
585) -> Result<bool, VectorError> {
586    // Delete node data
587    let existed = tx.delete(table, &encode_node_key(entity_id))?;
588
589    // Delete connections for each layer
590    for layer in 0..=max_layer {
591        let key = encode_connections_key(entity_id, layer);
592        tx.delete(table, &key)?;
593    }
594
595    Ok(existed)
596}
597
598/// Update the connections for a node at a specific layer within an existing transaction.
599pub fn update_connections_tx<T: Transaction>(
600    tx: &mut T,
601    table: &str,
602    entity_id: EntityId,
603    layer: usize,
604    neighbors: &[EntityId],
605) -> Result<(), VectorError> {
606    let key = encode_connections_key(entity_id, layer);
607    tx.put(table, &key, &serialize_connections(neighbors))?;
608    Ok(())
609}
610
611/// Load the entire graph within an existing transaction.
612pub fn load_graph_tx<T: Transaction>(
613    tx: &T,
614    table: &str,
615    metadata: &IndexMetadata,
616) -> Result<HnswGraph, VectorError> {
617    let mut graph = HnswGraph::new(metadata.dimension, metadata.distance_metric);
618    graph.entry_point = metadata.entry_point;
619    graph.max_layer = metadata.max_layer;
620
621    // Scan for all node keys
622    let node_prefix = [PREFIX_HNSW_NODE];
623    let node_end = [PREFIX_HNSW_NODE + 1];
624
625    let mut cursor = tx.range(
626        table,
627        std::ops::Bound::Included(&node_prefix[..]),
628        std::ops::Bound::Excluded(&node_end[..]),
629    )?;
630
631    // First pass: collect entity IDs
632    let mut entity_ids = Vec::new();
633    while let Some((key, _)) = cursor.next()? {
634        if let Some(entity_id) = decode_node_key(&key) {
635            entity_ids.push(entity_id);
636        }
637    }
638
639    // Second pass: load full nodes
640    for entity_id in entity_ids {
641        if let Some(node) = load_node_tx(tx, table, entity_id)? {
642            graph.nodes.insert(entity_id, node);
643        }
644    }
645
646    Ok(graph)
647}
648
649/// Save the entire graph within an existing transaction.
650pub fn save_graph_tx<T: Transaction>(
651    tx: &mut T,
652    table: &str,
653    graph: &HnswGraph,
654    config: &HnswConfig,
655) -> Result<(), VectorError> {
656    // Save metadata
657    let metadata = IndexMetadata {
658        dimension: graph.dimension,
659        distance_metric: graph.distance_metric,
660        entry_point: graph.entry_point,
661        max_layer: graph.max_layer,
662        m: config.m,
663        m_max0: config.m_max0,
664        ef_construction: config.ef_construction,
665        ef_search: config.ef_search,
666        ml_bits: config.ml.to_bits(),
667        pq_segments: config.pq_segments,
668        pq_centroids: config.pq_centroids,
669    };
670    save_metadata_tx(tx, table, &metadata)?;
671
672    // Save all nodes
673    for node in graph.nodes.values() {
674        save_node_tx(tx, table, node)?;
675    }
676
677    Ok(())
678}
679
680/// Clear all index data within an existing transaction.
681///
682/// This removes all keys for the given index table.
683pub fn clear_index_tx<T: Transaction>(tx: &mut T, table: &str) -> Result<(), VectorError> {
684    // Delete metadata
685    let _ = tx.delete(table, &encode_meta_key());
686
687    // Find and delete all node and connection keys
688    let node_prefix = [PREFIX_HNSW_NODE];
689    let connection_prefix = [PREFIX_HNSW_CONNECTIONS];
690
691    // We need to collect keys first since we can't mutate while iterating
692    let mut keys_to_delete = Vec::new();
693
694    // Collect node keys
695    {
696        let node_end = [PREFIX_HNSW_NODE + 1];
697        let mut cursor = tx.range(
698            table,
699            std::ops::Bound::Included(&node_prefix[..]),
700            std::ops::Bound::Excluded(&node_end[..]),
701        )?;
702
703        while let Some((key, _)) = cursor.next()? {
704            keys_to_delete.push(key.clone());
705        }
706    }
707
708    // Collect connection keys
709    {
710        let connection_end = [PREFIX_HNSW_CONNECTIONS + 1];
711        let mut cursor = tx.range(
712            table,
713            std::ops::Bound::Included(&connection_prefix[..]),
714            std::ops::Bound::Excluded(&connection_end[..]),
715        )?;
716
717        while let Some((key, _)) = cursor.next()? {
718            keys_to_delete.push(key.clone());
719        }
720    }
721
722    // Delete all collected keys
723    for key in keys_to_delete {
724        tx.delete(table, &key)?;
725    }
726
727    Ok(())
728}
729
730#[cfg(test)]
731mod tests {
732    use super::*;
733    use manifoldb_storage::backends::RedbEngine;
734
735    fn create_test_embedding(dim: usize, value: f32) -> Embedding {
736        Embedding::new(vec![value; dim]).unwrap()
737    }
738
739    #[test]
740    fn test_metadata_roundtrip() {
741        let metadata = IndexMetadata {
742            dimension: 128,
743            distance_metric: DistanceMetric::Cosine,
744            entry_point: Some(EntityId::new(42)),
745            max_layer: 3,
746            m: 16,
747            m_max0: 32,
748            ef_construction: 200,
749            ef_search: 50,
750            ml_bits: 0.5_f64.to_bits(),
751            pq_segments: 8,
752            pq_centroids: 256,
753        };
754
755        let bytes = metadata.to_bytes();
756        let decoded = IndexMetadata::from_bytes(&bytes).unwrap();
757
758        assert_eq!(decoded.dimension, 128);
759        assert_eq!(decoded.distance_metric, DistanceMetric::Cosine);
760        assert_eq!(decoded.entry_point, Some(EntityId::new(42)));
761        assert_eq!(decoded.max_layer, 3);
762        assert_eq!(decoded.m, 16);
763        assert_eq!(decoded.m_max0, 32);
764        assert_eq!(decoded.ef_construction, 200);
765        assert_eq!(decoded.ef_search, 50);
766        assert_eq!(decoded.pq_segments, 8);
767        assert_eq!(decoded.pq_centroids, 256);
768    }
769
770    #[test]
771    fn test_metadata_no_entry_point() {
772        let metadata = IndexMetadata {
773            dimension: 64,
774            distance_metric: DistanceMetric::Euclidean,
775            entry_point: None,
776            max_layer: 0,
777            m: 32,
778            m_max0: 64,
779            ef_construction: 100,
780            ef_search: 25,
781            ml_bits: 0.3_f64.to_bits(),
782            pq_segments: 0,
783            pq_centroids: 256,
784        };
785
786        let bytes = metadata.to_bytes();
787        let decoded = IndexMetadata::from_bytes(&bytes).unwrap();
788
789        assert_eq!(decoded.entry_point, None);
790        assert_eq!(decoded.pq_segments, 0);
791    }
792
793    #[test]
794    fn test_node_data_roundtrip() {
795        let embedding = create_test_embedding(4, 1.5);
796        let node_data = NodeData { embedding: embedding.clone(), max_layer: 2 };
797
798        let bytes = node_data.to_bytes();
799        let decoded = NodeData::from_bytes(&bytes).unwrap();
800
801        assert_eq!(decoded.max_layer, 2);
802        assert_eq!(decoded.embedding.as_slice(), embedding.as_slice());
803    }
804
805    #[test]
806    fn test_connections_roundtrip() {
807        let neighbors = vec![EntityId::new(1), EntityId::new(5), EntityId::new(10)];
808        let bytes = serialize_connections(&neighbors);
809        let decoded = deserialize_connections(&bytes).unwrap();
810
811        assert_eq!(decoded, neighbors);
812    }
813
814    #[test]
815    fn test_connections_empty() {
816        let neighbors: Vec<EntityId> = vec![];
817        let bytes = serialize_connections(&neighbors);
818        let decoded = deserialize_connections(&bytes).unwrap();
819
820        assert!(decoded.is_empty());
821    }
822
823    #[test]
824    fn test_save_load_node() {
825        let engine = RedbEngine::in_memory().unwrap();
826        let table = "test_hnsw";
827
828        let mut node = HnswNode::new(EntityId::new(42), create_test_embedding(4, 1.0), 2);
829        node.connections[0] = vec![EntityId::new(1), EntityId::new(2)];
830        node.connections[1] = vec![EntityId::new(3)];
831
832        save_node(&engine, table, &node).unwrap();
833        let loaded = load_node(&engine, table, EntityId::new(42)).unwrap().unwrap();
834
835        assert_eq!(loaded.entity_id, EntityId::new(42));
836        assert_eq!(loaded.max_layer, 2);
837        assert_eq!(loaded.connections[0], vec![EntityId::new(1), EntityId::new(2)]);
838        assert_eq!(loaded.connections[1], vec![EntityId::new(3)]);
839    }
840
841    #[test]
842    fn test_save_load_metadata() {
843        let engine = RedbEngine::in_memory().unwrap();
844        let table = "test_hnsw";
845
846        let metadata = IndexMetadata {
847            dimension: 128,
848            distance_metric: DistanceMetric::Cosine,
849            entry_point: Some(EntityId::new(1)),
850            max_layer: 3,
851            m: 16,
852            m_max0: 32,
853            ef_construction: 200,
854            ef_search: 50,
855            ml_bits: 0.5_f64.to_bits(),
856            pq_segments: 8,
857            pq_centroids: 256,
858        };
859
860        save_metadata(&engine, table, &metadata).unwrap();
861        let loaded = load_metadata(&engine, table).unwrap().unwrap();
862
863        assert_eq!(loaded.dimension, metadata.dimension);
864        assert_eq!(loaded.entry_point, metadata.entry_point);
865        assert_eq!(loaded.pq_segments, metadata.pq_segments);
866        assert_eq!(loaded.pq_centroids, metadata.pq_centroids);
867    }
868}