reddb_server/storage/engine/
pq.rs1use std::collections::HashMap;
33
34use super::distance::{cmp_f32, l2_squared_simd};
35use super::hnsw::NodeId;
36
37#[derive(Clone, Debug)]
39pub struct PQConfig {
40 pub dimension: usize,
42 pub n_subvectors: usize,
44 pub n_centroids: usize,
46 pub max_iterations: usize,
48}
49
50impl Default for PQConfig {
51 fn default() -> Self {
52 Self {
53 dimension: 128,
54 n_subvectors: 8,
55 n_centroids: 256,
56 max_iterations: 25,
57 }
58 }
59}
60
61impl PQConfig {
62 pub fn new(dimension: usize, n_subvectors: usize) -> Self {
63 assert!(
64 dimension.is_multiple_of(n_subvectors),
65 "dimension must be divisible by n_subvectors"
66 );
67 Self {
68 dimension,
69 n_subvectors,
70 n_centroids: 256,
71 max_iterations: 25,
72 }
73 }
74
75 pub fn subvector_dim(&self) -> usize {
77 self.dimension / self.n_subvectors
78 }
79}
80
81#[derive(Clone)]
83struct Codebook {
84 centroids: Vec<Vec<f32>>,
86 dim: usize,
88}
89
90impl Codebook {
91 fn new(dim: usize, n_centroids: usize) -> Self {
92 Self {
93 centroids: vec![vec![0.0; dim]; n_centroids],
94 dim,
95 }
96 }
97
98 fn train(&mut self, subvectors: &[Vec<f32>], max_iterations: usize) {
100 if subvectors.is_empty() {
101 return;
102 }
103
104 let k = self.centroids.len();
105
106 let step = subvectors.len().max(1) / k.max(1);
108 for (i, centroid) in self.centroids.iter_mut().enumerate() {
109 let idx = (i * step).min(subvectors.len() - 1);
110 *centroid = subvectors[idx].clone();
111 }
112
113 for _ in 0..max_iterations {
115 let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); k];
117 for (i, sv) in subvectors.iter().enumerate() {
118 let nearest = self.find_nearest(sv);
119 assignments[nearest].push(i);
120 }
121
122 let mut converged = true;
124 for (ci, indices) in assignments.iter().enumerate() {
125 if indices.is_empty() {
126 continue;
127 }
128
129 let mut new_centroid = vec![0.0f32; self.dim];
130 for &idx in indices {
131 for (j, &val) in subvectors[idx].iter().enumerate() {
132 new_centroid[j] += val;
133 }
134 }
135 for val in &mut new_centroid {
136 *val /= indices.len() as f32;
137 }
138
139 let shift = l2_squared_simd(&new_centroid, &self.centroids[ci]).sqrt();
141 if shift > 1e-4 {
142 converged = false;
143 }
144
145 self.centroids[ci] = new_centroid;
146 }
147
148 if converged {
149 break;
150 }
151 }
152 }
153
154 fn find_nearest(&self, subvector: &[f32]) -> usize {
156 self.centroids
157 .iter()
158 .enumerate()
159 .map(|(i, c)| (i, l2_squared_simd(subvector, c)))
160 .min_by(|(li, la), (ri, lb)| cmp_f32(*la, *lb).then_with(|| li.cmp(ri)))
161 .map(|(i, _)| i)
162 .unwrap_or(0)
163 }
164
165 fn compute_distance_table(&self, query_subvector: &[f32]) -> Vec<f32> {
167 self.centroids
168 .iter()
169 .map(|c| l2_squared_simd(query_subvector, c))
170 .collect()
171 }
172}
173
174pub type PQCode = Vec<u8>;
176
177pub struct ProductQuantizer {
179 config: PQConfig,
180 codebooks: Vec<Codebook>,
182 trained: bool,
184}
185
186impl ProductQuantizer {
187 pub fn new(config: PQConfig) -> Self {
189 let subdim = config.subvector_dim();
190 let codebooks = (0..config.n_subvectors)
191 .map(|_| Codebook::new(subdim, config.n_centroids))
192 .collect();
193
194 Self {
195 config,
196 codebooks,
197 trained: false,
198 }
199 }
200
201 pub fn with_dimension(dimension: usize) -> Self {
203 let n_subvectors = if dimension >= 64 { 8 } else { 4 };
205 Self::new(PQConfig::new(dimension, n_subvectors))
206 }
207
208 pub fn train(&mut self, vectors: &[Vec<f32>]) {
210 if vectors.is_empty() {
211 return;
212 }
213
214 let subdim = self.config.subvector_dim();
215
216 for (m, codebook) in self.codebooks.iter_mut().enumerate() {
218 let subvectors: Vec<Vec<f32>> = vectors
220 .iter()
221 .map(|v| v[m * subdim..(m + 1) * subdim].to_vec())
222 .collect();
223
224 codebook.train(&subvectors, self.config.max_iterations);
225 }
226
227 self.trained = true;
228 }
229
230 pub fn encode(&self, vector: &[f32]) -> PQCode {
232 let subdim = self.config.subvector_dim();
233
234 self.codebooks
235 .iter()
236 .enumerate()
237 .map(|(m, codebook)| {
238 let subvector = &vector[m * subdim..(m + 1) * subdim];
239 codebook.find_nearest(subvector) as u8
240 })
241 .collect()
242 }
243
244 pub fn encode_batch(&self, vectors: &[Vec<f32>]) -> Vec<PQCode> {
246 vectors.iter().map(|v| self.encode(v)).collect()
247 }
248
249 pub fn decode(&self, code: &PQCode) -> Vec<f32> {
251 let subdim = self.config.subvector_dim();
252 let mut vector = Vec::with_capacity(self.config.dimension);
253
254 for (m, &c) in code.iter().enumerate() {
255 let centroid = &self.codebooks[m].centroids[c as usize];
256 vector.extend_from_slice(centroid);
257 }
258
259 vector
260 }
261
262 pub fn compute_distances(&self, query: &[f32], codes: &[PQCode]) -> Vec<f32> {
264 let subdim = self.config.subvector_dim();
266 let tables: Vec<Vec<f32>> = self
267 .codebooks
268 .iter()
269 .enumerate()
270 .map(|(m, codebook)| {
271 let subquery = &query[m * subdim..(m + 1) * subdim];
272 codebook.compute_distance_table(subquery)
273 })
274 .collect();
275
276 codes
278 .iter()
279 .map(|code| {
280 code.iter()
281 .enumerate()
282 .map(|(m, &c)| tables[m][c as usize])
283 .sum::<f32>()
284 .sqrt()
285 })
286 .collect()
287 }
288
289 pub fn compression_ratio(&self) -> f32 {
291 let original_bytes = self.config.dimension * 4; let compressed_bytes = self.config.n_subvectors; original_bytes as f32 / compressed_bytes as f32
294 }
295
296 pub fn config(&self) -> &PQConfig {
298 &self.config
299 }
300
301 pub fn is_trained(&self) -> bool {
303 self.trained
304 }
305}
306
307pub struct PQIndex {
309 pq: ProductQuantizer,
311 codes: Vec<PQCode>,
313 ids: Vec<NodeId>,
315 id_to_idx: HashMap<NodeId, usize>,
317 originals: Option<Vec<Vec<f32>>>,
319 next_id: NodeId,
321}
322
323impl PQIndex {
324 pub fn new(config: PQConfig) -> Self {
326 Self {
327 pq: ProductQuantizer::new(config),
328 codes: Vec::new(),
329 ids: Vec::new(),
330 id_to_idx: HashMap::new(),
331 originals: None,
332 next_id: 0,
333 }
334 }
335
336 pub fn with_originals(mut self) -> Self {
338 self.originals = Some(Vec::new());
339 self
340 }
341
342 pub fn train(&mut self, vectors: &[Vec<f32>]) {
344 self.pq.train(vectors);
345 }
346
347 pub fn add(&mut self, vector: Vec<f32>) -> NodeId {
349 let id = self.next_id;
350 self.next_id += 1;
351 self.add_with_id(id, vector);
352 id
353 }
354
355 pub fn add_with_id(&mut self, id: NodeId, vector: Vec<f32>) {
357 let code = self.pq.encode(&vector);
358 let idx = self.codes.len();
359
360 self.codes.push(code);
361 self.ids.push(id);
362 self.id_to_idx.insert(id, idx);
363
364 if let Some(ref mut originals) = self.originals {
365 originals.push(vector);
366 }
367 }
368
369 pub fn add_batch(&mut self, vectors: Vec<Vec<f32>>) -> Vec<NodeId> {
371 vectors.into_iter().map(|v| self.add(v)).collect()
372 }
373
374 pub fn search(&self, query: &[f32], k: usize) -> Vec<(NodeId, f32)> {
376 if self.codes.is_empty() {
377 return Vec::new();
378 }
379
380 let distances = self.pq.compute_distances(query, &self.codes);
381
382 let mut results: Vec<(usize, f32)> = distances.into_iter().enumerate().collect();
383
384 results.sort_by(|(li, la), (ri, lb)| cmp_f32(*la, *lb).then_with(|| li.cmp(ri)));
385 results.truncate(k);
386
387 results
388 .into_iter()
389 .map(|(idx, dist)| (self.ids[idx], dist))
390 .collect()
391 }
392
393 pub fn search_rerank(&self, query: &[f32], k: usize, rerank_k: usize) -> Vec<(NodeId, f32)> {
395 let originals = match &self.originals {
396 Some(o) => o,
397 None => return self.search(query, k),
398 };
399
400 let candidates = self.search(query, rerank_k);
402
403 let mut reranked: Vec<(NodeId, f32)> = candidates
405 .into_iter()
406 .map(|(id, _)| {
407 let idx = self.id_to_idx[&id];
408 let dist = l2_squared_simd(query, &originals[idx]).sqrt();
409 (id, dist)
410 })
411 .collect();
412
413 reranked.sort_by(|(li, la), (ri, lb)| cmp_f32(*la, *lb).then_with(|| li.cmp(ri)));
414 reranked.truncate(k);
415 reranked
416 }
417
418 pub fn len(&self) -> usize {
420 self.codes.len()
421 }
422
423 pub fn is_empty(&self) -> bool {
425 self.codes.is_empty()
426 }
427
428 pub fn compression_ratio(&self) -> f32 {
430 self.pq.compression_ratio()
431 }
432
433 pub fn memory_usage(&self) -> usize {
435 let code_bytes = self.codes.len() * self.pq.config.n_subvectors;
436 let original_bytes = self
437 .originals
438 .as_ref()
439 .map(|o| o.len() * self.pq.config.dimension * 4)
440 .unwrap_or(0);
441 code_bytes + original_bytes
442 }
443}
444
445#[cfg(test)]
450mod tests {
451 use super::*;
452
453 fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
454 (0..dim)
455 .map(|i| ((seed * 1103515245 + i as u64 * 12345) % 1000) as f32 / 1000.0)
456 .collect()
457 }
458
459 #[test]
460 fn test_pq_encode_decode() {
461 let config = PQConfig::new(16, 4);
462 let mut pq = ProductQuantizer::new(config);
463
464 let training: Vec<Vec<f32>> = (0..100).map(|i| random_vector(16, i)).collect();
466
467 pq.train(&training);
468 assert!(pq.is_trained());
469
470 let original = random_vector(16, 999);
472 let code = pq.encode(&original);
473 let decoded = pq.decode(&code);
474
475 assert_eq!(code.len(), 4); assert_eq!(decoded.len(), 16);
477
478 let reconstruction_error: f32 = original
480 .iter()
481 .zip(decoded.iter())
482 .map(|(a, b)| (a - b).powi(2))
483 .sum();
484 assert!(reconstruction_error < 1.0); }
486
487 #[test]
488 fn test_pq_compression_ratio() {
489 let pq = ProductQuantizer::new(PQConfig::new(128, 8));
490 assert_eq!(pq.compression_ratio(), 64.0);
494 }
495
496 #[test]
497 fn test_pq_index_search() {
498 let mut index = PQIndex::new(PQConfig::new(8, 4));
499
500 let training: Vec<Vec<f32>> = (0..50).map(|i| random_vector(8, i)).collect();
502
503 index.train(&training);
504
505 for (i, v) in training.iter().enumerate() {
507 index.add_with_id(i as u64, v.clone());
508 }
509
510 let query = random_vector(8, 0);
512 let results = index.search(&query, 5);
513
514 assert_eq!(results.len(), 5);
515 assert_eq!(results[0].0, 0);
517 }
518
519 #[test]
520 fn test_pq_distance_tables() {
521 let config = PQConfig::new(8, 2);
522 let mut pq = ProductQuantizer::new(config);
523
524 let training: Vec<Vec<f32>> = vec![
525 vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
526 vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
527 ];
528
529 pq.train(&training);
530
531 let query = vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5];
532 let codes = pq.encode_batch(&training);
533 let distances = pq.compute_distances(&query, &codes);
534
535 assert_eq!(distances.len(), 2);
536 assert!((distances[0] - distances[1]).abs() < 0.1);
538 }
539}