Skip to main content

bytesandbrains_codec/pq/
distance.rs

1use std::fmt;
2
3use bb_core::embedding::{Distance, EmbeddingSpace};
4
5use super::code::{PQCode, bytes_for_nbits};
6
7/// Precomputed distance table for Asymmetric Distance Computation (ADC).
8///
9/// Given a query vector, this table stores the squared distance from each
10/// query subvector to each centroid in that subspace. During search, the
11/// distance to an encoded vector is computed by summing table lookups.
12///
13/// Table layout: `table[m * ksub + k]` is the distance from query subvector
14/// `m` to centroid `k` in subspace `m`.
15///
16/// The const generics must match the ProductQuantizer configuration:
17/// - M: number of subquantizers
18/// - NBITS: bits per centroid index
19pub struct PQDistanceTable<S: EmbeddingSpace, const M: usize, const NBITS: usize>
20where
21    [(); bytes_for_nbits(NBITS)]:,
22{
23    table: Vec<S::DistanceValue>,
24    ksub: usize,
25}
26
27impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> fmt::Debug for PQDistanceTable<S, M, NBITS>
28where
29    [(); bytes_for_nbits(NBITS)]:,
30{
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        f.debug_struct("PQDistanceTable")
33            .field("M", &M)
34            .field("NBITS", &NBITS)
35            .field("ksub", &self.ksub)
36            .field("table_len", &self.table.len())
37            .finish()
38    }
39}
40
41impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> PQDistanceTable<S, M, NBITS>
42where
43    [(); bytes_for_nbits(NBITS)]:,
44{
45    /// Number of centroids per subspace (2^NBITS)
46    pub const KSUB: usize = 1 << NBITS;
47
48    pub fn new(table: Vec<S::DistanceValue>, ksub: usize) -> Self {
49        debug_assert_eq!(table.len(), M * ksub);
50        Self { table, ksub }
51    }
52
53    /// Compute approximate distance to an encoded vector using table lookups.
54    ///
55    /// This is the core of ADC: instead of computing the full distance,
56    /// we sum precomputed partial distances from each subspace.
57    pub fn distance(&self, code: &PQCode<M, NBITS>) -> S::DistanceValue {
58        (0..M)
59            .map(|m| {
60                let c = code.get(m) as usize;
61                self.table[m * self.ksub + c]
62            })
63            .fold(S::DistanceValue::zero(), |acc, d| acc + d)
64    }
65
66    pub fn m(&self) -> usize {
67        M
68    }
69
70    pub fn ksub(&self) -> usize {
71        self.ksub
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use bb_core::embedding::{F32Distance, F32L2Space};
79
80    #[test]
81    fn test_distance_table_lookup() {
82        // 2 subquantizers, 4 centroids each (NBITS=2)
83        let ksub = 4;
84
85        // Table: distances from query subvectors to centroids
86        let table: Vec<F32Distance> = vec![
87            // Subspace 0: distances to centroids 0,1,2,3
88            1.0.into(),
89            2.0.into(),
90            3.0.into(),
91            4.0.into(),
92            // Subspace 1: distances to centroids 0,1,2,3
93            5.0.into(),
94            6.0.into(),
95            7.0.into(),
96            8.0.into(),
97        ];
98
99        let dt = PQDistanceTable::<F32L2Space<4>, 2, 2>::new(table, ksub);
100
101        // Code [0, 2] -> table[0*4 + 0] + table[1*4 + 2] = 1.0 + 7.0 = 8.0
102        let code = PQCode::<2, 2>::from_indices(&[0, 2]);
103        assert_eq!(dt.distance(&code).value(), 8.0);
104
105        // Code [3, 1] -> table[0*4 + 3] + table[1*4 + 1] = 4.0 + 6.0 = 10.0
106        let code = PQCode::<2, 2>::from_indices(&[3, 1]);
107        assert_eq!(dt.distance(&code).value(), 10.0);
108    }
109
110    #[test]
111    fn test_distance_table_nbits8() {
112        // 2 subquantizers, 256 centroids each (NBITS=8)
113        let ksub = 256;
114
115        // Create a sparse table (all zeros except a few entries)
116        let mut table: Vec<F32Distance> = vec![0.0.into(); 2 * 256];
117        table[0 * 256 + 100] = 5.0.into();  // subspace 0, centroid 100
118        table[1 * 256 + 200] = 3.0.into();  // subspace 1, centroid 200
119
120        let dt = PQDistanceTable::<F32L2Space<4>, 2, 8>::new(table, ksub);
121
122        // Code [100, 200] -> 5.0 + 3.0 = 8.0
123        let code = PQCode::<2, 8>::from_indices(&[100, 200]);
124        assert_eq!(dt.distance(&code).value(), 8.0);
125
126        // Code [0, 0] -> 0.0 + 0.0 = 0.0
127        let code = PQCode::<2, 8>::from_indices(&[0, 0]);
128        assert_eq!(dt.distance(&code).value(), 0.0);
129    }
130}