lance_index/vector/pq/
transform.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::fmt::{Debug, Formatter};
5use std::sync::Arc;
6
7use arrow_array::{cast::AsArray, Array, RecordBatch};
8use arrow_schema::Field;
9use lance_arrow::RecordBatchExt;
10use lance_core::{Error, Result};
11use snafu::location;
12use tracing::instrument;
13
14use super::ProductQuantizer;
15use crate::vector::quantizer::Quantization;
16use crate::vector::transform::Transformer;
17
18/// Product Quantizer Transformer
19///
20/// It transforms a column of vectors into a column of PQ codes.
21pub struct PQTransformer {
22    quantizer: ProductQuantizer,
23    input_column: String,
24    output_column: String,
25}
26
27impl PQTransformer {
28    pub fn new(quantizer: ProductQuantizer, input_column: &str, output_column: &str) -> Self {
29        Self {
30            quantizer,
31            input_column: input_column.to_owned(),
32            output_column: output_column.to_owned(),
33        }
34    }
35}
36
37impl Debug for PQTransformer {
38    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39        write!(
40            f,
41            "PQTransformer(input={}, output={})",
42            self.input_column, self.output_column
43        )
44    }
45}
46
47impl Transformer for PQTransformer {
48    #[instrument(name = "PQTransformer::transform", level = "debug", skip_all)]
49    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
50        if batch.column_by_name(&self.output_column).is_some() {
51            return Ok(batch.clone());
52        }
53        let input_arr = batch
54            .column_by_name(&self.input_column)
55            .ok_or(Error::Index {
56                message: format!(
57                    "PQ Transform: column {} not found in batch",
58                    self.input_column
59                ),
60                location: location!(),
61            })?;
62        let data = input_arr.as_fixed_size_list_opt().ok_or(Error::Index {
63            message: format!(
64                "PQ Transform: column {} is not a fixed size list, got {}",
65                self.input_column,
66                input_arr.data_type(),
67            ),
68            location: location!(),
69        })?;
70        let pq_code = self.quantizer.quantize(&data)?;
71        let pq_field = Field::new(&self.output_column, pq_code.data_type().clone(), false);
72        let batch = batch.try_with_column(pq_field, Arc::new(pq_code))?;
73        let batch = batch.drop_column(&self.input_column)?;
74        Ok(batch)
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    use arrow_array::{FixedSizeListArray, Float32Array, Int32Array};
83    use arrow_schema::{DataType, Schema};
84    use lance_arrow::FixedSizeListArrayExt;
85    use lance_linalg::distance::DistanceType;
86
87    use crate::vector::pq::PQBuildParams;
88
89    #[tokio::test]
90    async fn test_pq_transform() {
91        let values = Float32Array::from_iter((0..16000).map(|v| v as f32));
92        let dim = 16;
93        let arr = Arc::new(FixedSizeListArray::try_new_from_values(values, 16).unwrap());
94        let params = PQBuildParams::new(1, 8);
95        let pq = ProductQuantizer::build(arr.as_ref(), DistanceType::L2, &params).unwrap();
96
97        let schema = Schema::new(vec![
98            Field::new(
99                "vec",
100                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
101                true,
102            ),
103            Field::new("other", DataType::Int32, false),
104        ]);
105        let batch = RecordBatch::try_new(
106            Arc::new(schema),
107            vec![arr, Arc::new(Int32Array::from_iter_values(0..1000))],
108        )
109        .unwrap();
110
111        let transformer = PQTransformer::new(pq, "vec", "pq_code");
112        let batch = transformer.transform(&batch).unwrap();
113        assert!(batch.column_by_name("vec").is_none());
114        assert!(batch.column_by_name("pq_code").is_some());
115        assert!(batch.column_by_name("other").is_some());
116        assert_eq!(batch.num_rows(), 1000)
117    }
118}