nodedb-vector 0.0.4

Shared vector engine (HNSW index + distance functions) for NodeDB Origin and Lite
Documentation
//! Checkpoint serialization and deserialization for `VectorCollection`.

use serde::{Deserialize, Serialize};

use crate::collection::segment::{DEFAULT_SEAL_THRESHOLD, SealedSegment};
use crate::collection::tier::StorageTier;
use crate::distance::DistanceMetric;
use crate::flat::FlatIndex;
use crate::hnsw::{HnswIndex, HnswParams};

use super::lifecycle::VectorCollection;

#[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
pub(crate) struct CollectionSnapshot {
    pub dim: usize,
    pub params_m: usize,
    pub params_m0: usize,
    pub params_ef_construction: usize,
    pub params_metric: u8,
    pub next_id: u32,
    pub growing_base_id: u32,
    pub growing_vectors: Vec<Vec<f32>>,
    pub growing_deleted: Vec<bool>,
    pub sealed_segments: Vec<SealedSnapshot>,
    pub building_segments: Vec<BuildingSnapshot>,
    #[serde(default)]
    pub doc_id_map: Vec<(u32, String)>,
    #[serde(default)]
    pub multi_doc_map: Vec<(String, Vec<u32>)>,
}

#[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
pub(crate) struct SealedSnapshot {
    pub base_id: u32,
    pub hnsw_bytes: Vec<u8>,
}

#[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
pub(crate) struct BuildingSnapshot {
    pub base_id: u32,
    pub vectors: Vec<Vec<f32>>,
}

impl VectorCollection {
    /// Serialize all segments for checkpointing.
    pub fn checkpoint_to_bytes(&self) -> Vec<u8> {
        let snapshot = CollectionSnapshot {
            dim: self.dim,
            params_m: self.params.m,
            params_m0: self.params.m0,
            params_ef_construction: self.params.ef_construction,
            params_metric: self.params.metric as u8,
            next_id: self.next_id,
            growing_base_id: self.growing_base_id,
            growing_vectors: (0..self.growing.len() as u32)
                .filter_map(|i| self.growing.get_vector(i).map(|v| v.to_vec()))
                .collect(),
            growing_deleted: (0..self.growing.len() as u32)
                .map(|i| self.growing.get_vector(i).is_none())
                .collect(),
            sealed_segments: self
                .sealed
                .iter()
                .map(|s| SealedSnapshot {
                    base_id: s.base_id,
                    hnsw_bytes: s.index.checkpoint_to_bytes(),
                })
                .collect(),
            building_segments: self
                .building
                .iter()
                .map(|b| BuildingSnapshot {
                    base_id: b.base_id,
                    vectors: (0..b.flat.len() as u32)
                        .filter_map(|i| b.flat.get_vector(i).map(|v| v.to_vec()))
                        .collect(),
                })
                .collect(),
            doc_id_map: self
                .doc_id_map
                .iter()
                .map(|(&k, v)| (k, v.clone()))
                .collect(),
            multi_doc_map: self
                .multi_doc_map
                .iter()
                .map(|(k, v)| (k.clone(), v.clone()))
                .collect(),
        };
        match zerompk::to_msgpack_vec(&snapshot) {
            Ok(bytes) => bytes,
            Err(e) => {
                tracing::warn!(error = %e, "vector collection checkpoint serialization failed");
                Vec::new()
            }
        }
    }

    /// Restore a collection from checkpoint bytes.
    pub fn from_checkpoint(bytes: &[u8]) -> Option<Self> {
        let snap: CollectionSnapshot = zerompk::from_msgpack(bytes).ok()?;
        let metric = match snap.params_metric {
            0 => DistanceMetric::L2,
            1 => DistanceMetric::Cosine,
            2 => DistanceMetric::InnerProduct,
            3 => DistanceMetric::Manhattan,
            4 => DistanceMetric::Chebyshev,
            5 => DistanceMetric::Hamming,
            6 => DistanceMetric::Jaccard,
            7 => DistanceMetric::Pearson,
            _ => DistanceMetric::Cosine,
        };
        let params = HnswParams {
            m: snap.params_m,
            m0: snap.params_m0,
            ef_construction: snap.params_ef_construction,
            metric,
        };

        let mut growing = FlatIndex::new(snap.dim, metric);
        for v in &snap.growing_vectors {
            growing.insert(v.clone());
        }

        let mut sealed = Vec::with_capacity(snap.sealed_segments.len());
        for ss in &snap.sealed_segments {
            if let Some(index) = HnswIndex::from_checkpoint(&ss.hnsw_bytes) {
                let sq8 = VectorCollection::build_sq8_for_index(&index);
                sealed.push(SealedSegment {
                    index,
                    base_id: ss.base_id,
                    sq8,
                    tier: StorageTier::L0Ram,
                    mmap_vectors: None,
                });
            }
        }

        for bs in &snap.building_segments {
            let mut index = HnswIndex::new(snap.dim, params.clone());
            for v in &bs.vectors {
                index
                    .insert(v.clone())
                    .expect("dimension guaranteed by checkpoint");
            }
            let sq8 = VectorCollection::build_sq8_for_index(&index);
            sealed.push(SealedSegment {
                index,
                base_id: bs.base_id,
                sq8,
                tier: StorageTier::L0Ram,
                mmap_vectors: None,
            });
        }

        let next_segment_id = (sealed.len() + 1) as u32;

        Some(Self {
            growing,
            growing_base_id: snap.growing_base_id,
            sealed,
            building: Vec::new(),
            params,
            next_id: snap.next_id,
            next_segment_id,
            dim: snap.dim,
            data_dir: None,
            ram_budget_bytes: 0,
            mmap_fallback_count: 0,
            mmap_segment_count: 0,
            doc_id_map: snap.doc_id_map.into_iter().collect(),
            multi_doc_map: snap.multi_doc_map.into_iter().collect(),
            seal_threshold: DEFAULT_SEAL_THRESHOLD,
        })
    }
}

#[cfg(test)]
mod tests {
    use crate::collection::lifecycle::VectorCollection;
    use crate::distance::DistanceMetric;
    use crate::hnsw::HnswParams;

    #[test]
    fn checkpoint_roundtrip() {
        let mut coll = VectorCollection::new(
            3,
            HnswParams {
                metric: DistanceMetric::L2,
                ..HnswParams::default()
            },
        );
        for i in 0..50u32 {
            coll.insert(vec![i as f32, 0.0, 0.0]);
        }
        let bytes = coll.checkpoint_to_bytes();
        let restored = VectorCollection::from_checkpoint(&bytes).unwrap();
        assert_eq!(restored.len(), 50);
        assert_eq!(restored.dim(), 3);

        let results = restored.search(&[25.0, 0.0, 0.0], 1, 64);
        assert_eq!(results[0].id, 25);
    }
}