1use std::mem::size_of;
18use std::sync::Arc;
19
20use nodedb_mem::{EngineId, MemoryGovernor};
21use serde::{Deserialize, Serialize};
22
23use crate::error::VectorError;
24
25#[inline]
29fn try_reserve_or_skip(
30 governor: &Option<Arc<MemoryGovernor>>,
31 bytes: usize,
32) -> Result<Option<nodedb_mem::BudgetGuard>, VectorError> {
33 match governor {
34 Some(g) => Ok(Some(g.reserve(EngineId::Vector, bytes)?)),
35 None => Ok(None),
36 }
37}
38
39#[derive(
41 Clone, Debug, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
42)]
43pub struct PqCodec {
44 pub dim: usize,
46 pub m: usize,
48 pub k: usize,
50 pub sub_dim: usize,
52 codebooks: Vec<Vec<Vec<f32>>>,
55
56 #[serde(skip, default)]
59 #[msgpack(ignore)]
60 governor: Option<Arc<MemoryGovernor>>,
61}
62
63impl PqCodec {
64 pub fn with_governor(mut self, governor: Arc<MemoryGovernor>) -> Self {
75 self.governor = Some(governor);
76 self
77 }
78
79 pub fn train(vectors: &[&[f32]], dim: usize, m: usize, k: usize, max_iter: usize) -> Self {
85 assert!(!vectors.is_empty());
86 assert!(dim > 0 && m > 0 && k > 0);
87 assert!(
88 dim.is_multiple_of(m),
89 "dim ({dim}) must be divisible by m ({m})"
90 );
91
92 let sub_dim = dim / m;
93 let mut codebooks = Vec::with_capacity(m);
95
96 for sub in 0..m {
97 let offset = sub * sub_dim;
98 let sub_vectors: Vec<&[f32]> = vectors
100 .iter()
101 .map(|v| &v[offset..offset + sub_dim])
102 .collect();
103
104 let centroids = kmeans(&sub_vectors, sub_dim, k, max_iter);
105 codebooks.push(centroids);
106 }
107
108 Self {
109 dim,
110 m,
111 k,
112 sub_dim,
113 codebooks,
114 governor: None,
115 }
116 }
117
118 pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
125 debug_assert_eq!(vector.len(), self.dim);
126 let mut code = Vec::with_capacity(self.m);
128 for sub in 0..self.m {
129 let offset = sub * self.sub_dim;
130 let sub_vec = &vector[offset..offset + self.sub_dim];
131 let nearest = self.nearest_centroid(sub, sub_vec);
132 code.push(nearest as u8);
133 }
134 code
135 }
136
137 pub fn encode_batch(&self, vectors: &[&[f32]]) -> Result<Vec<u8>, VectorError> {
143 let capacity = self.m * vectors.len();
144 let _g = try_reserve_or_skip(&self.governor, capacity * size_of::<u8>())?;
145 let mut out = Vec::with_capacity(capacity);
147 for v in vectors {
148 out.extend(self.encode(v));
149 }
150 Ok(out)
151 }
152
153 pub fn build_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>, VectorError> {
162 debug_assert_eq!(query.len(), self.dim);
163 let total_bytes = self.m * self.k * size_of::<f32>();
164 let _g = try_reserve_or_skip(&self.governor, total_bytes)?;
165 let mut table = Vec::with_capacity(self.m);
167 for sub in 0..self.m {
168 let offset = sub * self.sub_dim;
169 let sub_query = &query[offset..offset + self.sub_dim];
170 let mut dists = Vec::with_capacity(self.k);
172 for centroid in &self.codebooks[sub] {
173 let d = l2_sub(sub_query, centroid);
174 dists.push(d);
175 }
176 table.push(dists);
177 }
178 Ok(table)
179 }
180
181 #[inline]
185 pub fn asymmetric_distance(&self, table: &[Vec<f32>], code: &[u8]) -> f32 {
186 debug_assert_eq!(code.len(), self.m);
187 let mut dist = 0.0f32;
188 for (sub, &c) in code.iter().enumerate() {
189 dist += table[sub][c as usize];
190 }
191 dist
192 }
193
194 pub fn decode(&self, code: &[u8]) -> Result<Vec<f32>, VectorError> {
199 debug_assert_eq!(code.len(), self.m);
200 let _g = try_reserve_or_skip(&self.governor, self.dim * size_of::<f32>())?;
201 let mut out = Vec::with_capacity(self.dim);
203 for (sub, &c) in code.iter().enumerate() {
204 out.extend_from_slice(&self.codebooks[sub][c as usize]);
205 }
206 Ok(out)
207 }
208
209 pub fn to_bytes(&self) -> Result<Vec<u8>, VectorError> {
217 const MAGIC: &[u8; 6] = b"NDPQ\0\0";
218 const VERSION: u8 = 1;
219 let estimated = self.m * self.k * self.sub_dim * size_of::<f32>() + 64;
220 let _g = try_reserve_or_skip(&self.governor, estimated)?;
221 let payload = zerompk::to_msgpack_vec(self).unwrap_or_default();
222 let mut out = Vec::with_capacity(7 + payload.len());
224 out.extend_from_slice(MAGIC);
225 out.push(VERSION);
226 out.extend_from_slice(&payload);
227 Ok(out)
228 }
229
230 pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
235 const MAGIC: &[u8; 6] = b"NDPQ\0\0";
236 const PQ_FORMAT_VERSION: u8 = 1;
237
238 if bytes.len() < 7 || &bytes[0..6] != MAGIC {
239 return Err(VectorError::InvalidMagic);
240 }
241 let version = bytes[6];
242 if version != PQ_FORMAT_VERSION {
243 return Err(VectorError::UnsupportedVersion {
244 found: version,
245 expected: PQ_FORMAT_VERSION,
246 });
247 }
248 zerompk::from_msgpack::<Self>(&bytes[7..])
249 .map_err(|e| VectorError::DeserializationFailed(e.to_string()))
250 }
251
252 fn nearest_centroid(&self, subspace: usize, sub_vec: &[f32]) -> usize {
253 let mut best_idx = 0;
254 let mut best_dist = f32::MAX;
255 for (i, centroid) in self.codebooks[subspace].iter().enumerate() {
256 let d = l2_sub(sub_vec, centroid);
257 if d < best_dist {
258 best_dist = d;
259 best_idx = i;
260 }
261 }
262 best_idx
263 }
264}
265
266#[inline]
268fn l2_sub(a: &[f32], b: &[f32]) -> f32 {
269 let mut sum = 0.0f32;
270 for i in 0..a.len() {
271 let d = a[i] - b[i];
272 sum += d * d;
273 }
274 sum
275}
276
277fn kmeans(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec<Vec<f32>> {
282 let n = data.len();
283 if n == 0 || k == 0 {
284 return Vec::new();
285 }
286 let k = k.min(n); let mut rng = crate::hnsw::Xorshift64::new(0xC0FF_EEDE_ADBE_EF42);
290
291 let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
293 centroids.push(data[0].to_vec());
294
295 let mut min_dists = vec![f32::MAX; n];
296 for (i, point) in data.iter().enumerate() {
298 let d = l2_sub(point, ¢roids[0]);
299 if d < min_dists[i] {
300 min_dists[i] = d;
301 }
302 }
303
304 for _ in 1..k {
305 let total: f64 = min_dists.iter().map(|&d| d as f64).sum();
306 let next_idx = if total < f64::EPSILON {
307 0
309 } else {
310 let target = rng.next_f64() * total;
311 let mut acc = 0.0f64;
312 let mut chosen = n - 1;
313 for (i, &d) in min_dists.iter().enumerate() {
314 acc += d as f64;
315 if acc >= target {
316 chosen = i;
317 break;
318 }
319 }
320 chosen
321 };
322 centroids.push(data[next_idx].to_vec());
323 let last = centroids.last().expect("just pushed");
325 for (i, point) in data.iter().enumerate() {
326 let d = l2_sub(point, last);
327 if d < min_dists[i] {
328 min_dists[i] = d;
329 }
330 }
331 }
332
333 let mut assignments = vec![0usize; n];
335 for _ in 0..max_iter {
336 let mut changed = false;
338 for (i, point) in data.iter().enumerate() {
339 let mut best = 0;
340 let mut best_d = f32::MAX;
341 for (c, centroid) in centroids.iter().enumerate() {
342 let d = l2_sub(point, centroid);
343 if d < best_d {
344 best_d = d;
345 best = c;
346 }
347 }
348 if assignments[i] != best {
349 assignments[i] = best;
350 changed = true;
351 }
352 }
353 if !changed {
354 break;
355 }
356
357 let mut sums = vec![vec![0.0f32; dim]; k];
359 let mut counts = vec![0usize; k];
360 for (i, point) in data.iter().enumerate() {
361 let c = assignments[i];
362 counts[c] += 1;
363 for d in 0..dim {
364 sums[c][d] += point[d];
365 }
366 }
367 for c in 0..k {
368 if counts[c] > 0 {
369 for d in 0..dim {
370 centroids[c][d] = sums[c][d] / counts[c] as f32;
371 }
372 }
373 }
374 }
375
376 centroids
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 fn make_clustered_data() -> Vec<Vec<f32>> {
384 let mut vecs = Vec::new();
386 for cluster in 0..4 {
387 let center = cluster as f32 * 10.0;
388 for i in 0..50 {
389 vecs.push(vec![
390 center + (i as f32) * 0.1,
391 center + (i as f32) * 0.05,
392 center - (i as f32) * 0.1,
393 center + (i as f32) * 0.02,
394 ]);
395 }
396 }
397 vecs
398 }
399
400 #[test]
401 fn encode_decode_roundtrip() {
402 let vecs = make_clustered_data();
403 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
404 let codec = PqCodec::train(&refs, 4, 2, 16, 10);
405
406 for v in &vecs {
407 let code = codec.encode(v);
408 assert_eq!(code.len(), 2); let decoded = codec.decode(&code).unwrap();
410 assert_eq!(decoded.len(), 4);
411 }
412 }
413
414 #[test]
415 fn distance_table_gives_correct_ordering() {
416 let vecs = make_clustered_data();
417 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
418 let codec = PqCodec::train(&refs, 4, 2, 16, 10);
419
420 let codes: Vec<Vec<u8>> = vecs.iter().map(|v| codec.encode(v)).collect();
421 let query = &[5.0, 5.0, 5.0, 5.0];
422 let table = codec.build_distance_table(query).unwrap();
423
424 let mut pq_dists: Vec<(usize, f32)> = codes
426 .iter()
427 .enumerate()
428 .map(|(i, c)| (i, codec.asymmetric_distance(&table, c)))
429 .collect();
430 pq_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
431
432 let mut exact_dists: Vec<(usize, f32)> = vecs
434 .iter()
435 .enumerate()
436 .map(|(i, v)| (i, l2_sub(query, v)))
437 .collect();
438 exact_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
439
440 let pq_top: std::collections::HashSet<usize> = pq_dists[..5].iter().map(|x| x.0).collect();
442 let exact_top: std::collections::HashSet<usize> =
443 exact_dists[..10].iter().map(|x| x.0).collect();
444 let overlap = pq_top.intersection(&exact_top).count();
445 assert!(overlap >= 3, "PQ recall too low: {overlap}/5 in top-10");
446 }
447
448 #[test]
449 fn batch_encode() {
450 let vecs = make_clustered_data();
451 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
452 let codec = PqCodec::train(&refs, 4, 2, 16, 10);
453
454 let batch = codec.encode_batch(&refs).unwrap();
455 assert_eq!(batch.len(), 2 * 200); }
457
458 #[test]
460 fn pq_codec_golden_format() {
461 let vecs = make_clustered_data();
462 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
463 let codec = PqCodec::train(&refs, 4, 2, 16, 10);
464
465 let bytes = codec.to_bytes().unwrap();
466
467 assert_eq!(&bytes[0..6], b"NDPQ\0\0", "magic mismatch");
469 assert_eq!(bytes[6], 1u8, "version must be 1");
471 let restored = zerompk::from_msgpack::<PqCodec>(&bytes[7..])
473 .expect("msgpack payload at offset 7 must decode");
474 assert_eq!(restored.dim, codec.dim);
475 assert_eq!(restored.m, codec.m);
476 }
477
478 #[test]
479 fn pq_version_mismatch_returns_error() {
480 let mut crafted = b"NDPQ\0\0".to_vec();
482 crafted.push(0u8); crafted.extend_from_slice(b"\x80"); let err = PqCodec::from_bytes(&crafted).unwrap_err();
486 assert!(
487 matches!(
488 err,
489 VectorError::UnsupportedVersion {
490 found: 0,
491 expected: 1
492 }
493 ),
494 "expected UnsupportedVersion, got: {err:?}"
495 );
496 }
497
498 #[test]
499 fn pq_invalid_magic_returns_error() {
500 let bad: &[u8] = b"JUNK\0\0\x01some-payload";
501 let err = PqCodec::from_bytes(bad).unwrap_err();
502 assert!(
503 matches!(err, VectorError::InvalidMagic),
504 "expected InvalidMagic, got: {err:?}"
505 );
506 }
507}