1use std::sync::Arc;
8
9use arrow::datatypes::{self, ArrowPrimitiveType};
10use arrow_array::{Array, FixedSizeListArray, UInt8Array, cast::AsArray};
11use arrow_array::{ArrayRef, Float32Array, PrimitiveArray};
12use arrow_schema::{DataType, Field};
13use deepsize::DeepSizeOf;
14use distance::build_distance_table_dot;
15use lance_arrow::*;
16use lance_core::{Error, Result, assume_eq};
17use lance_linalg::distance::{DistanceType, Dot, L2, l2::L2Prepared};
18use lance_table::utils::LanceIteratorExtension;
19use num_traits::Float;
20use prost::Message;
21use storage::{PQ_METADATA_KEY, ProductQuantizationMetadata, ProductQuantizationStorage};
22use tracing::instrument;
23
24pub mod builder;
25pub mod distance;
26pub mod storage;
27pub mod transform;
28pub(crate) mod utils;
29
30use self::distance::{
31 build_distance_table_l2, build_distance_table_l2_prepared, compute_pq_distance,
32};
33pub use self::utils::num_centroids;
34use super::quantizer::{
35 Quantization, QuantizationMetadata, QuantizationType, Quantizer, QuantizerBuildParams,
36};
37use super::{PQ_CODE_COLUMN, pb};
38use crate::vector::kmeans::compute_partition;
39pub use builder::PQBuildParams;
40use utils::get_sub_vector_centroids;
41
42#[derive(Debug, Clone)]
43pub struct ProductQuantizer {
44 pub num_sub_vectors: usize,
45 pub num_bits: u32,
46 pub dimension: usize,
47 pub codebook: FixedSizeListArray,
48 pub distance_type: DistanceType,
49 l2_targets: Option<Vec<L2Prepared>>,
53}
54
55impl DeepSizeOf for ProductQuantizer {
56 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
57 self.codebook.get_array_memory_size()
58 + self.num_sub_vectors.deep_size_of_children(_context)
59 + self.num_bits.deep_size_of_children(_context)
60 + self.dimension.deep_size_of_children(_context)
61 + self.distance_type.deep_size_of_children(_context)
62 + self
63 .l2_targets
64 .as_ref()
65 .map_or(0, |v| v.iter().map(|t| t.size_bytes()).sum())
66 }
67}
68
69impl ProductQuantizer {
70 fn build_l2_targets(
72 codebook: &FixedSizeListArray,
73 distance_type: DistanceType,
74 num_sub_vectors: usize,
75 num_bits: u32,
76 dimension: usize,
77 ) -> Option<Vec<L2Prepared>> {
78 if codebook.value_type() != DataType::Float32 || distance_type != DistanceType::L2 {
79 return None;
80 }
81 let values = codebook
82 .values()
83 .as_primitive::<datatypes::Float32Type>()
84 .values();
85 let sub_dim = dimension / num_sub_vectors;
86 let num_centroids = 2_usize.pow(num_bits);
87 let block_size = sub_dim * num_centroids;
88
89 let targets = (0..num_sub_vectors)
90 .map(|sub_idx| {
91 let block_start = sub_idx * block_size;
92 let block = &values[block_start..block_start + block_size];
93 L2Prepared::new(block, sub_dim)
94 })
95 .collect();
96 Some(targets)
97 }
98
99 pub fn new(
100 num_sub_vectors: usize,
101 num_bits: u32,
102 dimension: usize,
103 codebook: FixedSizeListArray,
104 distance_type: DistanceType,
105 ) -> Self {
106 let l2_targets = Self::build_l2_targets(
107 &codebook,
108 distance_type,
109 num_sub_vectors,
110 num_bits,
111 dimension,
112 );
113 Self {
114 num_bits,
115 num_sub_vectors,
116 dimension,
117 codebook,
118 distance_type,
119 l2_targets,
120 }
121 }
122
123 pub fn from_proto(proto: &pb::Pq, distance_type: DistanceType) -> Result<Self> {
124 let distance_type = match distance_type {
125 DistanceType::Cosine => DistanceType::L2,
126 _ => distance_type,
127 };
128 let codebook = match proto.codebook_tensor.as_ref() {
129 Some(tensor) => FixedSizeListArray::try_from(tensor)?,
130 None => FixedSizeListArray::try_new_from_values(
131 Float32Array::from(proto.codebook.clone()),
132 proto.dimension as i32,
133 )?,
134 };
135 Ok(Self::new(
136 proto.num_sub_vectors as usize,
137 proto.num_bits,
138 proto.dimension as usize,
139 codebook,
140 distance_type,
141 ))
142 }
143
144 #[instrument(name = "ProductQuantizer::transform", level = "debug", skip_all)]
145 fn transform<T: ArrowPrimitiveType>(&self, vectors: &dyn Array) -> Result<ArrayRef>
146 where
147 T::Native: Float + L2 + Dot,
148 {
149 match self.num_bits {
150 4 => self.transform_impl::<4, T>(vectors),
151 8 => self.transform_impl::<8, T>(vectors),
152 _ => Err(Error::index(format!(
153 "ProductQuantization: num_bits {} not supported",
154 self.num_bits
155 ))),
156 }
157 }
158
159 fn transform_impl<const NUM_BITS: u32, T: ArrowPrimitiveType>(
160 &self,
161 vectors: &dyn Array,
162 ) -> Result<ArrayRef>
163 where
164 T::Native: Float + L2 + Dot,
165 {
166 let fsl = vectors
167 .as_fixed_size_list_opt()
168 .ok_or(Error::index(format!(
169 "Expect to be a FixedSizeList<float> vector array, got: {:?} array",
170 vectors.data_type()
171 )))?;
172 let num_sub_vectors = self.num_sub_vectors;
173 let dim = self.dimension;
174 if NUM_BITS == 4 && !num_sub_vectors.is_multiple_of(2) {
175 return Err(Error::index(format!(
176 "PQ: num_sub_vectors must be divisible by 2 for num_bits=4, but got {}",
177 num_sub_vectors,
178 )));
179 }
180 let codebook = self.codebook.values().as_primitive::<T>();
181
182 let distance_type = self.distance_type;
183
184 let flatten_data = fsl.values().as_primitive::<T>();
185 let sub_dim = dim / num_sub_vectors;
186 let num_centroids = 2_usize.pow(NUM_BITS);
187 let total_code_length = fsl.len() * num_sub_vectors / (8 / NUM_BITS as usize);
188
189 let values = if let Some(targets) = &self.l2_targets {
190 let flat_f32: &[f32] = unsafe {
193 std::slice::from_raw_parts(
194 flatten_data.values().as_ptr() as *const f32,
195 flatten_data.values().len(),
196 )
197 };
198 let mut values = vec![0u8; total_code_length];
199 let bytes_per_vector = num_sub_vectors / (8 / NUM_BITS as usize);
200 let mut dist_buf = vec![0.0f32; num_centroids];
201 for (vec_idx, vector) in flat_f32.chunks_exact(dim).enumerate() {
202 let out = &mut values[vec_idx * bytes_per_vector..][..bytes_per_vector];
203 if NUM_BITS == 4 {
204 for (pair_idx, pair) in vector.chunks_exact(sub_dim * 2).enumerate() {
205 let lo = targets[pair_idx * 2]
206 .nearest_into(&pair[..sub_dim], &mut dist_buf)
207 .unwrap_or(0) as u8;
208 let hi = targets[pair_idx * 2 + 1]
209 .nearest_into(&pair[sub_dim..], &mut dist_buf)
210 .unwrap_or(0) as u8;
211 out[pair_idx] = (hi << 4) | lo;
212 }
213 } else {
214 for (sub_idx, sv) in vector.chunks_exact(sub_dim).enumerate() {
215 out[sub_idx] = targets[sub_idx]
216 .nearest_into(sv, &mut dist_buf)
217 .unwrap_or(0) as u8;
218 }
219 }
220 }
221 values
222 } else {
223 flatten_data
224 .values()
225 .chunks_exact(dim)
226 .flat_map(|vector| {
227 let sub_vec_code: Vec<u8> = vector
228 .chunks_exact(sub_dim)
229 .enumerate()
230 .map(|(sub_idx, sub_vector)| {
231 let centroids = get_sub_vector_centroids::<NUM_BITS, _>(
232 codebook.values(),
233 dim,
234 num_sub_vectors,
235 sub_idx,
236 );
237 assume_eq!(centroids.len(), num_centroids * sub_dim);
238 compute_partition(centroids, sub_vector, distance_type).unwrap_or(0)
239 as u8
240 })
241 .collect();
242 if NUM_BITS == 4 {
243 sub_vec_code
244 .chunks_exact(2)
245 .map(|v| (v[1] << 4) | v[0])
246 .collect::<Vec<_>>()
247 } else {
248 sub_vec_code
249 }
250 })
251 .exact_size(total_code_length)
252 .collect::<Vec<_>>()
253 };
254
255 let num_sub_vectors_in_byte = if NUM_BITS == 4 {
256 num_sub_vectors / 2
257 } else {
258 num_sub_vectors
259 };
260
261 debug_assert_eq!(values.len(), fsl.len() * num_sub_vectors_in_byte);
262 Ok(Arc::new(FixedSizeListArray::try_new_from_values(
263 UInt8Array::from(values),
264 num_sub_vectors_in_byte as i32,
265 )?))
266 }
267
268 pub fn compute_distances(&self, query: &dyn Array, code: &UInt8Array) -> Result<Float32Array> {
270 if code.is_empty() {
271 return Ok(Float32Array::from(Vec::<f32>::new()));
272 }
273
274 match self.distance_type {
275 DistanceType::L2 => self.l2_distances(query, code),
276 DistanceType::Cosine => {
277 debug_assert!(
280 false,
281 "cosine distance should be converted to normalized L2 distance"
282 );
283 let l2_dists = self.l2_distances(query, code)?;
287 Ok(l2_dists.values().iter().map(|v| *v / 2.0).collect())
288 }
289 DistanceType::Dot => self.dot_distances(query, code),
290 _ => panic!(
291 "ProductQuantization: distance type {} not supported",
292 self.distance_type
293 ),
294 }
295 }
296
297 fn l2_distances(&self, key: &dyn Array, code: &UInt8Array) -> Result<Float32Array> {
301 let distance_table = self.build_l2_distance_table(key)?;
302 Ok(self.compute_l2_distance(&distance_table, code.values()))
303 }
304
305 fn dot_distances(&self, key: &dyn Array, code: &UInt8Array) -> Result<Float32Array> {
311 match key.data_type() {
312 DataType::Float16 => {
313 self.dot_distances_impl::<datatypes::Float16Type>(key.as_primitive(), code)
314 }
315 DataType::Float32 => {
316 self.dot_distances_impl::<datatypes::Float32Type>(key.as_primitive(), code)
317 }
318 DataType::Float64 => {
319 self.dot_distances_impl::<datatypes::Float64Type>(key.as_primitive(), code)
320 }
321 _ => Err(Error::index(format!(
322 "unsupported data type: {}",
323 key.data_type()
324 ))),
325 }
326 }
327
328 fn dot_distances_impl<T: ArrowPrimitiveType>(
329 &self,
330 key: &PrimitiveArray<T>,
331 code: &UInt8Array,
332 ) -> Result<Float32Array>
333 where
334 T::Native: Dot,
335 {
336 let distance_table = build_distance_table_dot(
337 self.codebook.values().as_primitive::<T>().values(),
338 self.num_bits,
339 self.num_sub_vectors,
340 key.values(),
341 );
342
343 let distances = compute_pq_distance(
344 &distance_table,
345 self.num_bits,
346 self.num_sub_vectors,
347 code.values(),
348 0,
349 );
350
351 let diff = self.num_sub_vectors as f32 - 1.0;
352 let distances = distances.into_iter().map(|d| d - diff).collect::<Vec<_>>();
353 Ok(distances.into())
354 }
355
356 fn build_l2_distance_table(&self, key: &dyn Array) -> Result<Vec<f32>> {
357 match key.data_type() {
358 DataType::Float16 => {
359 Ok(self.build_l2_distance_table_impl::<datatypes::Float16Type>(key.as_primitive()))
360 }
361 DataType::Float32 => {
362 if let Some(targets) = &self.l2_targets {
363 let query = key.as_primitive::<datatypes::Float32Type>().values();
364 Ok(build_distance_table_l2_prepared(targets, query))
365 } else {
366 Ok(self
367 .build_l2_distance_table_impl::<datatypes::Float32Type>(key.as_primitive()))
368 }
369 }
370 DataType::Float64 => {
371 Ok(self.build_l2_distance_table_impl::<datatypes::Float64Type>(key.as_primitive()))
372 }
373 _ => Err(Error::index(format!(
374 "unsupported data type: {}",
375 key.data_type()
376 ))),
377 }
378 }
379
380 fn build_l2_distance_table_impl<T: ArrowPrimitiveType>(
381 &self,
382 key: &PrimitiveArray<T>,
383 ) -> Vec<f32>
384 where
385 T::Native: L2,
386 {
387 build_distance_table_l2(
388 self.codebook.values().as_primitive::<T>().values(),
389 self.num_bits,
390 self.num_sub_vectors,
391 key.values(),
392 )
393 }
394
395 #[inline]
412 fn compute_l2_distance(&self, distance_table: &[f32], code: &[u8]) -> Float32Array {
413 Float32Array::from(compute_pq_distance(
414 distance_table,
415 self.num_bits,
416 self.num_sub_vectors,
417 code,
418 100,
419 ))
420 }
421
422 pub fn centroids<T: ArrowPrimitiveType>(&self, sub_vector_idx: usize) -> &[T::Native] {
426 match self.num_bits {
427 4 => get_sub_vector_centroids::<4, _>(
428 self.codebook.values().as_primitive::<T>().values(),
429 self.dimension,
430 self.num_sub_vectors,
431 sub_vector_idx,
432 ),
433 8 => get_sub_vector_centroids::<8, _>(
434 self.codebook.values().as_primitive::<T>().values(),
435 self.dimension,
436 self.num_sub_vectors,
437 sub_vector_idx,
438 ),
439 _ => panic!(
440 "ProductQuantization: num_bits {} not supported",
441 self.num_bits
442 ),
443 }
444 }
445}
446
447impl Quantization for ProductQuantizer {
448 type BuildParams = PQBuildParams;
449 type Metadata = ProductQuantizationMetadata;
450 type Storage = ProductQuantizationStorage;
451
452 fn build(
453 data: &dyn Array,
454 distance_type: DistanceType,
455 params: &Self::BuildParams,
456 ) -> Result<Self> {
457 assert_eq!(data.null_count(), 0);
458 let fsl = data.as_fixed_size_list_opt().ok_or(Error::index(format!(
459 "PQ builder: input is not a FixedSizeList: {}",
460 data.data_type()
461 )))?;
462
463 if let Some(codebook) = params.codebook.as_ref() {
464 return Ok(Self::new(
465 params.num_sub_vectors,
466 params.num_bits as u32,
467 fsl.value_length() as usize,
468 FixedSizeListArray::try_new_from_values(codebook.clone(), fsl.value_length())?,
469 distance_type,
470 ));
471 }
472
473 params.build(data, distance_type)
474 }
475
476 fn retrain(&mut self, data: &dyn Array) -> Result<()> {
477 assert_eq!(data.null_count(), 0);
478 let params = PQBuildParams::with_codebook(
479 self.num_sub_vectors,
480 self.num_bits as usize,
481 Arc::new(self.codebook.clone()),
482 );
483
484 *self = params.build(data, self.distance_type)?;
485 Ok(())
486 }
487
488 fn code_dim(&self) -> usize {
489 self.num_sub_vectors
490 }
491
492 fn column(&self) -> &'static str {
493 PQ_CODE_COLUMN
494 }
495
496 fn use_residual(distance_type: DistanceType) -> bool {
497 PQBuildParams::use_residual(distance_type)
498 }
499
500 fn quantize(&self, vectors: &dyn Array) -> Result<ArrayRef> {
501 let fsl = vectors
502 .as_fixed_size_list_opt()
503 .ok_or(Error::index(format!(
504 "Expect to be a FixedSizeList<float> vector array, got: {:?} array",
505 vectors.data_type()
506 )))?;
507
508 match fsl.value_type() {
509 DataType::Float16 => self.transform::<datatypes::Float16Type>(vectors),
510 DataType::Float32 => self.transform::<datatypes::Float32Type>(vectors),
511 DataType::Float64 => self.transform::<datatypes::Float64Type>(vectors),
512 _ => Err(Error::index(format!(
513 "unsupported data type: {}",
514 fsl.value_type()
515 ))),
516 }
517 }
518
519 fn metadata_key() -> &'static str {
520 PQ_METADATA_KEY
521 }
522
523 fn quantization_type() -> QuantizationType {
524 QuantizationType::Product
525 }
526
527 fn metadata(&self, args: Option<QuantizationMetadata>) -> Self::Metadata {
528 let codebook_position = match &args {
529 Some(args) => args.codebook_position,
530 None => Some(0),
531 };
532
533 let codebook_position = codebook_position.expect("codebook position should be set");
534 ProductQuantizationMetadata {
535 codebook_position,
536 nbits: self.num_bits,
537 num_sub_vectors: self.num_sub_vectors,
538 dimension: self.dimension,
539 codebook: Some(self.codebook.clone()),
540 codebook_tensor: Vec::new(),
541 transposed: args.map(|args| args.transposed).unwrap_or_default(),
542 }
543 }
544
545 fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result<Quantizer> {
546 let distance_type = match distance_type {
547 DistanceType::Cosine => DistanceType::L2,
548 _ => distance_type,
549 };
550 let codebook = match metadata.codebook.as_ref() {
551 Some(fsl) => fsl.clone(),
552 None => {
553 let tensor = pb::Tensor::decode(metadata.codebook_tensor.as_ref())?;
554 FixedSizeListArray::try_from(&tensor)?
555 }
556 };
557 Ok(Quantizer::Product(Self::new(
558 metadata.num_sub_vectors,
559 metadata.nbits,
560 metadata.dimension,
561 codebook,
562 distance_type,
563 )))
564 }
565
566 fn field(&self) -> Field {
567 let num_bytes_per_sub_vector = self.num_sub_vectors * self.num_bits as usize / 8;
568 Field::new(
569 PQ_CODE_COLUMN,
570 DataType::FixedSizeList(
571 Arc::new(Field::new("item", DataType::UInt8, true)),
572 num_bytes_per_sub_vector as i32,
573 ),
574 true,
575 )
576 }
577}
578
579impl TryFrom<&ProductQuantizer> for pb::Pq {
580 type Error = Error;
581
582 fn try_from(pq: &ProductQuantizer) -> Result<Self> {
583 let tensor = pb::Tensor::try_from(&pq.codebook)?;
584 Ok(Self {
585 num_bits: pq.num_bits,
586 num_sub_vectors: pq.num_sub_vectors as u32,
587 dimension: pq.dimension as u32,
588 codebook: vec![],
589 codebook_tensor: Some(tensor),
590 })
591 }
592}
593
594impl TryFrom<Quantizer> for ProductQuantizer {
595 type Error = Error;
596 fn try_from(value: Quantizer) -> Result<Self> {
597 match value {
598 Quantizer::Product(pq) => Ok(pq),
599 _ => Err(Error::index("Expect to be a ProductQuantizer".to_string())),
600 }
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607
608 use std::iter::repeat_n;
609
610 use approx::assert_relative_eq;
611 use arrow::datatypes::UInt8Type;
612 use arrow_array::Float16Array;
613 use half::f16;
614 use lance_linalg::distance::l2_distance_batch;
615 use lance_linalg::kernels::argmin;
616 use lance_testing::datagen::generate_random_array;
617 use num_traits::Zero;
618 use storage::transpose;
619
620 #[test]
621 fn test_f16_pq_to_protobuf() {
622 let pq = ProductQuantizer::new(
623 4,
624 8,
625 16,
626 FixedSizeListArray::try_new_from_values(
627 Float16Array::from_iter_values(repeat_n(f16::zero(), 256 * 16)),
628 16,
629 )
630 .unwrap(),
631 DistanceType::L2,
632 );
633 let proto: pb::Pq = pb::Pq::try_from(&pq).unwrap();
634 assert_eq!(proto.num_bits, 8);
635 assert_eq!(proto.num_sub_vectors, 4);
636 assert_eq!(proto.dimension, 16);
637 assert!(proto.codebook.is_empty());
638 assert!(proto.codebook_tensor.is_some());
639
640 let tensor = proto.codebook_tensor.as_ref().unwrap();
641 assert_eq!(tensor.data_type, pb::tensor::DataType::Float16 as i32);
642 assert_eq!(tensor.shape, vec![256, 16]);
643 }
644
645 #[test]
646 fn test_l2_distance() {
647 const DIM: usize = 512;
648 const TOTAL: usize = 66; let codebook = generate_random_array(256 * DIM);
650 let pq = ProductQuantizer::new(
651 16,
652 8,
653 DIM,
654 FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(),
655 DistanceType::L2,
656 );
657 let pq_code = UInt8Array::from_iter_values((0..16 * TOTAL).map(|v| v as u8));
658 let query = generate_random_array(DIM);
659
660 let transposed_pq_codes = transpose(&pq_code, TOTAL, 16);
661 let dists = pq.compute_distances(&query, &transposed_pq_codes).unwrap();
662
663 let sub_vec_len = DIM / 16;
664 let expected = pq_code
665 .values()
666 .chunks(16)
667 .map(|code| {
668 code.iter()
669 .enumerate()
670 .flat_map(|(sub_idx, c)| {
671 let subvec_centroids = pq.centroids::<datatypes::Float32Type>(sub_idx);
672 let subvec =
673 &query.values()[sub_idx * sub_vec_len..(sub_idx + 1) * sub_vec_len];
674 l2_distance_batch(
675 subvec,
676 &subvec_centroids
677 [*c as usize * sub_vec_len..(*c as usize + 1) * sub_vec_len],
678 sub_vec_len,
679 )
680 })
681 .sum::<f32>()
682 })
683 .collect::<Vec<_>>();
684 dists
685 .values()
686 .iter()
687 .zip(expected.iter())
688 .for_each(|(v, e)| {
689 assert_relative_eq!(*v, *e, epsilon = 1e-4);
690 });
691 }
692
693 #[test]
694 fn test_pq_transform() {
695 const DIM: usize = 16;
696 const TOTAL: usize = 64;
697 let codebook = generate_random_array(DIM * 256);
698 let pq = ProductQuantizer::new(
699 4,
700 8,
701 DIM,
702 FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(),
703 DistanceType::L2,
704 );
705
706 let vectors = generate_random_array(DIM * TOTAL);
707 let fsl = FixedSizeListArray::try_new_from_values(vectors.clone(), DIM as i32).unwrap();
708 let pq_code = pq.quantize(&fsl).unwrap();
709
710 let mut expected = Vec::with_capacity(TOTAL * 4);
711 vectors.values().chunks_exact(DIM).for_each(|vec| {
712 vec.chunks_exact(DIM / 4)
713 .enumerate()
714 .for_each(|(sub_idx, sub_vec)| {
715 let centroids = pq.centroids::<datatypes::Float32Type>(sub_idx);
716 let dists = l2_distance_batch(sub_vec, centroids, DIM / 4);
717 let code = argmin(dists).unwrap() as u8;
718 expected.push(code);
719 });
720 });
721
722 assert_eq!(pq_code.len(), TOTAL);
723 assert_eq!(
724 &expected,
725 pq_code
726 .as_fixed_size_list()
727 .values()
728 .as_primitive::<UInt8Type>()
729 .values()
730 );
731 }
732}