1use std::{cmp::min, collections::HashMap, sync::Arc};
9
10use arrow::datatypes::{self, UInt8Type};
11use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
12use arrow_array::{
13 FixedSizeListArray, RecordBatch, UInt8Array, UInt64Array,
14 cast::AsArray,
15 types::{Float32Type, UInt64Type},
16};
17use arrow_schema::{DataType, SchemaRef};
18use async_trait::async_trait;
19use bytes::{Bytes, BytesMut};
20use deepsize::DeepSizeOf;
21use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt};
22use lance_core::{Error, ROW_ID, Result};
23use lance_file::previous::{
24 reader::FileReader as PreviousFileReader, writer::FileWriter as PreviousFileWriter,
25};
26use lance_io::{object_store::ObjectStore, utils::read_message};
27use lance_linalg::distance::{DistanceType, Dot, L2};
28use lance_table::utils::LanceIteratorExtension;
29use lance_table::{format::SelfDescribingFileReader, io::manifest::ManifestDescribing};
30use object_store::path::Path;
31use prost::Message;
32use serde::{Deserialize, Serialize};
33
34use super::ProductQuantizer;
35use super::distance::{build_distance_table_dot, build_distance_table_l2, compute_pq_distance};
36use crate::frag_reuse::FragReuseIndex;
37use crate::vector::graph::{OrderedFloat, OrderedNode};
38use crate::{
39 INDEX_METADATA_SCHEMA_KEY, IndexMetadata, pb,
40 vector::{
41 PQ_CODE_COLUMN,
42 pq::transform::PQTransformer,
43 quantizer::{QuantizerMetadata, QuantizerStorage},
44 storage::{DistCalculator, VectorStore},
45 transform::Transformer,
46 },
47};
48
49pub const PQ_METADATA_KEY: &str = "lance:pq";
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ProductQuantizationMetadata {
53 pub codebook_position: usize,
54 pub nbits: u32,
55 pub num_sub_vectors: usize,
56 pub dimension: usize,
57
58 #[serde(skip)]
59 pub codebook: Option<FixedSizeListArray>,
60
61 pub codebook_tensor: Vec<u8>,
65 pub transposed: bool,
66}
67
68impl DeepSizeOf for ProductQuantizationMetadata {
69 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
70 self.codebook
71 .as_ref()
72 .map(|codebook| codebook.get_array_memory_size())
73 .unwrap_or(0)
74 }
75}
76
77impl PartialEq for ProductQuantizationMetadata {
78 fn eq(&self, other: &Self) -> bool {
79 self.num_sub_vectors == other.num_sub_vectors
80 && self.nbits == other.nbits
81 && self.dimension == other.dimension
82 && self.codebook == other.codebook
83 }
84}
85
86#[async_trait]
87impl QuantizerMetadata for ProductQuantizationMetadata {
88 fn buffer_index(&self) -> Option<u32> {
89 if self.codebook_position > 0 {
90 Some(self.codebook_position as u32)
92 } else {
93 None
94 }
95 }
96
97 fn set_buffer_index(&mut self, index: u32) {
98 self.codebook_position = index as usize;
99 }
100
101 fn parse_buffer(&mut self, bytes: Bytes) -> Result<()> {
102 debug_assert!(!bytes.is_empty());
103 debug_assert!(self.codebook.is_none());
104 let codebook_tensor: pb::Tensor = pb::Tensor::decode(bytes)?;
105 self.codebook = Some(FixedSizeListArray::try_from(&codebook_tensor)?);
106 Ok(())
107 }
108
109 fn extra_metadata(&self) -> Result<Option<Bytes>> {
110 if let Some(codebook) = &self.codebook {
111 let codebook_tensor: pb::Tensor = pb::Tensor::try_from(codebook)?;
112 let mut bytes = BytesMut::new();
113 codebook_tensor.encode(&mut bytes)?;
114 Ok(Some(bytes.freeze()))
115 } else if !self.codebook_tensor.is_empty() {
116 Ok(Some(Bytes::from(self.codebook_tensor.clone())))
120 } else {
121 Ok(None)
122 }
123 }
124
125 async fn load(reader: &PreviousFileReader) -> Result<Self> {
126 let metadata = reader
127 .schema()
128 .metadata
129 .get(PQ_METADATA_KEY)
130 .ok_or(Error::index(format!(
131 "Reading PQ storage: metadata key {} not found",
132 PQ_METADATA_KEY
133 )))?;
134 let mut metadata: Self = serde_json::from_str(metadata)
135 .map_err(|_| Error::index(format!("Failed to parse PQ metadata: {}", metadata)))?;
136
137 debug_assert!(metadata.codebook.is_none());
138 debug_assert!(metadata.codebook_tensor.is_empty());
139
140 let codebook_tensor: pb::Tensor =
141 read_message(reader.object_reader.as_ref(), metadata.codebook_position).await?;
142 metadata.codebook = Some(FixedSizeListArray::try_from(&codebook_tensor)?);
143 Ok(metadata)
144 }
145}
146
147#[derive(Clone, Debug)]
153pub struct ProductQuantizationStorage {
154 metadata: ProductQuantizationMetadata,
155 distance_type: DistanceType,
156 batch: RecordBatch,
157
158 pq_code: Arc<UInt8Array>,
160 row_ids: Arc<UInt64Array>,
161}
162
163impl DeepSizeOf for ProductQuantizationStorage {
164 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
165 self.batch.get_array_memory_size()
166 + self
167 .metadata
168 .codebook
169 .as_ref()
170 .map(|codebook| codebook.get_array_memory_size())
171 .unwrap_or(0)
172 }
173}
174
175impl PartialEq for ProductQuantizationStorage {
176 fn eq(&self, other: &Self) -> bool {
177 self.distance_type == other.distance_type
178 && self.metadata.eq(&other.metadata)
179 && self.batch.columns().eq(other.batch.columns())
180 }
181}
182
183impl ProductQuantizationStorage {
184 #[allow(clippy::too_many_arguments)]
185 pub fn new(
186 codebook: FixedSizeListArray,
187 mut batch: RecordBatch,
188 num_bits: u32,
189 num_sub_vectors: usize,
190 dimension: usize,
191 distance_type: DistanceType,
192 transposed: bool,
193 frag_reuse_index: Option<Arc<FragReuseIndex>>,
194 ) -> Result<Self> {
195 if batch.num_columns() != 2 {
196 log::warn!(
197 "PQ storage should have 2 columns, but got {} columns: {}",
198 batch.num_columns(),
199 batch.schema(),
200 );
201 batch = batch.project(&[
202 batch.schema().index_of(ROW_ID)?,
203 batch.schema().index_of(PQ_CODE_COLUMN)?,
204 ])?;
205 }
206
207 let Some(row_ids) = batch.column_by_name(ROW_ID) else {
208 return Err(Error::index(
209 "Row ID column not found from PQ storage".to_string(),
210 ));
211 };
212 let row_ids: Arc<UInt64Array> = row_ids
213 .as_primitive_opt::<UInt64Type>()
214 .ok_or(Error::index(
215 "Row ID column is not of type UInt64".to_string(),
216 ))?
217 .clone()
218 .into();
219
220 if !transposed {
221 let num_sub_vectors_in_byte = if num_bits == 4 {
222 num_sub_vectors / 2
223 } else {
224 num_sub_vectors
225 };
226 let pq_col = batch[PQ_CODE_COLUMN].as_fixed_size_list();
227 let transposed_code = transpose(
228 pq_col.values().as_primitive::<UInt8Type>(),
229 row_ids.len(),
230 num_sub_vectors_in_byte,
231 );
232 let pq_code_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
233 transposed_code,
234 num_sub_vectors_in_byte as i32,
235 )?);
236 batch = batch.replace_column_by_name(PQ_CODE_COLUMN, pq_code_fsl)?;
237 }
238
239 let mut pq_code: Arc<UInt8Array> = batch[PQ_CODE_COLUMN]
240 .as_fixed_size_list()
241 .values()
242 .as_primitive()
243 .clone()
244 .into();
245
246 if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
247 let transposed_codes = pq_code.values();
248 let mut new_row_ids = Vec::with_capacity(row_ids.len());
249 let mut new_codes = Vec::with_capacity(row_ids.len() * num_sub_vectors);
250
251 let row_ids_values = row_ids.values();
252 for (i, row_id) in row_ids_values.iter().enumerate() {
253 if let Some(mapped_value) = frag_reuse_index_ref.remap_row_id(*row_id) {
254 new_row_ids.push(mapped_value);
255 new_codes.extend(get_pq_code(
256 transposed_codes,
257 num_bits,
258 num_sub_vectors,
259 i as u32,
260 ));
261 }
262 }
263
264 let new_row_ids = Arc::new(UInt64Array::from(new_row_ids));
265 let new_codes = UInt8Array::from(new_codes);
266 batch = if new_row_ids.is_empty() {
267 RecordBatch::new_empty(batch.schema())
268 } else {
269 let num_bytes_in_code = new_codes.len() / new_row_ids.len();
270 let new_transposed_codes =
271 transpose(&new_codes, new_row_ids.len(), num_bytes_in_code);
272 let codes_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
273 new_transposed_codes,
274 num_bytes_in_code as i32,
275 )?);
276 RecordBatch::try_new(batch.schema(), vec![new_row_ids, codes_fsl])?
277 };
278 pq_code = batch[PQ_CODE_COLUMN]
279 .as_fixed_size_list()
280 .values()
281 .as_primitive::<UInt8Type>()
282 .clone()
283 .into();
284 }
285
286 let distance_type = match distance_type {
287 DistanceType::Cosine => DistanceType::L2,
288 _ => distance_type,
289 };
290 let metadata = ProductQuantizationMetadata {
291 codebook_position: 0,
292 nbits: num_bits,
293 num_sub_vectors,
294 dimension,
295 codebook: Some(codebook),
296 codebook_tensor: Vec::new(), transposed: true,
298 };
299 Ok(Self {
300 metadata,
301 distance_type,
302 batch,
303 pq_code,
304 row_ids,
305 })
306 }
307
308 pub fn batch(&self) -> &RecordBatch {
309 &self.batch
310 }
311
312 pub async fn build(
323 quantizer: ProductQuantizer,
324 batch: &RecordBatch,
325 vector_col: &str,
326 frag_reuse_index: Option<Arc<FragReuseIndex>>,
327 ) -> Result<Self> {
328 let codebook = quantizer.codebook.clone();
329 let num_bits = quantizer.num_bits;
330 let dimension = quantizer.dimension;
331 let num_sub_vectors = quantizer.num_sub_vectors;
332 let metric_type = quantizer.distance_type;
333 let transform = PQTransformer::new(quantizer, vector_col, PQ_CODE_COLUMN);
334 let batch = transform.transform(batch)?;
335 Self::new(
336 codebook,
337 batch,
338 num_bits,
339 num_sub_vectors,
340 dimension,
341 metric_type,
342 false,
343 frag_reuse_index,
344 )
345 }
346
347 pub fn codebook(&self) -> &FixedSizeListArray {
348 self.metadata.codebook.as_ref().unwrap()
349 }
350
351 pub async fn load(
367 object_store: &ObjectStore,
368 path: &Path,
369 frag_reuse_index: Option<Arc<FragReuseIndex>>,
370 ) -> Result<Self> {
371 let reader = PreviousFileReader::try_new_self_described(object_store, path, None).await?;
372 let schema = reader.schema();
373
374 let metadata_str = schema
375 .metadata
376 .get(INDEX_METADATA_SCHEMA_KEY)
377 .ok_or(Error::index(format!(
378 "Reading PQ storage: index key {} not found",
379 INDEX_METADATA_SCHEMA_KEY
380 )))?;
381 let index_metadata: IndexMetadata = serde_json::from_str(metadata_str).map_err(|_| {
382 Error::index(format!("Failed to parse index metadata: {}", metadata_str))
383 })?;
384 let distance_type: DistanceType =
385 DistanceType::try_from(index_metadata.distance_type.as_str())?;
386
387 let metadata = ProductQuantizationMetadata::load(&reader).await?;
388 Self::load_partition(
389 &reader,
390 0..reader.len(),
391 distance_type,
392 &metadata,
393 frag_reuse_index,
394 )
395 .await
396 }
397
398 pub fn schema(&self) -> SchemaRef {
399 self.batch.schema()
400 }
401
402 pub fn get_row_ids(&self, ids: &[u32]) -> Vec<u64> {
403 ids.iter()
404 .map(|&id| self.row_ids.value(id as usize))
405 .collect()
406 }
407
408 pub async fn write_partition(
412 &self,
413 writer: &mut PreviousFileWriter<ManifestDescribing>,
414 ) -> Result<usize> {
415 let batch_size: usize = 10240; for offset in (0..self.batch.num_rows()).step_by(batch_size) {
417 let length = min(batch_size, self.batch.num_rows() - offset);
418 let slice = self.batch.slice(offset, length);
419 writer.write(&[slice]).await?;
420 }
421 Ok(self.batch.num_rows())
422 }
423}
424
425pub fn transpose<T: ArrowPrimitiveType>(
426 original: &PrimitiveArray<T>,
427 num_rows: usize,
428 num_columns: usize,
429) -> PrimitiveArray<T>
430where
431 PrimitiveArray<T>: From<Vec<T::Native>>,
432{
433 if original.is_empty() {
434 return original.clone();
435 }
436
437 let mut transposed_codes = vec![T::default_value(); original.len()];
438 for (vec_idx, codes) in original.values().chunks_exact(num_columns).enumerate() {
439 for (sub_vec_idx, code) in codes.iter().enumerate() {
440 transposed_codes[sub_vec_idx * num_rows + vec_idx] = *code;
441 }
442 }
443
444 transposed_codes.into()
445}
446
447#[async_trait]
448impl QuantizerStorage for ProductQuantizationStorage {
449 type Metadata = ProductQuantizationMetadata;
450
451 fn try_from_batch(
452 batch: RecordBatch,
453 metadata: &Self::Metadata,
454 distance_type: DistanceType,
455 frag_reuse_index: Option<Arc<FragReuseIndex>>,
456 ) -> Result<Self>
457 where
458 Self: Sized,
459 {
460 let distance_type = match distance_type {
461 DistanceType::Cosine => DistanceType::L2,
462 _ => distance_type,
463 };
464
465 let codebook = match &metadata.codebook {
467 Some(codebook) => codebook.clone(),
468 None => {
469 debug_assert!(!metadata.codebook_tensor.is_empty());
471 let codebook_tensor = pb::Tensor::decode(metadata.codebook_tensor.as_slice())?;
472 FixedSizeListArray::try_from(&codebook_tensor)?
473 }
474 };
475
476 Self::new(
477 codebook,
478 batch,
479 metadata.nbits,
480 metadata.num_sub_vectors,
481 metadata.dimension,
482 distance_type,
483 metadata.transposed,
484 frag_reuse_index,
485 )
486 }
487
488 fn metadata(&self) -> &Self::Metadata {
489 &self.metadata
490 }
491
492 fn remap(&self, mapping: &HashMap<u64, Option<u64>>) -> Result<Self> {
495 let transposed_codes = self.pq_code.values();
496 let mut new_row_ids = Vec::with_capacity(self.len());
497 let mut new_codes = Vec::with_capacity(self.len() * self.metadata.num_sub_vectors);
498
499 let row_ids = self.row_ids.values();
500 for (i, row_id) in row_ids.iter().enumerate() {
501 match mapping.get(row_id) {
502 Some(Some(new_id)) => {
503 new_row_ids.push(*new_id);
504 new_codes.extend(get_pq_code(
505 transposed_codes,
506 self.metadata.nbits,
507 self.metadata.num_sub_vectors,
508 i as u32,
509 ));
510 }
511 Some(None) => {}
512 None => {
513 new_row_ids.push(*row_id);
514 new_codes.extend(get_pq_code(
515 transposed_codes,
516 self.metadata.nbits,
517 self.metadata.num_sub_vectors,
518 i as u32,
519 ));
520 }
521 }
522 }
523
524 let new_row_ids = Arc::new(UInt64Array::from(new_row_ids));
525 let new_codes = UInt8Array::from(new_codes);
526 let batch = if new_row_ids.is_empty() {
527 RecordBatch::new_empty(self.schema())
528 } else {
529 let num_bytes_in_code = new_codes.len() / new_row_ids.len();
530 let new_transposed_codes = transpose(&new_codes, new_row_ids.len(), num_bytes_in_code);
531 let codes_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
532 new_transposed_codes,
533 num_bytes_in_code as i32,
534 )?);
535 RecordBatch::try_new(self.schema(), vec![new_row_ids.clone(), codes_fsl])?
536 };
537 let transposed_codes = batch[PQ_CODE_COLUMN]
538 .as_fixed_size_list()
539 .values()
540 .as_primitive::<UInt8Type>()
541 .clone();
542
543 Ok(Self {
544 metadata: self.metadata.clone(),
545 distance_type: self.distance_type,
546 batch,
547 pq_code: Arc::new(transposed_codes),
548 row_ids: new_row_ids,
549 })
550 }
551
552 async fn load_partition(
558 reader: &PreviousFileReader,
559 range: std::ops::Range<usize>,
560 distance_type: DistanceType,
561 metadata: &Self::Metadata,
562 frag_reuse_index: Option<Arc<FragReuseIndex>>,
563 ) -> Result<Self> {
564 let codebook = metadata
566 .codebook
567 .as_ref()
568 .ok_or(Error::index(
569 "Codebook not found in PQ metadata".to_string(),
570 ))?
571 .values()
572 .as_primitive::<Float32Type>()
573 .clone();
574
575 let codebook =
576 FixedSizeListArray::try_new_from_values(codebook, metadata.dimension as i32)?;
577
578 let schema = reader.schema();
579 let batch = reader.read_range(range, schema).await?;
580
581 Self::new(
582 codebook,
583 batch,
584 metadata.nbits,
585 metadata.num_sub_vectors,
586 metadata.dimension,
587 distance_type,
588 metadata.transposed,
589 frag_reuse_index,
590 )
591 }
592}
593
594impl VectorStore for ProductQuantizationStorage {
595 type DistanceCalculator<'a> = PQDistCalculator;
596
597 fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
598 Ok(std::iter::once(self.batch.clone()))
599 }
600
601 fn append_batch(&self, _batch: RecordBatch, _vector_column: &str) -> Result<Self> {
602 unimplemented!()
603 }
604
605 fn schema(&self) -> &SchemaRef {
606 self.batch.schema_ref()
607 }
608
609 fn as_any(&self) -> &dyn std::any::Any {
610 self
611 }
612
613 fn len(&self) -> usize {
614 self.batch.num_rows()
615 }
616
617 fn distance_type(&self) -> DistanceType {
618 self.distance_type
619 }
620
621 fn row_id(&self, id: u32) -> u64 {
622 self.row_ids.values()[id as usize]
623 }
624
625 fn row_ids(&self) -> impl Iterator<Item = &u64> {
626 self.row_ids.values().iter()
627 }
628
629 fn dist_calculator(&self, query: ArrayRef, _dist_q_c: f32) -> Self::DistanceCalculator<'_> {
630 let codebook = self.metadata.codebook.as_ref().unwrap();
631 match codebook.value_type() {
632 DataType::Float16 => PQDistCalculator::new(
633 codebook
634 .values()
635 .as_primitive::<datatypes::Float16Type>()
636 .values(),
637 self.metadata.nbits,
638 self.metadata.num_sub_vectors,
639 self.pq_code.clone(),
640 query.as_primitive::<datatypes::Float16Type>().values(),
641 self.distance_type,
642 ),
643 DataType::Float32 => PQDistCalculator::new(
644 codebook
645 .values()
646 .as_primitive::<datatypes::Float32Type>()
647 .values(),
648 self.metadata.nbits,
649 self.metadata.num_sub_vectors,
650 self.pq_code.clone(),
651 query.as_primitive::<datatypes::Float32Type>().values(),
652 self.distance_type,
653 ),
654 DataType::Float64 => PQDistCalculator::new(
655 codebook
656 .values()
657 .as_primitive::<datatypes::Float64Type>()
658 .values(),
659 self.metadata.nbits,
660 self.metadata.num_sub_vectors,
661 self.pq_code.clone(),
662 query.as_primitive::<datatypes::Float64Type>().values(),
663 self.distance_type,
664 ),
665 _ => unimplemented!("Unsupported data type: {:?}", codebook.value_type()),
666 }
667 }
668
669 fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
670 let codes = get_pq_code(
671 self.pq_code.values(),
672 self.metadata.nbits,
673 self.metadata.num_sub_vectors,
674 id,
675 );
676 let codebook = self.metadata.codebook.as_ref().unwrap();
677 match codebook.value_type() {
678 DataType::Float16 => {
679 let codebook = codebook
680 .values()
681 .as_primitive::<datatypes::Float16Type>()
682 .values();
683 let query = get_centroids(
684 codebook,
685 self.metadata.nbits,
686 self.metadata.num_sub_vectors,
687 self.metadata.dimension,
688 codes,
689 );
690 PQDistCalculator::new(
691 codebook,
692 self.metadata.nbits,
693 self.metadata.num_sub_vectors,
694 self.pq_code.clone(),
695 &query,
696 self.distance_type,
697 )
698 }
699 DataType::Float32 => {
700 let codebook = codebook
701 .values()
702 .as_primitive::<datatypes::Float32Type>()
703 .values();
704 let query = get_centroids(
705 codebook,
706 self.metadata.nbits,
707 self.metadata.num_sub_vectors,
708 self.metadata.dimension,
709 codes,
710 );
711 PQDistCalculator::new(
712 codebook,
713 self.metadata.nbits,
714 self.metadata.num_sub_vectors,
715 self.pq_code.clone(),
716 &query,
717 self.distance_type,
718 )
719 }
720 DataType::Float64 => {
721 let codebook = codebook
722 .values()
723 .as_primitive::<datatypes::Float64Type>()
724 .values();
725 let query = get_centroids(
726 codebook,
727 self.metadata.nbits,
728 self.metadata.num_sub_vectors,
729 self.metadata.dimension,
730 codes,
731 );
732 PQDistCalculator::new(
733 codebook,
734 self.metadata.nbits,
735 self.metadata.num_sub_vectors,
736 self.pq_code.clone(),
737 &query,
738 self.distance_type,
739 )
740 }
741 _ => unimplemented!("Unsupported data type: {:?}", codebook.value_type()),
742 }
743 }
744
745 fn dist_between(&self, u: u32, v: u32) -> f32 {
746 let pq_codes = self.pq_code.values();
749 let u_codes = get_pq_code(
750 pq_codes,
751 self.metadata.nbits,
752 self.metadata.num_sub_vectors,
753 u,
754 );
755 let v_codes = get_pq_code(
756 pq_codes,
757 self.metadata.nbits,
758 self.metadata.num_sub_vectors,
759 v,
760 );
761 let codebook = self.metadata.codebook.as_ref().unwrap();
762
763 match codebook.value_type() {
764 DataType::Float16 => {
765 let qu = get_centroids(
766 codebook
767 .values()
768 .as_primitive::<datatypes::Float16Type>()
769 .values(),
770 self.metadata.nbits,
771 self.metadata.num_sub_vectors,
772 self.metadata.dimension,
773 u_codes,
774 );
775 let qv = get_centroids(
776 codebook
777 .values()
778 .as_primitive::<datatypes::Float16Type>()
779 .values(),
780 self.metadata.nbits,
781 self.metadata.num_sub_vectors,
782 self.metadata.dimension,
783 v_codes,
784 );
785 self.distance_type.func()(&qu, &qv)
786 }
787 DataType::Float32 => {
788 let qu = get_centroids(
789 codebook
790 .values()
791 .as_primitive::<datatypes::Float32Type>()
792 .values(),
793 self.metadata.nbits,
794 self.metadata.num_sub_vectors,
795 self.metadata.dimension,
796 u_codes,
797 );
798 let qv = get_centroids(
799 codebook
800 .values()
801 .as_primitive::<datatypes::Float32Type>()
802 .values(),
803 self.metadata.nbits,
804 self.metadata.num_sub_vectors,
805 self.metadata.dimension,
806 v_codes,
807 );
808 self.distance_type.func()(&qu, &qv)
809 }
810 DataType::Float64 => {
811 let qu = get_centroids(
812 codebook
813 .values()
814 .as_primitive::<datatypes::Float64Type>()
815 .values(),
816 self.metadata.nbits,
817 self.metadata.num_sub_vectors,
818 self.metadata.dimension,
819 u_codes,
820 );
821 let qv = get_centroids(
822 codebook
823 .values()
824 .as_primitive::<datatypes::Float64Type>()
825 .values(),
826 self.metadata.nbits,
827 self.metadata.num_sub_vectors,
828 self.metadata.dimension,
829 v_codes,
830 );
831 self.distance_type.func()(&qu, &qv)
832 }
833 _ => unimplemented!("Unsupported data type: {:?}", codebook.value_type()),
834 }
835 }
836
837 fn prefers_candidate(&self, candidate: &OrderedNode, selected: &[OrderedNode]) -> bool {
838 selected
839 .iter()
840 .all(|other| candidate.dist < OrderedFloat(self.dist_between(candidate.id, other.id)))
841 }
842}
843
844pub struct PQDistCalculator {
846 distance_table: Vec<f32>,
847 pq_code: Arc<UInt8Array>,
848 num_sub_vectors: usize,
849 num_bits: u32,
850 distance_type: DistanceType,
851}
852
853impl PQDistCalculator {
854 fn new<T: L2 + Dot>(
855 codebook: &[T],
856 num_bits: u32,
857 num_sub_vectors: usize,
858 pq_code: Arc<UInt8Array>,
859 query: &[T],
860 distance_type: DistanceType,
861 ) -> Self {
862 let distance_table = match distance_type {
863 DistanceType::L2 | DistanceType::Cosine => {
864 build_distance_table_l2(codebook, num_bits, num_sub_vectors, query)
865 }
866 DistanceType::Dot => {
867 build_distance_table_dot(codebook, num_bits, num_sub_vectors, query)
868 }
869 _ => unimplemented!("DistanceType is not supported: {:?}", distance_type),
870 };
871 Self {
872 distance_table,
873 num_sub_vectors,
874 pq_code,
875 num_bits,
876 distance_type,
877 }
878 }
879
880 fn get_pq_code(&self, id: u32) -> impl Iterator<Item = usize> + '_ {
881 get_pq_code(
882 self.pq_code.values(),
883 self.num_bits,
884 self.num_sub_vectors,
885 id,
886 )
887 .map(|v| v as usize)
888 }
889}
890
891impl DistCalculator for PQDistCalculator {
892 fn distance(&self, id: u32) -> f32 {
893 let num_centroids = 2_usize.pow(self.num_bits);
894 let pq_code = self.get_pq_code(id);
895 let diff = self.num_sub_vectors as f32 - 1.0;
896 let dist = if self.num_bits == 4 {
897 pq_code
898 .enumerate()
899 .map(|(i, c)| {
900 let current_idx = c & 0x0F;
901 let next_idx = c >> 4;
902
903 self.distance_table[2 * i * num_centroids + current_idx]
904 + self.distance_table[(2 * i + 1) * num_centroids + next_idx]
905 })
906 .sum()
907 } else {
908 pq_code
909 .enumerate()
910 .map(|(i, c)| self.distance_table[i * num_centroids + c])
911 .sum()
912 };
913
914 if self.distance_type == DistanceType::Dot {
915 dist - diff
916 } else {
917 dist
918 }
919 }
920
921 fn distance_all(&self, k_hint: usize) -> Vec<f32> {
922 match self.distance_type {
923 DistanceType::L2 => compute_pq_distance(
924 &self.distance_table,
925 self.num_bits,
926 self.num_sub_vectors,
927 self.pq_code.values(),
928 k_hint,
929 ),
930 DistanceType::Cosine => {
931 debug_assert!(
934 false,
935 "cosine distance should be converted to normalized L2 distance"
936 );
937 let l2_dists = compute_pq_distance(
941 &self.distance_table,
942 self.num_bits,
943 self.num_sub_vectors,
944 self.pq_code.values(),
945 k_hint,
946 );
947 l2_dists.into_iter().map(|v| v / 2.0).collect()
948 }
949 DistanceType::Dot => {
950 let dot_dists = compute_pq_distance(
951 &self.distance_table,
952 self.num_bits,
953 self.num_sub_vectors,
954 self.pq_code.values(),
955 k_hint,
956 );
957 let diff = self.num_sub_vectors as f32 - 1.0;
958 dot_dists.into_iter().map(|v| v - diff).collect()
959 }
960 _ => unimplemented!("distance type is not supported: {:?}", self.distance_type),
961 }
962 }
963}
964
965fn get_pq_code(
966 pq_code: &[u8],
967 num_bits: u32,
968 num_sub_vectors: usize,
969 id: u32,
970) -> impl Iterator<Item = u8> + '_ {
971 let num_bytes = if num_bits == 4 {
972 num_sub_vectors / 2
973 } else {
974 num_sub_vectors
975 };
976
977 let num_vectors = pq_code.len() / num_bytes;
978 pq_code
979 .iter()
980 .skip(id as usize)
981 .step_by(num_vectors)
982 .copied()
983 .exact_size(num_bytes)
984}
985
986fn get_centroids<T: Clone>(
987 codebook: &[T],
988 num_bits: u32,
989 num_sub_vectors: usize,
990 dimension: usize,
991 codes: impl Iterator<Item = u8>,
992) -> Vec<T> {
993 if num_bits == 4 {
997 return get_centroids_4bit(codebook, num_sub_vectors, dimension, codes);
998 }
999
1000 let num_centroids: usize = 2_usize.pow(8);
1001 let sub_vector_width = dimension / num_sub_vectors;
1002 let mut centroids = Vec::with_capacity(dimension);
1003 for (sub_vec_idx, centroid_idx) in codes.enumerate() {
1004 let centroid_idx = centroid_idx as usize;
1005 let centroid = &codebook[sub_vec_idx * num_centroids * sub_vector_width
1006 + centroid_idx * sub_vector_width
1007 ..sub_vec_idx * num_centroids * sub_vector_width
1008 + (centroid_idx + 1) * sub_vector_width];
1009 centroids.extend_from_slice(centroid);
1010 }
1011 centroids
1012}
1013
1014fn get_centroids_4bit<T: Clone>(
1015 codebook: &[T],
1016 num_sub_vectors: usize,
1017 dimension: usize,
1018 codes: impl Iterator<Item = u8>,
1019) -> Vec<T> {
1020 let num_centroids: usize = 16;
1021 let sub_vector_width = dimension / num_sub_vectors;
1022 let mut centroids = Vec::with_capacity(dimension);
1023 for (sub_vec_idx, centroid_idx) in codes.into_iter().enumerate() {
1024 let current_idx = (centroid_idx & 0x0F) as usize;
1025 let offset = 2 * sub_vec_idx * num_centroids * sub_vector_width;
1026 let current_centroid = &codebook[offset + current_idx * sub_vector_width
1027 ..offset + (current_idx + 1) * sub_vector_width];
1028 centroids.extend_from_slice(current_centroid);
1029
1030 let next_idx = (centroid_idx >> 4) as usize;
1031 let offset = (2 * sub_vec_idx + 1) * num_centroids * sub_vector_width;
1032 let next_centroid = &codebook
1033 [offset + next_idx * sub_vector_width..offset + (next_idx + 1) * sub_vector_width];
1034 centroids.extend_from_slice(next_centroid);
1035 }
1036 centroids
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041 use crate::vector::storage::StorageBuilder;
1042
1043 use super::*;
1044
1045 use arrow_array::{Float32Array, UInt32Array};
1046 use arrow_schema::{DataType, Field, Schema as ArrowSchema};
1047 use lance_arrow::FixedSizeListArrayExt;
1048 use lance_core::ROW_ID_FIELD;
1049 use rand::Rng;
1050
1051 const DIM: usize = 32;
1052 const TOTAL: usize = 512;
1053 const NUM_SUB_VECTORS: usize = 16;
1054
1055 async fn create_pq_storage() -> ProductQuantizationStorage {
1056 let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|_| rand::random()));
1057 let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap();
1058 let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::Dot);
1059
1060 let schema = ArrowSchema::new(vec![
1061 Field::new(
1062 "vec",
1063 DataType::FixedSizeList(
1064 Field::new_list_field(DataType::Float32, true).into(),
1065 DIM as i32,
1066 ),
1067 true,
1068 ),
1069 ROW_ID_FIELD.clone(),
1070 ]);
1071 let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|_| rand::random()));
1072 let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64));
1073 let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap();
1074 let batch =
1075 RecordBatch::try_new(schema.into(), vec![Arc::new(fsl), Arc::new(row_ids)]).unwrap();
1076
1077 StorageBuilder::new("vec".to_owned(), pq.distance_type, pq, None)
1078 .unwrap()
1079 .build(vec![batch])
1080 .unwrap()
1081 }
1082
1083 async fn create_pq_storage_with_extra_column() -> ProductQuantizationStorage {
1084 let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|_| rand::random()));
1085 let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap();
1086 let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::Dot);
1087
1088 let schema = ArrowSchema::new(vec![
1089 Field::new(
1090 "vec",
1091 DataType::FixedSizeList(
1092 Field::new_list_field(DataType::Float32, true).into(),
1093 DIM as i32,
1094 ),
1095 true,
1096 ),
1097 ROW_ID_FIELD.clone(),
1098 Field::new("extra", DataType::UInt32, true),
1099 ]);
1100 let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|_| rand::random()));
1101 let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64));
1102 let extra_column = UInt32Array::from_iter_values((0..TOTAL).map(|v| v as u32));
1103 let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap();
1104 let batch = RecordBatch::try_new(
1105 schema.into(),
1106 vec![Arc::new(fsl), Arc::new(row_ids), Arc::new(extra_column)],
1107 )
1108 .unwrap();
1109
1110 StorageBuilder::new("vec".to_owned(), pq.distance_type, pq, None)
1111 .unwrap()
1112 .build(vec![batch])
1113 .unwrap()
1114 }
1115
1116 #[tokio::test]
1117 async fn test_build_pq_storage() {
1118 let storage = create_pq_storage().await;
1119 assert_eq!(storage.len(), TOTAL);
1120 assert_eq!(storage.metadata.num_sub_vectors, NUM_SUB_VECTORS);
1121 assert_eq!(
1122 storage.metadata.codebook.as_ref().unwrap().values().len(),
1123 256 * DIM
1124 );
1125 assert_eq!(storage.pq_code.len(), TOTAL * NUM_SUB_VECTORS);
1126 assert_eq!(storage.row_ids.len(), TOTAL);
1127 }
1128
1129 #[tokio::test]
1130 async fn test_distance_all() {
1131 let storage = create_pq_storage().await;
1132 let query = Arc::new(Float32Array::from_iter_values((0..DIM).map(|v| v as f32)));
1133 let dist_calc = storage.dist_calculator(query, 0.0);
1134 let expected = (0..storage.len())
1135 .map(|id| dist_calc.distance(id as u32))
1136 .collect::<Vec<_>>();
1137 let distances = dist_calc.distance_all(100);
1138 assert_eq!(distances, expected);
1139 }
1140
1141 #[tokio::test]
1142 async fn test_dist_between() {
1143 let mut rng = rand::rng();
1144 let storage = create_pq_storage().await;
1145 let u = rng.random_range(0..storage.len() as u32);
1146 let v = rng.random_range(0..storage.len() as u32);
1147 let dist1 = storage.dist_between(u, v);
1148 let dist2 = storage.dist_between(v, u);
1149 assert_eq!(dist1, dist2);
1150 }
1151
1152 #[tokio::test]
1153 async fn test_remap_with_extra_column() {
1154 let storage = create_pq_storage_with_extra_column().await;
1155 let mut mapping = HashMap::new();
1156 for i in 0..TOTAL / 2 {
1157 mapping.insert(i as u64, Some((TOTAL + i) as u64));
1158 }
1159 for i in TOTAL / 2..TOTAL {
1160 mapping.insert(i as u64, None);
1161 }
1162 let new_storage = storage.remap(&mapping).unwrap();
1163 assert_eq!(new_storage.len(), TOTAL / 2);
1164 assert_eq!(new_storage.row_ids.len(), TOTAL / 2);
1165 for (i, row_id) in new_storage.row_ids().enumerate() {
1166 assert_eq!(*row_id, (TOTAL + i) as u64);
1167 }
1168 assert_eq!(new_storage.batch.num_columns(), 2);
1169 assert!(new_storage.batch.column_by_name(ROW_ID).is_some());
1170 assert!(new_storage.batch.column_by_name(PQ_CODE_COLUMN).is_some());
1171 }
1172}