1use std::collections::HashMap;
5use std::sync::Arc;
6
7use arrow::array::AsArray;
8use arrow::datatypes::{Float16Type, Float32Type, Float64Type, UInt64Type, UInt8Type};
9use arrow_array::{
10 Array, FixedSizeListArray, Float32Array, RecordBatch, UInt32Array, UInt64Array, UInt8Array,
11};
12use arrow_schema::{DataType, SchemaRef};
13use async_trait::async_trait;
14use bytes::{Bytes, BytesMut};
15use deepsize::DeepSizeOf;
16use itertools::Itertools;
17use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, RecordBatchExt};
18use lance_core::{Error, Result, ROW_ID};
19use lance_file::previous::reader::FileReader as PreviousFileReader;
20use lance_linalg::distance::{DistanceType, Dot};
21use lance_linalg::simd::dist_table::{BATCH_SIZE, PERM0, PERM0_INVERSE};
22use lance_linalg::simd::{self};
23use lance_table::utils::LanceIteratorExtension;
24use num_traits::AsPrimitive;
25use prost::Message;
26use serde::{Deserialize, Serialize};
27use snafu::location;
28
29use crate::frag_reuse::FragReuseIndex;
30use crate::pb;
31use crate::vector::bq::transform::{ADD_FACTORS_COLUMN, SCALE_FACTORS_COLUMN};
32use crate::vector::pq::storage::transpose;
33use crate::vector::quantizer::{QuantizerMetadata, QuantizerStorage};
34use crate::vector::storage::{DistCalculator, VectorStore};
35
36pub const RABIT_METADATA_KEY: &str = "lance:rabit";
37pub const RABIT_CODE_COLUMN: &str = "_rabit_codes";
38pub const SEGMENT_LENGTH: usize = 4;
39pub const SEGMENT_NUM_CODES: usize = 1 << SEGMENT_LENGTH;
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct RabitQuantizationMetadata {
43 #[serde(skip)]
47 pub rotate_mat: Option<FixedSizeListArray>,
48 pub rotate_mat_position: u32,
49 pub num_bits: u8,
50 pub packed: bool,
51}
52
53impl DeepSizeOf for RabitQuantizationMetadata {
54 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
55 self.rotate_mat
56 .as_ref()
57 .map(|inv_p| inv_p.get_array_memory_size())
58 .unwrap_or(0)
59 }
60}
61
62#[async_trait]
63impl QuantizerMetadata for RabitQuantizationMetadata {
64 fn buffer_index(&self) -> Option<u32> {
65 Some(self.rotate_mat_position)
66 }
67
68 fn set_buffer_index(&mut self, index: u32) {
69 self.rotate_mat_position = index;
70 }
71
72 fn parse_buffer(&mut self, bytes: Bytes) -> Result<()> {
73 debug_assert!(!bytes.is_empty());
74 let codebook_tensor: pb::Tensor = pb::Tensor::decode(bytes)?;
75 self.rotate_mat = Some(FixedSizeListArray::try_from(&codebook_tensor)?);
76 Ok(())
77 }
78
79 fn extra_metadata(&self) -> Result<Option<Bytes>> {
80 if let Some(inv_p) = &self.rotate_mat {
81 let inv_p_tensor = pb::Tensor::try_from(inv_p)?;
82 let mut bytes = BytesMut::new();
83 inv_p_tensor.encode(&mut bytes)?;
84 Ok(Some(bytes.freeze()))
85 } else {
86 Ok(None)
87 }
88 }
89
90 async fn load(reader: &PreviousFileReader) -> Result<Self> {
91 let metadata_str =
92 reader
93 .schema()
94 .metadata
95 .get(RABIT_METADATA_KEY)
96 .ok_or(Error::Index {
97 message: format!(
98 "Reading Rabit metadata: metadata key {} not found",
99 RABIT_METADATA_KEY
100 ),
101 location: location!(),
102 })?;
103 serde_json::from_str(metadata_str).map_err(|_| Error::Index {
104 message: format!("Failed to parse index metadata: {}", metadata_str),
105 location: location!(),
106 })
107 }
108}
109
110#[derive(Debug, Clone)]
111pub struct RabitQuantizationStorage {
112 metadata: RabitQuantizationMetadata,
113 batch: RecordBatch,
114 distance_type: DistanceType,
115
116 row_ids: UInt64Array,
118 codes: FixedSizeListArray,
119 add_factors: Float32Array,
120 scale_factors: Float32Array,
121}
122
123impl DeepSizeOf for RabitQuantizationStorage {
124 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
125 self.metadata.deep_size_of_children(context) + self.batch.get_array_memory_size()
126 }
127}
128
129impl RabitQuantizationStorage {
130 fn rotate_query_vector<T: ArrowFloatType>(
131 rotate_mat: &FixedSizeListArray,
132 qr: &dyn Array,
133 ) -> Vec<f32>
134 where
135 T::Native: Dot,
136 {
137 let d = qr.len();
138 let code_dim = rotate_mat.len();
139 let rotate_mat = rotate_mat
140 .values()
141 .as_any()
142 .downcast_ref::<T::ArrayType>()
143 .unwrap()
144 .as_slice();
145
146 let qr = qr
147 .as_any()
148 .downcast_ref::<T::ArrayType>()
149 .unwrap()
150 .as_slice();
151
152 rotate_mat
153 .chunks_exact(code_dim)
154 .map(|chunk| lance_linalg::distance::dot(&chunk[..d], qr))
155 .collect()
156 }
157}
158
159pub struct RabitDistCalculator<'a> {
160 dim: usize,
161 num_bits: u8,
164 codes: &'a [u8],
166 dist_table: Vec<f32>,
170 add_factors: &'a [f32],
171 scale_factors: &'a [f32],
172 query_factor: f32,
173
174 sum_q: f32,
175 sqrt_d: f32,
176}
177
178impl<'a> RabitDistCalculator<'a> {
179 #[allow(clippy::too_many_arguments)]
180 pub fn new(
181 dim: usize,
182 num_bits: u8,
183 dist_table: Vec<f32>,
184 sum_q: f32,
185 codes: &'a [u8],
186 add_factors: &'a [f32],
187 scale_factors: &'a [f32],
188 query_factor: f32,
189 ) -> Self {
190 Self {
191 dim,
192 num_bits,
193 codes,
194 dist_table,
195 add_factors,
196 scale_factors,
197 query_factor,
198 sqrt_d: (dim as f32 * num_bits as f32).sqrt(),
199 sum_q,
200 }
201 }
202}
203
204#[inline]
205fn lowbit(x: usize) -> usize {
206 1 << x.trailing_zeros()
207}
208
209#[inline]
210pub fn build_dist_table_direct<T: ArrowFloatType>(qc: &[T::Native]) -> Vec<f32>
211where
212 T::Native: AsPrimitive<f32>,
213{
214 let mut dist_table = vec![0.0; qc.len() * 4];
218 qc.chunks_exact(SEGMENT_LENGTH)
219 .zip(dist_table.chunks_exact_mut(SEGMENT_NUM_CODES))
220 .for_each(|(sub_vec, dist_table)| build_dist_table_for_subvec::<T>(sub_vec, dist_table));
221 dist_table
222}
223
224#[inline(always)]
225fn build_dist_table_for_subvec<T: ArrowFloatType>(sub_vec: &[T::Native], dist_table: &mut [f32])
226where
227 T::Native: AsPrimitive<f32>,
228{
229 (1..SEGMENT_NUM_CODES).for_each(|j| {
231 dist_table[j] = dist_table[j - lowbit(j)] + sub_vec[LOWBIT_IDX[j]].as_();
244 })
245}
246
247#[inline]
249fn quantize_dist_table(dist_table: &[f32]) -> (f32, f32, Vec<u8>) {
250 let (qmin, qmax) = dist_table
251 .iter()
252 .cloned()
253 .minmax_by(|a, b| a.total_cmp(b))
254 .into_option()
255 .unwrap();
256 if qmin == qmax {
258 return (qmin, qmax, vec![0; dist_table.len()]);
259 }
260 let factor = 255.0 / (qmax - qmin);
261 let quantized_dist_table = dist_table
262 .iter()
263 .map(|&d| ((d - qmin) * factor).round() as u8)
264 .collect();
265
266 (qmin, qmax, quantized_dist_table)
267}
268
269#[inline]
270fn compute_rq_distance_flat(
271 dist_table: &[f32],
272 codes: &[u8],
273 offset: usize,
274 length: usize,
275 dists: &mut [f32],
276) {
277 let d = dist_table.len() / 4;
278 let code_len = d / u8::BITS as usize;
279 let codes = &codes[offset * code_len..(offset + length) * code_len];
280 let dists = &mut dists[offset..offset + length];
281
282 for (sub_vec_idx, codes) in codes.chunks_exact(length).enumerate() {
283 let current_dist_table = &dist_table
284 [sub_vec_idx * 2 * SEGMENT_NUM_CODES..(sub_vec_idx * 2 + 1) * SEGMENT_NUM_CODES];
285 let next_dist_table = &dist_table
286 [(sub_vec_idx * 2 + 1) * SEGMENT_NUM_CODES..(sub_vec_idx * 2 + 2) * SEGMENT_NUM_CODES];
287
288 codes.iter().zip(dists.iter_mut()).for_each(|(code, dist)| {
289 let current_code = (code & 0x0F) as usize;
290 let next_code = (code >> 4) as usize;
291 *dist += current_dist_table[current_code] + next_dist_table[next_code];
292 });
293 }
294}
295
296impl DistCalculator for RabitDistCalculator<'_> {
297 #[inline(always)]
298 fn distance(&self, id: u32) -> f32 {
299 let id = id as usize;
300 let code_len = self.dim * (self.num_bits as usize) / u8::BITS as usize;
301 let num_vectors = self.codes.len() / code_len;
302 let code = get_rq_code(self.codes, id, num_vectors, code_len);
303 let dist = code
304 .zip(self.dist_table.chunks_exact(SEGMENT_NUM_CODES).tuples())
305 .map(|(code_byte, (dist_table, next_dist_table))| {
306 let current_code = (code_byte & 0x0F) as usize;
309 let next_code = (code_byte >> 4) as usize;
310 dist_table[current_code] + next_dist_table[next_code]
311 })
312 .sum::<f32>();
313
314 let dist_vq_qr = (2.0 * dist - self.sum_q) / self.sqrt_d;
316 dist_vq_qr * self.scale_factors[id] + self.add_factors[id] + self.query_factor
317 }
318
319 #[inline(always)]
320 fn distance_all(&self, _: usize) -> Vec<f32> {
321 let code_len = self.dim * (self.num_bits as usize) / u8::BITS as usize;
322 let n = self.codes.len() / code_len;
323 if n == 0 {
324 return Vec::new();
325 }
326
327 let mut dists = vec![0.0; n];
328
329 let (qmin, qmax, quantized_dists_table) = quantize_dist_table(&self.dist_table);
330 let mut quantized_dists = vec![0; n];
331
332 let remainder = n % BATCH_SIZE;
333 simd::dist_table::sum_4bit_dist_table(
334 n - remainder,
335 code_len,
336 self.codes,
337 &quantized_dists_table,
338 &mut quantized_dists,
339 );
340 if remainder > 0 {
341 compute_rq_distance_flat(
342 &self.dist_table,
343 self.codes,
344 n - remainder,
345 remainder,
346 &mut dists,
347 );
348 }
349
350 let range = (qmax - qmin) / 255.0;
351 let num_tables = quantized_dists_table.len() / 16;
352 let sum_min = num_tables as f32 * qmin;
353 dists
354 .iter_mut()
355 .take(n - remainder)
356 .zip(quantized_dists.into_iter().take(n - remainder))
357 .for_each(|(dist, q_dist)| {
358 *dist = (q_dist as f32) * range + sum_min;
359 });
360
361 dists
362 .into_iter()
363 .enumerate()
364 .map(|(id, dist)| {
365 let dist_vq_qr = (2.0 * dist - self.sum_q) / self.sqrt_d;
366 dist_vq_qr * self.scale_factors[id] + self.add_factors[id] + self.query_factor
367 })
368 .collect()
369 }
370}
371
372impl VectorStore for RabitQuantizationStorage {
373 type DistanceCalculator<'a> = RabitDistCalculator<'a>;
374
375 fn as_any(&self) -> &dyn std::any::Any {
376 self
377 }
378
379 fn schema(&self) -> &SchemaRef {
380 self.batch.schema_ref()
381 }
382
383 fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch> + Send> {
384 Ok(std::iter::once(self.batch.clone()))
385 }
386
387 fn append_batch(&self, _batch: RecordBatch, _vector_column: &str) -> Result<Self> {
388 unimplemented!("RabitQ does not support append_batch")
389 }
390
391 fn len(&self) -> usize {
392 self.batch.num_rows()
393 }
394
395 fn row_id(&self, id: u32) -> u64 {
396 self.row_ids.value(id as usize)
397 }
398
399 fn row_ids(&self) -> impl Iterator<Item = &u64> {
400 self.row_ids.values().iter()
401 }
402
403 fn distance_type(&self) -> DistanceType {
404 self.distance_type
405 }
406
407 #[inline(never)]
409 fn dist_calculator(&self, qr: Arc<dyn Array>, dist_q_c: f32) -> Self::DistanceCalculator<'_> {
410 let codes = self.codes.values().as_primitive::<UInt8Type>().values();
411 let rotate_mat = self
412 .metadata
413 .rotate_mat
414 .as_ref()
415 .expect("RabitQ metadata not loaded");
416
417 let rotated_qr = match rotate_mat.value_type() {
418 DataType::Float16 => Self::rotate_query_vector::<Float16Type>(rotate_mat, &qr),
419 DataType::Float32 => Self::rotate_query_vector::<Float32Type>(rotate_mat, &qr),
420 DataType::Float64 => Self::rotate_query_vector::<Float64Type>(rotate_mat, &qr),
421 dt => unimplemented!("RabitQ does not support data type: {}", dt),
422 };
423
424 let dist_table = build_dist_table_direct::<Float32Type>(&rotated_qr);
425 let sum_q = rotated_qr.into_iter().sum();
426
427 let q_factor = match self.distance_type {
428 DistanceType::L2 => dist_q_c,
429 DistanceType::Cosine | DistanceType::Dot => dist_q_c - 1.0,
430 _ => unimplemented!(
431 "RabitQ does not support distance type: {}",
432 self.distance_type
433 ),
434 };
435 RabitDistCalculator::new(
436 qr.len(),
437 self.metadata.num_bits,
438 dist_table,
439 sum_q,
440 codes,
441 self.add_factors.values(),
442 self.scale_factors.values(),
443 q_factor,
444 )
445 }
446
447 fn dist_calculator_from_id(&self, _: u32) -> Self::DistanceCalculator<'_> {
450 unimplemented!("RabitQ does not support dist_calculator_from_id")
451 }
452}
453
454const LOWBIT_IDX: [usize; 16] = {
455 let mut array = [0; 16];
456 let mut i = 1;
457 while i < 16 {
458 array[i] = i.trailing_zeros() as usize;
459 i += 1;
460 }
461 array
462};
463
464fn get_column(
465 quantization_code: &[u8],
466 code_len: usize,
467 row: usize,
468 col_idx: usize,
469 codes: &mut [u8; 32],
470) {
471 for (i, code) in codes.iter_mut().enumerate() {
472 let vec_idx = row + i;
473 *code = quantization_code[vec_idx * code_len + col_idx];
474 }
475}
476
477pub fn pack_codes(codes: &FixedSizeListArray) -> FixedSizeListArray {
478 let code_len = codes.value_length() as usize;
479
480 let num_blocks = codes.len() / BATCH_SIZE;
482 let num_packed_vectors = num_blocks * BATCH_SIZE;
483
484 let mut blocks = vec![0u8; codes.values().len()];
490
491 let codes_values = codes
492 .slice(0, num_packed_vectors)
493 .values()
494 .as_primitive::<UInt8Type>()
495 .clone();
496 let codes_values = codes_values.values();
497
498 let mut col = [0u8; 32];
501 let mut col_0 = [0u8; 32]; let mut col_1 = [0u8; 32]; for row in (0..num_packed_vectors).step_by(BATCH_SIZE) {
504 for i in 0..code_len {
508 get_column(codes_values, code_len, row, i, &mut col);
509
510 for j in 0..32 {
511 col_0[j] = col[j] & 0xF;
512 col_1[j] = col[j] >> 4;
513 }
514
515 let block_offset = (row / BATCH_SIZE) * code_len * BATCH_SIZE + i * BATCH_SIZE;
516 for j in 0..16 {
517 let val0 = col_0[PERM0[j]] | (col_0[PERM0[j] + 16] << 4);
520 let val1 = col_1[PERM0[j]] | (col_1[PERM0[j] + 16] << 4);
521 blocks[block_offset + j] = val0;
522 blocks[block_offset + j + 16] = val1;
523 }
524 }
525 }
526
527 let transposed_codes = transpose(
529 &codes.values().as_primitive::<UInt8Type>().slice(
530 num_packed_vectors * code_len,
531 (codes.len() - num_packed_vectors) * code_len,
532 ),
533 codes.len() - num_packed_vectors,
534 code_len,
535 );
536
537 let offset = codes.values().len() - transposed_codes.len();
538 for (i, v) in transposed_codes.values().iter().enumerate() {
539 blocks[offset + i] = *v;
540 }
541
542 assert_eq!(blocks.len(), codes.values().len());
543 FixedSizeListArray::try_new_from_values(UInt8Array::from(blocks), code_len as i32).unwrap()
544}
545
546pub fn unpack_codes(codes: &FixedSizeListArray) -> FixedSizeListArray {
548 let code_len = codes.value_length() as usize;
549 let num_vectors = codes.len();
550
551 let num_blocks = num_vectors / BATCH_SIZE;
553 let num_packed_vectors = num_blocks * BATCH_SIZE;
554
555 let mut unpacked = vec![0u8; codes.values().len()];
556
557 let codes_values = codes.values().as_primitive::<UInt8Type>().values();
558
559 for batch_idx in 0..num_blocks {
561 let block_start = batch_idx * code_len * BATCH_SIZE;
562
563 for i in 0..code_len {
564 let block_offset = block_start + i * BATCH_SIZE;
565 let block = &codes_values[block_offset..block_offset + BATCH_SIZE];
566
567 for j in 0..16 {
569 let val0 = block[j];
570 let val1 = block[j + 16];
571
572 let low_0 = val0 & 0xF;
573 let high_0 = val0 >> 4;
574 let low_1 = val1 & 0xF;
575 let high_1 = val1 >> 4;
576
577 let vec_idx_0 = batch_idx * BATCH_SIZE + PERM0[j];
578 let vec_idx_1 = batch_idx * BATCH_SIZE + PERM0[j] + 16;
579
580 unpacked[vec_idx_0 * code_len + i] = low_0 | (low_1 << 4);
581 unpacked[vec_idx_1 * code_len + i] = high_0 | (high_1 << 4);
582 }
583 }
584 }
585
586 if num_packed_vectors < num_vectors {
588 let remainder = num_vectors - num_packed_vectors;
589 let offset = num_packed_vectors * code_len;
590 let transposed_data = &codes_values[offset..];
591
592 for row in 0..remainder {
594 for col in 0..code_len {
595 unpacked[offset + row * code_len + col] = transposed_data[col * remainder + row];
596 }
597 }
598 }
599
600 FixedSizeListArray::try_new_from_values(UInt8Array::from(unpacked), code_len as i32).unwrap()
601}
602
603#[async_trait]
604impl QuantizerStorage for RabitQuantizationStorage {
605 type Metadata = RabitQuantizationMetadata;
606
607 fn try_from_batch(
608 batch: RecordBatch,
609 metadata: &Self::Metadata,
610 distance_type: DistanceType,
611 _fri: Option<Arc<FragReuseIndex>>,
612 ) -> Result<Self> {
613 let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>().clone();
614 let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list().clone();
615 let add_factors = batch[ADD_FACTORS_COLUMN]
616 .as_primitive::<Float32Type>()
617 .clone();
618 let scale_factors = batch[SCALE_FACTORS_COLUMN]
619 .as_primitive::<Float32Type>()
620 .clone();
621
622 let (batch, codes) = if !metadata.packed {
623 let codes = pack_codes(&codes);
624 let batch = batch.replace_column_by_name(RABIT_CODE_COLUMN, Arc::new(codes))?;
625 let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list().clone();
626 (batch, codes)
627 } else {
628 (batch, codes)
629 };
630
631 let mut metadata = metadata.clone();
632 metadata.packed = true;
633
634 Ok(Self {
635 metadata,
636 batch,
637 distance_type,
638 row_ids,
639 codes,
640 add_factors,
641 scale_factors,
642 })
643 }
644
645 fn metadata(&self) -> &Self::Metadata {
646 &self.metadata
647 }
648
649 async fn load_partition(
650 reader: &PreviousFileReader,
651 range: std::ops::Range<usize>,
652 distance_type: DistanceType,
653 metadata: &Self::Metadata,
654 frag_reuse_index: Option<Arc<FragReuseIndex>>,
655 ) -> Result<Self> {
656 let schema = reader.schema();
657 let batch = reader.read_range(range, schema).await?;
658 Self::try_from_batch(batch, metadata, distance_type, frag_reuse_index)
659 }
660
661 fn remap(&self, mapping: &HashMap<u64, Option<u64>>) -> Result<Self> {
662 let num_vectors = self.codes.len();
663 let num_code_bytes = self.codes.value_length() as usize;
664 let codes = self.codes.values().as_primitive::<UInt8Type>().values();
665 let mut indices = Vec::with_capacity(num_vectors);
666 let mut new_row_ids = Vec::with_capacity(num_vectors);
667 let mut new_codes = Vec::with_capacity(codes.len());
668
669 let row_ids = self.row_ids.values();
670 for (i, row_id) in row_ids.iter().enumerate() {
671 match mapping.get(row_id) {
672 Some(Some(new_id)) => {
673 indices.push(i as u32);
674 new_row_ids.push(*new_id);
675 new_codes.extend(get_rq_code(codes, i, num_vectors, num_code_bytes));
676 }
677 Some(None) => {}
678 None => {
679 indices.push(i as u32);
680 new_row_ids.push(*row_id);
681 new_codes.extend(get_rq_code(codes, i, num_vectors, num_code_bytes));
682 }
683 }
684 }
685
686 let new_row_ids = UInt64Array::from(new_row_ids);
687 let new_codes = FixedSizeListArray::try_new_from_values(
688 UInt8Array::from(new_codes),
689 num_code_bytes as i32,
690 )?;
691 let batch = if new_row_ids.is_empty() {
692 RecordBatch::new_empty(self.schema().clone())
693 } else {
694 let codes = Arc::new(pack_codes(&new_codes));
695 self.batch
696 .take(&UInt32Array::from(indices))?
697 .replace_column_by_name(ROW_ID, Arc::new(new_row_ids.clone()))?
698 .replace_column_by_name(RABIT_CODE_COLUMN, codes)?
699 };
700 let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list().clone();
701
702 Ok(Self {
703 metadata: self.metadata.clone(),
704 distance_type: self.distance_type,
705 batch,
706 codes,
707 add_factors: self.add_factors.clone(),
708 scale_factors: self.scale_factors.clone(),
709 row_ids: new_row_ids,
710 })
711 }
712}
713
714#[inline]
715fn get_rq_code(
716 codes: &[u8],
717 id: usize,
718 num_vectors: usize,
719 num_code_bytes: usize,
720) -> impl Iterator<Item = u8> + '_ {
721 let remainder = num_vectors % BATCH_SIZE;
722
723 if id < num_vectors - remainder {
724 let codes = &codes[id / BATCH_SIZE * BATCH_SIZE * num_code_bytes
726 ..(id / BATCH_SIZE + 1) * BATCH_SIZE * num_code_bytes];
727
728 let id_in_batch = id % BATCH_SIZE;
729 if id_in_batch < 16 {
730 let idx = PERM0_INVERSE[id_in_batch];
731 codes
732 .chunks_exact(BATCH_SIZE)
733 .map(|block| (block[idx] & 0xF) | (block[idx + 16] << 4))
734 .exact_size(num_code_bytes)
735 .collect_vec()
736 .into_iter()
737 } else {
738 let idx = PERM0_INVERSE[id_in_batch - 16];
739 codes
740 .chunks_exact(BATCH_SIZE)
741 .map(|block| (block[idx] >> 4) | (block[idx + 16] & 0xF0))
742 .exact_size(num_code_bytes)
743 .collect_vec()
744 .into_iter()
745 }
746 } else {
747 let id = id - (num_vectors - remainder);
748 let codes = &codes[(num_vectors - remainder) * num_code_bytes..];
749 codes
750 .iter()
751 .skip(id)
752 .step_by(remainder)
753 .copied()
754 .exact_size(num_code_bytes)
755 .collect_vec()
756 .into_iter()
757 }
758}
759
760#[cfg(test)]
761mod tests {
762 use super::*;
763
764 fn build_dist_table_not_optimized<T: ArrowFloatType>(
765 sub_vec: &[T::Native],
766 dist_table: &mut [f32],
767 ) where
768 T::Native: AsPrimitive<f32>,
769 {
770 for (j, dist) in dist_table.iter_mut().enumerate().take(SEGMENT_NUM_CODES) {
771 for (k, v) in sub_vec.iter().enumerate().take(SEGMENT_LENGTH) {
772 if j & (1 << k) != 0 {
773 *dist += v.as_();
774 }
775 }
776 }
777 }
778
779 #[test]
780 fn test_build_dist_table_not_optimized() {
781 let sub_vec = vec![1.0, 2.0, 3.0, 4.0];
782 let mut expected = vec![0.0; SEGMENT_NUM_CODES];
783 build_dist_table_not_optimized::<Float32Type>(&sub_vec, &mut expected);
784 let mut dist_table = vec![0.0; SEGMENT_NUM_CODES];
785 build_dist_table_for_subvec::<Float32Type>(&sub_vec, &mut dist_table);
786 assert_eq!(dist_table, expected);
787 }
788
789 #[test]
790 fn test_pack_unpack_codes() {
791 for num_vectors in [10, 32, 50, 64, 100] {
793 let code_len = 8;
794
795 let mut codes_data = Vec::new();
797 for i in 0..num_vectors {
798 for j in 0..code_len {
799 codes_data.push((i * code_len + j) as u8);
800 }
801 }
802
803 let original_codes = FixedSizeListArray::try_new_from_values(
804 UInt8Array::from(codes_data.clone()),
805 code_len,
806 )
807 .unwrap();
808
809 let packed = pack_codes(&original_codes);
811 let unpacked = unpack_codes(&packed);
812
813 assert_eq!(original_codes.len(), unpacked.len());
815 assert_eq!(original_codes.value_length(), unpacked.value_length());
816
817 let original_values = original_codes.values().as_primitive::<UInt8Type>().values();
818 let unpacked_values = unpacked.values().as_primitive::<UInt8Type>().values();
819
820 assert_eq!(
821 original_values, unpacked_values,
822 "Mismatch for num_vectors={}",
823 num_vectors
824 );
825 }
826 }
827}