use crate::list_bounds::{DistanceMetric, SphericalCapMetadata};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CentroidCompression {
Fp32,
Fp16,
Int8,
PQ { n_subquantizers: usize, n_bits: u8 },
OPQ { n_subquantizers: usize, n_bits: u8 },
}
impl CentroidCompression {
pub fn bytes_per_centroid(&self, dim: usize) -> usize {
match self {
Self::Fp32 => dim * 4,
Self::Fp16 => dim * 2,
Self::Int8 => dim,
Self::PQ {
n_subquantizers,
n_bits,
} => {
(*n_subquantizers * *n_bits as usize + 7) / 8
}
Self::OPQ {
n_subquantizers,
n_bits,
} => (*n_subquantizers * *n_bits as usize + 7) / 8,
}
}
pub fn fits_in_cache(&self, n_centroids: usize, dim: usize, cache_bytes: usize) -> bool {
self.bytes_per_centroid(dim) * n_centroids <= cache_bytes
}
pub fn recommend(n_centroids: usize, dim: usize, cache_bytes: usize) -> Self {
for compression in [
Self::Fp32,
Self::Fp16,
Self::Int8,
Self::PQ {
n_subquantizers: dim / 4,
n_bits: 8,
},
] {
if compression.fits_in_cache(n_centroids, dim, cache_bytes) {
return compression;
}
}
Self::PQ {
n_subquantizers: 16,
n_bits: 4,
}
}
}
#[derive(Debug, Clone)]
pub struct RoutingConfig {
pub compression: CentroidCompression,
pub refine_top_k: usize,
pub full_precision_refine: bool,
pub target_llc_bytes: usize,
pub metric: DistanceMetric,
pub prefetch_depth: usize,
}
impl Default for RoutingConfig {
fn default() -> Self {
Self {
compression: CentroidCompression::Fp16,
refine_top_k: 64,
full_precision_refine: true,
target_llc_bytes: 32 * 1024 * 1024, metric: DistanceMetric::Cosine,
prefetch_depth: 4,
}
}
}
impl RoutingConfig {
pub fn compression(mut self, compression: CentroidCompression) -> Self {
self.compression = compression;
self
}
pub fn refine_top_k(mut self, k: usize) -> Self {
self.refine_top_k = k;
self
}
pub fn target_llc(mut self, bytes: usize) -> Self {
self.target_llc_bytes = bytes;
self
}
pub fn metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
}
#[derive(Debug, Clone)]
pub struct Fp16Centroids {
data: Vec<u16>,
n_centroids: usize,
dim: usize,
}
impl Fp16Centroids {
pub fn from_fp32(centroids: &[f32], dim: usize) -> Self {
let n_centroids = centroids.len() / dim;
let data: Vec<u16> = centroids.iter().map(|&x| f32_to_f16(x)).collect();
Self {
data,
n_centroids,
dim,
}
}
pub fn get_fp32(&self, idx: usize) -> Vec<f32> {
let start = idx * self.dim;
self.data[start..start + self.dim]
.iter()
.map(|&x| f16_to_f32(x))
.collect()
}
pub fn dot_products(&self, query: &[f32]) -> Vec<f32> {
let query_f16: Vec<u16> = query.iter().map(|&x| f32_to_f16(x)).collect();
(0..self.n_centroids)
.map(|i| {
let start = i * self.dim;
let centroid = &self.data[start..start + self.dim];
dot_f16(centroid, &query_f16)
})
.collect()
}
pub fn memory_bytes(&self) -> usize {
self.data.len() * 2
}
}
#[derive(Debug, Clone)]
pub struct Int8Centroids {
data: Vec<i8>,
scales: Vec<f32>,
zero_points: Vec<f32>,
n_centroids: usize,
dim: usize,
}
impl Int8Centroids {
pub fn from_fp32(centroids: &[f32], dim: usize) -> Self {
let n_centroids = centroids.len() / dim;
let mut mins = vec![f32::MAX; dim];
let mut maxs = vec![f32::MIN; dim];
for i in 0..n_centroids {
for j in 0..dim {
let val = centroids[i * dim + j];
mins[j] = mins[j].min(val);
maxs[j] = maxs[j].max(val);
}
}
let mut scales = Vec::with_capacity(dim);
let mut zero_points = Vec::with_capacity(dim);
for j in 0..dim {
let range = maxs[j] - mins[j];
let scale = if range > 1e-10 { range / 255.0 } else { 1.0 };
scales.push(scale);
zero_points.push(mins[j]);
}
let data: Vec<i8> = centroids
.iter()
.enumerate()
.map(|(idx, &val)| {
let j = idx % dim;
let q = ((val - zero_points[j]) / scales[j]).round() as i32;
q.clamp(-128, 127) as i8
})
.collect();
Self {
data,
scales,
zero_points,
n_centroids,
dim,
}
}
pub fn get_fp32(&self, idx: usize) -> Vec<f32> {
let start = idx * self.dim;
(0..self.dim)
.map(|j| self.data[start + j] as f32 * self.scales[j] + self.zero_points[j])
.collect()
}
pub fn dot_products(&self, query: &[f32]) -> Vec<f32> {
let query_i8: Vec<i8> = query
.iter()
.enumerate()
.map(|(j, &val)| {
let q = ((val - self.zero_points[j]) / self.scales[j]).round() as i32;
q.clamp(-128, 127) as i8
})
.collect();
(0..self.n_centroids)
.map(|i| {
let start = i * self.dim;
let centroid = &self.data[start..start + self.dim];
let dot_i32: i32 = centroid
.iter()
.zip(query_i8.iter())
.map(|(&a, &b)| a as i32 * b as i32)
.sum();
dot_i32 as f32 * self.scales[0] * self.scales[0]
})
.collect()
}
pub fn memory_bytes(&self) -> usize {
self.data.len() + self.scales.len() * 4 + self.zero_points.len() * 4
}
}
pub struct RoutingLayer {
compressed: CompressedCentroids,
full_precision: Option<Vec<f32>>,
caps: Vec<SphericalCapMetadata>,
config: RoutingConfig,
dim: usize,
n_lists: usize,
}
enum CompressedCentroids {
Fp32(Vec<f32>),
Fp16(Fp16Centroids),
Int8(Int8Centroids),
}
impl RoutingLayer {
pub fn build(centroids: &[f32], dim: usize, config: RoutingConfig) -> Self {
let n_lists = centroids.len() / dim;
let compressed = match config.compression {
CentroidCompression::Fp32 => CompressedCentroids::Fp32(centroids.to_vec()),
CentroidCompression::Fp16 => {
CompressedCentroids::Fp16(Fp16Centroids::from_fp32(centroids, dim))
}
CentroidCompression::Int8 => {
CompressedCentroids::Int8(Int8Centroids::from_fp32(centroids, dim))
}
_ => {
CompressedCentroids::Fp16(Fp16Centroids::from_fp32(centroids, dim))
}
};
let full_precision = if config.full_precision_refine {
Some(centroids.to_vec())
} else {
None
};
let caps: Vec<SphericalCapMetadata> = (0..n_lists)
.map(|i| {
let centroid = ¢roids[i * dim..(i + 1) * dim];
SphericalCapMetadata {
centroid: centroid.to_vec(),
theta_max: 0.0, min_dot_to_centroid: 1.0,
max_dot_to_centroid: 1.0,
vector_count: 0,
mean_dot_to_centroid: 1.0,
}
})
.collect();
Self {
compressed,
full_precision,
caps,
config,
dim,
n_lists,
}
}
pub fn route(&self, query: &[f32], n_probes: usize) -> Vec<ListCandidate> {
let n_probes = n_probes.min(self.n_lists);
let coarse_scores = self.coarse_scores(query);
let refine_k = self.config.refine_top_k.min(self.n_lists);
let mut indices: Vec<usize> = (0..self.n_lists).collect();
if self.config.metric.higher_is_better() {
indices.select_nth_unstable_by(refine_k - 1, |&a, &b| {
coarse_scores[b].partial_cmp(&coarse_scores[a]).unwrap()
});
} else {
indices.select_nth_unstable_by(refine_k - 1, |&a, &b| {
coarse_scores[a].partial_cmp(&coarse_scores[b]).unwrap()
});
}
let top_indices = &indices[..refine_k];
let refined_scores = if let Some(ref full) = self.full_precision {
self.refine_scores(query, top_indices, full)
} else {
top_indices.iter().map(|&i| coarse_scores[i]).collect()
};
let mut candidates: Vec<ListCandidate> = top_indices
.iter()
.zip(refined_scores.iter())
.map(|(&idx, &score)| ListCandidate {
list_idx: idx as u32,
score,
bound: self.compute_bound(idx, query),
vector_count: self.caps[idx].vector_count,
})
.collect();
if self.config.metric.higher_is_better() {
candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
} else {
candidates.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
}
candidates.truncate(n_probes);
candidates
}
fn coarse_scores(&self, query: &[f32]) -> Vec<f32> {
match &self.compressed {
CompressedCentroids::Fp32(data) => self.dot_products_fp32(query, data),
CompressedCentroids::Fp16(fp16) => fp16.dot_products(query),
CompressedCentroids::Int8(int8) => int8.dot_products(query),
}
}
fn dot_products_fp32(&self, query: &[f32], centroids: &[f32]) -> Vec<f32> {
(0..self.n_lists)
.map(|i| {
let centroid = ¢roids[i * self.dim..(i + 1) * self.dim];
dot_product_f32(query, centroid)
})
.collect()
}
fn refine_scores(&self, query: &[f32], indices: &[usize], centroids: &[f32]) -> Vec<f32> {
indices
.iter()
.map(|&i| {
let centroid = ¢roids[i * self.dim..(i + 1) * self.dim];
dot_product_f32(query, centroid)
})
.collect()
}
fn compute_bound(&self, idx: usize, query: &[f32]) -> f32 {
let cap = &self.caps[idx];
let dot = dot_product_f32(query, &cap.centroid);
let angle = dot.clamp(-1.0, 1.0).acos();
let min_angle = (angle - cap.theta_max).max(0.0);
min_angle.cos()
}
pub fn update_cap(&mut self, list_idx: usize, cap: SphericalCapMetadata) {
if list_idx < self.caps.len() {
self.caps[list_idx] = cap;
}
}
pub fn memory_bytes(&self) -> usize {
let compressed_bytes = match &self.compressed {
CompressedCentroids::Fp32(data) => data.len() * 4,
CompressedCentroids::Fp16(fp16) => fp16.memory_bytes(),
CompressedCentroids::Int8(int8) => int8.memory_bytes(),
};
let full_bytes = self
.full_precision
.as_ref()
.map(|v| v.len() * 4)
.unwrap_or(0);
let cap_bytes = self.caps.len() * std::mem::size_of::<SphericalCapMetadata>();
compressed_bytes + full_bytes + cap_bytes
}
pub fn fits_in_cache(&self) -> bool {
let compressed_bytes = match &self.compressed {
CompressedCentroids::Fp32(data) => data.len() * 4,
CompressedCentroids::Fp16(fp16) => fp16.memory_bytes(),
CompressedCentroids::Int8(int8) => int8.memory_bytes(),
};
compressed_bytes <= self.config.target_llc_bytes
}
pub fn stats(&self) -> RoutingStats {
RoutingStats {
n_lists: self.n_lists,
dim: self.dim,
compression: format!("{:?}", self.config.compression),
compressed_bytes: match &self.compressed {
CompressedCentroids::Fp32(data) => data.len() * 4,
CompressedCentroids::Fp16(fp16) => fp16.memory_bytes(),
CompressedCentroids::Int8(int8) => int8.memory_bytes(),
},
total_bytes: self.memory_bytes(),
fits_in_cache: self.fits_in_cache(),
target_cache_bytes: self.config.target_llc_bytes,
}
}
}
#[derive(Debug, Clone)]
pub struct ListCandidate {
pub list_idx: u32,
pub score: f32,
pub bound: f32,
pub vector_count: u32,
}
#[derive(Debug, Clone)]
pub struct RoutingStats {
pub n_lists: usize,
pub dim: usize,
pub compression: String,
pub compressed_bytes: usize,
pub total_bytes: usize,
pub fits_in_cache: bool,
pub target_cache_bytes: usize,
}
#[inline]
fn f32_to_f16(x: f32) -> u16 {
let bits = x.to_bits();
let sign = (bits >> 31) & 1;
let exp = ((bits >> 23) & 0xff) as i32;
let frac = bits & 0x7fffff;
if exp == 0xff {
return ((sign << 15) | 0x7c00 | (frac >> 13)) as u16;
}
if exp == 0 {
return (sign << 15) as u16;
}
let new_exp = exp - 127 + 15;
if new_exp >= 31 {
return ((sign << 15) | 0x7c00) as u16;
}
if new_exp <= 0 {
return (sign << 15) as u16;
}
let new_frac = frac >> 13;
((sign << 15) | ((new_exp as u32) << 10) | new_frac) as u16
}
#[inline]
fn f16_to_f32(x: u16) -> f32 {
let sign = ((x >> 15) & 1) as u32;
let exp = ((x >> 10) & 0x1f) as u32;
let frac = (x & 0x3ff) as u32;
if exp == 0 {
if frac == 0 {
return f32::from_bits(sign << 31);
}
let normalized = (frac as f32) / 1024.0 * 2.0f32.powi(-14);
return if sign == 1 { -normalized } else { normalized };
}
if exp == 31 {
if frac == 0 {
return f32::from_bits((sign << 31) | 0x7f800000);
}
return f32::NAN;
}
let new_exp = (exp as i32 - 15 + 127) as u32;
let new_frac = frac << 13;
f32::from_bits((sign << 31) | (new_exp << 23) | new_frac)
}
#[inline]
fn dot_f16(a: &[u16], b: &[u16]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| f16_to_f32(x) * f16_to_f32(y))
.sum()
}
#[inline]
fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compression_bytes() {
let dim = 768;
assert_eq!(CentroidCompression::Fp32.bytes_per_centroid(dim), 3072);
assert_eq!(CentroidCompression::Fp16.bytes_per_centroid(dim), 1536);
assert_eq!(CentroidCompression::Int8.bytes_per_centroid(dim), 768);
}
#[test]
fn test_compression_recommendation() {
let cache_32mb = 32 * 1024 * 1024;
let dim = 768;
let rec1 = CentroidCompression::recommend(10_000, dim, cache_32mb);
assert!(matches!(rec1, CentroidCompression::Fp32));
let rec2 = CentroidCompression::recommend(20_000, dim, cache_32mb);
assert!(matches!(rec2, CentroidCompression::Fp16));
let rec3 = CentroidCompression::recommend(40_000, dim, cache_32mb);
assert!(matches!(rec3, CentroidCompression::Int8));
}
#[test]
fn test_fp16_conversion() {
let values = [0.0, 1.0, -1.0, 0.5, 0.123, 100.0, -100.0];
for &x in &values {
let f16 = f32_to_f16(x);
let back = f16_to_f32(f16);
let rel_error = if x.abs() > 1e-10 {
(x - back).abs() / x.abs()
} else {
(x - back).abs()
};
assert!(
rel_error < 0.01,
"FP16 roundtrip error too high: {} -> {} -> {}",
x,
f16,
back
);
}
}
#[test]
fn test_routing_layer() {
let dim = 4;
let n_centroids = 10;
let centroids: Vec<f32> = (0..n_centroids * dim)
.map(|i| (i as f32 / (n_centroids * dim) as f32))
.collect();
let config = RoutingConfig::default()
.compression(CentroidCompression::Fp16)
.refine_top_k(5);
let routing = RoutingLayer::build(¢roids, dim, config);
let query = vec![0.5, 0.5, 0.5, 0.5];
let candidates = routing.route(&query, 3);
assert_eq!(candidates.len(), 3);
assert!(routing.fits_in_cache());
}
#[test]
fn test_int8_centroids() {
let dim = 4;
let centroids = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
let int8 = Int8Centroids::from_fp32(¢roids, dim);
let recovered = int8.get_fp32(0);
for i in 0..dim {
let error = (recovered[i] - centroids[i]).abs();
assert!(error < 0.1, "Int8 quantization error too high");
}
}
}