lance_index/vector/pq/
transform.rs1use std::fmt::{Debug, Formatter};
5use std::sync::Arc;
6
7use arrow_array::{Array, RecordBatch, cast::AsArray};
8use arrow_schema::Field;
9use lance_arrow::RecordBatchExt;
10use lance_core::{Error, Result};
11use tracing::instrument;
12
13use super::ProductQuantizer;
14use crate::vector::quantizer::Quantization;
15use crate::vector::transform::Transformer;
16
17pub struct PQTransformer {
21 quantizer: ProductQuantizer,
22 input_column: String,
23 output_column: String,
24}
25
26impl PQTransformer {
27 pub fn new(quantizer: ProductQuantizer, input_column: &str, output_column: &str) -> Self {
28 Self {
29 quantizer,
30 input_column: input_column.to_owned(),
31 output_column: output_column.to_owned(),
32 }
33 }
34}
35
36impl Debug for PQTransformer {
37 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
38 write!(
39 f,
40 "PQTransformer(input={}, output={})",
41 self.input_column, self.output_column
42 )
43 }
44}
45
46impl Transformer for PQTransformer {
47 #[instrument(name = "PQTransformer::transform", level = "debug", skip_all)]
48 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
49 if batch.column_by_name(&self.output_column).is_some() {
50 return Ok(batch.clone());
51 }
52 let input_arr = batch
53 .column_by_name(&self.input_column)
54 .ok_or(Error::index(format!(
55 "PQ Transform: column {} not found in batch",
56 self.input_column
57 )))?;
58 let data = input_arr
59 .as_fixed_size_list_opt()
60 .ok_or(Error::index(format!(
61 "PQ Transform: column {} is not a fixed size list, got {}",
62 self.input_column,
63 input_arr.data_type(),
64 )))?;
65 let pq_code = self.quantizer.quantize(&data)?;
66 let pq_field = Field::new(&self.output_column, pq_code.data_type().clone(), false);
67 let batch = batch.try_with_column(pq_field, Arc::new(pq_code))?;
68 let batch = batch.drop_column(&self.input_column)?;
69 Ok(batch)
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76
77 use arrow_array::{FixedSizeListArray, Float32Array, Int32Array};
78 use arrow_schema::{DataType, Schema};
79 use lance_arrow::FixedSizeListArrayExt;
80 use lance_linalg::distance::DistanceType;
81
82 use crate::vector::pq::PQBuildParams;
83
84 #[tokio::test]
85 async fn test_pq_transform() {
86 let values = Float32Array::from_iter((0..16000).map(|v| v as f32));
87 let dim = 16;
88 let arr = Arc::new(FixedSizeListArray::try_new_from_values(values, 16).unwrap());
89 let params = PQBuildParams::new(1, 8);
90 let pq = ProductQuantizer::build(arr.as_ref(), DistanceType::L2, ¶ms).unwrap();
91
92 let schema = Schema::new(vec![
93 Field::new(
94 "vec",
95 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
96 true,
97 ),
98 Field::new("other", DataType::Int32, false),
99 ]);
100 let batch = RecordBatch::try_new(
101 Arc::new(schema),
102 vec![arr, Arc::new(Int32Array::from_iter_values(0..1000))],
103 )
104 .unwrap();
105
106 let transformer = PQTransformer::new(pq, "vec", "pq_code");
107 let batch = transformer.transform(&batch).unwrap();
108 assert!(batch.column_by_name("vec").is_none());
109 assert!(batch.column_by_name("pq_code").is_some());
110 assert!(batch.column_by_name("other").is_some());
111 assert_eq!(batch.num_rows(), 1000)
112 }
113}