1use std::sync::Arc;
5
6use arrow::array::AsArray;
7use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
8use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt8Array};
9use arrow_schema::{DataType, Field};
10use bitvec::prelude::{BitVec, Lsb0};
11use deepsize::DeepSizeOf;
12use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, FloatType};
13use lance_core::{Error, Result};
14use ndarray::{Axis, ShapeBuilder, s};
15use num_traits::{AsPrimitive, FromPrimitive};
16use rand_distr::Distribution;
17use rayon::prelude::*;
18
19use crate::vector::bq::storage::{
20 RABIT_CODE_COLUMN, RABIT_METADATA_KEY, RabitQuantizationMetadata, RabitQuantizationStorage,
21};
22use crate::vector::bq::transform::{ADD_FACTORS_FIELD, SCALE_FACTORS_FIELD};
23use crate::vector::bq::{
24 RQBuildParams, RQRotationType,
25 rotation::{apply_fast_rotation, random_fast_rotation_signs},
26};
27use crate::vector::quantizer::{Quantization, Quantizer, QuantizerBuildParams};
28
29pub struct RabitBuildParams {
33 pub num_bits: u8,
34 pub rotation_type: RQRotationType,
35}
36
37impl Default for RabitBuildParams {
38 fn default() -> Self {
39 Self {
40 num_bits: 1,
41 rotation_type: RQRotationType::default(),
42 }
43 }
44}
45
46impl QuantizerBuildParams for RabitBuildParams {
47 fn sample_size(&self) -> usize {
48 0
50 }
51}
52
53#[derive(Debug, Clone, DeepSizeOf)]
54pub struct RabitQuantizer {
55 metadata: RabitQuantizationMetadata,
56}
57
58#[inline]
59fn pack_sign_bits(codes: &mut [u8], rotated: &[f32]) {
60 codes.fill(0);
61 for (bit_idx, value) in rotated.iter().enumerate() {
62 if value.is_sign_positive() {
63 codes[bit_idx / u8::BITS as usize] |= 1u8 << (bit_idx % u8::BITS as usize);
64 }
65 }
66}
67
68impl RabitQuantizer {
69 pub fn new<T: ArrowFloatType>(num_bits: u8, dim: i32) -> Self {
70 Self::new_with_rotation::<T>(num_bits, dim, RQRotationType::default())
71 }
72
73 pub fn new_with_rotation<T: ArrowFloatType>(
74 num_bits: u8,
75 dim: i32,
76 rotation_type: RQRotationType,
77 ) -> Self {
78 let code_dim = (dim * num_bits as i32) as usize;
79 let metadata = match rotation_type {
80 RQRotationType::Matrix => {
81 let rotate_mat = random_orthogonal::<T>(code_dim);
83 let (rotate_mat, _) = rotate_mat.into_raw_vec_and_offset();
84 let rotate_mat = match T::FLOAT_TYPE {
85 FloatType::Float16 | FloatType::Float32 | FloatType::Float64 => {
86 let rotate_mat = <T::ArrayType as FloatArray<T>>::from_values(rotate_mat);
87 FixedSizeListArray::try_new_from_values(rotate_mat, code_dim as i32)
88 .unwrap()
89 }
90 _ => unimplemented!("RabitQ does not support data type: {:?}", T::FLOAT_TYPE),
91 };
92 RabitQuantizationMetadata {
93 rotate_mat: Some(rotate_mat),
94 rotate_mat_position: None,
95 fast_rotation_signs: None,
96 rotation_type,
97 code_dim: code_dim as u32,
98 num_bits,
99 packed: false,
100 }
101 }
102 RQRotationType::Fast => RabitQuantizationMetadata {
103 rotate_mat: None,
104 rotate_mat_position: None,
105 fast_rotation_signs: Some(random_fast_rotation_signs(code_dim)),
106 rotation_type,
107 code_dim: code_dim as u32,
108 num_bits,
109 packed: false,
110 },
111 };
112 Self { metadata }
113 }
114
115 pub fn num_bits(&self) -> u8 {
116 self.metadata.num_bits
117 }
118
119 pub fn rotation_type(&self) -> RQRotationType {
120 self.metadata.rotation_type
121 }
122
123 #[inline]
124 fn fast_rotation_signs(&self) -> &[u8] {
125 self.metadata
126 .fast_rotation_signs
127 .as_ref()
128 .expect("RabitQ fast rotation signs missing")
129 .as_slice()
130 }
131
132 #[inline]
133 fn rotate_mat_flat<T: ArrowFloatType>(&self) -> &[T::Native] {
134 let rotate_mat = self.metadata.rotate_mat.as_ref().unwrap();
135 rotate_mat
136 .values()
137 .as_any()
138 .downcast_ref::<T::ArrayType>()
139 .unwrap()
140 .as_slice()
141 }
142
143 #[inline]
144 fn rotate_mat<T: ArrowFloatType>(&'_ self) -> ndarray::ArrayView2<'_, T::Native> {
145 let code_dim = self.code_dim();
146 ndarray::ArrayView2::from_shape((code_dim, code_dim), self.rotate_mat_flat::<T>()).unwrap()
147 }
148
149 fn rotate_vectors<T: ArrowFloatType>(
150 &self,
151 vectors: ndarray::ArrayView2<'_, T::Native>,
152 ) -> ndarray::Array2<f32>
153 where
154 T::Native: AsPrimitive<f32>,
155 {
156 let dim = vectors.nrows();
157 let code_dim = self.code_dim();
158 match self.rotation_type() {
159 RQRotationType::Matrix => {
160 let rotate_mat = self.rotate_mat::<T>();
161 let rotate_mat = rotate_mat.slice(s![.., 0..dim]);
162 rotate_mat.dot(&vectors).mapv(|v| v.as_())
163 }
164 RQRotationType::Fast => {
165 let signs = self.fast_rotation_signs();
166 let ncols = vectors.ncols();
167 let mut rotated_data = vec![0.0f32; code_dim * ncols];
168 rotated_data
169 .par_chunks_mut(code_dim)
170 .enumerate()
171 .for_each_init(
172 || vec![0.0f32; code_dim],
173 |scratch, (col_idx, dst)| {
174 let column = vectors.column(col_idx);
175 let input = column
176 .as_slice()
177 .expect("RabitQ input vectors should be contiguous");
178 apply_fast_rotation(input, scratch, signs);
179 dst.copy_from_slice(scratch);
180 },
181 );
182
183 ndarray::Array2::from_shape_vec((code_dim, ncols).f(), rotated_data).unwrap()
184 }
185 }
186 }
187
188 pub fn dim(&self) -> usize {
189 self.code_dim() / self.metadata.num_bits as usize
190 }
191
192 pub fn codes_res_dot_dists<T: ArrowFloatType>(
194 &self,
195 residual_vectors: &FixedSizeListArray,
196 ) -> Result<Vec<f32>>
197 where
198 T::Native: AsPrimitive<f32> + Sync,
199 {
200 let dim = self.dim();
201 if residual_vectors.value_length() as usize != dim {
202 return Err(Error::invalid_input(format!(
203 "Vector dimension mismatch: {} != {}",
204 residual_vectors.value_length(),
205 dim
206 )));
207 }
208
209 let sqrt_dim = (dim as f32 * self.metadata.num_bits as f32).sqrt();
210 let values = residual_vectors
211 .values()
212 .as_any()
213 .downcast_ref::<T::ArrayType>()
214 .unwrap()
215 .as_slice();
216
217 match self.rotation_type() {
218 RQRotationType::Matrix => {
219 let vec_mat =
221 ndarray::ArrayView2::from_shape((residual_vectors.len(), dim), values)
222 .map_err(|e| Error::invalid_input(e.to_string()))?;
223 let vec_mat = vec_mat.t();
224 let rotated_vectors = self.rotate_vectors::<T>(vec_mat);
225 let norm_dists = rotated_vectors.mapv(f32::abs).sum_axis(Axis(0)) / sqrt_dim;
226 debug_assert_eq!(norm_dists.len(), residual_vectors.len());
227 Ok(norm_dists.to_vec())
228 }
229 RQRotationType::Fast => {
230 let code_dim = self.code_dim();
231 let signs = self.fast_rotation_signs();
232 let mut norm_dists = vec![0.0f32; residual_vectors.len()];
233 norm_dists
234 .par_iter_mut()
235 .zip(values.par_chunks_exact(dim))
236 .for_each_init(
237 || vec![0.0f32; code_dim],
238 |scratch, (dst, input)| {
239 apply_fast_rotation(input, scratch, signs);
240 *dst = scratch.iter().map(|v| v.abs()).sum::<f32>() / sqrt_dim;
241 },
242 );
243 Ok(norm_dists)
244 }
245 }
246 }
247
248 fn transform<T: ArrowFloatType>(
249 &self,
250 residual_vectors: &FixedSizeListArray,
251 ) -> Result<ArrayRef>
252 where
253 T::Native: AsPrimitive<f32> + Sync,
254 {
255 let n = residual_vectors.len();
258 let dim = self.dim();
259 debug_assert_eq!(residual_vectors.values().len(), n * dim);
260 let values = residual_vectors
261 .values()
262 .as_any()
263 .downcast_ref::<T::ArrayType>()
264 .unwrap()
265 .as_slice();
266 let code_dim = self.code_dim();
267 let code_bytes = code_dim / u8::BITS as usize;
268
269 match self.rotation_type() {
270 RQRotationType::Matrix => {
271 let vectors = ndarray::ArrayView2::from_shape((n, dim), values)
272 .map_err(|e| Error::invalid_input(e.to_string()))?;
273 let vectors = vectors.t();
274 let rotated_vectors = self.rotate_vectors::<T>(vectors);
275
276 let quantized_vectors = rotated_vectors.t().mapv(|v| v.is_sign_positive());
277 let bv: BitVec<u8, Lsb0> = BitVec::from_iter(quantized_vectors);
278
279 let codes = UInt8Array::from(bv.into_vec());
280 debug_assert_eq!(codes.len(), n * code_bytes);
281 Ok(Arc::new(FixedSizeListArray::try_new_from_values(
282 codes,
283 code_bytes as i32, )?))
285 }
286 RQRotationType::Fast => {
287 let signs = self.fast_rotation_signs();
288 let mut encoded_codes = vec![0u8; n * code_bytes];
289 encoded_codes
290 .par_chunks_mut(code_bytes)
291 .zip(values.par_chunks_exact(dim))
292 .for_each_init(
293 || vec![0.0f32; code_dim],
294 |scratch, (code_dst, input)| {
295 apply_fast_rotation(input, scratch, signs);
296 pack_sign_bits(code_dst, scratch);
297 },
298 );
299 let codes = UInt8Array::from(encoded_codes);
300 debug_assert_eq!(codes.len(), n * code_bytes);
301 Ok(Arc::new(FixedSizeListArray::try_new_from_values(
302 codes,
303 code_bytes as i32,
304 )?))
305 }
306 }
307 }
308}
309
310impl Quantization for RabitQuantizer {
311 type BuildParams = RQBuildParams;
312 type Metadata = RabitQuantizationMetadata;
313 type Storage = RabitQuantizationStorage;
314
315 fn build(
316 data: &dyn Array,
317 _: lance_linalg::distance::DistanceType,
318 params: &Self::BuildParams,
319 ) -> Result<Self> {
320 let dim = data.as_fixed_size_list().value_length() as usize;
321 if !dim.is_multiple_of(u8::BITS as usize) {
322 return Err(Error::invalid_input(
323 "vector dimension must be divisible by 8 for IVF_RQ",
324 ));
325 }
326
327 let q = match data.as_fixed_size_list().value_type() {
328 DataType::Float16 => Self::new_with_rotation::<Float16Type>(
329 params.num_bits,
330 data.as_fixed_size_list().value_length(),
331 params.rotation_type,
332 ),
333 DataType::Float32 => Self::new_with_rotation::<Float32Type>(
334 params.num_bits,
335 data.as_fixed_size_list().value_length(),
336 params.rotation_type,
337 ),
338 DataType::Float64 => Self::new_with_rotation::<Float64Type>(
339 params.num_bits,
340 data.as_fixed_size_list().value_length(),
341 params.rotation_type,
342 ),
343 dt => {
344 return Err(Error::invalid_input(format!(
345 "Unsupported data type: {:?}",
346 dt
347 )));
348 }
349 };
350 Ok(q)
351 }
352
353 fn retrain(&mut self, _data: &dyn Array) -> Result<()> {
354 Ok(())
355 }
356
357 fn code_dim(&self) -> usize {
358 if self.metadata.code_dim > 0 {
359 self.metadata.code_dim as usize
360 } else {
361 self.metadata
362 .rotate_mat
363 .as_ref()
364 .map(|rotate_mat| rotate_mat.len())
365 .unwrap_or(0)
366 }
367 }
368
369 fn column(&self) -> &'static str {
370 RABIT_CODE_COLUMN
371 }
372
373 fn use_residual(_: lance_linalg::distance::DistanceType) -> bool {
374 true
375 }
376
377 fn quantize(&self, vectors: &dyn Array) -> Result<arrow_array::ArrayRef> {
378 let vectors = vectors.as_fixed_size_list();
379 match vectors.value_type() {
380 DataType::Float16 => self.transform::<Float16Type>(vectors),
381 DataType::Float32 => self.transform::<Float32Type>(vectors),
382 DataType::Float64 => self.transform::<Float64Type>(vectors),
383 value_type => Err(Error::invalid_input(format!(
384 "Unsupported data type: {:?}",
385 value_type
386 ))),
387 }
388 }
389
390 fn metadata_key() -> &'static str {
391 RABIT_METADATA_KEY
392 }
393
394 fn quantization_type() -> crate::vector::quantizer::QuantizationType {
395 crate::vector::quantizer::QuantizationType::Rabit
396 }
397
398 fn metadata(
399 &self,
400 args: Option<crate::vector::quantizer::QuantizationMetadata>,
401 ) -> Self::Metadata {
402 let mut metadata = self.metadata.clone();
403 metadata.packed = args.map(|args| args.transposed).unwrap_or_default();
404 metadata
405 }
406
407 fn from_metadata(
408 metadata: &Self::Metadata,
409 _: lance_linalg::distance::DistanceType,
410 ) -> Result<Quantizer> {
411 Ok(Quantizer::Rabit(Self {
412 metadata: metadata.clone(),
413 }))
414 }
415
416 fn field(&self) -> Field {
417 Field::new(
418 RABIT_CODE_COLUMN,
419 DataType::FixedSizeList(
420 Arc::new(Field::new("item", DataType::UInt8, true)),
421 self.code_dim() as i32 / u8::BITS as i32, ),
423 true,
424 )
425 }
426
427 fn extra_fields(&self) -> Vec<Field> {
428 vec![ADD_FACTORS_FIELD.clone(), SCALE_FACTORS_FIELD.clone()]
429 }
430}
431
432impl TryFrom<Quantizer> for RabitQuantizer {
433 type Error = Error;
434
435 fn try_from(quantizer: Quantizer) -> Result<Self> {
436 match quantizer {
437 Quantizer::Rabit(quantizer) => Ok(quantizer),
438 _ => Err(Error::invalid_input(
439 "Cannot convert non-RabitQuantizer to RabitQuantizer",
440 )),
441 }
442 }
443}
444
445impl From<RabitQuantizer> for Quantizer {
446 fn from(quantizer: RabitQuantizer) -> Self {
447 Self::Rabit(quantizer)
448 }
449}
450
451fn random_normal_matrix(n: usize) -> ndarray::Array2<f64> {
452 let mut rng = rand::rng();
453 let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
454 ndarray::Array2::from_shape_simple_fn((n, n), || normal.sample(&mut rng))
455}
456
457fn householder_qr(a: ndarray::Array2<f64>) -> (ndarray::Array2<f64>, ndarray::Array2<f64>) {
459 let (m, n) = a.dim();
460 let mut q = ndarray::Array2::eye(m);
461 let mut r = a;
462
463 for k in 0..n.min(m - 1) {
464 let mut x = r.slice(s![k.., k]).to_owned();
465 let x_norm = x.dot(&x).sqrt();
466
467 if x_norm < f64::EPSILON {
468 continue;
469 }
470
471 let sign = if x[0] >= 0.0 { 1.0 } else { -1.0 };
473 x[0] += sign * x_norm;
474 let u = &x / x.dot(&x).sqrt();
475
476 let mut u_outer = ndarray::Array2::zeros((m - k, m - k));
479 for i in 0..(m - k) {
480 for j in 0..(m - k) {
481 u_outer[[i, j]] = u[i] * u[j];
482 }
483 }
484 let h = ndarray::Array2::eye(m - k) - 2.0 * u_outer;
485
486 let r_block = r.slice(s![k.., k..]).to_owned();
488 let h_r = h.dot(&r_block);
489 r.slice_mut(s![k.., k..]).assign(&h_r);
490
491 let q_block = q.slice(s![.., k..]).to_owned();
493 let q_h = q_block.dot(&h);
494 q.slice_mut(s![.., k..]).assign(&q_h);
495 }
496
497 (q, r)
498}
499
500fn random_orthogonal<T: ArrowFloatType>(n: usize) -> ndarray::Array2<T::Native>
501where
502 T::Native: FromPrimitive,
503{
504 let a = random_normal_matrix(n);
505 let (q, _) = householder_qr(a);
506
507 q.mapv(|v| T::Native::from_f64(v).unwrap())
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514 use approx::assert_relative_eq;
515 use arrow::datatypes::Float32Type;
516 use arrow_array::{FixedSizeListArray, Float32Array};
517 use lance_linalg::distance::DistanceType;
518 use rstest::rstest;
519
520 #[rstest]
521 #[case(8)]
522 #[case(16)]
523 #[case(32)]
524 fn test_householder_qr(#[case] n: usize) {
525 let a = random_normal_matrix(n);
526 let (m, n) = a.dim();
527
528 let (q, r) = householder_qr(a.clone());
529
530 let q_t_q = q.t().dot(&q);
532 for i in 0..m {
533 for j in 0..m {
534 let expected = if i == j { 1.0 } else { 0.0 };
535 assert_relative_eq!(q_t_q[[i, j]], expected, epsilon = 1e-5);
536 }
537 }
538
539 let qr = q.dot(&r);
541 for i in 0..m {
542 for j in 0..n {
543 assert_relative_eq!(qr[[i, j]], a[[i, j]], epsilon = 1e-5);
544 }
545 }
546
547 for i in 1..n.min(m) {
549 for j in 0..i {
550 assert_relative_eq!(r[[i, j]], 0.0, epsilon = 1e-5);
551 }
552 }
553
554 assert_eq!(q.dim(), (m, m));
556 assert_eq!(r.dim(), (m, n));
557 }
558
559 #[test]
560 fn test_rabit_quantizer_rotation_modes() {
561 let fast_q = RabitQuantizer::new_with_rotation::<Float32Type>(1, 128, RQRotationType::Fast);
562 assert_eq!(fast_q.rotation_type(), RQRotationType::Fast);
563 assert_eq!(fast_q.dim(), 128);
564
565 let matrix_q =
566 RabitQuantizer::new_with_rotation::<Float32Type>(1, 128, RQRotationType::Matrix);
567 assert_eq!(matrix_q.rotation_type(), RQRotationType::Matrix);
568 assert_eq!(matrix_q.dim(), 128);
569 }
570
571 #[test]
572 fn test_rabit_quantizer_requires_dim_divisible_by_8() {
573 let vectors = Float32Array::from(vec![0.0f32; 4 * 30]);
574 let fsl = FixedSizeListArray::try_new_from_values(vectors, 30).unwrap();
575 let params = RQBuildParams::new(1);
576
577 let err = RabitQuantizer::build(&fsl, DistanceType::L2, ¶ms).unwrap_err();
578 assert!(
579 err.to_string()
580 .contains("vector dimension must be divisible by 8 for IVF_RQ"),
581 "{}",
582 err
583 );
584 }
585}