nodedb_codec/vector_quant/
opq.rs1use nalgebra::{DMatrix, SVD};
40
41use crate::vector_quant::codec::{AdcLut, VectorCodec};
42use crate::vector_quant::layout::{QuantHeader, QuantMode, UnifiedQuantizedVector};
43use crate::vector_quant::opq_kmeans::l2_sq;
44use crate::vector_quant::opq_kmeans::lloyd;
45
46pub struct OpqCodec {
53 pub dim: usize,
54 pub m: usize,
56 pub k: usize,
58 pub sub_dim: usize,
59 rotation: Vec<f32>,
61 codebooks: Vec<Vec<Vec<f32>>>,
63}
64
65impl OpqCodec {
66 pub fn train(
75 vectors: &[&[f32]],
76 dim: usize,
77 m: usize,
78 k: usize,
79 opq_iters: usize,
80 kmeans_iters: usize,
81 ) -> Self {
82 assert!(!vectors.is_empty(), "training set must be non-empty");
83 assert!(dim > 0 && m > 0 && k > 0, "dim/m/k must be positive");
84 assert!(
85 dim.is_multiple_of(m),
86 "dim ({dim}) must be divisible by m ({m})"
87 );
88 let sub_dim = dim / m;
89 let seed = dim as u64 ^ ((m as u64) << 16) ^ ((k as u64) << 32);
90
91 let mut rotation = identity(dim);
92 let mut codebooks: Vec<Vec<Vec<f32>>> = Vec::new();
93
94 let iters = opq_iters.max(1);
95
96 for iter in 0..iters {
97 let rotated: Vec<Vec<f32>> =
99 vectors.iter().map(|v| matvec(&rotation, v, dim)).collect();
100 codebooks = train_codebooks(&rotated, m, k, sub_dim, kmeans_iters, seed ^ iter as u64);
101
102 if iter + 1 < iters {
113 let n = vectors.len();
114 let x_mat = DMatrix::from_fn(dim, n, |row, col| vectors[col][row]);
117 let y_mat = {
118 let recon: Vec<Vec<f32>> = rotated
119 .iter()
120 .map(|rv| {
121 let codes = pq_encode(rv, &codebooks, m, sub_dim);
122 dequantize_codes(&codes, &codebooks)
123 })
124 .collect();
125 DMatrix::from_fn(dim, n, |row, col| recon[col][row])
126 };
127
128 let m_mat = &x_mat * y_mat.transpose();
130
131 let has_nan = m_mat.iter().any(|x| x.is_nan());
134 if !has_nan {
135 let svd = SVD::new(m_mat, true, true);
136 if let (Some(u), Some(v_t)) = (svd.u, svd.v_t) {
137 let r_new = v_t.transpose() * u.transpose();
139 let mut buf = Vec::with_capacity(dim * dim);
141 for i in 0..dim {
142 for j in 0..dim {
143 buf.push(r_new[(i, j)]);
144 }
145 }
146 rotation = buf;
147 }
148 }
149 }
150 }
151
152 Self {
153 dim,
154 m,
155 k,
156 sub_dim,
157 rotation,
158 codebooks,
159 }
160 }
161
162 pub fn apply_rotation(&self, v: &[f32]) -> Vec<f32> {
164 matvec(&self.rotation, v, self.dim)
165 }
166
167 fn encode_inner(&self, v: &[f32]) -> (Vec<u8>, UnifiedQuantizedVector) {
168 let rotated = self.apply_rotation(v);
169 let codes = pq_encode(&rotated, &self.codebooks, self.m, self.sub_dim);
170 let uqv = make_uqv(&codes, self.dim as u16);
171 (codes, uqv)
172 }
173
174 fn dequantize(&self, codes: &[u8]) -> Vec<f32> {
175 dequantize_codes(codes, &self.codebooks)
176 }
177}
178
179fn identity(dim: usize) -> Vec<f32> {
183 let mut mat = vec![0.0f32; dim * dim];
184 for i in 0..dim {
185 mat[i * dim + i] = 1.0;
186 }
187 mat
188}
189
190fn dequantize_codes(codes: &[u8], codebooks: &[Vec<Vec<f32>>]) -> Vec<f32> {
192 let mut out = Vec::with_capacity(codebooks.len() * codebooks[0][0].len());
193 for (s, &c) in codes.iter().enumerate() {
194 out.extend_from_slice(&codebooks[s][c as usize]);
195 }
196 out
197}
198
199#[inline]
201fn matvec(r: &[f32], v: &[f32], dim: usize) -> Vec<f32> {
202 let mut out = vec![0.0f32; dim];
203 for i in 0..dim {
204 let row = &r[i * dim..(i + 1) * dim];
205 out[i] = row.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
206 }
207 out
208}
209
210fn pq_encode(v: &[f32], codebooks: &[Vec<Vec<f32>>], m: usize, sub_dim: usize) -> Vec<u8> {
211 let mut codes = Vec::with_capacity(m);
212 #[allow(clippy::needless_range_loop)]
213 for s in 0..m {
214 let offset = s * sub_dim;
215 let sub = &v[offset..offset + sub_dim];
216 let best = codebooks[s]
217 .iter()
218 .enumerate()
219 .min_by(|(_, a), (_, b)| {
220 l2_sq(sub, a)
221 .partial_cmp(&l2_sq(sub, b))
222 .unwrap_or(std::cmp::Ordering::Equal)
223 })
224 .map(|(i, _)| i)
225 .unwrap_or(0);
226 codes.push(best as u8);
227 }
228 codes
229}
230
231fn train_codebooks(
232 rotated: &[Vec<f32>],
233 m: usize,
234 k: usize,
235 sub_dim: usize,
236 kmeans_iters: usize,
237 seed: u64,
238) -> Vec<Vec<Vec<f32>>> {
239 let mut codebooks = Vec::with_capacity(m);
240 for s in 0..m {
241 let offset = s * sub_dim;
242 let sub_vecs: Vec<Vec<f32>> = rotated
243 .iter()
244 .map(|v| v[offset..offset + sub_dim].to_vec())
245 .collect();
246 let centroids = lloyd(
247 &sub_vecs,
248 sub_dim,
249 k,
250 kmeans_iters,
251 seed ^ (s as u64 * 0x1234567),
252 );
253 codebooks.push(centroids);
254 }
255 codebooks
256}
257
258fn make_uqv(codes: &[u8], dim: u16) -> UnifiedQuantizedVector {
259 let header = QuantHeader {
260 quant_mode: QuantMode::Pq as u16,
261 dim,
262 global_scale: 1.0,
263 residual_norm: 0.0,
264 dot_quantized: 0.0,
265 outlier_bitmask: 0,
266 reserved: [0; 8],
267 };
268 UnifiedQuantizedVector::new(header, codes, &[])
269 .expect("make_uqv: layout construction must not fail for valid inputs")
270}
271
272pub struct OpqQuantized {
276 codes: Vec<u8>,
277 uqv: UnifiedQuantizedVector,
278}
279
280impl AsRef<UnifiedQuantizedVector> for OpqQuantized {
281 fn as_ref(&self) -> &UnifiedQuantizedVector {
282 &self.uqv
283 }
284}
285
286pub struct OpqQuery {
288 pub distance_table: Vec<f32>,
289 #[allow(dead_code)]
290 rotated: Vec<f32>,
291}
292
293impl VectorCodec for OpqCodec {
296 type Quantized = OpqQuantized;
297 type Query = OpqQuery;
298
299 fn encode(&self, v: &[f32]) -> Self::Quantized {
300 let (codes, uqv) = self.encode_inner(v);
301 OpqQuantized { codes, uqv }
302 }
303
304 fn prepare_query(&self, q: &[f32]) -> Self::Query {
306 let rotated = self.apply_rotation(q);
307 let mut table = vec![0.0f32; self.m * self.k];
308 for s in 0..self.m {
309 let offset = s * self.sub_dim;
310 let sub_q = &rotated[offset..offset + self.sub_dim];
311 for c in 0..self.k {
312 table[s * self.k + c] = l2_sq(sub_q, &self.codebooks[s][c]);
313 }
314 }
315 OpqQuery {
316 distance_table: table,
317 rotated,
318 }
319 }
320
321 fn adc_lut(&self, q: &Self::Query) -> Option<AdcLut> {
322 let mut lut = AdcLut::new(self.m as u16, self.k as u16);
323 lut.table.copy_from_slice(&q.distance_table);
324 Some(lut)
325 }
326
327 fn fast_symmetric_distance(&self, q: &Self::Quantized, v: &Self::Quantized) -> f32 {
329 let qv = self.dequantize(&q.codes);
330 let vv = self.dequantize(&v.codes);
331 l2_sq(&qv, &vv)
332 }
333
334 fn exact_asymmetric_distance(&self, q: &Self::Query, v: &Self::Quantized) -> f32 {
336 v.codes
337 .iter()
338 .enumerate()
339 .map(|(s, &code)| q.distance_table[s * self.k + code as usize])
340 .sum()
341 }
342}
343
344#[cfg(test)]
347mod tests {
348 use super::*;
349
350 fn tiny_dataset() -> Vec<Vec<f32>> {
351 (0..10)
352 .map(|i| {
353 let base = i as f32 * 2.0;
354 vec![
355 base,
356 base + 0.1,
357 base - 0.1,
358 base + 0.2,
359 base * 0.5,
360 base * 0.5 + 0.1,
361 base * 0.5 - 0.1,
362 base * 0.5 + 0.05,
363 ]
364 })
365 .collect()
366 }
367
368 fn train_tiny() -> OpqCodec {
369 let vecs = tiny_dataset();
370 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
371 OpqCodec::train(&refs, 8, 2, 4, 10, 30)
372 }
373
374 #[test]
375 fn encode_produces_m_bytes() {
376 let codec = train_tiny();
377 let vecs = tiny_dataset();
378 for v in &vecs {
379 let q = codec.encode(v);
380 assert_eq!(q.codes.len(), codec.m);
381 }
382 }
383
384 #[test]
385 fn distance_is_non_negative() {
386 let codec = train_tiny();
387 let vecs = tiny_dataset();
388 for v in &vecs {
389 let qv = codec.encode(v);
390 let qq = codec.prepare_query(v);
391 let asym = codec.exact_asymmetric_distance(&qq, &qv);
392 let sym = codec.fast_symmetric_distance(&qv, &qv);
393 assert!(
394 asym >= 0.0,
395 "asymmetric distance must be non-negative, got {asym}"
396 );
397 assert!(
398 sym >= 0.0,
399 "symmetric distance must be non-negative, got {sym}"
400 );
401 }
402 }
403
404 #[test]
405 fn top1_recall_on_training_set() {
406 let vecs = tiny_dataset();
407 let codec = train_tiny();
408 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
409 let encoded: Vec<_> = refs.iter().map(|v| codec.encode(v)).collect();
410
411 let mut correct = 0usize;
412 for (i, v) in refs.iter().enumerate() {
413 let query = codec.prepare_query(v);
414 let best = encoded
415 .iter()
416 .enumerate()
417 .min_by(|(_, a), (_, b)| {
418 codec
419 .exact_asymmetric_distance(&query, a)
420 .partial_cmp(&codec.exact_asymmetric_distance(&query, b))
421 .unwrap_or(std::cmp::Ordering::Equal)
422 })
423 .map(|(idx, _)| idx)
424 .unwrap_or(usize::MAX);
425 if best == i {
426 correct += 1;
427 }
428 }
429 let recall = correct as f64 / vecs.len() as f64;
430 assert!(
435 recall >= 0.70,
436 "top-1 recall on training set too low: {correct}/{} = {recall:.2}",
437 vecs.len()
438 );
439 }
440
441 #[test]
442 fn more_iterations_reduce_reconstruction_error() {
443 let vecs = tiny_dataset();
444 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
445
446 let codec_1 = OpqCodec::train(&refs, 8, 2, 4, 1, 10);
447 let codec_5 = OpqCodec::train(&refs, 8, 2, 4, 5, 10);
448
449 let mean_recon_error = |codec: &OpqCodec| -> f32 {
450 refs.iter()
451 .map(|v| {
452 let rotated = codec.apply_rotation(v);
453 let codes = pq_encode(&rotated, &codec.codebooks, codec.m, codec.sub_dim);
454 let recon = dequantize_codes(&codes, &codec.codebooks);
455 l2_sq(&rotated, &recon)
456 })
457 .sum::<f32>()
458 / refs.len() as f32
459 };
460
461 let err_1 = mean_recon_error(&codec_1);
462 let err_5 = mean_recon_error(&codec_5);
463
464 assert!(
465 err_5 <= err_1 * 1.05,
466 "5-iter OPQ (err={err_5:.4}) should have ≤ reconstruction error than 1-iter (err={err_1:.4})"
467 );
468 }
469}