1use std::{cmp::min, collections::HashMap, sync::Arc};
9
10use arrow::datatypes::{self, UInt8Type};
11use arrow_array::{
12 cast::AsArray,
13 types::{Float32Type, UInt64Type},
14 FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array,
15};
16use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
17use arrow_schema::{DataType, SchemaRef};
18use async_trait::async_trait;
19use bytes::{Bytes, BytesMut};
20use deepsize::DeepSizeOf;
21use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt};
22use lance_core::{Error, Result, ROW_ID};
23use lance_file::previous::{
24 reader::FileReader as PreviousFileReader, writer::FileWriter as PreviousFileWriter,
25};
26use lance_io::{object_store::ObjectStore, utils::read_message};
27use lance_linalg::distance::{DistanceType, Dot, L2};
28use lance_table::utils::LanceIteratorExtension;
29use lance_table::{format::SelfDescribingFileReader, io::manifest::ManifestDescribing};
30use object_store::path::Path;
31use prost::Message;
32use serde::{Deserialize, Serialize};
33use snafu::location;
34
35use super::distance::{build_distance_table_dot, build_distance_table_l2, compute_pq_distance};
36use super::ProductQuantizer;
37use crate::frag_reuse::FragReuseIndex;
38use crate::{
39 pb,
40 vector::{
41 pq::transform::PQTransformer,
42 quantizer::{QuantizerMetadata, QuantizerStorage},
43 storage::{DistCalculator, VectorStore},
44 transform::Transformer,
45 PQ_CODE_COLUMN,
46 },
47 IndexMetadata, INDEX_METADATA_SCHEMA_KEY,
48};
49
50pub const PQ_METADATA_KEY: &str = "lance:pq";
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ProductQuantizationMetadata {
54 pub codebook_position: usize,
55 pub nbits: u32,
56 pub num_sub_vectors: usize,
57 pub dimension: usize,
58
59 #[serde(skip)]
60 pub codebook: Option<FixedSizeListArray>,
61
62 pub codebook_tensor: Vec<u8>,
66 pub transposed: bool,
67}
68
69impl DeepSizeOf for ProductQuantizationMetadata {
70 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
71 self.codebook
72 .as_ref()
73 .map(|codebook| codebook.get_array_memory_size())
74 .unwrap_or(0)
75 }
76}
77
78impl PartialEq for ProductQuantizationMetadata {
79 fn eq(&self, other: &Self) -> bool {
80 self.num_sub_vectors == other.num_sub_vectors
81 && self.nbits == other.nbits
82 && self.dimension == other.dimension
83 && self.codebook == other.codebook
84 }
85}
86
87#[async_trait]
88impl QuantizerMetadata for ProductQuantizationMetadata {
89 fn buffer_index(&self) -> Option<u32> {
90 if self.codebook_position > 0 {
91 Some(self.codebook_position as u32)
93 } else {
94 None
95 }
96 }
97
98 fn set_buffer_index(&mut self, index: u32) {
99 self.codebook_position = index as usize;
100 }
101
102 fn parse_buffer(&mut self, bytes: Bytes) -> Result<()> {
103 debug_assert!(!bytes.is_empty());
104 debug_assert!(self.codebook.is_none());
105 let codebook_tensor: pb::Tensor = pb::Tensor::decode(bytes)?;
106 self.codebook = Some(FixedSizeListArray::try_from(&codebook_tensor)?);
107 Ok(())
108 }
109
110 fn extra_metadata(&self) -> Result<Option<Bytes>> {
111 debug_assert!(self.codebook.is_some());
112 let codebook_tensor: pb::Tensor = pb::Tensor::try_from(self.codebook.as_ref().unwrap())?;
113 let mut bytes = BytesMut::new();
114 codebook_tensor.encode(&mut bytes)?;
115 Ok(Some(bytes.freeze()))
116 }
117
118 async fn load(reader: &PreviousFileReader) -> Result<Self> {
119 let metadata = reader
120 .schema()
121 .metadata
122 .get(PQ_METADATA_KEY)
123 .ok_or(Error::Index {
124 message: format!(
125 "Reading PQ storage: metadata key {} not found",
126 PQ_METADATA_KEY
127 ),
128 location: location!(),
129 })?;
130 let mut metadata: Self = serde_json::from_str(metadata).map_err(|_| Error::Index {
131 message: format!("Failed to parse PQ metadata: {}", metadata),
132 location: location!(),
133 })?;
134
135 debug_assert!(metadata.codebook.is_none());
136 debug_assert!(metadata.codebook_tensor.is_empty());
137
138 let codebook_tensor: pb::Tensor =
139 read_message(reader.object_reader.as_ref(), metadata.codebook_position).await?;
140 metadata.codebook = Some(FixedSizeListArray::try_from(&codebook_tensor)?);
141 Ok(metadata)
142 }
143}
144
145#[derive(Clone, Debug)]
151pub struct ProductQuantizationStorage {
152 metadata: ProductQuantizationMetadata,
153 distance_type: DistanceType,
154 batch: RecordBatch,
155
156 pq_code: Arc<UInt8Array>,
158 row_ids: Arc<UInt64Array>,
159}
160
161impl DeepSizeOf for ProductQuantizationStorage {
162 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
163 self.batch.get_array_memory_size()
164 + self
165 .metadata
166 .codebook
167 .as_ref()
168 .map(|codebook| codebook.get_array_memory_size())
169 .unwrap_or(0)
170 }
171}
172
173impl PartialEq for ProductQuantizationStorage {
174 fn eq(&self, other: &Self) -> bool {
175 self.distance_type == other.distance_type
176 && self.metadata.eq(&other.metadata)
177 && self.batch.columns().eq(other.batch.columns())
178 }
179}
180
181impl ProductQuantizationStorage {
182 #[allow(clippy::too_many_arguments)]
183 pub fn new(
184 codebook: FixedSizeListArray,
185 mut batch: RecordBatch,
186 num_bits: u32,
187 num_sub_vectors: usize,
188 dimension: usize,
189 distance_type: DistanceType,
190 transposed: bool,
191 frag_reuse_index: Option<Arc<FragReuseIndex>>,
192 ) -> Result<Self> {
193 if batch.num_columns() != 2 {
194 log::warn!(
195 "PQ storage should have 2 columns, but got {} columns: {}",
196 batch.num_columns(),
197 batch.schema(),
198 );
199 batch = batch.project(&[
200 batch.schema().index_of(ROW_ID)?,
201 batch.schema().index_of(PQ_CODE_COLUMN)?,
202 ])?;
203 }
204
205 let Some(row_ids) = batch.column_by_name(ROW_ID) else {
206 return Err(Error::Index {
207 message: "Row ID column not found from PQ storage".to_string(),
208 location: location!(),
209 });
210 };
211 let row_ids: Arc<UInt64Array> = row_ids
212 .as_primitive_opt::<UInt64Type>()
213 .ok_or(Error::Index {
214 message: "Row ID column is not of type UInt64".to_string(),
215 location: location!(),
216 })?
217 .clone()
218 .into();
219
220 if !transposed {
221 let num_sub_vectors_in_byte = if num_bits == 4 {
222 num_sub_vectors / 2
223 } else {
224 num_sub_vectors
225 };
226 let pq_col = batch[PQ_CODE_COLUMN].as_fixed_size_list();
227 let transposed_code = transpose(
228 pq_col.values().as_primitive::<UInt8Type>(),
229 row_ids.len(),
230 num_sub_vectors_in_byte,
231 );
232 let pq_code_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
233 transposed_code,
234 num_sub_vectors_in_byte as i32,
235 )?);
236 batch = batch.replace_column_by_name(PQ_CODE_COLUMN, pq_code_fsl)?;
237 }
238
239 let mut pq_code: Arc<UInt8Array> = batch[PQ_CODE_COLUMN]
240 .as_fixed_size_list()
241 .values()
242 .as_primitive()
243 .clone()
244 .into();
245
246 if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
247 let transposed_codes = pq_code.values();
248 let mut new_row_ids = Vec::with_capacity(row_ids.len());
249 let mut new_codes = Vec::with_capacity(row_ids.len() * num_sub_vectors);
250
251 let row_ids_values = row_ids.values();
252 for (i, row_id) in row_ids_values.iter().enumerate() {
253 if let Some(mapped_value) = frag_reuse_index_ref.remap_row_id(*row_id) {
254 new_row_ids.push(mapped_value);
255 new_codes.extend(get_pq_code(
256 transposed_codes,
257 num_bits,
258 num_sub_vectors,
259 i as u32,
260 ));
261 }
262 }
263
264 let new_row_ids = Arc::new(UInt64Array::from(new_row_ids));
265 let new_codes = UInt8Array::from(new_codes);
266 batch = if new_row_ids.is_empty() {
267 RecordBatch::new_empty(batch.schema())
268 } else {
269 let num_bytes_in_code = new_codes.len() / new_row_ids.len();
270 let new_transposed_codes =
271 transpose(&new_codes, new_row_ids.len(), num_bytes_in_code);
272 let codes_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
273 new_transposed_codes,
274 num_bytes_in_code as i32,
275 )?);
276 RecordBatch::try_new(batch.schema(), vec![new_row_ids, codes_fsl])?
277 };
278 pq_code = batch[PQ_CODE_COLUMN]
279 .as_fixed_size_list()
280 .values()
281 .as_primitive::<UInt8Type>()
282 .clone()
283 .into();
284 }
285
286 let distance_type = match distance_type {
287 DistanceType::Cosine => DistanceType::L2,
288 _ => distance_type,
289 };
290 let metadata = ProductQuantizationMetadata {
291 codebook_position: 0,
292 nbits: num_bits,
293 num_sub_vectors,
294 dimension,
295 codebook: Some(codebook),
296 codebook_tensor: Vec::new(), transposed: true,
298 };
299 Ok(Self {
300 metadata,
301 distance_type,
302 batch,
303 pq_code,
304 row_ids,
305 })
306 }
307
308 pub fn batch(&self) -> &RecordBatch {
309 &self.batch
310 }
311
312 pub async fn build(
323 quantizer: ProductQuantizer,
324 batch: &RecordBatch,
325 vector_col: &str,
326 frag_reuse_index: Option<Arc<FragReuseIndex>>,
327 ) -> Result<Self> {
328 let codebook = quantizer.codebook.clone();
329 let num_bits = quantizer.num_bits;
330 let dimension = quantizer.dimension;
331 let num_sub_vectors = quantizer.num_sub_vectors;
332 let metric_type = quantizer.distance_type;
333 let transform = PQTransformer::new(quantizer, vector_col, PQ_CODE_COLUMN);
334 let batch = transform.transform(batch)?;
335 Self::new(
336 codebook,
337 batch,
338 num_bits,
339 num_sub_vectors,
340 dimension,
341 metric_type,
342 false,
343 frag_reuse_index,
344 )
345 }
346
347 pub fn codebook(&self) -> &FixedSizeListArray {
348 self.metadata.codebook.as_ref().unwrap()
349 }
350
351 pub async fn load(
367 object_store: &ObjectStore,
368 path: &Path,
369 frag_reuse_index: Option<Arc<FragReuseIndex>>,
370 ) -> Result<Self> {
371 let reader = PreviousFileReader::try_new_self_described(object_store, path, None).await?;
372 let schema = reader.schema();
373
374 let metadata_str = schema
375 .metadata
376 .get(INDEX_METADATA_SCHEMA_KEY)
377 .ok_or(Error::Index {
378 message: format!(
379 "Reading PQ storage: index key {} not found",
380 INDEX_METADATA_SCHEMA_KEY
381 ),
382 location: location!(),
383 })?;
384 let index_metadata: IndexMetadata =
385 serde_json::from_str(metadata_str).map_err(|_| Error::Index {
386 message: format!("Failed to parse index metadata: {}", metadata_str),
387 location: location!(),
388 })?;
389 let distance_type: DistanceType =
390 DistanceType::try_from(index_metadata.distance_type.as_str())?;
391
392 let metadata = ProductQuantizationMetadata::load(&reader).await?;
393 Self::load_partition(
394 &reader,
395 0..reader.len(),
396 distance_type,
397 &metadata,
398 frag_reuse_index,
399 )
400 .await
401 }
402
403 pub fn schema(&self) -> SchemaRef {
404 self.batch.schema()
405 }
406
407 pub fn get_row_ids(&self, ids: &[u32]) -> Vec<u64> {
408 ids.iter()
409 .map(|&id| self.row_ids.value(id as usize))
410 .collect()
411 }
412
413 pub async fn write_partition(
417 &self,
418 writer: &mut PreviousFileWriter<ManifestDescribing>,
419 ) -> Result<usize> {
420 let batch_size: usize = 10240; for offset in (0..self.batch.num_rows()).step_by(batch_size) {
422 let length = min(batch_size, self.batch.num_rows() - offset);
423 let slice = self.batch.slice(offset, length);
424 writer.write(&[slice]).await?;
425 }
426 Ok(self.batch.num_rows())
427 }
428}
429
430pub fn transpose<T: ArrowPrimitiveType>(
431 original: &PrimitiveArray<T>,
432 num_rows: usize,
433 num_columns: usize,
434) -> PrimitiveArray<T>
435where
436 PrimitiveArray<T>: From<Vec<T::Native>>,
437{
438 if original.is_empty() {
439 return original.clone();
440 }
441
442 let mut transposed_codes = vec![T::default_value(); original.len()];
443 for (vec_idx, codes) in original.values().chunks_exact(num_columns).enumerate() {
444 for (sub_vec_idx, code) in codes.iter().enumerate() {
445 transposed_codes[sub_vec_idx * num_rows + vec_idx] = *code;
446 }
447 }
448
449 transposed_codes.into()
450}
451
452#[async_trait]
453impl QuantizerStorage for ProductQuantizationStorage {
454 type Metadata = ProductQuantizationMetadata;
455
456 fn try_from_batch(
457 batch: RecordBatch,
458 metadata: &Self::Metadata,
459 distance_type: DistanceType,
460 frag_reuse_index: Option<Arc<FragReuseIndex>>,
461 ) -> Result<Self>
462 where
463 Self: Sized,
464 {
465 let distance_type = match distance_type {
466 DistanceType::Cosine => DistanceType::L2,
467 _ => distance_type,
468 };
469
470 let codebook = match &metadata.codebook {
472 Some(codebook) => codebook.clone(),
473 None => {
474 debug_assert!(!metadata.codebook_tensor.is_empty());
476 let codebook_tensor = pb::Tensor::decode(metadata.codebook_tensor.as_slice())?;
477 FixedSizeListArray::try_from(&codebook_tensor)?
478 }
479 };
480
481 Self::new(
482 codebook,
483 batch,
484 metadata.nbits,
485 metadata.num_sub_vectors,
486 metadata.dimension,
487 distance_type,
488 metadata.transposed,
489 frag_reuse_index,
490 )
491 }
492
493 fn metadata(&self) -> &Self::Metadata {
494 &self.metadata
495 }
496
497 fn remap(&self, mapping: &HashMap<u64, Option<u64>>) -> Result<Self> {
500 let transposed_codes = self.pq_code.values();
501 let mut new_row_ids = Vec::with_capacity(self.len());
502 let mut new_codes = Vec::with_capacity(self.len() * self.metadata.num_sub_vectors);
503
504 let row_ids = self.row_ids.values();
505 for (i, row_id) in row_ids.iter().enumerate() {
506 match mapping.get(row_id) {
507 Some(Some(new_id)) => {
508 new_row_ids.push(*new_id);
509 new_codes.extend(get_pq_code(
510 transposed_codes,
511 self.metadata.nbits,
512 self.metadata.num_sub_vectors,
513 i as u32,
514 ));
515 }
516 Some(None) => {}
517 None => {
518 new_row_ids.push(*row_id);
519 new_codes.extend(get_pq_code(
520 transposed_codes,
521 self.metadata.nbits,
522 self.metadata.num_sub_vectors,
523 i as u32,
524 ));
525 }
526 }
527 }
528
529 let new_row_ids = Arc::new(UInt64Array::from(new_row_ids));
530 let new_codes = UInt8Array::from(new_codes);
531 let batch = if new_row_ids.is_empty() {
532 RecordBatch::new_empty(self.schema())
533 } else {
534 let num_bytes_in_code = new_codes.len() / new_row_ids.len();
535 let new_transposed_codes = transpose(&new_codes, new_row_ids.len(), num_bytes_in_code);
536 let codes_fsl = Arc::new(FixedSizeListArray::try_new_from_values(
537 new_transposed_codes,
538 num_bytes_in_code as i32,
539 )?);
540 RecordBatch::try_new(self.schema(), vec![new_row_ids.clone(), codes_fsl])?
541 };
542 let transposed_codes = batch[PQ_CODE_COLUMN]
543 .as_fixed_size_list()
544 .values()
545 .as_primitive::<UInt8Type>()
546 .clone();
547
548 Ok(Self {
549 metadata: self.metadata.clone(),
550 distance_type: self.distance_type,
551 batch,
552 pq_code: Arc::new(transposed_codes),
553 row_ids: new_row_ids,
554 })
555 }
556
557 async fn load_partition(
563 reader: &PreviousFileReader,
564 range: std::ops::Range<usize>,
565 distance_type: DistanceType,
566 metadata: &Self::Metadata,
567 frag_reuse_index: Option<Arc<FragReuseIndex>>,
568 ) -> Result<Self> {
569 let codebook = metadata
571 .codebook
572 .as_ref()
573 .ok_or(Error::Index {
574 message: "Codebook not found in PQ metadata".to_string(),
575 location: location!(),
576 })?
577 .values()
578 .as_primitive::<Float32Type>()
579 .clone();
580
581 let codebook =
582 FixedSizeListArray::try_new_from_values(codebook, metadata.dimension as i32)?;
583
584 let schema = reader.schema();
585 let batch = reader.read_range(range, schema).await?;
586
587 Self::new(
588 codebook,
589 batch,
590 metadata.nbits,
591 metadata.num_sub_vectors,
592 metadata.dimension,
593 distance_type,
594 metadata.transposed,
595 frag_reuse_index,
596 )
597 }
598}
599
600impl VectorStore for ProductQuantizationStorage {
601 type DistanceCalculator<'a> = PQDistCalculator;
602
603 fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch>> {
604 Ok(std::iter::once(self.batch.clone()))
605 }
606
607 fn append_batch(&self, _batch: RecordBatch, _vector_column: &str) -> Result<Self> {
608 unimplemented!()
609 }
610
611 fn schema(&self) -> &SchemaRef {
612 self.batch.schema_ref()
613 }
614
615 fn as_any(&self) -> &dyn std::any::Any {
616 self
617 }
618
619 fn len(&self) -> usize {
620 self.batch.num_rows()
621 }
622
623 fn distance_type(&self) -> DistanceType {
624 self.distance_type
625 }
626
627 fn row_id(&self, id: u32) -> u64 {
628 self.row_ids.values()[id as usize]
629 }
630
631 fn row_ids(&self) -> impl Iterator<Item = &u64> {
632 self.row_ids.values().iter()
633 }
634
635 fn dist_calculator(&self, query: ArrayRef, _dist_q_c: f32) -> Self::DistanceCalculator<'_> {
636 let codebook = self.metadata.codebook.as_ref().unwrap();
637 match codebook.value_type() {
638 DataType::Float16 => PQDistCalculator::new(
639 codebook
640 .values()
641 .as_primitive::<datatypes::Float16Type>()
642 .values(),
643 self.metadata.nbits,
644 self.metadata.num_sub_vectors,
645 self.pq_code.clone(),
646 query.as_primitive::<datatypes::Float16Type>().values(),
647 self.distance_type,
648 ),
649 DataType::Float32 => PQDistCalculator::new(
650 codebook
651 .values()
652 .as_primitive::<datatypes::Float32Type>()
653 .values(),
654 self.metadata.nbits,
655 self.metadata.num_sub_vectors,
656 self.pq_code.clone(),
657 query.as_primitive::<datatypes::Float32Type>().values(),
658 self.distance_type,
659 ),
660 DataType::Float64 => PQDistCalculator::new(
661 codebook
662 .values()
663 .as_primitive::<datatypes::Float64Type>()
664 .values(),
665 self.metadata.nbits,
666 self.metadata.num_sub_vectors,
667 self.pq_code.clone(),
668 query.as_primitive::<datatypes::Float64Type>().values(),
669 self.distance_type,
670 ),
671 _ => unimplemented!("Unsupported data type: {:?}", codebook.value_type()),
672 }
673 }
674
675 fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> {
676 let codes = get_pq_code(
677 self.pq_code.values(),
678 self.metadata.nbits,
679 self.metadata.num_sub_vectors,
680 id,
681 );
682 let codebook = self.metadata.codebook.as_ref().unwrap();
683 match codebook.value_type() {
684 DataType::Float16 => {
685 let codebook = codebook
686 .values()
687 .as_primitive::<datatypes::Float16Type>()
688 .values();
689 let query = get_centroids(
690 codebook,
691 self.metadata.nbits,
692 self.metadata.num_sub_vectors,
693 self.metadata.dimension,
694 codes,
695 );
696 PQDistCalculator::new(
697 codebook,
698 self.metadata.nbits,
699 self.metadata.num_sub_vectors,
700 self.pq_code.clone(),
701 &query,
702 self.distance_type,
703 )
704 }
705 DataType::Float32 => {
706 let codebook = codebook
707 .values()
708 .as_primitive::<datatypes::Float32Type>()
709 .values();
710 let query = get_centroids(
711 codebook,
712 self.metadata.nbits,
713 self.metadata.num_sub_vectors,
714 self.metadata.dimension,
715 codes,
716 );
717 PQDistCalculator::new(
718 codebook,
719 self.metadata.nbits,
720 self.metadata.num_sub_vectors,
721 self.pq_code.clone(),
722 &query,
723 self.distance_type,
724 )
725 }
726 DataType::Float64 => {
727 let codebook = codebook
728 .values()
729 .as_primitive::<datatypes::Float64Type>()
730 .values();
731 let query = get_centroids(
732 codebook,
733 self.metadata.nbits,
734 self.metadata.num_sub_vectors,
735 self.metadata.dimension,
736 codes,
737 );
738 PQDistCalculator::new(
739 codebook,
740 self.metadata.nbits,
741 self.metadata.num_sub_vectors,
742 self.pq_code.clone(),
743 &query,
744 self.distance_type,
745 )
746 }
747 _ => unimplemented!("Unsupported data type: {:?}", codebook.value_type()),
748 }
749 }
750
751 fn dist_between(&self, u: u32, v: u32) -> f32 {
752 let pq_codes = self.pq_code.values();
755 let u_codes = get_pq_code(
756 pq_codes,
757 self.metadata.nbits,
758 self.metadata.num_sub_vectors,
759 u,
760 );
761 let v_codes = get_pq_code(
762 pq_codes,
763 self.metadata.nbits,
764 self.metadata.num_sub_vectors,
765 v,
766 );
767 let codebook = self.metadata.codebook.as_ref().unwrap();
768
769 match codebook.value_type() {
770 DataType::Float16 => {
771 let qu = get_centroids(
772 codebook
773 .values()
774 .as_primitive::<datatypes::Float16Type>()
775 .values(),
776 self.metadata.nbits,
777 self.metadata.num_sub_vectors,
778 self.metadata.dimension,
779 u_codes,
780 );
781 let qv = get_centroids(
782 codebook
783 .values()
784 .as_primitive::<datatypes::Float16Type>()
785 .values(),
786 self.metadata.nbits,
787 self.metadata.num_sub_vectors,
788 self.metadata.dimension,
789 v_codes,
790 );
791 self.distance_type.func()(&qu, &qv)
792 }
793 DataType::Float32 => {
794 let qu = get_centroids(
795 codebook
796 .values()
797 .as_primitive::<datatypes::Float32Type>()
798 .values(),
799 self.metadata.nbits,
800 self.metadata.num_sub_vectors,
801 self.metadata.dimension,
802 u_codes,
803 );
804 let qv = get_centroids(
805 codebook
806 .values()
807 .as_primitive::<datatypes::Float32Type>()
808 .values(),
809 self.metadata.nbits,
810 self.metadata.num_sub_vectors,
811 self.metadata.dimension,
812 v_codes,
813 );
814 self.distance_type.func()(&qu, &qv)
815 }
816 DataType::Float64 => {
817 let qu = get_centroids(
818 codebook
819 .values()
820 .as_primitive::<datatypes::Float64Type>()
821 .values(),
822 self.metadata.nbits,
823 self.metadata.num_sub_vectors,
824 self.metadata.dimension,
825 u_codes,
826 );
827 let qv = get_centroids(
828 codebook
829 .values()
830 .as_primitive::<datatypes::Float64Type>()
831 .values(),
832 self.metadata.nbits,
833 self.metadata.num_sub_vectors,
834 self.metadata.dimension,
835 v_codes,
836 );
837 self.distance_type.func()(&qu, &qv)
838 }
839 _ => unimplemented!("Unsupported data type: {:?}", codebook.value_type()),
840 }
841 }
842}
843
844pub struct PQDistCalculator {
846 distance_table: Vec<f32>,
847 pq_code: Arc<UInt8Array>,
848 num_sub_vectors: usize,
849 num_bits: u32,
850 distance_type: DistanceType,
851}
852
853impl PQDistCalculator {
854 fn new<T: L2 + Dot>(
855 codebook: &[T],
856 num_bits: u32,
857 num_sub_vectors: usize,
858 pq_code: Arc<UInt8Array>,
859 query: &[T],
860 distance_type: DistanceType,
861 ) -> Self {
862 let distance_table = match distance_type {
863 DistanceType::L2 | DistanceType::Cosine => {
864 build_distance_table_l2(codebook, num_bits, num_sub_vectors, query)
865 }
866 DistanceType::Dot => {
867 build_distance_table_dot(codebook, num_bits, num_sub_vectors, query)
868 }
869 _ => unimplemented!("DistanceType is not supported: {:?}", distance_type),
870 };
871 Self {
872 distance_table,
873 num_sub_vectors,
874 pq_code,
875 num_bits,
876 distance_type,
877 }
878 }
879
880 fn get_pq_code(&self, id: u32) -> impl Iterator<Item = usize> + '_ {
881 get_pq_code(
882 self.pq_code.values(),
883 self.num_bits,
884 self.num_sub_vectors,
885 id,
886 )
887 .map(|v| v as usize)
888 }
889}
890
891impl DistCalculator for PQDistCalculator {
892 fn distance(&self, id: u32) -> f32 {
893 let num_centroids = 2_usize.pow(self.num_bits);
894 let pq_code = self.get_pq_code(id);
895 let diff = self.num_sub_vectors as f32 - 1.0;
896 let dist = if self.num_bits == 4 {
897 pq_code
898 .enumerate()
899 .map(|(i, c)| {
900 let current_idx = c & 0x0F;
901 let next_idx = c >> 4;
902
903 self.distance_table[2 * i * num_centroids + current_idx]
904 + self.distance_table[(2 * i + 1) * num_centroids + next_idx]
905 })
906 .sum()
907 } else {
908 pq_code
909 .enumerate()
910 .map(|(i, c)| self.distance_table[i * num_centroids + c])
911 .sum()
912 };
913
914 if self.distance_type == DistanceType::Dot {
915 dist - diff
916 } else {
917 dist
918 }
919 }
920
921 fn distance_all(&self, k_hint: usize) -> Vec<f32> {
922 match self.distance_type {
923 DistanceType::L2 => compute_pq_distance(
924 &self.distance_table,
925 self.num_bits,
926 self.num_sub_vectors,
927 self.pq_code.values(),
928 k_hint,
929 ),
930 DistanceType::Cosine => {
931 debug_assert!(
934 false,
935 "cosine distance should be converted to normalized L2 distance"
936 );
937 let l2_dists = compute_pq_distance(
941 &self.distance_table,
942 self.num_bits,
943 self.num_sub_vectors,
944 self.pq_code.values(),
945 k_hint,
946 );
947 l2_dists.into_iter().map(|v| v / 2.0).collect()
948 }
949 DistanceType::Dot => {
950 let dot_dists = compute_pq_distance(
951 &self.distance_table,
952 self.num_bits,
953 self.num_sub_vectors,
954 self.pq_code.values(),
955 k_hint,
956 );
957 let diff = self.num_sub_vectors as f32 - 1.0;
958 dot_dists.into_iter().map(|v| v - diff).collect()
959 }
960 _ => unimplemented!("distance type is not supported: {:?}", self.distance_type),
961 }
962 }
963}
964
965fn get_pq_code(
966 pq_code: &[u8],
967 num_bits: u32,
968 num_sub_vectors: usize,
969 id: u32,
970) -> impl Iterator<Item = u8> + '_ {
971 let num_bytes = if num_bits == 4 {
972 num_sub_vectors / 2
973 } else {
974 num_sub_vectors
975 };
976
977 let num_vectors = pq_code.len() / num_bytes;
978 pq_code
979 .iter()
980 .skip(id as usize)
981 .step_by(num_vectors)
982 .copied()
983 .exact_size(num_bytes)
984}
985
986fn get_centroids<T: Clone>(
987 codebook: &[T],
988 num_bits: u32,
989 num_sub_vectors: usize,
990 dimension: usize,
991 codes: impl Iterator<Item = u8>,
992) -> Vec<T> {
993 if num_bits == 4 {
997 return get_centroids_4bit(codebook, num_sub_vectors, dimension, codes);
998 }
999
1000 let num_centroids: usize = 2_usize.pow(8);
1001 let sub_vector_width = dimension / num_sub_vectors;
1002 let mut centroids = Vec::with_capacity(dimension);
1003 for (sub_vec_idx, centroid_idx) in codes.enumerate() {
1004 let centroid_idx = centroid_idx as usize;
1005 let centroid = &codebook[sub_vec_idx * num_centroids * sub_vector_width
1006 + centroid_idx * sub_vector_width
1007 ..sub_vec_idx * num_centroids * sub_vector_width
1008 + (centroid_idx + 1) * sub_vector_width];
1009 centroids.extend_from_slice(centroid);
1010 }
1011 centroids
1012}
1013
1014fn get_centroids_4bit<T: Clone>(
1015 codebook: &[T],
1016 num_sub_vectors: usize,
1017 dimension: usize,
1018 codes: impl Iterator<Item = u8>,
1019) -> Vec<T> {
1020 let num_centroids: usize = 16;
1021 let sub_vector_width = dimension / num_sub_vectors;
1022 let mut centroids = Vec::with_capacity(dimension);
1023 for (sub_vec_idx, centroid_idx) in codes.into_iter().enumerate() {
1024 let current_idx = (centroid_idx & 0x0F) as usize;
1025 let offset = 2 * sub_vec_idx * num_centroids * sub_vector_width;
1026 let current_centroid = &codebook[offset + current_idx * sub_vector_width
1027 ..offset + (current_idx + 1) * sub_vector_width];
1028 centroids.extend_from_slice(current_centroid);
1029
1030 let next_idx = (centroid_idx >> 4) as usize;
1031 let offset = (2 * sub_vec_idx + 1) * num_centroids * sub_vector_width;
1032 let next_centroid = &codebook
1033 [offset + next_idx * sub_vector_width..offset + (next_idx + 1) * sub_vector_width];
1034 centroids.extend_from_slice(next_centroid);
1035 }
1036 centroids
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041 use crate::vector::storage::StorageBuilder;
1042
1043 use super::*;
1044
1045 use arrow_array::{Float32Array, UInt32Array};
1046 use arrow_schema::{DataType, Field, Schema as ArrowSchema};
1047 use lance_arrow::FixedSizeListArrayExt;
1048 use lance_core::ROW_ID_FIELD;
1049 use rand::Rng;
1050
1051 const DIM: usize = 32;
1052 const TOTAL: usize = 512;
1053 const NUM_SUB_VECTORS: usize = 16;
1054
1055 async fn create_pq_storage() -> ProductQuantizationStorage {
1056 let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|_| rand::random()));
1057 let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap();
1058 let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::Dot);
1059
1060 let schema = ArrowSchema::new(vec![
1061 Field::new(
1062 "vec",
1063 DataType::FixedSizeList(
1064 Field::new_list_field(DataType::Float32, true).into(),
1065 DIM as i32,
1066 ),
1067 true,
1068 ),
1069 ROW_ID_FIELD.clone(),
1070 ]);
1071 let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|_| rand::random()));
1072 let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64));
1073 let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap();
1074 let batch =
1075 RecordBatch::try_new(schema.into(), vec![Arc::new(fsl), Arc::new(row_ids)]).unwrap();
1076
1077 StorageBuilder::new("vec".to_owned(), pq.distance_type, pq, None)
1078 .unwrap()
1079 .build(vec![batch])
1080 .unwrap()
1081 }
1082
1083 async fn create_pq_storage_with_extra_column() -> ProductQuantizationStorage {
1084 let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|_| rand::random()));
1085 let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap();
1086 let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::Dot);
1087
1088 let schema = ArrowSchema::new(vec![
1089 Field::new(
1090 "vec",
1091 DataType::FixedSizeList(
1092 Field::new_list_field(DataType::Float32, true).into(),
1093 DIM as i32,
1094 ),
1095 true,
1096 ),
1097 ROW_ID_FIELD.clone(),
1098 Field::new("extra", DataType::UInt32, true),
1099 ]);
1100 let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|_| rand::random()));
1101 let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64));
1102 let extra_column = UInt32Array::from_iter_values((0..TOTAL).map(|v| v as u32));
1103 let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap();
1104 let batch = RecordBatch::try_new(
1105 schema.into(),
1106 vec![Arc::new(fsl), Arc::new(row_ids), Arc::new(extra_column)],
1107 )
1108 .unwrap();
1109
1110 StorageBuilder::new("vec".to_owned(), pq.distance_type, pq, None)
1111 .unwrap()
1112 .build(vec![batch])
1113 .unwrap()
1114 }
1115
1116 #[tokio::test]
1117 async fn test_build_pq_storage() {
1118 let storage = create_pq_storage().await;
1119 assert_eq!(storage.len(), TOTAL);
1120 assert_eq!(storage.metadata.num_sub_vectors, NUM_SUB_VECTORS);
1121 assert_eq!(
1122 storage.metadata.codebook.as_ref().unwrap().values().len(),
1123 256 * DIM
1124 );
1125 assert_eq!(storage.pq_code.len(), TOTAL * NUM_SUB_VECTORS);
1126 assert_eq!(storage.row_ids.len(), TOTAL);
1127 }
1128
1129 #[tokio::test]
1130 async fn test_distance_all() {
1131 let storage = create_pq_storage().await;
1132 let query = Arc::new(Float32Array::from_iter_values((0..DIM).map(|v| v as f32)));
1133 let dist_calc = storage.dist_calculator(query, 0.0);
1134 let expected = (0..storage.len())
1135 .map(|id| dist_calc.distance(id as u32))
1136 .collect::<Vec<_>>();
1137 let distances = dist_calc.distance_all(100);
1138 assert_eq!(distances, expected);
1139 }
1140
1141 #[tokio::test]
1142 async fn test_dist_between() {
1143 let mut rng = rand::rng();
1144 let storage = create_pq_storage().await;
1145 let u = rng.random_range(0..storage.len() as u32);
1146 let v = rng.random_range(0..storage.len() as u32);
1147 let dist1 = storage.dist_between(u, v);
1148 let dist2 = storage.dist_between(v, u);
1149 assert_eq!(dist1, dist2);
1150 }
1151
1152 #[tokio::test]
1153 async fn test_remap_with_extra_column() {
1154 let storage = create_pq_storage_with_extra_column().await;
1155 let mut mapping = HashMap::new();
1156 for i in 0..TOTAL / 2 {
1157 mapping.insert(i as u64, Some((TOTAL + i) as u64));
1158 }
1159 for i in TOTAL / 2..TOTAL {
1160 mapping.insert(i as u64, None);
1161 }
1162 let new_storage = storage.remap(&mapping).unwrap();
1163 assert_eq!(new_storage.len(), TOTAL / 2);
1164 assert_eq!(new_storage.row_ids.len(), TOTAL / 2);
1165 for (i, row_id) in new_storage.row_ids().enumerate() {
1166 assert_eq!(*row_id, (TOTAL + i) as u64);
1167 }
1168 assert_eq!(new_storage.batch.num_columns(), 2);
1169 assert!(new_storage.batch.column_by_name(ROW_ID).is_some());
1170 assert!(new_storage.batch.column_by_name(PQ_CODE_COLUMN).is_some());
1171 }
1172}