lance_index/vector/bq/
transform.rs1use std::fmt::{Debug, Formatter};
5use std::sync::{Arc, LazyLock};
6
7use arrow::array::AsArray;
8use arrow::datatypes::{Float16Type, Float32Type, Float64Type, UInt32Type};
9use arrow_array::{Array, ArrowNativeTypeOp, FixedSizeListArray, Float32Array, RecordBatch};
10use arrow_schema::DataType;
11use lance_arrow::RecordBatchExt;
12use lance_core::{Error, Result};
13use lance_linalg::distance::{DistanceType, norm_squared_fsl};
14use tracing::instrument;
15
16use crate::vector::bq::builder::RabitQuantizer;
17use crate::vector::bq::storage::RABIT_CODE_COLUMN;
18use crate::vector::quantizer::Quantization;
19use crate::vector::transform::Transformer;
20use crate::vector::{CENTROID_DIST_COLUMN, PART_ID_COLUMN};
21
22pub const ADD_FACTORS_COLUMN: &str = "__add_factors";
24pub const SCALE_FACTORS_COLUMN: &str = "__scale_factors";
26
27pub static ADD_FACTORS_FIELD: LazyLock<arrow_schema::Field> = LazyLock::new(|| {
28 arrow_schema::Field::new(ADD_FACTORS_COLUMN, arrow_schema::DataType::Float32, true)
29});
30pub static SCALE_FACTORS_FIELD: LazyLock<arrow_schema::Field> = LazyLock::new(|| {
31 arrow_schema::Field::new(SCALE_FACTORS_COLUMN, arrow_schema::DataType::Float32, true)
32});
33
34pub struct RQTransformer {
35 rq: RabitQuantizer,
36 distance_type: DistanceType,
37 centroids_norm_square: Option<Float32Array>,
38 vector_column: String,
39}
40
41impl RQTransformer {
42 pub fn new(
43 rq: RabitQuantizer,
44 distance_type: DistanceType,
45 centroids: FixedSizeListArray,
46 vector_column: impl Into<String>,
47 ) -> Self {
48 let centroids_norm_square = (distance_type == DistanceType::Dot)
50 .then(|| Float32Array::from(norm_squared_fsl(¢roids)));
51
52 Self {
53 rq,
54 distance_type,
55 centroids_norm_square,
56 vector_column: vector_column.into(),
57 }
58 }
59}
60
61impl Debug for RQTransformer {
62 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
63 write!(f, "RabitTransformer(vector_column={})", self.vector_column)
64 }
65}
66
67impl Transformer for RQTransformer {
68 #[instrument(name = "RQTransformer::transform", level = "debug", skip_all)]
69 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
70 if batch.column_by_name(RABIT_CODE_COLUMN).is_some() {
71 return Ok(batch.clone());
72 }
73
74 let residual_vectors = batch
75 .column_by_name(&self.vector_column)
76 .ok_or(Error::index(format!(
77 "RQ Transform: column {} not found in batch",
78 self.vector_column
79 )))?;
80 let residual_vectors = residual_vectors
81 .as_fixed_size_list_opt()
82 .ok_or(Error::index(format!(
83 "RQ Transform: column {} is not a fixed size list, got {}",
84 self.vector_column,
85 residual_vectors.data_type(),
86 )))?;
87
88 let dist_v_c = batch
89 .column_by_name(CENTROID_DIST_COLUMN)
90 .ok_or(Error::index(format!(
91 "RQ Transform: column {} not found in batch",
92 CENTROID_DIST_COLUMN
93 )))?;
94 let dist_v_c = dist_v_c.as_primitive::<Float32Type>();
95
96 let res_norm_square = match self.distance_type {
97 DistanceType::L2 => dist_v_c.clone(),
99 DistanceType::Dot => Float32Array::from(norm_squared_fsl(residual_vectors)),
100 _ => {
101 return Err(Error::index(format!(
102 "RQ Transform: distance type {} not supported",
103 self.distance_type
104 )));
105 }
106 };
107
108 let rq_codes = self.rq.quantize(&residual_vectors)?;
109 let codes_fsl = rq_codes.as_fixed_size_list();
110
111 let ip_rq_res = match residual_vectors.value_type() {
112 DataType::Float16 => Float32Array::from(
113 self.rq
114 .codes_res_dot_dists::<Float16Type>(residual_vectors)?,
115 ),
116 DataType::Float32 => Float32Array::from(
117 self.rq
118 .codes_res_dot_dists::<Float32Type>(residual_vectors)?,
119 ),
120 DataType::Float64 => Float32Array::from(
121 self.rq
122 .codes_res_dot_dists::<Float64Type>(residual_vectors)?,
123 ),
124 _ => {
125 return Err(Error::index(format!(
126 "RQ Transform: unsupported residual vector data type: {}",
127 residual_vectors.data_type()
128 )));
129 }
130 };
131 debug_assert_eq!(codes_fsl.len(), batch.num_rows());
132
133 let add_factors = match self.distance_type {
134 DistanceType::L2 => res_norm_square.clone(),
135 DistanceType::Dot => {
136 let part_ids = &batch[PART_ID_COLUMN];
138 let part_ids = part_ids.as_primitive::<UInt32Type>();
139 let centroids_norm_square = self.centroids_norm_square.as_ref().ok_or(
140 Error::index("RQ Transform: centroids norm square not found".to_string()),
141 )?;
142 let centroids_norm_square =
143 arrow::compute::take(centroids_norm_square, part_ids, None)?;
144 let centroids_norm_square = centroids_norm_square.as_primitive::<Float32Type>();
145 Float32Array::from_iter_values(
146 dist_v_c
147 .values()
148 .iter()
149 .zip(centroids_norm_square.values().iter())
150 .map(|(dist_v_c, centroids_norm_square)| dist_v_c + centroids_norm_square),
151 )
152 }
153 _ => {
154 return Err(Error::index(format!(
155 "RQ Transform: distance type {} not supported",
156 self.distance_type
157 )));
158 }
159 };
160
161 let scale_factors = match self.distance_type {
162 DistanceType::L2 => Float32Array::from_iter_values(
163 res_norm_square.values().iter().zip(ip_rq_res.values()).map(
164 |(res_norm_square, ip_rq_res)| {
165 (-2.0 * res_norm_square)
166 .div_checked(*ip_rq_res)
167 .unwrap_or_default()
168 },
169 ),
170 ),
171 DistanceType::Dot => Float32Array::from_iter_values(
172 res_norm_square.values().iter().zip(ip_rq_res.values()).map(
173 |(res_norm_square, ip_rq_res)| {
174 -res_norm_square.div_checked(*ip_rq_res).unwrap_or_default()
175 },
176 ),
177 ),
178 _ => {
179 return Err(Error::index(format!(
180 "RQ Transform: distance type {} not supported",
181 self.distance_type
182 )));
183 }
184 };
185
186 let batch = batch.try_with_column(self.rq.field(), Arc::new(rq_codes))?;
187 let batch = batch
188 .try_with_column(ADD_FACTORS_FIELD.clone(), Arc::new(add_factors))?
189 .drop_column(CENTROID_DIST_COLUMN)?;
190 let batch = batch.try_with_column(SCALE_FACTORS_FIELD.clone(), Arc::new(scale_factors))?;
191
192 let batch = batch
193 .drop_column(&self.vector_column)?
194 .drop_column(CENTROID_DIST_COLUMN)?;
195 Ok(batch)
196 }
197}