lance_index/vector/bq/
transform.rs1use std::fmt::{Debug, Formatter};
5use std::sync::{Arc, LazyLock};
6
7use arrow::array::AsArray;
8use arrow::datatypes::{Float16Type, Float32Type, Float64Type, UInt32Type};
9use arrow_array::{Array, ArrowNativeTypeOp, FixedSizeListArray, Float32Array, RecordBatch};
10use arrow_schema::DataType;
11use lance_arrow::RecordBatchExt;
12use lance_core::{Error, Result};
13use lance_linalg::distance::{norm_squared_fsl, DistanceType};
14use snafu::location;
15use tracing::instrument;
16
17use crate::vector::bq::builder::RabitQuantizer;
18use crate::vector::bq::storage::RABIT_CODE_COLUMN;
19use crate::vector::quantizer::Quantization;
20use crate::vector::transform::Transformer;
21use crate::vector::{CENTROID_DIST_COLUMN, PART_ID_COLUMN};
22
23pub const ADD_FACTORS_COLUMN: &str = "__add_factors";
25pub const SCALE_FACTORS_COLUMN: &str = "__scale_factors";
27
28pub static ADD_FACTORS_FIELD: LazyLock<arrow_schema::Field> = LazyLock::new(|| {
29 arrow_schema::Field::new(ADD_FACTORS_COLUMN, arrow_schema::DataType::Float32, true)
30});
31pub static SCALE_FACTORS_FIELD: LazyLock<arrow_schema::Field> = LazyLock::new(|| {
32 arrow_schema::Field::new(SCALE_FACTORS_COLUMN, arrow_schema::DataType::Float32, true)
33});
34
35pub struct RQTransformer {
36 rq: RabitQuantizer,
37 distance_type: DistanceType,
38 centroids_norm_square: Option<Float32Array>,
39 vector_column: String,
40}
41
42impl RQTransformer {
43 pub fn new(
44 rq: RabitQuantizer,
45 distance_type: DistanceType,
46 centroids: FixedSizeListArray,
47 vector_column: impl Into<String>,
48 ) -> Self {
49 let centroids_norm_square = (distance_type == DistanceType::Dot)
51 .then(|| Float32Array::from(norm_squared_fsl(¢roids)));
52
53 Self {
54 rq,
55 distance_type,
56 centroids_norm_square,
57 vector_column: vector_column.into(),
58 }
59 }
60}
61
62impl Debug for RQTransformer {
63 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
64 write!(f, "RabitTransformer(vector_column={})", self.vector_column)
65 }
66}
67
68impl Transformer for RQTransformer {
69 #[instrument(name = "RQTransformer::transform", level = "debug", skip_all)]
70 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
71 if batch.column_by_name(RABIT_CODE_COLUMN).is_some() {
72 return Ok(batch.clone());
73 }
74
75 let residual_vectors = batch
76 .column_by_name(&self.vector_column)
77 .ok_or(Error::Index {
78 message: format!(
79 "RQ Transform: column {} not found in batch",
80 self.vector_column
81 ),
82 location: location!(),
83 })?;
84 let residual_vectors = residual_vectors
85 .as_fixed_size_list_opt()
86 .ok_or(Error::Index {
87 message: format!(
88 "RQ Transform: column {} is not a fixed size list, got {}",
89 self.vector_column,
90 residual_vectors.data_type(),
91 ),
92 location: location!(),
93 })?;
94
95 let dist_v_c = batch
96 .column_by_name(CENTROID_DIST_COLUMN)
97 .ok_or(Error::Index {
98 message: format!(
99 "RQ Transform: column {} not found in batch",
100 CENTROID_DIST_COLUMN
101 ),
102 location: location!(),
103 })?;
104 let dist_v_c = dist_v_c.as_primitive::<Float32Type>();
105
106 let res_norm_square = match self.distance_type {
107 DistanceType::L2 => dist_v_c.clone(),
109 DistanceType::Dot => Float32Array::from(norm_squared_fsl(residual_vectors)),
110 _ => {
111 return Err(Error::Index {
112 message: format!(
113 "RQ Transform: distance type {} not supported",
114 self.distance_type
115 ),
116 location: location!(),
117 });
118 }
119 };
120
121 let rq_codes = self.rq.quantize(&residual_vectors)?;
122 let codes_fsl = rq_codes.as_fixed_size_list();
123
124 let ip_rq_res = match residual_vectors.value_type() {
125 DataType::Float16 => Float32Array::from(
126 self.rq
127 .codes_res_dot_dists::<Float16Type>(residual_vectors)?,
128 ),
129 DataType::Float32 => Float32Array::from(
130 self.rq
131 .codes_res_dot_dists::<Float32Type>(residual_vectors)?,
132 ),
133 DataType::Float64 => Float32Array::from(
134 self.rq
135 .codes_res_dot_dists::<Float64Type>(residual_vectors)?,
136 ),
137 _ => {
138 return Err(Error::Index {
139 message: format!(
140 "RQ Transform: unsupported residual vector data type: {}",
141 residual_vectors.data_type()
142 ),
143 location: location!(),
144 });
145 }
146 };
147 debug_assert_eq!(codes_fsl.len(), batch.num_rows());
148
149 let add_factors = match self.distance_type {
150 DistanceType::L2 => res_norm_square.clone(),
151 DistanceType::Dot => {
152 let part_ids = &batch[PART_ID_COLUMN];
154 let part_ids = part_ids.as_primitive::<UInt32Type>();
155 let centroids_norm_square =
156 self.centroids_norm_square.as_ref().ok_or(Error::Index {
157 message: "RQ Transform: centroids norm square not found".to_string(),
158 location: location!(),
159 })?;
160 let centroids_norm_square =
161 arrow::compute::take(centroids_norm_square, part_ids, None)?;
162 let centroids_norm_square = centroids_norm_square.as_primitive::<Float32Type>();
163 Float32Array::from_iter_values(
164 dist_v_c
165 .values()
166 .iter()
167 .zip(centroids_norm_square.values().iter())
168 .map(|(dist_v_c, centroids_norm_square)| dist_v_c + centroids_norm_square),
169 )
170 }
171 _ => {
172 return Err(Error::Index {
173 message: format!(
174 "RQ Transform: distance type {} not supported",
175 self.distance_type
176 ),
177 location: location!(),
178 });
179 }
180 };
181
182 let scale_factors = match self.distance_type {
183 DistanceType::L2 => Float32Array::from_iter_values(
184 res_norm_square.values().iter().zip(ip_rq_res.values()).map(
185 |(res_norm_square, ip_rq_res)| {
186 (-2.0 * res_norm_square)
187 .div_checked(*ip_rq_res)
188 .unwrap_or_default()
189 },
190 ),
191 ),
192 DistanceType::Dot => Float32Array::from_iter_values(
193 res_norm_square.values().iter().zip(ip_rq_res.values()).map(
194 |(res_norm_square, ip_rq_res)| {
195 -res_norm_square.div_checked(*ip_rq_res).unwrap_or_default()
196 },
197 ),
198 ),
199 _ => {
200 return Err(Error::Index {
201 message: format!(
202 "RQ Transform: distance type {} not supported",
203 self.distance_type
204 ),
205 location: location!(),
206 });
207 }
208 };
209
210 let batch = batch.try_with_column(self.rq.field(), Arc::new(rq_codes))?;
211 let batch = batch
212 .try_with_column(ADD_FACTORS_FIELD.clone(), Arc::new(add_factors))?
213 .drop_column(CENTROID_DIST_COLUMN)?;
214 let batch = batch.try_with_column(SCALE_FACTORS_FIELD.clone(), Arc::new(scale_factors))?;
215
216 let batch = batch
217 .drop_column(&self.vector_column)?
218 .drop_column(CENTROID_DIST_COLUMN)?;
219 Ok(batch)
220 }
221}