1use std::{ops::Range, sync::Arc};
5
6use arrow::array::AsArray;
7use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
8use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt8Array};
9
10use arrow_schema::{DataType, Field};
11use builder::SQBuildParams;
12use deepsize::DeepSizeOf;
13use itertools::Itertools;
14use lance_arrow::*;
15use lance_core::{Error, Result};
16use lance_linalg::distance::DistanceType;
17use num_traits::*;
18use snafu::location;
19use storage::{ScalarQuantizationMetadata, ScalarQuantizationStorage, SQ_METADATA_KEY};
20
21use super::quantizer::{Quantization, QuantizationMetadata, QuantizationType, Quantizer};
22use super::SQ_CODE_COLUMN;
23
24pub mod builder;
25pub mod storage;
26pub mod transform;
27
28#[derive(Debug, Clone)]
33pub struct ScalarQuantizer {
34 metadata: ScalarQuantizationMetadata,
35}
36
37impl DeepSizeOf for ScalarQuantizer {
38 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
39 0
40 }
41}
42
43impl ScalarQuantizer {
44 pub fn new(num_bits: u16, dim: usize) -> Self {
45 Self {
46 metadata: ScalarQuantizationMetadata {
47 num_bits,
48 dim,
49 bounds: Range::<f64> {
50 start: f64::MAX,
51 end: f64::MIN,
52 },
53 },
54 }
55 }
56
57 pub fn with_bounds(num_bits: u16, dim: usize, bounds: Range<f64>) -> Self {
58 let mut sq = Self::new(num_bits, dim);
59 sq.metadata.bounds = bounds;
60 sq
61 }
62
63 pub fn num_bits(&self) -> u16 {
64 self.metadata.num_bits
65 }
66
67 pub fn update_bounds<T: ArrowFloatType>(
68 &mut self,
69 vectors: &FixedSizeListArray,
70 ) -> Result<Range<f64>> {
71 let data = vectors
72 .values()
73 .as_any()
74 .downcast_ref::<T::ArrayType>()
75 .ok_or(Error::Index {
76 message: format!(
77 "Expect to be a float vector array, got: {:?}",
78 vectors.value_type()
79 ),
80 location: location!(),
81 })?
82 .as_slice();
83
84 self.metadata.bounds = data.iter().fold(self.metadata.bounds.clone(), |f, v| {
85 f.start.min(v.as_())..f.end.max(v.as_())
86 });
87
88 Ok(self.metadata.bounds.clone())
89 }
90
91 pub fn transform<T: ArrowFloatType>(&self, data: &dyn Array) -> Result<ArrayRef> {
92 let fsl = data
93 .as_fixed_size_list_opt()
94 .ok_or(Error::Index {
95 message: format!(
96 "Expect to be a FixedSizeList<float> vector array, got: {:?} array",
97 data.data_type()
98 ),
99 location: location!(),
100 })?
101 .clone();
102 let data = fsl
103 .values()
104 .as_any()
105 .downcast_ref::<T::ArrayType>()
106 .ok_or(Error::Index {
107 message: format!(
108 "Expect to be a float vector array, got: {:?}",
109 fsl.value_type()
110 ),
111 location: location!(),
112 })?
113 .as_slice();
114
115 let builder: Vec<u8> = scale_to_u8::<T>(data, &self.metadata.bounds);
117
118 Ok(Arc::new(FixedSizeListArray::try_new_from_values(
119 UInt8Array::from(builder),
120 fsl.value_length(),
121 )?))
122 }
123
124 pub fn bounds(&self) -> Range<f64> {
125 self.metadata.bounds.clone()
126 }
127
128 pub fn use_residual(&self) -> bool {
130 false
131 }
132}
133
134impl TryFrom<Quantizer> for ScalarQuantizer {
135 type Error = Error;
136 fn try_from(value: Quantizer) -> Result<Self> {
137 match value {
138 Quantizer::Scalar(sq) => Ok(sq),
139 _ => Err(Error::Index {
140 message: "Expect to be a ScalarQuantizer".to_string(),
141 location: location!(),
142 }),
143 }
144 }
145}
146
147impl Quantization for ScalarQuantizer {
148 type BuildParams = SQBuildParams;
149 type Metadata = ScalarQuantizationMetadata;
150 type Storage = ScalarQuantizationStorage;
151
152 fn build(data: &dyn Array, _: DistanceType, params: &Self::BuildParams) -> Result<Self> {
153 let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index {
154 message: format!(
155 "SQ builder: input is not a FixedSizeList: {}",
156 data.data_type()
157 ),
158 location: location!(),
159 })?;
160
161 let mut quantizer = Self::new(params.num_bits, fsl.value_length() as usize);
162
163 match fsl.value_type() {
164 DataType::Float16 => {
165 quantizer.update_bounds::<Float16Type>(fsl)?;
166 }
167 DataType::Float32 => {
168 quantizer.update_bounds::<Float32Type>(fsl)?;
169 }
170 DataType::Float64 => {
171 quantizer.update_bounds::<Float64Type>(fsl)?;
172 }
173 _ => {
174 return Err(Error::Index {
175 message: format!("SQ builder: unsupported data type: {}", fsl.value_type()),
176 location: location!(),
177 })
178 }
179 }
180
181 Ok(quantizer)
182 }
183
184 fn retrain(&mut self, data: &dyn Array) -> Result<()> {
185 let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index {
186 message: format!(
187 "SQ retrain: input is not a FixedSizeList: {}",
188 data.data_type()
189 ),
190 location: location!(),
191 })?;
192
193 match fsl.value_type() {
194 DataType::Float16 => {
195 self.update_bounds::<Float16Type>(fsl)?;
196 }
197 DataType::Float32 => {
198 self.update_bounds::<Float32Type>(fsl)?;
199 }
200 DataType::Float64 => {
201 self.update_bounds::<Float64Type>(fsl)?;
202 }
203 value_type => {
204 return Err(Error::invalid_input(
205 format!("unsupported data type {} for scalar quantizer", value_type),
206 location!(),
207 ))
208 }
209 }
210 Ok(())
211 }
212
213 fn code_dim(&self) -> usize {
214 self.metadata.dim
215 }
216
217 fn column(&self) -> &'static str {
218 SQ_CODE_COLUMN
219 }
220
221 fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
222 match vectors.as_fixed_size_list().value_type() {
223 DataType::Float16 => self.transform::<Float16Type>(vectors),
224 DataType::Float32 => self.transform::<Float32Type>(vectors),
225 DataType::Float64 => self.transform::<Float64Type>(vectors),
226 value_type => Err(Error::invalid_input(
227 format!("unsupported data type {} for scalar quantizer", value_type),
228 location!(),
229 )),
230 }
231 }
232
233 fn metadata_key() -> &'static str {
234 SQ_METADATA_KEY
235 }
236
237 fn quantization_type() -> QuantizationType {
238 QuantizationType::Scalar
239 }
240
241 fn metadata(&self, _: Option<QuantizationMetadata>) -> Self::Metadata {
242 self.metadata.clone()
243 }
244
245 fn from_metadata(metadata: &Self::Metadata, _: DistanceType) -> Result<Quantizer> {
246 Ok(Quantizer::Scalar(Self {
247 metadata: metadata.clone(),
248 }))
249 }
250
251 fn field(&self) -> Field {
252 Field::new(
253 SQ_CODE_COLUMN,
254 DataType::FixedSizeList(
255 Arc::new(Field::new("item", DataType::UInt8, true)),
256 self.code_dim() as i32,
257 ),
258 true,
259 )
260 }
261}
262
263pub(crate) fn scale_to_u8<T: ArrowFloatType>(values: &[T::Native], bounds: &Range<f64>) -> Vec<u8> {
264 if bounds.start == bounds.end {
265 return vec![0; values.len()];
266 }
267
268 let range = bounds.end - bounds.start;
269 values
270 .iter()
271 .map(|&v| {
272 let v = v.to_f64().unwrap();
273 let v = (v - bounds.start) * 255.0 / range;
274 v as u8 })
276 .collect_vec()
277}
278
279pub(crate) fn inverse_scalar_dist(
280 values: impl Iterator<Item = f32>,
281 bounds: &Range<f64>,
282) -> Vec<f32> {
283 let range = (bounds.end - bounds.start) as f32;
284 values
285 .map(|v| v * range.powi(2) / 255.0.powi(2))
286 .collect_vec()
287}
288#[cfg(test)]
289mod tests {
290 use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
291 use arrow_array::{Float16Array, Float32Array, Float64Array};
292 use half::f16;
293
294 use super::*;
295
296 #[tokio::test]
297 async fn test_f16_sq8() {
298 let float_values = Vec::from_iter((0..16).map(|v| f16::from_usize(v).unwrap()));
299 let float_array = Float16Array::from_iter_values(float_values.clone());
300 let vectors =
301 FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
302 .unwrap();
303 let mut sq = ScalarQuantizer::new(8, float_values.len());
304
305 sq.update_bounds::<Float16Type>(&vectors).unwrap();
306 assert_eq!(sq.bounds().start, float_values[0].to_f64());
307 assert_eq!(
308 sq.bounds().end,
309 float_values.last().cloned().unwrap().to_f64()
310 );
311
312 let sq_code = sq.transform::<Float16Type>(&vectors).unwrap();
313 let sq_values = sq_code
314 .as_fixed_size_list()
315 .values()
316 .as_any()
317 .downcast_ref::<UInt8Array>()
318 .unwrap();
319
320 sq_values.values().iter().enumerate().for_each(|(i, v)| {
321 assert_eq!(*v, (i * 17) as u8);
322 });
323 }
324
325 #[tokio::test]
326 async fn test_f32_sq8() {
327 let float_values = Vec::from_iter((0..16).map(|v| v as f32));
328 let float_array = Float32Array::from_iter_values(float_values.clone());
329 let vectors =
330 FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
331 .unwrap();
332 let mut sq = ScalarQuantizer::new(8, float_values.len());
333
334 sq.update_bounds::<Float32Type>(&vectors).unwrap();
335 assert_eq!(sq.bounds().start, float_values[0].to_f64().unwrap());
336 assert_eq!(
337 sq.bounds().end,
338 float_values.last().cloned().unwrap().to_f64().unwrap()
339 );
340
341 let sq_code = sq.transform::<Float32Type>(&vectors).unwrap();
342 let sq_values = sq_code
343 .as_fixed_size_list()
344 .values()
345 .as_any()
346 .downcast_ref::<UInt8Array>()
347 .unwrap();
348
349 sq_values.values().iter().enumerate().for_each(|(i, v)| {
350 assert_eq!(*v, (i * 17) as u8,);
351 });
352 }
353
354 #[tokio::test]
355 async fn test_f64_sq8() {
356 let float_values = Vec::from_iter((0..16).map(|v| v as f64));
357 let float_array = Float64Array::from_iter_values(float_values.clone());
358 let vectors =
359 FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
360 .unwrap();
361 let mut sq = ScalarQuantizer::new(8, float_values.len());
362
363 sq.update_bounds::<Float64Type>(&vectors).unwrap();
364 assert_eq!(sq.bounds().start, float_values[0]);
365 assert_eq!(sq.bounds().end, float_values.last().cloned().unwrap());
366
367 let sq_code = sq.transform::<Float64Type>(&vectors).unwrap();
368 let sq_values = sq_code
369 .as_fixed_size_list()
370 .values()
371 .as_any()
372 .downcast_ref::<UInt8Array>()
373 .unwrap();
374
375 sq_values.values().iter().enumerate().for_each(|(i, v)| {
376 assert_eq!(*v, (i * 17) as u8,);
377 });
378 }
379
380 #[tokio::test]
381 async fn test_scale_to_u8_with_nan() {
382 let values = vec![0.0, 1.0, 2.0, 3.0, f64::NAN];
383 let bounds = Range::<f64> {
384 start: 0.0,
385 end: 3.0,
386 };
387 let u8_values = scale_to_u8::<Float64Type>(&values, &bounds);
388 assert_eq!(u8_values, vec![0, 85, 170, 255, 0]);
389 }
390}