1use std::iter::once;
7use std::sync::Arc;
8
9use arrow_array::types::Float32Type;
10use arrow_array::{cast::AsArray, Array, ArrayRef, UInt8Array};
11use lance_core::{Error, Result};
12use num_traits::Float;
13use snafu::location;
14
15use crate::vector::quantizer::QuantizerBuildParams;
16
17pub mod builder;
18pub mod storage;
19pub mod transform;
20
21#[derive(Clone, Default)]
22pub struct BinaryQuantization {}
23
24impl BinaryQuantization {
25 pub fn transform(&self, data: &dyn Array) -> Result<ArrayRef> {
27 let fsl = data
28 .as_fixed_size_list_opt()
29 .ok_or(Error::Index {
30 message: format!(
31 "Expect to be a float vector array, got: {:?}",
32 data.data_type()
33 ),
34 location: location!(),
35 })?
36 .clone();
37
38 let data = fsl
39 .values()
40 .as_primitive_opt::<Float32Type>()
41 .ok_or(Error::Index {
42 message: format!(
43 "Expect to be a float32 vector array, got: {:?}",
44 fsl.values().data_type()
45 ),
46 location: location!(),
47 })?;
48 let dim = fsl.value_length() as usize;
49 let code = data
50 .values()
51 .chunks_exact(dim)
52 .flat_map(binary_quantization)
53 .collect::<Vec<_>>();
54
55 Ok(Arc::new(UInt8Array::from(code)))
56 }
57}
58
59fn binary_quantization<T: Float>(data: &[T]) -> impl Iterator<Item = u8> + '_ {
63 let iter = data.chunks_exact(8);
64 iter.clone()
65 .map(|c| {
66 let mut bits: u8 = 0;
69 c.iter().enumerate().for_each(|(idx, v)| {
70 bits |= (v.is_sign_positive() as u8) << idx;
71 });
72 bits
73 })
74 .chain(once(0).map(move |_| {
75 let mut bits: u8 = 0;
76 iter.remainder().iter().enumerate().for_each(|(idx, v)| {
77 bits |= (v.is_sign_positive() as u8) << idx;
78 });
79 bits
80 }))
81}
82
83#[derive(Clone, Debug, PartialEq, Eq)]
84pub struct RQBuildParams {
85 pub num_bits: u8,
86}
87
88impl RQBuildParams {
89 pub fn new(num_bits: u8) -> Self {
90 Self { num_bits }
91 }
92}
93
94impl QuantizerBuildParams for RQBuildParams {
95 fn sample_size(&self) -> usize {
96 0
97 }
98}
99
100impl Default for RQBuildParams {
101 fn default() -> Self {
102 Self { num_bits: 1 }
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 use half::{bf16, f16};
111
112 fn test_bq<T: Float>() {
113 let data: Vec<T> = [1.0, -1.0, 1.0, -5.0, -7.0, -1.0, 1.0, -1.0, -0.2, 1.2, 3.2]
114 .iter()
115 .map(|&v| T::from(v).unwrap())
116 .collect();
117 let expected = vec![0b01000101, 0b00000110];
118 let result = binary_quantization(&data).collect::<Vec<_>>();
119 assert_eq!(result, expected);
120 }
121
122 #[test]
123 fn test_binary_quantization() {
124 test_bq::<bf16>();
125 test_bq::<f16>();
126 test_bq::<f32>();
127 test_bq::<f64>();
128 }
129}