1use std::sync::Arc;
5
6use arrow::array::AsArray;
7use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
8use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt8Array};
9use arrow_schema::{DataType, Field};
10use bitvec::prelude::{BitVec, Lsb0};
11use deepsize::DeepSizeOf;
12use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, FloatType};
13use lance_core::{Error, Result};
14use ndarray::{s, Axis};
15use num_traits::{AsPrimitive, FromPrimitive};
16use rand_distr::Distribution;
17use snafu::location;
18
19use crate::vector::bq::storage::{
20 RabitQuantizationMetadata, RabitQuantizationStorage, RABIT_CODE_COLUMN, RABIT_METADATA_KEY,
21};
22use crate::vector::bq::transform::{ADD_FACTORS_FIELD, SCALE_FACTORS_FIELD};
23use crate::vector::bq::RQBuildParams;
24use crate::vector::quantizer::{Quantization, Quantizer, QuantizerBuildParams};
25
26pub struct RabitBuildParams {
30 pub num_bits: u8,
31}
32
33impl Default for RabitBuildParams {
34 fn default() -> Self {
35 Self { num_bits: 1 }
36 }
37}
38
39impl QuantizerBuildParams for RabitBuildParams {
40 fn sample_size(&self) -> usize {
41 0
43 }
44}
45
46#[derive(Debug, Clone, DeepSizeOf)]
47pub struct RabitQuantizer {
48 metadata: RabitQuantizationMetadata,
49}
50
51impl RabitQuantizer {
52 pub fn new<T: ArrowFloatType>(num_bits: u8, dim: i32) -> Self {
53 let code_dim = dim * num_bits as i32;
56 let rotate_mat = random_orthogonal::<T>(code_dim as usize);
57 let (rotate_mat, _) = rotate_mat.into_raw_vec_and_offset();
58
59 let rotate_mat = match T::FLOAT_TYPE {
60 FloatType::Float16 | FloatType::Float32 | FloatType::Float64 => {
61 let rotate_mat = T::ArrayType::from(rotate_mat);
62 FixedSizeListArray::try_new_from_values(rotate_mat, code_dim).unwrap()
63 }
64 _ => unimplemented!("RabitQ does not support data type: {:?}", T::FLOAT_TYPE),
65 };
66
67 let metadata = RabitQuantizationMetadata {
68 rotate_mat: Some(rotate_mat),
69 rotate_mat_position: 0,
70 num_bits,
71 packed: false,
72 };
73 Self { metadata }
74 }
75
76 pub fn num_bits(&self) -> u8 {
77 self.metadata.num_bits
78 }
79
80 #[inline]
81 fn rotate_mat_flat<T: ArrowFloatType>(&self) -> &[T::Native] {
82 let rotate_mat = self.metadata.rotate_mat.as_ref().unwrap();
83 rotate_mat
84 .values()
85 .as_any()
86 .downcast_ref::<T::ArrayType>()
87 .unwrap()
88 .as_slice()
89 }
90
91 #[inline]
92 fn rotate_mat<T: ArrowFloatType>(&'_ self) -> ndarray::ArrayView2<'_, T::Native> {
93 let code_dim = self.code_dim();
94 ndarray::ArrayView2::from_shape((code_dim, code_dim), self.rotate_mat_flat::<T>()).unwrap()
95 }
96
97 pub fn dim(&self) -> usize {
98 self.code_dim() / self.metadata.num_bits as usize
99 }
100
101 pub fn codes_res_dot_dists<T: ArrowFloatType>(
103 &self,
104 residual_vectors: &FixedSizeListArray,
105 ) -> Result<Vec<f32>>
106 where
107 T::Native: AsPrimitive<f32>,
108 {
109 let dim = self.dim();
110 if residual_vectors.value_length() as usize != dim {
111 return Err(Error::invalid_input(
112 format!(
113 "Vector dimension mismatch: {} != {}",
114 residual_vectors.value_length(),
115 dim
116 ),
117 location!(),
118 ));
119 }
120
121 let vec_mat = ndarray::ArrayView2::from_shape(
123 (residual_vectors.len(), dim),
124 residual_vectors
125 .values()
126 .as_any()
127 .downcast_ref::<T::ArrayType>()
128 .unwrap()
129 .as_slice(),
130 )
131 .map_err(|e| Error::invalid_input(e.to_string(), location!()))?;
132 let vec_mat = vec_mat.t();
133
134 let rotate_mat = self.rotate_mat::<T>();
135 let rotate_mat = rotate_mat.slice(s![.., 0..dim]);
137 let rotated_vectors = rotate_mat.dot(&vec_mat);
138 let sqrt_dim = (dim as f32 * self.metadata.num_bits as f32).sqrt();
139 let norm_dists = rotated_vectors.mapv(|v| v.as_().abs()).sum_axis(Axis(0)) / sqrt_dim;
140 debug_assert_eq!(norm_dists.len(), residual_vectors.len());
141 Ok(norm_dists.to_vec())
142 }
143
144 fn transform<T: ArrowFloatType>(
145 &self,
146 residual_vectors: &FixedSizeListArray,
147 ) -> Result<ArrayRef>
148 where
149 T::Native: AsPrimitive<f32>,
150 {
151 let n = residual_vectors.len();
154 let dim = self.dim();
155 debug_assert_eq!(residual_vectors.values().len(), n * dim);
156
157 let vectors = ndarray::ArrayView2::from_shape(
158 (n, dim),
159 residual_vectors
160 .values()
161 .as_any()
162 .downcast_ref::<T::ArrayType>()
163 .unwrap()
164 .as_slice(),
165 )
166 .map_err(|e| Error::invalid_input(e.to_string(), location!()))?;
167 let vectors = vectors.t();
168 let rotate_mat = self.rotate_mat::<T>();
169 let rotate_mat = rotate_mat.slice(s![.., 0..dim]);
170 let rotated_vectors = rotate_mat.dot(&vectors);
171
172 let quantized_vectors = rotated_vectors.t().mapv(|v| v.as_().is_sign_positive());
173 let bv: BitVec<u8, Lsb0> = BitVec::from_iter(quantized_vectors);
174
175 let codes = UInt8Array::from(bv.into_vec());
176 debug_assert_eq!(codes.len(), n * self.code_dim() / u8::BITS as usize);
177 Ok(Arc::new(FixedSizeListArray::try_new_from_values(
178 codes,
179 self.code_dim() as i32 / u8::BITS as i32, )?))
181 }
182}
183
184impl Quantization for RabitQuantizer {
185 type BuildParams = RQBuildParams;
186 type Metadata = RabitQuantizationMetadata;
187 type Storage = RabitQuantizationStorage;
188
189 fn build(
190 data: &dyn Array,
191 _: lance_linalg::distance::DistanceType,
192 params: &Self::BuildParams,
193 ) -> Result<Self> {
194 let q = match data.as_fixed_size_list().value_type() {
195 DataType::Float16 => {
196 Self::new::<Float16Type>(params.num_bits, data.as_fixed_size_list().value_length())
197 }
198 DataType::Float32 => {
199 Self::new::<Float32Type>(params.num_bits, data.as_fixed_size_list().value_length())
200 }
201 DataType::Float64 => {
202 Self::new::<Float64Type>(params.num_bits, data.as_fixed_size_list().value_length())
203 }
204 dt => {
205 return Err(Error::invalid_input(
206 format!("Unsupported data type: {:?}", dt),
207 location!(),
208 ))
209 }
210 };
211 Ok(q)
212 }
213
214 fn retrain(&mut self, _data: &dyn Array) -> Result<()> {
215 Ok(())
216 }
217
218 fn code_dim(&self) -> usize {
219 self.metadata
220 .rotate_mat
221 .as_ref()
222 .map(|inv_p| inv_p.len())
223 .unwrap_or(0)
224 }
225
226 fn column(&self) -> &'static str {
227 RABIT_CODE_COLUMN
228 }
229
230 fn use_residual(_: lance_linalg::distance::DistanceType) -> bool {
231 true
232 }
233
234 fn quantize(&self, vectors: &dyn Array) -> Result<arrow_array::ArrayRef> {
235 let vectors = vectors.as_fixed_size_list();
236 match vectors.value_type() {
237 DataType::Float16 => self.transform::<Float16Type>(vectors),
238 DataType::Float32 => self.transform::<Float32Type>(vectors),
239 DataType::Float64 => self.transform::<Float64Type>(vectors),
240 value_type => Err(Error::invalid_input(
241 format!("Unsupported data type: {:?}", value_type),
242 location!(),
243 )),
244 }
245 }
246
247 fn metadata_key() -> &'static str {
248 RABIT_METADATA_KEY
249 }
250
251 fn quantization_type() -> crate::vector::quantizer::QuantizationType {
252 crate::vector::quantizer::QuantizationType::Rabit
253 }
254
255 fn metadata(
256 &self,
257 args: Option<crate::vector::quantizer::QuantizationMetadata>,
258 ) -> Self::Metadata {
259 let mut metadata = self.metadata.clone();
260 metadata.packed = args.map(|args| args.transposed).unwrap_or_default();
261 metadata
262 }
263
264 fn from_metadata(
265 metadata: &Self::Metadata,
266 _: lance_linalg::distance::DistanceType,
267 ) -> Result<Quantizer> {
268 Ok(Quantizer::Rabit(Self {
269 metadata: metadata.clone(),
270 }))
271 }
272
273 fn field(&self) -> Field {
274 Field::new(
275 RABIT_CODE_COLUMN,
276 DataType::FixedSizeList(
277 Arc::new(Field::new("item", DataType::UInt8, true)),
278 self.code_dim() as i32 / u8::BITS as i32, ),
280 true,
281 )
282 }
283
284 fn extra_fields(&self) -> Vec<Field> {
285 vec![ADD_FACTORS_FIELD.clone(), SCALE_FACTORS_FIELD.clone()]
286 }
287}
288
289impl TryFrom<Quantizer> for RabitQuantizer {
290 type Error = Error;
291
292 fn try_from(quantizer: Quantizer) -> Result<Self> {
293 match quantizer {
294 Quantizer::Rabit(quantizer) => Ok(quantizer),
295 _ => Err(Error::invalid_input(
296 "Cannot convert non-RabitQuantizer to RabitQuantizer",
297 location!(),
298 )),
299 }
300 }
301}
302
303impl From<RabitQuantizer> for Quantizer {
304 fn from(quantizer: RabitQuantizer) -> Self {
305 Self::Rabit(quantizer)
306 }
307}
308
309fn random_normal_matrix(n: usize) -> ndarray::Array2<f64> {
310 let mut rng = rand::rng();
311 let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
312 ndarray::Array2::from_shape_simple_fn((n, n), || normal.sample(&mut rng))
313}
314
315fn householder_qr(a: ndarray::Array2<f64>) -> (ndarray::Array2<f64>, ndarray::Array2<f64>) {
317 let (m, n) = a.dim();
318 let mut q = ndarray::Array2::eye(m);
319 let mut r = a;
320
321 for k in 0..n.min(m - 1) {
322 let mut x = r.slice(s![k.., k]).to_owned();
323 let x_norm = x.dot(&x).sqrt();
324
325 if x_norm < f64::EPSILON {
326 continue;
327 }
328
329 let sign = if x[0] >= 0.0 { 1.0 } else { -1.0 };
331 x[0] += sign * x_norm;
332 let u = &x / x.dot(&x).sqrt();
333
334 let mut u_outer = ndarray::Array2::zeros((m - k, m - k));
337 for i in 0..(m - k) {
338 for j in 0..(m - k) {
339 u_outer[[i, j]] = u[i] * u[j];
340 }
341 }
342 let h = ndarray::Array2::eye(m - k) - 2.0 * u_outer;
343
344 let r_block = r.slice(s![k.., k..]).to_owned();
346 let h_r = h.dot(&r_block);
347 r.slice_mut(s![k.., k..]).assign(&h_r);
348
349 let q_block = q.slice(s![.., k..]).to_owned();
351 let q_h = q_block.dot(&h);
352 q.slice_mut(s![.., k..]).assign(&q_h);
353 }
354
355 (q, r)
356}
357
358fn random_orthogonal<T: ArrowFloatType>(n: usize) -> ndarray::Array2<T::Native>
359where
360 T::Native: FromPrimitive,
361{
362 let a = random_normal_matrix(n);
363 let (q, _) = householder_qr(a);
364
365 q.mapv(|v| T::Native::from_f64(v).unwrap())
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use approx::assert_relative_eq;
373 use rstest::rstest;
374
375 #[rstest]
376 #[case(8)]
377 #[case(16)]
378 #[case(32)]
379 fn test_householder_qr(#[case] n: usize) {
380 let a = random_normal_matrix(n);
381 let (m, n) = a.dim();
382
383 let (q, r) = householder_qr(a.clone());
384
385 let q_t_q = q.t().dot(&q);
387 for i in 0..m {
388 for j in 0..m {
389 let expected = if i == j { 1.0 } else { 0.0 };
390 assert_relative_eq!(q_t_q[[i, j]], expected, epsilon = 1e-5);
391 }
392 }
393
394 let qr = q.dot(&r);
396 for i in 0..m {
397 for j in 0..n {
398 assert_relative_eq!(qr[[i, j]], a[[i, j]], epsilon = 1e-5);
399 }
400 }
401
402 for i in 1..n.min(m) {
404 for j in 0..i {
405 assert_relative_eq!(r[[i, j]], 0.0, epsilon = 1e-5);
406 }
407 }
408
409 assert_eq!(q.dim(), (m, m));
411 assert_eq!(r.dim(), (m, n));
412 }
413}