Skip to main content

vector/
delta.rs

1//! In-memory delta for buffering vector writes before flush.
2//!
3//! This module implements the delta pattern for accumulating vector writes
4//! in memory before they are atomically flushed to storage.
5//!
6//! ## Key Design
7//!
8//! - The `VectorDbDeltaContext` contains shared state: the in-memory ID dictionary
9//!   (DashMap), the centroid graph for assignment, and a sync-safe ID allocator
10//! - The delta handles all write logic in `apply()`: ID allocation, dictionary
11//!   lookup for upsert detection, centroid assignment, and dictionary updates
12//! - The write path just validates and enqueues
13//!
14//! ## WriteCoordinator Integration
15//!
16//! The `VectorDbWriteDelta` implements the `Delta` trait for use with the
17//! WriteCoordinator. The delta receives `VectorWrite` instances and handles
18//! ID allocation, dictionary updates, and centroid assignment.
19
20use std::any::Any;
21use std::collections::HashMap;
22use std::sync::{Arc, OnceLock};
23
24use crate::hnsw::CentroidGraph;
25use crate::lire::commands::RebalanceCommand;
26use crate::lire::rebalancer::IndexRebalancer;
27use crate::model::{AttributeValue, MetadataFieldSpec, VECTOR_FIELD_NAME};
28use crate::serde::FieldValue;
29use crate::serde::posting_list::PostingUpdate;
30use crate::storage::record;
31use common::SequenceAllocator;
32use common::coordinator::{Delta, PauseHandle};
33use common::storage::RecordOp;
34use dashmap::DashMap;
35use roaring::RoaringTreemap;
36use std::collections::HashSet;
37use tracing::debug;
38// ============================================================================
39// WriteCoordinator Integration Types
40// ============================================================================
41
42pub(crate) enum VectorDbWrite {
43    Write(Vec<VectorWrite>),
44    Rebalance(RebalanceCommand),
45}
46
47/// A vector write ready for the coordinator.
48///
49/// The write path validates and enqueues this struct.
50/// The delta handles ID allocation, dictionary lookup, centroid assignment, and updates.
51#[derive(Debug, Clone)]
52pub(crate) struct VectorWrite {
53    /// User-provided external ID.
54    pub(crate) external_id: String,
55    /// Vector embedding values.
56    pub(crate) values: Vec<f32>,
57    /// All attributes including the vector field.
58    pub(crate) attributes: Vec<(String, AttributeValue)>,
59}
60
61/// Configuration options for the delta.
62pub(crate) struct VectorDbDeltaOpts {
63    /// Vector dimensions for encoding.
64    pub(crate) dimensions: usize,
65    /// Target number of centroid entries per chunk.
66    pub(crate) chunk_target: usize,
67    pub(crate) max_pending_and_running_rebalance_tasks: usize,
68    pub(crate) split_threshold_vectors: usize,
69    pub(crate) rebalance_backpressure_resume_threshold: usize,
70    /// Names of indexed metadata fields (for maintaining the inverted index).
71    pub(crate) indexed_fields: HashSet<String>,
72}
73
74impl VectorDbDeltaOpts {
75    /// Build the set of indexed field names from metadata field specs.
76    pub(crate) fn indexed_fields_from(specs: &[MetadataFieldSpec]) -> HashSet<String> {
77        specs
78            .iter()
79            .filter(|s| s.indexed)
80            .map(|s| s.name.clone())
81            .collect()
82    }
83}
84
85/// Image containing shared state for the delta.
86///
87/// This is passed to `Delta::init()` when creating a fresh delta. The image
88/// contains references to shared in-memory structures that persist across
89/// delta lifecycles.
90pub(crate) struct VectorDbDeltaContext {
91    /// Configuration options.
92    pub(crate) opts: VectorDbDeltaOpts,
93    /// In-memory ID dictionary mapping external_id -> internal_id.
94    /// Updated by the delta during apply().
95    pub(crate) dictionary: Arc<DashMap<String, u64>>,
96    /// In-memory centroid graph for assignment (immutable after initialization).
97    pub(crate) centroid_graph: Arc<dyn CentroidGraph>,
98    /// Synchronous ID allocator for internal ID generation.
99    pub(crate) id_allocator: SequenceAllocator,
100    /// The current centroid chunk being appended to.
101    pub(crate) current_chunk_id: u32,
102    /// Number of centroid entries in the current chunk.
103    pub(crate) current_chunk_count: usize,
104    pub(crate) rebalancer: IndexRebalancer,
105    pub(crate) pause_handle: Arc<OnceLock<PauseHandle>>,
106}
107
108/// Immutable delta containing all RecordOps ready to be flushed.
109///
110/// This is the result of `Delta::freeze()` and contains the finalized
111/// operations to apply atomically to storage.
112#[derive(Clone)]
113pub struct VectorDbImmutableDelta {
114    /// All RecordOps accumulated and finalized from the delta.
115    pub ops: Vec<RecordOp>,
116}
117
118/// Mutable delta that accumulates writes and builds RecordOps.
119///
120/// Implements the `Delta` trait for use with WriteCoordinator.
121pub(crate) struct VectorDbWriteDelta {
122    /// Reference to the shared image.
123    pub(crate) ctx: VectorDbDeltaContext,
124    /// Accumulated RecordOps (ID dictionary, vector data).
125    pub(crate) ops: Vec<RecordOp>,
126    /// Shared view of the delta's current state, readable by concurrent readers.
127    pub(crate) view: Arc<std::sync::RwLock<VectorDbDeltaView>>,
128}
129
130impl VectorDbWriteDelta {
131    /// Assign a vector to its nearest centroid using the HNSW graph.
132    fn assign_to_centroid(&self, vector: &[f32]) -> u64 {
133        self.ctx
134            .centroid_graph
135            .search(vector, 1)
136            .first()
137            .copied()
138            .unwrap_or(1)
139    }
140}
141
142impl Delta for VectorDbWriteDelta {
143    type Context = VectorDbDeltaContext;
144    type Write = VectorDbWrite;
145    type DeltaView = Arc<std::sync::RwLock<VectorDbDeltaView>>;
146    type Frozen = VectorDbImmutableDelta;
147    type FrozenView = Arc<VectorDbDeltaView>;
148    type ApplyResult = Arc<dyn Any + Send + Sync + 'static>;
149
150    fn init(context: VectorDbDeltaContext) -> Self {
151        Self {
152            ctx: context,
153            ops: Vec::new(),
154            view: Arc::new(std::sync::RwLock::new(VectorDbDeltaView::new())),
155        }
156    }
157
158    fn apply(
159        &mut self,
160        write: Self::Write,
161    ) -> Result<Arc<dyn Any + Send + Sync + 'static>, String> {
162        let result = match write {
163            VectorDbWrite::Write(writes) => self.apply_write(writes),
164            VectorDbWrite::Rebalance(cmd) => self.apply_rebalance_cmd(cmd),
165        };
166        self.toggle_rebalance_backpressure();
167        result
168    }
169
170    fn estimate_size(&self) -> usize {
171        let view = self.view.read().expect("lock poisoned");
172        // Rough estimate: 100 bytes per op, 50 bytes per posting update, 8 bytes per deletion
173        self.ops.len() * 100
174            + view
175                .posting_updates
176                .values()
177                .map(|v| v.len())
178                .sum::<usize>()
179                * 50
180            + view.deleted_centroids.len() as usize * 8
181    }
182
183    fn freeze(self) -> (Self::Frozen, Self::FrozenView, Self::Context) {
184        self.ctx.rebalancer.log_summary();
185        let mut ops = self.ops;
186        let view = self.view.read().expect("lock poisoned").clone();
187
188        // Finalize posting list merges and centroid stats deltas
189        for (centroid_id, updates) in &view.posting_updates {
190            let count = updates.len() as i32;
191            if let Ok(op) = record::merge_posting_list(*centroid_id, updates.clone()) {
192                ops.push(op);
193            }
194            ops.push(record::merge_centroid_stats(*centroid_id, count));
195        }
196
197        // Finalize metadata inverted index merges
198        for (encoded_key, vector_ids) in &view.metadata_index_updates {
199            if let Ok(op) = record::merge_metadata_index_bitmap(encoded_key.clone(), vector_ids) {
200                ops.push(op);
201            }
202        }
203
204        // Finalize deleted vectors merge
205        if !view.deleted_centroids.is_empty() {
206            let op = record::merge_deleted_vectors(view.deleted_centroids.clone())
207                .expect("failure to construct deleted vectors row");
208            ops.push(op);
209        }
210
211        (VectorDbImmutableDelta { ops }, Arc::new(view), self.ctx)
212    }
213
214    fn reader(&self) -> Self::DeltaView {
215        self.view.clone()
216    }
217}
218
219impl VectorDbWriteDelta {
220    fn pause_handle(&self) -> PauseHandle {
221        self.ctx.pause_handle.get().unwrap().clone()
222    }
223
224    fn toggle_rebalance_backpressure(&self) {
225        let total_tasks = self.ctx.rebalancer.total_ops_pending_and_running();
226        let max_centroid_limit = self.ctx.opts.split_threshold_vectors.saturating_mul(2) as u64;
227        if total_tasks >= self.ctx.opts.max_pending_and_running_rebalance_tasks
228            || self.ctx.rebalancer.max_centroid_size() >= max_centroid_limit
229        {
230            debug!(
231                "applying rebalance backpressure: {} {}",
232                total_tasks, self.ctx.opts.max_pending_and_running_rebalance_tasks
233            );
234            self.pause_handle().pause();
235        } else if total_tasks < self.ctx.opts.rebalance_backpressure_resume_threshold {
236            self.pause_handle().unpause();
237        }
238    }
239
240    fn apply_write(
241        &mut self,
242        vector_writes: Vec<VectorWrite>,
243    ) -> Result<Arc<dyn Any + Send + Sync + 'static>, String> {
244        let mut view = self.view.write().expect("lock poisoned");
245
246        for write in vector_writes {
247            // Allocate new internal ID
248            let (new_internal_id, seq_alloc_put) = self.ctx.id_allocator.allocate_one();
249            if let Some(seq_alloc_put) = seq_alloc_put {
250                self.ops.push(RecordOp::Put(seq_alloc_put.into()));
251            }
252
253            // Check dictionary for existing mapping (upsert detection)
254            let old_internal_id = self.ctx.dictionary.get(&write.external_id).map(|r| *r);
255
256            // Assign to centroid using the graph
257            let centroid_id = self.assign_to_centroid(&write.values);
258
259            // Update ID dictionary (in-memory)
260            self.ctx
261                .dictionary
262                .insert(write.external_id.clone(), new_internal_id);
263
264            // Build storage ops for ID dictionary
265            self.ops.push(record::put_id_dictionary(
266                &write.external_id,
267                new_internal_id,
268            ));
269
270            // Handle old vector deletion (if upsert)
271            if let Some(old_id) = old_internal_id {
272                self.ops.push(record::delete_vector_data(old_id));
273            }
274
275            // Write new vector data
276            self.ops.push(record::put_vector_data(
277                new_internal_id,
278                &write.external_id,
279                &write.attributes,
280            ));
281
282            // Accumulate metadata inverted index postings for indexed attributes
283            for (attr_name, attr_value) in &write.attributes {
284                if attr_name == VECTOR_FIELD_NAME {
285                    continue;
286                }
287                if !self.ctx.opts.indexed_fields.contains(attr_name) {
288                    continue;
289                }
290                let field_value: FieldValue = attr_value.clone().into();
291                view.add_to_metadata_index(attr_name.clone(), field_value, new_internal_id);
292            }
293
294            // Accumulate posting list update
295            view.add_to_posting(centroid_id, new_internal_id, write.values);
296            self.ctx.rebalancer.update_counts(&[(centroid_id, 1)])
297        }
298
299        drop(view);
300
301        Ok(Arc::new(()))
302    }
303}
304
305#[derive(Clone)]
306pub(crate) struct VectorDbDeltaView {
307    pub(crate) posting_updates: HashMap<u64, Vec<PostingUpdate>>,
308    pub(crate) deleted_centroids: RoaringTreemap,
309    /// Accumulated metadata index postings: encoded key → set of vector IDs.
310    /// Built into merge ops during freeze().
311    pub(crate) metadata_index_updates: HashMap<bytes::Bytes, RoaringTreemap>,
312}
313
314impl VectorDbDeltaView {
315    fn new() -> Self {
316        Self {
317            posting_updates: HashMap::new(),
318            deleted_centroids: RoaringTreemap::new(),
319            metadata_index_updates: HashMap::new(),
320        }
321    }
322
323    pub(crate) fn add_to_posting(&mut self, centroid_id: u64, vector_id: u64, vector: Vec<f32>) {
324        self.posting_updates
325            .entry(centroid_id)
326            .or_default()
327            .push(PostingUpdate::append(vector_id, vector));
328    }
329
330    pub(crate) fn add_to_metadata_index(
331        &mut self,
332        field_name: String,
333        field_value: FieldValue,
334        vector_id: u64,
335    ) {
336        let key = crate::serde::key::MetadataIndexKey::new(field_name, field_value).encode();
337        #[allow(clippy::unwrap_or_default)]
338        self.metadata_index_updates
339            .entry(key)
340            .or_insert_with(RoaringTreemap::new)
341            .insert(vector_id);
342    }
343
344    pub(crate) fn delete_from_posting(&mut self, centroid_id: u64, vector_id: u64) {
345        self.posting_updates
346            .entry(centroid_id)
347            .or_default()
348            .push(PostingUpdate::delete(vector_id));
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::hnsw::CentroidGraph;
356    use crate::lire::rebalancer::{IndexRebalancer, IndexRebalancerOpts};
357    use crate::model::AttributeValue;
358    use crate::serde::centroid_chunk::CentroidEntry;
359    use crate::serde::collection_meta::DistanceMetric;
360    use crate::serde::key::{CentroidStatsKey, IdDictionaryKey, PostingListKey, VectorDataKey};
361    use bytes::{Buf, Bytes};
362    use common::SequenceAllocator;
363    use common::coordinator::Delta;
364    use common::storage::RecordOp;
365    use common::storage::in_memory::InMemoryStorage;
366
367    /// Mock CentroidGraph with configurable centroids. Search returns all
368    /// centroid IDs in insertion order (first = assignment target).
369    struct MockCentroidGraph {
370        centroids: Vec<(u64, Vec<f32>)>,
371    }
372
373    impl MockCentroidGraph {
374        fn new(centroids: Vec<(u64, Vec<f32>)>) -> Self {
375            Self { centroids }
376        }
377    }
378
379    impl CentroidGraph for MockCentroidGraph {
380        fn search(&self, _query: &[f32], _k: usize) -> Vec<u64> {
381            self.centroids.iter().map(|(id, _)| *id).collect()
382        }
383
384        fn add_centroid(&self, _entry: &CentroidEntry) -> crate::error::Result<()> {
385            Ok(())
386        }
387
388        fn remove_centroid(&self, _centroid_id: u64) -> crate::error::Result<()> {
389            Ok(())
390        }
391
392        fn get_centroid_vector(&self, centroid_id: u64) -> Option<Vec<f32>> {
393            self.centroids
394                .iter()
395                .find(|(id, _)| *id == centroid_id)
396                .map(|(_, v)| v.clone())
397        }
398
399        fn len(&self) -> usize {
400            self.centroids.len()
401        }
402    }
403
404    /// Create a test context with the given centroid ID for assignment.
405    async fn create_test_context(centroid_id: u64) -> VectorDbDeltaContext {
406        let storage: Arc<dyn common::Storage> = Arc::new(InMemoryStorage::new());
407        let key = Bytes::from_static(&[0x01, 0x02]);
408        let id_allocator = SequenceAllocator::load(storage.as_ref(), key)
409            .await
410            .unwrap();
411        let centroid_graph: Arc<dyn CentroidGraph> =
412            Arc::new(MockCentroidGraph::new(vec![(centroid_id, vec![0.0; 3])]));
413        let rebalancer = IndexRebalancer::new(
414            IndexRebalancerOpts {
415                dimensions: 3,
416                distance_metric: DistanceMetric::L2,
417                split_search_neighbourhood: 4,
418                split_threshold_vectors: 10_000,
419                merge_threshold_vectors: 0,
420                max_rebalance_tasks: 0,
421            },
422            centroid_graph.clone(),
423            HashMap::new(),
424            Arc::new(std::sync::OnceLock::new()),
425        );
426
427        VectorDbDeltaContext {
428            opts: VectorDbDeltaOpts {
429                dimensions: 3,
430                chunk_target: 4096,
431                max_pending_and_running_rebalance_tasks: usize::MAX,
432                split_threshold_vectors: usize::MAX,
433                rebalance_backpressure_resume_threshold: 0,
434                indexed_fields: HashSet::new(),
435            },
436            dictionary: Arc::new(DashMap::new()),
437            centroid_graph,
438            id_allocator,
439            current_chunk_id: 0,
440            current_chunk_count: 0,
441            rebalancer,
442            pause_handle: Arc::new(OnceLock::new()),
443        }
444    }
445
446    /// Create a simple vector write for testing.
447    fn create_vector_write(external_id: &str, values: Vec<f32>) -> VectorWrite {
448        VectorWrite {
449            external_id: external_id.to_string(),
450            values: values.clone(),
451            attributes: vec![
452                ("vector".to_string(), AttributeValue::Vector(values)),
453                (
454                    "category".to_string(),
455                    AttributeValue::String("test".to_string()),
456                ),
457            ],
458        }
459    }
460
461    /// Helper to check if an op is a Put for a specific key prefix.
462    fn is_put_with_key_prefix(op: &RecordOp, prefix: &[u8]) -> bool {
463        match op {
464            RecordOp::Put(record) => record.record.key.starts_with(prefix),
465            _ => false,
466        }
467    }
468
469    /// Helper to check if an op is a Merge for a specific key prefix.
470    fn is_merge_with_key_prefix(op: &RecordOp, prefix: &[u8]) -> bool {
471        match op {
472            RecordOp::Merge(record) => record.record.key.starts_with(prefix),
473            _ => false,
474        }
475    }
476
477    #[tokio::test]
478    async fn should_add_vectors() {
479        // given
480        let ctx = create_test_context(1).await;
481        let mut delta = VectorDbWriteDelta::init(ctx);
482
483        let write = create_vector_write("vec-1", vec![1.0, 2.0, 3.0]);
484
485        // when
486        delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
487        let (frozen, _view, _ctx) = delta.freeze();
488
489        // then - should have ops for ID dictionary put and vector data put
490        let id_dict_key = IdDictionaryKey::new("vec-1").encode();
491        let vector_data_key_prefix = VectorDataKey::new(0).encode();
492
493        // Find ID dictionary put
494        let has_id_dict_put = frozen.ops.iter().any(|op| match op {
495            RecordOp::Put(record) => record.record.key == id_dict_key,
496            _ => false,
497        });
498        assert!(has_id_dict_put, "should have ID dictionary put op");
499
500        // Find vector data put (key starts with vector data prefix)
501        let has_vector_data_put = frozen
502            .ops
503            .iter()
504            .any(|op| is_put_with_key_prefix(op, &vector_data_key_prefix[..2]));
505        assert!(has_vector_data_put, "should have vector data put op");
506    }
507
508    #[tokio::test]
509    async fn should_assign_vectors_to_postings() {
510        // given
511        let centroid_id = 42u64;
512        let ctx = create_test_context(centroid_id).await;
513        let mut delta = VectorDbWriteDelta::init(ctx);
514
515        let write = create_vector_write("vec-1", vec![1.0, 2.0, 3.0]);
516
517        // when
518        delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
519        let (frozen, _view, _ctx) = delta.freeze();
520
521        // then - should have a merge op for the posting list of centroid 42
522        let posting_key = PostingListKey::new(centroid_id).encode();
523        let has_posting_merge = frozen.ops.iter().any(|op| match op {
524            RecordOp::Merge(record) => record.record.key == posting_key,
525            _ => false,
526        });
527        assert!(
528            has_posting_merge,
529            "should have posting list merge op for centroid {}",
530            centroid_id
531        );
532    }
533
534    #[tokio::test]
535    async fn should_update_dictionary_on_insert() {
536        // given
537        let ctx = create_test_context(1).await;
538        let dictionary = Arc::clone(&ctx.dictionary);
539        let mut delta = VectorDbWriteDelta::init(ctx);
540
541        let write = create_vector_write("vec-1", vec![1.0, 2.0, 3.0]);
542
543        // when
544        delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
545
546        // then - dictionary should be updated in memory
547        assert!(dictionary.contains_key("vec-1"));
548        let internal_id = *dictionary.get("vec-1").unwrap();
549        assert_eq!(internal_id, 0, "first allocated ID should be 0");
550    }
551
552    #[tokio::test]
553    async fn should_add_vectors_on_update() {
554        // given
555        let ctx = create_test_context(1).await;
556        let mut delta = VectorDbWriteDelta::init(ctx);
557        let write = create_vector_write("vec-1", vec![1.0, 2.0, 3.0]);
558        delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
559        let write = create_vector_write("vec-1", vec![4.0, 5.0, 6.0]);
560        let first_id = *delta.ctx.dictionary.get("vec-1").unwrap();
561
562        // when:
563        delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
564        let (frozen, _view, ctx) = delta.freeze();
565
566        // then - should have put for new ID dictionary entry only
567        let id_dict_key = IdDictionaryKey::new("vec-1").encode();
568        let id_dict_puts: Vec<_> = frozen
569            .ops
570            .clone()
571            .into_iter()
572            .filter(|op| match op {
573                RecordOp::Put(record) => record.record.key == id_dict_key,
574                _ => false,
575            })
576            .collect();
577        assert!(!id_dict_puts.is_empty());
578        let RecordOp::Put(record) = id_dict_puts.last().unwrap() else {
579            panic!("should have ID dictionary put op");
580        };
581        let new_id = record.record.value.clone().get_u64_le();
582        assert!(new_id > first_id);
583        // Dictionary should have new internal ID
584        let new_id_dict = *ctx.dictionary.get("vec-1").unwrap();
585        assert_eq!(new_id_dict, new_id);
586    }
587
588    #[tokio::test]
589    async fn should_assign_vectors_to_postings_on_update() {
590        // given
591        let centroid_id = 5u64;
592        let ctx = create_test_context(centroid_id).await;
593
594        // Pre-populate dictionary to simulate existing vector
595        ctx.dictionary.insert("vec-1".to_string(), 100);
596
597        let mut delta = VectorDbWriteDelta::init(ctx);
598
599        let write = create_vector_write("vec-1", vec![4.0, 5.0, 6.0]);
600
601        // when
602        delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
603        let (frozen, _view, _ctx) = delta.freeze();
604
605        // then - should have posting list merge for the new vector
606        let posting_key = PostingListKey::new(centroid_id).encode();
607        let has_posting_merge = frozen.ops.iter().any(|op| match op {
608            RecordOp::Merge(record) => record.record.key == posting_key,
609            _ => false,
610        });
611        assert!(
612            has_posting_merge,
613            "should have posting list merge op on update"
614        );
615    }
616
617    #[tokio::test]
618    async fn should_delete_old_vector_data_on_update() {
619        // given
620        let ctx = create_test_context(1).await;
621        let mut delta = VectorDbWriteDelta::init(ctx);
622        let write = create_vector_write("vec-1", vec![4.0, 5.0, 6.0]);
623        delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
624        let old_internal_id = *delta.ctx.dictionary.get("vec-1").unwrap();
625
626        // when
627        let write = create_vector_write("vec-1", vec![4.0, 5.0, 6.0]);
628        delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
629        let (frozen, _view, _ctx) = delta.freeze();
630
631        // then - should have delete op for old vector data
632        let old_vector_key = VectorDataKey::new(old_internal_id).encode();
633        let has_vector_delete = frozen.ops.iter().any(|op| match op {
634            RecordOp::Delete(key) => *key == old_vector_key,
635            _ => false,
636        });
637        assert!(has_vector_delete, "should have vector data delete op");
638    }
639
640    #[tokio::test]
641    async fn should_handle_multiple_vectors_in_single_apply() {
642        // given
643        let ctx = create_test_context(1).await;
644        let mut delta = VectorDbWriteDelta::init(ctx);
645
646        let writes = vec![
647            create_vector_write("vec-1", vec![1.0, 0.0, 0.0]),
648            create_vector_write("vec-2", vec![0.0, 1.0, 0.0]),
649            create_vector_write("vec-3", vec![0.0, 0.0, 1.0]),
650        ];
651
652        // when
653        delta.apply(VectorDbWrite::Write(writes)).unwrap();
654        let (frozen, _view, ctx) = delta.freeze();
655
656        // then - should have 3 vectors in dictionary
657        assert_eq!(ctx.dictionary.len(), 3);
658        assert!(ctx.dictionary.contains_key("vec-1"));
659        assert!(ctx.dictionary.contains_key("vec-2"));
660        assert!(ctx.dictionary.contains_key("vec-3"));
661
662        // Should have ID dictionary puts for each
663        let id_dict_puts = frozen
664            .ops
665            .iter()
666            .filter(|op| is_put_with_key_prefix(op, &IdDictionaryKey::new("").encode()[..2]))
667            .count();
668        assert_eq!(id_dict_puts, 3, "should have 3 ID dictionary put ops");
669
670        // Should have vector data puts for each
671        let vector_data_puts = frozen
672            .ops
673            .iter()
674            .filter(|op| is_put_with_key_prefix(op, &VectorDataKey::new(0).encode()[..2]))
675            .count();
676        assert_eq!(vector_data_puts, 3, "should have 3 vector data put ops");
677    }
678
679    #[tokio::test]
680    async fn should_allocate_sequential_internal_ids() {
681        // given
682        let ctx = create_test_context(1).await;
683        let dictionary = Arc::clone(&ctx.dictionary);
684        let mut delta = VectorDbWriteDelta::init(ctx);
685
686        let writes = vec![
687            create_vector_write("vec-1", vec![1.0, 0.0, 0.0]),
688            create_vector_write("vec-2", vec![0.0, 1.0, 0.0]),
689            create_vector_write("vec-3", vec![0.0, 0.0, 1.0]),
690        ];
691
692        // when
693        delta.apply(VectorDbWrite::Write(writes)).unwrap();
694
695        // then - internal IDs should be sequential starting from 0
696        let id1 = *dictionary.get("vec-1").unwrap();
697        let id2 = *dictionary.get("vec-2").unwrap();
698        let id3 = *dictionary.get("vec-3").unwrap();
699
700        assert_eq!(id1, 0);
701        assert_eq!(id2, 1);
702        assert_eq!(id3, 2);
703    }
704
705    #[tokio::test]
706    async fn should_group_postings_by_centroid() {
707        // given - create a mock that returns different centroids based on query
708        struct MultiCentroidGraph;
709
710        impl CentroidGraph for MultiCentroidGraph {
711            fn search(&self, query: &[f32], _k: usize) -> Vec<u64> {
712                // Return centroid based on which dimension has highest value
713                if query[0] > query[1] && query[0] > query[2] {
714                    vec![1]
715                } else if query[1] > query[2] {
716                    vec![2]
717                } else {
718                    vec![3]
719                }
720            }
721
722            fn add_centroid(&self, _entry: &CentroidEntry) -> crate::error::Result<()> {
723                Ok(())
724            }
725
726            fn remove_centroid(&self, _centroid_id: u64) -> crate::error::Result<()> {
727                Ok(())
728            }
729
730            fn get_centroid_vector(&self, _centroid_id: u64) -> Option<Vec<f32>> {
731                None
732            }
733
734            fn len(&self) -> usize {
735                3
736            }
737        }
738
739        let storage: Arc<dyn common::Storage> = Arc::new(InMemoryStorage::new());
740        let key = Bytes::from_static(&[0x01, 0x02]);
741        let id_allocator = SequenceAllocator::load(storage.as_ref(), key)
742            .await
743            .unwrap();
744
745        let centroid_graph: Arc<dyn CentroidGraph> = Arc::new(MultiCentroidGraph);
746        let rebalancer = IndexRebalancer::new(
747            IndexRebalancerOpts {
748                dimensions: 3,
749                distance_metric: DistanceMetric::L2,
750                split_search_neighbourhood: 4,
751                split_threshold_vectors: 10_000,
752                merge_threshold_vectors: 0,
753                max_rebalance_tasks: 0,
754            },
755            centroid_graph.clone(),
756            HashMap::new(),
757            Arc::new(std::sync::OnceLock::new()),
758        );
759
760        let ctx = VectorDbDeltaContext {
761            opts: VectorDbDeltaOpts {
762                dimensions: 3,
763                chunk_target: 4096,
764                max_pending_and_running_rebalance_tasks: usize::MAX,
765                split_threshold_vectors: usize::MAX,
766                rebalance_backpressure_resume_threshold: 0,
767                indexed_fields: HashSet::new(),
768            },
769            dictionary: Arc::new(DashMap::new()),
770            centroid_graph,
771            id_allocator,
772            current_chunk_id: 0,
773            current_chunk_count: 0,
774            rebalancer,
775            pause_handle: Arc::new(OnceLock::new()),
776        };
777
778        let mut delta = VectorDbWriteDelta::init(ctx);
779
780        let writes = vec![
781            create_vector_write("vec-1", vec![1.0, 0.0, 0.0]), // -> centroid 1
782            create_vector_write("vec-2", vec![0.0, 1.0, 0.0]), // -> centroid 2
783            create_vector_write("vec-3", vec![0.0, 0.0, 1.0]), // -> centroid 3
784            create_vector_write("vec-4", vec![0.9, 0.1, 0.0]), // -> centroid 1
785        ];
786
787        // when
788        delta.apply(VectorDbWrite::Write(writes)).unwrap();
789        let (frozen, _view, _ctx) = delta.freeze();
790
791        // then - should have posting list merges for centroids 1, 2, and 3
792        let posting_merges: Vec<_> = frozen
793            .ops
794            .iter()
795            .filter(|op| is_merge_with_key_prefix(op, &PostingListKey::new(0).encode()[..2]))
796            .collect();
797
798        assert_eq!(
799            posting_merges.len(),
800            3,
801            "should have 3 posting list merge ops"
802        );
803    }
804
805    #[tokio::test]
806    async fn should_emit_centroid_stats_on_freeze() {
807        // given
808        let centroid_id = 42u64;
809        let ctx = create_test_context(centroid_id).await;
810        let mut delta = VectorDbWriteDelta::init(ctx);
811
812        let writes = vec![
813            create_vector_write("vec-1", vec![1.0, 2.0, 3.0]),
814            create_vector_write("vec-2", vec![4.0, 5.0, 6.0]),
815        ];
816
817        // when
818        delta.apply(VectorDbWrite::Write(writes)).unwrap();
819        let (frozen, _view, _ctx) = delta.freeze();
820
821        // then - should have a centroid stats merge op with delta = 2
822        let stats_key = CentroidStatsKey::new(centroid_id).encode();
823        let stats_merge = frozen.ops.iter().find(|op| match op {
824            RecordOp::Merge(record) => record.record.key == stats_key,
825            _ => false,
826        });
827        assert!(
828            stats_merge.is_some(),
829            "should have centroid stats merge op for centroid {}",
830            centroid_id
831        );
832
833        // Verify the delta value is 2
834        if let Some(RecordOp::Merge(record)) = stats_merge {
835            let value = crate::serde::centroid_stats::CentroidStatsValue::decode_from_bytes(
836                &record.record.value,
837            )
838            .unwrap();
839            assert_eq!(value.num_vectors, 2, "should have delta of 2 for 2 vectors");
840        }
841    }
842
843    #[tokio::test]
844    async fn should_estimate_size_correctly() {
845        // given
846        let ctx = create_test_context(1).await;
847        let mut delta = VectorDbWriteDelta::init(ctx);
848
849        // Initial size should be 0
850        assert_eq!(delta.estimate_size(), 0);
851
852        // when - add a vector
853        let write = create_vector_write("vec-1", vec![1.0, 2.0, 3.0]);
854        delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
855
856        // then - size should be non-zero
857        let size = delta.estimate_size();
858        assert!(size > 0, "size should be non-zero after adding vector");
859    }
860
861    #[tokio::test]
862    async fn should_expose_posting_updates_via_reader() {
863        // given
864        let centroid_id = 7u64;
865        let ctx = create_test_context(centroid_id).await;
866        let mut delta = VectorDbWriteDelta::init(ctx);
867        let reader = delta.reader();
868
869        // when - insert a new vector and upsert an existing one
870        let writes = vec![
871            create_vector_write("vec-2", vec![1.0, 0.0, 0.0]),
872            create_vector_write("vec-1", vec![0.0, 1.0, 0.0]),
873        ];
874        delta.apply(VectorDbWrite::Write(writes)).unwrap();
875
876        // then - reader should see posting updates for both vectors
877        let view = reader.read().expect("lock poisoned");
878        let postings = view
879            .posting_updates
880            .get(&centroid_id)
881            .expect("should have postings for centroid");
882        assert_eq!(
883            postings.len(),
884            2,
885            "should have posting updates for both vectors"
886        );
887    }
888
889    #[tokio::test]
890    async fn should_emit_metadata_index_merge_ops_for_indexed_fields() {
891        // given - context with "category" as an indexed field
892        let storage: Arc<dyn common::Storage> = Arc::new(InMemoryStorage::new());
893        let key = Bytes::from_static(&[0x01, 0x02]);
894        let id_allocator = SequenceAllocator::load(storage.as_ref(), key)
895            .await
896            .unwrap();
897        let centroid_graph: Arc<dyn CentroidGraph> =
898            Arc::new(MockCentroidGraph::new(vec![(1, vec![0.0; 3])]));
899        let rebalancer = IndexRebalancer::new(
900            IndexRebalancerOpts {
901                dimensions: 3,
902                distance_metric: DistanceMetric::L2,
903                split_search_neighbourhood: 4,
904                split_threshold_vectors: 10_000,
905                merge_threshold_vectors: 0,
906                max_rebalance_tasks: 0,
907            },
908            centroid_graph.clone(),
909            HashMap::new(),
910            Arc::new(std::sync::OnceLock::new()),
911        );
912
913        let ctx = VectorDbDeltaContext {
914            opts: VectorDbDeltaOpts {
915                dimensions: 3,
916                chunk_target: 4096,
917                max_pending_and_running_rebalance_tasks: usize::MAX,
918                split_threshold_vectors: usize::MAX,
919                rebalance_backpressure_resume_threshold: 0,
920                indexed_fields: HashSet::from(["category".to_string()]),
921            },
922            dictionary: Arc::new(DashMap::new()),
923            centroid_graph,
924            id_allocator,
925            current_chunk_id: 0,
926            current_chunk_count: 0,
927            rebalancer,
928            pause_handle: Arc::new(OnceLock::new()),
929        };
930        let mut delta = VectorDbWriteDelta::init(ctx);
931
932        let writes = vec![
933            VectorWrite {
934                external_id: "vec-1".to_string(),
935                values: vec![1.0, 0.0, 0.0],
936                attributes: vec![
937                    (
938                        "vector".to_string(),
939                        AttributeValue::Vector(vec![1.0, 0.0, 0.0]),
940                    ),
941                    (
942                        "category".to_string(),
943                        AttributeValue::String("shoes".to_string()),
944                    ),
945                ],
946            },
947            VectorWrite {
948                external_id: "vec-2".to_string(),
949                values: vec![0.0, 1.0, 0.0],
950                attributes: vec![
951                    (
952                        "vector".to_string(),
953                        AttributeValue::Vector(vec![0.0, 1.0, 0.0]),
954                    ),
955                    (
956                        "category".to_string(),
957                        AttributeValue::String("shoes".to_string()),
958                    ),
959                ],
960            },
961            VectorWrite {
962                external_id: "vec-3".to_string(),
963                values: vec![0.0, 0.0, 1.0],
964                attributes: vec![
965                    (
966                        "vector".to_string(),
967                        AttributeValue::Vector(vec![0.0, 0.0, 1.0]),
968                    ),
969                    (
970                        "category".to_string(),
971                        AttributeValue::String("boots".to_string()),
972                    ),
973                ],
974            },
975        ];
976
977        // when
978        delta.apply(VectorDbWrite::Write(writes)).unwrap();
979        let (frozen, _view, _ctx) = delta.freeze();
980
981        // then - should have metadata index merge ops
982        let metadata_prefix = crate::serde::RecordType::MetadataIndex.prefix();
983        let mut prefix_buf = bytes::BytesMut::with_capacity(2);
984        metadata_prefix.write_to(&mut prefix_buf);
985        let prefix = prefix_buf.freeze();
986
987        let metadata_merges: Vec<_> = frozen
988            .ops
989            .iter()
990            .filter(|op| is_merge_with_key_prefix(op, &prefix))
991            .collect();
992
993        // Should have exactly 2 merge ops: one for (category, "shoes") and one for (category, "boots")
994        assert_eq!(
995            metadata_merges.len(),
996            2,
997            "should have 2 metadata index merge ops (one per unique field/value pair)"
998        );
999
1000        // Decode the bitmaps and verify: "shoes" should have 2 vector IDs, "boots" should have 1
1001        let mut bitmap_sizes: Vec<u64> = metadata_merges
1002            .iter()
1003            .map(|op| {
1004                let RecordOp::Merge(record) = op else {
1005                    panic!("expected merge op");
1006                };
1007                let bitmap = crate::serde::metadata_index::MetadataIndexValue::decode_from_bytes(
1008                    &record.record.value,
1009                )
1010                .unwrap();
1011                bitmap.len()
1012            })
1013            .collect();
1014        bitmap_sizes.sort();
1015        assert_eq!(
1016            bitmap_sizes,
1017            vec![1, 2],
1018            "should have bitmaps with 1 and 2 entries"
1019        );
1020    }
1021
1022    #[tokio::test]
1023    async fn should_not_emit_metadata_index_ops_for_non_indexed_fields() {
1024        // given - context with NO indexed fields
1025        let ctx = create_test_context(1).await;
1026        let mut delta = VectorDbWriteDelta::init(ctx);
1027
1028        let write = create_vector_write("vec-1", vec![1.0, 2.0, 3.0]);
1029
1030        // when
1031        delta.apply(VectorDbWrite::Write(vec![write])).unwrap();
1032        let (frozen, _view, _ctx) = delta.freeze();
1033
1034        // then - should have NO metadata index merge ops
1035        let metadata_prefix = crate::serde::RecordType::MetadataIndex.prefix();
1036        let mut prefix_buf = bytes::BytesMut::with_capacity(2);
1037        metadata_prefix.write_to(&mut prefix_buf);
1038        let prefix = prefix_buf.freeze();
1039
1040        let metadata_merges = frozen
1041            .ops
1042            .iter()
1043            .filter(|op| is_merge_with_key_prefix(op, &prefix))
1044            .count();
1045        assert_eq!(
1046            metadata_merges, 0,
1047            "should have no metadata index ops when no fields are indexed"
1048        );
1049    }
1050
1051    #[tokio::test]
1052    async fn should_update_centroid_counts_per_centroid() {
1053        // given - create a mock that routes vectors to different centroids
1054        struct MultiCentroidGraph;
1055
1056        impl CentroidGraph for MultiCentroidGraph {
1057            fn search(&self, query: &[f32], _k: usize) -> Vec<u64> {
1058                if query[0] > query[1] && query[0] > query[2] {
1059                    vec![1]
1060                } else if query[1] > query[2] {
1061                    vec![2]
1062                } else {
1063                    vec![3]
1064                }
1065            }
1066
1067            fn add_centroid(&self, _entry: &CentroidEntry) -> crate::error::Result<()> {
1068                Ok(())
1069            }
1070
1071            fn remove_centroid(&self, _centroid_id: u64) -> crate::error::Result<()> {
1072                Ok(())
1073            }
1074
1075            fn get_centroid_vector(&self, _centroid_id: u64) -> Option<Vec<f32>> {
1076                None
1077            }
1078
1079            fn len(&self) -> usize {
1080                3
1081            }
1082        }
1083
1084        let storage: Arc<dyn common::Storage> = Arc::new(InMemoryStorage::new());
1085        let key = Bytes::from_static(&[0x01, 0x02]);
1086        let id_allocator = SequenceAllocator::load(storage.as_ref(), key)
1087            .await
1088            .unwrap();
1089
1090        let centroid_graph: Arc<dyn CentroidGraph> = Arc::new(MultiCentroidGraph);
1091        let rebalancer = IndexRebalancer::new(
1092            IndexRebalancerOpts {
1093                dimensions: 3,
1094                distance_metric: DistanceMetric::L2,
1095                split_search_neighbourhood: 4,
1096                split_threshold_vectors: 10_000,
1097                merge_threshold_vectors: 0,
1098                max_rebalance_tasks: 0,
1099            },
1100            centroid_graph.clone(),
1101            HashMap::new(),
1102            Arc::new(std::sync::OnceLock::new()),
1103        );
1104
1105        let ctx = VectorDbDeltaContext {
1106            opts: VectorDbDeltaOpts {
1107                dimensions: 3,
1108                chunk_target: 4096,
1109                max_pending_and_running_rebalance_tasks: usize::MAX,
1110                split_threshold_vectors: usize::MAX,
1111                rebalance_backpressure_resume_threshold: 0,
1112                indexed_fields: HashSet::new(),
1113            },
1114            dictionary: Arc::new(DashMap::new()),
1115            centroid_graph,
1116            id_allocator,
1117            current_chunk_id: 0,
1118            current_chunk_count: 0,
1119            rebalancer,
1120            pause_handle: Arc::new(OnceLock::new()),
1121        };
1122
1123        let mut delta = VectorDbWriteDelta::init(ctx);
1124
1125        let writes = vec![
1126            create_vector_write("vec-1", vec![1.0, 0.0, 0.0]), // -> centroid 1
1127            create_vector_write("vec-2", vec![0.0, 1.0, 0.0]), // -> centroid 2
1128            create_vector_write("vec-3", vec![0.0, 0.0, 1.0]), // -> centroid 3
1129            create_vector_write("vec-4", vec![0.9, 0.1, 0.0]), // -> centroid 1
1130        ];
1131
1132        // when
1133        delta.apply(VectorDbWrite::Write(writes)).unwrap();
1134        let (frozen, _view, ctx) = delta.freeze();
1135
1136        // then - rebalancer should have correct counts per centroid
1137        assert_eq!(ctx.rebalancer.centroid_count(1), Some(2));
1138        assert_eq!(ctx.rebalancer.centroid_count(2), Some(1));
1139        assert_eq!(ctx.rebalancer.centroid_count(3), Some(1));
1140
1141        // and - frozen delta should have centroid stats merge ops with correct deltas
1142        for (centroid_id, expected_count) in [(1u64, 2i32), (2, 1), (3, 1)] {
1143            let stats_key = CentroidStatsKey::new(centroid_id).encode();
1144            let stats_merge = frozen.ops.iter().find(|op| match op {
1145                RecordOp::Merge(record) => record.record.key == stats_key,
1146                _ => false,
1147            });
1148            assert!(
1149                stats_merge.is_some(),
1150                "should have centroid stats merge op for centroid {}",
1151                centroid_id
1152            );
1153            if let Some(RecordOp::Merge(record)) = stats_merge {
1154                let value = crate::serde::centroid_stats::CentroidStatsValue::decode_from_bytes(
1155                    &record.record.value,
1156                )
1157                .unwrap();
1158                assert_eq!(
1159                    value.num_vectors, expected_count,
1160                    "centroid {} should have count delta {}",
1161                    centroid_id, expected_count
1162                );
1163            }
1164        }
1165    }
1166}