1use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
4
5use crate::error::{Error, Result};
6
7pub enum CentroidStore {
13 Owned(Array2<f32>),
15 #[cfg(feature = "npy")]
17 Mmap(crate::mmap::MmapNpyArray2F32),
18}
19
20impl CentroidStore {
21 pub fn view(&self) -> ArrayView2<'_, f32> {
25 match self {
26 CentroidStore::Owned(arr) => arr.view(),
27 #[cfg(feature = "npy")]
28 CentroidStore::Mmap(mmap) => mmap.view(),
29 }
30 }
31
32 pub fn nrows(&self) -> usize {
34 match self {
35 CentroidStore::Owned(arr) => arr.nrows(),
36 #[cfg(feature = "npy")]
37 CentroidStore::Mmap(mmap) => mmap.nrows(),
38 }
39 }
40
41 pub fn ncols(&self) -> usize {
43 match self {
44 CentroidStore::Owned(arr) => arr.ncols(),
45 #[cfg(feature = "npy")]
46 CentroidStore::Mmap(mmap) => mmap.ncols(),
47 }
48 }
49
50 pub fn row(&self, idx: usize) -> ArrayView1<'_, f32> {
52 match self {
53 CentroidStore::Owned(arr) => arr.row(idx),
54 #[cfg(feature = "npy")]
55 CentroidStore::Mmap(mmap) => mmap.row(idx),
56 }
57 }
58
59 pub fn slice_rows(&self, start: usize, end: usize) -> ArrayView2<'_, f32> {
63 match self {
64 CentroidStore::Owned(arr) => arr.slice(s![start..end, ..]),
65 #[cfg(feature = "npy")]
66 CentroidStore::Mmap(mmap) => mmap.slice_rows(start, end),
67 }
68 }
69}
70
71impl Clone for CentroidStore {
72 fn clone(&self) -> Self {
73 match self {
74 CentroidStore::Owned(arr) => CentroidStore::Owned(arr.clone()),
76 #[cfg(feature = "npy")]
78 CentroidStore::Mmap(mmap) => CentroidStore::Owned(mmap.to_owned()),
79 }
80 }
81}
82
83#[derive(Clone)]
89pub struct ResidualCodec {
90 pub nbits: usize,
92 pub centroids: CentroidStore,
95 pub avg_residual: Array1<f32>,
97 pub bucket_cutoffs: Option<Array1<f32>>,
99 pub bucket_weights: Option<Array1<f32>>,
101 pub byte_reversed_bits_map: Vec<u8>,
103 pub bucket_weight_indices_lookup: Option<Array2<usize>>,
105}
106
107impl ResidualCodec {
108 pub fn new(
118 nbits: usize,
119 centroids: Array2<f32>,
120 avg_residual: Array1<f32>,
121 bucket_cutoffs: Option<Array1<f32>>,
122 bucket_weights: Option<Array1<f32>>,
123 ) -> Result<Self> {
124 Self::new_with_store(
125 nbits,
126 CentroidStore::Owned(centroids),
127 avg_residual,
128 bucket_cutoffs,
129 bucket_weights,
130 )
131 }
132
133 pub fn new_with_store(
137 nbits: usize,
138 centroids: CentroidStore,
139 avg_residual: Array1<f32>,
140 bucket_cutoffs: Option<Array1<f32>>,
141 bucket_weights: Option<Array1<f32>>,
142 ) -> Result<Self> {
143 if nbits == 0 || 8 % nbits != 0 {
144 return Err(Error::Codec(format!(
145 "nbits must be a divisor of 8, got {}",
146 nbits
147 )));
148 }
149
150 let nbits_mask = (1u32 << nbits) - 1;
152 let mut byte_reversed_bits_map = vec![0u8; 256];
153
154 for (i, byte_slot) in byte_reversed_bits_map.iter_mut().enumerate() {
155 let val = i as u32;
156 let mut out = 0u32;
157 let mut pos = 8i32;
158
159 while pos >= nbits as i32 {
160 let segment = (val >> (pos as u32 - nbits as u32)) & nbits_mask;
161
162 let mut rev_segment = 0u32;
163 for k in 0..nbits {
164 if (segment & (1 << k)) != 0 {
165 rev_segment |= 1 << (nbits - 1 - k);
166 }
167 }
168
169 out |= rev_segment;
170
171 if pos > nbits as i32 {
172 out <<= nbits;
173 }
174
175 pos -= nbits as i32;
176 }
177 *byte_slot = out as u8;
178 }
179
180 let keys_per_byte = 8 / nbits;
182 let bucket_weight_indices_lookup = if bucket_weights.is_some() {
183 let mask = (1usize << nbits) - 1;
184 let mut table = Array2::<usize>::zeros((256, keys_per_byte));
185
186 for byte_val in 0..256usize {
187 for k in (0..keys_per_byte).rev() {
188 let shift = k * nbits;
189 let index = (byte_val >> shift) & mask;
190 table[[byte_val, keys_per_byte - 1 - k]] = index;
191 }
192 }
193 Some(table)
194 } else {
195 None
196 };
197
198 Ok(Self {
199 nbits,
200 centroids,
201 avg_residual,
202 bucket_cutoffs,
203 bucket_weights,
204 byte_reversed_bits_map,
205 bucket_weight_indices_lookup,
206 })
207 }
208
209 pub fn embedding_dim(&self) -> usize {
211 self.centroids.ncols()
212 }
213
214 pub fn num_centroids(&self) -> usize {
216 self.centroids.nrows()
217 }
218
219 pub fn centroids_view(&self) -> ArrayView2<'_, f32> {
223 self.centroids.view()
224 }
225
226 pub fn compress_into_codes(&self, embeddings: &Array2<f32>) -> Array1<usize> {
240 use rayon::prelude::*;
241
242 let n = embeddings.nrows();
243 if n == 0 {
244 return Array1::zeros(0);
245 }
246
247 let centroids = self.centroids_view();
249
250 const BATCH_SIZE: usize = 2048;
252
253 let mut all_codes = Vec::with_capacity(n);
254
255 for start in (0..n).step_by(BATCH_SIZE) {
256 let end = (start + BATCH_SIZE).min(n);
257 let batch = embeddings.slice(ndarray::s![start..end, ..]);
258
259 let scores = batch.dot(¢roids.t());
261
262 let batch_codes: Vec<usize> = scores
264 .axis_iter(Axis(0))
265 .into_par_iter()
266 .map(|row| {
267 row.iter()
268 .enumerate()
269 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
270 .map(|(idx, _)| idx)
271 .unwrap_or(0)
272 })
273 .collect();
274
275 all_codes.extend(batch_codes);
276 }
277
278 Array1::from_vec(all_codes)
279 }
280
281 pub fn quantize_residuals(&self, residuals: &Array2<f32>) -> Result<Array2<u8>> {
293 use rayon::prelude::*;
294
295 let cutoffs = self
296 .bucket_cutoffs
297 .as_ref()
298 .ok_or_else(|| Error::Codec("bucket_cutoffs required for quantization".into()))?;
299
300 let n = residuals.nrows();
301 let dim = residuals.ncols();
302 let packed_dim = dim * self.nbits / 8;
303 let nbits = self.nbits;
304
305 if n == 0 {
306 return Ok(Array2::zeros((0, packed_dim)));
307 }
308
309 let cutoffs_slice = cutoffs.as_slice().unwrap();
311
312 let packed_rows: Vec<Vec<u8>> = residuals
314 .axis_iter(Axis(0))
315 .into_par_iter()
316 .map(|row| {
317 let mut packed_row = vec![0u8; packed_dim];
318 let mut bit_idx = 0;
319
320 for &val in row.iter() {
321 let bucket = cutoffs_slice.iter().filter(|&&c| val > c).count();
323
324 for b in 0..nbits {
326 let bit = ((bucket >> b) & 1) as u8;
327 let byte_idx = bit_idx / 8;
328 let bit_pos = 7 - (bit_idx % 8);
329 packed_row[byte_idx] |= bit << bit_pos;
330 bit_idx += 1;
331 }
332 }
333
334 packed_row
335 })
336 .collect();
337
338 let mut packed = Array2::<u8>::zeros((n, packed_dim));
340 for (i, row) in packed_rows.into_iter().enumerate() {
341 for (j, val) in row.into_iter().enumerate() {
342 packed[[i, j]] = val;
343 }
344 }
345
346 Ok(packed)
347 }
348
349 pub fn decompress(
360 &self,
361 packed_residuals: &Array2<u8>,
362 codes: &ArrayView1<usize>,
363 ) -> Result<Array2<f32>> {
364 let bucket_weights = self
365 .bucket_weights
366 .as_ref()
367 .ok_or_else(|| Error::Codec("bucket_weights required for decompression".into()))?;
368
369 let lookup = self
370 .bucket_weight_indices_lookup
371 .as_ref()
372 .ok_or_else(|| Error::Codec("bucket_weight_indices_lookup required".into()))?;
373
374 let n = packed_residuals.nrows();
375 let dim = self.embedding_dim();
376
377 let mut output = Array2::<f32>::zeros((n, dim));
378
379 for i in 0..n {
380 let centroid = self.centroids.row(codes[i]);
382
383 let mut residual_idx = 0;
385 for &byte_val in packed_residuals.row(i).iter() {
386 let reversed = self.byte_reversed_bits_map[byte_val as usize];
387 let indices = lookup.row(reversed as usize);
388
389 for &bucket_idx in indices.iter() {
390 if residual_idx < dim {
391 output[[i, residual_idx]] =
392 centroid[residual_idx] + bucket_weights[bucket_idx];
393 residual_idx += 1;
394 }
395 }
396 }
397 }
398
399 for mut row in output.axis_iter_mut(Axis(0)) {
401 let norm = row.dot(&row).sqrt().max(1e-12);
402 row /= norm;
403 }
404
405 Ok(output)
406 }
407
408 #[cfg(feature = "npy")]
410 pub fn load_from_dir(index_path: &std::path::Path) -> Result<Self> {
411 use ndarray_npy::ReadNpyExt;
412 use std::fs::File;
413
414 let centroids_path = index_path.join("centroids.npy");
415 let centroids: Array2<f32> = Array2::read_npy(
416 File::open(¢roids_path)
417 .map_err(|e| Error::IndexLoad(format!("Failed to open centroids.npy: {}", e)))?,
418 )
419 .map_err(|e| Error::IndexLoad(format!("Failed to read centroids.npy: {}", e)))?;
420
421 let avg_residual_path = index_path.join("avg_residual.npy");
422 let avg_residual: Array1<f32> =
423 Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
424 Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
425 })?)
426 .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
427
428 let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
429 let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
430 Some(
431 Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
432 Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
433 })?)
434 .map_err(|e| {
435 Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
436 })?,
437 )
438 } else {
439 None
440 };
441
442 let bucket_weights_path = index_path.join("bucket_weights.npy");
443 let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
444 Some(
445 Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
446 Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
447 })?)
448 .map_err(|e| {
449 Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
450 })?,
451 )
452 } else {
453 None
454 };
455
456 let metadata_path = index_path.join("metadata.json");
458 let metadata: serde_json::Value = serde_json::from_reader(
459 File::open(&metadata_path)
460 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
461 )
462 .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
463
464 let nbits = metadata["nbits"]
465 .as_u64()
466 .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
467 as usize;
468
469 Self::new(
470 nbits,
471 centroids,
472 avg_residual,
473 bucket_cutoffs,
474 bucket_weights,
475 )
476 }
477
478 #[cfg(feature = "npy")]
486 pub fn load_mmap_from_dir(index_path: &std::path::Path) -> Result<Self> {
487 use ndarray_npy::ReadNpyExt;
488 use std::fs::File;
489
490 let centroids_path = index_path.join("centroids.npy");
492 let mmap_centroids = crate::mmap::MmapNpyArray2F32::from_npy_file(¢roids_path)?;
493
494 let avg_residual_path = index_path.join("avg_residual.npy");
496 let avg_residual: Array1<f32> =
497 Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
498 Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
499 })?)
500 .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
501
502 let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
503 let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
504 Some(
505 Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
506 Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
507 })?)
508 .map_err(|e| {
509 Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
510 })?,
511 )
512 } else {
513 None
514 };
515
516 let bucket_weights_path = index_path.join("bucket_weights.npy");
517 let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
518 Some(
519 Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
520 Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
521 })?)
522 .map_err(|e| {
523 Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
524 })?,
525 )
526 } else {
527 None
528 };
529
530 let metadata_path = index_path.join("metadata.json");
532 let metadata: serde_json::Value = serde_json::from_reader(
533 File::open(&metadata_path)
534 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
535 )
536 .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
537
538 let nbits = metadata["nbits"]
539 .as_u64()
540 .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
541 as usize;
542
543 Self::new_with_store(
544 nbits,
545 CentroidStore::Mmap(mmap_centroids),
546 avg_residual,
547 bucket_cutoffs,
548 bucket_weights,
549 )
550 }
551}
552
553#[cfg(test)]
554mod tests {
555 use super::*;
556
557 #[test]
558 fn test_codec_creation() {
559 let centroids =
560 Array2::from_shape_vec((4, 8), (0..32).map(|x| x as f32).collect()).unwrap();
561 let avg_residual = Array1::zeros(8);
562 let bucket_cutoffs = Some(Array1::from_vec(vec![-0.5, 0.0, 0.5]));
563 let bucket_weights = Some(Array1::from_vec(vec![-0.75, -0.25, 0.25, 0.75]));
564
565 let codec = ResidualCodec::new(2, centroids, avg_residual, bucket_cutoffs, bucket_weights);
566 assert!(codec.is_ok());
567
568 let codec = codec.unwrap();
569 assert_eq!(codec.nbits, 2);
570 assert_eq!(codec.embedding_dim(), 8);
571 assert_eq!(codec.num_centroids(), 4);
572 }
573
574 #[test]
575 fn test_compress_into_codes() {
576 let centroids = Array2::from_shape_vec(
577 (3, 4),
578 vec![
579 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ],
583 )
584 .unwrap();
585
586 let avg_residual = Array1::zeros(4);
587 let codec = ResidualCodec::new(2, centroids, avg_residual, None, None).unwrap();
588
589 let embeddings = Array2::from_shape_vec(
590 (2, 4),
591 vec![
592 0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 0.95, 0.05, ],
595 )
596 .unwrap();
597
598 let codes = codec.compress_into_codes(&embeddings);
599 assert_eq!(codes[0], 0);
600 assert_eq!(codes[1], 2);
601 }
602
603 #[test]
604 fn test_quantize_decompress_roundtrip_4bit() {
605 let dim = 8;
607 let centroids = Array2::zeros((4, dim));
608 let avg_residual = Array1::zeros(dim);
609
610 let bucket_cutoffs: Vec<f32> = (1..16).map(|i| (i as f32 / 16.0 - 0.5) * 2.0).collect();
613 let bucket_weights: Vec<f32> = (0..16)
615 .map(|i| ((i as f32 + 0.5) / 16.0 - 0.5) * 2.0)
616 .collect();
617
618 let codec = ResidualCodec::new(
619 4,
620 centroids,
621 avg_residual,
622 Some(Array1::from_vec(bucket_cutoffs)),
623 Some(Array1::from_vec(bucket_weights)),
624 )
625 .unwrap();
626
627 let residuals = Array2::from_shape_vec(
629 (2, dim),
630 vec![
631 -0.9, -0.7, -0.5, -0.3, 0.0, 0.3, 0.5, 0.9, -0.8, -0.4, 0.0, 0.4, 0.8, -0.6, 0.2, 0.6,
633 ],
634 )
635 .unwrap();
636
637 let packed = codec.quantize_residuals(&residuals).unwrap();
639 assert_eq!(packed.ncols(), dim * 4 / 8); let codes = Array1::from_vec(vec![0, 0]);
643
644 let decompressed = codec.decompress(&packed, &codes.view()).unwrap();
646
647 for i in 0..residuals.nrows() {
650 for j in 0..residuals.ncols() {
651 let orig = residuals[[i, j]];
652 let recon = decompressed[[i, j]];
653 if orig.abs() > 0.2 {
657 assert!(
658 (orig > 0.0) == (recon > 0.0) || recon.abs() < 0.1,
659 "Sign mismatch at [{}, {}]: orig={}, recon={}",
660 i,
661 j,
662 orig,
663 recon
664 );
665 }
666 }
667 }
668 }
669}