1use std::sync::Arc;
8
9use arrow::datatypes::{self, ArrowPrimitiveType};
10use arrow_array::{cast::AsArray, Array, FixedSizeListArray, UInt8Array};
11use arrow_array::{ArrayRef, Float32Array, PrimitiveArray};
12use arrow_schema::{DataType, Field};
13use deepsize::DeepSizeOf;
14use distance::build_distance_table_dot;
15use lance_arrow::*;
16use lance_core::{assume_eq, Error, Result};
17use lance_linalg::distance::{DistanceType, Dot, L2};
18use lance_table::utils::LanceIteratorExtension;
19use num_traits::Float;
20use prost::Message;
21use snafu::location;
22use storage::{ProductQuantizationMetadata, ProductQuantizationStorage, PQ_METADATA_KEY};
23use tracing::instrument;
24
25pub mod builder;
26pub mod distance;
27pub mod storage;
28pub mod transform;
29pub(crate) mod utils;
30
31use self::distance::{build_distance_table_l2, compute_pq_distance};
32pub use self::utils::num_centroids;
33use super::quantizer::{
34 Quantization, QuantizationMetadata, QuantizationType, Quantizer, QuantizerBuildParams,
35};
36use super::{pb, PQ_CODE_COLUMN};
37use crate::vector::kmeans::compute_partition;
38pub use builder::PQBuildParams;
39use utils::get_sub_vector_centroids;
40
41#[derive(Debug, Clone)]
42pub struct ProductQuantizer {
43 pub num_sub_vectors: usize,
44 pub num_bits: u32,
45 pub dimension: usize,
46 pub codebook: FixedSizeListArray,
47 pub distance_type: DistanceType,
48}
49
50impl DeepSizeOf for ProductQuantizer {
51 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
52 self.codebook.get_array_memory_size()
53 + self.num_sub_vectors.deep_size_of_children(_context)
54 + self.num_bits.deep_size_of_children(_context)
55 + self.dimension.deep_size_of_children(_context)
56 + self.distance_type.deep_size_of_children(_context)
57 }
58}
59
60impl ProductQuantizer {
61 pub fn new(
62 num_sub_vectors: usize,
63 num_bits: u32,
64 dimension: usize,
65 codebook: FixedSizeListArray,
66 distance_type: DistanceType,
67 ) -> Self {
68 Self {
69 num_bits,
70 num_sub_vectors,
71 dimension,
72 codebook,
73 distance_type,
74 }
75 }
76
77 pub fn from_proto(proto: &pb::Pq, distance_type: DistanceType) -> Result<Self> {
78 let distance_type = match distance_type {
79 DistanceType::Cosine => DistanceType::L2,
80 _ => distance_type,
81 };
82 let codebook = match proto.codebook_tensor.as_ref() {
83 Some(tensor) => FixedSizeListArray::try_from(tensor)?,
84 None => FixedSizeListArray::try_new_from_values(
85 Float32Array::from(proto.codebook.clone()),
86 proto.dimension as i32,
87 )?,
88 };
89 Ok(Self {
90 num_bits: proto.num_bits,
91 num_sub_vectors: proto.num_sub_vectors as usize,
92 dimension: proto.dimension as usize,
93 codebook,
94 distance_type,
95 })
96 }
97
98 #[instrument(name = "ProductQuantizer::transform", level = "debug", skip_all)]
99 fn transform<T: ArrowPrimitiveType>(&self, vectors: &dyn Array) -> Result<ArrayRef>
100 where
101 T::Native: Float + L2 + Dot,
102 {
103 match self.num_bits {
104 4 => self.transform_impl::<4, T>(vectors),
105 8 => self.transform_impl::<8, T>(vectors),
106 _ => Err(Error::Index {
107 message: format!(
108 "ProductQuantization: num_bits {} not supported",
109 self.num_bits
110 ),
111 location: location!(),
112 }),
113 }
114 }
115
116 fn transform_impl<const NUM_BITS: u32, T: ArrowPrimitiveType>(
117 &self,
118 vectors: &dyn Array,
119 ) -> Result<ArrayRef>
120 where
121 T::Native: Float + L2 + Dot,
122 {
123 let fsl = vectors.as_fixed_size_list_opt().ok_or(Error::Index {
124 message: format!(
125 "Expect to be a FixedSizeList<float> vector array, got: {:?} array",
126 vectors.data_type()
127 ),
128 location: location!(),
129 })?;
130 let num_sub_vectors = self.num_sub_vectors;
131 let dim = self.dimension;
132 if NUM_BITS == 4 && num_sub_vectors % 2 != 0 {
133 return Err(Error::Index {
134 message: format!(
135 "PQ: num_sub_vectors must be divisible by 2 for num_bits=4, but got {}",
136 num_sub_vectors,
137 ),
138 location: location!(),
139 });
140 }
141 let codebook = self.codebook.values().as_primitive::<T>();
142
143 let distance_type = self.distance_type;
144
145 let flatten_data = fsl.values().as_primitive::<T>();
146 let sub_dim = dim / num_sub_vectors;
147 let total_code_length = fsl.len() * num_sub_vectors / (8 / NUM_BITS as usize);
148 let values = flatten_data
149 .values()
150 .chunks_exact(dim)
151 .flat_map(|vector| {
152 let sub_vec_code = vector
153 .chunks_exact(sub_dim)
154 .enumerate()
155 .map(|(sub_idx, sub_vector)| {
156 let centroids = get_sub_vector_centroids::<NUM_BITS, _>(
157 codebook.values(),
158 dim,
159 num_sub_vectors,
160 sub_idx,
161 );
162 assume_eq!(centroids.len(), 2_usize.pow(NUM_BITS) * sub_dim);
165 compute_partition(centroids, sub_vector, distance_type).unwrap_or(0) as u8
166 })
167 .collect::<Vec<_>>();
168 if NUM_BITS == 4 {
169 sub_vec_code
170 .chunks_exact(2)
171 .map(|v| (v[1] << 4) | v[0])
172 .collect::<Vec<_>>()
173 } else {
174 sub_vec_code
175 }
176 })
177 .exact_size(total_code_length)
178 .collect::<Vec<_>>();
179
180 let num_sub_vectors_in_byte = if NUM_BITS == 4 {
181 num_sub_vectors / 2
182 } else {
183 num_sub_vectors
184 };
185
186 debug_assert_eq!(values.len(), fsl.len() * num_sub_vectors_in_byte);
187 Ok(Arc::new(FixedSizeListArray::try_new_from_values(
188 UInt8Array::from(values),
189 num_sub_vectors_in_byte as i32,
190 )?))
191 }
192
193 pub fn compute_distances(&self, query: &dyn Array, code: &UInt8Array) -> Result<Float32Array> {
195 if code.is_empty() {
196 return Ok(Float32Array::from(Vec::<f32>::new()));
197 }
198
199 match self.distance_type {
200 DistanceType::L2 => self.l2_distances(query, code),
201 DistanceType::Cosine => {
202 debug_assert!(
205 false,
206 "cosine distance should be converted to normalized L2 distance"
207 );
208 let l2_dists = self.l2_distances(query, code)?;
212 Ok(l2_dists.values().iter().map(|v| *v / 2.0).collect())
213 }
214 DistanceType::Dot => self.dot_distances(query, code),
215 _ => panic!(
216 "ProductQuantization: distance type {} not supported",
217 self.distance_type
218 ),
219 }
220 }
221
222 fn l2_distances(&self, key: &dyn Array, code: &UInt8Array) -> Result<Float32Array> {
226 let distance_table = self.build_l2_distance_table(key)?;
227
228 #[cfg(target_feature = "avx512f")]
229 {
230 Ok(self.compute_l2_distance(&distance_table, code.values()))
231 }
232 #[cfg(not(target_feature = "avx512f"))]
233 {
234 Ok(self.compute_l2_distance(&distance_table, code.values()))
235 }
236 }
237
238 fn dot_distances(&self, key: &dyn Array, code: &UInt8Array) -> Result<Float32Array> {
244 match key.data_type() {
245 DataType::Float16 => {
246 self.dot_distances_impl::<datatypes::Float16Type>(key.as_primitive(), code)
247 }
248 DataType::Float32 => {
249 self.dot_distances_impl::<datatypes::Float32Type>(key.as_primitive(), code)
250 }
251 DataType::Float64 => {
252 self.dot_distances_impl::<datatypes::Float64Type>(key.as_primitive(), code)
253 }
254 _ => Err(Error::Index {
255 message: format!("unsupported data type: {}", key.data_type()),
256 location: location!(),
257 }),
258 }
259 }
260
261 fn dot_distances_impl<T: ArrowPrimitiveType>(
262 &self,
263 key: &PrimitiveArray<T>,
264 code: &UInt8Array,
265 ) -> Result<Float32Array>
266 where
267 T::Native: Dot,
268 {
269 let distance_table = build_distance_table_dot(
270 self.codebook.values().as_primitive::<T>().values(),
271 self.num_bits,
272 self.num_sub_vectors,
273 key.values(),
274 );
275
276 let distances = compute_pq_distance(
277 &distance_table,
278 self.num_bits,
279 self.num_sub_vectors,
280 code.values(),
281 0,
282 );
283
284 let diff = self.num_sub_vectors as f32 - 1.0;
285 let distances = distances.into_iter().map(|d| d - diff).collect::<Vec<_>>();
286 Ok(distances.into())
287 }
288
289 fn build_l2_distance_table(&self, key: &dyn Array) -> Result<Vec<f32>> {
290 match key.data_type() {
291 DataType::Float16 => {
292 Ok(self.build_l2_distance_table_impl::<datatypes::Float16Type>(key.as_primitive()))
293 }
294 DataType::Float32 => {
295 Ok(self.build_l2_distance_table_impl::<datatypes::Float32Type>(key.as_primitive()))
296 }
297 DataType::Float64 => {
298 Ok(self.build_l2_distance_table_impl::<datatypes::Float64Type>(key.as_primitive()))
299 }
300 _ => Err(Error::Index {
301 message: format!("unsupported data type: {}", key.data_type()),
302 location: location!(),
303 }),
304 }
305 }
306
307 fn build_l2_distance_table_impl<T: ArrowPrimitiveType>(
308 &self,
309 key: &PrimitiveArray<T>,
310 ) -> Vec<f32>
311 where
312 T::Native: L2,
313 {
314 build_distance_table_l2(
315 self.codebook.values().as_primitive::<T>().values(),
316 self.num_bits,
317 self.num_sub_vectors,
318 key.values(),
319 )
320 }
321
322 #[inline]
339 fn compute_l2_distance(&self, distance_table: &[f32], code: &[u8]) -> Float32Array {
340 Float32Array::from(compute_pq_distance(
341 distance_table,
342 self.num_bits,
343 self.num_sub_vectors,
344 code,
345 100,
346 ))
347 }
348
349 pub fn centroids<T: ArrowPrimitiveType>(&self, sub_vector_idx: usize) -> &[T::Native] {
353 match self.num_bits {
354 4 => get_sub_vector_centroids::<4, _>(
355 self.codebook.values().as_primitive::<T>().values(),
356 self.dimension,
357 self.num_sub_vectors,
358 sub_vector_idx,
359 ),
360 8 => get_sub_vector_centroids::<8, _>(
361 self.codebook.values().as_primitive::<T>().values(),
362 self.dimension,
363 self.num_sub_vectors,
364 sub_vector_idx,
365 ),
366 _ => panic!(
367 "ProductQuantization: num_bits {} not supported",
368 self.num_bits
369 ),
370 }
371 }
372}
373
374impl Quantization for ProductQuantizer {
375 type BuildParams = PQBuildParams;
376 type Metadata = ProductQuantizationMetadata;
377 type Storage = ProductQuantizationStorage;
378
379 fn build(
380 data: &dyn Array,
381 distance_type: DistanceType,
382 params: &Self::BuildParams,
383 ) -> Result<Self> {
384 assert_eq!(data.null_count(), 0);
385 let fsl = data.as_fixed_size_list_opt().ok_or(Error::Index {
386 message: format!(
387 "PQ builder: input is not a FixedSizeList: {}",
388 data.data_type()
389 ),
390 location: location!(),
391 })?;
392
393 if let Some(codebook) = params.codebook.as_ref() {
394 return Ok(Self::new(
395 params.num_sub_vectors,
396 params.num_bits as u32,
397 fsl.value_length() as usize,
398 FixedSizeListArray::try_new_from_values(codebook.clone(), fsl.value_length())?,
399 distance_type,
400 ));
401 }
402
403 params.build(data, distance_type)
404 }
405
406 fn retrain(&mut self, data: &dyn Array) -> Result<()> {
407 assert_eq!(data.null_count(), 0);
408 let params = PQBuildParams::with_codebook(
409 self.num_sub_vectors,
410 self.num_bits as usize,
411 Arc::new(self.codebook.clone()),
412 );
413
414 *self = params.build(data, self.distance_type)?;
415 Ok(())
416 }
417
418 fn code_dim(&self) -> usize {
419 self.num_sub_vectors
420 }
421
422 fn column(&self) -> &'static str {
423 PQ_CODE_COLUMN
424 }
425
426 fn use_residual(distance_type: DistanceType) -> bool {
427 PQBuildParams::use_residual(distance_type)
428 }
429
430 fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
431 let fsl = vectors.as_fixed_size_list_opt().ok_or(Error::Index {
432 message: format!(
433 "Expect to be a FixedSizeList<float> vector array, got: {:?} array",
434 vectors.data_type()
435 ),
436 location: location!(),
437 })?;
438
439 match fsl.value_type() {
440 DataType::Float16 => self.transform::<datatypes::Float16Type>(vectors),
441 DataType::Float32 => self.transform::<datatypes::Float32Type>(vectors),
442 DataType::Float64 => self.transform::<datatypes::Float64Type>(vectors),
443 _ => Err(Error::Index {
444 message: format!("unsupported data type: {}", fsl.value_type()),
445 location: location!(),
446 }),
447 }
448 }
449
450 fn metadata_key() -> &'static str {
451 PQ_METADATA_KEY
452 }
453
454 fn quantization_type() -> QuantizationType {
455 QuantizationType::Product
456 }
457
458 fn metadata(&self, args: Option<QuantizationMetadata>) -> Self::Metadata {
459 let codebook_position = match &args {
460 Some(args) => args.codebook_position,
461 None => Some(0),
462 };
463
464 let codebook_position = codebook_position.expect("codebook position should be set");
465 ProductQuantizationMetadata {
466 codebook_position,
467 nbits: self.num_bits,
468 num_sub_vectors: self.num_sub_vectors,
469 dimension: self.dimension,
470 codebook: Some(self.codebook.clone()),
471 codebook_tensor: Vec::new(),
472 transposed: args.map(|args| args.transposed).unwrap_or_default(),
473 }
474 }
475
476 fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer> {
477 let distance_type = match distance_type {
478 DistanceType::Cosine => DistanceType::L2,
479 _ => distance_type,
480 };
481 let codebook = match metadata.codebook.as_ref() {
482 Some(fsl) => fsl.clone(),
483 None => {
484 let tensor = pb::Tensor::decode(metadata.codebook_tensor.as_ref())?;
485 FixedSizeListArray::try_from(&tensor)?
486 }
487 };
488 Ok(Quantizer::Product(Self::new(
489 metadata.num_sub_vectors,
490 metadata.nbits,
491 metadata.dimension,
492 codebook,
493 distance_type,
494 )))
495 }
496
497 fn field(&self) -> Field {
498 let num_bytes_per_sub_vector = self.num_sub_vectors * self.num_bits as usize / 8;
499 Field::new(
500 PQ_CODE_COLUMN,
501 DataType::FixedSizeList(
502 Arc::new(Field::new("item", DataType::UInt8, true)),
503 num_bytes_per_sub_vector as i32,
504 ),
505 true,
506 )
507 }
508}
509
510impl TryFrom<&ProductQuantizer> for pb::Pq {
511 type Error = Error;
512
513 fn try_from(pq: &ProductQuantizer) -> Result<Self> {
514 let tensor = pb::Tensor::try_from(&pq.codebook)?;
515 Ok(Self {
516 num_bits: pq.num_bits,
517 num_sub_vectors: pq.num_sub_vectors as u32,
518 dimension: pq.dimension as u32,
519 codebook: vec![],
520 codebook_tensor: Some(tensor),
521 })
522 }
523}
524
525impl TryFrom<Quantizer> for ProductQuantizer {
526 type Error = Error;
527 fn try_from(value: Quantizer) -> Result<Self> {
528 match value {
529 Quantizer::Product(pq) => Ok(pq),
530 _ => Err(Error::Index {
531 message: "Expect to be a ProductQuantizer".to_string(),
532 location: location!(),
533 }),
534 }
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541
542 use std::iter::repeat_n;
543
544 use approx::assert_relative_eq;
545 use arrow::datatypes::UInt8Type;
546 use arrow_array::Float16Array;
547 use half::f16;
548 use lance_linalg::distance::l2_distance_batch;
549 use lance_linalg::kernels::argmin;
550 use lance_testing::datagen::generate_random_array;
551 use num_traits::Zero;
552 use storage::transpose;
553
554 #[test]
555 fn test_f16_pq_to_protobuf() {
556 let pq = ProductQuantizer::new(
557 4,
558 8,
559 16,
560 FixedSizeListArray::try_new_from_values(
561 Float16Array::from_iter_values(repeat_n(f16::zero(), 256 * 16)),
562 16,
563 )
564 .unwrap(),
565 DistanceType::L2,
566 );
567 let proto: pb::Pq = pb::Pq::try_from(&pq).unwrap();
568 assert_eq!(proto.num_bits, 8);
569 assert_eq!(proto.num_sub_vectors, 4);
570 assert_eq!(proto.dimension, 16);
571 assert!(proto.codebook.is_empty());
572 assert!(proto.codebook_tensor.is_some());
573
574 let tensor = proto.codebook_tensor.as_ref().unwrap();
575 assert_eq!(tensor.data_type, pb::tensor::DataType::Float16 as i32);
576 assert_eq!(tensor.shape, vec![256, 16]);
577 }
578
579 #[test]
580 fn test_l2_distance() {
581 const DIM: usize = 512;
582 const TOTAL: usize = 66; let codebook = generate_random_array(256 * DIM);
584 let pq = ProductQuantizer::new(
585 16,
586 8,
587 DIM,
588 FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(),
589 DistanceType::L2,
590 );
591 let pq_code = UInt8Array::from_iter_values((0..16 * TOTAL).map(|v| v as u8));
592 let query = generate_random_array(DIM);
593
594 let transposed_pq_codes = transpose(&pq_code, TOTAL, 16);
595 let dists = pq.compute_distances(&query, &transposed_pq_codes).unwrap();
596
597 let sub_vec_len = DIM / 16;
598 let expected = pq_code
599 .values()
600 .chunks(16)
601 .map(|code| {
602 code.iter()
603 .enumerate()
604 .flat_map(|(sub_idx, c)| {
605 let subvec_centroids = pq.centroids::<datatypes::Float32Type>(sub_idx);
606 let subvec =
607 &query.values()[sub_idx * sub_vec_len..(sub_idx + 1) * sub_vec_len];
608 l2_distance_batch(
609 subvec,
610 &subvec_centroids
611 [*c as usize * sub_vec_len..(*c as usize + 1) * sub_vec_len],
612 sub_vec_len,
613 )
614 })
615 .sum::<f32>()
616 })
617 .collect::<Vec<_>>();
618 dists
619 .values()
620 .iter()
621 .zip(expected.iter())
622 .for_each(|(v, e)| {
623 assert_relative_eq!(*v, *e, epsilon = 1e-4);
624 });
625 }
626
627 #[test]
628 fn test_pq_transform() {
629 const DIM: usize = 16;
630 const TOTAL: usize = 64;
631 let codebook = generate_random_array(DIM * 256);
632 let pq = ProductQuantizer::new(
633 4,
634 8,
635 DIM,
636 FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(),
637 DistanceType::L2,
638 );
639
640 let vectors = generate_random_array(DIM * TOTAL);
641 let fsl = FixedSizeListArray::try_new_from_values(vectors.clone(), DIM as i32).unwrap();
642 let pq_code = pq.quantize(&fsl).unwrap();
643
644 let mut expected = Vec::with_capacity(TOTAL * 4);
645 vectors.values().chunks_exact(DIM).for_each(|vec| {
646 vec.chunks_exact(DIM / 4)
647 .enumerate()
648 .for_each(|(sub_idx, sub_vec)| {
649 let centroids = pq.centroids::<datatypes::Float32Type>(sub_idx);
650 let dists = l2_distance_batch(sub_vec, centroids, DIM / 4);
651 let code = argmin(dists).unwrap() as u8;
652 expected.push(code);
653 });
654 });
655
656 assert_eq!(pq_code.len(), TOTAL);
657 assert_eq!(
658 &expected,
659 pq_code
660 .as_fixed_size_list()
661 .values()
662 .as_primitive::<UInt8Type>()
663 .values()
664 );
665 }
666}