1use std::any::Any;
21use std::collections::HashMap;
22use std::sync::{Arc, OnceLock};
23
24use crate::hnsw::CentroidGraph;
25use crate::lire::commands::RebalanceCommand;
26use crate::lire::rebalancer::IndexRebalancer;
27use crate::model::{AttributeValue, MetadataFieldSpec, VECTOR_FIELD_NAME};
28use crate::serde::FieldValue;
29use crate::serde::posting_list::PostingUpdate;
30use crate::storage::record;
31use common::SequenceAllocator;
32use common::coordinator::{Delta, PauseHandle};
33use common::storage::RecordOp;
34use dashmap::DashMap;
35use roaring::RoaringTreemap;
36use std::collections::HashSet;
37use tracing::debug;
38pub(crate) enum VectorDbWrite {
43 Write(Vec<VectorWrite>),
44 Rebalance(RebalanceCommand),
45}
46
47#[derive(Debug, Clone)]
52pub(crate) struct VectorWrite {
53 pub(crate) external_id: String,
55 pub(crate) values: Vec<f32>,
57 pub(crate) attributes: Vec<(String, AttributeValue)>,
59}
60
61pub(crate) struct VectorDbDeltaOpts {
63 pub(crate) dimensions: usize,
65 pub(crate) chunk_target: usize,
67 pub(crate) max_pending_and_running_rebalance_tasks: usize,
68 pub(crate) split_threshold_vectors: usize,
69 pub(crate) rebalance_backpressure_resume_threshold: usize,
70 pub(crate) indexed_fields: HashSet<String>,
72}
73
74impl VectorDbDeltaOpts {
75 pub(crate) fn indexed_fields_from(specs: &[MetadataFieldSpec]) -> HashSet<String> {
77 specs
78 .iter()
79 .filter(|s| s.indexed)
80 .map(|s| s.name.clone())
81 .collect()
82 }
83}
84
85pub(crate) struct VectorDbDeltaContext {
91 pub(crate) opts: VectorDbDeltaOpts,
93 pub(crate) dictionary: Arc<DashMap<String, u64>>,
96 pub(crate) centroid_graph: Arc<dyn CentroidGraph>,
98 pub(crate) id_allocator: SequenceAllocator,
100 pub(crate) current_chunk_id: u32,
102 pub(crate) current_chunk_count: usize,
104 pub(crate) rebalancer: IndexRebalancer,
105 pub(crate) pause_handle: Arc<OnceLock<PauseHandle>>,
106}
107
108#[derive(Clone)]
113pub struct VectorDbImmutableDelta {
114 pub ops: Vec<RecordOp>,
116}
117
118pub(crate) struct VectorDbWriteDelta {
122 pub(crate) ctx: VectorDbDeltaContext,
124 pub(crate) ops: Vec<RecordOp>,
126 pub(crate) view: Arc<std::sync::RwLock<VectorDbDeltaView>>,
128}
129
130impl VectorDbWriteDelta {
131 fn assign_to_centroid(&self, vector: &[f32]) -> u64 {
133 self.ctx
134 .centroid_graph
135 .search(vector, 1)
136 .first()
137 .copied()
138 .unwrap_or(1)
139 }
140}
141
142impl Delta for VectorDbWriteDelta {
143 type Context = VectorDbDeltaContext;
144 type Write = VectorDbWrite;
145 type DeltaView = Arc<std::sync::RwLock<VectorDbDeltaView>>;
146 type Frozen = VectorDbImmutableDelta;
147 type FrozenView = Arc<VectorDbDeltaView>;
148 type ApplyResult = Arc<dyn Any + Send + Sync + 'static>;
149
150 fn init(context: VectorDbDeltaContext) -> Self {
151 Self {
152 ctx: context,
153 ops: Vec::new(),
154 view: Arc::new(std::sync::RwLock::new(VectorDbDeltaView::new())),
155 }
156 }
157
158 fn apply(
159 &mut self,
160 write: Self::Write,
161 ) -> Result<Arc<dyn Any + Send + Sync + 'static>, String> {
162 let result = match write {
163 VectorDbWrite::Write(writes) => self.apply_write(writes),
164 VectorDbWrite::Rebalance(cmd) => self.apply_rebalance_cmd(cmd),
165 };
166 self.toggle_rebalance_backpressure();
167 result
168 }
169
170 fn estimate_size(&self) -> usize {
171 let view = self.view.read().expect("lock poisoned");
172 self.ops.len() * 100
174 + view
175 .posting_updates
176 .values()
177 .map(|v| v.len())
178 .sum::<usize>()
179 * 50
180 + view.deleted_centroids.len() as usize * 8
181 }
182
183 fn freeze(self) -> (Self::Frozen, Self::FrozenView, Self::Context) {
184 self.ctx.rebalancer.log_summary();
185 let mut ops = self.ops;
186 let view = self.view.read().expect("lock poisoned").clone();
187
188 for (centroid_id, updates) in &view.posting_updates {
190 let count = updates.len() as i32;
191 if let Ok(op) = record::merge_posting_list(*centroid_id, updates.clone()) {
192 ops.push(op);
193 }
194 ops.push(record::merge_centroid_stats(*centroid_id, count));
195 }
196
197 for (encoded_key, vector_ids) in &view.metadata_index_updates {
199 if let Ok(op) = record::merge_metadata_index_bitmap(encoded_key.clone(), vector_ids) {
200 ops.push(op);
201 }
202 }
203
204 if !view.deleted_centroids.is_empty() {
206 let op = record::merge_deleted_vectors(view.deleted_centroids.clone())
207 .expect("failure to construct deleted vectors row");
208 ops.push(op);
209 }
210
211 (VectorDbImmutableDelta { ops }, Arc::new(view), self.ctx)
212 }
213
214 fn reader(&self) -> Self::DeltaView {
215 self.view.clone()
216 }
217}
218
219impl VectorDbWriteDelta {
220 fn pause_handle(&self) -> PauseHandle {
221 self.ctx.pause_handle.get().unwrap().clone()
222 }
223
224 fn toggle_rebalance_backpressure(&self) {
225 let total_tasks = self.ctx.rebalancer.total_ops_pending_and_running();
226 let max_centroid_limit = self.ctx.opts.split_threshold_vectors.saturating_mul(2) as u64;
227 if total_tasks >= self.ctx.opts.max_pending_and_running_rebalance_tasks
228 || self.ctx.rebalancer.max_centroid_size() >= max_centroid_limit
229 {
230 debug!(
231 "applying rebalance backpressure: {} {}",
232 total_tasks, self.ctx.opts.max_pending_and_running_rebalance_tasks
233 );
234 self.pause_handle().pause();
235 } else if total_tasks < self.ctx.opts.rebalance_backpressure_resume_threshold {
236 self.pause_handle().unpause();
237 }
238 }
239
240 fn apply_write(
241 &mut self,
242 vector_writes: Vec<VectorWrite>,
243 ) -> Result<Arc<dyn Any + Send + Sync + 'static>, String> {
244 let mut view = self.view.write().expect("lock poisoned");
245
246 for write in vector_writes {
247 let (new_internal_id, seq_alloc_put) = self.ctx.id_allocator.allocate_one();
249 if let Some(seq_alloc_put) = seq_alloc_put {
250 self.ops.push(RecordOp::Put(seq_alloc_put.into()));
251 }
252
253 let old_internal_id = self.ctx.dictionary.get(&write.external_id).map(|r| *r);
255
256 let centroid_id = self.assign_to_centroid(&write.values);
258
259 self.ctx
261 .dictionary
262 .insert(write.external_id.clone(), new_internal_id);
263
264 self.ops.push(record::put_id_dictionary(
266 &write.external_id,
267 new_internal_id,
268 ));
269
270 if let Some(old_id) = old_internal_id {
272 self.ops.push(record::delete_vector_data(old_id));
273 }
274
275 self.ops.push(record::put_vector_data(
277 new_internal_id,
278 &write.external_id,
279 &write.attributes,
280 ));
281
282 for (attr_name, attr_value) in &write.attributes {
284 if attr_name == VECTOR_FIELD_NAME {
285 continue;
286 }
287 if !self.ctx.opts.indexed_fields.contains(attr_name) {
288 continue;
289 }
290 let field_value: FieldValue = attr_value.clone().into();
291 view.add_to_metadata_index(attr_name.clone(), field_value, new_internal_id);
292 }
293
294 view.add_to_posting(centroid_id, new_internal_id, write.values);
296 self.ctx.rebalancer.update_counts(&[(centroid_id, 1)])
297 }
298
299 drop(view);
300
301 Ok(Arc::new(()))
302 }
303}
304
305#[derive(Clone)]
306pub(crate) struct VectorDbDeltaView {
307 pub(crate) posting_updates: HashMap<u64, Vec<PostingUpdate>>,
308 pub(crate) deleted_centroids: RoaringTreemap,
309 pub(crate) metadata_index_updates: HashMap<bytes::Bytes, RoaringTreemap>,
312}
313
314impl VectorDbDeltaView {
315 fn new() -> Self {
316 Self {
317 posting_updates: HashMap::new(),
318 deleted_centroids: RoaringTreemap::new(),
319 metadata_index_updates: HashMap::new(),
320 }
321 }
322
323 pub(crate) fn add_to_posting(&mut self, centroid_id: u64, vector_id: u64, vector: Vec<f32>) {
324 self.posting_updates
325 .entry(centroid_id)
326 .or_default()
327 .push(PostingUpdate::append(vector_id, vector));
328 }
329
330 pub(crate) fn add_to_metadata_index(
331 &mut self,
332 field_name: String,
333 field_value: FieldValue,
334 vector_id: u64,
335 ) {
336 let key = crate::serde::key::MetadataIndexKey::new(field_name, field_value).encode();
337 #[allow(clippy::unwrap_or_default)]
338 self.metadata_index_updates
339 .entry(key)
340 .or_insert_with(RoaringTreemap::new)
341 .insert(vector_id);
342 }
343
344 pub(crate) fn delete_from_posting(&mut self, centroid_id: u64, vector_id: u64) {
345 self.posting_updates
346 .entry(centroid_id)
347 .or_default()
348 .push(PostingUpdate::delete(vector_id));
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::hnsw::CentroidGraph;
356 use crate::lire::rebalancer::{IndexRebalancer, IndexRebalancerOpts};
357 use crate::model::AttributeValue;
358 use crate::serde::centroid_chunk::CentroidEntry;
359 use crate::serde::collection_meta::DistanceMetric;
360 use crate::serde::key::{CentroidStatsKey, IdDictionaryKey, PostingListKey, VectorDataKey};
361 use bytes::{Buf, Bytes};
362 use common::SequenceAllocator;
363 use common::coordinator::Delta;
364 use common::storage::RecordOp;
365 use common::storage::in_memory::InMemoryStorage;
366
367 struct MockCentroidGraph {
370 centroids: Vec<(u64, Vec<f32>)>,
371 }
372
373 impl MockCentroidGraph {
374 fn new(centroids: Vec<(u64, Vec<f32>)>) -> Self {
375 Self { centroids }
376 }
377 }
378
379 impl CentroidGraph for MockCentroidGraph {
380 fn search(&self, _query: &[f32], _k: usize) -> Vec<u64> {
381 self.centroids.iter().map(|(id, _)| *id).collect()
382 }
383
384 fn add_centroid(&self, _entry: &CentroidEntry) -> crate::error::Result<()> {
385 Ok(())
386 }
387
388 fn remove_centroid(&self, _centroid_id: u64) -> crate::error::Result<()> {
389 Ok(())
390 }
391
392 fn get_centroid_vector(&self, centroid_id: u64) -> Option<Vec<f32>> {
393 self.centroids
394 .iter()
395 .find(|(id, _)| *id == centroid_id)
396 .map(|(_, v)| v.clone())
397 }
398
399 fn len(&self) -> usize {
400 self.centroids.len()
401 }
402 }
403
404 async fn create_test_context(centroid_id: u64) -> VectorDbDeltaContext {
406 let storage: Arc<dyn common::Storage> = Arc::new(InMemoryStorage::new());
407 let key = Bytes::from_static(&[0x01, 0x02]);
408 let id_allocator = SequenceAllocator::load(storage.as_ref(), key)
409 .await
410 .unwrap();
411 let centroid_graph: Arc<dyn CentroidGraph> =
412 Arc::new(MockCentroidGraph::new(vec![(centroid_id, vec![0.0; 3])]));
413 let rebalancer = IndexRebalancer::new(
414 IndexRebalancerOpts {
415 dimensions: 3,
416 distance_metric: DistanceMetric::L2,
417 split_search_neighbourhood: 4,
418 split_threshold_vectors: 10_000,
419 merge_threshold_vectors: 0,
420 max_rebalance_tasks: 0,
421 },
422 centroid_graph.clone(),
423 HashMap::new(),
424 Arc::new(std::sync::OnceLock::new()),
425 );
426
427 VectorDbDeltaContext {
428 opts: VectorDbDeltaOpts {
429 dimensions: 3,
430 chunk_target: 4096,
431 max_pending_and_running_rebalance_tasks: usize::MAX,
432 split_threshold_vectors: usize::MAX,
433 rebalance_backpressure_resume_threshold: 0,
434 indexed_fields: HashSet::new(),
435 },
436 dictionary: Arc::new(DashMap::new()),
437 centroid_graph,
438 id_allocator,
439 current_chunk_id: 0,
440 current_chunk_count: 0,
441 rebalancer,
442 pause_handle: Arc::new(OnceLock::new()),
443 }
444 }
445
446 fn create_vector_write(external_id: &str, values: Vec<f32>) -> VectorWrite {
448 VectorWrite {
449 external_id: external_id.to_string(),
450 values: values.clone(),
451 attributes: vec![
452 ("vector".to_string(), AttributeValue::Vector(values)),
453 (
454 "category".to_string(),
455 AttributeValue::String("test".to_string()),
456 ),
457 ],
458 }
459 }
460
461 fn is_put_with_key_prefix(op: &RecordOp, prefix: &[u8]) -> bool {
463 match op {
464 RecordOp::Put(record) => record.record.key.starts_with(prefix),
465 _ => false,
466 }
467 }
468
469 fn is_merge_with_key_prefix(op: &RecordOp, prefix: &[u8]) -> bool {
471 match op {
472 RecordOp::Merge(record) => record.record.key.starts_with(prefix),
473 _ => false,
474 }
475 }
476
477 #[tokio::test]
478 async fn should_add_vectors() {
479 let ctx = create_test_context(1).await;
481 let mut delta = VectorDbWriteDelta::init(ctx);
482
483 let write = create_vector_write("vec-1", vec![1.0, 2.0, 3.0]);
484
485 delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
487 let (frozen, _view, _ctx) = delta.freeze();
488
489 let id_dict_key = IdDictionaryKey::new("vec-1").encode();
491 let vector_data_key_prefix = VectorDataKey::new(0).encode();
492
493 let has_id_dict_put = frozen.ops.iter().any(|op| match op {
495 RecordOp::Put(record) => record.record.key == id_dict_key,
496 _ => false,
497 });
498 assert!(has_id_dict_put, "should have ID dictionary put op");
499
500 let has_vector_data_put = frozen
502 .ops
503 .iter()
504 .any(|op| is_put_with_key_prefix(op, &vector_data_key_prefix[..2]));
505 assert!(has_vector_data_put, "should have vector data put op");
506 }
507
508 #[tokio::test]
509 async fn should_assign_vectors_to_postings() {
510 let centroid_id = 42u64;
512 let ctx = create_test_context(centroid_id).await;
513 let mut delta = VectorDbWriteDelta::init(ctx);
514
515 let write = create_vector_write("vec-1", vec![1.0, 2.0, 3.0]);
516
517 delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
519 let (frozen, _view, _ctx) = delta.freeze();
520
521 let posting_key = PostingListKey::new(centroid_id).encode();
523 let has_posting_merge = frozen.ops.iter().any(|op| match op {
524 RecordOp::Merge(record) => record.record.key == posting_key,
525 _ => false,
526 });
527 assert!(
528 has_posting_merge,
529 "should have posting list merge op for centroid {}",
530 centroid_id
531 );
532 }
533
534 #[tokio::test]
535 async fn should_update_dictionary_on_insert() {
536 let ctx = create_test_context(1).await;
538 let dictionary = Arc::clone(&ctx.dictionary);
539 let mut delta = VectorDbWriteDelta::init(ctx);
540
541 let write = create_vector_write("vec-1", vec![1.0, 2.0, 3.0]);
542
543 delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
545
546 assert!(dictionary.contains_key("vec-1"));
548 let internal_id = *dictionary.get("vec-1").unwrap();
549 assert_eq!(internal_id, 0, "first allocated ID should be 0");
550 }
551
552 #[tokio::test]
553 async fn should_add_vectors_on_update() {
554 let ctx = create_test_context(1).await;
556 let mut delta = VectorDbWriteDelta::init(ctx);
557 let write = create_vector_write("vec-1", vec![1.0, 2.0, 3.0]);
558 delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
559 let write = create_vector_write("vec-1", vec![4.0, 5.0, 6.0]);
560 let first_id = *delta.ctx.dictionary.get("vec-1").unwrap();
561
562 delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
564 let (frozen, _view, ctx) = delta.freeze();
565
566 let id_dict_key = IdDictionaryKey::new("vec-1").encode();
568 let id_dict_puts: Vec<_> = frozen
569 .ops
570 .clone()
571 .into_iter()
572 .filter(|op| match op {
573 RecordOp::Put(record) => record.record.key == id_dict_key,
574 _ => false,
575 })
576 .collect();
577 assert!(!id_dict_puts.is_empty());
578 let RecordOp::Put(record) = id_dict_puts.last().unwrap() else {
579 panic!("should have ID dictionary put op");
580 };
581 let new_id = record.record.value.clone().get_u64_le();
582 assert!(new_id > first_id);
583 let new_id_dict = *ctx.dictionary.get("vec-1").unwrap();
585 assert_eq!(new_id_dict, new_id);
586 }
587
588 #[tokio::test]
589 async fn should_assign_vectors_to_postings_on_update() {
590 let centroid_id = 5u64;
592 let ctx = create_test_context(centroid_id).await;
593
594 ctx.dictionary.insert("vec-1".to_string(), 100);
596
597 let mut delta = VectorDbWriteDelta::init(ctx);
598
599 let write = create_vector_write("vec-1", vec![4.0, 5.0, 6.0]);
600
601 delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
603 let (frozen, _view, _ctx) = delta.freeze();
604
605 let posting_key = PostingListKey::new(centroid_id).encode();
607 let has_posting_merge = frozen.ops.iter().any(|op| match op {
608 RecordOp::Merge(record) => record.record.key == posting_key,
609 _ => false,
610 });
611 assert!(
612 has_posting_merge,
613 "should have posting list merge op on update"
614 );
615 }
616
617 #[tokio::test]
618 async fn should_delete_old_vector_data_on_update() {
619 let ctx = create_test_context(1).await;
621 let mut delta = VectorDbWriteDelta::init(ctx);
622 let write = create_vector_write("vec-1", vec![4.0, 5.0, 6.0]);
623 delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
624 let old_internal_id = *delta.ctx.dictionary.get("vec-1").unwrap();
625
626 let write = create_vector_write("vec-1", vec![4.0, 5.0, 6.0]);
628 delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
629 let (frozen, _view, _ctx) = delta.freeze();
630
631 let old_vector_key = VectorDataKey::new(old_internal_id).encode();
633 let has_vector_delete = frozen.ops.iter().any(|op| match op {
634 RecordOp::Delete(key) => *key == old_vector_key,
635 _ => false,
636 });
637 assert!(has_vector_delete, "should have vector data delete op");
638 }
639
640 #[tokio::test]
641 async fn should_handle_multiple_vectors_in_single_apply() {
642 let ctx = create_test_context(1).await;
644 let mut delta = VectorDbWriteDelta::init(ctx);
645
646 let writes = vec![
647 create_vector_write("vec-1", vec![1.0, 0.0, 0.0]),
648 create_vector_write("vec-2", vec![0.0, 1.0, 0.0]),
649 create_vector_write("vec-3", vec![0.0, 0.0, 1.0]),
650 ];
651
652 delta.apply(VectorDbWrite::Write(writes)).unwrap();
654 let (frozen, _view, ctx) = delta.freeze();
655
656 assert_eq!(ctx.dictionary.len(), 3);
658 assert!(ctx.dictionary.contains_key("vec-1"));
659 assert!(ctx.dictionary.contains_key("vec-2"));
660 assert!(ctx.dictionary.contains_key("vec-3"));
661
662 let id_dict_puts = frozen
664 .ops
665 .iter()
666 .filter(|op| is_put_with_key_prefix(op, &IdDictionaryKey::new("").encode()[..2]))
667 .count();
668 assert_eq!(id_dict_puts, 3, "should have 3 ID dictionary put ops");
669
670 let vector_data_puts = frozen
672 .ops
673 .iter()
674 .filter(|op| is_put_with_key_prefix(op, &VectorDataKey::new(0).encode()[..2]))
675 .count();
676 assert_eq!(vector_data_puts, 3, "should have 3 vector data put ops");
677 }
678
679 #[tokio::test]
680 async fn should_allocate_sequential_internal_ids() {
681 let ctx = create_test_context(1).await;
683 let dictionary = Arc::clone(&ctx.dictionary);
684 let mut delta = VectorDbWriteDelta::init(ctx);
685
686 let writes = vec![
687 create_vector_write("vec-1", vec![1.0, 0.0, 0.0]),
688 create_vector_write("vec-2", vec![0.0, 1.0, 0.0]),
689 create_vector_write("vec-3", vec![0.0, 0.0, 1.0]),
690 ];
691
692 delta.apply(VectorDbWrite::Write(writes)).unwrap();
694
695 let id1 = *dictionary.get("vec-1").unwrap();
697 let id2 = *dictionary.get("vec-2").unwrap();
698 let id3 = *dictionary.get("vec-3").unwrap();
699
700 assert_eq!(id1, 0);
701 assert_eq!(id2, 1);
702 assert_eq!(id3, 2);
703 }
704
705 #[tokio::test]
706 async fn should_group_postings_by_centroid() {
707 struct MultiCentroidGraph;
709
710 impl CentroidGraph for MultiCentroidGraph {
711 fn search(&self, query: &[f32], _k: usize) -> Vec<u64> {
712 if query[0] > query[1] && query[0] > query[2] {
714 vec![1]
715 } else if query[1] > query[2] {
716 vec![2]
717 } else {
718 vec![3]
719 }
720 }
721
722 fn add_centroid(&self, _entry: &CentroidEntry) -> crate::error::Result<()> {
723 Ok(())
724 }
725
726 fn remove_centroid(&self, _centroid_id: u64) -> crate::error::Result<()> {
727 Ok(())
728 }
729
730 fn get_centroid_vector(&self, _centroid_id: u64) -> Option<Vec<f32>> {
731 None
732 }
733
734 fn len(&self) -> usize {
735 3
736 }
737 }
738
739 let storage: Arc<dyn common::Storage> = Arc::new(InMemoryStorage::new());
740 let key = Bytes::from_static(&[0x01, 0x02]);
741 let id_allocator = SequenceAllocator::load(storage.as_ref(), key)
742 .await
743 .unwrap();
744
745 let centroid_graph: Arc<dyn CentroidGraph> = Arc::new(MultiCentroidGraph);
746 let rebalancer = IndexRebalancer::new(
747 IndexRebalancerOpts {
748 dimensions: 3,
749 distance_metric: DistanceMetric::L2,
750 split_search_neighbourhood: 4,
751 split_threshold_vectors: 10_000,
752 merge_threshold_vectors: 0,
753 max_rebalance_tasks: 0,
754 },
755 centroid_graph.clone(),
756 HashMap::new(),
757 Arc::new(std::sync::OnceLock::new()),
758 );
759
760 let ctx = VectorDbDeltaContext {
761 opts: VectorDbDeltaOpts {
762 dimensions: 3,
763 chunk_target: 4096,
764 max_pending_and_running_rebalance_tasks: usize::MAX,
765 split_threshold_vectors: usize::MAX,
766 rebalance_backpressure_resume_threshold: 0,
767 indexed_fields: HashSet::new(),
768 },
769 dictionary: Arc::new(DashMap::new()),
770 centroid_graph,
771 id_allocator,
772 current_chunk_id: 0,
773 current_chunk_count: 0,
774 rebalancer,
775 pause_handle: Arc::new(OnceLock::new()),
776 };
777
778 let mut delta = VectorDbWriteDelta::init(ctx);
779
780 let writes = vec![
781 create_vector_write("vec-1", vec![1.0, 0.0, 0.0]), create_vector_write("vec-2", vec![0.0, 1.0, 0.0]), create_vector_write("vec-3", vec![0.0, 0.0, 1.0]), create_vector_write("vec-4", vec![0.9, 0.1, 0.0]), ];
786
787 delta.apply(VectorDbWrite::Write(writes)).unwrap();
789 let (frozen, _view, _ctx) = delta.freeze();
790
791 let posting_merges: Vec<_> = frozen
793 .ops
794 .iter()
795 .filter(|op| is_merge_with_key_prefix(op, &PostingListKey::new(0).encode()[..2]))
796 .collect();
797
798 assert_eq!(
799 posting_merges.len(),
800 3,
801 "should have 3 posting list merge ops"
802 );
803 }
804
805 #[tokio::test]
806 async fn should_emit_centroid_stats_on_freeze() {
807 let centroid_id = 42u64;
809 let ctx = create_test_context(centroid_id).await;
810 let mut delta = VectorDbWriteDelta::init(ctx);
811
812 let writes = vec![
813 create_vector_write("vec-1", vec![1.0, 2.0, 3.0]),
814 create_vector_write("vec-2", vec![4.0, 5.0, 6.0]),
815 ];
816
817 delta.apply(VectorDbWrite::Write(writes)).unwrap();
819 let (frozen, _view, _ctx) = delta.freeze();
820
821 let stats_key = CentroidStatsKey::new(centroid_id).encode();
823 let stats_merge = frozen.ops.iter().find(|op| match op {
824 RecordOp::Merge(record) => record.record.key == stats_key,
825 _ => false,
826 });
827 assert!(
828 stats_merge.is_some(),
829 "should have centroid stats merge op for centroid {}",
830 centroid_id
831 );
832
833 if let Some(RecordOp::Merge(record)) = stats_merge {
835 let value = crate::serde::centroid_stats::CentroidStatsValue::decode_from_bytes(
836 &record.record.value,
837 )
838 .unwrap();
839 assert_eq!(value.num_vectors, 2, "should have delta of 2 for 2 vectors");
840 }
841 }
842
843 #[tokio::test]
844 async fn should_estimate_size_correctly() {
845 let ctx = create_test_context(1).await;
847 let mut delta = VectorDbWriteDelta::init(ctx);
848
849 assert_eq!(delta.estimate_size(), 0);
851
852 let write = create_vector_write("vec-1", vec![1.0, 2.0, 3.0]);
854 delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
855
856 let size = delta.estimate_size();
858 assert!(size > 0, "size should be non-zero after adding vector");
859 }
860
861 #[tokio::test]
862 async fn should_expose_posting_updates_via_reader() {
863 let centroid_id = 7u64;
865 let ctx = create_test_context(centroid_id).await;
866 let mut delta = VectorDbWriteDelta::init(ctx);
867 let reader = delta.reader();
868
869 let writes = vec![
871 create_vector_write("vec-2", vec![1.0, 0.0, 0.0]),
872 create_vector_write("vec-1", vec![0.0, 1.0, 0.0]),
873 ];
874 delta.apply(VectorDbWrite::Write(writes)).unwrap();
875
876 let view = reader.read().expect("lock poisoned");
878 let postings = view
879 .posting_updates
880 .get(¢roid_id)
881 .expect("should have postings for centroid");
882 assert_eq!(
883 postings.len(),
884 2,
885 "should have posting updates for both vectors"
886 );
887 }
888
889 #[tokio::test]
890 async fn should_emit_metadata_index_merge_ops_for_indexed_fields() {
891 let storage: Arc<dyn common::Storage> = Arc::new(InMemoryStorage::new());
893 let key = Bytes::from_static(&[0x01, 0x02]);
894 let id_allocator = SequenceAllocator::load(storage.as_ref(), key)
895 .await
896 .unwrap();
897 let centroid_graph: Arc<dyn CentroidGraph> =
898 Arc::new(MockCentroidGraph::new(vec![(1, vec![0.0; 3])]));
899 let rebalancer = IndexRebalancer::new(
900 IndexRebalancerOpts {
901 dimensions: 3,
902 distance_metric: DistanceMetric::L2,
903 split_search_neighbourhood: 4,
904 split_threshold_vectors: 10_000,
905 merge_threshold_vectors: 0,
906 max_rebalance_tasks: 0,
907 },
908 centroid_graph.clone(),
909 HashMap::new(),
910 Arc::new(std::sync::OnceLock::new()),
911 );
912
913 let ctx = VectorDbDeltaContext {
914 opts: VectorDbDeltaOpts {
915 dimensions: 3,
916 chunk_target: 4096,
917 max_pending_and_running_rebalance_tasks: usize::MAX,
918 split_threshold_vectors: usize::MAX,
919 rebalance_backpressure_resume_threshold: 0,
920 indexed_fields: HashSet::from(["category".to_string()]),
921 },
922 dictionary: Arc::new(DashMap::new()),
923 centroid_graph,
924 id_allocator,
925 current_chunk_id: 0,
926 current_chunk_count: 0,
927 rebalancer,
928 pause_handle: Arc::new(OnceLock::new()),
929 };
930 let mut delta = VectorDbWriteDelta::init(ctx);
931
932 let writes = vec![
933 VectorWrite {
934 external_id: "vec-1".to_string(),
935 values: vec![1.0, 0.0, 0.0],
936 attributes: vec![
937 (
938 "vector".to_string(),
939 AttributeValue::Vector(vec![1.0, 0.0, 0.0]),
940 ),
941 (
942 "category".to_string(),
943 AttributeValue::String("shoes".to_string()),
944 ),
945 ],
946 },
947 VectorWrite {
948 external_id: "vec-2".to_string(),
949 values: vec![0.0, 1.0, 0.0],
950 attributes: vec![
951 (
952 "vector".to_string(),
953 AttributeValue::Vector(vec![0.0, 1.0, 0.0]),
954 ),
955 (
956 "category".to_string(),
957 AttributeValue::String("shoes".to_string()),
958 ),
959 ],
960 },
961 VectorWrite {
962 external_id: "vec-3".to_string(),
963 values: vec![0.0, 0.0, 1.0],
964 attributes: vec![
965 (
966 "vector".to_string(),
967 AttributeValue::Vector(vec![0.0, 0.0, 1.0]),
968 ),
969 (
970 "category".to_string(),
971 AttributeValue::String("boots".to_string()),
972 ),
973 ],
974 },
975 ];
976
977 delta.apply(VectorDbWrite::Write(writes)).unwrap();
979 let (frozen, _view, _ctx) = delta.freeze();
980
981 let metadata_prefix = crate::serde::RecordType::MetadataIndex.prefix();
983 let mut prefix_buf = bytes::BytesMut::with_capacity(2);
984 metadata_prefix.write_to(&mut prefix_buf);
985 let prefix = prefix_buf.freeze();
986
987 let metadata_merges: Vec<_> = frozen
988 .ops
989 .iter()
990 .filter(|op| is_merge_with_key_prefix(op, &prefix))
991 .collect();
992
993 assert_eq!(
995 metadata_merges.len(),
996 2,
997 "should have 2 metadata index merge ops (one per unique field/value pair)"
998 );
999
1000 let mut bitmap_sizes: Vec<u64> = metadata_merges
1002 .iter()
1003 .map(|op| {
1004 let RecordOp::Merge(record) = op else {
1005 panic!("expected merge op");
1006 };
1007 let bitmap = crate::serde::metadata_index::MetadataIndexValue::decode_from_bytes(
1008 &record.record.value,
1009 )
1010 .unwrap();
1011 bitmap.len()
1012 })
1013 .collect();
1014 bitmap_sizes.sort();
1015 assert_eq!(
1016 bitmap_sizes,
1017 vec![1, 2],
1018 "should have bitmaps with 1 and 2 entries"
1019 );
1020 }
1021
1022 #[tokio::test]
1023 async fn should_not_emit_metadata_index_ops_for_non_indexed_fields() {
1024 let ctx = create_test_context(1).await;
1026 let mut delta = VectorDbWriteDelta::init(ctx);
1027
1028 let write = create_vector_write("vec-1", vec![1.0, 2.0, 3.0]);
1029
1030 delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
1032 let (frozen, _view, _ctx) = delta.freeze();
1033
1034 let metadata_prefix = crate::serde::RecordType::MetadataIndex.prefix();
1036 let mut prefix_buf = bytes::BytesMut::with_capacity(2);
1037 metadata_prefix.write_to(&mut prefix_buf);
1038 let prefix = prefix_buf.freeze();
1039
1040 let metadata_merges = frozen
1041 .ops
1042 .iter()
1043 .filter(|op| is_merge_with_key_prefix(op, &prefix))
1044 .count();
1045 assert_eq!(
1046 metadata_merges, 0,
1047 "should have no metadata index ops when no fields are indexed"
1048 );
1049 }
1050
1051 #[tokio::test]
1052 async fn should_update_centroid_counts_per_centroid() {
1053 struct MultiCentroidGraph;
1055
1056 impl CentroidGraph for MultiCentroidGraph {
1057 fn search(&self, query: &[f32], _k: usize) -> Vec<u64> {
1058 if query[0] > query[1] && query[0] > query[2] {
1059 vec![1]
1060 } else if query[1] > query[2] {
1061 vec![2]
1062 } else {
1063 vec![3]
1064 }
1065 }
1066
1067 fn add_centroid(&self, _entry: &CentroidEntry) -> crate::error::Result<()> {
1068 Ok(())
1069 }
1070
1071 fn remove_centroid(&self, _centroid_id: u64) -> crate::error::Result<()> {
1072 Ok(())
1073 }
1074
1075 fn get_centroid_vector(&self, _centroid_id: u64) -> Option<Vec<f32>> {
1076 None
1077 }
1078
1079 fn len(&self) -> usize {
1080 3
1081 }
1082 }
1083
1084 let storage: Arc<dyn common::Storage> = Arc::new(InMemoryStorage::new());
1085 let key = Bytes::from_static(&[0x01, 0x02]);
1086 let id_allocator = SequenceAllocator::load(storage.as_ref(), key)
1087 .await
1088 .unwrap();
1089
1090 let centroid_graph: Arc<dyn CentroidGraph> = Arc::new(MultiCentroidGraph);
1091 let rebalancer = IndexRebalancer::new(
1092 IndexRebalancerOpts {
1093 dimensions: 3,
1094 distance_metric: DistanceMetric::L2,
1095 split_search_neighbourhood: 4,
1096 split_threshold_vectors: 10_000,
1097 merge_threshold_vectors: 0,
1098 max_rebalance_tasks: 0,
1099 },
1100 centroid_graph.clone(),
1101 HashMap::new(),
1102 Arc::new(std::sync::OnceLock::new()),
1103 );
1104
1105 let ctx = VectorDbDeltaContext {
1106 opts: VectorDbDeltaOpts {
1107 dimensions: 3,
1108 chunk_target: 4096,
1109 max_pending_and_running_rebalance_tasks: usize::MAX,
1110 split_threshold_vectors: usize::MAX,
1111 rebalance_backpressure_resume_threshold: 0,
1112 indexed_fields: HashSet::new(),
1113 },
1114 dictionary: Arc::new(DashMap::new()),
1115 centroid_graph,
1116 id_allocator,
1117 current_chunk_id: 0,
1118 current_chunk_count: 0,
1119 rebalancer,
1120 pause_handle: Arc::new(OnceLock::new()),
1121 };
1122
1123 let mut delta = VectorDbWriteDelta::init(ctx);
1124
1125 let writes = vec![
1126 create_vector_write("vec-1", vec![1.0, 0.0, 0.0]), create_vector_write("vec-2", vec![0.0, 1.0, 0.0]), create_vector_write("vec-3", vec![0.0, 0.0, 1.0]), create_vector_write("vec-4", vec![0.9, 0.1, 0.0]), ];
1131
1132 delta.apply(VectorDbWrite::Write(writes)).unwrap();
1134 let (frozen, _view, ctx) = delta.freeze();
1135
1136 assert_eq!(ctx.rebalancer.centroid_count(1), Some(2));
1138 assert_eq!(ctx.rebalancer.centroid_count(2), Some(1));
1139 assert_eq!(ctx.rebalancer.centroid_count(3), Some(1));
1140
1141 for (centroid_id, expected_count) in [(1u64, 2i32), (2, 1), (3, 1)] {
1143 let stats_key = CentroidStatsKey::new(centroid_id).encode();
1144 let stats_merge = frozen.ops.iter().find(|op| match op {
1145 RecordOp::Merge(record) => record.record.key == stats_key,
1146 _ => false,
1147 });
1148 assert!(
1149 stats_merge.is_some(),
1150 "should have centroid stats merge op for centroid {}",
1151 centroid_id
1152 );
1153 if let Some(RecordOp::Merge(record)) = stats_merge {
1154 let value = crate::serde::centroid_stats::CentroidStatsValue::decode_from_bytes(
1155 &record.record.value,
1156 )
1157 .unwrap();
1158 assert_eq!(
1159 value.num_vectors, expected_count,
1160 "centroid {} should have count delta {}",
1161 centroid_id, expected_count
1162 );
1163 }
1164 }
1165 }
1166}