Skip to main content

lance_index/vector/bq/
storage.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use arrow::array::AsArray;
8use arrow::datatypes::{Float16Type, Float32Type, Float64Type, UInt8Type, UInt64Type};
9use arrow_array::{
10    Array, FixedSizeListArray, Float32Array, RecordBatch, UInt8Array, UInt32Array, UInt64Array,
11};
12use arrow_schema::{DataType, SchemaRef};
13use async_trait::async_trait;
14use bytes::{Bytes, BytesMut};
15use deepsize::DeepSizeOf;
16use itertools::Itertools;
17use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, RecordBatchExt};
18use lance_core::{Error, ROW_ID, Result};
19use lance_file::previous::reader::FileReader as PreviousFileReader;
20use lance_linalg::distance::{DistanceType, Dot};
21use lance_linalg::simd::dist_table::{BATCH_SIZE, PERM0, PERM0_INVERSE};
22use lance_linalg::simd::{self};
23use lance_table::utils::LanceIteratorExtension;
24use num_traits::AsPrimitive;
25use prost::Message;
26use serde::{Deserialize, Serialize};
27
28use crate::frag_reuse::FragReuseIndex;
29use crate::pb;
30use crate::vector::bq::RQRotationType;
31use crate::vector::bq::rotation::apply_fast_rotation;
32use crate::vector::bq::transform::{ADD_FACTORS_COLUMN, SCALE_FACTORS_COLUMN};
33use crate::vector::pq::storage::transpose;
34use crate::vector::quantizer::{QuantizerMetadata, QuantizerStorage};
35use crate::vector::storage::{DistCalculator, VectorStore};
36
37pub const RABIT_METADATA_KEY: &str = "lance:rabit";
38pub const RABIT_CODE_COLUMN: &str = "_rabit_codes";
39pub const SEGMENT_LENGTH: usize = 4;
40pub const SEGMENT_NUM_CODES: usize = 1 << SEGMENT_LENGTH;
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct RabitQuantizationMetadata {
44    // this rotate matrix is large, and lance index would store all metadata in schema metadata,
45    // which is in JSON format, so we skip it in serialization and deserialization, and store it
46    // in the global buffer, which is a binary format (protobuf for now) for efficiency.
47    #[serde(skip)]
48    pub rotate_mat: Option<FixedSizeListArray>,
49    #[serde(default)]
50    pub rotate_mat_position: Option<u32>,
51    #[serde(default)]
52    pub fast_rotation_signs: Option<Vec<u8>>,
53    #[serde(default = "default_rotation_type_compat")]
54    pub rotation_type: RQRotationType,
55    #[serde(default)]
56    pub code_dim: u32,
57    pub num_bits: u8,
58    pub packed: bool,
59}
60
61fn default_rotation_type_compat() -> RQRotationType {
62    // Older metadata does not have this field and always used dense matrices.
63    RQRotationType::Matrix
64}
65
66impl DeepSizeOf for RabitQuantizationMetadata {
67    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
68        self.rotate_mat
69            .as_ref()
70            .map(|inv_p| inv_p.get_array_memory_size())
71            .unwrap_or(0)
72            + self
73                .fast_rotation_signs
74                .as_ref()
75                .map(|signs| signs.len())
76                .unwrap_or(0)
77    }
78}
79
80#[async_trait]
81impl QuantizerMetadata for RabitQuantizationMetadata {
82    fn buffer_index(&self) -> Option<u32> {
83        match self.rotation_type {
84            RQRotationType::Matrix => self.rotate_mat_position,
85            RQRotationType::Fast => None,
86        }
87    }
88
89    fn set_buffer_index(&mut self, index: u32) {
90        self.rotate_mat_position = Some(index);
91    }
92
93    fn parse_buffer(&mut self, bytes: Bytes) -> Result<()> {
94        if self.rotation_type != RQRotationType::Matrix {
95            return Ok(());
96        }
97        debug_assert!(!bytes.is_empty());
98        let codebook_tensor: pb::Tensor = pb::Tensor::decode(bytes)?;
99        self.rotate_mat = Some(FixedSizeListArray::try_from(&codebook_tensor)?);
100        if self.code_dim == 0 {
101            self.code_dim = self
102                .rotate_mat
103                .as_ref()
104                .map(|rotate_mat| rotate_mat.len() as u32)
105                .unwrap_or(0);
106        }
107        Ok(())
108    }
109
110    fn extra_metadata(&self) -> Result<Option<Bytes>> {
111        match self.rotation_type {
112            RQRotationType::Matrix => {
113                if let Some(inv_p) = &self.rotate_mat {
114                    let inv_p_tensor = pb::Tensor::try_from(inv_p)?;
115                    let mut bytes = BytesMut::new();
116                    inv_p_tensor.encode(&mut bytes)?;
117                    Ok(Some(bytes.freeze()))
118                } else {
119                    Ok(None)
120                }
121            }
122            RQRotationType::Fast => Ok(None),
123        }
124    }
125
126    async fn load(reader: &PreviousFileReader) -> Result<Self> {
127        let metadata_str = reader
128            .schema()
129            .metadata
130            .get(RABIT_METADATA_KEY)
131            .ok_or(Error::index(format!(
132                "Reading Rabit metadata: metadata key {} not found",
133                RABIT_METADATA_KEY
134            )))?;
135        serde_json::from_str(metadata_str)
136            .map_err(|_| Error::index(format!("Failed to parse index metadata: {}", metadata_str)))
137    }
138}
139
140#[derive(Debug, Clone)]
141pub struct RabitQuantizationStorage {
142    metadata: RabitQuantizationMetadata,
143    batch: RecordBatch,
144    distance_type: DistanceType,
145
146    // helper fields
147    row_ids: UInt64Array,
148    codes: FixedSizeListArray,
149    add_factors: Float32Array,
150    scale_factors: Float32Array,
151}
152
153impl DeepSizeOf for RabitQuantizationStorage {
154    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
155        self.metadata.deep_size_of_children(context) + self.batch.get_array_memory_size()
156    }
157}
158
159impl RabitQuantizationStorage {
160    fn rotate_query_vector_dense<T: ArrowFloatType>(
161        rotate_mat: &FixedSizeListArray,
162        qr: &dyn Array,
163    ) -> Vec<f32>
164    where
165        T::Native: Dot,
166    {
167        let d = qr.len();
168        let code_dim = rotate_mat.len();
169        let rotate_mat = rotate_mat
170            .values()
171            .as_any()
172            .downcast_ref::<T::ArrayType>()
173            .unwrap()
174            .as_slice();
175
176        let qr = qr
177            .as_any()
178            .downcast_ref::<T::ArrayType>()
179            .unwrap()
180            .as_slice();
181
182        rotate_mat
183            .chunks_exact(code_dim)
184            .map(|chunk| lance_linalg::distance::dot(&chunk[..d], qr))
185            .collect()
186    }
187
188    fn rotate_query_vector_fast<T: ArrowFloatType>(
189        code_dim: usize,
190        signs: &[u8],
191        qr: &dyn Array,
192    ) -> Vec<f32>
193    where
194        T::Native: AsPrimitive<f32>,
195    {
196        let qr = qr
197            .as_any()
198            .downcast_ref::<T::ArrayType>()
199            .unwrap()
200            .as_slice();
201
202        let mut output = vec![0.0f32; code_dim];
203        apply_fast_rotation(qr, &mut output, signs);
204        output
205    }
206}
207
208pub struct RabitDistCalculator<'a> {
209    dim: usize,
210    // num_bits is the number of bits per dimension,
211    // it's always 1 for now
212    num_bits: u8,
213    // n * d * num_bits / 8 bytes
214    codes: &'a [u8],
215    // this is a flattened 2D array of size d/4 * 16,
216    // we split the query codes into d/4 chunks, each chunk is with 4 elements,
217    // then dist_table[i][j] is the distance between the i-th query code and the code j
218    dist_table: Vec<f32>,
219    add_factors: &'a [f32],
220    scale_factors: &'a [f32],
221    query_factor: f32,
222
223    sum_q: f32,
224    sqrt_d: f32,
225}
226
227impl<'a> RabitDistCalculator<'a> {
228    #[allow(clippy::too_many_arguments)]
229    pub fn new(
230        dim: usize,
231        num_bits: u8,
232        dist_table: Vec<f32>,
233        sum_q: f32,
234        codes: &'a [u8],
235        add_factors: &'a [f32],
236        scale_factors: &'a [f32],
237        query_factor: f32,
238    ) -> Self {
239        Self {
240            dim,
241            num_bits,
242            codes,
243            dist_table,
244            add_factors,
245            scale_factors,
246            query_factor,
247            sqrt_d: (dim as f32 * num_bits as f32).sqrt(),
248            sum_q,
249        }
250    }
251}
252
253#[inline]
254fn lowbit(x: usize) -> usize {
255    1 << x.trailing_zeros()
256}
257
258#[inline]
259pub fn build_dist_table_direct<T: ArrowFloatType>(qc: &[T::Native]) -> Vec<f32>
260where
261    T::Native: AsPrimitive<f32>,
262{
263    // every 4 bits (SEGMENT_LENGTH) is a segment, and we need to compute the distance between the segment and all the codes
264    // so there are dim/4 segments, and the number of codes is 16 (2^{SEGMENT_LENGTH}),
265    // so we have dim/4 * 16 = dim * 4 elements in the dist_table
266    let mut dist_table = vec![0.0; qc.len() * 4];
267    qc.chunks_exact(SEGMENT_LENGTH)
268        .zip(dist_table.chunks_exact_mut(SEGMENT_NUM_CODES))
269        .for_each(|(sub_vec, dist_table)| build_dist_table_for_subvec::<T>(sub_vec, dist_table));
270    dist_table
271}
272
273#[inline(always)]
274fn build_dist_table_for_subvec<T: ArrowFloatType>(sub_vec: &[T::Native], dist_table: &mut [f32])
275where
276    T::Native: AsPrimitive<f32>,
277{
278    // skip 0 because it's always 0
279    (1..SEGMENT_NUM_CODES).for_each(|j| {
280        // this is a little bit tricky,
281        // j represents a subset of 4 bits, that if the i-th bit of `j` is 1,
282        // then we need to add the distance of the i-th dim of the segment.
283        // but we don't need to check all bits of `j`,
284        // because `j` = `j - lowbit(j)` + `lowbit(j)`,
285        // where `j-lowbit(j)` is less than `j`,
286        // which means dist_table[j-lowbit(j)] is already computed,
287        // and we can use it to compute dist_table[j]
288        // for example, if j = 0b1010, then j - lowbit(j) = 0b1000,
289        // and dist_table[0b1000] is already computed,
290        // so dist_table[0b1010] = dist_table[0b1000] + sub_vec[LOWBIT_IDX[0b1010]];
291        // where lowbit(0b1010) = 0b10, LOWBIT_IDX[0b1010] = LOWBIT_IDX[0b10] = 1.
292        dist_table[j] = dist_table[j - lowbit(j)] + sub_vec[LOWBIT_IDX[j]].as_();
293    })
294}
295
296// Quantize the distance table into a caller-owned buffer.
297#[inline]
298fn quantize_dist_table_into(dist_table: &[f32], quantized_dist_table: &mut Vec<u8>) -> (f32, f32) {
299    let (qmin, qmax) = dist_table
300        .iter()
301        .cloned()
302        .minmax_by(|a, b| a.total_cmp(b))
303        .into_option()
304        .unwrap();
305    quantized_dist_table.clear();
306    quantized_dist_table.resize(dist_table.len(), 0);
307    // this happens if the query is all zeros
308    if qmin == qmax {
309        return (qmin, qmax);
310    }
311    let factor = 255.0 / (qmax - qmin);
312    quantized_dist_table
313        .iter_mut()
314        .zip(dist_table.iter())
315        .for_each(|(quantized, &d)| {
316            *quantized = ((d - qmin) * factor).round() as u8;
317        });
318
319    (qmin, qmax)
320}
321
322impl DistCalculator for RabitDistCalculator<'_> {
323    #[inline(always)]
324    fn distance(&self, id: u32) -> f32 {
325        let id = id as usize;
326        let code_len = self.dim * (self.num_bits as usize) / u8::BITS as usize;
327        let num_vectors = self.codes.len() / code_len;
328        let dist =
329            compute_single_rq_distance(self.codes, id, num_vectors, code_len, &self.dist_table);
330
331        // distance between quantized vector and query vector
332        let dist_vq_qr = (2.0 * dist - self.sum_q) / self.sqrt_d;
333        dist_vq_qr * self.scale_factors[id] + self.add_factors[id] + self.query_factor
334    }
335
336    #[inline(always)]
337    fn distance_all(&self, _: usize) -> Vec<f32> {
338        let mut dists = Vec::new();
339        let mut quantized_dists = Vec::new();
340        let mut quantized_dists_table = Vec::new();
341        self.distance_all_with_scratch(
342            0,
343            &mut dists,
344            &mut quantized_dists,
345            &mut quantized_dists_table,
346        );
347        dists
348    }
349
350    #[inline(always)]
351    fn distance_all_with_scratch(
352        &self,
353        _: usize,
354        dists: &mut Vec<f32>,
355        quantized_dists: &mut Vec<u16>,
356        quantized_dists_table: &mut Vec<u8>,
357    ) {
358        let code_len = self.dim * (self.num_bits as usize) / u8::BITS as usize;
359        let n = self.codes.len() / code_len;
360        if n == 0 {
361            dists.clear();
362            quantized_dists.clear();
363            return;
364        }
365
366        dists.clear();
367        dists.resize(n, 0.0);
368        let (qmin, qmax) = quantize_dist_table_into(&self.dist_table, quantized_dists_table);
369        quantized_dists.clear();
370        quantized_dists.resize(n, 0);
371
372        let remainder = n % BATCH_SIZE;
373        simd::dist_table::sum_4bit_dist_table(
374            n - remainder,
375            code_len,
376            self.codes,
377            quantized_dists_table,
378            quantized_dists,
379        );
380
381        let range = (qmax - qmin) / 255.0;
382        let num_tables = quantized_dists_table.len() / 16;
383        let sum_min = num_tables as f32 * qmin;
384        dists
385            .iter_mut()
386            .take(n - remainder)
387            .zip(quantized_dists.iter().take(n - remainder))
388            .for_each(|(dist, q_dist)| {
389                *dist = (*q_dist as f32) * range + sum_min;
390            });
391
392        dists
393            .iter_mut()
394            .enumerate()
395            .take(n - remainder)
396            .for_each(|(id, dist)| {
397                let dist_vq_qr = (2.0 * *dist - self.sum_q) / self.sqrt_d;
398                *dist =
399                    dist_vq_qr * self.scale_factors[id] + self.add_factors[id] + self.query_factor;
400            });
401
402        dists
403            .iter_mut()
404            .enumerate()
405            .skip(n - remainder)
406            .for_each(|(id, dist)| {
407                *dist = self.distance(id as u32);
408            });
409    }
410}
411
412impl VectorStore for RabitQuantizationStorage {
413    type DistanceCalculator<'a> = RabitDistCalculator<'a>;
414
415    fn as_any(&self) -> &dyn std::any::Any {
416        self
417    }
418
419    fn schema(&self) -> &SchemaRef {
420        self.batch.schema_ref()
421    }
422
423    fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch> + Send> {
424        Ok(std::iter::once(self.batch.clone()))
425    }
426
427    fn append_batch(&self, _batch: RecordBatch, _vector_column: &str) -> Result<Self> {
428        unimplemented!("RabitQ does not support append_batch")
429    }
430
431    fn len(&self) -> usize {
432        self.batch.num_rows()
433    }
434
435    fn row_id(&self, id: u32) -> u64 {
436        self.row_ids.value(id as usize)
437    }
438
439    fn row_ids(&self) -> impl Iterator<Item = &u64> {
440        self.row_ids.values().iter()
441    }
442
443    fn distance_type(&self) -> DistanceType {
444        self.distance_type
445    }
446
447    // qr = (q-c)
448    #[inline(never)]
449    fn dist_calculator(&self, qr: Arc<dyn Array>, dist_q_c: f32) -> Self::DistanceCalculator<'_> {
450        let codes = self.codes.values().as_primitive::<UInt8Type>().values();
451        let code_dim = if self.metadata.code_dim > 0 {
452            self.metadata.code_dim as usize
453        } else {
454            self.metadata
455                .rotate_mat
456                .as_ref()
457                .map(|rotate_mat| rotate_mat.len())
458                .unwrap_or_default()
459        };
460
461        let rotated_qr = match self.metadata.rotation_type {
462            RQRotationType::Matrix => {
463                let rotate_mat = self
464                    .metadata
465                    .rotate_mat
466                    .as_ref()
467                    .expect("RabitQ dense rotation metadata not loaded");
468
469                match rotate_mat.value_type() {
470                    DataType::Float16 => {
471                        Self::rotate_query_vector_dense::<Float16Type>(rotate_mat, &qr)
472                    }
473                    DataType::Float32 => {
474                        Self::rotate_query_vector_dense::<Float32Type>(rotate_mat, &qr)
475                    }
476                    DataType::Float64 => {
477                        Self::rotate_query_vector_dense::<Float64Type>(rotate_mat, &qr)
478                    }
479                    dt => unimplemented!("RabitQ does not support data type: {}", dt),
480                }
481            }
482            RQRotationType::Fast => {
483                let signs = self
484                    .metadata
485                    .fast_rotation_signs
486                    .as_ref()
487                    .expect("RabitQ fast rotation metadata not loaded");
488                match qr.data_type() {
489                    DataType::Float16 => {
490                        Self::rotate_query_vector_fast::<Float16Type>(code_dim, signs, &qr)
491                    }
492                    DataType::Float32 => {
493                        Self::rotate_query_vector_fast::<Float32Type>(code_dim, signs, &qr)
494                    }
495                    DataType::Float64 => {
496                        Self::rotate_query_vector_fast::<Float64Type>(code_dim, signs, &qr)
497                    }
498                    dt => unimplemented!("RabitQ does not support data type: {}", dt),
499                }
500            }
501        };
502
503        let dist_table = build_dist_table_direct::<Float32Type>(&rotated_qr);
504        let sum_q = rotated_qr.into_iter().sum();
505
506        let q_factor = match self.distance_type {
507            DistanceType::L2 => dist_q_c,
508            DistanceType::Cosine | DistanceType::Dot => dist_q_c - 1.0,
509            _ => unimplemented!(
510                "RabitQ does not support distance type: {}",
511                self.distance_type
512            ),
513        };
514        RabitDistCalculator::new(
515            qr.len(),
516            self.metadata.num_bits,
517            dist_table,
518            sum_q,
519            codes,
520            self.add_factors.values(),
521            self.scale_factors.values(),
522            q_factor,
523        )
524    }
525
526    // TODO: implement this
527    // This method is required for HNSW, we can't support HNSW_RABIT before this is implemented
528    fn dist_calculator_from_id(&self, _: u32) -> Self::DistanceCalculator<'_> {
529        unimplemented!("RabitQ does not support dist_calculator_from_id")
530    }
531}
532
533const LOWBIT_IDX: [usize; 16] = {
534    let mut array = [0; 16];
535    let mut i = 1;
536    while i < 16 {
537        array[i] = i.trailing_zeros() as usize;
538        i += 1;
539    }
540    array
541};
542
543fn get_column(
544    quantization_code: &[u8],
545    code_len: usize,
546    row: usize,
547    col_idx: usize,
548    codes: &mut [u8; 32],
549) {
550    for (i, code) in codes.iter_mut().enumerate() {
551        let vec_idx = row + i;
552        *code = quantization_code[vec_idx * code_len + col_idx];
553    }
554}
555
556pub fn pack_codes(codes: &FixedSizeListArray) -> FixedSizeListArray {
557    let code_len = codes.value_length() as usize;
558
559    // round up num of vectors to multiple of batch size (32)
560    let num_blocks = codes.len() / BATCH_SIZE;
561    let num_packed_vectors = num_blocks * BATCH_SIZE;
562
563    // calculate total size for packed blocks
564    // we pack each 32 vectors into a block, each block contains 2 codes (1byte) of each vector
565    // so every 32 vectors would produce code_len blocks
566    // the low 16 bytes of each block is the codes for the low 4 bits of each vector
567    // the high 16 bytes of each block is the codes for the high 4 bits of each vector
568    let mut blocks = vec![0u8; codes.values().len()];
569
570    let codes_values = codes
571        .slice(0, num_packed_vectors)
572        .values()
573        .as_primitive::<UInt8Type>()
574        .clone();
575    let codes_values = codes_values.values();
576
577    // Pack codes batch by batch
578    // Each batch contains codes for 32 vectors
579    let mut col = [0u8; 32];
580    let mut col_0 = [0u8; 32]; // lower 4 bits
581    let mut col_1 = [0u8; 32]; // higher 4 bits
582    for row in (0..num_packed_vectors).step_by(BATCH_SIZE) {
583        // Get quantization codes for each column for each batch
584        // i.e., we get the codes for 8 dims of 32 vectors and reorganize the data layout
585        // based on the shuffle SIMD instruction used during querying
586        for i in 0..code_len {
587            get_column(codes_values, code_len, row, i, &mut col);
588
589            for j in 0..32 {
590                col_0[j] = col[j] & 0xF;
591                col_1[j] = col[j] >> 4;
592            }
593
594            let block_offset = (row / BATCH_SIZE) * code_len * BATCH_SIZE + i * BATCH_SIZE;
595            for j in 0..16 {
596                // The lower 4 bits represent vector 0 to 15
597                // The upper 4 bits represent vector 16 to 31
598                let val0 = col_0[PERM0[j]] | (col_0[PERM0[j] + 16] << 4);
599                let val1 = col_1[PERM0[j]] | (col_1[PERM0[j] + 16] << 4);
600                blocks[block_offset + j] = val0;
601                blocks[block_offset + j + 16] = val1;
602            }
603        }
604    }
605
606    // for the left codes, transpose them for better cache locality
607    let transposed_codes = transpose(
608        &codes.values().as_primitive::<UInt8Type>().slice(
609            num_packed_vectors * code_len,
610            (codes.len() - num_packed_vectors) * code_len,
611        ),
612        codes.len() - num_packed_vectors,
613        code_len,
614    );
615
616    let offset = codes.values().len() - transposed_codes.len();
617    for (i, v) in transposed_codes.values().iter().enumerate() {
618        blocks[offset + i] = *v;
619    }
620
621    assert_eq!(blocks.len(), codes.values().len());
622    FixedSizeListArray::try_new_from_values(UInt8Array::from(blocks), code_len as i32).unwrap()
623}
624
625// Inverse of pack_codes
626pub fn unpack_codes(codes: &FixedSizeListArray) -> FixedSizeListArray {
627    let code_len = codes.value_length() as usize;
628    let num_vectors = codes.len();
629
630    // Calculate number of complete batches
631    let num_blocks = num_vectors / BATCH_SIZE;
632    let num_packed_vectors = num_blocks * BATCH_SIZE;
633
634    let mut unpacked = vec![0u8; codes.values().len()];
635
636    let codes_values = codes.values().as_primitive::<UInt8Type>().values();
637
638    // Unpack complete batches
639    for batch_idx in 0..num_blocks {
640        let block_start = batch_idx * code_len * BATCH_SIZE;
641
642        for i in 0..code_len {
643            let block_offset = block_start + i * BATCH_SIZE;
644            let block = &codes_values[block_offset..block_offset + BATCH_SIZE];
645
646            // Reverse the permutation
647            for j in 0..16 {
648                let val0 = block[j];
649                let val1 = block[j + 16];
650
651                let low_0 = val0 & 0xF;
652                let high_0 = val0 >> 4;
653                let low_1 = val1 & 0xF;
654                let high_1 = val1 >> 4;
655
656                let vec_idx_0 = batch_idx * BATCH_SIZE + PERM0[j];
657                let vec_idx_1 = batch_idx * BATCH_SIZE + PERM0[j] + 16;
658
659                unpacked[vec_idx_0 * code_len + i] = low_0 | (low_1 << 4);
660                unpacked[vec_idx_1 * code_len + i] = high_0 | (high_1 << 4);
661            }
662        }
663    }
664
665    // Transpose back the remainder
666    if num_packed_vectors < num_vectors {
667        let remainder = num_vectors - num_packed_vectors;
668        let offset = num_packed_vectors * code_len;
669        let transposed_data = &codes_values[offset..];
670
671        // Transpose from column-major back to row-major
672        for row in 0..remainder {
673            for col in 0..code_len {
674                unpacked[offset + row * code_len + col] = transposed_data[col * remainder + row];
675            }
676        }
677    }
678
679    FixedSizeListArray::try_new_from_values(UInt8Array::from(unpacked), code_len as i32).unwrap()
680}
681
682#[async_trait]
683impl QuantizerStorage for RabitQuantizationStorage {
684    type Metadata = RabitQuantizationMetadata;
685
686    fn try_from_batch(
687        batch: RecordBatch,
688        metadata: &Self::Metadata,
689        distance_type: DistanceType,
690        _fri: Option<Arc<FragReuseIndex>>,
691    ) -> Result<Self> {
692        let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>().clone();
693        let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list().clone();
694        let add_factors = batch[ADD_FACTORS_COLUMN]
695            .as_primitive::<Float32Type>()
696            .clone();
697        let scale_factors = batch[SCALE_FACTORS_COLUMN]
698            .as_primitive::<Float32Type>()
699            .clone();
700
701        let (batch, codes) = if !metadata.packed {
702            let codes = pack_codes(&codes);
703            let batch = batch.replace_column_by_name(RABIT_CODE_COLUMN, Arc::new(codes))?;
704            let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list().clone();
705            (batch, codes)
706        } else {
707            (batch, codes)
708        };
709
710        let mut metadata = metadata.clone();
711        metadata.packed = true;
712
713        Ok(Self {
714            metadata,
715            batch,
716            distance_type,
717            row_ids,
718            codes,
719            add_factors,
720            scale_factors,
721        })
722    }
723
724    fn metadata(&self) -> &Self::Metadata {
725        &self.metadata
726    }
727
728    async fn load_partition(
729        reader: &PreviousFileReader,
730        range: std::ops::Range<usize>,
731        distance_type: DistanceType,
732        metadata: &Self::Metadata,
733        frag_reuse_index: Option<Arc<FragReuseIndex>>,
734    ) -> Result<Self> {
735        let schema = reader.schema();
736        let batch = reader.read_range(range, schema).await?;
737        Self::try_from_batch(batch, metadata, distance_type, frag_reuse_index)
738    }
739
740    fn remap(&self, mapping: &HashMap<u64, Option<u64>>) -> Result<Self> {
741        let num_vectors = self.codes.len();
742        let num_code_bytes = self.codes.value_length() as usize;
743        let codes = self.codes.values().as_primitive::<UInt8Type>().values();
744        let mut indices = Vec::with_capacity(num_vectors);
745        let mut new_row_ids = Vec::with_capacity(num_vectors);
746        let mut new_codes = Vec::with_capacity(codes.len());
747
748        let row_ids = self.row_ids.values();
749        for (i, row_id) in row_ids.iter().enumerate() {
750            match mapping.get(row_id) {
751                Some(Some(new_id)) => {
752                    indices.push(i as u32);
753                    new_row_ids.push(*new_id);
754                    new_codes.extend(get_rq_code(codes, i, num_vectors, num_code_bytes));
755                }
756                Some(None) => {}
757                None => {
758                    indices.push(i as u32);
759                    new_row_ids.push(*row_id);
760                    new_codes.extend(get_rq_code(codes, i, num_vectors, num_code_bytes));
761                }
762            }
763        }
764
765        let new_row_ids = UInt64Array::from(new_row_ids);
766        let new_codes = FixedSizeListArray::try_new_from_values(
767            UInt8Array::from(new_codes),
768            num_code_bytes as i32,
769        )?;
770        let batch = if new_row_ids.is_empty() {
771            RecordBatch::new_empty(self.schema().clone())
772        } else {
773            let codes = Arc::new(pack_codes(&new_codes));
774            self.batch
775                .take(&UInt32Array::from(indices))?
776                .replace_column_by_name(ROW_ID, Arc::new(new_row_ids.clone()))?
777                .replace_column_by_name(RABIT_CODE_COLUMN, codes)?
778        };
779        let codes = batch[RABIT_CODE_COLUMN].as_fixed_size_list().clone();
780
781        Ok(Self {
782            metadata: self.metadata.clone(),
783            distance_type: self.distance_type,
784            batch,
785            codes,
786            add_factors: self.add_factors.clone(),
787            scale_factors: self.scale_factors.clone(),
788            row_ids: new_row_ids,
789        })
790    }
791}
792
793/// Compute the raw distance for a single vector without allocating.
794///
795/// Fuses code extraction from the packed layout with distance accumulation
796/// in a single pass, avoiding the intermediate `Vec` allocation that
797/// `get_rq_code` + iterator would require.
798#[inline]
799fn compute_single_rq_distance(
800    codes: &[u8],
801    id: usize,
802    num_vectors: usize,
803    num_code_bytes: usize,
804    dist_table: &[f32],
805) -> f32 {
806    let remainder = num_vectors % BATCH_SIZE;
807    let mut dist_table_iter = dist_table.chunks_exact(SEGMENT_NUM_CODES).tuples();
808
809    if id < num_vectors - remainder {
810        let batch_codes = &codes[id / BATCH_SIZE * BATCH_SIZE * num_code_bytes
811            ..(id / BATCH_SIZE + 1) * BATCH_SIZE * num_code_bytes];
812
813        let id_in_batch = id % BATCH_SIZE;
814        let idx = PERM0_INVERSE[id_in_batch % 16];
815        let is_lower = id_in_batch < 16;
816
817        let mut dist = 0.0f32;
818        for block in batch_codes.chunks_exact(BATCH_SIZE) {
819            let code_byte = if is_lower {
820                (block[idx] & 0xF) | (block[idx + 16] << 4)
821            } else {
822                (block[idx] >> 4) | (block[idx + 16] & 0xF0)
823            };
824            if let Some((current_dt, next_dt)) = dist_table_iter.next() {
825                let current_code = (code_byte & 0x0F) as usize;
826                let next_code = (code_byte >> 4) as usize;
827                dist += current_dt[current_code] + next_dt[next_code];
828            }
829        }
830        dist
831    } else {
832        let offset_id = id - (num_vectors - remainder);
833        let remainder_codes = &codes[(num_vectors - remainder) * num_code_bytes..];
834
835        let mut dist = 0.0f32;
836        for &code_byte in remainder_codes.iter().skip(offset_id).step_by(remainder) {
837            if let Some((current_dt, next_dt)) = dist_table_iter.next() {
838                let current_code = (code_byte & 0x0F) as usize;
839                let next_code = (code_byte >> 4) as usize;
840                dist += current_dt[current_code] + next_dt[next_code];
841            }
842        }
843        dist
844    }
845}
846
847#[inline]
848fn get_rq_code(
849    codes: &[u8],
850    id: usize,
851    num_vectors: usize,
852    num_code_bytes: usize,
853) -> impl Iterator<Item = u8> + '_ {
854    let remainder = num_vectors % BATCH_SIZE;
855
856    if id < num_vectors - remainder {
857        // the codes are packed
858        let codes = &codes[id / BATCH_SIZE * BATCH_SIZE * num_code_bytes
859            ..(id / BATCH_SIZE + 1) * BATCH_SIZE * num_code_bytes];
860
861        let id_in_batch = id % BATCH_SIZE;
862        if id_in_batch < 16 {
863            let idx = PERM0_INVERSE[id_in_batch];
864            codes
865                .chunks_exact(BATCH_SIZE)
866                .map(|block| (block[idx] & 0xF) | (block[idx + 16] << 4))
867                .exact_size(num_code_bytes)
868                .collect_vec()
869                .into_iter()
870        } else {
871            let idx = PERM0_INVERSE[id_in_batch - 16];
872            codes
873                .chunks_exact(BATCH_SIZE)
874                .map(|block| (block[idx] >> 4) | (block[idx + 16] & 0xF0))
875                .exact_size(num_code_bytes)
876                .collect_vec()
877                .into_iter()
878        }
879    } else {
880        let id = id - (num_vectors - remainder);
881        let codes = &codes[(num_vectors - remainder) * num_code_bytes..];
882        codes
883            .iter()
884            .skip(id)
885            .step_by(remainder)
886            .copied()
887            .exact_size(num_code_bytes)
888            .collect_vec()
889            .into_iter()
890    }
891}
892
893#[cfg(test)]
894mod tests {
895    use super::*;
896    use std::collections::HashMap;
897
898    use arrow_array::{ArrayRef, Float32Array, UInt64Array};
899    use lance_core::ROW_ID;
900    use lance_linalg::distance::DistanceType;
901
902    use crate::vector::bq::{RQRotationType, builder::RabitQuantizer};
903    use crate::vector::quantizer::{Quantization, QuantizerStorage};
904
905    fn build_dist_table_not_optimized<T: ArrowFloatType>(
906        sub_vec: &[T::Native],
907        dist_table: &mut [f32],
908    ) where
909        T::Native: AsPrimitive<f32>,
910    {
911        for (j, dist) in dist_table.iter_mut().enumerate().take(SEGMENT_NUM_CODES) {
912            for (k, v) in sub_vec.iter().enumerate().take(SEGMENT_LENGTH) {
913                if j & (1 << k) != 0 {
914                    *dist += v.as_();
915                }
916            }
917        }
918    }
919
920    #[test]
921    fn test_build_dist_table_not_optimized() {
922        let sub_vec = vec![1.0, 2.0, 3.0, 4.0];
923        let mut expected = vec![0.0; SEGMENT_NUM_CODES];
924        build_dist_table_not_optimized::<Float32Type>(&sub_vec, &mut expected);
925        let mut dist_table = vec![0.0; SEGMENT_NUM_CODES];
926        build_dist_table_for_subvec::<Float32Type>(&sub_vec, &mut dist_table);
927        assert_eq!(dist_table, expected);
928    }
929
930    #[test]
931    fn test_pack_unpack_codes() {
932        // Test with multiple batch sizes to cover both packed and transposed sections
933        for num_vectors in [10, 32, 50, 64, 100] {
934            let code_len = 8;
935
936            // Create test data with known pattern
937            let mut codes_data = Vec::new();
938            for i in 0..num_vectors {
939                for j in 0..code_len {
940                    codes_data.push((i * code_len + j) as u8);
941                }
942            }
943
944            let original_codes = FixedSizeListArray::try_new_from_values(
945                UInt8Array::from(codes_data.clone()),
946                code_len,
947            )
948            .unwrap();
949
950            // Pack and then unpack
951            let packed = pack_codes(&original_codes);
952            let unpacked = unpack_codes(&packed);
953
954            // Verify they match
955            assert_eq!(original_codes.len(), unpacked.len());
956            assert_eq!(original_codes.value_length(), unpacked.value_length());
957
958            let original_values = original_codes.values().as_primitive::<UInt8Type>().values();
959            let unpacked_values = unpacked.values().as_primitive::<UInt8Type>().values();
960
961            assert_eq!(
962                original_values, unpacked_values,
963                "Mismatch for num_vectors={}",
964                num_vectors
965            );
966        }
967    }
968
969    fn make_test_codes(num_vectors: usize, code_dim: i32) -> FixedSizeListArray {
970        let quantizer =
971            RabitQuantizer::new_with_rotation::<Float32Type>(1, code_dim, RQRotationType::Fast);
972        let values = Float32Array::from_iter_values(
973            (0..num_vectors * code_dim as usize).map(|idx| idx as f32 / code_dim as f32),
974        );
975        let vectors = FixedSizeListArray::try_new_from_values(values, code_dim).unwrap();
976        quantizer
977            .quantize(&vectors)
978            .unwrap()
979            .as_fixed_size_list()
980            .clone()
981    }
982
983    fn make_test_metadata(code_dim: usize) -> RabitQuantizationMetadata {
984        RabitQuantizer::new_with_rotation::<Float32Type>(1, code_dim as i32, RQRotationType::Fast)
985            .metadata(None)
986    }
987
988    fn make_test_batch(codes: FixedSizeListArray) -> RecordBatch {
989        let num_rows = codes.len();
990        RecordBatch::try_from_iter(vec![
991            (
992                ROW_ID,
993                Arc::new(UInt64Array::from_iter_values(0..num_rows as u64)) as ArrayRef,
994            ),
995            (RABIT_CODE_COLUMN, Arc::new(codes) as ArrayRef),
996            (
997                ADD_FACTORS_COLUMN,
998                Arc::new(Float32Array::from_iter_values(
999                    (0..num_rows).map(|v| v as f32),
1000                )) as ArrayRef,
1001            ),
1002            (
1003                SCALE_FACTORS_COLUMN,
1004                Arc::new(Float32Array::from_iter_values(
1005                    (0..num_rows).map(|v| v as f32 + 0.5),
1006                )) as ArrayRef,
1007            ),
1008        ])
1009        .unwrap()
1010    }
1011
1012    fn assert_codes_eq(actual: &FixedSizeListArray, expected: &FixedSizeListArray) {
1013        assert_eq!(actual.len(), expected.len());
1014        assert_eq!(actual.value_length(), expected.value_length());
1015        assert_eq!(
1016            actual.values().as_primitive::<UInt8Type>().values(),
1017            expected.values().as_primitive::<UInt8Type>().values()
1018        );
1019    }
1020
1021    #[test]
1022    fn test_try_from_batch_canonicalizes_rq_codes_to_packed_layout() {
1023        let original_codes = make_test_codes(50, 64);
1024        let metadata = make_test_metadata(original_codes.value_length() as usize * 8);
1025        assert!(!metadata.packed);
1026
1027        let storage = RabitQuantizationStorage::try_from_batch(
1028            make_test_batch(original_codes.clone()),
1029            &metadata,
1030            DistanceType::L2,
1031            None,
1032        )
1033        .unwrap();
1034
1035        assert!(storage.metadata().packed);
1036        let stored_batch = storage.to_batches().unwrap().next().unwrap();
1037        let stored_codes = stored_batch[RABIT_CODE_COLUMN].as_fixed_size_list();
1038        let expected_codes = pack_codes(&original_codes);
1039        assert_codes_eq(stored_codes, &expected_codes);
1040    }
1041
1042    #[test]
1043    fn test_remap_preserves_packed_rq_storage_layout() {
1044        let original_codes = make_test_codes(50, 64);
1045        let metadata = make_test_metadata(original_codes.value_length() as usize * 8);
1046        let storage = RabitQuantizationStorage::try_from_batch(
1047            make_test_batch(original_codes.clone()),
1048            &metadata,
1049            DistanceType::L2,
1050            None,
1051        )
1052        .unwrap();
1053
1054        let mut mapping = HashMap::new();
1055        mapping.insert(1, Some(101));
1056        mapping.insert(3, None);
1057        mapping.insert(4, Some(104));
1058
1059        let remapped = storage.remap(&mapping).unwrap();
1060        assert!(remapped.metadata().packed);
1061
1062        let remapped_batch = remapped.to_batches().unwrap().next().unwrap();
1063        let remapped_row_ids = remapped_batch[ROW_ID].as_primitive::<UInt64Type>().values();
1064        let expected_row_ids = UInt64Array::from_iter_values(
1065            [0, 101, 2, 104]
1066                .into_iter()
1067                .chain(5..original_codes.len() as u64),
1068        );
1069        assert_eq!(remapped_row_ids, expected_row_ids.values());
1070
1071        let remapped_codes = remapped_batch[RABIT_CODE_COLUMN].as_fixed_size_list();
1072        let repacked = pack_codes(&unpack_codes(remapped_codes));
1073        assert_codes_eq(remapped_codes, &repacked);
1074    }
1075}