1use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
4
5use crate::error::{Error, Result};
6
7const DEFAULT_MAX_NEAREST_CENTROID_MEMORY: usize = 1024 * 1024 * 1024; fn max_nearest_centroid_memory() -> usize {
14 std::env::var("NEXT_PLAID_MAX_NEAREST_CENTROID_MEMORY_MB")
15 .ok()
16 .and_then(|v| v.parse::<usize>().ok())
17 .filter(|&mb| mb > 0)
18 .map(|mb| mb.saturating_mul(1024 * 1024))
19 .unwrap_or(DEFAULT_MAX_NEAREST_CENTROID_MEMORY)
20}
21
22pub enum CentroidStore {
28 Owned(Array2<f32>),
30 Mmap(crate::mmap::MmapNpyArray2F32),
32}
33
34impl CentroidStore {
35 pub fn view(&self) -> ArrayView2<'_, f32> {
39 match self {
40 CentroidStore::Owned(arr) => arr.view(),
41 CentroidStore::Mmap(mmap) => mmap.view(),
42 }
43 }
44
45 pub fn nrows(&self) -> usize {
47 match self {
48 CentroidStore::Owned(arr) => arr.nrows(),
49 CentroidStore::Mmap(mmap) => mmap.nrows(),
50 }
51 }
52
53 pub fn ncols(&self) -> usize {
55 match self {
56 CentroidStore::Owned(arr) => arr.ncols(),
57 CentroidStore::Mmap(mmap) => mmap.ncols(),
58 }
59 }
60
61 pub fn row(&self, idx: usize) -> ArrayView1<'_, f32> {
63 match self {
64 CentroidStore::Owned(arr) => arr.row(idx),
65 CentroidStore::Mmap(mmap) => mmap.row(idx),
66 }
67 }
68
69 pub fn slice_rows(&self, start: usize, end: usize) -> ArrayView2<'_, f32> {
73 match self {
74 CentroidStore::Owned(arr) => arr.slice(s![start..end, ..]),
75 CentroidStore::Mmap(mmap) => mmap.slice_rows(start, end),
76 }
77 }
78}
79
80impl Clone for CentroidStore {
81 fn clone(&self) -> Self {
82 match self {
83 CentroidStore::Owned(arr) => CentroidStore::Owned(arr.clone()),
85 CentroidStore::Mmap(mmap) => CentroidStore::Owned(mmap.to_owned()),
87 }
88 }
89}
90
91#[derive(Clone)]
97pub struct ResidualCodec {
98 pub nbits: usize,
100 pub centroids: CentroidStore,
103 pub avg_residual: Array1<f32>,
105 pub bucket_cutoffs: Option<Array1<f32>>,
107 pub bucket_weights: Option<Array1<f32>>,
109 pub byte_reversed_bits_map: Vec<u8>,
111 pub bucket_weight_indices_lookup: Option<Array2<usize>>,
113}
114
115impl ResidualCodec {
116 pub fn new(
126 nbits: usize,
127 centroids: Array2<f32>,
128 avg_residual: Array1<f32>,
129 bucket_cutoffs: Option<Array1<f32>>,
130 bucket_weights: Option<Array1<f32>>,
131 ) -> Result<Self> {
132 Self::new_with_store(
133 nbits,
134 CentroidStore::Owned(centroids),
135 avg_residual,
136 bucket_cutoffs,
137 bucket_weights,
138 )
139 }
140
141 pub fn new_with_store(
145 nbits: usize,
146 centroids: CentroidStore,
147 avg_residual: Array1<f32>,
148 bucket_cutoffs: Option<Array1<f32>>,
149 bucket_weights: Option<Array1<f32>>,
150 ) -> Result<Self> {
151 if nbits == 0 || 8 % nbits != 0 {
152 return Err(Error::Codec(format!(
153 "nbits must be a divisor of 8, got {}",
154 nbits
155 )));
156 }
157
158 let nbits_mask = (1u32 << nbits) - 1;
160 let mut byte_reversed_bits_map = vec![0u8; 256];
161
162 for (i, byte_slot) in byte_reversed_bits_map.iter_mut().enumerate() {
163 let val = i as u32;
164 let mut out = 0u32;
165 let mut pos = 8i32;
166
167 while pos >= nbits as i32 {
168 let segment = (val >> (pos as u32 - nbits as u32)) & nbits_mask;
169
170 let mut rev_segment = 0u32;
171 for k in 0..nbits {
172 if (segment & (1 << k)) != 0 {
173 rev_segment |= 1 << (nbits - 1 - k);
174 }
175 }
176
177 out |= rev_segment;
178
179 if pos > nbits as i32 {
180 out <<= nbits;
181 }
182
183 pos -= nbits as i32;
184 }
185 *byte_slot = out as u8;
186 }
187
188 let keys_per_byte = 8 / nbits;
190 let bucket_weight_indices_lookup = if bucket_weights.is_some() {
191 let mask = (1usize << nbits) - 1;
192 let mut table = Array2::<usize>::zeros((256, keys_per_byte));
193
194 for byte_val in 0..256usize {
195 for k in (0..keys_per_byte).rev() {
196 let shift = k * nbits;
197 let index = (byte_val >> shift) & mask;
198 table[[byte_val, keys_per_byte - 1 - k]] = index;
199 }
200 }
201 Some(table)
202 } else {
203 None
204 };
205
206 Ok(Self {
207 nbits,
208 centroids,
209 avg_residual,
210 bucket_cutoffs,
211 bucket_weights,
212 byte_reversed_bits_map,
213 bucket_weight_indices_lookup,
214 })
215 }
216
217 pub fn embedding_dim(&self) -> usize {
219 self.centroids.ncols()
220 }
221
222 pub fn num_centroids(&self) -> usize {
224 self.centroids.nrows()
225 }
226
227 pub fn centroids_view(&self) -> ArrayView2<'_, f32> {
231 self.centroids.view()
232 }
233
234 pub fn compress_into_codes(&self, embeddings: &Array2<f32>) -> Array1<usize> {
251 #[cfg(feature = "cuda")]
253 {
254 let force_gpu = crate::is_force_gpu();
255 if let Some(ctx) = crate::cuda::get_global_context() {
256 let centroids = self.centroids_view();
257 match crate::cuda::compress_into_codes_cuda_batched(
258 &ctx,
259 &embeddings.view(),
260 ¢roids,
261 None,
262 ) {
263 Ok(codes) => return codes,
264 Err(e) => {
265 if force_gpu {
266 panic!(
267 "FORCE_GPU is set but CUDA compress_into_codes failed: {}",
268 e
269 );
270 }
271 eprintln!(
272 "[next-plaid] CUDA compression error: {}. Falling back to CPU.",
273 e
274 );
275 }
276 }
277 } else if force_gpu {
278 panic!("FORCE_GPU is set but CUDA context is unavailable");
279 }
280 }
281
282 self.compress_into_codes_cpu(embeddings)
283 }
284
285 pub fn compress_into_codes_cpu(&self, embeddings: &Array2<f32>) -> Array1<usize> {
288 use rayon::prelude::*;
289
290 let n = embeddings.nrows();
291 if n == 0 {
292 return Array1::zeros(0);
293 }
294
295 let centroids = self.centroids_view();
297 let num_centroids = centroids.nrows();
298
299 let max_batch_by_memory =
303 max_nearest_centroid_memory() / (num_centroids * std::mem::size_of::<f32>());
304 let batch_size = max_batch_by_memory.clamp(1, 1024);
305 let batch_ranges: Vec<(usize, usize)> = (0..n)
306 .step_by(batch_size)
307 .map(|start| (start, (start + batch_size).min(n)))
308 .collect();
309
310 let chunked_codes: Vec<Vec<usize>> = batch_ranges
311 .into_par_iter()
312 .map(|(start, end)| {
313 let batch = embeddings.slice(ndarray::s![start..end, ..]);
314
315 let scores = batch.dot(¢roids.t());
317
318 scores
320 .axis_iter(Axis(0))
321 .map(|row| {
322 row.iter()
323 .enumerate()
324 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
325 .map(|(idx, _)| idx)
326 .unwrap_or(0)
327 })
328 .collect()
329 })
330 .collect();
331
332 Array1::from_vec(chunked_codes.into_iter().flatten().collect())
333 }
334
335 pub fn quantize_residuals(&self, residuals: &Array2<f32>) -> Result<Array2<u8>> {
347 use rayon::prelude::*;
348
349 let cutoffs = self
350 .bucket_cutoffs
351 .as_ref()
352 .ok_or_else(|| Error::Codec("bucket_cutoffs required for quantization".into()))?;
353
354 let n = residuals.nrows();
355 let dim = residuals.ncols();
356 let packed_dim = dim * self.nbits / 8;
357 let nbits = self.nbits;
358
359 if n == 0 {
360 return Ok(Array2::zeros((0, packed_dim)));
361 }
362
363 let cutoffs_slice = cutoffs.as_slice().unwrap();
365
366 let packed_rows: Vec<Vec<u8>> = residuals
368 .axis_iter(Axis(0))
369 .into_par_iter()
370 .map(|row| {
371 let mut packed_row = vec![0u8; packed_dim];
372 let mut bit_idx = 0;
373
374 for &val in row.iter() {
375 let bucket = cutoffs_slice.iter().filter(|&&c| val > c).count();
377
378 for b in 0..nbits {
380 let bit = ((bucket >> b) & 1) as u8;
381 let byte_idx = bit_idx / 8;
382 let bit_pos = 7 - (bit_idx % 8);
383 packed_row[byte_idx] |= bit << bit_pos;
384 bit_idx += 1;
385 }
386 }
387
388 packed_row
389 })
390 .collect();
391
392 let mut packed = Array2::<u8>::zeros((n, packed_dim));
394 for (i, row) in packed_rows.into_iter().enumerate() {
395 for (j, val) in row.into_iter().enumerate() {
396 packed[[i, j]] = val;
397 }
398 }
399
400 Ok(packed)
401 }
402
403 pub fn decompress(
414 &self,
415 packed_residuals: &Array2<u8>,
416 codes: &ArrayView1<usize>,
417 ) -> Result<Array2<f32>> {
418 let bucket_weights = self
419 .bucket_weights
420 .as_ref()
421 .ok_or_else(|| Error::Codec("bucket_weights required for decompression".into()))?;
422
423 let lookup = self
424 .bucket_weight_indices_lookup
425 .as_ref()
426 .ok_or_else(|| Error::Codec("bucket_weight_indices_lookup required".into()))?;
427
428 let n = packed_residuals.nrows();
429 let dim = self.embedding_dim();
430
431 let mut output = Array2::<f32>::zeros((n, dim));
432
433 for i in 0..n {
434 let centroid = self.centroids.row(codes[i]);
436
437 let mut residual_idx = 0;
439 for &byte_val in packed_residuals.row(i).iter() {
440 let reversed = self.byte_reversed_bits_map[byte_val as usize];
441 let indices = lookup.row(reversed as usize);
442
443 for &bucket_idx in indices.iter() {
444 if residual_idx < dim {
445 output[[i, residual_idx]] =
446 centroid[residual_idx] + bucket_weights[bucket_idx];
447 residual_idx += 1;
448 }
449 }
450 }
451 }
452
453 for mut row in output.axis_iter_mut(Axis(0)) {
455 let norm = row.dot(&row).sqrt().max(1e-12);
456 row /= norm;
457 }
458
459 Ok(output)
460 }
461
462 pub fn load_from_dir(index_path: &std::path::Path) -> Result<Self> {
464 use ndarray_npy::ReadNpyExt;
465 use std::fs::File;
466
467 let centroids_path = index_path.join("centroids.npy");
468 let centroids: Array2<f32> = Array2::read_npy(
469 File::open(¢roids_path)
470 .map_err(|e| Error::IndexLoad(format!("Failed to open centroids.npy: {}", e)))?,
471 )
472 .map_err(|e| Error::IndexLoad(format!("Failed to read centroids.npy: {}", e)))?;
473
474 let avg_residual_path = index_path.join("avg_residual.npy");
475 let avg_residual: Array1<f32> =
476 Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
477 Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
478 })?)
479 .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
480
481 let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
482 let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
483 Some(
484 Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
485 Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
486 })?)
487 .map_err(|e| {
488 Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
489 })?,
490 )
491 } else {
492 None
493 };
494
495 let bucket_weights_path = index_path.join("bucket_weights.npy");
496 let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
497 Some(
498 Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
499 Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
500 })?)
501 .map_err(|e| {
502 Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
503 })?,
504 )
505 } else {
506 None
507 };
508
509 let metadata_path = index_path.join("metadata.json");
511 let metadata: serde_json::Value = serde_json::from_reader(
512 File::open(&metadata_path)
513 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
514 )
515 .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
516
517 let nbits = metadata["nbits"]
518 .as_u64()
519 .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
520 as usize;
521
522 Self::new(
523 nbits,
524 centroids,
525 avg_residual,
526 bucket_cutoffs,
527 bucket_weights,
528 )
529 }
530
531 pub fn load_mmap_from_dir(index_path: &std::path::Path) -> Result<Self> {
539 use ndarray_npy::ReadNpyExt;
540 use std::fs::File;
541
542 let centroids_path = index_path.join("centroids.npy");
544 let mmap_centroids = crate::mmap::MmapNpyArray2F32::from_npy_file(¢roids_path)?;
545
546 let avg_residual_path = index_path.join("avg_residual.npy");
548 let avg_residual: Array1<f32> =
549 Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
550 Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
551 })?)
552 .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
553
554 let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
555 let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
556 Some(
557 Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
558 Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
559 })?)
560 .map_err(|e| {
561 Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
562 })?,
563 )
564 } else {
565 None
566 };
567
568 let bucket_weights_path = index_path.join("bucket_weights.npy");
569 let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
570 Some(
571 Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
572 Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
573 })?)
574 .map_err(|e| {
575 Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
576 })?,
577 )
578 } else {
579 None
580 };
581
582 let metadata_path = index_path.join("metadata.json");
584 let metadata: serde_json::Value = serde_json::from_reader(
585 File::open(&metadata_path)
586 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
587 )
588 .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
589
590 let nbits = metadata["nbits"]
591 .as_u64()
592 .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
593 as usize;
594
595 Self::new_with_store(
596 nbits,
597 CentroidStore::Mmap(mmap_centroids),
598 avg_residual,
599 bucket_cutoffs,
600 bucket_weights,
601 )
602 }
603}
604
605#[cfg(test)]
606mod tests {
607 use super::*;
608
609 #[test]
610 fn test_codec_creation() {
611 let centroids =
612 Array2::from_shape_vec((4, 8), (0..32).map(|x| x as f32).collect()).unwrap();
613 let avg_residual = Array1::zeros(8);
614 let bucket_cutoffs = Some(Array1::from_vec(vec![-0.5, 0.0, 0.5]));
615 let bucket_weights = Some(Array1::from_vec(vec![-0.75, -0.25, 0.25, 0.75]));
616
617 let codec = ResidualCodec::new(2, centroids, avg_residual, bucket_cutoffs, bucket_weights);
618 assert!(codec.is_ok());
619
620 let codec = codec.unwrap();
621 assert_eq!(codec.nbits, 2);
622 assert_eq!(codec.embedding_dim(), 8);
623 assert_eq!(codec.num_centroids(), 4);
624 }
625
626 #[test]
627 fn test_compress_into_codes() {
628 let centroids = Array2::from_shape_vec(
629 (3, 4),
630 vec![
631 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ],
635 )
636 .unwrap();
637
638 let avg_residual = Array1::zeros(4);
639 let codec = ResidualCodec::new(2, centroids, avg_residual, None, None).unwrap();
640
641 let embeddings = Array2::from_shape_vec(
642 (2, 4),
643 vec![
644 0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 0.95, 0.05, ],
647 )
648 .unwrap();
649
650 let codes = codec.compress_into_codes(&embeddings);
651 assert_eq!(codes[0], 0);
652 assert_eq!(codes[1], 2);
653 }
654
655 #[test]
656 fn test_quantize_decompress_roundtrip_4bit() {
657 let dim = 8;
659 let centroids = Array2::zeros((4, dim));
660 let avg_residual = Array1::zeros(dim);
661
662 let bucket_cutoffs: Vec<f32> = (1..16).map(|i| (i as f32 / 16.0 - 0.5) * 2.0).collect();
665 let bucket_weights: Vec<f32> = (0..16)
667 .map(|i| ((i as f32 + 0.5) / 16.0 - 0.5) * 2.0)
668 .collect();
669
670 let codec = ResidualCodec::new(
671 4,
672 centroids,
673 avg_residual,
674 Some(Array1::from_vec(bucket_cutoffs)),
675 Some(Array1::from_vec(bucket_weights)),
676 )
677 .unwrap();
678
679 let residuals = Array2::from_shape_vec(
681 (2, dim),
682 vec![
683 -0.9, -0.7, -0.5, -0.3, 0.0, 0.3, 0.5, 0.9, -0.8, -0.4, 0.0, 0.4, 0.8, -0.6, 0.2, 0.6,
685 ],
686 )
687 .unwrap();
688
689 let packed = codec.quantize_residuals(&residuals).unwrap();
691 assert_eq!(packed.ncols(), dim * 4 / 8); let codes = Array1::from_vec(vec![0, 0]);
695
696 let decompressed = codec.decompress(&packed, &codes.view()).unwrap();
698
699 for i in 0..residuals.nrows() {
702 for j in 0..residuals.ncols() {
703 let orig = residuals[[i, j]];
704 let recon = decompressed[[i, j]];
705 if orig.abs() > 0.2 {
709 assert!(
710 (orig > 0.0) == (recon > 0.0) || recon.abs() < 0.1,
711 "Sign mismatch at [{}, {}]: orig={}, recon={}",
712 i,
713 j,
714 orig,
715 recon
716 );
717 }
718 }
719 }
720 }
721}