use alloc::vec::Vec;
use crate::vector_ops::{DistanceMetric, dot_product, euclidean_distance_sq, manhattan_distance};
use super::pq::Codebooks;
pub struct AdcTable {
distances: Vec<f32>,
num_subvectors: usize,
}
impl AdcTable {
pub fn build(query: &[f32], codebooks: &Codebooks, metric: DistanceMetric) -> Self {
let m = codebooks.num_subvectors;
let sub_dim = codebooks.sub_dim;
let required_len = m.saturating_mul(sub_dim);
if query.len() < required_len {
return Self {
distances: Vec::new(),
num_subvectors: 0,
};
}
let mut distances = Vec::with_capacity(m * 256);
for sub_idx in 0..m {
let q_sub = &query[sub_idx * sub_dim..(sub_idx + 1) * sub_dim];
for k in 0..256 {
let centroid = codebooks.centroid(sub_idx, k);
let d = subvector_distance(q_sub, centroid, metric);
distances.push(d);
}
}
Self {
distances,
num_subvectors: m,
}
}
#[inline]
pub fn approximate_distance(&self, pq_codes: &[u8]) -> f32 {
let len = pq_codes.len().min(self.num_subvectors);
let mut dist = 0.0f32;
for (m, &code) in pq_codes[..len].iter().enumerate() {
let idx = m * 256 + code as usize;
if let Some(&d) = self.distances.get(idx) {
dist += d;
}
}
dist
}
}
impl core::fmt::Debug for AdcTable {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("AdcTable")
.field("num_subvectors", &self.num_subvectors)
.field("table_entries", &self.distances.len())
.finish()
}
}
#[inline]
fn subvector_distance(query_sub: &[f32], centroid: &[f32], metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::EuclideanSq => euclidean_distance_sq(query_sub, centroid),
DistanceMetric::DotProduct | DistanceMetric::Cosine => -dot_product(query_sub, centroid),
DistanceMetric::Manhattan => manhattan_distance(query_sub, centroid),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ivfpq::pq::train_codebooks;
#[test]
fn adc_matches_exact_distance() {
#[rustfmt::skip]
let training: Vec<f32> = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0,
];
let codebooks = train_codebooks(&training, 8, 2, 25, DistanceMetric::EuclideanSq).unwrap();
let query = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0];
let adc = AdcTable::build(&query, &codebooks, DistanceMetric::EuclideanSq);
let codes = codebooks.encode(&training[0..8]);
let approx_dist = adc.approximate_distance(&codes);
assert!(
approx_dist < 0.5,
"expected near-zero approx distance for self, got {approx_dist}"
);
}
#[test]
fn adc_ordering_preserved() {
#[rustfmt::skip]
let training: Vec<f32> = vec![
0.0, 0.0, 0.0, 0.0,
1.0, 1.0, 1.0, 1.0,
5.0, 5.0, 5.0, 5.0,
10.0, 10.0, 10.0, 10.0,
];
let codebooks = train_codebooks(&training, 4, 2, 25, DistanceMetric::EuclideanSq).unwrap();
let query = vec![0.0, 0.0, 0.0, 0.0];
let adc = AdcTable::build(&query, &codebooks, DistanceMetric::EuclideanSq);
let codes_near = codebooks.encode(&[1.0, 1.0, 1.0, 1.0]);
let codes_far = codebooks.encode(&[10.0, 10.0, 10.0, 10.0]);
let d_near = adc.approximate_distance(&codes_near);
let d_far = adc.approximate_distance(&codes_far);
assert!(
d_near < d_far,
"ordering violated: near={d_near}, far={d_far}"
);
}
}