1use std::{ops::Range, sync::Arc};
5
6use arrow::array::AsArray;
7use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
8use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt8Array};
9
10use arrow_schema::{DataType, Field};
11use builder::SQBuildParams;
12use deepsize::DeepSizeOf;
13use itertools::Itertools;
14use lance_arrow::*;
15use lance_core::{Error, Result};
16use lance_linalg::distance::DistanceType;
17use num_traits::*;
18use snafu::location;
19use storage::{ScalarQuantizationMetadata, ScalarQuantizationStorage, SQ_METADATA_KEY};
20
21use super::quantizer::{Quantization, QuantizationMetadata, QuantizationType, Quantizer};
22use super::SQ_CODE_COLUMN;
23
24pub mod builder;
25pub mod storage;
26pub mod transform;
27
28#[derive(Debug, Clone)]
33pub struct ScalarQuantizer {
34 metadata: ScalarQuantizationMetadata,
35}
36
37impl DeepSizeOf for ScalarQuantizer {
38 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
39 0
40 }
41}
42
43impl ScalarQuantizer {
44 pub fn new(num_bits: u16, dim: usize) -> Self {
45 Self {
46 metadata: ScalarQuantizationMetadata {
47 num_bits,
48 dim,
49 bounds: Range::<f64> {
50 start: f64::MAX,
51 end: f64::MIN,
52 },
53 },
54 }
55 }
56
57 pub fn with_bounds(num_bits: u16, dim: usize, bounds: Range<f64>) -> Self {
58 let mut sq = Self::new(num_bits, dim);
59 sq.metadata.bounds = bounds;
60 sq
61 }
62
63 pub fn num_bits(&self) -> u16 {
64 self.metadata.num_bits
65 }
66
67 pub fn update_bounds<T: ArrowFloatType>(
68 &mut self,
69 vectors: &FixedSizeListArray,
70 ) -> Result<Range<f64>> {
71 let data = vectors
72 .values()
73 .as_any()
74 .downcast_ref::<T::ArrayType>()
75 .ok_or(Error::Index {
76 message: format!(
77 "Expect to be a float vector array, got: {:?}",
78 vectors.value_type()
79 ),
80 location: location!(),
81 })?
82 .as_slice();
83
84 self.metadata.bounds = data.iter().fold(self.metadata.bounds.clone(), |f, v| {
85 f.start.min(v.as_())..f.end.max(v.as_())
86 });
87
88 Ok(self.metadata.bounds.clone())
89 }
90
91 pub fn transform<T: ArrowFloatType>(&self, data: &dyn Array) -> Result<ArrayRef> {
92 let fsl = data
93 .as_fixed_size_list_opt()
94 .ok_or(Error::Index {
95 message: format!(
96 "Expect to be a FixedSizeList<float> vector array, got: {:?} array",
97 data.data_type()
98 ),
99 location: location!(),
100 })?
101 .clone();
102 let data = fsl
103 .values()
104 .as_any()
105 .downcast_ref::<T::ArrayType>()
106 .ok_or(Error::Index {
107 message: format!(
108 "Expect to be a float vector array, got: {:?}",
109 fsl.value_type()
110 ),
111 location: location!(),
112 })?
113 .as_slice();
114
115 let builder: Vec<u8> = scale_to_u8::<T>(data, &self.metadata.bounds);
117
118 Ok(Arc::new(FixedSizeListArray::try_new_from_values(
119 UInt8Array::from(builder),
120 fsl.value_length(),
121 )?))
122 }
123
124 pub fn bounds(&self) -> Range<f64> {
125 self.metadata.bounds.clone()
126 }
127
128 pub fn use_residual(&self) -> bool {
130 false
131 }
132}
133
134impl TryFrom<Quantizer> for ScalarQuantizer {
135 type Error = Error;
136 fn try_from(value: Quantizer) -> Result<Self> {
137 match value {
138 Quantizer::Scalar(sq) => Ok(sq),
139 _ => Err(Error::Index {
140 message: "Expect to be a ScalarQuantizer".to_string(),
141 location: location!(),
142 }),
143 }
144 }
145}
146
147impl Quantization for ScalarQuantizer {
148 type BuildParams = SQBuildParams;
149 type Metadata = ScalarQuantizationMetadata;
150 type Storage = ScalarQuantizationStorage;
151
152 fn build(data: &dyn Array, _: DistanceType, params: &Self::BuildParams) -> Result<Self> {
153 let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index {
154 message: format!(
155 "SQ builder: input is not a FixedSizeList: {}",
156 data.data_type()
157 ),
158 location: location!(),
159 })?;
160
161 let mut quantizer = Self::new(params.num_bits, fsl.value_length() as usize);
162
163 match fsl.value_type() {
164 DataType::Float16 => {
165 quantizer.update_bounds::<Float16Type>(fsl)?;
166 }
167 DataType::Float32 => {
168 quantizer.update_bounds::<Float32Type>(fsl)?;
169 }
170 DataType::Float64 => {
171 quantizer.update_bounds::<Float64Type>(fsl)?;
172 }
173 _ => {
174 return Err(Error::Index {
175 message: format!("SQ builder: unsupported data type: {}", fsl.value_type()),
176 location: location!(),
177 })
178 }
179 }
180
181 Ok(quantizer)
182 }
183
184 fn retrain(&mut self, data: &dyn Array) -> Result<()> {
185 let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index {
186 message: format!(
187 "SQ retrain: input is not a FixedSizeList: {}",
188 data.data_type()
189 ),
190 location: location!(),
191 })?;
192
193 match fsl.value_type() {
194 DataType::Float16 => {
195 self.update_bounds::<Float16Type>(fsl)?;
196 }
197 DataType::Float32 => {
198 self.update_bounds::<Float32Type>(fsl)?;
199 }
200 DataType::Float64 => {
201 self.update_bounds::<Float64Type>(fsl)?;
202 }
203 value_type => {
204 return Err(Error::invalid_input(
205 format!("unsupported data type {} for scalar quantizer", value_type),
206 location!(),
207 ))
208 }
209 }
210 Ok(())
211 }
212
213 fn code_dim(&self) -> usize {
214 self.metadata.dim
215 }
216
217 fn column(&self) -> &'static str {
218 SQ_CODE_COLUMN
219 }
220
221 fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
222 match vectors.as_fixed_size_list().value_type() {
223 DataType::Float16 => self.transform::<Float16Type>(vectors),
224 DataType::Float32 => self.transform::<Float32Type>(vectors),
225 DataType::Float64 => self.transform::<Float64Type>(vectors),
226 value_type => Err(Error::invalid_input(
227 format!("unsupported data type {} for scalar quantizer", value_type),
228 location!(),
229 )),
230 }
231 }
232
233 fn metadata_key() -> &'static str {
234 SQ_METADATA_KEY
235 }
236
237 fn quantization_type() -> QuantizationType {
238 QuantizationType::Scalar
239 }
240
241 fn metadata(&self, _: Option<QuantizationMetadata>) -> Self::Metadata {
242 self.metadata.clone()
243 }
244
245 fn from_metadata(metadata: &Self::Metadata, _: DistanceType) -> Result<Quantizer> {
246 Ok(Quantizer::Scalar(Self {
247 metadata: metadata.clone(),
248 }))
249 }
250
251 fn field(&self) -> Field {
252 Field::new(
253 SQ_CODE_COLUMN,
254 DataType::FixedSizeList(
255 Arc::new(Field::new("item", DataType::UInt8, true)),
256 self.code_dim() as i32,
257 ),
258 true,
259 )
260 }
261}
262
263pub(crate) fn scale_to_u8<T: ArrowFloatType>(values: &[T::Native], bounds: &Range<f64>) -> Vec<u8> {
264 if bounds.start == bounds.end {
265 return vec![0; values.len()];
266 }
267
268 let range = bounds.end - bounds.start;
269 values
270 .iter()
271 .map(|&v| {
272 let v = v.to_f64().unwrap();
273 let v = (v - bounds.start) * 255.0 / range;
274 v as u8 })
276 .collect_vec()
277}
278
279#[cfg(test)]
280mod tests {
281 use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
282 use arrow_array::{Float16Array, Float32Array, Float64Array};
283 use half::f16;
284
285 use super::*;
286
287 #[tokio::test]
288 async fn test_f16_sq8() {
289 let float_values = Vec::from_iter((0..16).map(|v| f16::from_usize(v).unwrap()));
290 let float_array = Float16Array::from_iter_values(float_values.clone());
291 let vectors =
292 FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
293 .unwrap();
294 let mut sq = ScalarQuantizer::new(8, float_values.len());
295
296 sq.update_bounds::<Float16Type>(&vectors).unwrap();
297 assert_eq!(sq.bounds().start, float_values[0].to_f64());
298 assert_eq!(
299 sq.bounds().end,
300 float_values.last().cloned().unwrap().to_f64()
301 );
302
303 let sq_code = sq.transform::<Float16Type>(&vectors).unwrap();
304 let sq_values = sq_code
305 .as_fixed_size_list()
306 .values()
307 .as_any()
308 .downcast_ref::<UInt8Array>()
309 .unwrap();
310
311 sq_values.values().iter().enumerate().for_each(|(i, v)| {
312 assert_eq!(*v, (i * 17) as u8);
313 });
314 }
315
316 #[tokio::test]
317 async fn test_f32_sq8() {
318 let float_values = Vec::from_iter((0..16).map(|v| v as f32));
319 let float_array = Float32Array::from_iter_values(float_values.clone());
320 let vectors =
321 FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
322 .unwrap();
323 let mut sq = ScalarQuantizer::new(8, float_values.len());
324
325 sq.update_bounds::<Float32Type>(&vectors).unwrap();
326 assert_eq!(sq.bounds().start, float_values[0].to_f64().unwrap());
327 assert_eq!(
328 sq.bounds().end,
329 float_values.last().cloned().unwrap().to_f64().unwrap()
330 );
331
332 let sq_code = sq.transform::<Float32Type>(&vectors).unwrap();
333 let sq_values = sq_code
334 .as_fixed_size_list()
335 .values()
336 .as_any()
337 .downcast_ref::<UInt8Array>()
338 .unwrap();
339
340 sq_values.values().iter().enumerate().for_each(|(i, v)| {
341 assert_eq!(*v, (i * 17) as u8,);
342 });
343 }
344
345 #[tokio::test]
346 async fn test_f64_sq8() {
347 let float_values = Vec::from_iter((0..16).map(|v| v as f64));
348 let float_array = Float64Array::from_iter_values(float_values.clone());
349 let vectors =
350 FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
351 .unwrap();
352 let mut sq = ScalarQuantizer::new(8, float_values.len());
353
354 sq.update_bounds::<Float64Type>(&vectors).unwrap();
355 assert_eq!(sq.bounds().start, float_values[0]);
356 assert_eq!(sq.bounds().end, float_values.last().cloned().unwrap());
357
358 let sq_code = sq.transform::<Float64Type>(&vectors).unwrap();
359 let sq_values = sq_code
360 .as_fixed_size_list()
361 .values()
362 .as_any()
363 .downcast_ref::<UInt8Array>()
364 .unwrap();
365
366 sq_values.values().iter().enumerate().for_each(|(i, v)| {
367 assert_eq!(*v, (i * 17) as u8,);
368 });
369 }
370
371 #[tokio::test]
372 async fn test_scale_to_u8_with_nan() {
373 let values = vec![0.0, 1.0, 2.0, 3.0, f64::NAN];
374 let bounds = Range::<f64> {
375 start: 0.0,
376 end: 3.0,
377 };
378 let u8_values = scale_to_u8::<Float64Type>(&values, &bounds);
379 assert_eq!(u8_values, vec![0, 85, 170, 255, 0]);
380 }
381}