1use std::collections::HashMap;
22use std::sync::{Arc, Mutex};
23
24use crate::core::{DocId, FieldId, LuciError, SegmentId};
25use crate::mapping::Mapping;
26
27use super::DistanceMetric;
28use super::hnsw::{
29 BuildThreads, HnswBuilder, HnswIndex, HnswParams, checked_len, read_u32, read_u64, take_bytes,
30};
31
32const HNSW_M: usize = 16;
35
36const HNSW_EF_CONSTRUCTION: usize = 100;
39
40const VECTOR_INDEX_MAGIC: [u8; 4] = *b"VIDX";
42
43const VECTOR_INDEX_VERSION: u8 = 1;
45
46struct FieldGlobalHnsw {
50 builder: HnswBuilder,
51 resolver: Vec<(SegmentId, u32)>,
55 cached: Option<Arc<HnswIndex>>,
59}
60
61impl FieldGlobalHnsw {
62 fn new(params: HnswParams) -> Self {
63 Self {
64 builder: HnswBuilder::new(params),
65 resolver: Vec::new(),
66 cached: None,
67 }
68 }
69
70 fn invalidate_cache(&mut self) {
71 self.cached = None;
72 }
73
74 fn get_or_build_index(&mut self) -> Arc<HnswIndex> {
81 debug_assert!(
86 !self.builder.has_pending_tail(),
87 "get_or_build_index called with an unlinked pending tail; \
88 connect_pending was not run before persist",
89 );
90 if self.cached.is_none() {
91 self.cached = Some(Arc::new(self.builder.clone().build()));
92 }
93 Arc::clone(self.cached.as_ref().unwrap())
94 }
95}
96
97pub struct GlobalHnsw {
109 per_field: Mutex<HashMap<FieldId, FieldGlobalHnsw>>,
110}
111
112#[derive(Clone, Copy, Debug)]
115pub struct GlobalHit {
116 pub segment_id: SegmentId,
117 pub doc_id: DocId,
118 pub distance: f32,
119}
120
121impl GlobalHnsw {
122 pub fn new(schema: &Mapping) -> Self {
134 let mut per_field = HashMap::new();
135 for mapping in schema.fields() {
136 let Some(dims) = mapping.field_type.vector_dims() else {
137 continue;
138 };
139 if dims == 0 {
140 continue;
141 }
142 let field_id = schema.field_id(&mapping.name).unwrap_or_else(|| {
143 panic!(
144 "schema.fields() returned mapping for {:?} but \
145 field_id() couldn't find it; schema is internally \
146 inconsistent",
147 mapping.name
148 );
149 });
150 let quantization = mapping
151 .field_type
152 .vector_quantization()
153 .expect("dense_vector mapping must carry quantization");
154 per_field.insert(
155 field_id,
156 FieldGlobalHnsw::new(HnswParams {
157 dims,
158 m: HNSW_M,
159 ef_construction: HNSW_EF_CONSTRUCTION,
160 metric: DistanceMetric::Cosine,
161 quantization,
162 }),
163 );
164 }
165 Self {
166 per_field: Mutex::new(per_field),
167 }
168 }
169
170 pub fn add_vector(
184 &self,
185 field_id: FieldId,
186 segment_id: SegmentId,
187 local_doc_id: u32,
188 vector: Vec<f32>,
189 ) -> Result<u32, LuciError> {
190 let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
191 let field = guard.get_mut(&field_id).ok_or_else(|| {
192 LuciError::InvalidQuery(format!(
193 "GlobalHnsw::add_vector called for field {field_id:?} which is \
194 not a dense_vector field in the schema; this is an internal \
195 wiring bug",
196 ))
197 })?;
198 let ord = field.builder.len() as u32;
199 field.builder.add_vector(vector)?;
200 field.resolver.push((segment_id, local_doc_id));
201 field.invalidate_cache();
202 debug_assert_eq!(
203 field.resolver.len(),
204 field.builder.len(),
205 "resolver and builder lengths must agree after add_vector",
206 );
207 Ok(ord)
208 }
209
210 pub fn store_vector(
218 &self,
219 field_id: FieldId,
220 segment_id: SegmentId,
221 local_doc_id: u32,
222 vector: Vec<f32>,
223 ) -> Result<u32, LuciError> {
224 let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
225 let field = guard.get_mut(&field_id).ok_or_else(|| {
226 LuciError::InvalidQuery(format!(
227 "GlobalHnsw::store_vector called for field {field_id:?} which is \
228 not a dense_vector field in the schema; this is an internal \
229 wiring bug",
230 ))
231 })?;
232 let ord = field.builder.len() as u32;
233 field.builder.store_vector(vector)?;
234 field.resolver.push((segment_id, local_doc_id));
235 field.invalidate_cache();
236 debug_assert_eq!(
237 field.resolver.len(),
238 field.builder.len(),
239 "resolver and builder lengths must agree after store_vector",
240 );
241 Ok(ord)
242 }
243
244 pub fn connect_pending(&self, field_id: FieldId, threads: BuildThreads) {
254 let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
255 if let Some(field) = guard.get_mut(&field_id) {
256 field.builder.connect_pending(threads);
257 field.invalidate_cache();
258 }
259 }
260
261 pub fn search(
269 &self,
270 field_id: FieldId,
271 query: &[f32],
272 k: usize,
273 ef: usize,
274 ) -> Result<Option<(Vec<GlobalHit>, DistanceMetric)>, LuciError> {
275 let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
276 let field = match guard.get_mut(&field_id) {
277 Some(f) => f,
278 None => return Ok(None),
279 };
280 let metric = field.builder.params().metric;
281 let resolver = field.resolver.clone();
282 let index = field.get_or_build_index();
283 drop(guard);
286
287 let raw = index.search(query, k, ef)?;
288 let hits = raw
289 .into_iter()
290 .map(|(global_ord, dist)| {
291 let (seg, doc) = resolver[global_ord as usize];
292 GlobalHit {
293 segment_id: seg,
294 doc_id: DocId::new(doc),
295 distance: dist,
296 }
297 })
298 .collect();
299 Ok(Some((hits, metric)))
300 }
301
302 pub fn len(&self, field_id: FieldId) -> Option<usize> {
305 let guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
306 guard.get(&field_id).map(|f| f.builder.len())
307 }
308
309 pub fn is_empty(&self) -> bool {
311 let guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
312 guard.values().all(|f| f.builder.is_empty())
313 }
314
315 pub fn rewrite_after_merge(&self, merge_map: &HashMap<(SegmentId, u32), (SegmentId, u32)>) {
325 if merge_map.is_empty() {
326 return;
327 }
328 let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
329 for field in guard.values_mut() {
330 let mut changed = false;
331 for entry in &mut field.resolver {
332 if let Some(&(new_seg, new_doc)) = merge_map.get(entry) {
333 *entry = (new_seg, new_doc);
334 changed = true;
335 }
336 }
337 if changed {
338 field.invalidate_cache();
339 }
340 }
341 }
342
343 pub fn field_ids(&self) -> Vec<FieldId> {
345 let guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
346 guard.keys().copied().collect()
347 }
348
349 pub fn non_empty_field_ids(&self) -> Vec<FieldId> {
352 let guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
353 let mut ids: Vec<FieldId> = guard
354 .iter()
355 .filter(|(_, f)| !f.builder.is_empty())
356 .map(|(fid, _)| *fid)
357 .collect();
358 ids.sort();
359 ids
360 }
361
362 pub fn field_to_bytes(&self, field_id: FieldId) -> Option<Vec<u8>> {
376 let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
377 let field = guard.get_mut(&field_id)?;
378
379 let mut buf = Vec::new();
380 buf.extend_from_slice(&VECTOR_INDEX_MAGIC);
381 buf.push(VECTOR_INDEX_VERSION);
382
383 let index = field.get_or_build_index();
384 let hnsw_bytes = index.to_bytes();
385 buf.extend_from_slice(&(hnsw_bytes.len() as u32).to_le_bytes());
386 buf.extend_from_slice(&hnsw_bytes);
387
388 buf.extend_from_slice(&(field.resolver.len() as u32).to_le_bytes());
389 for (seg, doc) in &field.resolver {
390 buf.extend_from_slice(&seg.as_u64().to_le_bytes());
391 buf.extend_from_slice(&doc.to_le_bytes());
392 }
393
394 Some(buf)
395 }
396
397 pub fn load_field(&self, field_id: FieldId, data: &[u8]) -> Result<(), LuciError> {
405 if data.len() < 5 {
406 return Err(LuciError::IndexCorrupted(format!(
407 "vector index blob for field {field_id:?} too short: {} bytes",
408 data.len()
409 )));
410 }
411 if data[0..4] != VECTOR_INDEX_MAGIC {
412 return Err(LuciError::IndexCorrupted(format!(
413 "vector index blob for field {field_id:?} missing magic prefix"
414 )));
415 }
416 if data[4] != VECTOR_INDEX_VERSION {
417 return Err(LuciError::SegmentFormatUnknown(format!(
418 "unknown vector index blob version {} for field {field_id:?}",
419 data[4]
420 )));
421 }
422 let mut pos = 5;
423 let hnsw_len = read_u32(data, &mut pos)? as usize;
424 let hnsw_bytes = take_bytes(data, &mut pos, hnsw_len)?;
425 let index = HnswIndex::from_bytes(hnsw_bytes)?;
426 let builder = HnswBuilder::from_index(index);
427
428 let resolver_len = read_u32(data, &mut pos)? as usize;
433 let mut resolver = Vec::with_capacity(checked_len(resolver_len, 12, data, pos)?);
434 for _ in 0..resolver_len {
435 let seg = SegmentId::new(read_u64(data, &mut pos)?);
436 let doc = read_u32(data, &mut pos)?;
437 resolver.push((seg, doc));
438 }
439 if resolver.len() != builder.len() {
440 return Err(LuciError::IndexCorrupted(format!(
441 "vector index resolver/graph mismatch for field {field_id:?}: \
442 graph has {} vectors, resolver has {}",
443 builder.len(),
444 resolver.len()
445 )));
446 }
447
448 let mut guard = self.per_field.lock().expect("GlobalHnsw mutex poisoned");
449 let field = guard.get_mut(&field_id).ok_or_else(|| {
450 LuciError::InvalidQuery(format!(
451 "GlobalHnsw::load_field called for field {field_id:?} which is \
452 not a dense_vector field in the current schema"
453 ))
454 })?;
455 field.builder = builder;
456 field.resolver = resolver;
457 field.cached = None;
458 Ok(())
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use crate::mapping::{FieldType, Mapping};
465
466 use super::*;
467
468 fn vector_schema(name: &str, dims: usize) -> Mapping {
469 Mapping::builder()
470 .field(name, FieldType::dense_vector(dims))
471 .build()
472 }
473
474 #[test]
475 fn new_finds_dense_vector_fields() {
476 let schema = vector_schema("embedding", 4);
477 let g = GlobalHnsw::new(&schema);
478 let ids = g.field_ids();
479 assert_eq!(ids.len(), 1);
480 let field_id = schema.field_id("embedding").unwrap();
481 assert_eq!(g.len(field_id), Some(0));
482 }
483
484 #[test]
485 fn new_with_no_vector_fields_is_empty() {
486 let schema = Mapping::builder().build();
487 let g = GlobalHnsw::new(&schema);
488 assert!(g.is_empty());
489 }
490
491 #[test]
492 fn add_vector_returns_increasing_ordinals() {
493 let schema = vector_schema("embedding", 3);
494 let field_id = schema.field_id("embedding").unwrap();
495 let g = GlobalHnsw::new(&schema);
496 let seg = SegmentId::new(1);
497 let ord0 = g.add_vector(field_id, seg, 0, vec![1.0, 0.0, 0.0]).unwrap();
498 let ord1 = g.add_vector(field_id, seg, 1, vec![0.0, 1.0, 0.0]).unwrap();
499 let ord2 = g.add_vector(field_id, seg, 2, vec![0.0, 0.0, 1.0]).unwrap();
500 assert_eq!(ord0, 0);
501 assert_eq!(ord1, 1);
502 assert_eq!(ord2, 2);
503 assert_eq!(g.len(field_id), Some(3));
504 }
505
506 #[test]
507 fn add_vector_for_unknown_field_errors() {
508 let schema = vector_schema("embedding", 3);
509 let g = GlobalHnsw::new(&schema);
510 let seg = SegmentId::new(1);
511 let result = g.add_vector(FieldId(999), seg, 0, vec![1.0, 0.0, 0.0]);
512 assert!(matches!(result, Err(LuciError::InvalidQuery(_))));
513 }
514
515 #[test]
516 fn cosine_zero_vector_rejected() {
517 let schema = vector_schema("embedding", 3);
518 let field_id = schema.field_id("embedding").unwrap();
519 let g = GlobalHnsw::new(&schema);
520 let seg = SegmentId::new(1);
521 let result = g.add_vector(field_id, seg, 0, vec![0.0, 0.0, 0.0]);
522 assert!(matches!(result, Err(LuciError::InvalidQuery(_))));
523 }
524
525 #[test]
526 fn search_returns_hits_in_segment_local_doc_space() {
527 let schema = vector_schema("embedding", 3);
528 let field_id = schema.field_id("embedding").unwrap();
529 let g = GlobalHnsw::new(&schema);
530 let seg1 = SegmentId::new(1);
531 let seg2 = SegmentId::new(2);
532 g.add_vector(field_id, seg1, 0, vec![1.0, 0.0, 0.0])
533 .unwrap();
534 g.add_vector(field_id, seg1, 1, vec![0.0, 1.0, 0.0])
535 .unwrap();
536 g.add_vector(field_id, seg2, 0, vec![0.9, 0.1, 0.0])
537 .unwrap();
538
539 let (hits, metric) = g
540 .search(field_id, &[1.0, 0.0, 0.0], 3, 16)
541 .unwrap()
542 .unwrap();
543 assert_eq!(metric, DistanceMetric::Cosine);
544 assert_eq!(hits.len(), 3);
545 assert_eq!(hits[0].segment_id, seg1);
547 assert_eq!(hits[0].doc_id, DocId::new(0));
548 }
549
550 #[test]
551 fn roundtrip_field_to_bytes_load_field() {
552 let schema = vector_schema("embedding", 3);
553 let field_id = schema.field_id("embedding").unwrap();
554 let g = GlobalHnsw::new(&schema);
555 let seg1 = SegmentId::new(1);
556 g.add_vector(field_id, seg1, 0, vec![1.0, 0.0, 0.0])
557 .unwrap();
558 g.add_vector(field_id, seg1, 1, vec![0.0, 1.0, 0.0])
559 .unwrap();
560 g.add_vector(field_id, seg1, 2, vec![0.0, 0.0, 1.0])
561 .unwrap();
562
563 let bytes = g.field_to_bytes(field_id).unwrap();
564 let g2 = GlobalHnsw::new(&schema);
565 g2.load_field(field_id, &bytes).unwrap();
566 assert_eq!(g2.len(field_id), Some(3));
567
568 let (hits, _) = g2
569 .search(field_id, &[1.0, 0.0, 0.0], 1, 16)
570 .unwrap()
571 .unwrap();
572 assert_eq!(hits.len(), 1);
573 assert_eq!(hits[0].segment_id, seg1);
574 assert_eq!(hits[0].doc_id, DocId::new(0));
575 }
576
577 #[test]
582 fn load_field_rejects_corrupt_blob() {
583 let schema = vector_schema("embedding", 3);
584 let field_id = schema.field_id("embedding").unwrap();
585 let g = GlobalHnsw::new(&schema);
586 let seg = SegmentId::new(1);
587 g.add_vector(field_id, seg, 0, vec![1.0, 0.0, 0.0]).unwrap();
588 g.add_vector(field_id, seg, 1, vec![0.0, 1.0, 0.0]).unwrap();
589 let valid = g.field_to_bytes(field_id).unwrap();
590 assert!(
591 GlobalHnsw::new(&schema)
592 .load_field(field_id, &valid)
593 .is_ok(),
594 "valid blob must load"
595 );
596
597 for cut in [5usize, 6, 9, valid.len() / 2, valid.len() - 1] {
601 assert!(
602 GlobalHnsw::new(&schema)
603 .load_field(field_id, &valid[..cut])
604 .is_err(),
605 "truncated-to-{cut} blob must be rejected, not panic"
606 );
607 }
608
609 let hnsw_len = u32::from_le_bytes(valid[5..9].try_into().unwrap()) as usize;
613 let resolver_len_off = 5 + 4 + hnsw_len;
614 let mut bad_resolver = valid.clone();
615 bad_resolver[resolver_len_off..resolver_len_off + 4]
616 .copy_from_slice(&u32::MAX.to_le_bytes());
617 assert!(
618 matches!(
619 GlobalHnsw::new(&schema).load_field(field_id, &bad_resolver),
620 Err(LuciError::IndexCorrupted(_))
621 ),
622 "corrupt resolver length must be IndexCorrupted, not OOM/panic"
623 );
624 }
625
626 #[test]
627 fn non_empty_field_ids_omits_empty_fields() {
628 let schema = Mapping::builder()
629 .field("a", FieldType::dense_vector(2))
630 .field("b", FieldType::dense_vector(2))
631 .build();
632 let a = schema.field_id("a").unwrap();
633 let b = schema.field_id("b").unwrap();
634 let g = GlobalHnsw::new(&schema);
635 g.add_vector(a, SegmentId::new(1), 0, vec![1.0, 0.0])
637 .unwrap();
638 let ids = g.non_empty_field_ids();
639 assert_eq!(ids, vec![a]);
640 assert_eq!(g.len(b), Some(0));
641 }
642
643 #[test]
644 fn rewrite_after_merge_remaps_resolver() {
645 let schema = vector_schema("embedding", 3);
646 let field_id = schema.field_id("embedding").unwrap();
647 let g = GlobalHnsw::new(&schema);
648 let s1 = SegmentId::new(1);
649 let s2 = SegmentId::new(2);
650 let s3 = SegmentId::new(3);
651 g.add_vector(field_id, s1, 0, vec![1.0, 0.0, 0.0]).unwrap();
652 g.add_vector(field_id, s2, 0, vec![0.0, 1.0, 0.0]).unwrap();
653
654 let mut merge_map = HashMap::new();
656 merge_map.insert((s1, 0), (s3, 0));
657 merge_map.insert((s2, 0), (s3, 1));
658 g.rewrite_after_merge(&merge_map);
659
660 let (hits, _) = g
661 .search(field_id, &[1.0, 0.0, 0.0], 2, 16)
662 .unwrap()
663 .unwrap();
664 assert_eq!(hits.len(), 2);
665 for hit in &hits {
667 assert_eq!(hit.segment_id, s3);
668 }
669 }
670}