1use crate::VectorDbReader;
15use crate::delta::{
16 VectorDbDeltaContext, VectorDbDeltaOpts, VectorDbWrite, VectorDbWriteDelta, VectorWrite,
17};
18use crate::error::{Error, Result};
19use crate::flusher::VectorDbFlusher;
20use crate::hnsw::{CentroidGraph, build_centroid_graph};
21use crate::lire::rebalancer::{IndexRebalancer, IndexRebalancerOpts};
22use crate::model::{
23 AttributeValue, Config, Query, SearchResult, VECTOR_FIELD_NAME, Vector, attributes_to_map,
24};
25use crate::query_engine::{QueryEngine, QueryEngineOptions};
26use crate::serde::centroid_chunk::CentroidEntry;
27use crate::serde::key::SeqBlockKey;
28use crate::storage::VectorDbStorageReadExt;
29use crate::storage::merge_operator::VectorDbMergeOperator;
30use async_trait::async_trait;
31use common::SequenceAllocator;
32use common::coordinator::{Durability, WriteCoordinator, WriteCoordinatorConfig};
33use common::storage::{Storage, StorageRead, StorageSnapshot};
34use common::{StorageBuilder, StorageSemantics};
35use dashmap::DashMap;
36use std::collections::HashMap;
37use std::sync::{Arc, OnceLock};
38use std::time::Duration;
39
40pub(crate) const WRITE_CHANNEL: &str = "write";
41pub(crate) const REBALANCE_CHANNEL: &str = "rebalance";
42
43#[async_trait]
45pub trait VectorDbRead {
46 async fn search(&self, query: &Query) -> Result<Vec<SearchResult>>;
65
66 async fn search_with_nprobe(&self, query: &Query, nprobe: usize) -> Result<Vec<SearchResult>>;
67
68 async fn get(&self, id: &str) -> Result<Option<Vector>>;
81}
82
83pub struct VectorDb {
89 config: Config,
90 #[allow(dead_code)]
91 storage: Arc<dyn Storage>,
92
93 write_coordinator: WriteCoordinator<VectorDbWriteDelta, VectorDbFlusher>,
95
96 centroid_graph: Arc<dyn CentroidGraph>,
98}
99
100impl VectorDb {
101 pub async fn open(config: Config) -> Result<Self> {
121 let sb = StorageBuilder::new(&config.storage)
122 .await
123 .map_err(|e| Error::Storage(format!("Failed to create storage: {e}")))?;
124 Self::open_with_storage(config, sb).await
125 }
126
127 pub async fn open_with_storage(config: Config, builder: StorageBuilder) -> Result<Self> {
128 let centroid1: Vec<f32> = vec![0.0f32; config.dimensions as usize];
129 Self::open_with_centroids(config, vec![centroid1], builder).await
130 }
131
132 pub async fn open_with_centroids(
133 config: Config,
134 centroids: Vec<Vec<f32>>,
135 builder: StorageBuilder,
136 ) -> Result<Self> {
137 let merge_op = VectorDbMergeOperator::new(config.dimensions as usize);
138 let storage = builder
139 .with_semantics(StorageSemantics::new().with_merge_operator(Arc::new(merge_op)))
140 .build()
141 .await
142 .map_err(|e| Error::Storage(format!("Failed to create storage: {e}")))?;
143
144 Self::load_or_init_db(storage, config, centroids).await
145 }
146
147 async fn load_or_init_db(
154 storage: Arc<dyn Storage>,
155 config: Config,
156 centroids: Vec<Vec<f32>>,
157 ) -> Result<Self> {
158 let seq_key = SeqBlockKey.encode();
160 let mut id_allocator = SequenceAllocator::load(storage.as_ref(), seq_key).await?;
161
162 let snapshot = storage.snapshot().await?;
164
165 let dictionary = Arc::new(DashMap::new());
169 {
170 Self::load_dictionary_from_storage(snapshot.as_ref(), &dictionary).await?;
171 }
172
173 let centroid_counts = Self::load_centroid_counts_from_storage(snapshot.as_ref()).await?;
175
176 let (centroid_graph, current_chunk_id, current_chunk_count) =
179 Self::load_or_create_centroids(
180 &storage,
181 snapshot.as_ref(),
182 &config,
183 centroids,
184 &mut id_allocator,
185 )
186 .await?;
187
188 let flusher = VectorDbFlusher {
190 storage: Arc::clone(&storage),
191 };
192
193 let handle_tx = Arc::new(OnceLock::new());
194 let rebalancer = IndexRebalancer::new(
195 IndexRebalancerOpts {
196 dimensions: config.dimensions as usize,
197 distance_metric: config.distance_metric,
198 split_search_neighbourhood: config.split_search_neighbourhood,
199 split_threshold_vectors: config.split_threshold_vectors,
200 merge_threshold_vectors: config.merge_threshold_vectors,
201 max_rebalance_tasks: config.max_rebalance_tasks,
202 },
203 centroid_graph.clone(),
204 centroid_counts,
205 handle_tx.clone(),
206 );
207
208 let pause_handle = Arc::new(OnceLock::new());
209 let ctx = VectorDbDeltaContext {
210 opts: VectorDbDeltaOpts {
211 dimensions: config.dimensions as usize,
212 chunk_target: config.chunk_target as usize,
213 max_pending_and_running_rebalance_tasks: config
214 .max_pending_and_running_rebalance_tasks,
215 rebalance_backpressure_resume_threshold: config
216 .rebalance_backpressure_resume_threshold,
217 split_threshold_vectors: config.split_threshold_vectors,
218 indexed_fields: VectorDbDeltaOpts::indexed_fields_from(&config.metadata_fields),
219 },
220 dictionary: Arc::clone(&dictionary),
221 centroid_graph: Arc::clone(¢roid_graph),
222 id_allocator,
223 current_chunk_id,
224 current_chunk_count,
225 rebalancer,
226 pause_handle: pause_handle.clone(),
227 };
228
229 let coordinator_config = WriteCoordinatorConfig {
231 queue_capacity: 1000,
232 flush_interval: Duration::from_secs(5),
233 flush_size_threshold: 64 * 1024 * 1024,
234 };
235 let mut write_coordinator = WriteCoordinator::new(
236 coordinator_config,
237 vec![WRITE_CHANNEL.to_string(), REBALANCE_CHANNEL.to_string()],
238 ctx,
239 snapshot.clone(),
240 flusher,
241 );
242 handle_tx
243 .set(write_coordinator.handle(REBALANCE_CHANNEL))
244 .map_err(|_e| "unreachable")
245 .unwrap();
246 pause_handle
247 .set(write_coordinator.pause_handle(WRITE_CHANNEL))
248 .map_err(|_e| "unreachable")
249 .unwrap();
250 write_coordinator.start();
251
252 Ok(Self {
253 config,
254 storage,
255 write_coordinator,
256 centroid_graph,
257 })
258 }
259
260 async fn load_or_create_centroids(
264 storage: &Arc<dyn Storage>,
265 snapshot: &dyn StorageSnapshot,
266 config: &Config,
267 centroids: Vec<Vec<f32>>,
268 id_allocator: &mut SequenceAllocator,
269 ) -> Result<(Arc<dyn CentroidGraph>, u32, usize)> {
270 let scan_result = snapshot
272 .scan_all_centroids(config.dimensions as usize)
273 .await?;
274
275 if !scan_result.entries.is_empty() {
276 let last_chunk_id = scan_result.last_chunk_id;
277 let last_chunk_count = scan_result.last_chunk_count;
278 let deletions = snapshot.get_deleted_vectors().await?;
280 let live_centroids: Vec<CentroidEntry> = scan_result
281 .entries
282 .into_iter()
283 .filter(|c| !deletions.contains(c.centroid_id))
284 .collect();
285 let graph = build_centroid_graph(live_centroids, config.distance_metric)?;
286 return Ok((Arc::from(graph), last_chunk_id, last_chunk_count));
287 }
288
289 if centroids.is_empty() {
291 return Err(Error::InvalidInput(
292 "Centroids must be provided when creating a new database".to_string(),
293 ));
294 }
295
296 for centroid in ¢roids {
298 if centroid.len() != config.dimensions as usize {
299 return Err(Error::InvalidInput(format!(
300 "Centroid dimension mismatch: expected {}, got {}",
301 config.dimensions,
302 centroid.len()
303 )));
304 }
305 }
306
307 let mut ops = Vec::new();
309 let mut entries = Vec::with_capacity(centroids.len());
310 for vector in centroids {
311 let (centroid_id, seq_alloc_put) = id_allocator.allocate_one();
312 if let Some(seq_alloc_put) = seq_alloc_put {
313 ops.push(common::storage::RecordOp::Put(seq_alloc_put.into()));
314 }
315 entries.push(CentroidEntry::new(centroid_id, vector));
316 }
317
318 let chunk_target = config.chunk_target as usize;
320 let num_chunks = entries.chunks(chunk_target).len();
321 for (chunk_idx, chunk_entries) in entries.chunks(chunk_target).enumerate() {
322 ops.push(crate::storage::record::put_centroid_chunk(
323 chunk_idx as u32,
324 chunk_entries.to_vec(),
325 config.dimensions as usize,
326 ));
327 }
328 storage.apply(ops).await?;
329
330 let last_chunk_id = if num_chunks == 0 {
332 0
333 } else {
334 (num_chunks - 1) as u32
335 };
336 let last_chunk_count = if entries.is_empty() {
337 0
338 } else {
339 entries.len() - (last_chunk_id as usize * chunk_target)
340 };
341
342 let graph = build_centroid_graph(entries, config.distance_metric)?;
344 Ok((Arc::from(graph), last_chunk_id, last_chunk_count))
345 }
346
347 async fn load_dictionary_from_storage(
349 snapshot: &dyn StorageRead,
350 dictionary: &DashMap<String, u64>,
351 ) -> Result<()> {
352 let mut prefix_buf = bytes::BytesMut::with_capacity(2);
354 crate::serde::RecordType::IdDictionary
355 .prefix()
356 .write_to(&mut prefix_buf);
357 let prefix = prefix_buf.freeze();
358
359 let range = common::BytesRange::prefix(prefix);
361 let records = snapshot.scan(range).await?;
362
363 for record in records {
364 let key = crate::serde::key::IdDictionaryKey::decode(&record.key)?;
366 let external_id = key.external_id.clone();
367
368 let mut slice = record.value.as_ref();
370 let internal_id = common::serde::encoding::decode_u64(&mut slice).map_err(|e| {
371 Error::Encoding(format!(
372 "failed to decode internal ID from ID dictionary: {e}"
373 ))
374 })?;
375
376 dictionary.insert(external_id, internal_id);
377 }
378
379 Ok(())
380 }
381
382 async fn load_centroid_counts_from_storage(
387 snapshot: &dyn StorageRead,
388 ) -> Result<HashMap<u64, u64>> {
389 let stats = snapshot.scan_all_centroid_stats().await?;
390 let mut counts = HashMap::new();
391 for (centroid_id, value) in stats {
392 counts.insert(centroid_id, value.num_vectors.max(0) as u64);
393 }
394 Ok(counts)
395 }
396
397 pub async fn write(&self, vectors: Vec<Vector>) -> Result<()> {
423 let mut writes = Vec::with_capacity(vectors.len());
425 for vector in vectors {
426 writes.push(self.prepare_vector_write(vector)?);
427 }
428
429 let mut write_handle = self
431 .write_coordinator
432 .handle(WRITE_CHANNEL)
433 .write(VectorDbWrite::Write(writes))
434 .await
435 .map_err(|e| Error::Internal(format!("{}", e)))?;
436 write_handle
437 .wait(Durability::Applied)
438 .await
439 .map_err(|e| Error::Internal(format!("{}", e)))?;
440
441 Ok(())
442 }
443
444 pub async fn write_timeout(&self, vectors: Vec<Vector>, timeout: Duration) -> Result<()> {
472 let mut writes = Vec::with_capacity(vectors.len());
474 for vector in vectors {
475 writes.push(self.prepare_vector_write(vector)?);
476 }
477
478 let mut write_handle = self
480 .write_coordinator
481 .handle(WRITE_CHANNEL)
482 .write_timeout(VectorDbWrite::Write(writes), timeout)
483 .await
484 .map_err(|e| Error::Internal(format!("{}", e)))?;
485 write_handle
486 .wait(Durability::Applied)
487 .await
488 .map_err(|e| Error::Internal(format!("{}", e)))?;
489
490 Ok(())
491 }
492
493 fn prepare_vector_write(&self, vector: Vector) -> Result<VectorWrite> {
498 if vector.id.len() > 64 {
500 return Err(Error::InvalidInput(format!(
501 "External ID too long: {} bytes (max 64)",
502 vector.id.len()
503 )));
504 }
505
506 let attributes = attributes_to_map(&vector.attributes);
508
509 let values = match attributes.get(VECTOR_FIELD_NAME) {
511 Some(AttributeValue::Vector(v)) => v.clone(),
512 Some(_) => {
513 return Err(Error::InvalidInput(format!(
514 "Field '{}' must have type Vector",
515 VECTOR_FIELD_NAME
516 )));
517 }
518 None => {
519 return Err(Error::InvalidInput(format!(
520 "Missing required field '{}'",
521 VECTOR_FIELD_NAME
522 )));
523 }
524 };
525
526 if values.len() != self.config.dimensions as usize {
528 return Err(Error::InvalidInput(format!(
529 "Vector dimension mismatch: expected {}, got {}",
530 self.config.dimensions,
531 values.len()
532 )));
533 }
534
535 if !self.config.metadata_fields.is_empty() {
537 self.validate_attributes(&attributes)?;
538 }
539
540 let attributes_vec: Vec<(String, AttributeValue)> = attributes.into_iter().collect();
542
543 Ok(VectorWrite {
544 external_id: vector.id,
545 values,
546 attributes: attributes_vec,
547 })
548 }
549
550 fn validate_attributes(&self, metadata: &HashMap<String, AttributeValue>) -> Result<()> {
552 let schema: HashMap<&str, crate::serde::FieldType> = self
554 .config
555 .metadata_fields
556 .iter()
557 .map(|spec| (spec.name.as_str(), spec.field_type))
558 .collect();
559
560 for (field_name, value) in metadata {
562 if field_name == VECTOR_FIELD_NAME {
564 continue;
565 }
566
567 match schema.get(field_name.as_str()) {
568 Some(expected_type) => {
569 let actual_type = match value {
571 AttributeValue::String(_) => crate::serde::FieldType::String,
572 AttributeValue::Int64(_) => crate::serde::FieldType::Int64,
573 AttributeValue::Float64(_) => crate::serde::FieldType::Float64,
574 AttributeValue::Bool(_) => crate::serde::FieldType::Bool,
575 AttributeValue::Vector(_) => crate::serde::FieldType::Vector,
576 };
577
578 if actual_type != *expected_type {
579 return Err(Error::InvalidInput(format!(
580 "Type mismatch for field '{}': expected {:?}, got {:?}",
581 field_name, expected_type, actual_type
582 )));
583 }
584 }
585 None => {
586 return Err(Error::InvalidInput(format!(
587 "Unknown metadata field: '{}'. Valid fields: {:?}",
588 field_name,
589 schema.keys().collect::<Vec<_>>()
590 )));
591 }
592 }
593 }
594
595 Ok(())
596 }
597
598 pub async fn flush(&self) -> Result<()> {
615 let mut handle = self
616 .write_coordinator
617 .handle(WRITE_CHANNEL)
618 .flush(true)
619 .await
620 .map_err(|e| Error::Internal(format!("{}", e)))?;
621 handle
622 .wait(Durability::Durable)
623 .await
624 .map_err(|e| Error::Internal(format!("{}", e)))?;
625 Ok(())
626 }
627
628 pub async fn close(self) -> Result<()> {
634 self.flush().await?;
635 self.write_coordinator
636 .stop()
637 .await
638 .map_err(Error::Internal)?;
639 self.storage.close().await?;
640 Ok(())
641 }
642
643 pub fn num_centroids(&self) -> usize {
644 self.centroid_graph.len()
645 }
646
647 pub(crate) fn query_engine(&self) -> QueryEngine {
649 let snapshot = self.write_coordinator.view().snapshot.clone();
650 let options = QueryEngineOptions {
651 dimensions: self.config.dimensions,
652 distance_metric: self.config.distance_metric,
653 query_pruning_factor: self.config.query_pruning_factor,
654 };
655 QueryEngine::new(options, self.centroid_graph.clone(), snapshot)
656 }
657
658 pub async fn search_exact_nprobe(
660 &self,
661 query: &Query,
662 nprobe: usize,
663 ) -> Result<Vec<SearchResult>> {
664 self.query_engine().search_exact_nprobe(query, nprobe).await
665 }
666
667 pub async fn snapshot(&self) -> Box<dyn VectorDbRead> {
668 Box::new(VectorDbReader::new(self.query_engine())) as Box<dyn VectorDbRead>
669 }
670}
671
672#[async_trait]
673impl VectorDbRead for VectorDb {
674 async fn search(&self, query: &Query) -> Result<Vec<SearchResult>> {
675 self.query_engine().search(query).await
676 }
677
678 async fn search_with_nprobe(&self, query: &Query, nprobe: usize) -> Result<Vec<SearchResult>> {
679 self.query_engine().search_with_nprobe(query, nprobe).await
680 }
681
682 async fn get(&self, id: &str) -> Result<Option<Vector>> {
683 self.query_engine().get(id).await
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use crate::model::{MetadataFieldSpec, Vector};
691 use crate::serde::FieldType;
692 use crate::serde::collection_meta::DistanceMetric;
693 use crate::serde::key::{IdDictionaryKey, VectorDataKey};
694 use crate::serde::vector_data::VectorDataValue;
695 use common::StorageConfig;
696 use opendata_macros::storage_test;
697 use std::time::Duration;
698
699 fn create_test_config() -> Config {
700 Config {
701 storage: StorageConfig::InMemory,
702 dimensions: 3,
703 distance_metric: DistanceMetric::L2,
704 flush_interval: Duration::from_secs(60),
705 split_threshold_vectors: 10_000,
706 merge_threshold_vectors: 200,
707 split_search_neighbourhood: 8,
708 chunk_target: 4096,
709 metadata_fields: vec![
710 MetadataFieldSpec::new("category", FieldType::String, true),
711 MetadataFieldSpec::new("price", FieldType::Int64, true),
712 ],
713 ..Default::default()
714 }
715 }
716
717 fn create_test_centroids(dimensions: usize) -> Vec<Vec<f32>> {
718 vec![vec![1.0; dimensions]]
719 }
720
721 #[tokio::test]
722 async fn should_open_vector_db() {
723 let config = create_test_config();
725
726 let result = VectorDb::open(config).await;
728
729 assert!(result.is_ok());
731 }
732
733 #[storage_test(merge_operator = VectorDbMergeOperator::new(3))]
734 async fn should_write_and_flush_vectors(storage: Arc<dyn Storage>) {
735 let config = create_test_config();
737 let centroids = create_test_centroids(3);
738 let db = VectorDb::load_or_init_db(Arc::clone(&storage), config, centroids)
739 .await
740 .unwrap();
741
742 let vectors = vec![
743 Vector::builder("vec-1", vec![1.0, 0.0, 0.0])
744 .attribute("category", "shoes")
745 .attribute("price", 99i64)
746 .build(),
747 Vector::builder("vec-2", vec![0.0, 1.0, 0.0])
748 .attribute("category", "boots")
749 .attribute("price", 149i64)
750 .build(),
751 ];
752
753 db.write(vectors).await.unwrap();
755 db.flush().await.unwrap();
756
757 let vec1_data_key = VectorDataKey::new(1).encode();
762 let vec1_data = storage.get(vec1_data_key).await.unwrap();
763 assert!(vec1_data.is_some());
764
765 let vec2_data_key = VectorDataKey::new(2).encode();
766 let vec2_data = storage.get(vec2_data_key).await.unwrap();
767 assert!(vec2_data.is_some());
768
769 let dict_key1 = IdDictionaryKey::new("vec-1").encode();
771 let dict_entry1 = storage.get(dict_key1).await.unwrap();
772 assert!(dict_entry1.is_some());
773 }
774
775 #[storage_test(merge_operator = VectorDbMergeOperator::new(3))]
776 async fn should_upsert_existing_vector(storage: Arc<dyn Storage>) {
777 let config = create_test_config();
779 let centroids = create_test_centroids(3);
780 let db = VectorDb::load_or_init_db(Arc::clone(&storage), config, centroids)
781 .await
782 .unwrap();
783
784 let vector1 = Vector::builder("vec-1", vec![1.0, 0.0, 0.0])
786 .attribute("category", "shoes")
787 .attribute("price", 99i64)
788 .build();
789 db.write(vec![vector1]).await.unwrap();
790 db.flush().await.unwrap();
791
792 let vector2 = Vector::builder("vec-1", vec![2.0, 3.0, 4.0])
794 .attribute("category", "boots")
795 .attribute("price", 199i64)
796 .build();
797 db.write(vec![vector2]).await.unwrap();
798 db.flush().await.unwrap();
799
800 let vec_data_key = VectorDataKey::new(2).encode(); let vec_data = storage.get(vec_data_key).await.unwrap();
804 assert!(vec_data.is_some());
805 let decoded = VectorDataValue::decode_from_bytes(&vec_data.unwrap().value, 3).unwrap();
806 assert_eq!(decoded.vector_field(), &[2.0, 3.0, 4.0]);
807
808 let dict_key = IdDictionaryKey::new("vec-1").encode();
810 let dict_entry = storage.get(dict_key).await.unwrap();
811 assert!(dict_entry.is_some());
812 }
813
814 #[tokio::test]
815 async fn should_reject_vectors_with_wrong_dimensions() {
816 let config = create_test_config();
818 let db = VectorDb::open(config).await.unwrap();
819
820 let vector = Vector::new("vec-1", vec![1.0, 2.0]); let result = db.write(vec![vector]).await;
824
825 assert!(result.is_err());
827 assert!(
828 result
829 .unwrap_err()
830 .to_string()
831 .contains("dimension mismatch")
832 );
833 }
834
835 #[tokio::test]
836 async fn should_flush_empty_delta_without_error() {
837 let config = create_test_config();
839 let db = VectorDb::open(config).await.unwrap();
840
841 let result = db.flush().await;
843
844 assert!(result.is_ok());
846 }
847
848 #[storage_test(merge_operator = VectorDbMergeOperator::new(3))]
849 async fn should_load_dictionary_on_reopen(storage: Arc<dyn Storage>) {
850 let config = create_test_config();
852 let centroids = create_test_centroids(3);
853
854 {
855 let db =
856 VectorDb::load_or_init_db(Arc::clone(&storage), config.clone(), centroids.clone())
857 .await
858 .unwrap();
859 let vectors = vec![
860 Vector::builder("vec-1", vec![1.0, 0.0, 0.0])
861 .attribute("category", "shoes")
862 .attribute("price", 99i64)
863 .build(),
864 Vector::builder("vec-2", vec![0.0, 1.0, 0.0])
865 .attribute("category", "boots")
866 .attribute("price", 149i64)
867 .build(),
868 ];
869 db.write(vectors).await.unwrap();
870 db.flush().await.unwrap();
871 }
872
873 let db2 = VectorDb::load_or_init_db(Arc::clone(&storage), config, vec![])
875 .await
876 .unwrap();
877
878 let results = db2
880 .search(&Query::new(vec![1.0, 0.0, 0.0]).with_limit(10))
881 .await
882 .unwrap();
883 assert!(!results.is_empty());
884 }
885
886 #[tokio::test]
887 async fn flush_should_be_durable_across_reopen() {
888 use common::storage::config::{
889 LocalObjectStoreConfig, ObjectStoreConfig, SlateDbStorageConfig,
890 };
891
892 let tmp_dir = tempfile::tempdir().unwrap();
893 let storage_config = StorageConfig::SlateDb(SlateDbStorageConfig {
894 path: "data".to_string(),
895 object_store: ObjectStoreConfig::Local(LocalObjectStoreConfig {
896 path: tmp_dir.path().to_str().unwrap().to_string(),
897 }),
898 settings_path: None,
899 block_cache: None,
900 });
901
902 let config = Config {
903 storage: storage_config.clone(),
904 dimensions: 3,
905 distance_metric: DistanceMetric::L2,
906 ..Default::default()
907 };
908
909 let db = VectorDb::open(config.clone()).await.unwrap();
911 db.write(vec![
912 Vector::new("vec-1", vec![1.0, 0.0, 0.0]),
913 Vector::new("vec-2", vec![0.0, 1.0, 0.0]),
914 ])
915 .await
916 .unwrap();
917 db.flush().await.unwrap();
918 drop(db);
919
920 let db2 = VectorDb::open(config).await.unwrap();
922 let results = db2
923 .search(&Query::new(vec![1.0, 0.0, 0.0]).with_limit(10))
924 .await
925 .unwrap();
926 assert!(
927 !results.is_empty(),
928 "expected data to be durable after flush, but search returned no results"
929 );
930 }
931
932 #[tokio::test]
933 #[allow(clippy::needless_return)]
934 async fn close_without_explicit_flush_guarantees_durability() {
935 use common::storage::config::{
936 LocalObjectStoreConfig, ObjectStoreConfig, SlateDbStorageConfig,
937 };
938
939 let tmp_dir = tempfile::tempdir().unwrap();
940 let storage_config = StorageConfig::SlateDb(SlateDbStorageConfig {
941 path: "data".to_string(),
942 object_store: ObjectStoreConfig::Local(LocalObjectStoreConfig {
943 path: tmp_dir.path().to_str().unwrap().to_string(),
944 }),
945 settings_path: None,
946 block_cache: None,
947 });
948
949 let config = Config {
950 storage: storage_config.clone(),
951 dimensions: 3,
952 distance_metric: DistanceMetric::L2,
953 ..Default::default()
954 };
955
956 {
958 let db = VectorDb::open(config.clone()).await.unwrap();
959 db.write(vec![Vector::new("vec-1", vec![1.0, 0.0, 0.0])])
960 .await
961 .unwrap();
962 db.close().await.unwrap();
963 }
964
965 let db2 = VectorDb::open(config).await.unwrap();
967 let results = db2
968 .search(&Query::new(vec![1.0, 0.0, 0.0]).with_limit(1))
969 .await
970 .unwrap();
971 assert_eq!(results.len(), 1);
972 assert_eq!(results[0].vector.id, "vec-1");
973 }
974
975 #[tokio::test]
976 async fn should_fail_if_no_centroids_provided_for_new_db() {
977 let config = create_test_config();
979
980 let sb = StorageBuilder::new(&config.storage).await.unwrap();
982 let result = VectorDb::open_with_centroids(config, vec![], sb).await;
983
984 match result {
986 Err(e) => assert!(
987 e.to_string().contains("Centroids must be provided"),
988 "unexpected error: {}",
989 e
990 ),
991 Ok(_) => panic!("expected error when no centroids provided"),
992 }
993 }
994}