1use std::iter::once;
7use std::str::FromStr;
8use std::sync::Arc;
9
10use arrow_array::types::Float32Type;
11use arrow_array::{Array, ArrayRef, UInt8Array, cast::AsArray};
12use lance_core::{Error, Result};
13use num_traits::Float;
14use serde::{Deserialize, Serialize};
15
16use crate::vector::quantizer::QuantizerBuildParams;
17
18pub mod builder;
19pub mod rotation;
20pub mod storage;
21pub mod transform;
22
23#[derive(Clone, Default)]
24pub struct BinaryQuantization {}
25
26impl BinaryQuantization {
27 pub fn transform(&self, data: &dyn Array) -> Result<ArrayRef> {
29 let fsl = data
30 .as_fixed_size_list_opt()
31 .ok_or(Error::index(format!(
32 "Expect to be a float vector array, got: {:?}",
33 data.data_type()
34 )))?
35 .clone();
36
37 let data = fsl
38 .values()
39 .as_primitive_opt::<Float32Type>()
40 .ok_or(Error::index(format!(
41 "Expect to be a float32 vector array, got: {:?}",
42 fsl.values().data_type()
43 )))?;
44 let dim = fsl.value_length() as usize;
45 let code = data
46 .values()
47 .chunks_exact(dim)
48 .flat_map(binary_quantization)
49 .collect::<Vec<_>>();
50
51 Ok(Arc::new(UInt8Array::from(code)))
52 }
53}
54
55fn binary_quantization<T: Float>(data: &[T]) -> impl Iterator<Item = u8> + '_ {
59 let iter = data.chunks_exact(8);
60 iter.clone()
61 .map(|c| {
62 let mut bits: u8 = 0;
65 c.iter().enumerate().for_each(|(idx, v)| {
66 bits |= (v.is_sign_positive() as u8) << idx;
67 });
68 bits
69 })
70 .chain(once(0).map(move |_| {
71 let mut bits: u8 = 0;
72 iter.remainder().iter().enumerate().for_each(|(idx, v)| {
73 bits |= (v.is_sign_positive() as u8) << idx;
74 });
75 bits
76 }))
77}
78
79#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
80#[serde(rename_all = "snake_case")]
81pub enum RQRotationType {
82 #[default]
83 Fast,
84 Matrix,
85}
86
87impl FromStr for RQRotationType {
88 type Err = Error;
89
90 fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
91 match value.to_lowercase().as_str() {
92 "fast" | "fht_kac" | "fht-kac" => Ok(Self::Fast),
93 "matrix" | "dense" => Ok(Self::Matrix),
94 _ => Err(Error::invalid_input(format!(
95 "Unknown RQ rotation type: {}. Expected one of: fast, matrix",
96 value
97 ))),
98 }
99 }
100}
101
102#[derive(Clone, Debug, PartialEq, Eq)]
103pub struct RQBuildParams {
104 pub num_bits: u8,
105 pub rotation_type: RQRotationType,
106}
107
108impl RQBuildParams {
109 pub fn new(num_bits: u8) -> Self {
110 Self {
111 num_bits,
112 rotation_type: RQRotationType::default(),
113 }
114 }
115
116 pub fn with_rotation_type(num_bits: u8, rotation_type: RQRotationType) -> Self {
117 Self {
118 num_bits,
119 rotation_type,
120 }
121 }
122}
123
124impl QuantizerBuildParams for RQBuildParams {
125 fn sample_size(&self) -> usize {
126 0
127 }
128}
129
130impl Default for RQBuildParams {
131 fn default() -> Self {
132 Self {
133 num_bits: 1,
134 rotation_type: RQRotationType::default(),
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 use half::{bf16, f16};
144
145 fn test_bq<T: Float>() {
146 let data: Vec<T> = [1.0, -1.0, 1.0, -5.0, -7.0, -1.0, 1.0, -1.0, -0.2, 1.2, 3.2]
147 .iter()
148 .map(|&v| T::from(v).unwrap())
149 .collect();
150 let expected = vec![0b01000101, 0b00000110];
151 let result = binary_quantization(&data).collect::<Vec<_>>();
152 assert_eq!(result, expected);
153 }
154
155 #[test]
156 fn test_binary_quantization() {
157 test_bq::<bf16>();
158 test_bq::<f16>();
159 test_bq::<f32>();
160 test_bq::<f64>();
161 }
162
163 #[test]
164 fn test_rotation_type_parse() {
165 assert_eq!(
166 "fast".parse::<RQRotationType>().unwrap(),
167 RQRotationType::Fast
168 );
169 assert_eq!(
170 "matrix".parse::<RQRotationType>().unwrap(),
171 RQRotationType::Matrix
172 );
173 assert!("invalid".parse::<RQRotationType>().is_err());
174 }
175}