1use std::iter::once;
7use std::str::FromStr;
8use std::sync::Arc;
9
10use crate::pb::vector_index_details::RabitQuantization;
11use arrow_array::types::Float32Type;
12use arrow_array::{Array, ArrayRef, UInt8Array, cast::AsArray};
13use lance_core::{Error, Result};
14use num_traits::Float;
15use serde::{Deserialize, Serialize};
16
17use crate::vector::quantizer::QuantizerBuildParams;
18
19pub mod builder;
20pub mod rotation;
21pub mod storage;
22pub mod transform;
23
24#[derive(Clone, Default)]
25pub struct BinaryQuantization {}
26
27impl BinaryQuantization {
28 pub fn transform(&self, data: &dyn Array) -> Result<ArrayRef> {
30 let fsl = data
31 .as_fixed_size_list_opt()
32 .ok_or(Error::index(format!(
33 "Expect to be a float vector array, got: {:?}",
34 data.data_type()
35 )))?
36 .clone();
37
38 let data = fsl
39 .values()
40 .as_primitive_opt::<Float32Type>()
41 .ok_or(Error::index(format!(
42 "Expect to be a float32 vector array, got: {:?}",
43 fsl.values().data_type()
44 )))?;
45 let dim = fsl.value_length() as usize;
46 let code = data
47 .values()
48 .chunks_exact(dim)
49 .flat_map(binary_quantization)
50 .collect::<Vec<_>>();
51
52 Ok(Arc::new(UInt8Array::from(code)))
53 }
54}
55
56fn binary_quantization<T: Float>(data: &[T]) -> impl Iterator<Item = u8> + '_ {
60 let iter = data.chunks_exact(8);
61 iter.clone()
62 .map(|c| {
63 let mut bits: u8 = 0;
66 c.iter().enumerate().for_each(|(idx, v)| {
67 bits |= (v.is_sign_positive() as u8) << idx;
68 });
69 bits
70 })
71 .chain(once(0).map(move |_| {
72 let mut bits: u8 = 0;
73 iter.remainder().iter().enumerate().for_each(|(idx, v)| {
74 bits |= (v.is_sign_positive() as u8) << idx;
75 });
76 bits
77 }))
78}
79
80#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
81#[serde(rename_all = "snake_case")]
82pub enum RQRotationType {
83 #[default]
84 Fast,
85 Matrix,
86}
87
88impl FromStr for RQRotationType {
89 type Err = Error;
90
91 fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
92 match value.to_lowercase().as_str() {
93 "fast" | "fht_kac" | "fht-kac" => Ok(Self::Fast),
94 "matrix" | "dense" => Ok(Self::Matrix),
95 _ => Err(Error::invalid_input(format!(
96 "Unknown RQ rotation type: {}. Expected one of: fast, matrix",
97 value
98 ))),
99 }
100 }
101}
102
103#[derive(Clone, Debug, PartialEq, Eq)]
104pub struct RQBuildParams {
105 pub num_bits: u8,
106 pub rotation_type: RQRotationType,
107}
108
109impl RQBuildParams {
110 pub fn new(num_bits: u8) -> Self {
111 Self {
112 num_bits,
113 rotation_type: RQRotationType::default(),
114 }
115 }
116
117 pub fn with_rotation_type(num_bits: u8, rotation_type: RQRotationType) -> Self {
118 Self {
119 num_bits,
120 rotation_type,
121 }
122 }
123}
124
125impl From<&RQBuildParams> for RabitQuantization {
126 fn from(value: &RQBuildParams) -> Self {
127 use crate::pb::vector_index_details::rabit_quantization::RotationType;
128 Self {
129 num_bits: value.num_bits as u32,
130 rotation_type: match value.rotation_type {
131 RQRotationType::Fast => RotationType::Fast as i32,
132 RQRotationType::Matrix => RotationType::Matrix as i32,
133 },
134 }
135 }
136}
137
138impl QuantizerBuildParams for RQBuildParams {
139 fn sample_size(&self) -> usize {
140 0
141 }
142}
143
144impl Default for RQBuildParams {
145 fn default() -> Self {
146 Self {
147 num_bits: 1,
148 rotation_type: RQRotationType::default(),
149 }
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 use half::{bf16, f16};
158
159 fn test_bq<T: Float>() {
160 let data: Vec<T> = [1.0, -1.0, 1.0, -5.0, -7.0, -1.0, 1.0, -1.0, -0.2, 1.2, 3.2]
161 .iter()
162 .map(|&v| T::from(v).unwrap())
163 .collect();
164 let expected = vec![0b01000101, 0b00000110];
165 let result = binary_quantization(&data).collect::<Vec<_>>();
166 assert_eq!(result, expected);
167 }
168
169 #[test]
170 fn test_binary_quantization() {
171 test_bq::<bf16>();
172 test_bq::<f16>();
173 test_bq::<f32>();
174 test_bq::<f64>();
175 }
176
177 #[test]
178 fn test_rotation_type_parse() {
179 assert_eq!(
180 "fast".parse::<RQRotationType>().unwrap(),
181 RQRotationType::Fast
182 );
183 assert_eq!(
184 "matrix".parse::<RQRotationType>().unwrap(),
185 RQRotationType::Matrix
186 );
187 assert!("invalid".parse::<RQRotationType>().is_err());
188 }
189}