use super::code::{PQCode, bytes_for_nbits};
#[derive(Clone, Debug, PartialEq)]
pub struct SDCTable<const M: usize, const NBITS: usize>
where
[(); bytes_for_nbits(NBITS)]:,
{
table: Vec<f32>,
ksub: usize,
}
impl<const M: usize, const NBITS: usize> SDCTable<M, NBITS>
where
[(); bytes_for_nbits(NBITS)]:,
{
pub const KSUB: usize = 1 << NBITS;
pub fn from_centroids(centroids: &[f32], dsub: usize) -> Self {
Self::from_centroids_with_distance(centroids, dsub, |a, b| {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum()
})
}
pub fn from_centroids_with_distance(
centroids: &[f32],
dsub: usize,
subspace_distance: impl Fn(&[f32], &[f32]) -> f32,
) -> Self {
let ksub = Self::KSUB;
assert_eq!(
centroids.len(),
M * ksub * dsub,
"centroids length mismatch"
);
let mut table = vec![0.0f32; M * ksub * ksub];
for m in 0..M {
for i in 0..ksub {
let ci_offset = (m * ksub + i) * dsub;
let ci = ¢roids[ci_offset..ci_offset + dsub];
for j in 0..ksub {
let cj_offset = (m * ksub + j) * dsub;
let cj = ¢roids[cj_offset..cj_offset + dsub];
table[m * ksub * ksub + i * ksub + j] = subspace_distance(ci, cj);
}
}
}
Self { table, ksub }
}
#[inline]
pub fn distance(&self, code1: &PQCode<M, NBITS>, code2: &PQCode<M, NBITS>) -> f32 {
let mut sum = 0.0f32;
for m in 0..M {
let i = code1.get(m) as usize;
let j = code2.get(m) as usize;
sum += self.table[m * self.ksub * self.ksub + i * self.ksub + j];
}
sum
}
pub fn ksub(&self) -> usize {
self.ksub
}
pub fn table_data(&self) -> &[f32] {
&self.table
}
pub fn from_raw(table: Vec<f32>, ksub: usize) -> Self {
Self { table, ksub }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sdc_same_codes() {
let centroids: Vec<f32> = vec![
0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 2.0, 2.0, 2.0, ];
let sdc = SDCTable::<2, 2>::from_centroids(¢roids, 2);
let code1 = PQCode::<2, 2>::from_indices(&[0, 0]);
let code2 = PQCode::<2, 2>::from_indices(&[0, 0]);
assert_eq!(sdc.distance(&code1, &code2), 0.0);
let code3 = PQCode::<2, 2>::from_indices(&[3, 3]);
let code4 = PQCode::<2, 2>::from_indices(&[3, 3]);
assert_eq!(sdc.distance(&code3, &code4), 0.0);
}
#[test]
fn test_sdc_different_codes() {
let centroids: Vec<f32> = vec![
0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 2.0, 2.0, 2.0, ];
let sdc = SDCTable::<2, 2>::from_centroids(¢roids, 2);
let code1 = PQCode::<2, 2>::from_indices(&[0, 0]);
let code2 = PQCode::<2, 2>::from_indices(&[1, 1]);
assert_eq!(sdc.distance(&code1, &code2), 5.0);
}
#[test]
fn test_sdc_symmetry() {
let centroids: Vec<f32> = vec![
0.0, 0.0,
1.0, 1.0,
2.0, 2.0,
3.0, 3.0,
0.0, 0.0,
1.0, 1.0,
2.0, 2.0,
3.0, 3.0,
];
let sdc = SDCTable::<2, 2>::from_centroids(¢roids, 2);
let code1 = PQCode::<2, 2>::from_indices(&[0, 1]);
let code2 = PQCode::<2, 2>::from_indices(&[2, 3]);
assert_eq!(sdc.distance(&code1, &code2), sdc.distance(&code2, &code1));
}
#[test]
fn test_sdc_nbits8() {
let centroids: Vec<f32> = vec![0.0; 2 * 256 * 2];
let sdc = SDCTable::<2, 8>::from_centroids(¢roids, 2);
let code1 = PQCode::<2, 8>::from_indices(&[0, 0]);
let code2 = PQCode::<2, 8>::from_indices(&[255, 255]);
assert_eq!(sdc.distance(&code1, &code2), 0.0);
}
}