1use std::{ops::Range, sync::Arc};
5
6use arrow::array::AsArray;
7use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
8use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt8Array};
9
10use arrow_schema::{DataType, Field};
11use builder::SQBuildParams;
12use deepsize::DeepSizeOf;
13use itertools::Itertools;
14use lance_arrow::*;
15use lance_core::{Error, Result};
16use lance_linalg::distance::DistanceType;
17use num_traits::*;
18use storage::{SQ_METADATA_KEY, ScalarQuantizationMetadata, ScalarQuantizationStorage};
19
20use super::SQ_CODE_COLUMN;
21use super::quantizer::{Quantization, QuantizationMetadata, QuantizationType, Quantizer};
22
23pub mod builder;
24pub mod storage;
25pub mod transform;
26
27#[derive(Debug, Clone)]
32pub struct ScalarQuantizer {
33 metadata: ScalarQuantizationMetadata,
34}
35
36impl DeepSizeOf for ScalarQuantizer {
37 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
38 0
39 }
40}
41
42impl ScalarQuantizer {
43 pub fn new(num_bits: u16, dim: usize) -> Self {
44 Self {
45 metadata: ScalarQuantizationMetadata {
46 num_bits,
47 dim,
48 bounds: Range::<f64> {
49 start: f64::MAX,
50 end: f64::MIN,
51 },
52 },
53 }
54 }
55
56 pub fn with_bounds(num_bits: u16, dim: usize, bounds: Range<f64>) -> Self {
57 let mut sq = Self::new(num_bits, dim);
58 sq.metadata.bounds = bounds;
59 sq
60 }
61
62 pub fn num_bits(&self) -> u16 {
63 self.metadata.num_bits
64 }
65
66 pub fn update_bounds<T: ArrowFloatType>(
67 &mut self,
68 vectors: &FixedSizeListArray,
69 ) -> Result<Range<f64>> {
70 let data = vectors
71 .values()
72 .as_any()
73 .downcast_ref::<T::ArrayType>()
74 .ok_or(Error::index(format!(
75 "Expect to be a float vector array, got: {:?}",
76 vectors.value_type()
77 )))?
78 .as_slice();
79
80 self.metadata.bounds = data.iter().fold(self.metadata.bounds.clone(), |f, v| {
81 f.start.min(v.as_())..f.end.max(v.as_())
82 });
83
84 Ok(self.metadata.bounds.clone())
85 }
86
87 pub fn transform<T: ArrowFloatType>(&self, data: &dyn Array) -> Result<ArrayRef> {
88 let fsl = data
89 .as_fixed_size_list_opt()
90 .ok_or(Error::index(format!(
91 "Expect to be a FixedSizeList<float> vector array, got: {:?} array",
92 data.data_type()
93 )))?
94 .clone();
95 let data = fsl
96 .values()
97 .as_any()
98 .downcast_ref::<T::ArrayType>()
99 .ok_or(Error::index(format!(
100 "Expect to be a float vector array, got: {:?}",
101 fsl.value_type()
102 )))?
103 .as_slice();
104
105 let builder: Vec<u8> = scale_to_u8::<T>(data, &self.metadata.bounds);
107
108 Ok(Arc::new(FixedSizeListArray::try_new_from_values(
109 UInt8Array::from(builder),
110 fsl.value_length(),
111 )?))
112 }
113
114 pub fn bounds(&self) -> Range<f64> {
115 self.metadata.bounds.clone()
116 }
117
118 pub fn use_residual(&self) -> bool {
120 false
121 }
122}
123
124impl TryFrom<Quantizer> for ScalarQuantizer {
125 type Error = Error;
126 fn try_from(value: Quantizer) -> Result<Self> {
127 match value {
128 Quantizer::Scalar(sq) => Ok(sq),
129 _ => Err(Error::index("Expect to be a ScalarQuantizer".to_string())),
130 }
131 }
132}
133
134impl Quantization for ScalarQuantizer {
135 type BuildParams = SQBuildParams;
136 type Metadata = ScalarQuantizationMetadata;
137 type Storage = ScalarQuantizationStorage;
138
139 fn build(data: &dyn Array, _: DistanceType, params: &Self::BuildParams) -> Result<Self> {
140 let fsl = data.as_fixed_size_list_opt().ok_or(Error::index(format!(
141 "SQ builder: input is not a FixedSizeList: {}",
142 data.data_type()
143 )))?;
144
145 let mut quantizer = Self::new(params.num_bits, fsl.value_length() as usize);
146
147 match fsl.value_type() {
148 DataType::Float16 => {
149 quantizer.update_bounds::<Float16Type>(fsl)?;
150 }
151 DataType::Float32 => {
152 quantizer.update_bounds::<Float32Type>(fsl)?;
153 }
154 DataType::Float64 => {
155 quantizer.update_bounds::<Float64Type>(fsl)?;
156 }
157 _ => {
158 return Err(Error::index(format!(
159 "SQ builder: unsupported data type: {}",
160 fsl.value_type()
161 )));
162 }
163 }
164
165 Ok(quantizer)
166 }
167
168 fn retrain(&mut self, data: &dyn Array) -> Result<()> {
169 let fsl = data.as_fixed_size_list_opt().ok_or(Error::index(format!(
170 "SQ retrain: input is not a FixedSizeList: {}",
171 data.data_type()
172 )))?;
173
174 match fsl.value_type() {
175 DataType::Float16 => {
176 self.update_bounds::<Float16Type>(fsl)?;
177 }
178 DataType::Float32 => {
179 self.update_bounds::<Float32Type>(fsl)?;
180 }
181 DataType::Float64 => {
182 self.update_bounds::<Float64Type>(fsl)?;
183 }
184 value_type => {
185 return Err(Error::invalid_input(format!(
186 "unsupported data type {} for scalar quantizer",
187 value_type
188 )));
189 }
190 }
191 Ok(())
192 }
193
194 fn code_dim(&self) -> usize {
195 self.metadata.dim
196 }
197
198 fn column(&self) -> &'static str {
199 SQ_CODE_COLUMN
200 }
201
202 fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
203 match vectors.as_fixed_size_list().value_type() {
204 DataType::Float16 => self.transform::<Float16Type>(vectors),
205 DataType::Float32 => self.transform::<Float32Type>(vectors),
206 DataType::Float64 => self.transform::<Float64Type>(vectors),
207 value_type => Err(Error::invalid_input(format!(
208 "unsupported data type {} for scalar quantizer",
209 value_type
210 ))),
211 }
212 }
213
214 fn metadata_key() -> &'static str {
215 SQ_METADATA_KEY
216 }
217
218 fn quantization_type() -> QuantizationType {
219 QuantizationType::Scalar
220 }
221
222 fn metadata(&self, _: Option<QuantizationMetadata>) -> Self::Metadata {
223 self.metadata.clone()
224 }
225
226 fn from_metadata(metadata: &Self::Metadata, _: DistanceType) -> Result<Quantizer> {
227 Ok(Quantizer::Scalar(Self {
228 metadata: metadata.clone(),
229 }))
230 }
231
232 fn field(&self) -> Field {
233 Field::new(
234 SQ_CODE_COLUMN,
235 DataType::FixedSizeList(
236 Arc::new(Field::new("item", DataType::UInt8, true)),
237 self.code_dim() as i32,
238 ),
239 true,
240 )
241 }
242}
243
244pub(crate) fn scale_to_u8<T: ArrowFloatType>(values: &[T::Native], bounds: &Range<f64>) -> Vec<u8> {
245 if bounds.start == bounds.end {
246 return vec![0; values.len()];
247 }
248
249 let range = bounds.end - bounds.start;
250 values
251 .iter()
252 .map(|&v| {
253 let v = v.to_f64().unwrap();
254 let v = (v - bounds.start) * 255.0 / range;
255 v as u8 })
257 .collect_vec()
258}
259
260#[cfg(test)]
261mod tests {
262 use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
263 use arrow_array::{Float16Array, Float32Array, Float64Array};
264 use half::f16;
265
266 use super::*;
267
268 #[tokio::test]
269 async fn test_f16_sq8() {
270 let float_values = Vec::from_iter((0..16).map(|v| f16::from_usize(v).unwrap()));
271 let float_array = Float16Array::from_iter_values(float_values.clone());
272 let vectors =
273 FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
274 .unwrap();
275 let mut sq = ScalarQuantizer::new(8, float_values.len());
276
277 sq.update_bounds::<Float16Type>(&vectors).unwrap();
278 assert_eq!(sq.bounds().start, float_values[0].to_f64());
279 assert_eq!(
280 sq.bounds().end,
281 float_values.last().cloned().unwrap().to_f64()
282 );
283
284 let sq_code = sq.transform::<Float16Type>(&vectors).unwrap();
285 let sq_values = sq_code
286 .as_fixed_size_list()
287 .values()
288 .as_any()
289 .downcast_ref::<UInt8Array>()
290 .unwrap();
291
292 sq_values.values().iter().enumerate().for_each(|(i, v)| {
293 assert_eq!(*v, (i * 17) as u8);
294 });
295 }
296
297 #[tokio::test]
298 async fn test_f32_sq8() {
299 let float_values = Vec::from_iter((0..16).map(|v| v as f32));
300 let float_array = Float32Array::from_iter_values(float_values.clone());
301 let vectors =
302 FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
303 .unwrap();
304 let mut sq = ScalarQuantizer::new(8, float_values.len());
305
306 sq.update_bounds::<Float32Type>(&vectors).unwrap();
307 assert_eq!(sq.bounds().start, float_values[0].to_f64().unwrap());
308 assert_eq!(
309 sq.bounds().end,
310 float_values.last().cloned().unwrap().to_f64().unwrap()
311 );
312
313 let sq_code = sq.transform::<Float32Type>(&vectors).unwrap();
314 let sq_values = sq_code
315 .as_fixed_size_list()
316 .values()
317 .as_any()
318 .downcast_ref::<UInt8Array>()
319 .unwrap();
320
321 sq_values.values().iter().enumerate().for_each(|(i, v)| {
322 assert_eq!(*v, (i * 17) as u8,);
323 });
324 }
325
326 #[tokio::test]
327 async fn test_f64_sq8() {
328 let float_values = Vec::from_iter((0..16).map(|v| v as f64));
329 let float_array = Float64Array::from_iter_values(float_values.clone());
330 let vectors =
331 FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
332 .unwrap();
333 let mut sq = ScalarQuantizer::new(8, float_values.len());
334
335 sq.update_bounds::<Float64Type>(&vectors).unwrap();
336 assert_eq!(sq.bounds().start, float_values[0]);
337 assert_eq!(sq.bounds().end, float_values.last().cloned().unwrap());
338
339 let sq_code = sq.transform::<Float64Type>(&vectors).unwrap();
340 let sq_values = sq_code
341 .as_fixed_size_list()
342 .values()
343 .as_any()
344 .downcast_ref::<UInt8Array>()
345 .unwrap();
346
347 sq_values.values().iter().enumerate().for_each(|(i, v)| {
348 assert_eq!(*v, (i * 17) as u8,);
349 });
350 }
351
352 #[tokio::test]
353 async fn test_scale_to_u8_with_nan() {
354 let values = vec![0.0, 1.0, 2.0, 3.0, f64::NAN];
355 let bounds = Range::<f64> {
356 start: 0.0,
357 end: 3.0,
358 };
359 let u8_values = scale_to_u8::<Float64Type>(&values, &bounds);
360 assert_eq!(u8_values, vec![0, 85, 170, 255, 0]);
361 }
362}