1use std::collections::HashMap;
5use std::sync::Arc;
6
7use arrow::array::AsArray;
8use arrow::datatypes::{Float16Type, Float32Type, Float64Type, UInt8Type, UInt64Type};
9use arrow_array::{
10 Array, FixedSizeListArray, Float32Array, RecordBatch, UInt8Array, UInt32Array, UInt64Array,
11};
12use arrow_schema::{DataType, SchemaRef};
13use async_trait::async_trait;
14use bytes::{Bytes, BytesMut};
15use deepsize::DeepSizeOf;
16use itertools::Itertools;
17use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, RecordBatchExt};
18use lance_core::{Error, ROW_ID, Result};
19use lance_file::previous::reader::FileReader as PreviousFileReader;
20use lance_linalg::distance::{DistanceType, Dot};
21use lance_linalg::simd::dist_table::{BATCH_SIZE, PERM0, PERM0_INVERSE};
22use lance_linalg::simd::{self};
23use lance_table::utils::LanceIteratorExtension;
24use num_traits::AsPrimitive;
25use prost::Message;
26use serde::{Deserialize, Serialize};
27
28use crate::frag_reuse::FragReuseIndex;
29use crate::pb;
30use crate::vector::bq::RQRotationType;
31use crate::vector::bq::rotation::apply_fast_rotation;
32use crate::vector::bq::transform::{ADD_FACTORS_COLUMN, SCALE_FACTORS_COLUMN};
33use crate::vector::pq::storage::transpose;
34use crate::vector::quantizer::{QuantizerMetadata, QuantizerStorage};
35use crate::vector::storage::{DistCalculator, VectorStore};
36
37pub const RABIT_METADATA_KEY: &str = "lance:rabit";
38pub const RABIT_CODE_COLUMN: &str = "_rabit_codes";
39pub const SEGMENT_LENGTH: usize = 4;
40pub const SEGMENT_NUM_CODES: usize = 1 << SEGMENT_LENGTH;
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct RabitQuantizationMetadata {
44 #[serde(skip)]
48 pub rotate_mat: Option<FixedSizeListArray>,
49 #[serde(default)]
50 pub rotate_mat_position: Option<u32>,
51 #[serde(default)]
52 pub fast_rotation_signs: Option<Vec<u8>>,
53 #[serde(default = "default_rotation_type_compat")]
54 pub rotation_type: RQRotationType,
55 #[serde(default)]
56 pub code_dim: u32,
57 pub num_bits: u8,
58 pub packed: bool,
59}
60
61fn default_rotation_type_compat() -> RQRotationType {
62 RQRotationType::Matrix
64}
65
66impl DeepSizeOf for RabitQuantizationMetadata {
67 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
68 self.rotate_mat
69 .as_ref()
70 .map(|inv_p| inv_p.get_array_memory_size())
71 .unwrap_or(0)
72 + self
73 .fast_rotation_signs
74 .as_ref()
75 .map(|signs| signs.len())
76 .unwrap_or(0)
77 }
78}
79
80#[async_trait]
81impl QuantizerMetadata for RabitQuantizationMetadata {
82 fn buffer_index(&self) -> Option<u32> {
83 match self.rotation_type {
84 RQRotationType::Matrix => self.rotate_mat_position,
85 RQRotationType::Fast => None,
86 }
87 }
88
89 fn set_buffer_index(&mut self, index: u32) {
90 self.rotate_mat_position = Some(index);
91 }
92
93 fn parse_buffer(&mut self, bytes: Bytes) -> Result<()> {
94 if self.rotation_type != RQRotationType::Matrix {
95 return Ok(());
96 }
97 debug_assert!(!bytes.is_empty());
98 let codebook_tensor: pb::Tensor = pb::Tensor::decode(bytes)?;
99 self.rotate_mat = Some(FixedSizeListArray::try_from(&codebook_tensor)?);
100 if self.code_dim == 0 {
101 self.code_dim = self
102 .rotate_mat
103 .as_ref()
104 .map(|rotate_mat| rotate_mat.len() as u32)
105 .unwrap_or(0);
106 }
107 Ok(())
108 }
109
110 fn extra_metadata(&self) -> Result<Option<Bytes>> {
111 match self.rotation_type {
112 RQRotationType::Matrix => {
113 if let Some(inv_p) = &self.rotate_mat {
114 let inv_p_tensor = pb::Tensor::try_from(inv_p)?;
115 let mut bytes = BytesMut::new();
116 inv_p_tensor.encode(&mut bytes)?;
117 Ok(Some(bytes.freeze()))
118 } else {
119 Ok(None)
120 }
121 }
122 RQRotationType::Fast => Ok(None),
123 }
124 }
125
126 async fn load(reader: &PreviousFileReader) -> Result<Self> {
127 let metadata_str = reader
128 .schema()
129 .metadata
130 .get(RABIT_METADATA_KEY)
131 .ok_or(Error::index(format!(
132 "Reading Rabit metadata: metadata key {} not found",
133 RABIT_METADATA_KEY
134 )))?;
135 serde_json::from_str(metadata_str)
136 .map_err(|_| Error::index(format!("Failed to parse index metadata: {}", metadata_str)))
137 }
138}
139
140#[derive(Debug, Clone)]
141pub struct RabitQuantizationStorage {
142 metadata: RabitQuantizationMetadata,
143 batch: RecordBatch,
144 distance_type: DistanceType,
145
146 row_ids: UInt64Array,
148 codes: FixedSizeListArray,
149 add_factors: Float32Array,
150 scale_factors: Float32Array,
151}
152
153impl DeepSizeOf for RabitQuantizationStorage {
154 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
155 self.metadata.deep_size_of_children(context) + self.batch.get_array_memory_size()
156 }
157}
158
159impl RabitQuantizationStorage {
160 fn rotate_query_vector_dense<T: ArrowFloatType>(
161 rotate_mat: &FixedSizeListArray,
162 qr: &dyn Array,
163 ) -> Vec<f32>
164 where
165 T::Native: Dot,
166 {
167 let d = qr.len();
168 let code_dim = rotate_mat.len();
169 let rotate_mat = rotate_mat
170 .values()
171 .as_any()
172 .downcast_ref::<T::ArrayType>()
173 .unwrap()
174 .as_slice();
175
176 let qr = qr
177 .as_any()
178 .downcast_ref::<T::ArrayType>()
179 .unwrap()
180 .as_slice();
181
182 rotate_mat
183 .chunks_exact(code_dim)
184 .map(|chunk| lance_linalg::distance::dot(&chunk[..d], qr))
185 .collect()
186 }
187
188 fn rotate_query_vector_fast<T: ArrowFloatType>(
189 code_dim: usize,
190 signs: &[u8],
191 qr: &dyn Array,
192 ) -> Vec<f32>
193 where
194 T::Native: AsPrimitive<f32>,
195 {
196 let qr = qr
197 .as_any()
198 .downcast_ref::<T::ArrayType>()
199 .unwrap()
200 .as_slice();
201
202 let mut output = vec![0.0f32; code_dim];
203 apply_fast_rotation(qr, &mut output, signs);
204 output
205 }
206}
207
208pub struct RabitDistCalculator<'a> {
209 dim: usize,
210 num_bits: u8,
213 codes: &'a [u8],
215 dist_table: Vec<f32>,
219 add_factors: &'a [f32],
220 scale_factors: &'a [f32],
221 query_factor: f32,
222
223 sum_q: f32,
224 sqrt_d: f32,
225}
226
227impl<'a> RabitDistCalculator<'a> {
228 #[allow(clippy::too_many_arguments)]
229 pub fn new(
230 dim: usize,
231 num_bits: u8,
232 dist_table: Vec<f32>,
233 sum_q: f32,
234 codes: &'a [u8],
235 add_factors: &'a [f32],
236 scale_factors: &'a [f32],
237 query_factor: f32,
238 ) -> Self {
239 Self {
240 dim,
241 num_bits,
242 codes,
243 dist_table,
244 add_factors,
245 scale_factors,
246 query_factor,
247 sqrt_d: (dim as f32 * num_bits as f32).sqrt(),
248 sum_q,
249 }
250 }
251}
252
253#[inline]
254fn lowbit(x: usize) -> usize {
255 1 << x.trailing_zeros()
256}
257
258#[inline]
259pub fn build_dist_table_direct<T: ArrowFloatType>(qc: &[T::Native]) -> Vec<f32>
260where
261 T::Native: AsPrimitive<f32>,
262{
263 let mut dist_table = vec![0.0; qc.len() * 4];
267 qc.chunks_exact(SEGMENT_LENGTH)
268 .zip(dist_table.chunks_exact_mut(SEGMENT_NUM_CODES))
269 .for_each(|(sub_vec, dist_table)| build_dist_table_for_subvec::<T>(sub_vec, dist_table));
270 dist_table
271}
272
273#[inline(always)]
274fn build_dist_table_for_subvec<T: ArrowFloatType>(sub_vec: &[T::Native], dist_table: &mut [f32])
275where
276 T::Native: AsPrimitive<f32>,
277{
278 (1..SEGMENT_NUM_CODES).for_each(|j| {
280 dist_table[j] = dist_table[j - lowbit(j)] + sub_vec[LOWBIT_IDX[j]].as_();
293 })
294}
295
296#[inline]
298fn quantize_dist_table_into(dist_table: &[f32], quantized_dist_table: &mut Vec<u8>) -> (f32, f32) {
299 let (qmin, qmax) = dist_table
300 .iter()
301 .cloned()
302 .minmax_by(|a, b| a.total_cmp(b))
303 .into_option()
304 .unwrap();
305 quantized_dist_table.clear();
306 quantized_dist_table.resize(dist_table.len(), 0);
307 if qmin == qmax {
309 return (qmin, qmax);
310 }
311 let factor = 255.0 / (qmax - qmin);
312 quantized_dist_table
313 .iter_mut()
314 .zip(dist_table.iter())
315 .for_each(|(quantized, &d)| {
316 *quantized = ((d - qmin) * factor).round() as u8;
317 });
318
319 (qmin, qmax)
320}
321
322impl DistCalculator for RabitDistCalculator<'_> {
323 #[inline(always)]
324 fn distance(&self, id: u32) -> f32 {
325 let id = id as usize;
326 let code_len = self.dim * (self.num_bits as usize) / u8::BITS as usize;
327 let num_vectors = self.codes.len() / code_len;
328 let dist =
329 compute_single_rq_distance(self.codes, id, num_vectors, code_len, &self.dist_table);
330
331 let dist_vq_qr = (2.0 * dist - self.sum_q) / self.sqrt_d;
333 dist_vq_qr * self.scale_factors[id] + self.add_factors[id] + self.query_factor
334 }
335
336 #[inline(always)]
337 fn distance_all(&self, _: usize) -> Vec<f32> {
338 let mut dists = Vec::new();
339 let mut quantized_dists = Vec::new();
340 let mut quantized_dists_table = Vec::new();
341 self.distance_all_with_scratch(
342 0,
343 &mut dists,
344 &mut quantized_dists,
345 &mut quantized_dists_table,
346 );
347 dists
348 }
349
350 #[inline(always)]
351 fn distance_all_with_scratch(
352 &self,
353 _: usize,
354 dists: &mut Vec<f32>,
355 quantized_dists: &mut Vec<u16>,
356 quantized_dists_table: &mut Vec<u8>,
357 ) {
358 let code_len = self.dim * (self.num_bits as usize) / u8::BITS as usize;
359 let n = self.codes.len() / code_len;
360 if n == 0 {
361 dists.clear();
362 quantized_dists.clear();
363 return;
364 }
365
366 dists.clear();
367 dists.resize(n, 0.0);
368 let (qmin, qmax) = quantize_dist_table_into(&self.dist_table, quantized_dists_table);
369 quantized_dists.clear();
370 quantized_dists.resize(n, 0);
371
372 let remainder = n % BATCH_SIZE;
373 simd::dist_table::sum_4bit_dist_table(
374 n - remainder,
375 code_len,
376 self.codes,
377 quantized_dists_table,
378 quantized_dists,
379 );
380
381 let range = (qmax - qmin) / 255.0;
382 let num_tables = quantized_dists_table.len() / 16;
383 let sum_min = num_tables as f32 * qmin;
384 dists
385 .iter_mut()
386 .take(n - remainder)
387 .zip(quantized_dists.iter().take(n - remainder))
388 .for_each(|(dist, q_dist)| {
389 *dist = (*q_dist as f32) * range + sum_min;
390 });
391
392 dists
393 .iter_mut()
394 .enumerate()
395 .take(n - remainder)
396 .for_each(|(id, dist)| {
397 let dist_vq_qr = (2.0 * *dist - self.sum_q) / self.sqrt_d;
398 *dist =
399 dist_vq_qr * self.scale_factors[id] + self.add_factors[id] + self.query_factor;
400 });
401
402 dists
403 .iter_mut()
404 .enumerate()
405 .skip(n - remainder)
406 .for_each(|(id, dist)| {
407 *dist = self.distance(id as u32);
408 });
409 }
410}
411
412impl VectorStore for RabitQuantizationStorage {
413 type DistanceCalculator<'a> = RabitDistCalculator<'a>;
414
415 fn as_any(&self) -> &dyn std::any::Any {
416 self
417 }
418
419 fn schema(&self) -> &SchemaRef {
420 self.batch.schema_ref()
421 }
422
423 fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch> + Send> {
424 Ok(std::iter::once(self.batch.clone()))
425 }
426
427 fn append_batch(&self, _batch: RecordBatch, _vector_column: &str) -> Result<Self> {
428 unimplemented!("RabitQ does not support append_batch")
429 }
430
431 fn len(&self) -> usize {
432 self.batch.num_rows()
433 }
434
435 fn row_id(&self, id: u32) -> u64 {
436 self.row_ids.value(id as usize)
437 }
438
439 fn row_ids(&self) -> impl Iterator<Item = &u64> {
440 self.row_ids.values().iter()
441 }
442
443 fn distance_type(&self) -> DistanceType {
444 self.distance_type
445 }
446
447 #[inline(never)]
449 fn dist_calculator(&self, qr: Arc<dyn Array>, dist_q_c: f32) -> Self::DistanceCalculator<'_> {
450 let codes = self.codes.values().as_primitive::<UInt8Type>().values();
451 let code_dim = if self.metadata.code_dim > 0 {
452 self.metadata.code_dim as usize
453 } else {
454 self.metadata
455 .rotate_mat
456 .as_ref()
457 .map(|rotate_mat| rotate_mat.len())
458 .unwrap_or_default()
459 };
460
461 let rotated_qr = match self.metadata.rotation_type {
462 RQRotationType::Matrix => {
463 let rotate_mat = self
464 .metadata
465 .rotate_mat
466 .as_ref()
467 .expect("RabitQ dense rotation metadata not loaded");
468
469 match rotate_mat.value_type() {
470 DataType::Float16 => {
471 Self::rotate_query_vector_dense::<Float16Type>(rotate_mat, &qr)
472 }
473 DataType::Float32 => {
474 Self::rotate_query_vector_dense::<Float32Type>(rotate_mat, &qr)
475 }
476 DataType::Float64 => {
477 Self::rotate_query_vector_dense::<Float64Type>(rotate_mat, &qr)
478 }
479 dt => unimplemented!("RabitQ does not support data type: {}", dt),
480 }
481 }
482 RQRotationType::Fast => {
483 let signs = self
484 .metadata
485 .fast_rotation_signs
486 .as_ref()
487 .expect("RabitQ fast rotation metadata not loaded");
488 match qr.data_type() {
489 DataType::Float16 => {
490 Self::rotate_query_vector_fast::<Float16Type>(code_dim, signs, &qr)
491 }
492 DataType::Float32 => {
493 Self::rotate_query_vector_fast::<Float32Type>(code_dim, signs, &qr)
494 }
495 DataType::Float64 => {
496 Self::rotate_query_vector_fast::<Float64Type>(code_dim, signs, &qr)
497 }
498 dt => unimplemented!("RabitQ does not support data type: {}", dt),
499 }
500 }
501 };
502
503 let dist_table = build_dist_table_direct::<Float32Type>(&rotated_qr);
504 let sum_q = rotated_qr.into_iter().sum();
505
506 let q_factor = match self.distance_type {
507 DistanceType::L2 => dist_q_c,
508 DistanceType::Cosine | DistanceType::Dot => dist_q_c - 1.0,
509 _ => unimplemented!(
510 "RabitQ does not support distance type: {}",
511 self.distance_type
512 ),
513 };
514 RabitDistCalculator::new(
515 qr.len(),
516 self.metadata.num_bits,
517 dist_table,
518 sum_q,
519 codes,
520 self.add_factors.values(),
521 self.scale_factors.values(),
522 q_factor,
523 )
524 }
525
526 fn dist_calculator_from_id(&self, _: u32) -> Self::DistanceCalculator<'_> {
529 unimplemented!("RabitQ does not support dist_calculator_from_id")
530 }
531}
532
533const LOWBIT_IDX: [usize; 16] = {
534 let mut array = [0; 16];
535 let mut i = 1;
536 while i < 16 {
537 array[i] = i.trailing_zeros() as usize;
538 i += 1;
539 }
540 array
541};
542
543fn get_column(
544 quantization_code: &[u8],
545 code_len: usize,
546 row: usize,
547 col_idx: usize,
548 codes: &mut [u8; 32],
549) {
550 for (i, code) in codes.iter_mut().enumerate() {
551 let vec_idx = row + i;
552 *code = quantization_code[vec_idx * code_len + col_idx];
553 }
554}
555
556pub fn pack_codes(codes: &FixedSizeListArray) -> FixedSizeListArray {
557 let code_len = codes.value_length() as usize;
558
559 let num_blocks = codes.len() / BATCH_SIZE;
561 let num_packed_vectors = num_blocks * BATCH_SIZE;
562
563 let mut blocks = vec![0u8; codes.values().len()];
569
570 let codes_values = codes
571 .slice(0, num_packed_vectors)
572 .values()
573 .as_primitive::<UInt8Type>()
574 .clone();
575 let codes_values = codes_values.values();
576
577 let mut col = [0u8; 32];
580 let mut col_0 = [0u8; 32]; let mut col_1 = [0u8; 32]; for row in (0..num_packed_vectors).step_by(BATCH_SIZE) {
583 for i in 0..code_len {
587 get_column(codes_values, code_len, row, i, &mut col);
588
589 for j in 0..32 {
590 col_0[j] = col[j] & 0xF;
591 col_1[j] = col[j] >> 4;
592 }
593
594 let block_offset = (row / BATCH_SIZE) * code_len * BATCH_SIZE + i * BATCH_SIZE;
595 for j in 0..16 {
596 let val0 = col_0[PERM0[j]] | (col_0[PERM0[j] + 16] << 4);
599 let val1 = col_1[PERM0[j]] | (col_1[PERM0[j] + 16] << 4);
600 blocks[block_offset + j] = val0;
601 blocks[block_offset + j + 16] = val1;
602 }
603 }
604 }
605
606 let transposed_codes = transpose(
608 &codes.values().as_primitive::<UInt8Type>().slice(
609 num_packed_vectors * code_len,
610 (codes.len() - num_packed_vectors) * code_len,
611 ),
612 codes.len() - num_packed_vectors,
613 code_len,
614 );
615
616 let offset = codes.values().len() - transposed_codes.len();
617 for (i, v) in transposed_codes.values().iter().enumerate() {
618 blocks[offset + i] = *v;
619 }
620
621 assert_eq!(blocks.len(), codes.values().len());
622 FixedSizeListArray::try_new_from_values(UInt8Array::from(blocks), code_len as i32).unwrap()
623}
624
625pub fn unpack_codes(codes: &FixedSizeListArray) -> FixedSizeListArray {
627 let code_len = codes.value_length() as usize;
628 let num_vectors = codes.len();
629
630 let num_blocks = num_vectors / BATCH_SIZE;
632 let num_packed_vectors = num_blocks * BATCH_SIZE;
633
634 let mut unpacked = vec![0u8; codes.values().len()];
635
636 let codes_values = codes.values().as_primitive::<UInt8Type>().values();
637
638 for batch_idx in 0..num_blocks {
640 let block_start = batch_idx * code_len * BATCH_SIZE;
641
642 for i in 0..code_len {
643 let block_offset = block_start + i * BATCH_SIZE;
644 let block = &codes_values[block_offset..block_offset + BATCH_SIZE];
645
646 for j in 0..16 {
648 let val0 = block[j];
649 let val1 = block[j + 16];
650
651 let low_0 = val0 & 0xF;
652 let high_0 = val0 >> 4;
653 let low_1 = val1 & 0xF;
654 let high_1 = val1 >> 4;
655
656 let vec_idx_0 = batch_idx * BATCH_SIZE + PERM0[j];
657 let vec_idx_1 = batch_idx * BATCH_SIZE + PERM0[j] + 16;
658
659 unpacked[vec_idx_0 * code_len + i] = low_0 | (low_1 << 4);
660 unpacked[vec_idx_1 * code_len + i] = high_0 | (high_1 << 4);
661 }
662 }
663 }
664
665 if num_packed_vectors < num_vectors {
667 let remainder = num_vectors - num_packed_vectors;
668 let offset = num_packed_vectors * code_len;
669 let transposed_data = &codes_values[offset..];
670
671 for row in 0..remainder {
673 for col in 0..code_len {
674 unpacked[offset + row * code_len + col] = transposed_data[col * remainder + row];
675 }
676 }
677 }
678
679 FixedSizeListArray::try_new_from_values(UInt8Array::from(unpacked), code_len as i32).unwrap()
680}
681
682#[async_trait]
683impl QuantizerStorage for RabitQuantizationStorage {
684 type Metadata = RabitQuantizationMetadata;
685
686 fn try_from_batch(
687 batch: RecordBatch,
688 metadata: &Self::Metadata,
689 distance_type: DistanceType,
690 _fri: Option<Arc<FragReuseIndex>>,
691 ) -> Result<Self> {
692 let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>().clone();
693 let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list().clone();
694 let add_factors = batch[ADD_FACTORS_COLUMN]
695 .as_primitive::<Float32Type>()
696 .clone();
697 let scale_factors = batch[SCALE_FACTORS_COLUMN]
698 .as_primitive::<Float32Type>()
699 .clone();
700
701 let (batch, codes) = if !metadata.packed {
702 let codes = pack_codes(&codes);
703 let batch = batch.replace_column_by_name(RABIT_CODE_COLUMN, Arc::new(codes))?;
704 let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list().clone();
705 (batch, codes)
706 } else {
707 (batch, codes)
708 };
709
710 let mut metadata = metadata.clone();
711 metadata.packed = true;
712
713 Ok(Self {
714 metadata,
715 batch,
716 distance_type,
717 row_ids,
718 codes,
719 add_factors,
720 scale_factors,
721 })
722 }
723
724 fn metadata(&self) -> &Self::Metadata {
725 &self.metadata
726 }
727
728 async fn load_partition(
729 reader: &PreviousFileReader,
730 range: std::ops::Range<usize>,
731 distance_type: DistanceType,
732 metadata: &Self::Metadata,
733 frag_reuse_index: Option<Arc<FragReuseIndex>>,
734 ) -> Result<Self> {
735 let schema = reader.schema();
736 let batch = reader.read_range(range, schema).await?;
737 Self::try_from_batch(batch, metadata, distance_type, frag_reuse_index)
738 }
739
740 fn remap(&self, mapping: &HashMap<u64, Option<u64>>) -> Result<Self> {
741 let num_vectors = self.codes.len();
742 let num_code_bytes = self.codes.value_length() as usize;
743 let codes = self.codes.values().as_primitive::<UInt8Type>().values();
744 let mut indices = Vec::with_capacity(num_vectors);
745 let mut new_row_ids = Vec::with_capacity(num_vectors);
746 let mut new_codes = Vec::with_capacity(codes.len());
747
748 let row_ids = self.row_ids.values();
749 for (i, row_id) in row_ids.iter().enumerate() {
750 match mapping.get(row_id) {
751 Some(Some(new_id)) => {
752 indices.push(i as u32);
753 new_row_ids.push(*new_id);
754 new_codes.extend(get_rq_code(codes, i, num_vectors, num_code_bytes));
755 }
756 Some(None) => {}
757 None => {
758 indices.push(i as u32);
759 new_row_ids.push(*row_id);
760 new_codes.extend(get_rq_code(codes, i, num_vectors, num_code_bytes));
761 }
762 }
763 }
764
765 let new_row_ids = UInt64Array::from(new_row_ids);
766 let new_codes = FixedSizeListArray::try_new_from_values(
767 UInt8Array::from(new_codes),
768 num_code_bytes as i32,
769 )?;
770 let batch = if new_row_ids.is_empty() {
771 RecordBatch::new_empty(self.schema().clone())
772 } else {
773 let codes = Arc::new(pack_codes(&new_codes));
774 self.batch
775 .take(&UInt32Array::from(indices))?
776 .replace_column_by_name(ROW_ID, Arc::new(new_row_ids.clone()))?
777 .replace_column_by_name(RABIT_CODE_COLUMN, codes)?
778 };
779 let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list().clone();
780
781 Ok(Self {
782 metadata: self.metadata.clone(),
783 distance_type: self.distance_type,
784 batch,
785 codes,
786 add_factors: self.add_factors.clone(),
787 scale_factors: self.scale_factors.clone(),
788 row_ids: new_row_ids,
789 })
790 }
791}
792
793#[inline]
799fn compute_single_rq_distance(
800 codes: &[u8],
801 id: usize,
802 num_vectors: usize,
803 num_code_bytes: usize,
804 dist_table: &[f32],
805) -> f32 {
806 let remainder = num_vectors % BATCH_SIZE;
807 let mut dist_table_iter = dist_table.chunks_exact(SEGMENT_NUM_CODES).tuples();
808
809 if id < num_vectors - remainder {
810 let batch_codes = &codes[id / BATCH_SIZE * BATCH_SIZE * num_code_bytes
811 ..(id / BATCH_SIZE + 1) * BATCH_SIZE * num_code_bytes];
812
813 let id_in_batch = id % BATCH_SIZE;
814 let idx = PERM0_INVERSE[id_in_batch % 16];
815 let is_lower = id_in_batch < 16;
816
817 let mut dist = 0.0f32;
818 for block in batch_codes.chunks_exact(BATCH_SIZE) {
819 let code_byte = if is_lower {
820 (block[idx] & 0xF) | (block[idx + 16] << 4)
821 } else {
822 (block[idx] >> 4) | (block[idx + 16] & 0xF0)
823 };
824 if let Some((current_dt, next_dt)) = dist_table_iter.next() {
825 let current_code = (code_byte & 0x0F) as usize;
826 let next_code = (code_byte >> 4) as usize;
827 dist += current_dt[current_code] + next_dt[next_code];
828 }
829 }
830 dist
831 } else {
832 let offset_id = id - (num_vectors - remainder);
833 let remainder_codes = &codes[(num_vectors - remainder) * num_code_bytes..];
834
835 let mut dist = 0.0f32;
836 for &code_byte in remainder_codes.iter().skip(offset_id).step_by(remainder) {
837 if let Some((current_dt, next_dt)) = dist_table_iter.next() {
838 let current_code = (code_byte & 0x0F) as usize;
839 let next_code = (code_byte >> 4) as usize;
840 dist += current_dt[current_code] + next_dt[next_code];
841 }
842 }
843 dist
844 }
845}
846
847#[inline]
848fn get_rq_code(
849 codes: &[u8],
850 id: usize,
851 num_vectors: usize,
852 num_code_bytes: usize,
853) -> impl Iterator<Item = u8> + '_ {
854 let remainder = num_vectors % BATCH_SIZE;
855
856 if id < num_vectors - remainder {
857 let codes = &codes[id / BATCH_SIZE * BATCH_SIZE * num_code_bytes
859 ..(id / BATCH_SIZE + 1) * BATCH_SIZE * num_code_bytes];
860
861 let id_in_batch = id % BATCH_SIZE;
862 if id_in_batch < 16 {
863 let idx = PERM0_INVERSE[id_in_batch];
864 codes
865 .chunks_exact(BATCH_SIZE)
866 .map(|block| (block[idx] & 0xF) | (block[idx + 16] << 4))
867 .exact_size(num_code_bytes)
868 .collect_vec()
869 .into_iter()
870 } else {
871 let idx = PERM0_INVERSE[id_in_batch - 16];
872 codes
873 .chunks_exact(BATCH_SIZE)
874 .map(|block| (block[idx] >> 4) | (block[idx + 16] & 0xF0))
875 .exact_size(num_code_bytes)
876 .collect_vec()
877 .into_iter()
878 }
879 } else {
880 let id = id - (num_vectors - remainder);
881 let codes = &codes[(num_vectors - remainder) * num_code_bytes..];
882 codes
883 .iter()
884 .skip(id)
885 .step_by(remainder)
886 .copied()
887 .exact_size(num_code_bytes)
888 .collect_vec()
889 .into_iter()
890 }
891}
892
893#[cfg(test)]
894mod tests {
895 use super::*;
896 use std::collections::HashMap;
897
898 use arrow_array::{ArrayRef, Float32Array, UInt64Array};
899 use lance_core::ROW_ID;
900 use lance_linalg::distance::DistanceType;
901
902 use crate::vector::bq::{RQRotationType, builder::RabitQuantizer};
903 use crate::vector::quantizer::{Quantization, QuantizerStorage};
904
905 fn build_dist_table_not_optimized<T: ArrowFloatType>(
906 sub_vec: &[T::Native],
907 dist_table: &mut [f32],
908 ) where
909 T::Native: AsPrimitive<f32>,
910 {
911 for (j, dist) in dist_table.iter_mut().enumerate().take(SEGMENT_NUM_CODES) {
912 for (k, v) in sub_vec.iter().enumerate().take(SEGMENT_LENGTH) {
913 if j & (1 << k) != 0 {
914 *dist += v.as_();
915 }
916 }
917 }
918 }
919
920 #[test]
921 fn test_build_dist_table_not_optimized() {
922 let sub_vec = vec![1.0, 2.0, 3.0, 4.0];
923 let mut expected = vec![0.0; SEGMENT_NUM_CODES];
924 build_dist_table_not_optimized::<Float32Type>(&sub_vec, &mut expected);
925 let mut dist_table = vec![0.0; SEGMENT_NUM_CODES];
926 build_dist_table_for_subvec::<Float32Type>(&sub_vec, &mut dist_table);
927 assert_eq!(dist_table, expected);
928 }
929
930 #[test]
931 fn test_pack_unpack_codes() {
932 for num_vectors in [10, 32, 50, 64, 100] {
934 let code_len = 8;
935
936 let mut codes_data = Vec::new();
938 for i in 0..num_vectors {
939 for j in 0..code_len {
940 codes_data.push((i * code_len + j) as u8);
941 }
942 }
943
944 let original_codes = FixedSizeListArray::try_new_from_values(
945 UInt8Array::from(codes_data.clone()),
946 code_len,
947 )
948 .unwrap();
949
950 let packed = pack_codes(&original_codes);
952 let unpacked = unpack_codes(&packed);
953
954 assert_eq!(original_codes.len(), unpacked.len());
956 assert_eq!(original_codes.value_length(), unpacked.value_length());
957
958 let original_values = original_codes.values().as_primitive::<UInt8Type>().values();
959 let unpacked_values = unpacked.values().as_primitive::<UInt8Type>().values();
960
961 assert_eq!(
962 original_values, unpacked_values,
963 "Mismatch for num_vectors={}",
964 num_vectors
965 );
966 }
967 }
968
969 fn make_test_codes(num_vectors: usize, code_dim: i32) -> FixedSizeListArray {
970 let quantizer =
971 RabitQuantizer::new_with_rotation::<Float32Type>(1, code_dim, RQRotationType::Fast);
972 let values = Float32Array::from_iter_values(
973 (0..num_vectors * code_dim as usize).map(|idx| idx as f32 / code_dim as f32),
974 );
975 let vectors = FixedSizeListArray::try_new_from_values(values, code_dim).unwrap();
976 quantizer
977 .quantize(&vectors)
978 .unwrap()
979 .as_fixed_size_list()
980 .clone()
981 }
982
983 fn make_test_metadata(code_dim: usize) -> RabitQuantizationMetadata {
984 RabitQuantizer::new_with_rotation::<Float32Type>(1, code_dim as i32, RQRotationType::Fast)
985 .metadata(None)
986 }
987
988 fn make_test_batch(codes: FixedSizeListArray) -> RecordBatch {
989 let num_rows = codes.len();
990 RecordBatch::try_from_iter(vec![
991 (
992 ROW_ID,
993 Arc::new(UInt64Array::from_iter_values(0..num_rows as u64)) as ArrayRef,
994 ),
995 (RABIT_CODE_COLUMN, Arc::new(codes) as ArrayRef),
996 (
997 ADD_FACTORS_COLUMN,
998 Arc::new(Float32Array::from_iter_values(
999 (0..num_rows).map(|v| v as f32),
1000 )) as ArrayRef,
1001 ),
1002 (
1003 SCALE_FACTORS_COLUMN,
1004 Arc::new(Float32Array::from_iter_values(
1005 (0..num_rows).map(|v| v as f32 + 0.5),
1006 )) as ArrayRef,
1007 ),
1008 ])
1009 .unwrap()
1010 }
1011
1012 fn assert_codes_eq(actual: &FixedSizeListArray, expected: &FixedSizeListArray) {
1013 assert_eq!(actual.len(), expected.len());
1014 assert_eq!(actual.value_length(), expected.value_length());
1015 assert_eq!(
1016 actual.values().as_primitive::<UInt8Type>().values(),
1017 expected.values().as_primitive::<UInt8Type>().values()
1018 );
1019 }
1020
1021 #[test]
1022 fn test_try_from_batch_canonicalizes_rq_codes_to_packed_layout() {
1023 let original_codes = make_test_codes(50, 64);
1024 let metadata = make_test_metadata(original_codes.value_length() as usize * 8);
1025 assert!(!metadata.packed);
1026
1027 let storage = RabitQuantizationStorage::try_from_batch(
1028 make_test_batch(original_codes.clone()),
1029 &metadata,
1030 DistanceType::L2,
1031 None,
1032 )
1033 .unwrap();
1034
1035 assert!(storage.metadata().packed);
1036 let stored_batch = storage.to_batches().unwrap().next().unwrap();
1037 let stored_codes = stored_batch[RABIT_CODE_COLUMN].as_fixed_size_list();
1038 let expected_codes = pack_codes(&original_codes);
1039 assert_codes_eq(stored_codes, &expected_codes);
1040 }
1041
1042 #[test]
1043 fn test_remap_preserves_packed_rq_storage_layout() {
1044 let original_codes = make_test_codes(50, 64);
1045 let metadata = make_test_metadata(original_codes.value_length() as usize * 8);
1046 let storage = RabitQuantizationStorage::try_from_batch(
1047 make_test_batch(original_codes.clone()),
1048 &metadata,
1049 DistanceType::L2,
1050 None,
1051 )
1052 .unwrap();
1053
1054 let mut mapping = HashMap::new();
1055 mapping.insert(1, Some(101));
1056 mapping.insert(3, None);
1057 mapping.insert(4, Some(104));
1058
1059 let remapped = storage.remap(&mapping).unwrap();
1060 assert!(remapped.metadata().packed);
1061
1062 let remapped_batch = remapped.to_batches().unwrap().next().unwrap();
1063 let remapped_row_ids = remapped_batch[ROW_ID].as_primitive::<UInt64Type>().values();
1064 let expected_row_ids = UInt64Array::from_iter_values(
1065 [0, 101, 2, 104]
1066 .into_iter()
1067 .chain(5..original_codes.len() as u64),
1068 );
1069 assert_eq!(remapped_row_ids, expected_row_ids.values());
1070
1071 let remapped_codes = remapped_batch[RABIT_CODE_COLUMN].as_fixed_size_list();
1072 let repacked = pack_codes(&unpack_codes(remapped_codes));
1073 assert_codes_eq(remapped_codes, &repacked);
1074 }
1075}