1use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
4
5use crate::error::{Error, Result};
6
7const MAX_NEAREST_CENTROID_MEMORY: usize = 4 * 1024 * 1024 * 1024; pub enum CentroidStore {
18 Owned(Array2<f32>),
20 Mmap(crate::mmap::MmapNpyArray2F32),
22}
23
24impl CentroidStore {
25 pub fn view(&self) -> ArrayView2<'_, f32> {
29 match self {
30 CentroidStore::Owned(arr) => arr.view(),
31 CentroidStore::Mmap(mmap) => mmap.view(),
32 }
33 }
34
35 pub fn nrows(&self) -> usize {
37 match self {
38 CentroidStore::Owned(arr) => arr.nrows(),
39 CentroidStore::Mmap(mmap) => mmap.nrows(),
40 }
41 }
42
43 pub fn ncols(&self) -> usize {
45 match self {
46 CentroidStore::Owned(arr) => arr.ncols(),
47 CentroidStore::Mmap(mmap) => mmap.ncols(),
48 }
49 }
50
51 pub fn row(&self, idx: usize) -> ArrayView1<'_, f32> {
53 match self {
54 CentroidStore::Owned(arr) => arr.row(idx),
55 CentroidStore::Mmap(mmap) => mmap.row(idx),
56 }
57 }
58
59 pub fn slice_rows(&self, start: usize, end: usize) -> ArrayView2<'_, f32> {
63 match self {
64 CentroidStore::Owned(arr) => arr.slice(s![start..end, ..]),
65 CentroidStore::Mmap(mmap) => mmap.slice_rows(start, end),
66 }
67 }
68}
69
70impl Clone for CentroidStore {
71 fn clone(&self) -> Self {
72 match self {
73 CentroidStore::Owned(arr) => CentroidStore::Owned(arr.clone()),
75 CentroidStore::Mmap(mmap) => CentroidStore::Owned(mmap.to_owned()),
77 }
78 }
79}
80
81#[derive(Clone)]
87pub struct ResidualCodec {
88 pub nbits: usize,
90 pub centroids: CentroidStore,
93 pub avg_residual: Array1<f32>,
95 pub bucket_cutoffs: Option<Array1<f32>>,
97 pub bucket_weights: Option<Array1<f32>>,
99 pub byte_reversed_bits_map: Vec<u8>,
101 pub bucket_weight_indices_lookup: Option<Array2<usize>>,
103}
104
105impl ResidualCodec {
106 pub fn new(
116 nbits: usize,
117 centroids: Array2<f32>,
118 avg_residual: Array1<f32>,
119 bucket_cutoffs: Option<Array1<f32>>,
120 bucket_weights: Option<Array1<f32>>,
121 ) -> Result<Self> {
122 Self::new_with_store(
123 nbits,
124 CentroidStore::Owned(centroids),
125 avg_residual,
126 bucket_cutoffs,
127 bucket_weights,
128 )
129 }
130
131 pub fn new_with_store(
135 nbits: usize,
136 centroids: CentroidStore,
137 avg_residual: Array1<f32>,
138 bucket_cutoffs: Option<Array1<f32>>,
139 bucket_weights: Option<Array1<f32>>,
140 ) -> Result<Self> {
141 if nbits == 0 || 8 % nbits != 0 {
142 return Err(Error::Codec(format!(
143 "nbits must be a divisor of 8, got {}",
144 nbits
145 )));
146 }
147
148 let nbits_mask = (1u32 << nbits) - 1;
150 let mut byte_reversed_bits_map = vec![0u8; 256];
151
152 for (i, byte_slot) in byte_reversed_bits_map.iter_mut().enumerate() {
153 let val = i as u32;
154 let mut out = 0u32;
155 let mut pos = 8i32;
156
157 while pos >= nbits as i32 {
158 let segment = (val >> (pos as u32 - nbits as u32)) & nbits_mask;
159
160 let mut rev_segment = 0u32;
161 for k in 0..nbits {
162 if (segment & (1 << k)) != 0 {
163 rev_segment |= 1 << (nbits - 1 - k);
164 }
165 }
166
167 out |= rev_segment;
168
169 if pos > nbits as i32 {
170 out <<= nbits;
171 }
172
173 pos -= nbits as i32;
174 }
175 *byte_slot = out as u8;
176 }
177
178 let keys_per_byte = 8 / nbits;
180 let bucket_weight_indices_lookup = if bucket_weights.is_some() {
181 let mask = (1usize << nbits) - 1;
182 let mut table = Array2::<usize>::zeros((256, keys_per_byte));
183
184 for byte_val in 0..256usize {
185 for k in (0..keys_per_byte).rev() {
186 let shift = k * nbits;
187 let index = (byte_val >> shift) & mask;
188 table[[byte_val, keys_per_byte - 1 - k]] = index;
189 }
190 }
191 Some(table)
192 } else {
193 None
194 };
195
196 Ok(Self {
197 nbits,
198 centroids,
199 avg_residual,
200 bucket_cutoffs,
201 bucket_weights,
202 byte_reversed_bits_map,
203 bucket_weight_indices_lookup,
204 })
205 }
206
207 pub fn embedding_dim(&self) -> usize {
209 self.centroids.ncols()
210 }
211
212 pub fn num_centroids(&self) -> usize {
214 self.centroids.nrows()
215 }
216
217 pub fn centroids_view(&self) -> ArrayView2<'_, f32> {
221 self.centroids.view()
222 }
223
224 pub fn compress_into_codes(&self, embeddings: &Array2<f32>) -> Array1<usize> {
238 use rayon::prelude::*;
239
240 let n = embeddings.nrows();
241 if n == 0 {
242 return Array1::zeros(0);
243 }
244
245 let centroids = self.centroids_view();
247 let num_centroids = centroids.nrows();
248
249 let max_batch_by_memory =
253 MAX_NEAREST_CENTROID_MEMORY / (num_centroids * std::mem::size_of::<f32>());
254 let batch_size = max_batch_by_memory.clamp(1, 2048);
255
256 let mut all_codes = Vec::with_capacity(n);
257
258 for start in (0..n).step_by(batch_size) {
259 let end = (start + batch_size).min(n);
260 let batch = embeddings.slice(ndarray::s![start..end, ..]);
261
262 let scores = batch.dot(¢roids.t());
264
265 let batch_codes: Vec<usize> = scores
267 .axis_iter(Axis(0))
268 .into_par_iter()
269 .map(|row| {
270 row.iter()
271 .enumerate()
272 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
273 .map(|(idx, _)| idx)
274 .unwrap_or(0)
275 })
276 .collect();
277
278 all_codes.extend(batch_codes);
279 }
280
281 Array1::from_vec(all_codes)
282 }
283
284 pub fn quantize_residuals(&self, residuals: &Array2<f32>) -> Result<Array2<u8>> {
296 use rayon::prelude::*;
297
298 let cutoffs = self
299 .bucket_cutoffs
300 .as_ref()
301 .ok_or_else(|| Error::Codec("bucket_cutoffs required for quantization".into()))?;
302
303 let n = residuals.nrows();
304 let dim = residuals.ncols();
305 let packed_dim = dim * self.nbits / 8;
306 let nbits = self.nbits;
307
308 if n == 0 {
309 return Ok(Array2::zeros((0, packed_dim)));
310 }
311
312 let cutoffs_slice = cutoffs.as_slice().unwrap();
314
315 let packed_rows: Vec<Vec<u8>> = residuals
317 .axis_iter(Axis(0))
318 .into_par_iter()
319 .map(|row| {
320 let mut packed_row = vec![0u8; packed_dim];
321 let mut bit_idx = 0;
322
323 for &val in row.iter() {
324 let bucket = cutoffs_slice.iter().filter(|&&c| val > c).count();
326
327 for b in 0..nbits {
329 let bit = ((bucket >> b) & 1) as u8;
330 let byte_idx = bit_idx / 8;
331 let bit_pos = 7 - (bit_idx % 8);
332 packed_row[byte_idx] |= bit << bit_pos;
333 bit_idx += 1;
334 }
335 }
336
337 packed_row
338 })
339 .collect();
340
341 let mut packed = Array2::<u8>::zeros((n, packed_dim));
343 for (i, row) in packed_rows.into_iter().enumerate() {
344 for (j, val) in row.into_iter().enumerate() {
345 packed[[i, j]] = val;
346 }
347 }
348
349 Ok(packed)
350 }
351
352 pub fn decompress(
363 &self,
364 packed_residuals: &Array2<u8>,
365 codes: &ArrayView1<usize>,
366 ) -> Result<Array2<f32>> {
367 let bucket_weights = self
368 .bucket_weights
369 .as_ref()
370 .ok_or_else(|| Error::Codec("bucket_weights required for decompression".into()))?;
371
372 let lookup = self
373 .bucket_weight_indices_lookup
374 .as_ref()
375 .ok_or_else(|| Error::Codec("bucket_weight_indices_lookup required".into()))?;
376
377 let n = packed_residuals.nrows();
378 let dim = self.embedding_dim();
379
380 let mut output = Array2::<f32>::zeros((n, dim));
381
382 for i in 0..n {
383 let centroid = self.centroids.row(codes[i]);
385
386 let mut residual_idx = 0;
388 for &byte_val in packed_residuals.row(i).iter() {
389 let reversed = self.byte_reversed_bits_map[byte_val as usize];
390 let indices = lookup.row(reversed as usize);
391
392 for &bucket_idx in indices.iter() {
393 if residual_idx < dim {
394 output[[i, residual_idx]] =
395 centroid[residual_idx] + bucket_weights[bucket_idx];
396 residual_idx += 1;
397 }
398 }
399 }
400 }
401
402 for mut row in output.axis_iter_mut(Axis(0)) {
404 let norm = row.dot(&row).sqrt().max(1e-12);
405 row /= norm;
406 }
407
408 Ok(output)
409 }
410
411 pub fn load_from_dir(index_path: &std::path::Path) -> Result<Self> {
413 use ndarray_npy::ReadNpyExt;
414 use std::fs::File;
415
416 let centroids_path = index_path.join("centroids.npy");
417 let centroids: Array2<f32> = Array2::read_npy(
418 File::open(¢roids_path)
419 .map_err(|e| Error::IndexLoad(format!("Failed to open centroids.npy: {}", e)))?,
420 )
421 .map_err(|e| Error::IndexLoad(format!("Failed to read centroids.npy: {}", e)))?;
422
423 let avg_residual_path = index_path.join("avg_residual.npy");
424 let avg_residual: Array1<f32> =
425 Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
426 Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
427 })?)
428 .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
429
430 let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
431 let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
432 Some(
433 Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
434 Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
435 })?)
436 .map_err(|e| {
437 Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
438 })?,
439 )
440 } else {
441 None
442 };
443
444 let bucket_weights_path = index_path.join("bucket_weights.npy");
445 let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
446 Some(
447 Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
448 Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
449 })?)
450 .map_err(|e| {
451 Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
452 })?,
453 )
454 } else {
455 None
456 };
457
458 let metadata_path = index_path.join("metadata.json");
460 let metadata: serde_json::Value = serde_json::from_reader(
461 File::open(&metadata_path)
462 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
463 )
464 .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
465
466 let nbits = metadata["nbits"]
467 .as_u64()
468 .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
469 as usize;
470
471 Self::new(
472 nbits,
473 centroids,
474 avg_residual,
475 bucket_cutoffs,
476 bucket_weights,
477 )
478 }
479
480 pub fn load_mmap_from_dir(index_path: &std::path::Path) -> Result<Self> {
488 use ndarray_npy::ReadNpyExt;
489 use std::fs::File;
490
491 let centroids_path = index_path.join("centroids.npy");
493 let mmap_centroids = crate::mmap::MmapNpyArray2F32::from_npy_file(¢roids_path)?;
494
495 let avg_residual_path = index_path.join("avg_residual.npy");
497 let avg_residual: Array1<f32> =
498 Array1::read_npy(File::open(&avg_residual_path).map_err(|e| {
499 Error::IndexLoad(format!("Failed to open avg_residual.npy: {}", e))
500 })?)
501 .map_err(|e| Error::IndexLoad(format!("Failed to read avg_residual.npy: {}", e)))?;
502
503 let bucket_cutoffs_path = index_path.join("bucket_cutoffs.npy");
504 let bucket_cutoffs: Option<Array1<f32>> = if bucket_cutoffs_path.exists() {
505 Some(
506 Array1::read_npy(File::open(&bucket_cutoffs_path).map_err(|e| {
507 Error::IndexLoad(format!("Failed to open bucket_cutoffs.npy: {}", e))
508 })?)
509 .map_err(|e| {
510 Error::IndexLoad(format!("Failed to read bucket_cutoffs.npy: {}", e))
511 })?,
512 )
513 } else {
514 None
515 };
516
517 let bucket_weights_path = index_path.join("bucket_weights.npy");
518 let bucket_weights: Option<Array1<f32>> = if bucket_weights_path.exists() {
519 Some(
520 Array1::read_npy(File::open(&bucket_weights_path).map_err(|e| {
521 Error::IndexLoad(format!("Failed to open bucket_weights.npy: {}", e))
522 })?)
523 .map_err(|e| {
524 Error::IndexLoad(format!("Failed to read bucket_weights.npy: {}", e))
525 })?,
526 )
527 } else {
528 None
529 };
530
531 let metadata_path = index_path.join("metadata.json");
533 let metadata: serde_json::Value = serde_json::from_reader(
534 File::open(&metadata_path)
535 .map_err(|e| Error::IndexLoad(format!("Failed to open metadata.json: {}", e)))?,
536 )
537 .map_err(|e| Error::IndexLoad(format!("Failed to parse metadata.json: {}", e)))?;
538
539 let nbits = metadata["nbits"]
540 .as_u64()
541 .ok_or_else(|| Error::IndexLoad("nbits not found in metadata".into()))?
542 as usize;
543
544 Self::new_with_store(
545 nbits,
546 CentroidStore::Mmap(mmap_centroids),
547 avg_residual,
548 bucket_cutoffs,
549 bucket_weights,
550 )
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[test]
559 fn test_codec_creation() {
560 let centroids =
561 Array2::from_shape_vec((4, 8), (0..32).map(|x| x as f32).collect()).unwrap();
562 let avg_residual = Array1::zeros(8);
563 let bucket_cutoffs = Some(Array1::from_vec(vec![-0.5, 0.0, 0.5]));
564 let bucket_weights = Some(Array1::from_vec(vec![-0.75, -0.25, 0.25, 0.75]));
565
566 let codec = ResidualCodec::new(2, centroids, avg_residual, bucket_cutoffs, bucket_weights);
567 assert!(codec.is_ok());
568
569 let codec = codec.unwrap();
570 assert_eq!(codec.nbits, 2);
571 assert_eq!(codec.embedding_dim(), 8);
572 assert_eq!(codec.num_centroids(), 4);
573 }
574
575 #[test]
576 fn test_compress_into_codes() {
577 let centroids = Array2::from_shape_vec(
578 (3, 4),
579 vec![
580 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ],
584 )
585 .unwrap();
586
587 let avg_residual = Array1::zeros(4);
588 let codec = ResidualCodec::new(2, centroids, avg_residual, None, None).unwrap();
589
590 let embeddings = Array2::from_shape_vec(
591 (2, 4),
592 vec![
593 0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 0.95, 0.05, ],
596 )
597 .unwrap();
598
599 let codes = codec.compress_into_codes(&embeddings);
600 assert_eq!(codes[0], 0);
601 assert_eq!(codes[1], 2);
602 }
603
604 #[test]
605 fn test_quantize_decompress_roundtrip_4bit() {
606 let dim = 8;
608 let centroids = Array2::zeros((4, dim));
609 let avg_residual = Array1::zeros(dim);
610
611 let bucket_cutoffs: Vec<f32> = (1..16).map(|i| (i as f32 / 16.0 - 0.5) * 2.0).collect();
614 let bucket_weights: Vec<f32> = (0..16)
616 .map(|i| ((i as f32 + 0.5) / 16.0 - 0.5) * 2.0)
617 .collect();
618
619 let codec = ResidualCodec::new(
620 4,
621 centroids,
622 avg_residual,
623 Some(Array1::from_vec(bucket_cutoffs)),
624 Some(Array1::from_vec(bucket_weights)),
625 )
626 .unwrap();
627
628 let residuals = Array2::from_shape_vec(
630 (2, dim),
631 vec![
632 -0.9, -0.7, -0.5, -0.3, 0.0, 0.3, 0.5, 0.9, -0.8, -0.4, 0.0, 0.4, 0.8, -0.6, 0.2, 0.6,
634 ],
635 )
636 .unwrap();
637
638 let packed = codec.quantize_residuals(&residuals).unwrap();
640 assert_eq!(packed.ncols(), dim * 4 / 8); let codes = Array1::from_vec(vec![0, 0]);
644
645 let decompressed = codec.decompress(&packed, &codes.view()).unwrap();
647
648 for i in 0..residuals.nrows() {
651 for j in 0..residuals.ncols() {
652 let orig = residuals[[i, j]];
653 let recon = decompressed[[i, j]];
654 if orig.abs() > 0.2 {
658 assert!(
659 (orig > 0.0) == (recon > 0.0) || recon.abs() < 0.1,
660 "Sign mismatch at [{}, {}]: orig={}, recon={}",
661 i,
662 j,
663 orig,
664 recon
665 );
666 }
667 }
668 }
669 }
670}