lance_index/vector/bq/
transform.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::fmt::{Debug, Formatter};
5use std::sync::{Arc, LazyLock};
6
7use arrow::array::AsArray;
8use arrow::datatypes::{Float16Type, Float32Type, Float64Type, UInt32Type};
9use arrow_array::{Array, ArrowNativeTypeOp, FixedSizeListArray, Float32Array, RecordBatch};
10use arrow_schema::DataType;
11use lance_arrow::RecordBatchExt;
12use lance_core::{Error, Result};
13use lance_linalg::distance::{norm_squared_fsl, DistanceType};
14use snafu::location;
15use tracing::instrument;
16
17use crate::vector::bq::builder::RabitQuantizer;
18use crate::vector::bq::storage::RABIT_CODE_COLUMN;
19use crate::vector::quantizer::Quantization;
20use crate::vector::transform::Transformer;
21use crate::vector::{CENTROID_DIST_COLUMN, PART_ID_COLUMN};
22
23// the inner product of quantized vector and the residual vector.
24pub const ADD_FACTORS_COLUMN: &str = "__add_factors";
25// the inner product of quantized vector and the centroid vector.
26pub const SCALE_FACTORS_COLUMN: &str = "__scale_factors";
27
28pub static ADD_FACTORS_FIELD: LazyLock<arrow_schema::Field> = LazyLock::new(|| {
29    arrow_schema::Field::new(ADD_FACTORS_COLUMN, arrow_schema::DataType::Float32, true)
30});
31pub static SCALE_FACTORS_FIELD: LazyLock<arrow_schema::Field> = LazyLock::new(|| {
32    arrow_schema::Field::new(SCALE_FACTORS_COLUMN, arrow_schema::DataType::Float32, true)
33});
34
35pub struct RQTransformer {
36    rq: RabitQuantizer,
37    distance_type: DistanceType,
38    centroids_norm_square: Option<Float32Array>,
39    vector_column: String,
40}
41
42impl RQTransformer {
43    pub fn new(
44        rq: RabitQuantizer,
45        distance_type: DistanceType,
46        centroids: FixedSizeListArray,
47        vector_column: impl Into<String>,
48    ) -> Self {
49        // for dot product, the add factor is `1 - v*c + |c|^2`, so we need to compute |c|^2
50        let centroids_norm_square = (distance_type == DistanceType::Dot)
51            .then(|| Float32Array::from(norm_squared_fsl(&centroids)));
52
53        Self {
54            rq,
55            distance_type,
56            centroids_norm_square,
57            vector_column: vector_column.into(),
58        }
59    }
60}
61
62impl Debug for RQTransformer {
63    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
64        write!(f, "RabitTransformer(vector_column={})", self.vector_column)
65    }
66}
67
68impl Transformer for RQTransformer {
69    #[instrument(name = "RQTransformer::transform", level = "debug", skip_all)]
70    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
71        if batch.column_by_name(RABIT_CODE_COLUMN).is_some() {
72            return Ok(batch.clone());
73        }
74
75        let residual_vectors = batch
76            .column_by_name(&self.vector_column)
77            .ok_or(Error::Index {
78                message: format!(
79                    "RQ Transform: column {} not found in batch",
80                    self.vector_column
81                ),
82                location: location!(),
83            })?;
84        let residual_vectors = residual_vectors
85            .as_fixed_size_list_opt()
86            .ok_or(Error::Index {
87                message: format!(
88                    "RQ Transform: column {} is not a fixed size list, got {}",
89                    self.vector_column,
90                    residual_vectors.data_type(),
91                ),
92                location: location!(),
93            })?;
94
95        let dist_v_c = batch
96            .column_by_name(CENTROID_DIST_COLUMN)
97            .ok_or(Error::Index {
98                message: format!(
99                    "RQ Transform: column {} not found in batch",
100                    CENTROID_DIST_COLUMN
101                ),
102                location: location!(),
103            })?;
104        let dist_v_c = dist_v_c.as_primitive::<Float32Type>();
105
106        let res_norm_square = match self.distance_type {
107            // for L2, |v-c|^2 is just the distance to the centroid
108            DistanceType::L2 => dist_v_c.clone(),
109            DistanceType::Dot => Float32Array::from(norm_squared_fsl(residual_vectors)),
110            _ => {
111                return Err(Error::Index {
112                    message: format!(
113                        "RQ Transform: distance type {} not supported",
114                        self.distance_type
115                    ),
116                    location: location!(),
117                });
118            }
119        };
120
121        let rq_codes = self.rq.quantize(&residual_vectors)?;
122        let codes_fsl = rq_codes.as_fixed_size_list();
123
124        let ip_rq_res = match residual_vectors.value_type() {
125            DataType::Float16 => Float32Array::from(
126                self.rq
127                    .codes_res_dot_dists::<Float16Type>(residual_vectors)?,
128            ),
129            DataType::Float32 => Float32Array::from(
130                self.rq
131                    .codes_res_dot_dists::<Float32Type>(residual_vectors)?,
132            ),
133            DataType::Float64 => Float32Array::from(
134                self.rq
135                    .codes_res_dot_dists::<Float64Type>(residual_vectors)?,
136            ),
137            _ => {
138                return Err(Error::Index {
139                    message: format!(
140                        "RQ Transform: unsupported residual vector data type: {}",
141                        residual_vectors.data_type()
142                    ),
143                    location: location!(),
144                });
145            }
146        };
147        debug_assert_eq!(codes_fsl.len(), batch.num_rows());
148
149        let add_factors = match self.distance_type {
150            DistanceType::L2 => res_norm_square.clone(),
151            DistanceType::Dot => {
152                // for dot, the add factor is `1 - v*c + |c|^2 = dist_v_c + |c|^2`
153                let part_ids = &batch[PART_ID_COLUMN];
154                let part_ids = part_ids.as_primitive::<UInt32Type>();
155                let centroids_norm_square =
156                    self.centroids_norm_square.as_ref().ok_or(Error::Index {
157                        message: "RQ Transform: centroids norm square not found".to_string(),
158                        location: location!(),
159                    })?;
160                let centroids_norm_square =
161                    arrow::compute::take(centroids_norm_square, part_ids, None)?;
162                let centroids_norm_square = centroids_norm_square.as_primitive::<Float32Type>();
163                Float32Array::from_iter_values(
164                    dist_v_c
165                        .values()
166                        .iter()
167                        .zip(centroids_norm_square.values().iter())
168                        .map(|(dist_v_c, centroids_norm_square)| dist_v_c + centroids_norm_square),
169                )
170            }
171            _ => {
172                return Err(Error::Index {
173                    message: format!(
174                        "RQ Transform: distance type {} not supported",
175                        self.distance_type
176                    ),
177                    location: location!(),
178                });
179            }
180        };
181
182        let scale_factors = match self.distance_type {
183            DistanceType::L2 => Float32Array::from_iter_values(
184                res_norm_square.values().iter().zip(ip_rq_res.values()).map(
185                    |(res_norm_square, ip_rq_res)| {
186                        (-2.0 * res_norm_square)
187                            .div_checked(*ip_rq_res)
188                            .unwrap_or_default()
189                    },
190                ),
191            ),
192            DistanceType::Dot => Float32Array::from_iter_values(
193                res_norm_square.values().iter().zip(ip_rq_res.values()).map(
194                    |(res_norm_square, ip_rq_res)| {
195                        -res_norm_square.div_checked(*ip_rq_res).unwrap_or_default()
196                    },
197                ),
198            ),
199            _ => {
200                return Err(Error::Index {
201                    message: format!(
202                        "RQ Transform: distance type {} not supported",
203                        self.distance_type
204                    ),
205                    location: location!(),
206                });
207            }
208        };
209
210        let batch = batch.try_with_column(self.rq.field(), Arc::new(rq_codes))?;
211        let batch = batch
212            .try_with_column(ADD_FACTORS_FIELD.clone(), Arc::new(add_factors))?
213            .drop_column(CENTROID_DIST_COLUMN)?;
214        let batch = batch.try_with_column(SCALE_FACTORS_FIELD.clone(), Arc::new(scale_factors))?;
215
216        let batch = batch
217            .drop_column(&self.vector_column)?
218            .drop_column(CENTROID_DIST_COLUMN)?;
219        Ok(batch)
220    }
221}