1use std::mem::size_of;
18use std::sync::Arc;
19
20use nodedb_mem::{EngineId, MemoryGovernor};
21use serde::{Deserialize, Serialize};
22
23use crate::error::VectorError;
24
25#[inline]
29fn try_reserve_or_skip(
30 governor: &Option<Arc<MemoryGovernor>>,
31 bytes: usize,
32) -> Result<Option<nodedb_mem::BudgetGuard>, VectorError> {
33 match governor {
34 Some(g) => Ok(Some(g.reserve(EngineId::Vector, bytes)?)),
35 None => Ok(None),
36 }
37}
38
39#[derive(
41 Clone, Debug, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
42)]
43pub struct PqCodec {
44 pub dim: usize,
46 pub m: usize,
48 pub k: usize,
50 pub sub_dim: usize,
52 codebooks: Vec<Vec<Vec<f32>>>,
55
56 #[serde(skip, default)]
59 #[msgpack(ignore)]
60 governor: Option<Arc<MemoryGovernor>>,
61}
62
63impl PqCodec {
64 pub fn with_governor(mut self, governor: Arc<MemoryGovernor>) -> Self {
75 self.governor = Some(governor);
76 self
77 }
78
79 pub fn train(vectors: &[&[f32]], dim: usize, m: usize, k: usize, max_iter: usize) -> Self {
85 assert!(!vectors.is_empty());
86 assert!(dim > 0 && m > 0 && k > 0);
87 assert!(
88 dim.is_multiple_of(m),
89 "dim ({dim}) must be divisible by m ({m})"
90 );
91
92 let sub_dim = dim / m;
93 let mut codebooks = Vec::with_capacity(m);
94
95 for sub in 0..m {
96 let offset = sub * sub_dim;
97 let sub_vectors: Vec<&[f32]> = vectors
99 .iter()
100 .map(|v| &v[offset..offset + sub_dim])
101 .collect();
102
103 let centroids = kmeans(&sub_vectors, sub_dim, k, max_iter);
104 codebooks.push(centroids);
105 }
106
107 Self {
108 dim,
109 m,
110 k,
111 sub_dim,
112 codebooks,
113 governor: None,
114 }
115 }
116
117 pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
124 debug_assert_eq!(vector.len(), self.dim);
125 let mut code = Vec::with_capacity(self.m);
126 for sub in 0..self.m {
127 let offset = sub * self.sub_dim;
128 let sub_vec = &vector[offset..offset + self.sub_dim];
129 let nearest = self.nearest_centroid(sub, sub_vec);
130 code.push(nearest as u8);
131 }
132 code
133 }
134
135 pub fn encode_batch(&self, vectors: &[&[f32]]) -> Result<Vec<u8>, VectorError> {
141 let capacity = self.m * vectors.len();
142 let _g = try_reserve_or_skip(&self.governor, capacity * size_of::<u8>())?;
143 let mut out = Vec::with_capacity(capacity);
144 for v in vectors {
145 out.extend(self.encode(v));
146 }
147 Ok(out)
148 }
149
150 pub fn build_distance_table(&self, query: &[f32]) -> Result<Vec<Vec<f32>>, VectorError> {
159 debug_assert_eq!(query.len(), self.dim);
160 let total_bytes = self.m * self.k * size_of::<f32>();
161 let _g = try_reserve_or_skip(&self.governor, total_bytes)?;
162 let mut table = Vec::with_capacity(self.m);
163 for sub in 0..self.m {
164 let offset = sub * self.sub_dim;
165 let sub_query = &query[offset..offset + self.sub_dim];
166 let mut dists = Vec::with_capacity(self.k);
167 for centroid in &self.codebooks[sub] {
168 let d = l2_sub(sub_query, centroid);
169 dists.push(d);
170 }
171 table.push(dists);
172 }
173 Ok(table)
174 }
175
176 #[inline]
180 pub fn asymmetric_distance(&self, table: &[Vec<f32>], code: &[u8]) -> f32 {
181 debug_assert_eq!(code.len(), self.m);
182 let mut dist = 0.0f32;
183 for (sub, &c) in code.iter().enumerate() {
184 dist += table[sub][c as usize];
185 }
186 dist
187 }
188
189 pub fn decode(&self, code: &[u8]) -> Result<Vec<f32>, VectorError> {
194 debug_assert_eq!(code.len(), self.m);
195 let _g = try_reserve_or_skip(&self.governor, self.dim * size_of::<f32>())?;
196 let mut out = Vec::with_capacity(self.dim);
197 for (sub, &c) in code.iter().enumerate() {
198 out.extend_from_slice(&self.codebooks[sub][c as usize]);
199 }
200 Ok(out)
201 }
202
203 pub fn to_bytes(&self) -> Result<Vec<u8>, VectorError> {
211 const MAGIC: &[u8; 6] = b"NDPQ\0\0";
212 const VERSION: u8 = 1;
213 let estimated = self.m * self.k * self.sub_dim * size_of::<f32>() + 64;
214 let _g = try_reserve_or_skip(&self.governor, estimated)?;
215 let payload = zerompk::to_msgpack_vec(self).unwrap_or_default();
216 let mut out = Vec::with_capacity(7 + payload.len());
217 out.extend_from_slice(MAGIC);
218 out.push(VERSION);
219 out.extend_from_slice(&payload);
220 Ok(out)
221 }
222
223 pub fn from_bytes(bytes: &[u8]) -> Result<Self, VectorError> {
228 const MAGIC: &[u8; 6] = b"NDPQ\0\0";
229 const PQ_FORMAT_VERSION: u8 = 1;
230
231 if bytes.len() < 7 || &bytes[0..6] != MAGIC {
232 return Err(VectorError::InvalidMagic);
233 }
234 let version = bytes[6];
235 if version != PQ_FORMAT_VERSION {
236 return Err(VectorError::UnsupportedVersion {
237 found: version,
238 expected: PQ_FORMAT_VERSION,
239 });
240 }
241 zerompk::from_msgpack::<Self>(&bytes[7..])
242 .map_err(|e| VectorError::DeserializationFailed(e.to_string()))
243 }
244
245 fn nearest_centroid(&self, subspace: usize, sub_vec: &[f32]) -> usize {
246 let mut best_idx = 0;
247 let mut best_dist = f32::MAX;
248 for (i, centroid) in self.codebooks[subspace].iter().enumerate() {
249 let d = l2_sub(sub_vec, centroid);
250 if d < best_dist {
251 best_dist = d;
252 best_idx = i;
253 }
254 }
255 best_idx
256 }
257}
258
259#[inline]
261fn l2_sub(a: &[f32], b: &[f32]) -> f32 {
262 let mut sum = 0.0f32;
263 for i in 0..a.len() {
264 let d = a[i] - b[i];
265 sum += d * d;
266 }
267 sum
268}
269
270fn kmeans(data: &[&[f32]], dim: usize, k: usize, max_iter: usize) -> Vec<Vec<f32>> {
275 let n = data.len();
276 if n == 0 || k == 0 {
277 return Vec::new();
278 }
279 let k = k.min(n); let mut rng = crate::hnsw::Xorshift64::new(0xC0FF_EEDE_ADBE_EF42);
283
284 let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
285 centroids.push(data[0].to_vec());
286
287 let mut min_dists = vec![f32::MAX; n];
288 for (i, point) in data.iter().enumerate() {
290 let d = l2_sub(point, ¢roids[0]);
291 if d < min_dists[i] {
292 min_dists[i] = d;
293 }
294 }
295
296 for _ in 1..k {
297 let total: f64 = min_dists.iter().map(|&d| d as f64).sum();
298 let next_idx = if total < f64::EPSILON {
299 0
301 } else {
302 let target = rng.next_f64() * total;
303 let mut acc = 0.0f64;
304 let mut chosen = n - 1;
305 for (i, &d) in min_dists.iter().enumerate() {
306 acc += d as f64;
307 if acc >= target {
308 chosen = i;
309 break;
310 }
311 }
312 chosen
313 };
314 centroids.push(data[next_idx].to_vec());
315 let last = centroids.last().expect("just pushed");
317 for (i, point) in data.iter().enumerate() {
318 let d = l2_sub(point, last);
319 if d < min_dists[i] {
320 min_dists[i] = d;
321 }
322 }
323 }
324
325 let mut assignments = vec![0usize; n];
327 for _ in 0..max_iter {
328 let mut changed = false;
330 for (i, point) in data.iter().enumerate() {
331 let mut best = 0;
332 let mut best_d = f32::MAX;
333 for (c, centroid) in centroids.iter().enumerate() {
334 let d = l2_sub(point, centroid);
335 if d < best_d {
336 best_d = d;
337 best = c;
338 }
339 }
340 if assignments[i] != best {
341 assignments[i] = best;
342 changed = true;
343 }
344 }
345 if !changed {
346 break;
347 }
348
349 let mut sums = vec![vec![0.0f32; dim]; k];
351 let mut counts = vec![0usize; k];
352 for (i, point) in data.iter().enumerate() {
353 let c = assignments[i];
354 counts[c] += 1;
355 for d in 0..dim {
356 sums[c][d] += point[d];
357 }
358 }
359 for c in 0..k {
360 if counts[c] > 0 {
361 for d in 0..dim {
362 centroids[c][d] = sums[c][d] / counts[c] as f32;
363 }
364 }
365 }
366 }
367
368 centroids
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 fn make_clustered_data() -> Vec<Vec<f32>> {
376 let mut vecs = Vec::new();
378 for cluster in 0..4 {
379 let center = cluster as f32 * 10.0;
380 for i in 0..50 {
381 vecs.push(vec![
382 center + (i as f32) * 0.1,
383 center + (i as f32) * 0.05,
384 center - (i as f32) * 0.1,
385 center + (i as f32) * 0.02,
386 ]);
387 }
388 }
389 vecs
390 }
391
392 #[test]
393 fn encode_decode_roundtrip() {
394 let vecs = make_clustered_data();
395 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
396 let codec = PqCodec::train(&refs, 4, 2, 16, 10);
397
398 for v in &vecs {
399 let code = codec.encode(v);
400 assert_eq!(code.len(), 2); let decoded = codec.decode(&code).unwrap();
402 assert_eq!(decoded.len(), 4);
403 }
404 }
405
406 #[test]
407 fn distance_table_gives_correct_ordering() {
408 let vecs = make_clustered_data();
409 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
410 let codec = PqCodec::train(&refs, 4, 2, 16, 10);
411
412 let codes: Vec<Vec<u8>> = vecs.iter().map(|v| codec.encode(v)).collect();
413 let query = &[5.0, 5.0, 5.0, 5.0];
414 let table = codec.build_distance_table(query).unwrap();
415
416 let mut pq_dists: Vec<(usize, f32)> = codes
418 .iter()
419 .enumerate()
420 .map(|(i, c)| (i, codec.asymmetric_distance(&table, c)))
421 .collect();
422 pq_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
423
424 let mut exact_dists: Vec<(usize, f32)> = vecs
426 .iter()
427 .enumerate()
428 .map(|(i, v)| (i, l2_sub(query, v)))
429 .collect();
430 exact_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
431
432 let pq_top: std::collections::HashSet<usize> = pq_dists[..5].iter().map(|x| x.0).collect();
434 let exact_top: std::collections::HashSet<usize> =
435 exact_dists[..10].iter().map(|x| x.0).collect();
436 let overlap = pq_top.intersection(&exact_top).count();
437 assert!(overlap >= 3, "PQ recall too low: {overlap}/5 in top-10");
438 }
439
440 #[test]
441 fn batch_encode() {
442 let vecs = make_clustered_data();
443 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
444 let codec = PqCodec::train(&refs, 4, 2, 16, 10);
445
446 let batch = codec.encode_batch(&refs).unwrap();
447 assert_eq!(batch.len(), 2 * 200); }
449
450 #[test]
452 fn pq_codec_golden_format() {
453 let vecs = make_clustered_data();
454 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
455 let codec = PqCodec::train(&refs, 4, 2, 16, 10);
456
457 let bytes = codec.to_bytes().unwrap();
458
459 assert_eq!(&bytes[0..6], b"NDPQ\0\0", "magic mismatch");
461 assert_eq!(bytes[6], 1u8, "version must be 1");
463 let restored = zerompk::from_msgpack::<PqCodec>(&bytes[7..])
465 .expect("msgpack payload at offset 7 must decode");
466 assert_eq!(restored.dim, codec.dim);
467 assert_eq!(restored.m, codec.m);
468 }
469
470 #[test]
471 fn pq_version_mismatch_returns_error() {
472 let mut crafted = b"NDPQ\0\0".to_vec();
474 crafted.push(0u8); crafted.extend_from_slice(b"\x80"); let err = PqCodec::from_bytes(&crafted).unwrap_err();
478 assert!(
479 matches!(
480 err,
481 VectorError::UnsupportedVersion {
482 found: 0,
483 expected: 1
484 }
485 ),
486 "expected UnsupportedVersion, got: {err:?}"
487 );
488 }
489
490 #[test]
491 fn pq_invalid_magic_returns_error() {
492 let bad: &[u8] = b"JUNK\0\0\x01some-payload";
493 let err = PqCodec::from_bytes(bad).unwrap_err();
494 assert!(
495 matches!(err, VectorError::InvalidMagic),
496 "expected InvalidMagic, got: {err:?}"
497 );
498 }
499}