bytesandbrains_codec/pq/sdc.rs
1use super::code::{PQCode, bytes_for_nbits};
2
3/// Precomputed centroid-to-centroid distances for Symmetric Distance Computation (SDC).
4///
5/// SDC enables fast distance computation between two PQ codes without needing
6/// the original vectors. During quantizer training, we precompute squared L2
7/// distances between all pairs of centroids in each subspace.
8///
9/// Table layout: `table[m * ksub * ksub + i * ksub + j]` is the squared distance
10/// between centroid i and centroid j in subspace m.
11///
12/// The const generics must match the ProductQuantizer configuration:
13/// - M: number of subquantizers
14/// - NBITS: bits per centroid index
15#[derive(Clone, Debug, PartialEq)]
16pub struct SDCTable<const M: usize, const NBITS: usize>
17where
18 [(); bytes_for_nbits(NBITS)]:,
19{
20 /// Flat storage: M * ksub * ksub entries
21 table: Vec<f32>,
22 /// Number of centroids per subspace (2^NBITS)
23 ksub: usize,
24}
25
26impl<const M: usize, const NBITS: usize> SDCTable<M, NBITS>
27where
28 [(); bytes_for_nbits(NBITS)]:,
29{
30 /// Number of centroids per subspace (2^NBITS)
31 pub const KSUB: usize = 1 << NBITS;
32
33 /// Create an SDCTable from centroids using squared L2 distance.
34 ///
35 /// # Arguments
36 /// * `centroids` - Flat centroid storage: M * ksub * dsub floats
37 /// * `dsub` - Dimension of each subspace
38 pub fn from_centroids(centroids: &[f32], dsub: usize) -> Self {
39 Self::from_centroids_with_distance(centroids, dsub, |a, b| {
40 a.iter()
41 .zip(b.iter())
42 .map(|(x, y)| {
43 let diff = x - y;
44 diff * diff
45 })
46 .sum()
47 })
48 }
49
50 /// Create an SDCTable from centroids using a custom subspace distance function.
51 ///
52 /// The distance function receives two subvector slices of length `dsub`
53 /// and returns a non-negative distance value.
54 ///
55 /// # Arguments
56 /// * `centroids` - Flat centroid storage: M * ksub * dsub floats
57 /// * `dsub` - Dimension of each subspace
58 /// * `subspace_distance` - Distance function for subvector pairs
59 pub fn from_centroids_with_distance(
60 centroids: &[f32],
61 dsub: usize,
62 subspace_distance: impl Fn(&[f32], &[f32]) -> f32,
63 ) -> Self {
64 let ksub = Self::KSUB;
65 assert_eq!(
66 centroids.len(),
67 M * ksub * dsub,
68 "centroids length mismatch"
69 );
70
71 let mut table = vec![0.0f32; M * ksub * ksub];
72
73 for m in 0..M {
74 for i in 0..ksub {
75 let ci_offset = (m * ksub + i) * dsub;
76 let ci = ¢roids[ci_offset..ci_offset + dsub];
77 for j in 0..ksub {
78 let cj_offset = (m * ksub + j) * dsub;
79 let cj = ¢roids[cj_offset..cj_offset + dsub];
80
81 table[m * ksub * ksub + i * ksub + j] = subspace_distance(ci, cj);
82 }
83 }
84 }
85
86 Self { table, ksub }
87 }
88
89 /// Compute approximate squared L2 distance between two PQ codes.
90 ///
91 /// This is the core of SDC: sum precomputed centroid-to-centroid distances
92 /// for each subspace.
93 #[inline]
94 pub fn distance(&self, code1: &PQCode<M, NBITS>, code2: &PQCode<M, NBITS>) -> f32 {
95 let mut sum = 0.0f32;
96 for m in 0..M {
97 let i = code1.get(m) as usize;
98 let j = code2.get(m) as usize;
99 sum += self.table[m * self.ksub * self.ksub + i * self.ksub + j];
100 }
101 sum
102 }
103
104 /// Number of centroids per subspace.
105 pub fn ksub(&self) -> usize {
106 self.ksub
107 }
108
109 /// Access the raw table data for serialization.
110 pub fn table_data(&self) -> &[f32] {
111 &self.table
112 }
113
114 /// Reconstruct an SDCTable from raw data.
115 pub fn from_raw(table: Vec<f32>, ksub: usize) -> Self {
116 Self { table, ksub }
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn test_sdc_same_codes() {
126 // 2 subspaces, 4 centroids each (NBITS=2), 2-dim subspace
127 let centroids: Vec<f32> = vec![
128 // Subspace 0: 4 centroids of dim 2
129 0.0, 0.0, // centroid 0
130 1.0, 0.0, // centroid 1
131 0.0, 1.0, // centroid 2
132 1.0, 1.0, // centroid 3
133 // Subspace 1: 4 centroids of dim 2
134 0.0, 0.0, // centroid 0
135 2.0, 0.0, // centroid 1
136 0.0, 2.0, // centroid 2
137 2.0, 2.0, // centroid 3
138 ];
139
140 let sdc = SDCTable::<2, 2>::from_centroids(¢roids, 2);
141
142 // Same codes should have distance 0
143 let code1 = PQCode::<2, 2>::from_indices(&[0, 0]);
144 let code2 = PQCode::<2, 2>::from_indices(&[0, 0]);
145 assert_eq!(sdc.distance(&code1, &code2), 0.0);
146
147 let code3 = PQCode::<2, 2>::from_indices(&[3, 3]);
148 let code4 = PQCode::<2, 2>::from_indices(&[3, 3]);
149 assert_eq!(sdc.distance(&code3, &code4), 0.0);
150 }
151
152 #[test]
153 fn test_sdc_different_codes() {
154 // 2 subspaces, 4 centroids each (NBITS=2), 2-dim subspace
155 let centroids: Vec<f32> = vec![
156 // Subspace 0: 4 centroids of dim 2
157 0.0, 0.0, // centroid 0
158 1.0, 0.0, // centroid 1
159 0.0, 1.0, // centroid 2
160 1.0, 1.0, // centroid 3
161 // Subspace 1: 4 centroids of dim 2
162 0.0, 0.0, // centroid 0
163 2.0, 0.0, // centroid 1
164 0.0, 2.0, // centroid 2
165 2.0, 2.0, // centroid 3
166 ];
167
168 let sdc = SDCTable::<2, 2>::from_centroids(¢roids, 2);
169
170 // code1=[0,0] -> centroids (0,0) and (0,0)
171 // code2=[1,1] -> centroids (1,0) and (2,0)
172 // Distance in subspace 0: ||(0,0) - (1,0)||^2 = 1
173 // Distance in subspace 1: ||(0,0) - (2,0)||^2 = 4
174 // Total: 5
175 let code1 = PQCode::<2, 2>::from_indices(&[0, 0]);
176 let code2 = PQCode::<2, 2>::from_indices(&[1, 1]);
177 assert_eq!(sdc.distance(&code1, &code2), 5.0);
178 }
179
180 #[test]
181 fn test_sdc_symmetry() {
182 let centroids: Vec<f32> = vec![
183 0.0, 0.0,
184 1.0, 1.0,
185 2.0, 2.0,
186 3.0, 3.0,
187 0.0, 0.0,
188 1.0, 1.0,
189 2.0, 2.0,
190 3.0, 3.0,
191 ];
192
193 let sdc = SDCTable::<2, 2>::from_centroids(¢roids, 2);
194
195 let code1 = PQCode::<2, 2>::from_indices(&[0, 1]);
196 let code2 = PQCode::<2, 2>::from_indices(&[2, 3]);
197
198 // Distance should be symmetric
199 assert_eq!(sdc.distance(&code1, &code2), sdc.distance(&code2, &code1));
200 }
201
202 #[test]
203 fn test_sdc_nbits8() {
204 // Test with nbits=8 (256 centroids) - just verify it compiles and works
205 // Using only 4 centroids for test simplicity
206 let centroids: Vec<f32> = vec![0.0; 2 * 256 * 2]; // M=2, ksub=256, dsub=2
207
208 let sdc = SDCTable::<2, 8>::from_centroids(¢roids, 2);
209
210 let code1 = PQCode::<2, 8>::from_indices(&[0, 0]);
211 let code2 = PQCode::<2, 8>::from_indices(&[255, 255]);
212
213 // All centroids are zero, so distance should be 0
214 assert_eq!(sdc.distance(&code1, &code2), 0.0);
215 }
216}