lance_index/vector/bq/
storage.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use arrow::array::AsArray;
8use arrow::datatypes::{Float16Type, Float32Type, Float64Type, UInt64Type, UInt8Type};
9use arrow_array::{
10    Array, FixedSizeListArray, Float32Array, RecordBatch, UInt32Array, UInt64Array, UInt8Array,
11};
12use arrow_schema::{DataType, SchemaRef};
13use async_trait::async_trait;
14use bytes::{Bytes, BytesMut};
15use deepsize::DeepSizeOf;
16use itertools::Itertools;
17use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, RecordBatchExt};
18use lance_core::{Error, Result, ROW_ID};
19use lance_file::previous::reader::FileReader as PreviousFileReader;
20use lance_linalg::distance::{DistanceType, Dot};
21use lance_linalg::simd::dist_table::{BATCH_SIZE, PERM0, PERM0_INVERSE};
22use lance_linalg::simd::{self};
23use lance_table::utils::LanceIteratorExtension;
24use num_traits::AsPrimitive;
25use prost::Message;
26use serde::{Deserialize, Serialize};
27use snafu::location;
28
29use crate::frag_reuse::FragReuseIndex;
30use crate::pb;
31use crate::vector::bq::transform::{ADD_FACTORS_COLUMN, SCALE_FACTORS_COLUMN};
32use crate::vector::pq::storage::transpose;
33use crate::vector::quantizer::{QuantizerMetadata, QuantizerStorage};
34use crate::vector::storage::{DistCalculator, VectorStore};
35
36pub const RABIT_METADATA_KEY: &str = "lance:rabit";
37pub const RABIT_CODE_COLUMN: &str = "_rabit_codes";
38pub const SEGMENT_LENGTH: usize = 4;
39pub const SEGMENT_NUM_CODES: usize = 1 << SEGMENT_LENGTH;
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct RabitQuantizationMetadata {
43    // this rotate matrix is large, and lance index would store all metadata in schema metadata,
44    // which is in JSON format, so we skip it in serialization and deserialization, and store it
45    // in the global buffer, which is a binary format (protobuf for now) for efficiency.
46    #[serde(skip)]
47    pub rotate_mat: Option<FixedSizeListArray>,
48    pub rotate_mat_position: u32,
49    pub num_bits: u8,
50    pub packed: bool,
51}
52
53impl DeepSizeOf for RabitQuantizationMetadata {
54    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
55        self.rotate_mat
56            .as_ref()
57            .map(|inv_p| inv_p.get_array_memory_size())
58            .unwrap_or(0)
59    }
60}
61
62#[async_trait]
63impl QuantizerMetadata for RabitQuantizationMetadata {
64    fn buffer_index(&self) -> Option<u32> {
65        Some(self.rotate_mat_position)
66    }
67
68    fn set_buffer_index(&mut self, index: u32) {
69        self.rotate_mat_position = index;
70    }
71
72    fn parse_buffer(&mut self, bytes: Bytes) -> Result<()> {
73        debug_assert!(!bytes.is_empty());
74        let codebook_tensor: pb::Tensor = pb::Tensor::decode(bytes)?;
75        self.rotate_mat = Some(FixedSizeListArray::try_from(&codebook_tensor)?);
76        Ok(())
77    }
78
79    fn extra_metadata(&self) -> Result<Option<Bytes>> {
80        if let Some(inv_p) = &self.rotate_mat {
81            let inv_p_tensor = pb::Tensor::try_from(inv_p)?;
82            let mut bytes = BytesMut::new();
83            inv_p_tensor.encode(&mut bytes)?;
84            Ok(Some(bytes.freeze()))
85        } else {
86            Ok(None)
87        }
88    }
89
90    async fn load(reader: &PreviousFileReader) -> Result<Self> {
91        let metadata_str =
92            reader
93                .schema()
94                .metadata
95                .get(RABIT_METADATA_KEY)
96                .ok_or(Error::Index {
97                    message: format!(
98                        "Reading Rabit metadata: metadata key {} not found",
99                        RABIT_METADATA_KEY
100                    ),
101                    location: location!(),
102                })?;
103        serde_json::from_str(metadata_str).map_err(|_| Error::Index {
104            message: format!("Failed to parse index metadata: {}", metadata_str),
105            location: location!(),
106        })
107    }
108}
109
110#[derive(Debug, Clone)]
111pub struct RabitQuantizationStorage {
112    metadata: RabitQuantizationMetadata,
113    batch: RecordBatch,
114    distance_type: DistanceType,
115
116    // helper fields
117    row_ids: UInt64Array,
118    codes: FixedSizeListArray,
119    add_factors: Float32Array,
120    scale_factors: Float32Array,
121}
122
123impl DeepSizeOf for RabitQuantizationStorage {
124    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
125        self.metadata.deep_size_of_children(context) + self.batch.get_array_memory_size()
126    }
127}
128
129impl RabitQuantizationStorage {
130    fn rotate_query_vector<T: ArrowFloatType>(
131        rotate_mat: &FixedSizeListArray,
132        qr: &dyn Array,
133    ) -> Vec<f32>
134    where
135        T::Native: Dot,
136    {
137        let d = qr.len();
138        let code_dim = rotate_mat.len();
139        let rotate_mat = rotate_mat
140            .values()
141            .as_any()
142            .downcast_ref::<T::ArrayType>()
143            .unwrap()
144            .as_slice();
145
146        let qr = qr
147            .as_any()
148            .downcast_ref::<T::ArrayType>()
149            .unwrap()
150            .as_slice();
151
152        rotate_mat
153            .chunks_exact(code_dim)
154            .map(|chunk| lance_linalg::distance::dot(&chunk[..d], qr))
155            .collect()
156    }
157}
158
159pub struct RabitDistCalculator<'a> {
160    dim: usize,
161    // num_bits is the number of bits per dimension,
162    // it's always 1 for now
163    num_bits: u8,
164    // n * d * num_bits / 8 bytes
165    codes: &'a [u8],
166    // this is a flattened 2D array of size d/4 * 16,
167    // we split the query codes into d/4 chunks, each chunk is with 4 elements,
168    // then dist_table[i][j] is the distance between the i-th query code and the code j
169    dist_table: Vec<f32>,
170    add_factors: &'a [f32],
171    scale_factors: &'a [f32],
172    query_factor: f32,
173
174    sum_q: f32,
175    sqrt_d: f32,
176}
177
178impl<'a> RabitDistCalculator<'a> {
179    #[allow(clippy::too_many_arguments)]
180    pub fn new(
181        dim: usize,
182        num_bits: u8,
183        dist_table: Vec<f32>,
184        sum_q: f32,
185        codes: &'a [u8],
186        add_factors: &'a [f32],
187        scale_factors: &'a [f32],
188        query_factor: f32,
189    ) -> Self {
190        Self {
191            dim,
192            num_bits,
193            codes,
194            dist_table,
195            add_factors,
196            scale_factors,
197            query_factor,
198            sqrt_d: (dim as f32 * num_bits as f32).sqrt(),
199            sum_q,
200        }
201    }
202}
203
204#[inline]
205fn lowbit(x: usize) -> usize {
206    1 << x.trailing_zeros()
207}
208
209#[inline]
210pub fn build_dist_table_direct<T: ArrowFloatType>(qc: &[T::Native]) -> Vec<f32>
211where
212    T::Native: AsPrimitive<f32>,
213{
214    // every 4 bits (SEGMENT_LENGTH) is a segment, and we need to compute the distance between the segment and all the codes
215    // so there are dim/4 segments, and the number of codes is 16 (2^{SEGMENT_LENGTH}),
216    // so we have dim/4 * 16 = dim * 4 elements in the dist_table
217    let mut dist_table = vec![0.0; qc.len() * 4];
218    qc.chunks_exact(SEGMENT_LENGTH)
219        .zip(dist_table.chunks_exact_mut(SEGMENT_NUM_CODES))
220        .for_each(|(sub_vec, dist_table)| build_dist_table_for_subvec::<T>(sub_vec, dist_table));
221    dist_table
222}
223
224#[inline(always)]
225fn build_dist_table_for_subvec<T: ArrowFloatType>(sub_vec: &[T::Native], dist_table: &mut [f32])
226where
227    T::Native: AsPrimitive<f32>,
228{
229    // skip 0 because it's always 0
230    (1..SEGMENT_NUM_CODES).for_each(|j| {
231        // this is a little bit tricky,
232        // j represents a subset of 4 bits, that if the i-th bit of `j` is 1,
233        // then we need to add the distance of the i-th dim of the segment.
234        // but we don't need to check all bits of `j`,
235        // because `j` = `j - lowbit(j)` + `lowbit(j)`,
236        // where `j-lowbit(j)` is less than `j`,
237        // which means dist_table[j-lowbit(j)] is already computed,
238        // and we can use it to compute dist_table[j]
239        // for example, if j = 0b1010, then j - lowbit(j) = 0b1000,
240        // and dist_table[0b1000] is already computed,
241        // so dist_table[0b1010] = dist_table[0b1000] + sub_vec[LOWBIT_IDX[0b1010]];
242        // where lowbit(0b1010) = 0b10, LOWBIT_IDX[0b1010] = LOWBIT_IDX[0b10] = 1.
243        dist_table[j] = dist_table[j - lowbit(j)] + sub_vec[LOWBIT_IDX[j]].as_();
244    })
245}
246
247// Quantize the distance table to u8, map distance `d` to `(d-qmin) * 255 / (qmax-qmin)`
248#[inline]
249fn quantize_dist_table(dist_table: &[f32]) -> (f32, f32, Vec<u8>) {
250    let (qmin, qmax) = dist_table
251        .iter()
252        .cloned()
253        .minmax_by(|a, b| a.total_cmp(b))
254        .into_option()
255        .unwrap();
256    // this happens if the query is all zeros
257    if qmin == qmax {
258        return (qmin, qmax, vec![0; dist_table.len()]);
259    }
260    let factor = 255.0 / (qmax - qmin);
261    let quantized_dist_table = dist_table
262        .iter()
263        .map(|&d| ((d - qmin) * factor).round() as u8)
264        .collect();
265
266    (qmin, qmax, quantized_dist_table)
267}
268
269#[inline]
270fn compute_rq_distance_flat(
271    dist_table: &[f32],
272    codes: &[u8],
273    offset: usize,
274    length: usize,
275    dists: &mut [f32],
276) {
277    let d = dist_table.len() / 4;
278    let code_len = d / u8::BITS as usize;
279    let codes = &codes[offset * code_len..(offset + length) * code_len];
280    let dists = &mut dists[offset..offset + length];
281
282    for (sub_vec_idx, codes) in codes.chunks_exact(length).enumerate() {
283        let current_dist_table = &dist_table
284            [sub_vec_idx * 2 * SEGMENT_NUM_CODES..(sub_vec_idx * 2 + 1) * SEGMENT_NUM_CODES];
285        let next_dist_table = &dist_table
286            [(sub_vec_idx * 2 + 1) * SEGMENT_NUM_CODES..(sub_vec_idx * 2 + 2) * SEGMENT_NUM_CODES];
287
288        codes.iter().zip(dists.iter_mut()).for_each(|(code, dist)| {
289            let current_code = (code & 0x0F) as usize;
290            let next_code = (code >> 4) as usize;
291            *dist += current_dist_table[current_code] + next_dist_table[next_code];
292        });
293    }
294}
295
296impl DistCalculator for RabitDistCalculator<'_> {
297    #[inline(always)]
298    fn distance(&self, id: u32) -> f32 {
299        let id = id as usize;
300        let code_len = self.dim * (self.num_bits as usize) / u8::BITS as usize;
301        let num_vectors = self.codes.len() / code_len;
302        let code = get_rq_code(self.codes, id, num_vectors, code_len);
303        let dist = code
304            .zip(self.dist_table.chunks_exact(SEGMENT_NUM_CODES).tuples())
305            .map(|(code_byte, (dist_table, next_dist_table))| {
306                // code is a bit vector, we iterate over 8 bits at a time,
307                // every 4 bits is a sub-vector, we need to extract the bits
308                let current_code = (code_byte & 0x0F) as usize;
309                let next_code = (code_byte >> 4) as usize;
310                dist_table[current_code] + next_dist_table[next_code]
311            })
312            .sum::<f32>();
313
314        // distance between quantized vector and query vector
315        let dist_vq_qr = (2.0 * dist - self.sum_q) / self.sqrt_d;
316        dist_vq_qr * self.scale_factors[id] + self.add_factors[id] + self.query_factor
317    }
318
319    #[inline(always)]
320    fn distance_all(&self, _: usize) -> Vec<f32> {
321        let code_len = self.dim * (self.num_bits as usize) / u8::BITS as usize;
322        let n = self.codes.len() / code_len;
323        if n == 0 {
324            return Vec::new();
325        }
326
327        let mut dists = vec![0.0; n];
328
329        let (qmin, qmax, quantized_dists_table) = quantize_dist_table(&self.dist_table);
330        let mut quantized_dists = vec![0; n];
331
332        let remainder = n % BATCH_SIZE;
333        simd::dist_table::sum_4bit_dist_table(
334            n - remainder,
335            code_len,
336            self.codes,
337            &quantized_dists_table,
338            &mut quantized_dists,
339        );
340        if remainder > 0 {
341            compute_rq_distance_flat(
342                &self.dist_table,
343                self.codes,
344                n - remainder,
345                remainder,
346                &mut dists,
347            );
348        }
349
350        let range = (qmax - qmin) / 255.0;
351        let num_tables = quantized_dists_table.len() / 16;
352        let sum_min = num_tables as f32 * qmin;
353        dists
354            .iter_mut()
355            .take(n - remainder)
356            .zip(quantized_dists.into_iter().take(n - remainder))
357            .for_each(|(dist, q_dist)| {
358                *dist = (q_dist as f32) * range + sum_min;
359            });
360
361        dists
362            .into_iter()
363            .enumerate()
364            .map(|(id, dist)| {
365                let dist_vq_qr = (2.0 * dist - self.sum_q) / self.sqrt_d;
366                dist_vq_qr * self.scale_factors[id] + self.add_factors[id] + self.query_factor
367            })
368            .collect()
369    }
370}
371
372impl VectorStore for RabitQuantizationStorage {
373    type DistanceCalculator<'a> = RabitDistCalculator<'a>;
374
375    fn as_any(&self) -> &dyn std::any::Any {
376        self
377    }
378
379    fn schema(&self) -> &SchemaRef {
380        self.batch.schema_ref()
381    }
382
383    fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch> + Send> {
384        Ok(std::iter::once(self.batch.clone()))
385    }
386
387    fn append_batch(&self, _batch: RecordBatch, _vector_column: &str) -> Result<Self> {
388        unimplemented!("RabitQ does not support append_batch")
389    }
390
391    fn len(&self) -> usize {
392        self.batch.num_rows()
393    }
394
395    fn row_id(&self, id: u32) -> u64 {
396        self.row_ids.value(id as usize)
397    }
398
399    fn row_ids(&self) -> impl Iterator<Item = &u64> {
400        self.row_ids.values().iter()
401    }
402
403    fn distance_type(&self) -> DistanceType {
404        self.distance_type
405    }
406
407    // qr = (q-c)
408    #[inline(never)]
409    fn dist_calculator(&self, qr: Arc<dyn Array>, dist_q_c: f32) -> Self::DistanceCalculator<'_> {
410        let codes = self.codes.values().as_primitive::<UInt8Type>().values();
411        let rotate_mat = self
412            .metadata
413            .rotate_mat
414            .as_ref()
415            .expect("RabitQ metadata not loaded");
416
417        let rotated_qr = match rotate_mat.value_type() {
418            DataType::Float16 => Self::rotate_query_vector::<Float16Type>(rotate_mat, &qr),
419            DataType::Float32 => Self::rotate_query_vector::<Float32Type>(rotate_mat, &qr),
420            DataType::Float64 => Self::rotate_query_vector::<Float64Type>(rotate_mat, &qr),
421            dt => unimplemented!("RabitQ does not support data type: {}", dt),
422        };
423
424        let dist_table = build_dist_table_direct::<Float32Type>(&rotated_qr);
425        let sum_q = rotated_qr.into_iter().sum();
426
427        let q_factor = match self.distance_type {
428            DistanceType::L2 => dist_q_c,
429            DistanceType::Cosine | DistanceType::Dot => dist_q_c - 1.0,
430            _ => unimplemented!(
431                "RabitQ does not support distance type: {}",
432                self.distance_type
433            ),
434        };
435        RabitDistCalculator::new(
436            qr.len(),
437            self.metadata.num_bits,
438            dist_table,
439            sum_q,
440            codes,
441            self.add_factors.values(),
442            self.scale_factors.values(),
443            q_factor,
444        )
445    }
446
447    // TODO: implement this
448    // This method is required for HNSW, we can't support HNSW_RABIT before this is implemented
449    fn dist_calculator_from_id(&self, _: u32) -> Self::DistanceCalculator<'_> {
450        unimplemented!("RabitQ does not support dist_calculator_from_id")
451    }
452}
453
454const LOWBIT_IDX: [usize; 16] = {
455    let mut array = [0; 16];
456    let mut i = 1;
457    while i < 16 {
458        array[i] = i.trailing_zeros() as usize;
459        i += 1;
460    }
461    array
462};
463
464fn get_column(
465    quantization_code: &[u8],
466    code_len: usize,
467    row: usize,
468    col_idx: usize,
469    codes: &mut [u8; 32],
470) {
471    for (i, code) in codes.iter_mut().enumerate() {
472        let vec_idx = row + i;
473        *code = quantization_code[vec_idx * code_len + col_idx];
474    }
475}
476
477pub fn pack_codes(codes: &FixedSizeListArray) -> FixedSizeListArray {
478    let code_len = codes.value_length() as usize;
479
480    // round up num of vectors to multiple of batch size (32)
481    let num_blocks = codes.len() / BATCH_SIZE;
482    let num_packed_vectors = num_blocks * BATCH_SIZE;
483
484    // calculate total size for packed blocks
485    // we pack each 32 vectors into a block, each block contains 2 codes (1byte) of each vector
486    // so every 32 vectors would produce code_len blocks
487    // the low 16 bytes of each block is the codes for the low 4 bits of each vector
488    // the high 16 bytes of each block is the codes for the high 4 bits of each vector
489    let mut blocks = vec![0u8; codes.values().len()];
490
491    let codes_values = codes
492        .slice(0, num_packed_vectors)
493        .values()
494        .as_primitive::<UInt8Type>()
495        .clone();
496    let codes_values = codes_values.values();
497
498    // Pack codes batch by batch
499    // Each batch contains codes for 32 vectors
500    let mut col = [0u8; 32];
501    let mut col_0 = [0u8; 32]; // lower 4 bits
502    let mut col_1 = [0u8; 32]; // higher 4 bits
503    for row in (0..num_packed_vectors).step_by(BATCH_SIZE) {
504        // Get quantization codes for each column for each batch
505        // i.e., we get the codes for 8 dims of 32 vectors and reorganize the data layout
506        // based on the shuffle SIMD instruction used during querying
507        for i in 0..code_len {
508            get_column(codes_values, code_len, row, i, &mut col);
509
510            for j in 0..32 {
511                col_0[j] = col[j] & 0xF;
512                col_1[j] = col[j] >> 4;
513            }
514
515            let block_offset = (row / BATCH_SIZE) * code_len * BATCH_SIZE + i * BATCH_SIZE;
516            for j in 0..16 {
517                // The lower 4 bits represent vector 0 to 15
518                // The upper 4 bits represent vector 16 to 31
519                let val0 = col_0[PERM0[j]] | (col_0[PERM0[j] + 16] << 4);
520                let val1 = col_1[PERM0[j]] | (col_1[PERM0[j] + 16] << 4);
521                blocks[block_offset + j] = val0;
522                blocks[block_offset + j + 16] = val1;
523            }
524        }
525    }
526
527    // for the left codes, transpose them for better cache locality
528    let transposed_codes = transpose(
529        &codes.values().as_primitive::<UInt8Type>().slice(
530            num_packed_vectors * code_len,
531            (codes.len() - num_packed_vectors) * code_len,
532        ),
533        codes.len() - num_packed_vectors,
534        code_len,
535    );
536
537    let offset = codes.values().len() - transposed_codes.len();
538    for (i, v) in transposed_codes.values().iter().enumerate() {
539        blocks[offset + i] = *v;
540    }
541
542    assert_eq!(blocks.len(), codes.values().len());
543    FixedSizeListArray::try_new_from_values(UInt8Array::from(blocks), code_len as i32).unwrap()
544}
545
546// Inverse of pack_codes
547pub fn unpack_codes(codes: &FixedSizeListArray) -> FixedSizeListArray {
548    let code_len = codes.value_length() as usize;
549    let num_vectors = codes.len();
550
551    // Calculate number of complete batches
552    let num_blocks = num_vectors / BATCH_SIZE;
553    let num_packed_vectors = num_blocks * BATCH_SIZE;
554
555    let mut unpacked = vec![0u8; codes.values().len()];
556
557    let codes_values = codes.values().as_primitive::<UInt8Type>().values();
558
559    // Unpack complete batches
560    for batch_idx in 0..num_blocks {
561        let block_start = batch_idx * code_len * BATCH_SIZE;
562
563        for i in 0..code_len {
564            let block_offset = block_start + i * BATCH_SIZE;
565            let block = &codes_values[block_offset..block_offset + BATCH_SIZE];
566
567            // Reverse the permutation
568            for j in 0..16 {
569                let val0 = block[j];
570                let val1 = block[j + 16];
571
572                let low_0 = val0 & 0xF;
573                let high_0 = val0 >> 4;
574                let low_1 = val1 & 0xF;
575                let high_1 = val1 >> 4;
576
577                let vec_idx_0 = batch_idx * BATCH_SIZE + PERM0[j];
578                let vec_idx_1 = batch_idx * BATCH_SIZE + PERM0[j] + 16;
579
580                unpacked[vec_idx_0 * code_len + i] = low_0 | (low_1 << 4);
581                unpacked[vec_idx_1 * code_len + i] = high_0 | (high_1 << 4);
582            }
583        }
584    }
585
586    // Transpose back the remainder
587    if num_packed_vectors < num_vectors {
588        let remainder = num_vectors - num_packed_vectors;
589        let offset = num_packed_vectors * code_len;
590        let transposed_data = &codes_values[offset..];
591
592        // Transpose from column-major back to row-major
593        for row in 0..remainder {
594            for col in 0..code_len {
595                unpacked[offset + row * code_len + col] = transposed_data[col * remainder + row];
596            }
597        }
598    }
599
600    FixedSizeListArray::try_new_from_values(UInt8Array::from(unpacked), code_len as i32).unwrap()
601}
602
603#[async_trait]
604impl QuantizerStorage for RabitQuantizationStorage {
605    type Metadata = RabitQuantizationMetadata;
606
607    fn try_from_batch(
608        batch: RecordBatch,
609        metadata: &Self::Metadata,
610        distance_type: DistanceType,
611        _fri: Option<Arc<FragReuseIndex>>,
612    ) -> Result<Self> {
613        let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>().clone();
614        let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list().clone();
615        let add_factors = batch[ADD_FACTORS_COLUMN]
616            .as_primitive::<Float32Type>()
617            .clone();
618        let scale_factors = batch[SCALE_FACTORS_COLUMN]
619            .as_primitive::<Float32Type>()
620            .clone();
621
622        let (batch, codes) = if !metadata.packed {
623            let codes = pack_codes(&codes);
624            let batch = batch.replace_column_by_name(RABIT_CODE_COLUMN, Arc::new(codes))?;
625            let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list().clone();
626            (batch, codes)
627        } else {
628            (batch, codes)
629        };
630
631        let mut metadata = metadata.clone();
632        metadata.packed = true;
633
634        Ok(Self {
635            metadata,
636            batch,
637            distance_type,
638            row_ids,
639            codes,
640            add_factors,
641            scale_factors,
642        })
643    }
644
645    fn metadata(&self) -> &Self::Metadata {
646        &self.metadata
647    }
648
649    async fn load_partition(
650        reader: &PreviousFileReader,
651        range: std::ops::Range<usize>,
652        distance_type: DistanceType,
653        metadata: &Self::Metadata,
654        frag_reuse_index: Option<Arc<FragReuseIndex>>,
655    ) -> Result<Self> {
656        let schema = reader.schema();
657        let batch = reader.read_range(range, schema).await?;
658        Self::try_from_batch(batch, metadata, distance_type, frag_reuse_index)
659    }
660
661    fn remap(&self, mapping: &HashMap<u64, Option<u64>>) -> Result<Self> {
662        let num_vectors = self.codes.len();
663        let num_code_bytes = self.codes.value_length() as usize;
664        let codes = self.codes.values().as_primitive::<UInt8Type>().values();
665        let mut indices = Vec::with_capacity(num_vectors);
666        let mut new_row_ids = Vec::with_capacity(num_vectors);
667        let mut new_codes = Vec::with_capacity(codes.len());
668
669        let row_ids = self.row_ids.values();
670        for (i, row_id) in row_ids.iter().enumerate() {
671            match mapping.get(row_id) {
672                Some(Some(new_id)) => {
673                    indices.push(i as u32);
674                    new_row_ids.push(*new_id);
675                    new_codes.extend(get_rq_code(codes, i, num_vectors, num_code_bytes));
676                }
677                Some(None) => {}
678                None => {
679                    indices.push(i as u32);
680                    new_row_ids.push(*row_id);
681                    new_codes.extend(get_rq_code(codes, i, num_vectors, num_code_bytes));
682                }
683            }
684        }
685
686        let new_row_ids = UInt64Array::from(new_row_ids);
687        let new_codes = FixedSizeListArray::try_new_from_values(
688            UInt8Array::from(new_codes),
689            num_code_bytes as i32,
690        )?;
691        let batch = if new_row_ids.is_empty() {
692            RecordBatch::new_empty(self.schema().clone())
693        } else {
694            let codes = Arc::new(pack_codes(&new_codes));
695            self.batch
696                .take(&UInt32Array::from(indices))?
697                .replace_column_by_name(ROW_ID, Arc::new(new_row_ids.clone()))?
698                .replace_column_by_name(RABIT_CODE_COLUMN, codes)?
699        };
700        let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list().clone();
701
702        Ok(Self {
703            metadata: self.metadata.clone(),
704            distance_type: self.distance_type,
705            batch,
706            codes,
707            add_factors: self.add_factors.clone(),
708            scale_factors: self.scale_factors.clone(),
709            row_ids: new_row_ids,
710        })
711    }
712}
713
714#[inline]
715fn get_rq_code(
716    codes: &[u8],
717    id: usize,
718    num_vectors: usize,
719    num_code_bytes: usize,
720) -> impl Iterator<Item = u8> + '_ {
721    let remainder = num_vectors % BATCH_SIZE;
722
723    if id < num_vectors - remainder {
724        // the codes are packed
725        let codes = &codes[id / BATCH_SIZE * BATCH_SIZE * num_code_bytes
726            ..(id / BATCH_SIZE + 1) * BATCH_SIZE * num_code_bytes];
727
728        let id_in_batch = id % BATCH_SIZE;
729        if id_in_batch < 16 {
730            let idx = PERM0_INVERSE[id_in_batch];
731            codes
732                .chunks_exact(BATCH_SIZE)
733                .map(|block| (block[idx] & 0xF) | (block[idx + 16] << 4))
734                .exact_size(num_code_bytes)
735                .collect_vec()
736                .into_iter()
737        } else {
738            let idx = PERM0_INVERSE[id_in_batch - 16];
739            codes
740                .chunks_exact(BATCH_SIZE)
741                .map(|block| (block[idx] >> 4) | (block[idx + 16] & 0xF0))
742                .exact_size(num_code_bytes)
743                .collect_vec()
744                .into_iter()
745        }
746    } else {
747        let id = id - (num_vectors - remainder);
748        let codes = &codes[(num_vectors - remainder) * num_code_bytes..];
749        codes
750            .iter()
751            .skip(id)
752            .step_by(remainder)
753            .copied()
754            .exact_size(num_code_bytes)
755            .collect_vec()
756            .into_iter()
757    }
758}
759
760#[cfg(test)]
761mod tests {
762    use super::*;
763
764    fn build_dist_table_not_optimized<T: ArrowFloatType>(
765        sub_vec: &[T::Native],
766        dist_table: &mut [f32],
767    ) where
768        T::Native: AsPrimitive<f32>,
769    {
770        for (j, dist) in dist_table.iter_mut().enumerate().take(SEGMENT_NUM_CODES) {
771            for (k, v) in sub_vec.iter().enumerate().take(SEGMENT_LENGTH) {
772                if j & (1 << k) != 0 {
773                    *dist += v.as_();
774                }
775            }
776        }
777    }
778
779    #[test]
780    fn test_build_dist_table_not_optimized() {
781        let sub_vec = vec![1.0, 2.0, 3.0, 4.0];
782        let mut expected = vec![0.0; SEGMENT_NUM_CODES];
783        build_dist_table_not_optimized::<Float32Type>(&sub_vec, &mut expected);
784        let mut dist_table = vec![0.0; SEGMENT_NUM_CODES];
785        build_dist_table_for_subvec::<Float32Type>(&sub_vec, &mut dist_table);
786        assert_eq!(dist_table, expected);
787    }
788
789    #[test]
790    fn test_pack_unpack_codes() {
791        // Test with multiple batch sizes to cover both packed and transposed sections
792        for num_vectors in [10, 32, 50, 64, 100] {
793            let code_len = 8;
794
795            // Create test data with known pattern
796            let mut codes_data = Vec::new();
797            for i in 0..num_vectors {
798                for j in 0..code_len {
799                    codes_data.push((i * code_len + j) as u8);
800                }
801            }
802
803            let original_codes = FixedSizeListArray::try_new_from_values(
804                UInt8Array::from(codes_data.clone()),
805                code_len,
806            )
807            .unwrap();
808
809            // Pack and then unpack
810            let packed = pack_codes(&original_codes);
811            let unpacked = unpack_codes(&packed);
812
813            // Verify they match
814            assert_eq!(original_codes.len(), unpacked.len());
815            assert_eq!(original_codes.value_length(), unpacked.value_length());
816
817            let original_values = original_codes.values().as_primitive::<UInt8Type>().values();
818            let unpacked_values = unpacked.values().as_primitive::<UInt8Type>().values();
819
820            assert_eq!(
821                original_values, unpacked_values,
822                "Mismatch for num_vectors={}",
823                num_vectors
824            );
825        }
826    }
827}