1use hnsw_rs::prelude::*;
7use ipfrs_core::{Cid, Error, Result};
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
13pub enum DistanceMetric {
14 L2,
16 Cosine,
18 DotProduct,
20}
21
22#[derive(Debug, Clone)]
24pub struct SearchResult {
25 pub cid: Cid,
27 pub score: f32,
29}
30
31#[derive(Debug, Clone)]
33pub struct IncrementalBuildStats {
34 pub initial_size: usize,
36 pub final_size: usize,
38 pub vectors_inserted: usize,
40 pub vectors_failed: usize,
42 pub chunks_processed: usize,
44 pub should_rebuild: bool,
46}
47
48#[derive(Debug, Clone)]
50pub struct RebuildStats {
51 pub vectors_reinserted: usize,
53 pub old_parameters: (usize, usize),
55 pub new_parameters: (usize, usize),
57}
58
59#[derive(Debug, Clone)]
61pub struct BuildHealthStats {
62 pub index_size: usize,
64 pub current_m: usize,
66 pub current_ef_construction: usize,
68 pub optimal_m: usize,
70 pub optimal_ef_construction: usize,
72 pub parameter_efficiency: f32,
74 pub rebuild_recommended: bool,
76}
77
78pub struct VectorIndex {
83 index: Arc<RwLock<Hnsw<'static, f32, DistL2>>>,
85 id_to_cid: Arc<RwLock<HashMap<usize, Cid>>>,
87 cid_to_id: Arc<RwLock<HashMap<Cid, usize>>>,
89 vectors: Arc<RwLock<HashMap<Cid, Vec<f32>>>>,
91 next_id: Arc<RwLock<usize>>,
93 dimension: usize,
95 metric: DistanceMetric,
97}
98
99impl VectorIndex {
100 pub fn new(
108 dimension: usize,
109 metric: DistanceMetric,
110 max_nb_connection: usize,
111 ef_construction: usize,
112 ) -> Result<Self> {
113 if dimension == 0 {
114 return Err(Error::InvalidInput(
115 "Vector dimension must be greater than 0".to_string(),
116 ));
117 }
118
119 let index = Hnsw::<f32, DistL2>::new(
121 max_nb_connection,
122 dimension,
123 ef_construction,
124 200, DistL2 {},
126 );
127
128 Ok(Self {
129 index: Arc::new(RwLock::new(index)),
130 id_to_cid: Arc::new(RwLock::new(HashMap::new())),
131 cid_to_id: Arc::new(RwLock::new(HashMap::new())),
132 vectors: Arc::new(RwLock::new(HashMap::new())),
133 next_id: Arc::new(RwLock::new(0)),
134 dimension,
135 metric,
136 })
137 }
138
139 pub fn with_defaults(dimension: usize) -> Result<Self> {
143 Self::new(dimension, DistanceMetric::L2, 16, 200)
144 }
145
146 pub fn insert(&mut self, cid: &Cid, vector: &[f32]) -> Result<()> {
152 if vector.len() != self.dimension {
153 return Err(Error::InvalidInput(format!(
154 "Vector dimension mismatch: expected {}, got {}",
155 self.dimension,
156 vector.len()
157 )));
158 }
159
160 if self.cid_to_id.read().unwrap().contains_key(cid) {
162 return Err(Error::InvalidInput(format!(
163 "CID already exists in index: {}",
164 cid
165 )));
166 }
167
168 let mut next_id = self.next_id.write().unwrap();
170 let id = *next_id;
171 *next_id += 1;
172 drop(next_id);
173
174 let normalized = self.normalize_vector(vector);
176
177 let data_with_id = (normalized.as_slice(), id);
179 self.index.write().unwrap().insert(data_with_id);
180
181 self.vectors.write().unwrap().insert(*cid, vector.to_vec());
183
184 self.id_to_cid.write().unwrap().insert(id, *cid);
186 self.cid_to_id.write().unwrap().insert(*cid, id);
187
188 Ok(())
189 }
190
191 pub fn search(&self, query: &[f32], k: usize, ef_search: usize) -> Result<Vec<SearchResult>> {
198 if query.len() != self.dimension {
199 return Err(Error::InvalidInput(format!(
200 "Query dimension mismatch: expected {}, got {}",
201 self.dimension,
202 query.len()
203 )));
204 }
205
206 if k == 0 {
207 return Ok(Vec::new());
208 }
209
210 let normalized = self.normalize_vector(query);
212
213 let neighbors = self.index.read().unwrap().search(&normalized, k, ef_search);
215
216 let id_to_cid = self.id_to_cid.read().unwrap();
218 let results: Vec<SearchResult> = neighbors
219 .iter()
220 .filter_map(|neighbor| {
221 id_to_cid.get(&neighbor.d_id).map(|cid| SearchResult {
222 cid: *cid,
223 score: self.convert_distance(neighbor.distance),
224 })
225 })
226 .collect();
227
228 Ok(results)
229 }
230
231 pub fn delete(&mut self, cid: &Cid) -> Result<()> {
233 let id = self
234 .cid_to_id
235 .read()
236 .unwrap()
237 .get(cid)
238 .copied()
239 .ok_or_else(|| Error::NotFound(format!("CID not found in index: {}", cid)))?;
240
241 self.vectors.write().unwrap().remove(cid);
243
244 self.cid_to_id.write().unwrap().remove(cid);
246 self.id_to_cid.write().unwrap().remove(&id);
247
248 Ok(())
252 }
253
254 pub fn contains(&self, cid: &Cid) -> bool {
256 self.cid_to_id.read().unwrap().contains_key(cid)
257 }
258
259 pub fn len(&self) -> usize {
261 self.cid_to_id.read().unwrap().len()
262 }
263
264 pub fn is_empty(&self) -> bool {
266 self.len() == 0
267 }
268
269 pub fn dimension(&self) -> usize {
271 self.dimension
272 }
273
274 pub fn metric(&self) -> DistanceMetric {
276 self.metric
277 }
278
279 pub fn get_all_cids(&self) -> Vec<Cid> {
282 self.cid_to_id.read().unwrap().keys().copied().collect()
283 }
284
285 pub fn get_embedding(&self, cid: &Cid) -> Option<Vec<f32>> {
289 self.vectors.read().unwrap().get(cid).cloned()
290 }
291
292 pub fn get_all_embeddings(&self) -> Vec<(Cid, Vec<f32>)> {
296 self.vectors
297 .read()
298 .unwrap()
299 .iter()
300 .map(|(cid, vec)| (*cid, vec.clone()))
301 .collect()
302 }
303
304 pub fn iter(&self) -> Vec<(Cid, Vec<f32>)> {
308 self.get_all_embeddings()
309 }
310
311 fn normalize_vector(&self, vector: &[f32]) -> Vec<f32> {
313 match self.metric {
314 DistanceMetric::L2 => vector.to_vec(),
315 DistanceMetric::Cosine => {
316 let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
318 if norm > 0.0 {
319 vector.iter().map(|x| x / norm).collect()
320 } else {
321 vector.to_vec()
322 }
323 }
324 DistanceMetric::DotProduct => {
325 vector.to_vec()
327 }
328 }
329 }
330
331 fn convert_distance(&self, distance: f32) -> f32 {
333 match self.metric {
334 DistanceMetric::L2 => distance,
335 DistanceMetric::Cosine => {
336 1.0 - (distance * distance / 2.0)
339 }
340 DistanceMetric::DotProduct => {
341 -distance
343 }
344 }
345 }
346
347 pub fn compute_optimal_parameters(&self) -> (usize, usize) {
354 let size = self.len();
355
356 if size < 10_000 {
357 (16, 200) } else if size < 100_000 {
359 (32, 400) } else {
361 (48, 600) }
363 }
364
365 pub fn compute_optimal_ef_search(&self, k: usize) -> usize {
369 if k <= 50 {
372 50.max(k)
373 } else {
374 2 * k
375 }
376 }
377
378 pub fn get_parameter_recommendations(&self, use_case: UseCase) -> ParameterRecommendation {
380 let size = self.len();
381 ParameterTuner::recommend(size, self.dimension, use_case)
382 }
383
384 pub fn insert_batch(&mut self, items: &[(Cid, Vec<f32>)]) -> Result<()> {
391 for (cid, vector) in items {
392 self.insert(cid, vector)?;
393 }
394 Ok(())
395 }
396
397 pub fn insert_incremental(
409 &mut self,
410 items: &[(Cid, Vec<f32>)],
411 chunk_size: usize,
412 ) -> Result<IncrementalBuildStats> {
413 let start_size = self.len();
414 let mut chunks_processed = 0;
415 let mut failed_inserts = 0;
416
417 for chunk in items.chunks(chunk_size) {
419 for (cid, vector) in chunk {
420 if let Err(_e) = self.insert(cid, vector) {
421 failed_inserts += 1;
422 }
423 }
424 chunks_processed += 1;
425 }
426
427 let end_size = self.len();
428 let inserted = end_size - start_size;
429
430 let should_rebuild = self.should_rebuild();
432
433 Ok(IncrementalBuildStats {
434 initial_size: start_size,
435 final_size: end_size,
436 vectors_inserted: inserted,
437 vectors_failed: failed_inserts,
438 chunks_processed,
439 should_rebuild,
440 })
441 }
442
443 pub fn should_rebuild(&self) -> bool {
450 let size = self.len();
451 let (current_m, current_ef) = {
452 let idx = self.index.read().unwrap();
453 (
454 idx.get_max_nb_connection() as usize,
455 idx.get_ef_construction(),
456 )
457 };
458
459 let (optimal_m, optimal_ef) = self.compute_optimal_parameters();
460
461 if current_m < optimal_m / 2 || current_ef < optimal_ef / 2 {
463 return true;
464 }
465
466 if size > 100_000 && current_m < 32 {
468 return true;
469 }
470
471 false
472 }
473
474 pub fn rebuild(&mut self, use_case: UseCase) -> Result<RebuildStats> {
482 let start_size = self.len();
483
484 if start_size == 0 {
485 return Ok(RebuildStats {
486 vectors_reinserted: 0,
487 old_parameters: (0, 0),
488 new_parameters: (0, 0),
489 });
490 }
491
492 let _id_to_cid = self.id_to_cid.read().unwrap();
494
495 let old_params = {
500 let idx = self.index.read().unwrap();
501 (
502 idx.get_max_nb_connection() as usize,
503 idx.get_ef_construction(),
504 )
505 };
506
507 let recommendation = ParameterTuner::recommend(start_size, self.dimension, use_case);
509
510 let new_index = Hnsw::<f32, DistL2>::new(
512 recommendation.m,
513 self.dimension,
514 recommendation.ef_construction,
515 start_size,
516 DistL2 {},
517 );
518
519 *self.index.write().unwrap() = new_index;
521
522 Ok(RebuildStats {
526 vectors_reinserted: 0, old_parameters: old_params,
528 new_parameters: (recommendation.m, recommendation.ef_construction),
529 })
530 }
531
532 pub fn get_build_stats(&self) -> BuildHealthStats {
534 let size = self.len();
535 let (current_m, current_ef) = {
536 let idx = self.index.read().unwrap();
537 (
538 idx.get_max_nb_connection() as usize,
539 idx.get_ef_construction(),
540 )
541 };
542
543 let (optimal_m, optimal_ef) = self.compute_optimal_parameters();
544
545 let parameter_efficiency = if optimal_m > 0 {
546 (current_m as f32 / optimal_m as f32).min(1.0)
547 } else {
548 1.0
549 };
550
551 BuildHealthStats {
552 index_size: size,
553 current_m,
554 current_ef_construction: current_ef,
555 optimal_m,
556 optimal_ef_construction: optimal_ef,
557 parameter_efficiency,
558 rebuild_recommended: self.should_rebuild(),
559 }
560 }
561
562 pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<()> {
570 use std::fs::File;
571 use std::io::Write;
572
573 let (max_nb_connection, ef_construction) = {
575 let idx = self.index.read().unwrap();
576 (idx.get_max_nb_connection(), idx.get_ef_construction())
577 };
578
579 let metadata = IndexMetadata {
581 dimension: self.dimension,
582 metric: self.metric,
583 id_to_cid: self.id_to_cid.read().unwrap().clone(),
584 cid_to_id: self.cid_to_id.read().unwrap().clone(),
585 vectors: self.vectors.read().unwrap().clone(),
586 next_id: *self.next_id.read().unwrap(),
587 max_nb_connection: max_nb_connection as usize,
588 ef_construction,
589 };
590
591 let encoded = oxicode::serde::encode_to_vec(&metadata, oxicode::config::standard())
593 .map_err(|e| Error::Serialization(format!("Failed to serialize index: {}", e)))?;
594
595 let mut file = File::create(path.as_ref())
597 .map_err(|e| Error::Storage(format!("Failed to create index file: {}", e)))?;
598
599 file.write_all(&encoded)
600 .map_err(|e| Error::Storage(format!("Failed to write index file: {}", e)))?;
601
602 Ok(())
603 }
604
605 pub fn load(path: impl AsRef<std::path::Path>) -> Result<Self> {
612 use std::fs::File;
613 use std::io::Read;
614
615 let mut file = File::open(path.as_ref())
617 .map_err(|e| Error::Storage(format!("Failed to open index file: {}", e)))?;
618
619 let mut buffer = Vec::new();
620 file.read_to_end(&mut buffer)
621 .map_err(|e| Error::Storage(format!("Failed to read index file: {}", e)))?;
622
623 let metadata: IndexMetadata =
625 oxicode::serde::decode_owned_from_slice(&buffer, oxicode::config::standard())
626 .map(|(v, _)| v)
627 .map_err(|e| {
628 Error::Deserialization(format!("Failed to deserialize index: {}", e))
629 })?;
630
631 let index = Hnsw::<f32, DistL2>::new(
633 metadata.max_nb_connection,
634 metadata.dimension,
635 metadata.ef_construction,
636 200,
637 DistL2 {},
638 );
639
640 Ok(Self {
641 index: Arc::new(RwLock::new(index)),
642 id_to_cid: Arc::new(RwLock::new(metadata.id_to_cid)),
643 cid_to_id: Arc::new(RwLock::new(metadata.cid_to_id)),
644 vectors: Arc::new(RwLock::new(metadata.vectors)),
645 next_id: Arc::new(RwLock::new(metadata.next_id)),
646 dimension: metadata.dimension,
647 metric: metadata.metric,
648 })
649 }
650}
651
652#[derive(serde::Serialize, serde::Deserialize)]
654struct IndexMetadata {
655 dimension: usize,
656 metric: DistanceMetric,
657 #[serde(
658 serialize_with = "serialize_id_to_cid",
659 deserialize_with = "deserialize_id_to_cid"
660 )]
661 id_to_cid: HashMap<usize, Cid>,
662 #[serde(
663 serialize_with = "serialize_cid_to_id",
664 deserialize_with = "deserialize_cid_to_id"
665 )]
666 cid_to_id: HashMap<Cid, usize>,
667 #[serde(
668 serialize_with = "serialize_vectors",
669 deserialize_with = "deserialize_vectors"
670 )]
671 vectors: HashMap<Cid, Vec<f32>>,
672 next_id: usize,
673 max_nb_connection: usize,
674 ef_construction: usize,
675}
676
677fn serialize_id_to_cid<S>(
679 map: &HashMap<usize, Cid>,
680 serializer: S,
681) -> std::result::Result<S::Ok, S::Error>
682where
683 S: serde::Serializer,
684{
685 use serde::Serialize;
686 let string_map: HashMap<usize, String> =
687 map.iter().map(|(id, cid)| (*id, cid.to_string())).collect();
688 string_map.serialize(serializer)
689}
690
691fn deserialize_id_to_cid<'de, D>(
693 deserializer: D,
694) -> std::result::Result<HashMap<usize, Cid>, D::Error>
695where
696 D: serde::Deserializer<'de>,
697{
698 use serde::Deserialize;
699 let string_map: HashMap<usize, String> = HashMap::deserialize(deserializer)?;
700 string_map
701 .into_iter()
702 .map(|(id, cid_str)| {
703 cid_str
704 .parse::<Cid>()
705 .map(|cid| (id, cid))
706 .map_err(serde::de::Error::custom)
707 })
708 .collect()
709}
710
711fn serialize_cid_to_id<S>(
713 map: &HashMap<Cid, usize>,
714 serializer: S,
715) -> std::result::Result<S::Ok, S::Error>
716where
717 S: serde::Serializer,
718{
719 use serde::Serialize;
720 let string_map: HashMap<String, usize> =
721 map.iter().map(|(cid, id)| (cid.to_string(), *id)).collect();
722 string_map.serialize(serializer)
723}
724
725fn deserialize_cid_to_id<'de, D>(
727 deserializer: D,
728) -> std::result::Result<HashMap<Cid, usize>, D::Error>
729where
730 D: serde::Deserializer<'de>,
731{
732 use serde::Deserialize;
733 let string_map: HashMap<String, usize> = HashMap::deserialize(deserializer)?;
734 string_map
735 .into_iter()
736 .map(|(cid_str, id)| {
737 cid_str
738 .parse::<Cid>()
739 .map(|cid| (cid, id))
740 .map_err(serde::de::Error::custom)
741 })
742 .collect()
743}
744
745fn serialize_vectors<S>(
747 map: &HashMap<Cid, Vec<f32>>,
748 serializer: S,
749) -> std::result::Result<S::Ok, S::Error>
750where
751 S: serde::Serializer,
752{
753 use serde::Serialize;
754 let string_map: HashMap<String, Vec<f32>> = map
755 .iter()
756 .map(|(cid, vec)| (cid.to_string(), vec.clone()))
757 .collect();
758 string_map.serialize(serializer)
759}
760
761fn deserialize_vectors<'de, D>(
763 deserializer: D,
764) -> std::result::Result<HashMap<Cid, Vec<f32>>, D::Error>
765where
766 D: serde::Deserializer<'de>,
767{
768 use serde::Deserialize;
769 let string_map: HashMap<String, Vec<f32>> = HashMap::deserialize(deserializer)?;
770 string_map
771 .into_iter()
772 .map(|(cid_str, vec)| {
773 cid_str
774 .parse::<Cid>()
775 .map(|cid| (cid, vec))
776 .map_err(serde::de::Error::custom)
777 })
778 .collect()
779}
780
781#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Default)]
783pub enum UseCase {
784 LowLatency,
786 HighRecall,
788 #[default]
790 Balanced,
791 LowMemory,
793 LargeScale,
795}
796
797#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
799pub struct ParameterRecommendation {
800 pub m: usize,
802 pub ef_construction: usize,
804 pub ef_search: usize,
806 pub memory_per_vector: usize,
808 pub estimated_recall: f32,
810 pub latency_factor: f32,
812 pub explanation: String,
814}
815
816pub struct ParameterTuner;
818
819impl ParameterTuner {
820 pub fn recommend(
822 num_vectors: usize,
823 dimension: usize,
824 use_case: UseCase,
825 ) -> ParameterRecommendation {
826 let (m, ef_construction, ef_search, recall, latency) = match use_case {
827 UseCase::LowLatency => {
828 if num_vectors < 10_000 {
829 (8, 100, 32, 0.90, 0.6)
830 } else if num_vectors < 100_000 {
831 (12, 150, 50, 0.88, 0.7)
832 } else {
833 (16, 200, 64, 0.85, 0.8)
834 }
835 }
836 UseCase::HighRecall => {
837 if num_vectors < 10_000 {
838 (32, 400, 200, 0.99, 2.0)
839 } else if num_vectors < 100_000 {
840 (48, 500, 300, 0.98, 2.5)
841 } else {
842 (64, 600, 400, 0.97, 3.0)
843 }
844 }
845 UseCase::Balanced => {
846 if num_vectors < 10_000 {
847 (16, 200, 50, 0.95, 1.0)
848 } else if num_vectors < 100_000 {
849 (24, 300, 100, 0.94, 1.2)
850 } else {
851 (32, 400, 150, 0.93, 1.5)
852 }
853 }
854 UseCase::LowMemory => {
855 if num_vectors < 10_000 {
856 (8, 100, 50, 0.88, 0.9)
857 } else if num_vectors < 100_000 {
858 (10, 120, 64, 0.85, 1.0)
859 } else {
860 (12, 150, 80, 0.82, 1.1)
861 }
862 }
863 UseCase::LargeScale => {
864 (32, 400, 100, 0.93, 1.5)
866 }
867 };
868
869 let memory_per_vector = dimension * 4 + m * 2 * 4;
871
872 let explanation =
873 Self::generate_explanation(num_vectors, use_case, m, ef_construction, ef_search);
874
875 ParameterRecommendation {
876 m,
877 ef_construction,
878 ef_search,
879 memory_per_vector,
880 estimated_recall: recall,
881 latency_factor: latency,
882 explanation,
883 }
884 }
885
886 fn generate_explanation(
887 num_vectors: usize,
888 use_case: UseCase,
889 m: usize,
890 ef_construction: usize,
891 ef_search: usize,
892 ) -> String {
893 let size_category = if num_vectors < 10_000 {
894 "small"
895 } else if num_vectors < 100_000 {
896 "medium"
897 } else {
898 "large"
899 };
900
901 let use_case_str = match use_case {
902 UseCase::LowLatency => "low latency",
903 UseCase::HighRecall => "high recall",
904 UseCase::Balanced => "balanced",
905 UseCase::LowMemory => "low memory",
906 UseCase::LargeScale => "large scale",
907 };
908
909 format!(
910 "For {} dataset (~{} vectors) optimized for {}: \
911 M={} provides good connectivity, ef_construction={} ensures quality graph, \
912 ef_search={} balances speed and accuracy.",
913 size_category, num_vectors, use_case_str, m, ef_construction, ef_search
914 )
915 }
916
917 pub fn pareto_configurations(
919 num_vectors: usize,
920 dimension: usize,
921 ) -> Vec<ParameterRecommendation> {
922 vec![
923 Self::recommend(num_vectors, dimension, UseCase::LowLatency),
924 Self::recommend(num_vectors, dimension, UseCase::LowMemory),
925 Self::recommend(num_vectors, dimension, UseCase::Balanced),
926 Self::recommend(num_vectors, dimension, UseCase::HighRecall),
927 ]
928 }
929
930 pub fn estimate_memory(num_vectors: usize, dimension: usize, m: usize) -> usize {
932 let vector_memory = num_vectors * dimension * 4;
934
935 let graph_memory = num_vectors * m * 2 * 4;
937
938 let overhead = num_vectors * 50;
940
941 vector_memory + graph_memory + overhead
942 }
943
944 pub fn ef_search_for_recall(k: usize, target_recall: f32) -> usize {
946 let multiplier = if target_recall >= 0.99 {
949 10.0
950 } else if target_recall >= 0.95 {
951 4.0
952 } else if target_recall >= 0.90 {
953 2.0
954 } else {
955 1.5
956 };
957
958 ((k as f32) * multiplier).ceil() as usize
959 }
960}
961
962#[cfg(test)]
963mod tests {
964 use super::*;
965 use rand::Rng;
966
967 #[test]
968 fn test_vector_index_creation() {
969 let index = VectorIndex::with_defaults(128);
970 assert!(index.is_ok());
971 let index = index.unwrap();
972 assert_eq!(index.dimension(), 128);
973 assert_eq!(index.len(), 0);
974 assert!(index.is_empty());
975 }
976
977 #[test]
978 fn test_insert_and_search() {
979 let mut index = VectorIndex::with_defaults(4).unwrap();
980
981 let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
983 .parse::<Cid>()
984 .unwrap();
985 let vec1 = vec![1.0, 0.0, 0.0, 0.0];
986
987 let cid2 = "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354"
988 .parse::<Cid>()
989 .unwrap();
990 let vec2 = vec![0.9, 0.1, 0.0, 0.0];
991
992 index.insert(&cid1, &vec1).unwrap();
994 index.insert(&cid2, &vec2).unwrap();
995
996 assert_eq!(index.len(), 2);
997
998 let query = vec![1.0, 0.0, 0.0, 0.0];
1000 let results = index.search(&query, 1, 50).unwrap();
1001
1002 assert_eq!(results.len(), 1);
1003 assert_eq!(results[0].cid, cid1);
1004 }
1005
1006 #[test]
1007 fn test_parameter_tuner() {
1008 let balanced = ParameterTuner::recommend(50_000, 768, UseCase::Balanced);
1010 assert!(balanced.m > 0);
1011 assert!(balanced.ef_construction > 0);
1012 assert!(balanced.estimated_recall > 0.0);
1013
1014 let low_latency = ParameterTuner::recommend(50_000, 768, UseCase::LowLatency);
1015 let high_recall = ParameterTuner::recommend(50_000, 768, UseCase::HighRecall);
1016
1017 assert!(high_recall.m > low_latency.m);
1019 assert!(high_recall.estimated_recall > low_latency.estimated_recall);
1021
1022 let pareto = ParameterTuner::pareto_configurations(50_000, 768);
1024 assert_eq!(pareto.len(), 4);
1025
1026 let memory = ParameterTuner::estimate_memory(100_000, 768, 16);
1028 assert!(memory > 0);
1029
1030 let ef_high = ParameterTuner::ef_search_for_recall(10, 0.99);
1032 let ef_low = ParameterTuner::ef_search_for_recall(10, 0.85);
1033 assert!(ef_high > ef_low);
1034 }
1035
1036 #[test]
1037 fn test_incremental_build() {
1038 let mut index = VectorIndex::with_defaults(4).unwrap();
1039
1040 let items: Vec<(Cid, Vec<f32>)> = (0..20)
1042 .map(|i| {
1043 let cid_str = format!(
1044 "bafybei{}yrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1045 i
1046 );
1047 let cid = cid_str.parse::<Cid>().unwrap_or_else(|_| {
1048 "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
1049 .parse()
1050 .unwrap()
1051 });
1052 let vec = vec![i as f32, 0.0, 0.0, 0.0];
1053 (cid, vec)
1054 })
1055 .collect();
1056
1057 let stats = index.insert_incremental(&items, 5).unwrap();
1059
1060 assert_eq!(stats.chunks_processed, 4);
1061 assert!(stats.vectors_inserted <= 20);
1062 assert_eq!(stats.final_size, index.len());
1063 }
1064
1065 #[test]
1066 fn test_build_health_stats() {
1067 let index = VectorIndex::new(128, DistanceMetric::L2, 16, 200).unwrap();
1068
1069 let stats = index.get_build_stats();
1070 assert_eq!(stats.index_size, 0);
1071 assert_eq!(stats.current_m, 16);
1072 assert_eq!(stats.current_ef_construction, 200);
1073 assert!(stats.parameter_efficiency > 0.0);
1074
1075 assert!(!stats.rebuild_recommended);
1077 }
1078
1079 #[test]
1080 fn test_should_rebuild() {
1081 let index1 = VectorIndex::new(128, DistanceMetric::L2, 16, 200).unwrap();
1083 assert!(!index1.should_rebuild());
1084
1085 let index2 = VectorIndex::new(128, DistanceMetric::L2, 4, 50).unwrap();
1087 let _ = index2.should_rebuild();
1090 }
1091
1092 #[test]
1093 fn test_rebuild() {
1094 let mut index = VectorIndex::with_defaults(4).unwrap();
1095
1096 for i in 0..10 {
1098 let cid_str = format!(
1099 "bafybei{}yrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1100 i
1101 );
1102 let cid = cid_str.parse::<Cid>().unwrap_or_else(|_| {
1103 "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
1104 .parse()
1105 .unwrap()
1106 });
1107 let vec = vec![i as f32, 0.0, 0.0, 0.0];
1108 let _ = index.insert(&cid, &vec);
1109 }
1110
1111 let rebuild_stats = index.rebuild(UseCase::Balanced).unwrap();
1113
1114 assert_eq!(rebuild_stats.old_parameters.0, 16); assert!(rebuild_stats.new_parameters.0 > 0); }
1117
1118 fn compute_ground_truth(query: &[f32], vectors: &[(Cid, Vec<f32>)], k: usize) -> Vec<Cid> {
1120 let mut distances: Vec<(Cid, f32)> = vectors
1121 .iter()
1122 .map(|(cid, vec)| {
1123 let dist: f32 = query
1124 .iter()
1125 .zip(vec.iter())
1126 .map(|(a, b)| (a - b).powi(2))
1127 .sum();
1128 (*cid, dist.sqrt())
1129 })
1130 .collect();
1131
1132 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
1133 distances.iter().take(k).map(|(cid, _)| *cid).collect()
1134 }
1135
1136 fn calculate_recall_at_k(predicted: &[Cid], ground_truth: &[Cid], k: usize) -> f32 {
1138 let predicted_set: std::collections::HashSet<_> = predicted.iter().take(k).collect();
1139 let ground_truth_set: std::collections::HashSet<_> = ground_truth.iter().take(k).collect();
1140
1141 let intersection = predicted_set.intersection(&ground_truth_set).count();
1142 intersection as f32 / k as f32
1143 }
1144
1145 fn generate_test_cid(index: usize) -> Cid {
1147 use multihash_codetable::{Code, MultihashDigest};
1148 let data = format!("test_vector_{}", index);
1149 let hash = Code::Sha2_256.digest(data.as_bytes());
1150 Cid::new_v1(0x55, hash) }
1152
1153 #[test]
1154 fn test_recall_at_k() {
1155 let mut index = VectorIndex::with_defaults(128).unwrap();
1157
1158 let mut rng = rand::rng();
1160 let num_vectors = 100;
1161 let dimension = 128;
1162
1163 let mut vectors = Vec::new();
1164 for i in 0..num_vectors {
1165 let cid = generate_test_cid(i);
1166
1167 let vec: Vec<f32> = (0..dimension)
1168 .map(|_| rng.random_range(-1.0..1.0))
1169 .collect();
1170
1171 vectors.push((cid, vec.clone()));
1172 let _ = index.insert(&cid, &vec);
1173 }
1174
1175 let num_queries = 10;
1177 let mut total_recall_at_1 = 0.0;
1178 let mut total_recall_at_10 = 0.0;
1179
1180 for _ in 0..num_queries {
1181 let query: Vec<f32> = (0..dimension)
1182 .map(|_| rng.random_range(-1.0..1.0))
1183 .collect();
1184
1185 let hnsw_results = index.search(&query, 10, 50).unwrap();
1187 let hnsw_cids: Vec<Cid> = hnsw_results.iter().map(|r| r.cid).collect();
1188
1189 let ground_truth = compute_ground_truth(&query, &vectors, 10);
1191
1192 total_recall_at_1 += calculate_recall_at_k(&hnsw_cids, &ground_truth, 1);
1194 total_recall_at_10 += calculate_recall_at_k(&hnsw_cids, &ground_truth, 10);
1195 }
1196
1197 let avg_recall_at_1 = total_recall_at_1 / num_queries as f32;
1198 let avg_recall_at_10 = total_recall_at_10 / num_queries as f32;
1199
1200 assert!(
1202 avg_recall_at_10 > 0.8,
1203 "Recall@10 too low: {}",
1204 avg_recall_at_10
1205 );
1206
1207 assert!(
1209 avg_recall_at_1 > 0.5,
1210 "Recall@1 too low: {}",
1211 avg_recall_at_1
1212 );
1213 }
1214
1215 #[test]
1216 fn test_concurrent_queries() {
1217 use std::sync::Arc;
1218 use std::thread;
1219
1220 let mut index = VectorIndex::with_defaults(128).unwrap();
1222
1223 let mut rng = rand::rng();
1225 for i in 0..100 {
1226 let cid = generate_test_cid(i + 1000); let vec: Vec<f32> = (0..128).map(|_| rng.random_range(-1.0..1.0)).collect();
1229
1230 let _ = index.insert(&cid, &vec);
1231 }
1232
1233 let index = Arc::new(index);
1235 let num_threads = 10;
1236 let queries_per_thread = 100;
1237
1238 let mut handles = vec![];
1240 for _ in 0..num_threads {
1241 let index_clone = Arc::clone(&index);
1242 let handle = thread::spawn(move || {
1243 let mut thread_rng = rand::rng();
1244 let mut success_count = 0;
1245
1246 for _ in 0..queries_per_thread {
1247 let query: Vec<f32> = (0..128)
1248 .map(|_| thread_rng.random_range(-1.0..1.0))
1249 .collect();
1250
1251 if let Ok(results) = index_clone.search(&query, 10, 50) {
1252 if !results.is_empty() {
1253 success_count += 1;
1254 }
1255 }
1256 }
1257 success_count
1258 });
1259 handles.push(handle);
1260 }
1261
1262 let mut total_success = 0;
1264 for handle in handles {
1265 total_success += handle.join().unwrap();
1266 }
1267
1268 let total_queries = num_threads * queries_per_thread;
1270 assert_eq!(
1271 total_success, total_queries,
1272 "Some queries failed under concurrent load"
1273 );
1274 }
1275
1276 #[test]
1277 fn test_precision_at_k() {
1278 let mut index = VectorIndex::with_defaults(32).unwrap();
1280
1281 let num_clusters = 5;
1283 let vectors_per_cluster = 10;
1284
1285 for cluster in 0..num_clusters {
1286 let mut center = [0.0; 32];
1288 center[cluster] = 10.0;
1289
1290 for i in 0..vectors_per_cluster {
1291 let idx = cluster * vectors_per_cluster + i;
1292 let cid = generate_test_cid(idx + 2000); let mut rng = rand::rng();
1296 let vec: Vec<f32> = center
1297 .iter()
1298 .map(|&c| c + rng.random_range(-0.5..0.5))
1299 .collect();
1300
1301 let _ = index.insert(&cid, &vec);
1302 }
1303 }
1304
1305 let mut query = vec![0.0; 32];
1307 query[0] = 10.0;
1308
1309 let results = index.search(&query, 10, 50).unwrap();
1310
1311 assert_eq!(results.len(), 10, "Should return 10 results");
1315
1316 for result in &results {
1318 assert!(
1319 result.score < 5.0,
1320 "Result too far from query: {}",
1321 result.score
1322 );
1323 }
1324 }
1325}