Skip to main content

vector/
db.rs

1//! Vector database implementation with atomic flush semantics.
2//!
3//! This module provides the main `VectorDb` struct that handles:
4//! - Vector ingestion with validation
5//! - In-memory delta buffering via WriteCoordinator
6//! - Atomic flush with ID allocation
7//! - Snapshot management for consistency
8//!
9//! The implementation uses the WriteCoordinator pattern for write path:
10//! - Validation and ID allocation happen in write()
11//! - Delta handles dictionary lookup, centroid assignment, and builds RecordOps
12//! - Flusher applies ops atomically to storage
13
14use crate::VectorDbReader;
15use crate::delta::{
16    VectorDbDeltaContext, VectorDbDeltaOpts, VectorDbWrite, VectorDbWriteDelta, VectorWrite,
17};
18use crate::error::{Error, Result};
19use crate::flusher::VectorDbFlusher;
20use crate::hnsw::{CentroidGraph, build_centroid_graph};
21use crate::lire::rebalancer::{IndexRebalancer, IndexRebalancerOpts};
22use crate::model::{
23    AttributeValue, Config, Query, SearchResult, VECTOR_FIELD_NAME, Vector, attributes_to_map,
24};
25use crate::query_engine::{QueryEngine, QueryEngineOptions};
26use crate::serde::centroid_chunk::CentroidEntry;
27use crate::serde::key::SeqBlockKey;
28use crate::storage::VectorDbStorageReadExt;
29use crate::storage::merge_operator::VectorDbMergeOperator;
30use async_trait::async_trait;
31use common::SequenceAllocator;
32use common::coordinator::{Durability, WriteCoordinator, WriteCoordinatorConfig};
33use common::storage::{Storage, StorageRead, StorageSnapshot};
34use common::{StorageBuilder, StorageSemantics};
35use dashmap::DashMap;
36use std::collections::HashMap;
37use std::sync::{Arc, OnceLock};
38use std::time::Duration;
39
40pub(crate) const WRITE_CHANNEL: &str = "write";
41pub(crate) const REBALANCE_CHANNEL: &str = "rebalance";
42
43/// Trait for querying the vector db
44#[async_trait]
45pub trait VectorDbRead {
46    /// Search for k-nearest neighbors to a query vector.
47    ///
48    /// This implements the SPANN-style query algorithm:
49    /// 1. Search HNSW for nearest centroids
50    /// 2. Load posting lists for those centroids
51    /// 3. Filter deleted vectors
52    /// 4. Score candidates and return top-k
53    ///
54    /// # Arguments
55    /// * `query` - search query
56    ///
57    /// # Returns
58    /// Vector of SearchResults sorted by similarity (best first)
59    ///
60    /// # Errors
61    /// Returns an error if:
62    /// - Query dimensions don't match collection dimensions
63    /// - Storage read fails
64    async fn search(&self, query: &Query) -> Result<Vec<SearchResult>>;
65
66    async fn search_with_nprobe(&self, query: &Query, nprobe: usize) -> Result<Vec<SearchResult>>;
67
68    /// Retrieve a vector record by its external ID.
69    ///
70    /// This is a point lookup operation that retrieves a single record with all its fields.
71    /// Returns `None` if the record doesn't exist or has been deleted.
72    ///
73    /// # Arguments
74    ///
75    /// * `id` - External ID of the record to retrieve
76    ///
77    /// # Returns
78    ///
79    /// `Some(VectorRecord)` if found, `None` if not found or deleted.
80    async fn get(&self, id: &str) -> Result<Option<Vector>>;
81}
82
83/// Vector database for storing and querying embedding vectors.
84///
85/// `VectorDb` provides a high-level API for ingesting vectors with metadata.
86/// It handles internal details like ID allocation, centroid assignment,
87/// and metadata index maintenance automatically.
88pub struct VectorDb {
89    config: Config,
90    #[allow(dead_code)]
91    storage: Arc<dyn Storage>,
92
93    /// The WriteCoordinator itself (stored to keep it alive).
94    write_coordinator: WriteCoordinator<VectorDbWriteDelta, VectorDbFlusher>,
95
96    /// In-memory HNSW graph for centroid search (immutable after initialization).
97    centroid_graph: Arc<dyn CentroidGraph>,
98}
99
100impl VectorDb {
101    /// Open or create a vector database with the given configuration and centroids.
102    ///
103    /// If the database already exists (centroids are already stored), the provided
104    /// centroids are ignored and the stored centroids are used instead.
105    ///
106    /// If the database is new, the provided centroids are written to storage and
107    /// used to build the HNSW index.
108    ///
109    /// # Arguments
110    /// * `config` - Database configuration
111    /// * `centroids` - Initial centroids to use if database is new
112    ///
113    /// # Configuration Compatibility
114    /// If the database already exists, the configuration must be compatible:
115    /// - `dimensions` must match exactly
116    /// - `distance_metric` must match exactly
117    ///
118    /// Other configuration options (like `flush_interval`) can be changed
119    /// on subsequent opens.
120    pub async fn open(config: Config) -> Result<Self> {
121        let sb = StorageBuilder::new(&config.storage)
122            .await
123            .map_err(|e| Error::Storage(format!("Failed to create storage: {e}")))?;
124        Self::open_with_storage(config, sb).await
125    }
126
127    pub async fn open_with_storage(config: Config, builder: StorageBuilder) -> Result<Self> {
128        let centroid1: Vec<f32> = vec![0.0f32; config.dimensions as usize];
129        Self::open_with_centroids(config, vec![centroid1], builder).await
130    }
131
132    pub async fn open_with_centroids(
133        config: Config,
134        centroids: Vec<Vec<f32>>,
135        builder: StorageBuilder,
136    ) -> Result<Self> {
137        let merge_op = VectorDbMergeOperator::new(config.dimensions as usize);
138        let storage = builder
139            .with_semantics(StorageSemantics::new().with_merge_operator(Arc::new(merge_op)))
140            .build()
141            .await
142            .map_err(|e| Error::Storage(format!("Failed to create storage: {e}")))?;
143
144        Self::load_or_init_db(storage, config, centroids).await
145    }
146
147    /// Create a vector database with the given storage, configuration, and centroids.
148    /// The fn is internal to this module. It is intended to be used by the public api and
149    /// by tests.
150    ///
151    /// If centroids already exist in storage, the provided centroids are ignored.
152    /// Otherwise, the provided centroids are written to storage.
153    async fn load_or_init_db(
154        storage: Arc<dyn Storage>,
155        config: Config,
156        centroids: Vec<Vec<f32>>,
157    ) -> Result<Self> {
158        // Initialize sequence allocator for internal ID generation
159        let seq_key = SeqBlockKey.encode();
160        let mut id_allocator = SequenceAllocator::load(storage.as_ref(), seq_key).await?;
161
162        // Get initial snapshot
163        let snapshot = storage.snapshot().await?;
164
165        // For now, load the full ID dictionary from storage into memory at startup
166        // Eventually, we should load this in the background and allow the delta to
167        // read ids that are not yet loaded from storage
168        let dictionary = Arc::new(DashMap::new());
169        {
170            Self::load_dictionary_from_storage(snapshot.as_ref(), &dictionary).await?;
171        }
172
173        // Load centroid counts from storage
174        let centroid_counts = Self::load_centroid_counts_from_storage(snapshot.as_ref()).await?;
175
176        // For now, just force bootstrap centroids. Eventually we'll derive these automatically
177        // from the vectors
178        let (centroid_graph, current_chunk_id, current_chunk_count) =
179            Self::load_or_create_centroids(
180                &storage,
181                snapshot.as_ref(),
182                &config,
183                centroids,
184                &mut id_allocator,
185            )
186            .await?;
187
188        // Create flusher for the WriteCoordinator
189        let flusher = VectorDbFlusher {
190            storage: Arc::clone(&storage),
191        };
192
193        let handle_tx = Arc::new(OnceLock::new());
194        let rebalancer = IndexRebalancer::new(
195            IndexRebalancerOpts {
196                dimensions: config.dimensions as usize,
197                distance_metric: config.distance_metric,
198                split_search_neighbourhood: config.split_search_neighbourhood,
199                split_threshold_vectors: config.split_threshold_vectors,
200                merge_threshold_vectors: config.merge_threshold_vectors,
201                max_rebalance_tasks: config.max_rebalance_tasks,
202            },
203            centroid_graph.clone(),
204            centroid_counts,
205            handle_tx.clone(),
206        );
207
208        let pause_handle = Arc::new(OnceLock::new());
209        let ctx = VectorDbDeltaContext {
210            opts: VectorDbDeltaOpts {
211                dimensions: config.dimensions as usize,
212                chunk_target: config.chunk_target as usize,
213                max_pending_and_running_rebalance_tasks: config
214                    .max_pending_and_running_rebalance_tasks,
215                rebalance_backpressure_resume_threshold: config
216                    .rebalance_backpressure_resume_threshold,
217                split_threshold_vectors: config.split_threshold_vectors,
218                indexed_fields: VectorDbDeltaOpts::indexed_fields_from(&config.metadata_fields),
219            },
220            dictionary: Arc::clone(&dictionary),
221            centroid_graph: Arc::clone(&centroid_graph),
222            id_allocator,
223            current_chunk_id,
224            current_chunk_count,
225            rebalancer,
226            pause_handle: pause_handle.clone(),
227        };
228
229        // start write coordinator
230        let coordinator_config = WriteCoordinatorConfig {
231            queue_capacity: 1000,
232            flush_interval: Duration::from_secs(5),
233            flush_size_threshold: 64 * 1024 * 1024,
234        };
235        let mut write_coordinator = WriteCoordinator::new(
236            coordinator_config,
237            vec![WRITE_CHANNEL.to_string(), REBALANCE_CHANNEL.to_string()],
238            ctx,
239            snapshot.clone(),
240            flusher,
241        );
242        handle_tx
243            .set(write_coordinator.handle(REBALANCE_CHANNEL))
244            .map_err(|_e| "unreachable")
245            .unwrap();
246        pause_handle
247            .set(write_coordinator.pause_handle(WRITE_CHANNEL))
248            .map_err(|_e| "unreachable")
249            .unwrap();
250        write_coordinator.start();
251
252        Ok(Self {
253            config,
254            storage,
255            write_coordinator,
256            centroid_graph,
257        })
258    }
259
260    /// Load centroids from storage if they exist, otherwise create them from the provided entries.
261    /// Returns the centroid graph and the last chunk's ID and entry count, used for initializing
262    /// chunk tracking state.
263    async fn load_or_create_centroids(
264        storage: &Arc<dyn Storage>,
265        snapshot: &dyn StorageSnapshot,
266        config: &Config,
267        centroids: Vec<Vec<f32>>,
268        id_allocator: &mut SequenceAllocator,
269    ) -> Result<(Arc<dyn CentroidGraph>, u32, usize)> {
270        // Check if centroids already exist in storage
271        let scan_result = snapshot
272            .scan_all_centroids(config.dimensions as usize)
273            .await?;
274
275        if !scan_result.entries.is_empty() {
276            let last_chunk_id = scan_result.last_chunk_id;
277            let last_chunk_count = scan_result.last_chunk_count;
278            // Filter out centroids that have been deleted (tracked in deletions bitmap)
279            let deletions = snapshot.get_deleted_vectors().await?;
280            let live_centroids: Vec<CentroidEntry> = scan_result
281                .entries
282                .into_iter()
283                .filter(|c| !deletions.contains(c.centroid_id))
284                .collect();
285            let graph = build_centroid_graph(live_centroids, config.distance_metric)?;
286            return Ok((Arc::from(graph), last_chunk_id, last_chunk_count));
287        }
288
289        // No existing centroids - validate and write the provided ones
290        if centroids.is_empty() {
291            return Err(Error::InvalidInput(
292                "Centroids must be provided when creating a new database".to_string(),
293            ));
294        }
295
296        // Validate centroid dimensions
297        for centroid in &centroids {
298            if centroid.len() != config.dimensions as usize {
299                return Err(Error::InvalidInput(format!(
300                    "Centroid dimension mismatch: expected {}, got {}",
301                    config.dimensions,
302                    centroid.len()
303                )));
304            }
305        }
306
307        // Allocate IDs and build CentroidEntries
308        let mut ops = Vec::new();
309        let mut entries = Vec::with_capacity(centroids.len());
310        for vector in centroids {
311            let (centroid_id, seq_alloc_put) = id_allocator.allocate_one();
312            if let Some(seq_alloc_put) = seq_alloc_put {
313                ops.push(common::storage::RecordOp::Put(seq_alloc_put.into()));
314            }
315            entries.push(CentroidEntry::new(centroid_id, vector));
316        }
317
318        // Write centroids to storage in chunks
319        let chunk_target = config.chunk_target as usize;
320        let num_chunks = entries.chunks(chunk_target).len();
321        for (chunk_idx, chunk_entries) in entries.chunks(chunk_target).enumerate() {
322            ops.push(crate::storage::record::put_centroid_chunk(
323                chunk_idx as u32,
324                chunk_entries.to_vec(),
325                config.dimensions as usize,
326            ));
327        }
328        storage.apply(ops).await?;
329
330        // Compute last chunk state from what we just wrote
331        let last_chunk_id = if num_chunks == 0 {
332            0
333        } else {
334            (num_chunks - 1) as u32
335        };
336        let last_chunk_count = if entries.is_empty() {
337            0
338        } else {
339            entries.len() - (last_chunk_id as usize * chunk_target)
340        };
341
342        // Build and return the graph
343        let graph = build_centroid_graph(entries, config.distance_metric)?;
344        Ok((Arc::from(graph), last_chunk_id, last_chunk_count))
345    }
346
347    /// Load ID dictionary entries from storage into the in-memory DashMap.
348    async fn load_dictionary_from_storage(
349        snapshot: &dyn StorageRead,
350        dictionary: &DashMap<String, u64>,
351    ) -> Result<()> {
352        // Create prefix for all IdDictionary records
353        let mut prefix_buf = bytes::BytesMut::with_capacity(2);
354        crate::serde::RecordType::IdDictionary
355            .prefix()
356            .write_to(&mut prefix_buf);
357        let prefix = prefix_buf.freeze();
358
359        // Scan all IdDictionary records
360        let range = common::BytesRange::prefix(prefix);
361        let records = snapshot.scan(range).await?;
362
363        for record in records {
364            // Decode the key to get external_id
365            let key = crate::serde::key::IdDictionaryKey::decode(&record.key)?;
366            let external_id = key.external_id.clone();
367
368            // Decode the value to get internal_id
369            let mut slice = record.value.as_ref();
370            let internal_id = common::serde::encoding::decode_u64(&mut slice).map_err(|e| {
371                Error::Encoding(format!(
372                    "failed to decode internal ID from ID dictionary: {e}"
373                ))
374            })?;
375
376            dictionary.insert(external_id, internal_id);
377        }
378
379        Ok(())
380    }
381
382    /// Load centroid counts from storage into a HashMap.
383    ///
384    /// Scans all CentroidStats records and extracts the accumulated num_vectors
385    /// for each centroid.
386    async fn load_centroid_counts_from_storage(
387        snapshot: &dyn StorageRead,
388    ) -> Result<HashMap<u64, u64>> {
389        let stats = snapshot.scan_all_centroid_stats().await?;
390        let mut counts = HashMap::new();
391        for (centroid_id, value) in stats {
392            counts.insert(centroid_id, value.num_vectors.max(0) as u64);
393        }
394        Ok(counts)
395    }
396
397    /// Write vectors to the database.
398    ///
399    /// This is the primary write method. It accepts a batch of vectors and
400    /// returns when the data has been accepted for ingestion (but not
401    /// necessarily flushed to durable storage).
402    ///
403    /// # Atomicity
404    ///
405    /// This operation is atomic: either all vectors in the batch are accepted,
406    /// or none are. This matches the behavior of `TimeSeriesDb::write()`.
407    ///
408    /// # Upsert Semantics
409    ///
410    /// Writing a vector with an ID that already exists performs an upsert:
411    /// the old vector is deleted and replaced with the new one. The system
412    /// allocates a new internal ID for the updated vector and marks the old
413    /// internal ID as deleted. This ensures index structures are updated
414    /// correctly without expensive read-modify-write cycles.
415    ///
416    /// # Validation
417    ///
418    /// The following validations are performed:
419    /// - Vector dimensions must match `Config::dimensions`
420    /// - Attribute names must be defined in `Config::metadata_fields` (if specified)
421    /// - Attribute types must match the schema
422    pub async fn write(&self, vectors: Vec<Vector>) -> Result<()> {
423        // Validate and prepare all vectors
424        let mut writes = Vec::with_capacity(vectors.len());
425        for vector in vectors {
426            writes.push(self.prepare_vector_write(vector)?);
427        }
428
429        // Send all writes to coordinator in a single batch and wait to be applied
430        let mut write_handle = self
431            .write_coordinator
432            .handle(WRITE_CHANNEL)
433            .write(VectorDbWrite::Write(writes))
434            .await
435            .map_err(|e| Error::Internal(format!("{}", e)))?;
436        write_handle
437            .wait(Durability::Applied)
438            .await
439            .map_err(|e| Error::Internal(format!("{}", e)))?;
440
441        Ok(())
442    }
443
444    /// Write vectors to the database with a timeout.
445    ///
446    /// This is the primary write method. It accepts a batch of vectors and
447    /// returns when the data has been accepted for ingestion (but not
448    /// necessarily flushed to durable storage).
449    ///
450    /// The write may time out if the db is busy compacting or indexing.
451    ///
452    /// # Atomicity
453    ///
454    /// This operation is atomic: either all vectors in the batch are accepted,
455    /// or none are. This matches the behavior of `TimeSeriesDb::write()`.
456    ///
457    /// # Upsert Semantics
458    ///
459    /// Writing a vector with an ID that already exists performs an upsert:
460    /// the old vector is deleted and replaced with the new one. The system
461    /// allocates a new internal ID for the updated vector and marks the old
462    /// internal ID as deleted. This ensures index structures are updated
463    /// correctly without expensive read-modify-write cycles.
464    ///
465    /// # Validation
466    ///
467    /// The following validations are performed:
468    /// - Vector dimensions must match `Config::dimensions`
469    /// - Attribute names must be defined in `Config::metadata_fields` (if specified)
470    /// - Attribute types must match the schema
471    pub async fn write_timeout(&self, vectors: Vec<Vector>, timeout: Duration) -> Result<()> {
472        // Validate and prepare all vectors
473        let mut writes = Vec::with_capacity(vectors.len());
474        for vector in vectors {
475            writes.push(self.prepare_vector_write(vector)?);
476        }
477
478        // Send all writes to coordinator in a single batch and wait to be applied
479        let mut write_handle = self
480            .write_coordinator
481            .handle(WRITE_CHANNEL)
482            .write_timeout(VectorDbWrite::Write(writes), timeout)
483            .await
484            .map_err(|e| Error::Internal(format!("{}", e)))?;
485        write_handle
486            .wait(Durability::Applied)
487            .await
488            .map_err(|e| Error::Internal(format!("{}", e)))?;
489
490        Ok(())
491    }
492
493    /// Validate and prepare a vector write for the coordinator.
494    ///
495    /// This validates the vector. The delta handles ID allocation,
496    /// dictionary lookup, and centroid assignment.
497    fn prepare_vector_write(&self, vector: Vector) -> Result<VectorWrite> {
498        // Validate external ID length
499        if vector.id.len() > 64 {
500            return Err(Error::InvalidInput(format!(
501                "External ID too long: {} bytes (max 64)",
502                vector.id.len()
503            )));
504        }
505
506        // Convert attributes to map for validation
507        let attributes = attributes_to_map(&vector.attributes);
508
509        // Extract and validate "vector" attribute
510        let values = match attributes.get(VECTOR_FIELD_NAME) {
511            Some(AttributeValue::Vector(v)) => v.clone(),
512            Some(_) => {
513                return Err(Error::InvalidInput(format!(
514                    "Field '{}' must have type Vector",
515                    VECTOR_FIELD_NAME
516                )));
517            }
518            None => {
519                return Err(Error::InvalidInput(format!(
520                    "Missing required field '{}'",
521                    VECTOR_FIELD_NAME
522                )));
523            }
524        };
525
526        // Validate dimensions
527        if values.len() != self.config.dimensions as usize {
528            return Err(Error::InvalidInput(format!(
529                "Vector dimension mismatch: expected {}, got {}",
530                self.config.dimensions,
531                values.len()
532            )));
533        }
534
535        // Validate attributes against schema (if schema is defined)
536        if !self.config.metadata_fields.is_empty() {
537            self.validate_attributes(&attributes)?;
538        }
539
540        // Convert attributes to vec of tuples for VectorWrite
541        let attributes_vec: Vec<(String, AttributeValue)> = attributes.into_iter().collect();
542
543        Ok(VectorWrite {
544            external_id: vector.id,
545            values,
546            attributes: attributes_vec,
547        })
548    }
549
550    /// Validates attributes against the configured schema.
551    fn validate_attributes(&self, metadata: &HashMap<String, AttributeValue>) -> Result<()> {
552        // Build a map of field name -> expected type for quick lookup
553        let schema: HashMap<&str, crate::serde::FieldType> = self
554            .config
555            .metadata_fields
556            .iter()
557            .map(|spec| (spec.name.as_str(), spec.field_type))
558            .collect();
559
560        // Check each provided attribute (skip VECTOR_FIELD_NAME which is always allowed)
561        for (field_name, value) in metadata {
562            // Skip the special "vector" field
563            if field_name == VECTOR_FIELD_NAME {
564                continue;
565            }
566
567            match schema.get(field_name.as_str()) {
568                Some(expected_type) => {
569                    // Validate type matches
570                    let actual_type = match value {
571                        AttributeValue::String(_) => crate::serde::FieldType::String,
572                        AttributeValue::Int64(_) => crate::serde::FieldType::Int64,
573                        AttributeValue::Float64(_) => crate::serde::FieldType::Float64,
574                        AttributeValue::Bool(_) => crate::serde::FieldType::Bool,
575                        AttributeValue::Vector(_) => crate::serde::FieldType::Vector,
576                    };
577
578                    if actual_type != *expected_type {
579                        return Err(Error::InvalidInput(format!(
580                            "Type mismatch for field '{}': expected {:?}, got {:?}",
581                            field_name, expected_type, actual_type
582                        )));
583                    }
584                }
585                None => {
586                    return Err(Error::InvalidInput(format!(
587                        "Unknown metadata field: '{}'. Valid fields: {:?}",
588                        field_name,
589                        schema.keys().collect::<Vec<_>>()
590                    )));
591                }
592            }
593        }
594
595        Ok(())
596    }
597
598    /// Force flush all pending data to durable storage.
599    ///
600    /// Flushes the in-memory delta to the storage memtable, then persists
601    /// to durable storage. After this returns, data is both readable and
602    /// durable.
603    ///
604    /// # Atomic Flush
605    ///
606    /// The flush operation is atomic:
607    /// 1. All pending writes are frozen into an immutable delta
608    /// 2. RecordOps are applied in one batch via `storage.apply()`
609    /// 3. The snapshot is updated for queries
610    /// 4. Data is flushed to durable storage
611    ///
612    /// This ensures ID dictionary updates, deletes, and new records are all
613    /// applied together, maintaining consistency.
614    pub async fn flush(&self) -> Result<()> {
615        let mut handle = self
616            .write_coordinator
617            .handle(WRITE_CHANNEL)
618            .flush(true)
619            .await
620            .map_err(|e| Error::Internal(format!("{}", e)))?;
621        handle
622            .wait(Durability::Durable)
623            .await
624            .map_err(|e| Error::Internal(format!("{}", e)))?;
625        Ok(())
626    }
627
628    /// Closes the vector database, flushing any pending data and releasing resources.
629    ///
630    /// All written data is flushed to durable storage before the database is
631    /// closed. For SlateDB-backed storage, this also releases the database
632    /// fence.
633    pub async fn close(self) -> Result<()> {
634        self.flush().await?;
635        self.write_coordinator
636            .stop()
637            .await
638            .map_err(Error::Internal)?;
639        self.storage.close().await?;
640        Ok(())
641    }
642
643    pub fn num_centroids(&self) -> usize {
644        self.centroid_graph.len()
645    }
646
647    /// Create a QueryEngine from the current snapshot for executing queries.
648    pub(crate) fn query_engine(&self) -> QueryEngine {
649        let snapshot = self.write_coordinator.view().snapshot.clone();
650        let options = QueryEngineOptions {
651            dimensions: self.config.dimensions,
652            distance_metric: self.config.distance_metric,
653            query_pruning_factor: self.config.query_pruning_factor,
654        };
655        QueryEngine::new(options, self.centroid_graph.clone(), snapshot)
656    }
657
658    /// Search using brute-force centroid lookup (for diagnostics).
659    pub async fn search_exact_nprobe(
660        &self,
661        query: &Query,
662        nprobe: usize,
663    ) -> Result<Vec<SearchResult>> {
664        self.query_engine().search_exact_nprobe(query, nprobe).await
665    }
666
667    pub async fn snapshot(&self) -> Box<dyn VectorDbRead> {
668        Box::new(VectorDbReader::new(self.query_engine())) as Box<dyn VectorDbRead>
669    }
670}
671
672#[async_trait]
673impl VectorDbRead for VectorDb {
674    async fn search(&self, query: &Query) -> Result<Vec<SearchResult>> {
675        self.query_engine().search(query).await
676    }
677
678    async fn search_with_nprobe(&self, query: &Query, nprobe: usize) -> Result<Vec<SearchResult>> {
679        self.query_engine().search_with_nprobe(query, nprobe).await
680    }
681
682    async fn get(&self, id: &str) -> Result<Option<Vector>> {
683        self.query_engine().get(id).await
684    }
685}
686
687#[cfg(test)]
688mod tests {
689    use super::*;
690    use crate::model::{MetadataFieldSpec, Vector};
691    use crate::serde::FieldType;
692    use crate::serde::collection_meta::DistanceMetric;
693    use crate::serde::key::{IdDictionaryKey, VectorDataKey};
694    use crate::serde::vector_data::VectorDataValue;
695    use common::StorageConfig;
696    use opendata_macros::storage_test;
697    use std::time::Duration;
698
699    fn create_test_config() -> Config {
700        Config {
701            storage: StorageConfig::InMemory,
702            dimensions: 3,
703            distance_metric: DistanceMetric::L2,
704            flush_interval: Duration::from_secs(60),
705            split_threshold_vectors: 10_000,
706            merge_threshold_vectors: 200,
707            split_search_neighbourhood: 8,
708            chunk_target: 4096,
709            metadata_fields: vec![
710                MetadataFieldSpec::new("category", FieldType::String, true),
711                MetadataFieldSpec::new("price", FieldType::Int64, true),
712            ],
713            ..Default::default()
714        }
715    }
716
717    fn create_test_centroids(dimensions: usize) -> Vec<Vec<f32>> {
718        vec![vec![1.0; dimensions]]
719    }
720
721    #[tokio::test]
722    async fn should_open_vector_db() {
723        // given
724        let config = create_test_config();
725
726        // when
727        let result = VectorDb::open(config).await;
728
729        // then
730        assert!(result.is_ok());
731    }
732
733    #[storage_test(merge_operator = VectorDbMergeOperator::new(3))]
734    async fn should_write_and_flush_vectors(storage: Arc<dyn Storage>) {
735        // given
736        let config = create_test_config();
737        let centroids = create_test_centroids(3);
738        let db = VectorDb::load_or_init_db(Arc::clone(&storage), config, centroids)
739            .await
740            .unwrap();
741
742        let vectors = vec![
743            Vector::builder("vec-1", vec![1.0, 0.0, 0.0])
744                .attribute("category", "shoes")
745                .attribute("price", 99i64)
746                .build(),
747            Vector::builder("vec-2", vec![0.0, 1.0, 0.0])
748                .attribute("category", "boots")
749                .attribute("price", 149i64)
750                .build(),
751        ];
752
753        // when
754        db.write(vectors).await.unwrap();
755        db.flush().await.unwrap();
756
757        // then - verify records exist in storage
758        // Check VectorData records (now contain external_id, vector, and metadata)
759        // Note: centroid IDs are allocated from the same sequence as vector IDs.
760        // With 1 centroid (ID 0), vectors start at ID 1.
761        let vec1_data_key = VectorDataKey::new(1).encode();
762        let vec1_data = storage.get(vec1_data_key).await.unwrap();
763        assert!(vec1_data.is_some());
764
765        let vec2_data_key = VectorDataKey::new(2).encode();
766        let vec2_data = storage.get(vec2_data_key).await.unwrap();
767        assert!(vec2_data.is_some());
768
769        // Check IdDictionary
770        let dict_key1 = IdDictionaryKey::new("vec-1").encode();
771        let dict_entry1 = storage.get(dict_key1).await.unwrap();
772        assert!(dict_entry1.is_some());
773    }
774
775    #[storage_test(merge_operator = VectorDbMergeOperator::new(3))]
776    async fn should_upsert_existing_vector(storage: Arc<dyn Storage>) {
777        // given
778        let config = create_test_config();
779        let centroids = create_test_centroids(3);
780        let db = VectorDb::load_or_init_db(Arc::clone(&storage), config, centroids)
781            .await
782            .unwrap();
783
784        // First write
785        let vector1 = Vector::builder("vec-1", vec![1.0, 0.0, 0.0])
786            .attribute("category", "shoes")
787            .attribute("price", 99i64)
788            .build();
789        db.write(vec![vector1]).await.unwrap();
790        db.flush().await.unwrap();
791
792        // when - upsert with same ID but different values
793        let vector2 = Vector::builder("vec-1", vec![2.0, 3.0, 4.0])
794            .attribute("category", "boots")
795            .attribute("price", 199i64)
796            .build();
797        db.write(vec![vector2]).await.unwrap();
798        db.flush().await.unwrap();
799
800        // then - verify new vector data
801        // Centroid takes ID 0, first write gets ID 1, upsert gets ID 2
802        let vec_data_key = VectorDataKey::new(2).encode(); // New internal ID
803        let vec_data = storage.get(vec_data_key).await.unwrap();
804        assert!(vec_data.is_some());
805        let decoded = VectorDataValue::decode_from_bytes(&vec_data.unwrap().value, 3).unwrap();
806        assert_eq!(decoded.vector_field(), &[2.0, 3.0, 4.0]);
807
808        // Verify only one IdDictionary entry
809        let dict_key = IdDictionaryKey::new("vec-1").encode();
810        let dict_entry = storage.get(dict_key).await.unwrap();
811        assert!(dict_entry.is_some());
812    }
813
814    #[tokio::test]
815    async fn should_reject_vectors_with_wrong_dimensions() {
816        // given
817        let config = create_test_config();
818        let db = VectorDb::open(config).await.unwrap();
819
820        let vector = Vector::new("vec-1", vec![1.0, 2.0]); // Wrong: 2 instead of 3
821
822        // when
823        let result = db.write(vec![vector]).await;
824
825        // then
826        assert!(result.is_err());
827        assert!(
828            result
829                .unwrap_err()
830                .to_string()
831                .contains("dimension mismatch")
832        );
833    }
834
835    #[tokio::test]
836    async fn should_flush_empty_delta_without_error() {
837        // given
838        let config = create_test_config();
839        let db = VectorDb::open(config).await.unwrap();
840
841        // when
842        let result = db.flush().await;
843
844        // then
845        assert!(result.is_ok());
846    }
847
848    #[storage_test(merge_operator = VectorDbMergeOperator::new(3))]
849    async fn should_load_dictionary_on_reopen(storage: Arc<dyn Storage>) {
850        // given - create database and write vectors
851        let config = create_test_config();
852        let centroids = create_test_centroids(3);
853
854        {
855            let db =
856                VectorDb::load_or_init_db(Arc::clone(&storage), config.clone(), centroids.clone())
857                    .await
858                    .unwrap();
859            let vectors = vec![
860                Vector::builder("vec-1", vec![1.0, 0.0, 0.0])
861                    .attribute("category", "shoes")
862                    .attribute("price", 99i64)
863                    .build(),
864                Vector::builder("vec-2", vec![0.0, 1.0, 0.0])
865                    .attribute("category", "boots")
866                    .attribute("price", 149i64)
867                    .build(),
868            ];
869            db.write(vectors).await.unwrap();
870            db.flush().await.unwrap();
871        }
872
873        // when - reopen database (centroids should be loaded from storage)
874        let db2 = VectorDb::load_or_init_db(Arc::clone(&storage), config, vec![])
875            .await
876            .unwrap();
877
878        // then - should be able to search (dictionary and centroids loaded from storage)
879        let results = db2
880            .search(&Query::new(vec![1.0, 0.0, 0.0]).with_limit(10))
881            .await
882            .unwrap();
883        assert!(!results.is_empty());
884    }
885
886    #[tokio::test]
887    async fn flush_should_be_durable_across_reopen() {
888        use common::storage::config::{
889            LocalObjectStoreConfig, ObjectStoreConfig, SlateDbStorageConfig,
890        };
891
892        let tmp_dir = tempfile::tempdir().unwrap();
893        let storage_config = StorageConfig::SlateDb(SlateDbStorageConfig {
894            path: "data".to_string(),
895            object_store: ObjectStoreConfig::Local(LocalObjectStoreConfig {
896                path: tmp_dir.path().to_str().unwrap().to_string(),
897            }),
898            settings_path: None,
899            block_cache: None,
900        });
901
902        let config = Config {
903            storage: storage_config.clone(),
904            dimensions: 3,
905            distance_metric: DistanceMetric::L2,
906            ..Default::default()
907        };
908
909        // Write vectors and flush
910        let db = VectorDb::open(config.clone()).await.unwrap();
911        db.write(vec![
912            Vector::new("vec-1", vec![1.0, 0.0, 0.0]),
913            Vector::new("vec-2", vec![0.0, 1.0, 0.0]),
914        ])
915        .await
916        .unwrap();
917        db.flush().await.unwrap();
918        drop(db);
919
920        // Reopen from durable state — data should be visible
921        let db2 = VectorDb::open(config).await.unwrap();
922        let results = db2
923            .search(&Query::new(vec![1.0, 0.0, 0.0]).with_limit(10))
924            .await
925            .unwrap();
926        assert!(
927            !results.is_empty(),
928            "expected data to be durable after flush, but search returned no results"
929        );
930    }
931
932    #[tokio::test]
933    #[allow(clippy::needless_return)]
934    async fn close_without_explicit_flush_guarantees_durability() {
935        use common::storage::config::{
936            LocalObjectStoreConfig, ObjectStoreConfig, SlateDbStorageConfig,
937        };
938
939        let tmp_dir = tempfile::tempdir().unwrap();
940        let storage_config = StorageConfig::SlateDb(SlateDbStorageConfig {
941            path: "data".to_string(),
942            object_store: ObjectStoreConfig::Local(LocalObjectStoreConfig {
943                path: tmp_dir.path().to_str().unwrap().to_string(),
944            }),
945            settings_path: None,
946            block_cache: None,
947        });
948
949        let config = Config {
950            storage: storage_config.clone(),
951            dimensions: 3,
952            distance_metric: DistanceMetric::L2,
953            ..Default::default()
954        };
955
956        // Write a vector and close without calling flush()
957        {
958            let db = VectorDb::open(config.clone()).await.unwrap();
959            db.write(vec![Vector::new("vec-1", vec![1.0, 0.0, 0.0])])
960                .await
961                .unwrap();
962            db.close().await.unwrap();
963        }
964
965        // Reopen and verify the vector survived
966        let db2 = VectorDb::open(config).await.unwrap();
967        let results = db2
968            .search(&Query::new(vec![1.0, 0.0, 0.0]).with_limit(1))
969            .await
970            .unwrap();
971        assert_eq!(results.len(), 1);
972        assert_eq!(results[0].vector.id, "vec-1");
973    }
974
975    #[tokio::test]
976    async fn should_fail_if_no_centroids_provided_for_new_db() {
977        // given - new database without centroids
978        let config = create_test_config();
979
980        // when
981        let sb = StorageBuilder::new(&config.storage).await.unwrap();
982        let result = VectorDb::open_with_centroids(config, vec![], sb).await;
983
984        // then
985        match result {
986            Err(e) => assert!(
987                e.to_string().contains("Centroids must be provided"),
988                "unexpected error: {}",
989                e
990            ),
991            Ok(_) => panic!("expected error when no centroids provided"),
992        }
993    }
994}