1use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
4
5use crate::error::{Error, Result};
6
7const MAX_NEAREST_CENTROID_MEMORY: usize = 4 * 1024 * 1024 * 1024; pub enum CentroidStore {
18 Owned(Array2<f32>),
20 Mmap(crate::mmap::MmapNpyArray2F32),
22}
23
24impl CentroidStore {
25 pub fn view(&self) -> ArrayView2<'_, f32> {
29 match self {
30 CentroidStore::Owned(arr) => arr.view(),
31 CentroidStore::Mmap(mmap) => mmap.view(),
32 }
33 }
34
35 pub fn nrows(&self) -> usize {
37 match self {
38 CentroidStore::Owned(arr) => arr.nrows(),
39 CentroidStore::Mmap(mmap) => mmap.nrows(),
40 }
41 }
42
43 pub fn ncols(&self) -> usize {
45 match self {
46 CentroidStore::Owned(arr) => arr.ncols(),
47 CentroidStore::Mmap(mmap) => mmap.ncols(),
48 }
49 }
50
51 pub fn row(&self, idx: usize) -> ArrayView1<'_, f32> {
53 match self {
54 CentroidStore::Owned(arr) => arr.row(idx),
55 CentroidStore::Mmap(mmap) => mmap.row(idx),
56 }
57 }
58
59 pub fn slice_rows(&self, start: usize, end: usize) -> ArrayView2<'_, f32> {
63 match self {
64 CentroidStore::Owned(arr) => arr.slice(s![start..end, ..]),
65 CentroidStore::Mmap(mmap) => mmap.slice_rows(start, end),
66 }
67 }
68}
69
70impl Clone for CentroidStore {
71 fn clone(&self) -> Self {
72 match self {
73 CentroidStore::Owned(arr) => CentroidStore::Owned(arr.clone()),
75 CentroidStore::Mmap(mmap) => CentroidStore::Owned(mmap.to_owned()),
77 }
78 }
79}
80
81#[derive(Clone)]
87pub struct ResidualCodec {
88 pub nbits: usize,
90 pub centroids: CentroidStore,
93 pub avg_residual: Array1<f32>,
95 pub bucket_cutoffs: Option<Array1<f32>>,
97 pub bucket_weights: Option<Array1<f32>>,
99 pub byte_reversed_bits_map: Vec<u8>,
101 pub bucket_weight_indices_lookup: Option<Array2<usize>>,
103}
104
105impl ResidualCodec {
106 pub fn new(
116 nbits: usize,
117 centroids: Array2<f32>,
118 avg_residual: Array1<f32>,
119 bucket_cutoffs: Option<Array1<f32>>,
120 bucket_weights: Option<Array1<f32>>,
121 ) -> Result<Self> {
122 Self::new_with_store(
123 nbits,
124 CentroidStore::Owned(centroids),
125 avg_residual,
126 bucket_cutoffs,
127 bucket_weights,
128 )
129 }
130
131 pub fn new_with_store(
135 nbits: usize,
136 centroids: CentroidStore,
137 avg_residual: Array1<f32>,
138 bucket_cutoffs: Option<Array1<f32>>,
139 bucket_weights: Option<Array1<f32>>,
140 ) -> Result<Self> {
141 if nbits == 0 || 8 % nbits != 0 {
142 return Err(Error::Codec(format!(
143 "nbits must be a divisor of 8, got {}",
144 nbits
145 )));
146 }
147
148 let nbits_mask = (1u32 << nbits) - 1;
150 let mut byte_reversed_bits_map = vec![0u8; 256];
151
152 for (i, byte_slot) in byte_reversed_bits_map.iter_mut().enumerate() {
153 let val = i as u32;
154 let mut out = 0u32;
155 let mut pos = 8i32;
156
157 while pos >= nbits as i32 {
158 let segment = (val >> (pos as u32 - nbits as u32)) & nbits_mask;
159
160 let mut rev_segment = 0u32;
161 for k in 0..nbits {
162 if (segment & (1 << k)) != 0 {
163 rev_segment |= 1 << (nbits - 1 - k);
164 }
165 }
166
167 out |= rev_segment;
168
169 if pos > nbits as i32 {
170 out <<= nbits;
171 }
172
173 pos -= nbits as i32;
174 }
175 *byte_slot = out as u8;
176 }
177
178 let keys_per_byte = 8 / nbits;
180 let bucket_weight_indices_lookup = if bucket_weights.is_some() {
181 let mask = (1usize << nbits) - 1;
182 let mut table = Array2::<usize>::zeros((256, keys_per_byte));
183
184 for byte_val in 0..256usize {
185 for k in (0..keys_per_byte).rev() {
186 let shift = k * nbits;
187 let index = (byte_val >> shift) & mask;
188 table[[byte_val, keys_per_byte - 1 - k]] = index;
189 }
190 }
191 Some(table)
192 } else {
193 None
194 };
195
196 Ok(Self {
197 nbits,
198 centroids,
199 avg_residual,
200 bucket_cutoffs,
201 bucket_weights,
202 byte_reversed_bits_map,
203 bucket_weight_indices_lookup,
204 })
205 }
206
207 pub fn embedding_dim(&self) -> usize {
209 self.centroids.ncols()
210 }
211
212 pub fn num_centroids(&self) -> usize {
214 self.centroids.nrows()
215 }
216
217 pub fn centroids_view(&self) -> ArrayView2<'_, f32> {
221 self.centroids.view()
222 }
223
224 pub fn compress_into_codes(&self, embeddings: &Array2<f32>) -> Array1<usize> {
241 #[cfg(feature = "cuda")]
243 {
244 if let Some(ctx) = crate::cuda::get_global_context() {
245 let centroids = self.centroids_view();
246 match crate::cuda::compress_into_codes_cuda_batched(
247 ctx,
248 &embeddings.view(),
249 ¢roids,
250 None,
251 ) {
252 Ok(codes) => return codes,
253 Err(e) => {
254 eprintln!(
255 "[next-plaid] CUDA compress_into_codes failed: {}, falling back to CPU",
256 e
257 );
258 }
259 }
260 }
261 }
262
263 self.compress_into_codes_cpu(embeddings)
264 }
265
266 pub fn compress_into_codes_cpu(&self, embeddings: &Array2<f32>) -> Array1<usize> {
269 use rayon::prelude::*;
270
271 let n = embeddings.nrows();
272 if n == 0 {
273 return Array1::zeros(0);
274 }
275
276 let centroids = self.centroids_view();
278 let num_centroids = centroids.nrows();
279
280 let max_batch_by_memory =
284 MAX_NEAREST_CENTROID_MEMORY / (num_centroids * std::mem::size_of::<f32>());
285 let batch_size = max_batch_by_memory.clamp(1, 2048);
286
287 let mut all_codes = Vec::with_capacity(n);
288
289 for start in (0..n).step_by(batch_size) {
290 let end = (start + batch_size).min(n);
291 let batch = embeddings.slice(ndarray::s![start..end, ..]);
292
293 let scores = batch.dot(¢roids.t());
295
296 let batch_codes: Vec<usize> = scores
298 .axis_iter(Axis(0))
299 .into_par_iter()
300 .map(|row| {
301 row.iter()
302 .enumerate()
303 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
304 .map(|(idx, _)| idx)
305 .unwrap_or(0)
306 })
307 .collect();
308
309 all_codes.extend(batch_codes);
310 }
311
312 Array1::from_vec(all_codes)
313 }
314
315 pub fn quantize_residuals(&self, residuals: &Array2<f32>) -> Result<Array2<u8>> {
327 use rayon::prelude::*;
328
329 let cutoffs = self
330 .bucket_cutoffs
331 .as_ref()
332 .ok_or_else(|| Error::Codec("bucket_cutoffs required for quantization".into()))?;
333
334 let n = residuals.nrows();
335 let dim = residuals.ncols();
336 let packed_dim = dim * self.nbits / 8;
337 let nbits = self.nbits;
338
339 if n == 0 {
340 return Ok(Array2::zeros((0, packed_dim)));
341 }
342
343 let cutoffs_slice = cutoffs.as_slice().unwrap();
345
346 let packed_rows: Vec<Vec<u8>> = residuals
348 .axis_iter(Axis(0))
349 .into_par_iter()
350 .map(|row| {
351 let mut packed_row = vec![0u8; packed_dim];
352 let mut bit_idx = 0;
353
354 for &val in row.iter() {
355 let bucket = cutoffs_slice.iter().filter(|&&c| val > c).count();
357
358 for b in 0..nbits {
360 let bit = ((bucket >> b) & 1) as u8;
361 let byte_idx = bit_idx / 8;
362 let bit_pos = 7 - (bit_idx % 8);
363 packed_row[byte_idx] |= bit << bit_pos;
364 bit_idx += 1;
365 }
366 }
367
368 packed_row
369 })
370 .collect();
371
372 let mut packed = Array2::<u8>::zeros((n, packed_dim));
374 for (i, row) in packed_rows.into_iter().enumerate() {
375 for (j, val) in row.into_iter().enumerate() {
376 packed[[i, j]] = val;
377 }
378 }
379
380 Ok(packed)
381 }
382
383 pub fn decompress(
394 &self,
395 packed_residuals: &Array2<u8>,
396 codes: &ArrayView1<usize>,
397 ) -> Result<Array2<f32>> {
398 let bucket_weights = self
399 .bucket_weights
400 .as_ref()
401 .ok_or_else(|| Error::Codec("bucket_weights required for decompression".into()))?;
402
403 let lookup = self
404 .bucket_weight_indices_lookup
405 .as_ref()
406 .ok_or_else(|| Error::Codec("bucket_weight_indices_lookup required".into()))?;
407
408 let n = packed_residuals.nrows();
409 let dim = self.embedding_dim();
410
411 let mut output = Array2::<f32>::zeros((n, dim));
412
413 for i in 0..n {
414 let centroid = self.centroids.row(codes[i]);
416
417 let mut residual_idx = 0;
419 for &byte_val in packed_residuals.row(i).iter() {
420 let reversed = self.byte_reversed_bits_map[byte_val as usize];
421 let indices = lookup.row(reversed as usize);
422
423 for &bucket_idx in indices.iter() {
424 if residual_idx < dim {
425 output[[i, residual_idx]] =
426 centroid[residual_idx] + bucket_weights[bucket_idx];
427 residual_idx += 1;
428 }
429 }
430 }
431 }
432
433 for mut row in output.axis_iter_mut(Axis(0)) {
435 let norm = row.dot(&row).sqrt().max(1e-12);
436 row /= norm;
437 }
438
439 Ok(output)
440 }
441
442 pub fn load_from_dir(index_path: &std::path::Path) -> Result<Self> {
444 use ndarray_npy::ReadNpyExt;
445 use std::fs::File;
446
447 let centroids_path = index_path.join("centroids.npy");
448 let centroids: Array2<f32> = Array2::read_npy(
449 File::open(¢roids_path)
450 .map_err(|e| Error::IndexLoad(format!("Failed to open centroids.npy: {}", e)))?,
451 )
452 .map_err(|e| Error::IndexLoad(format!("Failed to read centroids.npy: {}", e)))?;
453
454 let avg_residual_path = index_path.join("avg_residual.npy");
455 let avg_residual: Array1<f32> =
456 Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
457 Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
458 })?)
459 .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
460
461 let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
462 let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
463 Some(
464 Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
465 Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
466 })?)
467 .map_err(|e| {
468 Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
469 })?,
470 )
471 } else {
472 None
473 };
474
475 let bucket_weights_path = index_path.join("bucket_weights.npy");
476 let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
477 Some(
478 Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
479 Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
480 })?)
481 .map_err(|e| {
482 Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
483 })?,
484 )
485 } else {
486 None
487 };
488
489 let metadata_path = index_path.join("metadata.json");
491 let metadata: serde_json::Value = serde_json::from_reader(
492 File::open(&metadata_path)
493 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
494 )
495 .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
496
497 let nbits = metadata["nbits"]
498 .as_u64()
499 .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
500 as usize;
501
502 Self::new(
503 nbits,
504 centroids,
505 avg_residual,
506 bucket_cutoffs,
507 bucket_weights,
508 )
509 }
510
511 pub fn load_mmap_from_dir(index_path: &std::path::Path) -> Result<Self> {
519 use ndarray_npy::ReadNpyExt;
520 use std::fs::File;
521
522 let centroids_path = index_path.join("centroids.npy");
524 let mmap_centroids = crate::mmap::MmapNpyArray2F32::from_npy_file(¢roids_path)?;
525
526 let avg_residual_path = index_path.join("avg_residual.npy");
528 let avg_residual: Array1<f32> =
529 Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
530 Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
531 })?)
532 .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
533
534 let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
535 let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
536 Some(
537 Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
538 Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
539 })?)
540 .map_err(|e| {
541 Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
542 })?,
543 )
544 } else {
545 None
546 };
547
548 let bucket_weights_path = index_path.join("bucket_weights.npy");
549 let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
550 Some(
551 Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
552 Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
553 })?)
554 .map_err(|e| {
555 Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
556 })?,
557 )
558 } else {
559 None
560 };
561
562 let metadata_path = index_path.join("metadata.json");
564 let metadata: serde_json::Value = serde_json::from_reader(
565 File::open(&metadata_path)
566 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
567 )
568 .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
569
570 let nbits = metadata["nbits"]
571 .as_u64()
572 .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
573 as usize;
574
575 Self::new_with_store(
576 nbits,
577 CentroidStore::Mmap(mmap_centroids),
578 avg_residual,
579 bucket_cutoffs,
580 bucket_weights,
581 )
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
590 fn test_codec_creation() {
591 let centroids =
592 Array2::from_shape_vec((4, 8), (0..32).map(|x| x as f32).collect()).unwrap();
593 let avg_residual = Array1::zeros(8);
594 let bucket_cutoffs = Some(Array1::from_vec(vec![-0.5, 0.0, 0.5]));
595 let bucket_weights = Some(Array1::from_vec(vec![-0.75, -0.25, 0.25, 0.75]));
596
597 let codec = ResidualCodec::new(2, centroids, avg_residual, bucket_cutoffs, bucket_weights);
598 assert!(codec.is_ok());
599
600 let codec = codec.unwrap();
601 assert_eq!(codec.nbits, 2);
602 assert_eq!(codec.embedding_dim(), 8);
603 assert_eq!(codec.num_centroids(), 4);
604 }
605
606 #[test]
607 fn test_compress_into_codes() {
608 let centroids = Array2::from_shape_vec(
609 (3, 4),
610 vec![
611 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ],
615 )
616 .unwrap();
617
618 let avg_residual = Array1::zeros(4);
619 let codec = ResidualCodec::new(2, centroids, avg_residual, None, None).unwrap();
620
621 let embeddings = Array2::from_shape_vec(
622 (2, 4),
623 vec![
624 0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 0.95, 0.05, ],
627 )
628 .unwrap();
629
630 let codes = codec.compress_into_codes(&embeddings);
631 assert_eq!(codes[0], 0);
632 assert_eq!(codes[1], 2);
633 }
634
635 #[test]
636 fn test_quantize_decompress_roundtrip_4bit() {
637 let dim = 8;
639 let centroids = Array2::zeros((4, dim));
640 let avg_residual = Array1::zeros(dim);
641
642 let bucket_cutoffs: Vec<f32> = (1..16).map(|i| (i as f32 / 16.0 - 0.5) * 2.0).collect();
645 let bucket_weights: Vec<f32> = (0..16)
647 .map(|i| ((i as f32 + 0.5) / 16.0 - 0.5) * 2.0)
648 .collect();
649
650 let codec = ResidualCodec::new(
651 4,
652 centroids,
653 avg_residual,
654 Some(Array1::from_vec(bucket_cutoffs)),
655 Some(Array1::from_vec(bucket_weights)),
656 )
657 .unwrap();
658
659 let residuals = Array2::from_shape_vec(
661 (2, dim),
662 vec![
663 -0.9, -0.7, -0.5, -0.3, 0.0, 0.3, 0.5, 0.9, -0.8, -0.4, 0.0, 0.4, 0.8, -0.6, 0.2, 0.6,
665 ],
666 )
667 .unwrap();
668
669 let packed = codec.quantize_residuals(&residuals).unwrap();
671 assert_eq!(packed.ncols(), dim * 4 / 8); let codes = Array1::from_vec(vec![0, 0]);
675
676 let decompressed = codec.decompress(&packed, &codes.view()).unwrap();
678
679 for i in 0..residuals.nrows() {
682 for j in 0..residuals.ncols() {
683 let orig = residuals[[i, j]];
684 let recon = decompressed[[i, j]];
685 if orig.abs() > 0.2 {
689 assert!(
690 (orig > 0.0) == (recon > 0.0) || recon.abs() < 0.1,
691 "Sign mismatch at [{}, {}]: orig={}, recon={}",
692 i,
693 j,
694 orig,
695 recon
696 );
697 }
698 }
699 }
700 }
701}