1#[derive(Debug, Clone)]
16pub struct PqConfig {
17 pub num_subspaces: usize,
19 pub num_centroids: usize,
21 pub dimension: usize,
23}
24
25impl PqConfig {
26 pub fn new(
31 dimension: usize,
32 num_subspaces: usize,
33 num_centroids: usize,
34 ) -> Result<Self, String> {
35 if num_subspaces == 0 {
36 return Err("num_subspaces must be > 0".to_string());
37 }
38 if num_centroids == 0 {
39 return Err("num_centroids must be > 0".to_string());
40 }
41 if dimension == 0 {
42 return Err("dimension must be > 0".to_string());
43 }
44 if dimension % num_subspaces != 0 {
45 return Err(format!(
46 "dimension ({}) must be divisible by num_subspaces ({})",
47 dimension, num_subspaces
48 ));
49 }
50 Ok(Self {
51 num_subspaces,
52 num_centroids,
53 dimension,
54 })
55 }
56
57 pub fn subspace_dim(&self) -> usize {
59 self.dimension / self.num_subspaces
60 }
61}
62
63pub struct PqEncoder {
72 config: PqConfig,
74 codebooks: Vec<Vec<Vec<f32>>>,
76}
77
78impl PqEncoder {
79 pub fn new_random(config: PqConfig) -> Self {
82 let sub_dim = config.subspace_dim();
83 let mut seed: u64 = 0xdeadbeef_cafebabe;
84 let mut codebooks: Vec<Vec<Vec<f32>>> = Vec::with_capacity(config.num_subspaces);
85
86 for _ in 0..config.num_subspaces {
87 let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(config.num_centroids);
88 for _ in 0..config.num_centroids {
89 let centroid: Vec<f32> = (0..sub_dim)
90 .map(|_| {
91 seed = seed
93 .wrapping_mul(6_364_136_223_846_793_005)
94 .wrapping_add(1_442_695_040_888_963_407);
95 let bits = (seed >> 11) as f32;
97 bits / (1u64 << 53) as f32 * 2.0 - 1.0
98 })
99 .collect();
100 centroids.push(centroid);
101 }
102 codebooks.push(centroids);
103 }
104
105 Self { config, codebooks }
106 }
107
108 pub fn encode(&self, vector: &[f32]) -> Result<Vec<usize>, String> {
112 if vector.len() != self.config.dimension {
113 return Err(format!(
114 "Vector length {} does not match configured dimension {}",
115 vector.len(),
116 self.config.dimension
117 ));
118 }
119 let sub_dim = self.config.subspace_dim();
120 let mut codes = Vec::with_capacity(self.config.num_subspaces);
121
122 for m in 0..self.config.num_subspaces {
123 let sub_vec = &vector[m * sub_dim..(m + 1) * sub_dim];
124 let best = self.nearest_centroid(m, sub_vec);
125 codes.push(best);
126 }
127 Ok(codes)
128 }
129
130 pub fn decode(&self, codes: &[usize]) -> Result<Vec<f32>, String> {
135 if codes.len() != self.config.num_subspaces {
136 return Err(format!(
137 "Expected {} codes, got {}",
138 self.config.num_subspaces,
139 codes.len()
140 ));
141 }
142 let sub_dim = self.config.subspace_dim();
143 let mut result = vec![0.0f32; self.config.dimension];
144
145 for (m, &code) in codes.iter().enumerate() {
146 if code >= self.config.num_centroids {
147 return Err(format!(
148 "Code {} in sub-space {} exceeds num_centroids {}",
149 code, m, self.config.num_centroids
150 ));
151 }
152 let centroid = &self.codebooks[m][code];
153 let offset = m * sub_dim;
154 result[offset..offset + sub_dim].copy_from_slice(centroid);
155 }
156 Ok(result)
157 }
158
159 pub fn asymmetric_distance(&self, query: &[f32], codes: &[usize]) -> Result<f32, String> {
166 if query.len() != self.config.dimension {
167 return Err(format!(
168 "Query length {} does not match configured dimension {}",
169 query.len(),
170 self.config.dimension
171 ));
172 }
173 if codes.len() != self.config.num_subspaces {
174 return Err(format!(
175 "Expected {} codes, got {}",
176 self.config.num_subspaces,
177 codes.len()
178 ));
179 }
180 let sub_dim = self.config.subspace_dim();
181 let mut total_dist = 0.0f32;
182
183 for (m, &code) in codes.iter().enumerate() {
184 if code >= self.config.num_centroids {
185 return Err(format!(
186 "Code {} in sub-space {} exceeds num_centroids {}",
187 code, m, self.config.num_centroids
188 ));
189 }
190 let centroid = &self.codebooks[m][code];
191 let sub_query = &query[m * sub_dim..(m + 1) * sub_dim];
192 let sq_dist: f32 = sub_query
193 .iter()
194 .zip(centroid.iter())
195 .map(|(q, c)| (q - c) * (q - c))
196 .sum();
197 total_dist += sq_dist;
198 }
199 Ok(total_dist)
200 }
201
202 pub fn config(&self) -> &PqConfig {
204 &self.config
205 }
206
207 fn nearest_centroid(&self, m: usize, sub_vec: &[f32]) -> usize {
211 let centroids = &self.codebooks[m];
212 let mut best_idx = 0usize;
213 let mut best_dist = f32::MAX;
214
215 for (k, centroid) in centroids.iter().enumerate() {
216 let dist: f32 = sub_vec
217 .iter()
218 .zip(centroid.iter())
219 .map(|(a, b)| (a - b) * (a - b))
220 .sum();
221 if dist < best_dist {
222 best_dist = dist;
223 best_idx = k;
224 }
225 }
226 best_idx
227 }
228}
229
230#[cfg(test)]
235mod tests {
236 use super::*;
237
238 fn make_encoder(dim: usize, m: usize, k: usize) -> PqEncoder {
239 let cfg = PqConfig::new(dim, m, k).expect("valid config");
240 PqEncoder::new_random(cfg)
241 }
242
243 #[test]
246 fn test_config_valid() {
247 let cfg = PqConfig::new(64, 4, 256).expect("ok");
248 assert_eq!(cfg.dimension, 64);
249 assert_eq!(cfg.num_subspaces, 4);
250 assert_eq!(cfg.num_centroids, 256);
251 }
252
253 #[test]
254 fn test_config_subspace_dim() {
255 let cfg = PqConfig::new(64, 4, 256).expect("ok");
256 assert_eq!(cfg.subspace_dim(), 16);
257 }
258
259 #[test]
260 fn test_config_subspace_dim_small() {
261 let cfg = PqConfig::new(8, 2, 4).expect("ok");
262 assert_eq!(cfg.subspace_dim(), 4);
263 }
264
265 #[test]
266 fn test_config_invalid_not_divisible() {
267 let result = PqConfig::new(7, 4, 256);
268 assert!(result.is_err());
269 }
270
271 #[test]
272 fn test_config_invalid_zero_subspaces() {
273 let result = PqConfig::new(64, 0, 256);
274 assert!(result.is_err());
275 }
276
277 #[test]
278 fn test_config_invalid_zero_centroids() {
279 let result = PqConfig::new(64, 4, 0);
280 assert!(result.is_err());
281 }
282
283 #[test]
284 fn test_config_invalid_zero_dimension() {
285 let result = PqConfig::new(0, 4, 256);
286 assert!(result.is_err());
287 }
288
289 #[test]
290 fn test_config_single_subspace() {
291 let cfg = PqConfig::new(16, 1, 8).expect("ok");
292 assert_eq!(cfg.subspace_dim(), 16);
293 }
294
295 #[test]
298 fn test_encode_returns_m_codes() {
299 let enc = make_encoder(16, 4, 8);
300 let vec: Vec<f32> = (0..16).map(|i| i as f32).collect();
301 let codes = enc.encode(&vec).expect("encode ok");
302 assert_eq!(codes.len(), 4);
303 }
304
305 #[test]
306 fn test_encode_codes_in_range() {
307 let enc = make_encoder(16, 4, 8);
308 let vec: Vec<f32> = (0..16).map(|i| i as f32 * 0.5).collect();
309 let codes = enc.encode(&vec).expect("encode ok");
310 for code in codes {
311 assert!(code < 8, "code {} should be < 8", code);
312 }
313 }
314
315 #[test]
316 fn test_encode_wrong_dimension_error() {
317 let enc = make_encoder(16, 4, 8);
318 let result = enc.encode(&[1.0, 2.0, 3.0]);
319 assert!(result.is_err());
320 }
321
322 #[test]
323 fn test_encode_zero_vector() {
324 let enc = make_encoder(8, 2, 4);
325 let vec = vec![0.0f32; 8];
326 let codes = enc.encode(&vec).expect("encode ok");
327 assert_eq!(codes.len(), 2);
328 }
329
330 #[test]
331 fn test_encode_deterministic() {
332 let enc = make_encoder(16, 4, 8);
333 let vec: Vec<f32> = (0..16).map(|i| i as f32).collect();
334 let codes1 = enc.encode(&vec).expect("ok");
335 let codes2 = enc.encode(&vec).expect("ok");
336 assert_eq!(codes1, codes2);
337 }
338
339 #[test]
342 fn test_decode_returns_full_dimension() {
343 let enc = make_encoder(16, 4, 8);
344 let codes = vec![0usize; 4];
345 let decoded = enc.decode(&codes).expect("decode ok");
346 assert_eq!(decoded.len(), 16);
347 }
348
349 #[test]
350 fn test_decode_wrong_code_count_error() {
351 let enc = make_encoder(16, 4, 8);
352 let codes = vec![0usize; 3]; assert!(enc.decode(&codes).is_err());
354 }
355
356 #[test]
357 fn test_decode_out_of_range_code_error() {
358 let enc = make_encoder(16, 4, 8);
359 let codes = vec![0, 0, 0, 100]; assert!(enc.decode(&codes).is_err());
361 }
362
363 #[test]
364 fn test_encode_decode_roundtrip_shape() {
365 let enc = make_encoder(32, 4, 16);
366 let vec: Vec<f32> = (0..32).map(|i| i as f32).collect();
367 let codes = enc.encode(&vec).expect("encode ok");
368 let decoded = enc.decode(&codes).expect("decode ok");
369 assert_eq!(decoded.len(), 32);
370 assert_eq!(codes.len(), 4);
371 }
372
373 #[test]
376 fn test_asymmetric_distance_non_negative() {
377 let enc = make_encoder(16, 4, 8);
378 let vec: Vec<f32> = (0..16).map(|i| i as f32).collect();
379 let codes = enc.encode(&vec).expect("encode ok");
380 let dist = enc.asymmetric_distance(&vec, &codes).expect("dist ok");
381 assert!(dist >= 0.0);
382 }
383
384 #[test]
385 fn test_asymmetric_distance_zero_for_centroid_query() {
386 let enc = make_encoder(8, 2, 4);
387 let vec = vec![0.0f32; 8];
390 let codes = enc.encode(&vec).expect("encode ok");
391 let dist = enc.asymmetric_distance(&vec, &codes).expect("dist ok");
392 assert!(dist >= 0.0);
393 }
394
395 #[test]
396 fn test_asymmetric_distance_wrong_query_dim() {
397 let enc = make_encoder(16, 4, 8);
398 let codes = vec![0usize; 4];
399 let result = enc.asymmetric_distance(&[1.0, 2.0], &codes);
400 assert!(result.is_err());
401 }
402
403 #[test]
404 fn test_asymmetric_distance_wrong_code_count() {
405 let enc = make_encoder(16, 4, 8);
406 let vec = vec![0.0f32; 16];
407 let result = enc.asymmetric_distance(&vec, &[0, 0]);
408 assert!(result.is_err());
409 }
410
411 #[test]
414 fn test_config_accessor() {
415 let enc = make_encoder(32, 8, 16);
416 let cfg = enc.config();
417 assert_eq!(cfg.dimension, 32);
418 assert_eq!(cfg.num_subspaces, 8);
419 assert_eq!(cfg.num_centroids, 16);
420 assert_eq!(cfg.subspace_dim(), 4);
421 }
422
423 #[test]
426 fn test_new_random_reproducible() {
427 let cfg1 = PqConfig::new(16, 4, 8).expect("ok");
428 let cfg2 = PqConfig::new(16, 4, 8).expect("ok");
429 let enc1 = PqEncoder::new_random(cfg1);
430 let enc2 = PqEncoder::new_random(cfg2);
431 let vec: Vec<f32> = (0..16).map(|i| i as f32).collect();
432 assert_eq!(
433 enc1.encode(&vec).expect("ok"),
434 enc2.encode(&vec).expect("ok")
435 );
436 }
437}