lance_index/vector/pq/
transform.rs1use std::fmt::{Debug, Formatter};
5use std::sync::Arc;
6
7use arrow_array::{cast::AsArray, Array, RecordBatch};
8use arrow_schema::Field;
9use lance_arrow::RecordBatchExt;
10use lance_core::{Error, Result};
11use snafu::location;
12use tracing::instrument;
13
14use super::ProductQuantizer;
15use crate::vector::quantizer::Quantization;
16use crate::vector::transform::Transformer;
17
18pub struct PQTransformer {
22 quantizer: ProductQuantizer,
23 input_column: String,
24 output_column: String,
25}
26
27impl PQTransformer {
28 pub fn new(quantizer: ProductQuantizer, input_column: &str, output_column: &str) -> Self {
29 Self {
30 quantizer,
31 input_column: input_column.to_owned(),
32 output_column: output_column.to_owned(),
33 }
34 }
35}
36
37impl Debug for PQTransformer {
38 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39 write!(
40 f,
41 "PQTransformer(input={}, output={})",
42 self.input_column, self.output_column
43 )
44 }
45}
46
47impl Transformer for PQTransformer {
48 #[instrument(name = "PQTransformer::transform", level = "debug", skip_all)]
49 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
50 if batch.column_by_name(&self.output_column).is_some() {
51 return Ok(batch.clone());
52 }
53 let input_arr = batch
54 .column_by_name(&self.input_column)
55 .ok_or(Error::Index {
56 message: format!(
57 "PQ Transform: column {} not found in batch",
58 self.input_column
59 ),
60 location: location!(),
61 })?;
62 let data = input_arr.as_fixed_size_list_opt().ok_or(Error::Index {
63 message: format!(
64 "PQ Transform: column {} is not a fixed size list, got {}",
65 self.input_column,
66 input_arr.data_type(),
67 ),
68 location: location!(),
69 })?;
70 let pq_code = self.quantizer.quantize(&data)?;
71 let pq_field = Field::new(&self.output_column, pq_code.data_type().clone(), false);
72 let batch = batch.try_with_column(pq_field, Arc::new(pq_code))?;
73 let batch = batch.drop_column(&self.input_column)?;
74 Ok(batch)
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 use arrow_array::{FixedSizeListArray, Float32Array, Int32Array};
83 use arrow_schema::{DataType, Schema};
84 use lance_arrow::FixedSizeListArrayExt;
85 use lance_linalg::distance::DistanceType;
86
87 use crate::vector::pq::PQBuildParams;
88
89 #[tokio::test]
90 async fn test_pq_transform() {
91 let values = Float32Array::from_iter((0..16000).map(|v| v as f32));
92 let dim = 16;
93 let arr = Arc::new(FixedSizeListArray::try_new_from_values(values, 16).unwrap());
94 let params = PQBuildParams::new(1, 8);
95 let pq = ProductQuantizer::build(arr.as_ref(), DistanceType::L2, ¶ms).unwrap();
96
97 let schema = Schema::new(vec![
98 Field::new(
99 "vec",
100 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
101 true,
102 ),
103 Field::new("other", DataType::Int32, false),
104 ]);
105 let batch = RecordBatch::try_new(
106 Arc::new(schema),
107 vec![arr, Arc::new(Int32Array::from_iter_values(0..1000))],
108 )
109 .unwrap();
110
111 let transformer = PQTransformer::new(pq, "vec", "pq_code");
112 let batch = transformer.transform(&batch).unwrap();
113 assert!(batch.column_by_name("vec").is_none());
114 assert!(batch.column_by_name("pq_code").is_some());
115 assert!(batch.column_by_name("other").is_some());
116 assert_eq!(batch.num_rows(), 1000)
117 }
118}