1use std::collections::HashMap;
23
24use nodedb_types::{Surrogate, VectorQuantization};
25use serde::{Deserialize, Serialize};
26
27use crate::collection::payload_index::PayloadIndexSetSnapshot;
28use crate::collection::segment::{DEFAULT_SEAL_THRESHOLD, SealedSegment};
29use crate::collection::tier::StorageTier;
30use crate::distance::DistanceMetric;
31use crate::error::VectorError;
32use crate::flat::FlatIndex;
33use crate::hnsw::{HnswIndex, HnswParams};
34use crate::quantize::pq::PqCodec;
35use crate::quantize::sq8::Sq8Codec;
36
37use super::lifecycle::VectorCollection;
38
39const SEGV_MAGIC: [u8; 4] = *b"SEGV";
42
43fn encrypt_checkpoint(
45 key: &nodedb_wal::crypto::WalEncryptionKey,
46 plaintext: &[u8],
47) -> Result<Vec<u8>, VectorError> {
48 nodedb_wal::crypto::encrypt_segment_envelope(key, &SEGV_MAGIC, plaintext).map_err(|e| {
49 VectorError::CheckpointEncryptionError {
50 detail: e.to_string(),
51 }
52 })
53}
54
55fn decrypt_checkpoint(
57 key: &nodedb_wal::crypto::WalEncryptionKey,
58 blob: &[u8],
59) -> Result<Vec<u8>, VectorError> {
60 nodedb_wal::crypto::decrypt_segment_envelope(key, &SEGV_MAGIC, blob).map_err(|e| {
61 VectorError::CheckpointEncryptionError {
62 detail: e.to_string(),
63 }
64 })
65}
66
67#[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
68pub(crate) struct CollectionSnapshot {
69 pub dim: usize,
70 pub params_m: usize,
71 pub params_m0: usize,
72 pub params_ef_construction: usize,
73 pub params_metric: u8,
74 pub next_id: u32,
75 pub growing_base_id: u32,
76 pub growing_vectors: Vec<Vec<f32>>,
77 pub growing_deleted: Vec<bool>,
78 pub sealed_segments: Vec<SealedSnapshot>,
79 pub building_segments: Vec<BuildingSnapshot>,
80 #[serde(default)]
82 pub surrogate_map: Vec<(u32, u32)>,
83 #[serde(default)]
85 pub multi_doc_map: Vec<(u32, Vec<u32>)>,
86 #[serde(default)]
90 pub quantization_tag: u8,
91 #[serde(default)]
94 pub payload_index_bytes: Vec<u8>,
95}
96
97#[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
98pub(crate) struct SealedSnapshot {
99 pub base_id: u32,
100 pub hnsw_bytes: Vec<u8>,
101 #[serde(default)]
102 pub pq_bytes: Option<Vec<u8>>,
103 #[serde(default)]
104 pub pq_codes: Option<Vec<u8>>,
105 #[serde(default)]
109 pub sq8_bytes: Option<Vec<u8>>,
110 #[serde(default)]
114 pub sq8_codes: Option<Vec<u8>>,
115}
116
117#[derive(Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack)]
118pub(crate) struct BuildingSnapshot {
119 pub base_id: u32,
120 pub vectors: Vec<Vec<f32>>,
121 #[serde(default)]
122 pub deleted: Vec<bool>,
123}
124
125impl VectorCollection {
126 pub fn checkpoint_to_bytes(
135 &self,
136 kek: Option<&nodedb_wal::crypto::WalEncryptionKey>,
137 ) -> Vec<u8> {
138 let snapshot = CollectionSnapshot {
139 dim: self.dim,
140 params_m: self.params.m,
141 params_m0: self.params.m0,
142 params_ef_construction: self.params.ef_construction,
143 params_metric: self.params.metric as u8,
144 next_id: self.next_id,
145 growing_base_id: self.growing_base_id,
146 growing_vectors: (0..self.growing.len() as u32)
147 .filter_map(|i| self.growing.get_vector_raw(i).map(|v| v.to_vec()))
148 .collect(),
149 growing_deleted: (0..self.growing.len() as u32)
150 .map(|i| self.growing.is_deleted(i))
151 .collect(),
152 sealed_segments: self
153 .sealed
154 .iter()
155 .map(|s| {
156 let (pq_bytes, pq_codes) = match &s.pq {
157 Some((codec, codes)) => (codec.to_bytes().ok(), Some(codes.clone())),
158 None => (None, None),
159 };
160 let (sq8_bytes, sq8_codes) = if pq_bytes.is_none() {
162 match &s.sq8 {
163 Some((codec, codes)) => (Some(codec.to_bytes()), Some(codes.clone())),
164 None => (None, None),
165 }
166 } else {
167 (None, None)
168 };
169 SealedSnapshot {
170 base_id: s.base_id,
171 hnsw_bytes: s.index.checkpoint_to_bytes(),
172 pq_bytes,
173 pq_codes,
174 sq8_bytes,
175 sq8_codes,
176 }
177 })
178 .collect(),
179 building_segments: self
180 .building
181 .iter()
182 .map(|b| BuildingSnapshot {
183 base_id: b.base_id,
184 vectors: (0..b.flat.len() as u32)
185 .filter_map(|i| b.flat.get_vector_raw(i).map(|v| v.to_vec()))
186 .collect(),
187 deleted: (0..b.flat.len() as u32)
188 .map(|i| b.flat.is_deleted(i))
189 .collect(),
190 })
191 .collect(),
192 surrogate_map: self
193 .surrogate_map
194 .iter()
195 .map(|(&k, s)| (k, s.as_u32()))
196 .collect(),
197 multi_doc_map: self
198 .multi_doc_map
199 .iter()
200 .map(|(k, v)| (k.as_u32(), v.clone()))
201 .collect(),
202 quantization_tag: quantization_to_tag(self.quantization),
203 payload_index_bytes: {
204 let snap = self.payload.to_snapshot();
205 match zerompk::to_msgpack_vec(&snap) {
206 Ok(bytes) => bytes,
207 Err(e) => {
208 tracing::warn!(
209 error = %e,
210 "vector payload index snapshot serialization failed"
211 );
212 return Vec::new();
213 }
214 }
215 },
216 };
217 let msgpack = match zerompk::to_msgpack_vec(&snapshot) {
218 Ok(bytes) => bytes,
219 Err(e) => {
220 tracing::warn!(error = %e, "vector collection checkpoint serialization failed");
221 return Vec::new();
222 }
223 };
224
225 if let Some(key) = kek {
226 match encrypt_checkpoint(key, &msgpack) {
227 Ok(encrypted) => encrypted,
228 Err(e) => {
229 tracing::warn!(error = %e, "vector collection checkpoint encryption failed");
230 Vec::new()
231 }
232 }
233 } else {
234 msgpack
235 }
236 }
237
238 pub fn from_checkpoint(
249 bytes: &[u8],
250 kek: Option<&nodedb_wal::crypto::WalEncryptionKey>,
251 ) -> Result<Option<Self>, VectorError> {
252 let is_encrypted = bytes.len() >= 4 && bytes[0..4] == SEGV_MAGIC;
253
254 let msgpack: Vec<u8>;
255 let msgpack_ref: &[u8];
256
257 if is_encrypted {
258 if let Some(key) = kek {
259 msgpack = decrypt_checkpoint(key, bytes)?;
260 msgpack_ref = &msgpack;
261 } else {
262 return Err(VectorError::CheckpointEncryptedNoKey);
263 }
264 } else if kek.is_some() {
265 return Err(VectorError::CheckpointPlaintextKeyRequired);
266 } else {
267 msgpack_ref = bytes;
268 }
269
270 let snap: CollectionSnapshot = match zerompk::from_msgpack(msgpack_ref) {
271 Ok(s) => s,
272 Err(_) => return Ok(None),
273 };
274 let metric = match snap.params_metric {
275 0 => DistanceMetric::L2,
276 1 => DistanceMetric::Cosine,
277 2 => DistanceMetric::InnerProduct,
278 3 => DistanceMetric::Manhattan,
279 4 => DistanceMetric::Chebyshev,
280 5 => DistanceMetric::Hamming,
281 6 => DistanceMetric::Jaccard,
282 7 => DistanceMetric::Pearson,
283 _ => DistanceMetric::Cosine,
284 };
285 let params = HnswParams {
286 m: snap.params_m,
287 m0: snap.params_m0,
288 ef_construction: snap.params_ef_construction,
289 metric,
290 };
291
292 let mut growing = FlatIndex::new(snap.dim, metric);
293 for (i, v) in snap.growing_vectors.iter().enumerate() {
294 let deleted = snap.growing_deleted.get(i).copied().unwrap_or(false);
295 if deleted {
296 growing.insert_tombstoned(v.clone());
297 } else {
298 growing.insert(v.clone());
299 }
300 }
301
302 let mut sealed = Vec::with_capacity(snap.sealed_segments.len());
304 for ss in &snap.sealed_segments {
305 if let Some(index) = HnswIndex::from_checkpoint(&ss.hnsw_bytes).ok().flatten() {
306 let pq = match (&ss.pq_bytes, &ss.pq_codes) {
307 (Some(bytes), Some(codes)) => PqCodec::from_bytes(bytes)
308 .ok()
309 .map(|codec| (codec, codes.clone())),
310 _ => None,
311 };
312 let sq8 = if pq.is_some() {
315 None
316 } else {
317 match (&ss.sq8_bytes, &ss.sq8_codes) {
318 (Some(codec_bytes), Some(codes)) => Sq8Codec::from_bytes(codec_bytes)
319 .ok()
320 .map(|codec| (codec, codes.clone())),
321 _ => None,
322 }
323 };
324 sealed.push(SealedSegment {
325 index,
326 base_id: ss.base_id,
327 sq8,
328 pq,
329 tier: StorageTier::L0Ram,
330 mmap_vectors: None,
331 });
332 }
333 }
334
335 for bs in &snap.building_segments {
336 let mut index = HnswIndex::new(snap.dim, params.clone());
337 for v in &bs.vectors {
338 index
339 .insert(v.clone())
340 .expect("dimension guaranteed by checkpoint");
341 }
342 for (i, &dead) in bs.deleted.iter().enumerate() {
344 if dead {
345 index.delete(i as u32);
346 }
347 }
348 let sq8 = VectorCollection::build_sq8_for_index(&index);
349 sealed.push(SealedSegment {
350 index,
351 base_id: bs.base_id,
352 sq8,
353 pq: None,
354 tier: StorageTier::L0Ram,
355 mmap_vectors: None,
356 });
357 }
358
359 let next_segment_id = (sealed.len() + 1) as u32;
360
361 let index_config = crate::index_config::IndexConfig {
362 hnsw: params.clone(),
363 ..crate::index_config::IndexConfig::default()
364 };
365 Ok(Some(Self {
366 growing,
367 growing_base_id: snap.growing_base_id,
368 sealed,
369 building: Vec::new(),
370 params,
371 next_id: snap.next_id,
372 next_segment_id,
373 dim: snap.dim,
374 data_dir: None,
375 ram_budget_bytes: 0,
376 mmap_fallback_count: 0,
377 mmap_segment_count: 0,
378 surrogate_map: snap
379 .surrogate_map
380 .iter()
381 .map(|&(k, s)| (k, Surrogate::new(s)))
382 .collect(),
383 surrogate_to_local: snap
384 .surrogate_map
385 .iter()
386 .map(|&(k, s)| (Surrogate::new(s), k))
387 .collect(),
388 multi_doc_map: snap
389 .multi_doc_map
390 .into_iter()
391 .map(|(k, v)| (Surrogate::new(k), v))
392 .collect::<HashMap<_, _>>(),
393 seal_threshold: DEFAULT_SEAL_THRESHOLD,
394 index_config,
395 codec_dispatch: None,
396 quantization: quantization_from_tag(snap.quantization_tag),
397 payload: if snap.payload_index_bytes.is_empty() {
398 super::payload_index::PayloadIndexSet::default()
399 } else {
400 zerompk::from_msgpack::<PayloadIndexSetSnapshot>(&snap.payload_index_bytes)
401 .map(super::payload_index::PayloadIndexSet::from_snapshot)
402 .unwrap_or_default()
403 },
404 arena_index: None,
405 }))
406 }
407}
408
409fn quantization_to_tag(q: VectorQuantization) -> u8 {
411 match q {
412 VectorQuantization::None => 0,
413 VectorQuantization::Sq8 => 1,
414 VectorQuantization::Pq => 2,
415 VectorQuantization::RaBitQ => 3,
416 VectorQuantization::Bbq => 4,
417 VectorQuantization::Binary => 5,
418 VectorQuantization::Ternary => 6,
419 VectorQuantization::Opq => 7,
420 _ => 0,
421 }
422}
423
424fn quantization_from_tag(tag: u8) -> VectorQuantization {
426 match tag {
427 0 => VectorQuantization::None,
428 1 => VectorQuantization::Sq8,
429 2 => VectorQuantization::Pq,
430 3 => VectorQuantization::RaBitQ,
431 4 => VectorQuantization::Bbq,
432 5 => VectorQuantization::Binary,
433 6 => VectorQuantization::Ternary,
434 7 => VectorQuantization::Opq,
435 _ => VectorQuantization::None,
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use crate::collection::lifecycle::VectorCollection;
442 use crate::distance::DistanceMetric;
443 use crate::hnsw::HnswParams;
444
445 #[test]
450 fn checkpoint_roundtrip_preserves_sq8() {
451 use crate::collection::lifecycle::VectorCollection;
452 use crate::hnsw::{HnswIndex, HnswParams};
453
454 let params = HnswParams {
455 metric: crate::distance::DistanceMetric::L2,
456 ..HnswParams::default()
457 };
458 let mut coll = VectorCollection::with_seal_threshold(8, params, 1024);
462 for i in 0..1024u32 {
463 let mut v = vec![0.0f32; 8];
464 for (d, slot) in v.iter_mut().enumerate() {
465 *slot = ((i as f32) * 0.01 + (d as f32) * 0.1).sin();
466 }
467 coll.insert(v);
468 }
469 let req = coll.seal("sq8_test").expect("seal produced request");
470 let mut idx = HnswIndex::new(req.dim, req.params.clone());
471 for v in &req.vectors {
472 idx.insert(v.clone()).unwrap();
473 }
474 coll.complete_build(req.segment_id, idx);
475
476 let sealed = coll.sealed_segments();
477 assert!(!sealed.is_empty(), "expected at least one sealed segment");
478 let orig_sq8 = sealed[0]
479 .sq8
480 .as_ref()
481 .expect("sq8 must be Some after complete_build with ≥1000 vectors");
482 let orig_dim = orig_sq8.0.dim();
483 let orig_bytes = orig_sq8.0.to_bytes();
485
486 let checkpoint = coll.checkpoint_to_bytes(None);
487 let restored = VectorCollection::from_checkpoint(&checkpoint, None)
488 .unwrap()
489 .unwrap();
490
491 let restored_sealed = restored.sealed_segments();
492 assert!(!restored_sealed.is_empty());
493 let restored_sq8 = restored_sealed[0]
494 .sq8
495 .as_ref()
496 .expect("sq8 must be Some after restoring checkpoint — never recomputed");
497
498 assert_eq!(restored_sq8.0.dim(), orig_dim, "dim mismatch after restore");
499 assert_eq!(
501 restored_sq8.0.to_bytes(),
502 orig_bytes,
503 "sq8 codec bytes differ — calibration data was recomputed rather than persisted"
504 );
505 }
506
507 #[test]
508 fn checkpoint_roundtrip() {
509 let mut coll = VectorCollection::new(
510 3,
511 HnswParams {
512 metric: DistanceMetric::L2,
513 ..HnswParams::default()
514 },
515 );
516 for i in 0..50u32 {
517 coll.insert(vec![i as f32, 0.0, 0.0]);
518 }
519 let bytes = coll.checkpoint_to_bytes(None);
520 let restored = VectorCollection::from_checkpoint(&bytes, None)
521 .unwrap()
522 .unwrap();
523 assert_eq!(restored.len(), 50);
524 assert_eq!(restored.dim(), 3);
525
526 let results = restored.search(&[25.0, 0.0, 0.0], 1, 64);
527 assert_eq!(results[0].id, 25);
528 }
529
530 #[test]
534 fn checkpoint_roundtrip_preserves_payload_bitmap() {
535 use crate::collection::PayloadIndexKind;
536 use crate::collection::payload_index::FilterPredicate;
537 use nodedb_types::Value;
538 use std::collections::HashMap;
539
540 let mut coll = VectorCollection::new(
541 3,
542 HnswParams {
543 metric: DistanceMetric::L2,
544 ..HnswParams::default()
545 },
546 );
547 coll.payload
548 .add_index("category".to_string(), PayloadIndexKind::Equality);
549 for i in 0u32..10 {
550 let node_id = coll.insert(vec![i as f32, 0.0, 0.0]);
551 let mut fields = HashMap::new();
552 let cat = if i % 2 == 0 { "A" } else { "B" };
553 fields.insert("category".to_string(), Value::String(cat.to_string()));
554 coll.payload.insert_row(node_id, &fields);
555 }
556
557 let bytes = coll.checkpoint_to_bytes(None);
558 let restored = VectorCollection::from_checkpoint(&bytes, None)
559 .unwrap()
560 .unwrap();
561
562 let pred = FilterPredicate::Eq {
563 field: "category".to_string(),
564 value: Value::String("A".to_string()),
565 };
566 let bm = restored
567 .payload
568 .pre_filter(&pred)
569 .expect("payload index 'category' must be present after restore");
570 assert_eq!(
571 bm.len(),
572 5,
573 "5 rows of category=A must survive checkpoint round-trip"
574 );
575 }
576}