1use super::binary_quantize::{BinaryIndex, BinaryVector};
55use super::int8_quantize::Int8Index;
56use std::cmp::Ordering;
57
58#[derive(Debug, Clone)]
60pub struct TieredSearchConfig {
61 pub rescore_multiplier: usize,
64 pub use_fp32_final: bool,
66 pub min_binary_candidates: usize,
68 pub max_binary_candidates: usize,
70}
71
72impl Default for TieredSearchConfig {
73 fn default() -> Self {
74 Self {
75 rescore_multiplier: 4,
76 use_fp32_final: false,
77 min_binary_candidates: 10,
78 max_binary_candidates: 1000,
79 }
80 }
81}
82
83impl TieredSearchConfig {
84 pub fn fast() -> Self {
86 Self {
87 rescore_multiplier: 2,
88 use_fp32_final: false,
89 min_binary_candidates: 10,
90 max_binary_candidates: 500,
91 }
92 }
93
94 pub fn quality() -> Self {
96 Self {
97 rescore_multiplier: 8,
98 use_fp32_final: true,
99 min_binary_candidates: 20,
100 max_binary_candidates: 2000,
101 }
102 }
103
104 pub fn precise() -> Self {
106 Self {
107 rescore_multiplier: 10,
108 use_fp32_final: true,
109 min_binary_candidates: 50,
110 max_binary_candidates: 5000,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
117pub struct TieredSearchResult {
118 pub id: usize,
120 pub distance: f32,
122 pub hamming_distance: u32,
124 pub int8_distance: Option<f32>,
126 pub fp32_distance: Option<f32>,
128}
129
130impl TieredSearchResult {
131 pub fn new(id: usize, hamming_distance: u32) -> Self {
132 Self {
133 id,
134 distance: hamming_distance as f32,
135 hamming_distance,
136 int8_distance: None,
137 fp32_distance: None,
138 }
139 }
140}
141
142pub struct TieredIndex {
144 binary_index: BinaryIndex,
146 int8_index: Int8Index,
148 fp32_vectors: Option<Vec<Vec<f32>>>,
150 dim: usize,
152 store_fp32: bool,
154 memory_config: Option<MemoryConstraint>,
156}
157
158#[derive(Debug, Clone)]
160pub struct MemoryConstraint {
161 pub max_bytes: usize,
163 pub max_vectors: usize,
165 pub bytes_per_vector: usize,
167 pub overhead_factor: f32,
169}
170
171#[derive(Debug, Clone)]
173pub struct MemoryLimitError {
174 pub current_vectors: usize,
176 pub max_vectors: usize,
178 pub current_bytes: usize,
180 pub max_bytes: usize,
182}
183
184impl std::fmt::Display for MemoryLimitError {
185 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186 write!(
187 f,
188 "Memory limit reached: {}/{} vectors, {:.2} MB/{:.2} MB",
189 self.current_vectors,
190 self.max_vectors,
191 self.current_bytes as f64 / 1_000_000.0,
192 self.max_bytes as f64 / 1_000_000.0
193 )
194 }
195}
196
197impl std::error::Error for MemoryLimitError {}
198
199impl MemoryConstraint {
200 pub fn bytes_per_vector(dim: usize, store_fp32: bool) -> usize {
202 let binary_bytes = dim.div_ceil(64) * 8; let int8_bytes = dim + 8; let fp32_bytes = if store_fp32 { dim * 4 } else { 0 };
205 binary_bytes + int8_bytes + fp32_bytes
206 }
207
208 pub fn from_bytes(max_bytes: usize, dim: usize, store_fp32: bool) -> Self {
210 let overhead_factor = 0.1; let usable_bytes = (max_bytes as f32 * (1.0 - overhead_factor)) as usize;
212 let bytes_per_vec = Self::bytes_per_vector(dim, store_fp32);
213 let max_vectors = usable_bytes / bytes_per_vec;
214
215 Self {
216 max_bytes,
217 max_vectors,
218 bytes_per_vector: bytes_per_vec,
219 overhead_factor,
220 }
221 }
222
223 pub fn from_vectors(max_vectors: usize, dim: usize, store_fp32: bool) -> Self {
225 let bytes_per_vec = Self::bytes_per_vector(dim, store_fp32);
226 let overhead_factor = 0.1;
227 let max_bytes = ((max_vectors * bytes_per_vec) as f32 / (1.0 - overhead_factor)) as usize;
228
229 Self {
230 max_bytes,
231 max_vectors,
232 bytes_per_vector: bytes_per_vec,
233 overhead_factor,
234 }
235 }
236}
237
238impl TieredIndex {
239 pub fn new(dim: usize) -> Self {
241 Self {
242 binary_index: BinaryIndex::new(dim),
243 int8_index: Int8Index::new(dim),
244 fp32_vectors: None,
245 dim,
246 store_fp32: false,
247 memory_config: None,
248 }
249 }
250
251 pub fn with_fp32_storage(dim: usize) -> Self {
253 Self {
254 binary_index: BinaryIndex::new(dim),
255 int8_index: Int8Index::new(dim),
256 fp32_vectors: Some(Vec::new()),
257 dim,
258 store_fp32: true,
259 memory_config: None,
260 }
261 }
262
263 pub fn with_capacity(dim: usize, capacity: usize, store_fp32: bool) -> Self {
265 Self {
266 binary_index: BinaryIndex::with_capacity(dim, capacity),
267 int8_index: Int8Index::with_capacity(dim, capacity),
268 fp32_vectors: if store_fp32 {
269 Some(Vec::with_capacity(capacity))
270 } else {
271 None
272 },
273 dim,
274 store_fp32,
275 memory_config: None,
276 }
277 }
278
279 pub fn memory_constrained(dim: usize, max_bytes: usize) -> Self {
297 let config = MemoryConstraint::from_bytes(max_bytes, dim, false);
298 let capacity = config.max_vectors;
299
300 Self {
301 binary_index: BinaryIndex::with_capacity(dim, capacity),
302 int8_index: Int8Index::with_capacity(dim, capacity),
303 fp32_vectors: None,
304 dim,
305 store_fp32: false,
306 memory_config: Some(config),
307 }
308 }
309
310 pub fn memory_constrained_precise(dim: usize, max_bytes: usize) -> Self {
314 let config = MemoryConstraint::from_bytes(max_bytes, dim, true);
315 let capacity = config.max_vectors;
316
317 Self {
318 binary_index: BinaryIndex::with_capacity(dim, capacity),
319 int8_index: Int8Index::with_capacity(dim, capacity),
320 fp32_vectors: Some(Vec::with_capacity(capacity)),
321 dim,
322 store_fp32: true,
323 memory_config: Some(config),
324 }
325 }
326
327 #[inline]
329 #[allow(non_snake_case)]
330 pub const fn MB(mb: usize) -> usize {
331 mb * 1024 * 1024
332 }
333
334 #[inline]
336 #[allow(non_snake_case)]
337 pub const fn GB(gb: usize) -> usize {
338 gb * 1024 * 1024 * 1024
339 }
340
341 #[inline]
343 pub fn is_constrained(&self) -> bool {
344 self.memory_config.is_some()
345 }
346
347 pub fn memory_constraint(&self) -> Option<&MemoryConstraint> {
349 self.memory_config.as_ref()
350 }
351
352 #[inline]
354 pub fn can_add(&self) -> bool {
355 match &self.memory_config {
356 Some(config) => self.len() < config.max_vectors,
357 None => true,
358 }
359 }
360
361 #[inline]
363 pub fn can_add_n(&self, n: usize) -> bool {
364 match &self.memory_config {
365 Some(config) => self.len() + n <= config.max_vectors,
366 None => true,
367 }
368 }
369
370 pub fn remaining_capacity(&self) -> Option<usize> {
372 self.memory_config
373 .as_ref()
374 .map(|c| c.max_vectors.saturating_sub(self.len()))
375 }
376
377 pub fn remaining_bytes(&self) -> Option<usize> {
379 self.memory_config.as_ref().map(|c| {
380 let used = self.memory_stats().total_bytes;
381 c.max_bytes.saturating_sub(used)
382 })
383 }
384
385 pub fn memory_utilization(&self) -> Option<f32> {
387 self.memory_config.as_ref().map(|c| {
388 if c.max_vectors == 0 {
389 0.0
390 } else {
391 self.len() as f32 / c.max_vectors as f32
392 }
393 })
394 }
395
396 pub fn add(&mut self, vector: &[f32]) -> bool {
401 debug_assert_eq!(vector.len(), self.dim, "Dimension mismatch");
402
403 if !self.can_add() {
405 return false;
406 }
407
408 self.binary_index.add_f32(vector);
409 self.int8_index.add_f32(vector);
410
411 if let Some(ref mut fp32) = self.fp32_vectors {
412 fp32.push(vector.to_vec());
413 }
414
415 true
416 }
417
418 pub fn add_unchecked(&mut self, vector: &[f32]) {
422 debug_assert_eq!(vector.len(), self.dim, "Dimension mismatch");
423
424 self.binary_index.add_f32(vector);
425 self.int8_index.add_f32(vector);
426
427 if let Some(ref mut fp32) = self.fp32_vectors {
428 fp32.push(vector.to_vec());
429 }
430 }
431
432 pub fn try_add(&mut self, vector: &[f32]) -> Result<(), MemoryLimitError> {
434 debug_assert_eq!(vector.len(), self.dim, "Dimension mismatch");
435
436 if let Some(ref config) = self.memory_config {
437 if self.len() >= config.max_vectors {
438 return Err(MemoryLimitError {
439 current_vectors: self.len(),
440 max_vectors: config.max_vectors,
441 current_bytes: self.memory_stats().total_bytes,
442 max_bytes: config.max_bytes,
443 });
444 }
445 }
446
447 self.add_unchecked(vector);
448 Ok(())
449 }
450
451 pub fn add_batch(&mut self, vectors: &[Vec<f32>]) -> usize {
455 let mut added = 0;
456 for v in vectors {
457 if self.add(v) {
458 added += 1;
459 } else {
460 break;
461 }
462 }
463 added
464 }
465
466 pub fn add_batch_partial(&mut self, vectors: &[Vec<f32>]) -> (usize, usize) {
470 let added = self.add_batch(vectors);
471 (added, vectors.len() - added)
472 }
473
474 #[inline]
476 pub fn len(&self) -> usize {
477 self.binary_index.len()
478 }
479
480 #[inline]
482 pub fn is_empty(&self) -> bool {
483 self.binary_index.is_empty()
484 }
485
486 #[inline]
488 pub fn dim(&self) -> usize {
489 self.dim
490 }
491
492 pub fn memory_stats(&self) -> TieredMemoryStats {
494 let binary_bytes = self.binary_index.memory_bytes();
495 let int8_bytes = self.int8_index.memory_bytes();
496 let fp32_bytes = self
497 .fp32_vectors
498 .as_ref()
499 .map(|v| v.len() * self.dim * 4)
500 .unwrap_or(0);
501
502 TieredMemoryStats {
503 binary_bytes,
504 int8_bytes,
505 fp32_bytes,
506 total_bytes: binary_bytes + int8_bytes + fp32_bytes,
507 n_vectors: self.len(),
508 dim: self.dim,
509 }
510 }
511
512 pub fn search(&self, query: &[f32], k: usize) -> Vec<TieredSearchResult> {
516 self.search_with_config(query, k, &TieredSearchConfig::default())
517 }
518
519 pub fn search_with_config(
521 &self,
522 query: &[f32],
523 k: usize,
524 config: &TieredSearchConfig,
525 ) -> Vec<TieredSearchResult> {
526 if self.is_empty() {
527 return Vec::new();
528 }
529
530 let k = k.min(self.len());
531
532 let n_binary_candidates = (k * config.rescore_multiplier)
534 .max(config.min_binary_candidates)
535 .min(config.max_binary_candidates)
536 .min(self.len());
537
538 let binary_query = BinaryVector::from_f32(query);
539 let binary_results = self.binary_index.search(&binary_query, n_binary_candidates);
540
541 let int8_rescored = self.int8_index.rescore_candidates(&binary_results, query);
543
544 let mut results: Vec<TieredSearchResult> = int8_rescored
546 .iter()
547 .take(if config.use_fp32_final { k * 2 } else { k })
548 .map(|&(id, int8_dist)| {
549 let hamming = binary_results
550 .iter()
551 .find(|(i, _)| *i == id)
552 .map(|(_, d)| *d)
553 .unwrap_or(0);
554
555 let mut result = TieredSearchResult::new(id, hamming);
556 result.int8_distance = Some(int8_dist);
557 result.distance = int8_dist;
558 result
559 })
560 .collect();
561
562 if config.use_fp32_final {
564 if let Some(ref fp32_vectors) = self.fp32_vectors {
565 for result in results.iter_mut() {
566 if result.id < fp32_vectors.len() {
567 let fp32_dist = cosine_distance_f32(query, &fp32_vectors[result.id]);
568 result.fp32_distance = Some(fp32_dist);
569 result.distance = fp32_dist;
570 }
571 }
572 results.sort_by(|a, b| {
574 a.distance
575 .partial_cmp(&b.distance)
576 .unwrap_or(Ordering::Equal)
577 .then_with(|| a.id.cmp(&b.id))
578 });
579 }
580 }
581
582 results.truncate(k);
583 results
584 }
585
586 pub fn search_binary_only(&self, query: &[f32], k: usize) -> Vec<TieredSearchResult> {
588 let binary_query = BinaryVector::from_f32(query);
589 let results = self.binary_index.search(&binary_query, k);
590
591 results
592 .into_iter()
593 .map(|(id, hamming)| TieredSearchResult::new(id, hamming))
594 .collect()
595 }
596
597 pub fn search_int8(
599 &self,
600 query: &[f32],
601 k: usize,
602 rescore_multiplier: usize,
603 ) -> Vec<TieredSearchResult> {
604 let config = TieredSearchConfig {
605 rescore_multiplier,
606 use_fp32_final: false,
607 ..Default::default()
608 };
609 self.search_with_config(query, k, &config)
610 }
611}
612
613#[derive(Debug, Clone)]
615pub struct TieredMemoryStats {
616 pub binary_bytes: usize,
618 pub int8_bytes: usize,
620 pub fp32_bytes: usize,
622 pub total_bytes: usize,
624 pub n_vectors: usize,
626 pub dim: usize,
628}
629
630impl TieredMemoryStats {
631 pub fn compression_ratio(&self) -> f32 {
633 let fp32_only = self.n_vectors * self.dim * 4;
634 if self.total_bytes > 0 {
635 fp32_only as f32 / self.total_bytes as f32
636 } else {
637 0.0
638 }
639 }
640
641 pub fn format(&self) -> String {
643 format!(
644 "Tiered Index: {} vectors × {} dim\n\
645 Binary: {} ({:.1} MB)\n\
646 int8: {} ({:.1} MB)\n\
647 fp32: {} ({:.1} MB)\n\
648 Total: {:.1} MB (vs {:.1} MB fp32-only, {:.1}x compression)",
649 self.n_vectors,
650 self.dim,
651 format_bytes(self.binary_bytes),
652 self.binary_bytes as f64 / 1_000_000.0,
653 format_bytes(self.int8_bytes),
654 self.int8_bytes as f64 / 1_000_000.0,
655 format_bytes(self.fp32_bytes),
656 self.fp32_bytes as f64 / 1_000_000.0,
657 self.total_bytes as f64 / 1_000_000.0,
658 (self.n_vectors * self.dim * 4) as f64 / 1_000_000.0,
659 self.compression_ratio()
660 )
661 }
662}
663
664fn format_bytes(bytes: usize) -> String {
665 if bytes >= 1_000_000_000 {
666 format!("{:.2} GB", bytes as f64 / 1_000_000_000.0)
667 } else if bytes >= 1_000_000 {
668 format!("{:.2} MB", bytes as f64 / 1_000_000.0)
669 } else if bytes >= 1_000 {
670 format!("{:.2} KB", bytes as f64 / 1_000.0)
671 } else {
672 format!("{} B", bytes)
673 }
674}
675
676fn cosine_distance_f32(a: &[f32], b: &[f32]) -> f32 {
678 let mut dot = 0.0f32;
679 let mut norm_a = 0.0f32;
680 let mut norm_b = 0.0f32;
681
682 for (x, y) in a.iter().zip(b.iter()) {
683 dot += x * y;
684 norm_a += x * x;
685 norm_b += y * y;
686 }
687
688 let denom = (norm_a * norm_b).sqrt();
689 if denom > 0.0 {
690 1.0 - dot / denom
691 } else {
692 1.0
693 }
694}
695
696pub struct TieredIndexBuilder {
702 dim: usize,
703 capacity: Option<usize>,
704 store_fp32: bool,
705}
706
707impl TieredIndexBuilder {
708 pub fn new(dim: usize) -> Self {
709 Self {
710 dim,
711 capacity: None,
712 store_fp32: false,
713 }
714 }
715
716 pub fn with_capacity(mut self, capacity: usize) -> Self {
718 self.capacity = Some(capacity);
719 self
720 }
721
722 pub fn with_fp32_storage(mut self) -> Self {
724 self.store_fp32 = true;
725 self
726 }
727
728 pub fn build(self) -> TieredIndex {
730 match self.capacity {
731 Some(cap) => TieredIndex::with_capacity(self.dim, cap, self.store_fp32),
732 None => {
733 if self.store_fp32 {
734 TieredIndex::with_fp32_storage(self.dim)
735 } else {
736 TieredIndex::new(self.dim)
737 }
738 }
739 }
740 }
741}
742
743#[cfg(test)]
748mod tests {
749 use super::*;
750
751 fn random_vector(dim: usize, seed: usize) -> Vec<f32> {
752 (0..dim)
754 .map(|i| {
755 let x = ((seed * 1103515245 + i * 12345) % 2147483648) as f32 / 2147483648.0;
756 x * 2.0 - 1.0 })
758 .collect()
759 }
760
761 #[test]
762 fn test_tiered_index_basic() {
763 let mut index = TieredIndex::new(64);
764
765 let v1 = random_vector(64, 1);
766 let v2 = random_vector(64, 2);
767 let v3 = random_vector(64, 3);
768
769 index.add(&v1);
770 index.add(&v2);
771 index.add(&v3);
772
773 assert_eq!(index.len(), 3);
774 }
775
776 #[test]
777 fn test_tiered_search() {
778 let mut index = TieredIndex::new(64);
779
780 for i in 0..100 {
782 index.add(&random_vector(64, i));
783 }
784
785 let query = random_vector(64, 0); let results = index.search(&query, 5);
788
789 assert_eq!(results.len(), 5);
790 assert_eq!(results[0].id, 0);
792 }
793
794 #[test]
795 fn test_tiered_with_fp32() {
796 let mut index = TieredIndex::with_fp32_storage(64);
797
798 for i in 0..50 {
799 index.add(&random_vector(64, i));
800 }
801
802 let query = random_vector(64, 0);
803 let results = index.search_with_config(&query, 5, &TieredSearchConfig::quality());
804
805 assert_eq!(results.len(), 5);
806 assert!(results[0].fp32_distance.is_some());
807 }
808
809 #[test]
810 fn test_memory_stats() {
811 let mut index = TieredIndex::new(1024);
812
813 for i in 0..1000 {
814 index.add(&random_vector(1024, i));
815 }
816
817 let stats = index.memory_stats();
818
819 assert!(stats.binary_bytes > 100_000);
821 assert!(stats.binary_bytes < 200_000);
822
823 assert!(stats.int8_bytes > 1_000_000);
825 assert!(stats.int8_bytes < 1_500_000);
826
827 assert!(stats.compression_ratio() > 2.0);
829 }
830
831 #[test]
832 fn test_binary_only_search() {
833 let mut index = TieredIndex::new(128);
834
835 for i in 0..100 {
836 index.add(&random_vector(128, i));
837 }
838
839 let query = random_vector(128, 50);
840 let results = index.search_binary_only(&query, 10);
841
842 assert_eq!(results.len(), 10);
843 assert!(results[0].int8_distance.is_none());
845 }
846
847 #[test]
848 fn test_search_configs() {
849 let mut index = TieredIndex::with_fp32_storage(64);
850
851 for i in 0..100 {
852 index.add(&random_vector(64, i));
853 }
854
855 let query = random_vector(64, 0);
856
857 let fast = index.search_with_config(&query, 5, &TieredSearchConfig::fast());
859 let quality = index.search_with_config(&query, 5, &TieredSearchConfig::quality());
860 let precise = index.search_with_config(&query, 5, &TieredSearchConfig::precise());
861
862 assert_eq!(fast.len(), 5);
863 assert_eq!(quality.len(), 5);
864 assert_eq!(precise.len(), 5);
865
866 assert!(quality[0].fp32_distance.is_some());
868 assert!(precise[0].fp32_distance.is_some());
869 }
870
871 #[test]
872 fn test_builder() {
873 let index = TieredIndexBuilder::new(256)
874 .with_capacity(1000)
875 .with_fp32_storage()
876 .build();
877
878 assert_eq!(index.dim(), 256);
879 assert!(index.is_empty());
880 }
881
882 #[test]
883 fn test_memory_constrained() {
884 let mut index = TieredIndex::memory_constrained(64, 100 * 1024);
888
889 assert!(index.is_constrained());
890 assert!(index.can_add());
891
892 let config = index.memory_constraint().unwrap();
893 assert!(config.max_vectors > 1000);
894 assert!(config.max_vectors < 1500);
895
896 let mut added = 0;
898 for i in 0..2000 {
899 if index.add(&random_vector(64, i)) {
900 added += 1;
901 } else {
902 break;
903 }
904 }
905
906 assert!(added < 2000);
908 assert_eq!(index.len(), added);
909 assert!(!index.can_add());
910 }
911
912 #[test]
913 fn test_memory_constrained_batch() {
914 let mut index = TieredIndex::memory_constrained(32, 50 * 1024); let vectors: Vec<Vec<f32>> = (0..1000).map(|i| random_vector(32, i)).collect();
917
918 let (added, remaining) = index.add_batch_partial(&vectors);
919
920 assert!(added > 0);
921 assert!(added < 1000);
922 assert_eq!(added + remaining, 1000);
923 assert_eq!(index.len(), added);
924 }
925
926 #[test]
927 fn test_memory_constrained_try_add() {
928 let mut index = TieredIndex::memory_constrained(16, 1024); let config = index.memory_constraint().unwrap();
931 let max = config.max_vectors;
932
933 for i in 0..max {
935 assert!(index.try_add(&random_vector(16, i)).is_ok());
936 }
937
938 let result = index.try_add(&random_vector(16, max + 1));
940 assert!(result.is_err());
941
942 let err = result.unwrap_err();
943 assert_eq!(err.current_vectors, max);
944 assert_eq!(err.max_vectors, max);
945 }
946
947 #[test]
948 fn test_memory_utilization() {
949 let mut index = TieredIndex::memory_constrained(64, 10 * 1024);
950
951 assert_eq!(index.memory_utilization(), Some(0.0));
952
953 let max = index.memory_constraint().unwrap().max_vectors;
954 let half = max / 2;
955
956 for i in 0..half {
957 index.add(&random_vector(64, i));
958 }
959
960 let util = index.memory_utilization().unwrap();
961 assert!(util > 0.4 && util < 0.6);
962 }
963
964 #[test]
965 fn test_remaining_capacity() {
966 let mut index = TieredIndex::memory_constrained(64, 20 * 1024);
967
968 let initial = index.remaining_capacity().unwrap();
969 assert!(initial > 0);
970
971 index.add(&random_vector(64, 0));
972
973 let after = index.remaining_capacity().unwrap();
974 assert_eq!(after, initial - 1);
975 }
976
977 #[test]
978 fn test_bytes_per_vector_calculation() {
979 let bpv = MemoryConstraint::bytes_per_vector(1024, false);
981 assert_eq!(bpv, 128 + 1032);
985
986 let bpv_fp32 = MemoryConstraint::bytes_per_vector(1024, true);
988 assert_eq!(bpv_fp32, 128 + 1032 + 4096);
990 }
991}