use super::binary_quantize::{BinaryIndex, BinaryVector};
use super::int8_quantize::Int8Index;
use std::cmp::Ordering;
#[derive(Debug, Clone)]
pub struct TieredSearchConfig {
pub rescore_multiplier: usize,
pub use_fp32_final: bool,
pub min_binary_candidates: usize,
pub max_binary_candidates: usize,
}
impl Default for TieredSearchConfig {
fn default() -> Self {
Self {
rescore_multiplier: 4,
use_fp32_final: false,
min_binary_candidates: 10,
max_binary_candidates: 1000,
}
}
}
impl TieredSearchConfig {
pub fn fast() -> Self {
Self {
rescore_multiplier: 2,
use_fp32_final: false,
min_binary_candidates: 10,
max_binary_candidates: 500,
}
}
pub fn quality() -> Self {
Self {
rescore_multiplier: 8,
use_fp32_final: true,
min_binary_candidates: 20,
max_binary_candidates: 2000,
}
}
pub fn precise() -> Self {
Self {
rescore_multiplier: 10,
use_fp32_final: true,
min_binary_candidates: 50,
max_binary_candidates: 5000,
}
}
}
#[derive(Debug, Clone)]
pub struct TieredSearchResult {
pub id: usize,
pub distance: f32,
pub hamming_distance: u32,
pub int8_distance: Option<f32>,
pub fp32_distance: Option<f32>,
}
impl TieredSearchResult {
pub fn new(id: usize, hamming_distance: u32) -> Self {
Self {
id,
distance: hamming_distance as f32,
hamming_distance,
int8_distance: None,
fp32_distance: None,
}
}
}
pub struct TieredIndex {
binary_index: BinaryIndex,
int8_index: Int8Index,
fp32_vectors: Option<Vec<Vec<f32>>>,
dim: usize,
store_fp32: bool,
memory_config: Option<MemoryConstraint>,
}
#[derive(Debug, Clone)]
pub struct MemoryConstraint {
pub max_bytes: usize,
pub max_vectors: usize,
pub bytes_per_vector: usize,
pub overhead_factor: f32,
}
#[derive(Debug, Clone)]
pub struct MemoryLimitError {
pub current_vectors: usize,
pub max_vectors: usize,
pub current_bytes: usize,
pub max_bytes: usize,
}
impl std::fmt::Display for MemoryLimitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Memory limit reached: {}/{} vectors, {:.2} MB/{:.2} MB",
self.current_vectors,
self.max_vectors,
self.current_bytes as f64 / 1_000_000.0,
self.max_bytes as f64 / 1_000_000.0
)
}
}
impl std::error::Error for MemoryLimitError {}
impl MemoryConstraint {
pub fn bytes_per_vector(dim: usize, store_fp32: bool) -> usize {
let binary_bytes = dim.div_ceil(64) * 8; let int8_bytes = dim + 8; let fp32_bytes = if store_fp32 { dim * 4 } else { 0 };
binary_bytes + int8_bytes + fp32_bytes
}
pub fn from_bytes(max_bytes: usize, dim: usize, store_fp32: bool) -> Self {
let overhead_factor = 0.1; let usable_bytes = (max_bytes as f32 * (1.0 - overhead_factor)) as usize;
let bytes_per_vec = Self::bytes_per_vector(dim, store_fp32);
let max_vectors = usable_bytes / bytes_per_vec;
Self {
max_bytes,
max_vectors,
bytes_per_vector: bytes_per_vec,
overhead_factor,
}
}
pub fn from_vectors(max_vectors: usize, dim: usize, store_fp32: bool) -> Self {
let bytes_per_vec = Self::bytes_per_vector(dim, store_fp32);
let overhead_factor = 0.1;
let max_bytes = ((max_vectors * bytes_per_vec) as f32 / (1.0 - overhead_factor)) as usize;
Self {
max_bytes,
max_vectors,
bytes_per_vector: bytes_per_vec,
overhead_factor,
}
}
}
impl TieredIndex {
pub fn new(dim: usize) -> Self {
Self {
binary_index: BinaryIndex::new(dim),
int8_index: Int8Index::new(dim),
fp32_vectors: None,
dim,
store_fp32: false,
memory_config: None,
}
}
pub fn with_fp32_storage(dim: usize) -> Self {
Self {
binary_index: BinaryIndex::new(dim),
int8_index: Int8Index::new(dim),
fp32_vectors: Some(Vec::new()),
dim,
store_fp32: true,
memory_config: None,
}
}
pub fn with_capacity(dim: usize, capacity: usize, store_fp32: bool) -> Self {
Self {
binary_index: BinaryIndex::with_capacity(dim, capacity),
int8_index: Int8Index::with_capacity(dim, capacity),
fp32_vectors: if store_fp32 {
Some(Vec::with_capacity(capacity))
} else {
None
},
dim,
store_fp32,
memory_config: None,
}
}
pub fn memory_constrained(dim: usize, max_bytes: usize) -> Self {
let config = MemoryConstraint::from_bytes(max_bytes, dim, false);
let capacity = config.max_vectors;
Self {
binary_index: BinaryIndex::with_capacity(dim, capacity),
int8_index: Int8Index::with_capacity(dim, capacity),
fp32_vectors: None,
dim,
store_fp32: false,
memory_config: Some(config),
}
}
pub fn memory_constrained_precise(dim: usize, max_bytes: usize) -> Self {
let config = MemoryConstraint::from_bytes(max_bytes, dim, true);
let capacity = config.max_vectors;
Self {
binary_index: BinaryIndex::with_capacity(dim, capacity),
int8_index: Int8Index::with_capacity(dim, capacity),
fp32_vectors: Some(Vec::with_capacity(capacity)),
dim,
store_fp32: true,
memory_config: Some(config),
}
}
#[inline]
#[allow(non_snake_case)]
pub const fn MB(mb: usize) -> usize {
mb * 1024 * 1024
}
#[inline]
#[allow(non_snake_case)]
pub const fn GB(gb: usize) -> usize {
gb * 1024 * 1024 * 1024
}
#[inline]
pub fn is_constrained(&self) -> bool {
self.memory_config.is_some()
}
pub fn memory_constraint(&self) -> Option<&MemoryConstraint> {
self.memory_config.as_ref()
}
#[inline]
pub fn can_add(&self) -> bool {
match &self.memory_config {
Some(config) => self.len() < config.max_vectors,
None => true,
}
}
#[inline]
pub fn can_add_n(&self, n: usize) -> bool {
match &self.memory_config {
Some(config) => self.len() + n <= config.max_vectors,
None => true,
}
}
pub fn remaining_capacity(&self) -> Option<usize> {
self.memory_config
.as_ref()
.map(|c| c.max_vectors.saturating_sub(self.len()))
}
pub fn remaining_bytes(&self) -> Option<usize> {
self.memory_config.as_ref().map(|c| {
let used = self.memory_stats().total_bytes;
c.max_bytes.saturating_sub(used)
})
}
pub fn memory_utilization(&self) -> Option<f32> {
self.memory_config.as_ref().map(|c| {
if c.max_vectors == 0 {
0.0
} else {
self.len() as f32 / c.max_vectors as f32
}
})
}
pub fn add(&mut self, vector: &[f32]) -> bool {
debug_assert_eq!(vector.len(), self.dim, "Dimension mismatch");
if !self.can_add() {
return false;
}
self.binary_index.add_f32(vector);
self.int8_index.add_f32(vector);
if let Some(ref mut fp32) = self.fp32_vectors {
fp32.push(vector.to_vec());
}
true
}
pub fn add_unchecked(&mut self, vector: &[f32]) {
debug_assert_eq!(vector.len(), self.dim, "Dimension mismatch");
self.binary_index.add_f32(vector);
self.int8_index.add_f32(vector);
if let Some(ref mut fp32) = self.fp32_vectors {
fp32.push(vector.to_vec());
}
}
pub fn try_add(&mut self, vector: &[f32]) -> Result<(), MemoryLimitError> {
debug_assert_eq!(vector.len(), self.dim, "Dimension mismatch");
if let Some(ref config) = self.memory_config {
if self.len() >= config.max_vectors {
return Err(MemoryLimitError {
current_vectors: self.len(),
max_vectors: config.max_vectors,
current_bytes: self.memory_stats().total_bytes,
max_bytes: config.max_bytes,
});
}
}
self.add_unchecked(vector);
Ok(())
}
pub fn add_batch(&mut self, vectors: &[Vec<f32>]) -> usize {
let mut added = 0;
for v in vectors {
if self.add(v) {
added += 1;
} else {
break;
}
}
added
}
pub fn add_batch_partial(&mut self, vectors: &[Vec<f32>]) -> (usize, usize) {
let added = self.add_batch(vectors);
(added, vectors.len() - added)
}
#[inline]
pub fn len(&self) -> usize {
self.binary_index.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.binary_index.is_empty()
}
#[inline]
pub fn dim(&self) -> usize {
self.dim
}
pub fn memory_stats(&self) -> TieredMemoryStats {
let binary_bytes = self.binary_index.memory_bytes();
let int8_bytes = self.int8_index.memory_bytes();
let fp32_bytes = self
.fp32_vectors
.as_ref()
.map(|v| v.len() * self.dim * 4)
.unwrap_or(0);
TieredMemoryStats {
binary_bytes,
int8_bytes,
fp32_bytes,
total_bytes: binary_bytes + int8_bytes + fp32_bytes,
n_vectors: self.len(),
dim: self.dim,
}
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<TieredSearchResult> {
self.search_with_config(query, k, &TieredSearchConfig::default())
}
pub fn search_with_config(
&self,
query: &[f32],
k: usize,
config: &TieredSearchConfig,
) -> Vec<TieredSearchResult> {
if self.is_empty() {
return Vec::new();
}
let k = k.min(self.len());
let n_binary_candidates = (k * config.rescore_multiplier)
.max(config.min_binary_candidates)
.min(config.max_binary_candidates)
.min(self.len());
let binary_query = BinaryVector::from_f32(query);
let binary_results = self.binary_index.search(&binary_query, n_binary_candidates);
let int8_rescored = self.int8_index.rescore_candidates(&binary_results, query);
let mut results: Vec<TieredSearchResult> = int8_rescored
.iter()
.take(if config.use_fp32_final { k * 2 } else { k })
.map(|&(id, int8_dist)| {
let hamming = binary_results
.iter()
.find(|(i, _)| *i == id)
.map(|(_, d)| *d)
.unwrap_or(0);
let mut result = TieredSearchResult::new(id, hamming);
result.int8_distance = Some(int8_dist);
result.distance = int8_dist;
result
})
.collect();
if config.use_fp32_final {
if let Some(ref fp32_vectors) = self.fp32_vectors {
for result in results.iter_mut() {
if result.id < fp32_vectors.len() {
let fp32_dist = cosine_distance_f32(query, &fp32_vectors[result.id]);
result.fp32_distance = Some(fp32_dist);
result.distance = fp32_dist;
}
}
results.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(Ordering::Equal)
.then_with(|| a.id.cmp(&b.id))
});
}
}
results.truncate(k);
results
}
pub fn search_binary_only(&self, query: &[f32], k: usize) -> Vec<TieredSearchResult> {
let binary_query = BinaryVector::from_f32(query);
let results = self.binary_index.search(&binary_query, k);
results
.into_iter()
.map(|(id, hamming)| TieredSearchResult::new(id, hamming))
.collect()
}
pub fn search_int8(
&self,
query: &[f32],
k: usize,
rescore_multiplier: usize,
) -> Vec<TieredSearchResult> {
let config = TieredSearchConfig {
rescore_multiplier,
use_fp32_final: false,
..Default::default()
};
self.search_with_config(query, k, &config)
}
}
#[derive(Debug, Clone)]
pub struct TieredMemoryStats {
pub binary_bytes: usize,
pub int8_bytes: usize,
pub fp32_bytes: usize,
pub total_bytes: usize,
pub n_vectors: usize,
pub dim: usize,
}
impl TieredMemoryStats {
pub fn compression_ratio(&self) -> f32 {
let fp32_only = self.n_vectors * self.dim * 4;
if self.total_bytes > 0 {
fp32_only as f32 / self.total_bytes as f32
} else {
0.0
}
}
pub fn format(&self) -> String {
format!(
"Tiered Index: {} vectors × {} dim\n\
Binary: {} ({:.1} MB)\n\
int8: {} ({:.1} MB)\n\
fp32: {} ({:.1} MB)\n\
Total: {:.1} MB (vs {:.1} MB fp32-only, {:.1}x compression)",
self.n_vectors,
self.dim,
format_bytes(self.binary_bytes),
self.binary_bytes as f64 / 1_000_000.0,
format_bytes(self.int8_bytes),
self.int8_bytes as f64 / 1_000_000.0,
format_bytes(self.fp32_bytes),
self.fp32_bytes as f64 / 1_000_000.0,
self.total_bytes as f64 / 1_000_000.0,
(self.n_vectors * self.dim * 4) as f64 / 1_000_000.0,
self.compression_ratio()
)
}
}
fn format_bytes(bytes: usize) -> String {
if bytes >= 1_000_000_000 {
format!("{:.2} GB", bytes as f64 / 1_000_000_000.0)
} else if bytes >= 1_000_000 {
format!("{:.2} MB", bytes as f64 / 1_000_000.0)
} else if bytes >= 1_000 {
format!("{:.2} KB", bytes as f64 / 1_000.0)
} else {
format!("{} B", bytes)
}
}
fn cosine_distance_f32(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = (norm_a * norm_b).sqrt();
if denom > 0.0 {
1.0 - dot / denom
} else {
1.0
}
}
pub struct TieredIndexBuilder {
dim: usize,
capacity: Option<usize>,
store_fp32: bool,
}
impl TieredIndexBuilder {
pub fn new(dim: usize) -> Self {
Self {
dim,
capacity: None,
store_fp32: false,
}
}
pub fn with_capacity(mut self, capacity: usize) -> Self {
self.capacity = Some(capacity);
self
}
pub fn with_fp32_storage(mut self) -> Self {
self.store_fp32 = true;
self
}
pub fn build(self) -> TieredIndex {
match self.capacity {
Some(cap) => TieredIndex::with_capacity(self.dim, cap, self.store_fp32),
None => {
if self.store_fp32 {
TieredIndex::with_fp32_storage(self.dim)
} else {
TieredIndex::new(self.dim)
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn random_vector(dim: usize, seed: usize) -> Vec<f32> {
(0..dim)
.map(|i| {
let x = ((seed * 1103515245 + i * 12345) % 2147483648) as f32 / 2147483648.0;
x * 2.0 - 1.0 })
.collect()
}
#[test]
fn test_tiered_index_basic() {
let mut index = TieredIndex::new(64);
let v1 = random_vector(64, 1);
let v2 = random_vector(64, 2);
let v3 = random_vector(64, 3);
index.add(&v1);
index.add(&v2);
index.add(&v3);
assert_eq!(index.len(), 3);
}
#[test]
fn test_tiered_search() {
let mut index = TieredIndex::new(64);
for i in 0..100 {
index.add(&random_vector(64, i));
}
let query = random_vector(64, 0); let results = index.search(&query, 5);
assert_eq!(results.len(), 5);
assert_eq!(results[0].id, 0);
}
#[test]
fn test_tiered_with_fp32() {
let mut index = TieredIndex::with_fp32_storage(64);
for i in 0..50 {
index.add(&random_vector(64, i));
}
let query = random_vector(64, 0);
let results = index.search_with_config(&query, 5, &TieredSearchConfig::quality());
assert_eq!(results.len(), 5);
assert!(results[0].fp32_distance.is_some());
}
#[test]
fn test_memory_stats() {
let mut index = TieredIndex::new(1024);
for i in 0..1000 {
index.add(&random_vector(1024, i));
}
let stats = index.memory_stats();
assert!(stats.binary_bytes > 100_000);
assert!(stats.binary_bytes < 200_000);
assert!(stats.int8_bytes > 1_000_000);
assert!(stats.int8_bytes < 1_500_000);
assert!(stats.compression_ratio() > 2.0);
}
#[test]
fn test_binary_only_search() {
let mut index = TieredIndex::new(128);
for i in 0..100 {
index.add(&random_vector(128, i));
}
let query = random_vector(128, 50);
let results = index.search_binary_only(&query, 10);
assert_eq!(results.len(), 10);
assert!(results[0].int8_distance.is_none());
}
#[test]
fn test_search_configs() {
let mut index = TieredIndex::with_fp32_storage(64);
for i in 0..100 {
index.add(&random_vector(64, i));
}
let query = random_vector(64, 0);
let fast = index.search_with_config(&query, 5, &TieredSearchConfig::fast());
let quality = index.search_with_config(&query, 5, &TieredSearchConfig::quality());
let precise = index.search_with_config(&query, 5, &TieredSearchConfig::precise());
assert_eq!(fast.len(), 5);
assert_eq!(quality.len(), 5);
assert_eq!(precise.len(), 5);
assert!(quality[0].fp32_distance.is_some());
assert!(precise[0].fp32_distance.is_some());
}
#[test]
fn test_builder() {
let index = TieredIndexBuilder::new(256)
.with_capacity(1000)
.with_fp32_storage()
.build();
assert_eq!(index.dim(), 256);
assert!(index.is_empty());
}
#[test]
fn test_memory_constrained() {
let mut index = TieredIndex::memory_constrained(64, 100 * 1024);
assert!(index.is_constrained());
assert!(index.can_add());
let config = index.memory_constraint().unwrap();
assert!(config.max_vectors > 1000);
assert!(config.max_vectors < 1500);
let mut added = 0;
for i in 0..2000 {
if index.add(&random_vector(64, i)) {
added += 1;
} else {
break;
}
}
assert!(added < 2000);
assert_eq!(index.len(), added);
assert!(!index.can_add());
}
#[test]
fn test_memory_constrained_batch() {
let mut index = TieredIndex::memory_constrained(32, 50 * 1024);
let vectors: Vec<Vec<f32>> = (0..1000).map(|i| random_vector(32, i)).collect();
let (added, remaining) = index.add_batch_partial(&vectors);
assert!(added > 0);
assert!(added < 1000);
assert_eq!(added + remaining, 1000);
assert_eq!(index.len(), added);
}
#[test]
fn test_memory_constrained_try_add() {
let mut index = TieredIndex::memory_constrained(16, 1024);
let config = index.memory_constraint().unwrap();
let max = config.max_vectors;
for i in 0..max {
assert!(index.try_add(&random_vector(16, i)).is_ok());
}
let result = index.try_add(&random_vector(16, max + 1));
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.current_vectors, max);
assert_eq!(err.max_vectors, max);
}
#[test]
fn test_memory_utilization() {
let mut index = TieredIndex::memory_constrained(64, 10 * 1024);
assert_eq!(index.memory_utilization(), Some(0.0));
let max = index.memory_constraint().unwrap().max_vectors;
let half = max / 2;
for i in 0..half {
index.add(&random_vector(64, i));
}
let util = index.memory_utilization().unwrap();
assert!(util > 0.4 && util < 0.6);
}
#[test]
fn test_remaining_capacity() {
let mut index = TieredIndex::memory_constrained(64, 20 * 1024);
let initial = index.remaining_capacity().unwrap();
assert!(initial > 0);
index.add(&random_vector(64, 0));
let after = index.remaining_capacity().unwrap();
assert_eq!(after, initial - 1);
}
#[test]
fn test_bytes_per_vector_calculation() {
let bpv = MemoryConstraint::bytes_per_vector(1024, false);
assert_eq!(bpv, 128 + 1032);
let bpv_fp32 = MemoryConstraint::bytes_per_vector(1024, true);
assert_eq!(bpv_fp32, 128 + 1032 + 4096);
}
}