Skip to main content

lance_index/vector/bq/
transform.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::fmt::{Debug, Formatter};
5use std::sync::{Arc, LazyLock};
6
7use arrow::array::AsArray;
8use arrow::datatypes::{Float16Type, Float32Type, Float64Type, UInt32Type};
9use arrow_array::{Array, ArrowNativeTypeOp, FixedSizeListArray, Float32Array, RecordBatch};
10use arrow_schema::DataType;
11use lance_arrow::RecordBatchExt;
12use lance_core::{Error, Result};
13use lance_linalg::distance::{DistanceType, norm_squared_fsl};
14use tracing::instrument;
15
16use crate::vector::bq::builder::RabitQuantizer;
17use crate::vector::bq::storage::RABIT_CODE_COLUMN;
18use crate::vector::quantizer::Quantization;
19use crate::vector::transform::Transformer;
20use crate::vector::{CENTROID_DIST_COLUMN, PART_ID_COLUMN};
21
22// the inner product of quantized vector and the residual vector.
23pub const ADD_FACTORS_COLUMN: &str = "__add_factors";
24// the inner product of quantized vector and the centroid vector.
25pub const SCALE_FACTORS_COLUMN: &str = "__scale_factors";
26
27pub static ADD_FACTORS_FIELD: LazyLock<arrow_schema::Field> = LazyLock::new(|| {
28    arrow_schema::Field::new(ADD_FACTORS_COLUMN, arrow_schema::DataType::Float32, true)
29});
30pub static SCALE_FACTORS_FIELD: LazyLock<arrow_schema::Field> = LazyLock::new(|| {
31    arrow_schema::Field::new(SCALE_FACTORS_COLUMN, arrow_schema::DataType::Float32, true)
32});
33
34pub struct RQTransformer {
35    rq: RabitQuantizer,
36    distance_type: DistanceType,
37    centroids_norm_square: Option<Float32Array>,
38    vector_column: String,
39}
40
41impl RQTransformer {
42    pub fn new(
43        rq: RabitQuantizer,
44        distance_type: DistanceType,
45        centroids: FixedSizeListArray,
46        vector_column: impl Into<String>,
47    ) -> Self {
48        // for dot product, the add factor is `1 - v*c + |c|^2`, so we need to compute |c|^2
49        let centroids_norm_square = (distance_type == DistanceType::Dot)
50            .then(|| Float32Array::from(norm_squared_fsl(&centroids)));
51
52        Self {
53            rq,
54            distance_type,
55            centroids_norm_square,
56            vector_column: vector_column.into(),
57        }
58    }
59}
60
61impl Debug for RQTransformer {
62    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
63        write!(f, "RabitTransformer(vector_column={})", self.vector_column)
64    }
65}
66
67impl Transformer for RQTransformer {
68    #[instrument(name = "RQTransformer::transform", level = "debug", skip_all)]
69    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
70        if batch.column_by_name(RABIT_CODE_COLUMN).is_some() {
71            return Ok(batch.clone());
72        }
73
74        let residual_vectors = batch
75            .column_by_name(&self.vector_column)
76            .ok_or(Error::index(format!(
77                "RQ Transform: column {} not found in batch",
78                self.vector_column
79            )))?;
80        let residual_vectors = residual_vectors
81            .as_fixed_size_list_opt()
82            .ok_or(Error::index(format!(
83                "RQ Transform: column {} is not a fixed size list, got {}",
84                self.vector_column,
85                residual_vectors.data_type(),
86            )))?;
87
88        let dist_v_c = batch
89            .column_by_name(CENTROID_DIST_COLUMN)
90            .ok_or(Error::index(format!(
91                "RQ Transform: column {} not found in batch",
92                CENTROID_DIST_COLUMN
93            )))?;
94        let dist_v_c = dist_v_c.as_primitive::<Float32Type>();
95
96        let res_norm_square = match self.distance_type {
97            // for L2, |v-c|^2 is just the distance to the centroid
98            DistanceType::L2 => dist_v_c.clone(),
99            DistanceType::Dot => Float32Array::from(norm_squared_fsl(residual_vectors)),
100            _ => {
101                return Err(Error::index(format!(
102                    "RQ Transform: distance type {} not supported",
103                    self.distance_type
104                )));
105            }
106        };
107
108        let rq_codes = self.rq.quantize(&residual_vectors)?;
109        let codes_fsl = rq_codes.as_fixed_size_list();
110
111        let ip_rq_res = match residual_vectors.value_type() {
112            DataType::Float16 => Float32Array::from(
113                self.rq
114                    .codes_res_dot_dists::<Float16Type>(residual_vectors)?,
115            ),
116            DataType::Float32 => Float32Array::from(
117                self.rq
118                    .codes_res_dot_dists::<Float32Type>(residual_vectors)?,
119            ),
120            DataType::Float64 => Float32Array::from(
121                self.rq
122                    .codes_res_dot_dists::<Float64Type>(residual_vectors)?,
123            ),
124            _ => {
125                return Err(Error::index(format!(
126                    "RQ Transform: unsupported residual vector data type: {}",
127                    residual_vectors.data_type()
128                )));
129            }
130        };
131        debug_assert_eq!(codes_fsl.len(), batch.num_rows());
132
133        let add_factors = match self.distance_type {
134            DistanceType::L2 => res_norm_square.clone(),
135            DistanceType::Dot => {
136                // for dot, the add factor is `1 - v*c + |c|^2 = dist_v_c + |c|^2`
137                let part_ids = &batch[PART_ID_COLUMN];
138                let part_ids = part_ids.as_primitive::<UInt32Type>();
139                let centroids_norm_square = self.centroids_norm_square.as_ref().ok_or(
140                    Error::index("RQ Transform: centroids norm square not found".to_string()),
141                )?;
142                let centroids_norm_square =
143                    arrow::compute::take(centroids_norm_square, part_ids, None)?;
144                let centroids_norm_square = centroids_norm_square.as_primitive::<Float32Type>();
145                Float32Array::from_iter_values(
146                    dist_v_c
147                        .values()
148                        .iter()
149                        .zip(centroids_norm_square.values().iter())
150                        .map(|(dist_v_c, centroids_norm_square)| dist_v_c + centroids_norm_square),
151                )
152            }
153            _ => {
154                return Err(Error::index(format!(
155                    "RQ Transform: distance type {} not supported",
156                    self.distance_type
157                )));
158            }
159        };
160
161        let scale_factors = match self.distance_type {
162            DistanceType::L2 => Float32Array::from_iter_values(
163                res_norm_square.values().iter().zip(ip_rq_res.values()).map(
164                    |(res_norm_square, ip_rq_res)| {
165                        (-2.0 * res_norm_square)
166                            .div_checked(*ip_rq_res)
167                            .unwrap_or_default()
168                    },
169                ),
170            ),
171            DistanceType::Dot => Float32Array::from_iter_values(
172                res_norm_square.values().iter().zip(ip_rq_res.values()).map(
173                    |(res_norm_square, ip_rq_res)| {
174                        -res_norm_square.div_checked(*ip_rq_res).unwrap_or_default()
175                    },
176                ),
177            ),
178            _ => {
179                return Err(Error::index(format!(
180                    "RQ Transform: distance type {} not supported",
181                    self.distance_type
182                )));
183            }
184        };
185
186        let batch = batch.try_with_column(self.rq.field(), Arc::new(rq_codes))?;
187        let batch = batch
188            .try_with_column(ADD_FACTORS_FIELD.clone(), Arc::new(add_factors))?
189            .drop_column(CENTROID_DIST_COLUMN)?;
190        let batch = batch.try_with_column(SCALE_FACTORS_FIELD.clone(), Arc::new(scale_factors))?;
191
192        let batch = batch
193            .drop_column(&self.vector_column)?
194            .drop_column(CENTROID_DIST_COLUMN)?;
195        Ok(batch)
196    }
197}