Skip to main content

bytesandbrains_codec/pq/
sdc.rs

1use super::code::{PQCode, bytes_for_nbits};
2
3/// Precomputed centroid-to-centroid distances for Symmetric Distance Computation (SDC).
4///
5/// SDC enables fast distance computation between two PQ codes without needing
6/// the original vectors. During quantizer training, we precompute squared L2
7/// distances between all pairs of centroids in each subspace.
8///
9/// Table layout: `table[m * ksub * ksub + i * ksub + j]` is the squared distance
10/// between centroid i and centroid j in subspace m.
11///
12/// The const generics must match the ProductQuantizer configuration:
13/// - M: number of subquantizers
14/// - NBITS: bits per centroid index
15#[derive(Clone, Debug, PartialEq)]
16pub struct SDCTable<const M: usize, const NBITS: usize>
17where
18    [(); bytes_for_nbits(NBITS)]:,
19{
20    /// Flat storage: M * ksub * ksub entries
21    table: Vec<f32>,
22    /// Number of centroids per subspace (2^NBITS)
23    ksub: usize,
24}
25
26impl<const M: usize, const NBITS: usize> SDCTable<M, NBITS>
27where
28    [(); bytes_for_nbits(NBITS)]:,
29{
30    /// Number of centroids per subspace (2^NBITS)
31    pub const KSUB: usize = 1 << NBITS;
32
33    /// Create an SDCTable from centroids using squared L2 distance.
34    ///
35    /// # Arguments
36    /// * `centroids` - Flat centroid storage: M * ksub * dsub floats
37    /// * `dsub` - Dimension of each subspace
38    pub fn from_centroids(centroids: &[f32], dsub: usize) -> Self {
39        Self::from_centroids_with_distance(centroids, dsub, |a, b| {
40            a.iter()
41                .zip(b.iter())
42                .map(|(x, y)| {
43                    let diff = x - y;
44                    diff * diff
45                })
46                .sum()
47        })
48    }
49
50    /// Create an SDCTable from centroids using a custom subspace distance function.
51    ///
52    /// The distance function receives two subvector slices of length `dsub`
53    /// and returns a non-negative distance value.
54    ///
55    /// # Arguments
56    /// * `centroids` - Flat centroid storage: M * ksub * dsub floats
57    /// * `dsub` - Dimension of each subspace
58    /// * `subspace_distance` - Distance function for subvector pairs
59    pub fn from_centroids_with_distance(
60        centroids: &[f32],
61        dsub: usize,
62        subspace_distance: impl Fn(&[f32], &[f32]) -> f32,
63    ) -> Self {
64        let ksub = Self::KSUB;
65        assert_eq!(
66            centroids.len(),
67            M * ksub * dsub,
68            "centroids length mismatch"
69        );
70
71        let mut table = vec![0.0f32; M * ksub * ksub];
72
73        for m in 0..M {
74            for i in 0..ksub {
75                let ci_offset = (m * ksub + i) * dsub;
76                let ci = &centroids[ci_offset..ci_offset + dsub];
77                for j in 0..ksub {
78                    let cj_offset = (m * ksub + j) * dsub;
79                    let cj = &centroids[cj_offset..cj_offset + dsub];
80
81                    table[m * ksub * ksub + i * ksub + j] = subspace_distance(ci, cj);
82                }
83            }
84        }
85
86        Self { table, ksub }
87    }
88
89    /// Compute approximate squared L2 distance between two PQ codes.
90    ///
91    /// This is the core of SDC: sum precomputed centroid-to-centroid distances
92    /// for each subspace.
93    #[inline]
94    pub fn distance(&self, code1: &PQCode<M, NBITS>, code2: &PQCode<M, NBITS>) -> f32 {
95        let mut sum = 0.0f32;
96        for m in 0..M {
97            let i = code1.get(m) as usize;
98            let j = code2.get(m) as usize;
99            sum += self.table[m * self.ksub * self.ksub + i * self.ksub + j];
100        }
101        sum
102    }
103
104    /// Number of centroids per subspace.
105    pub fn ksub(&self) -> usize {
106        self.ksub
107    }
108
109    /// Access the raw table data for serialization.
110    pub fn table_data(&self) -> &[f32] {
111        &self.table
112    }
113
114    /// Reconstruct an SDCTable from raw data.
115    pub fn from_raw(table: Vec<f32>, ksub: usize) -> Self {
116        Self { table, ksub }
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn test_sdc_same_codes() {
126        // 2 subspaces, 4 centroids each (NBITS=2), 2-dim subspace
127        let centroids: Vec<f32> = vec![
128            // Subspace 0: 4 centroids of dim 2
129            0.0, 0.0, // centroid 0
130            1.0, 0.0, // centroid 1
131            0.0, 1.0, // centroid 2
132            1.0, 1.0, // centroid 3
133            // Subspace 1: 4 centroids of dim 2
134            0.0, 0.0, // centroid 0
135            2.0, 0.0, // centroid 1
136            0.0, 2.0, // centroid 2
137            2.0, 2.0, // centroid 3
138        ];
139
140        let sdc = SDCTable::<2, 2>::from_centroids(&centroids, 2);
141
142        // Same codes should have distance 0
143        let code1 = PQCode::<2, 2>::from_indices(&[0, 0]);
144        let code2 = PQCode::<2, 2>::from_indices(&[0, 0]);
145        assert_eq!(sdc.distance(&code1, &code2), 0.0);
146
147        let code3 = PQCode::<2, 2>::from_indices(&[3, 3]);
148        let code4 = PQCode::<2, 2>::from_indices(&[3, 3]);
149        assert_eq!(sdc.distance(&code3, &code4), 0.0);
150    }
151
152    #[test]
153    fn test_sdc_different_codes() {
154        // 2 subspaces, 4 centroids each (NBITS=2), 2-dim subspace
155        let centroids: Vec<f32> = vec![
156            // Subspace 0: 4 centroids of dim 2
157            0.0, 0.0, // centroid 0
158            1.0, 0.0, // centroid 1
159            0.0, 1.0, // centroid 2
160            1.0, 1.0, // centroid 3
161            // Subspace 1: 4 centroids of dim 2
162            0.0, 0.0, // centroid 0
163            2.0, 0.0, // centroid 1
164            0.0, 2.0, // centroid 2
165            2.0, 2.0, // centroid 3
166        ];
167
168        let sdc = SDCTable::<2, 2>::from_centroids(&centroids, 2);
169
170        // code1=[0,0] -> centroids (0,0) and (0,0)
171        // code2=[1,1] -> centroids (1,0) and (2,0)
172        // Distance in subspace 0: ||(0,0) - (1,0)||^2 = 1
173        // Distance in subspace 1: ||(0,0) - (2,0)||^2 = 4
174        // Total: 5
175        let code1 = PQCode::<2, 2>::from_indices(&[0, 0]);
176        let code2 = PQCode::<2, 2>::from_indices(&[1, 1]);
177        assert_eq!(sdc.distance(&code1, &code2), 5.0);
178    }
179
180    #[test]
181    fn test_sdc_symmetry() {
182        let centroids: Vec<f32> = vec![
183            0.0, 0.0,
184            1.0, 1.0,
185            2.0, 2.0,
186            3.0, 3.0,
187            0.0, 0.0,
188            1.0, 1.0,
189            2.0, 2.0,
190            3.0, 3.0,
191        ];
192
193        let sdc = SDCTable::<2, 2>::from_centroids(&centroids, 2);
194
195        let code1 = PQCode::<2, 2>::from_indices(&[0, 1]);
196        let code2 = PQCode::<2, 2>::from_indices(&[2, 3]);
197
198        // Distance should be symmetric
199        assert_eq!(sdc.distance(&code1, &code2), sdc.distance(&code2, &code1));
200    }
201
202    #[test]
203    fn test_sdc_nbits8() {
204        // Test with nbits=8 (256 centroids) - just verify it compiles and works
205        // Using only 4 centroids for test simplicity
206        let centroids: Vec<f32> = vec![0.0; 2 * 256 * 2]; // M=2, ksub=256, dsub=2
207
208        let sdc = SDCTable::<2, 8>::from_centroids(&centroids, 2);
209
210        let code1 = PQCode::<2, 8>::from_indices(&[0, 0]);
211        let code2 = PQCode::<2, 8>::from_indices(&[255, 255]);
212
213        // All centroids are zero, so distance should be 0
214        assert_eq!(sdc.distance(&code1, &code2), 0.0);
215    }
216}