Skip to main content

lance_index/vector/pq/
transform.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::fmt::{Debug, Formatter};
5use std::sync::Arc;
6
7use arrow_array::{Array, RecordBatch, cast::AsArray};
8use arrow_schema::Field;
9use lance_arrow::RecordBatchExt;
10use lance_core::{Error, Result};
11use tracing::instrument;
12
13use super::ProductQuantizer;
14use crate::vector::quantizer::Quantization;
15use crate::vector::transform::Transformer;
16
17/// Product Quantizer Transformer
18///
19/// It transforms a column of vectors into a column of PQ codes.
20pub struct PQTransformer {
21    quantizer: ProductQuantizer,
22    input_column: String,
23    output_column: String,
24}
25
26impl PQTransformer {
27    pub fn new(quantizer: ProductQuantizer, input_column: &str, output_column: &str) -> Self {
28        Self {
29            quantizer,
30            input_column: input_column.to_owned(),
31            output_column: output_column.to_owned(),
32        }
33    }
34}
35
36impl Debug for PQTransformer {
37    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
38        write!(
39            f,
40            "PQTransformer(input={}, output={})",
41            self.input_column, self.output_column
42        )
43    }
44}
45
46impl Transformer for PQTransformer {
47    #[instrument(name = "PQTransformer::transform", level = "debug", skip_all)]
48    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
49        if batch.column_by_name(&self.output_column).is_some() {
50            return Ok(batch.clone());
51        }
52        let input_arr = batch
53            .column_by_name(&self.input_column)
54            .ok_or(Error::index(format!(
55                "PQ Transform: column {} not found in batch",
56                self.input_column
57            )))?;
58        let data = input_arr
59            .as_fixed_size_list_opt()
60            .ok_or(Error::index(format!(
61                "PQ Transform: column {} is not a fixed size list, got {}",
62                self.input_column,
63                input_arr.data_type(),
64            )))?;
65        let pq_code = self.quantizer.quantize(&data)?;
66        let pq_field = Field::new(&self.output_column, pq_code.data_type().clone(), false);
67        let batch = batch.try_with_column(pq_field, Arc::new(pq_code))?;
68        let batch = batch.drop_column(&self.input_column)?;
69        Ok(batch)
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76
77    use arrow_array::{FixedSizeListArray, Float32Array, Int32Array};
78    use arrow_schema::{DataType, Schema};
79    use lance_arrow::FixedSizeListArrayExt;
80    use lance_linalg::distance::DistanceType;
81
82    use crate::vector::pq::PQBuildParams;
83
84    #[tokio::test]
85    async fn test_pq_transform() {
86        let values = Float32Array::from_iter((0..16000).map(|v| v as f32));
87        let dim = 16;
88        let arr = Arc::new(FixedSizeListArray::try_new_from_values(values, 16).unwrap());
89        let params = PQBuildParams::new(1, 8);
90        let pq = ProductQuantizer::build(arr.as_ref(), DistanceType::L2, &params).unwrap();
91
92        let schema = Schema::new(vec![
93            Field::new(
94                "vec",
95                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
96                true,
97            ),
98            Field::new("other", DataType::Int32, false),
99        ]);
100        let batch = RecordBatch::try_new(
101            Arc::new(schema),
102            vec![arr, Arc::new(Int32Array::from_iter_values(0..1000))],
103        )
104        .unwrap();
105
106        let transformer = PQTransformer::new(pq, "vec", "pq_code");
107        let batch = transformer.transform(&batch).unwrap();
108        assert!(batch.column_by_name("vec").is_none());
109        assert!(batch.column_by_name("pq_code").is_some());
110        assert!(batch.column_by_name("other").is_some());
111        assert_eq!(batch.num_rows(), 1000)
112    }
113}