Skip to main content

lance_index/vector/sq/
transform.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{
5    fmt::{Debug, Formatter},
6    sync::Arc,
7};
8
9use arrow::array::AsArray;
10use arrow_array::{
11    RecordBatch,
12    types::{Float16Type, Float32Type, Float64Type},
13};
14use arrow_schema::{DataType, Field};
15use tracing::instrument;
16
17use crate::vector::transform::Transformer;
18
19use lance_arrow::RecordBatchExt;
20use lance_core::{Error, Result};
21
22use super::ScalarQuantizer;
23
24pub struct SQTransformer {
25    quantizer: ScalarQuantizer,
26    input_column: String,
27    output_column: String,
28}
29
30impl SQTransformer {
31    pub fn new(quantizer: ScalarQuantizer, input_column: String, output_column: String) -> Self {
32        Self {
33            quantizer,
34            input_column,
35            output_column,
36        }
37    }
38}
39
40impl Debug for SQTransformer {
41    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
42        write!(
43            f,
44            "SQTransformer(input={}, output={})",
45            self.input_column, self.output_column
46        )
47    }
48}
49
50impl Transformer for SQTransformer {
51    #[instrument(name = "SQTransformer::transform", level = "debug", skip_all)]
52    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
53        let input = batch
54            .column_by_name(&self.input_column)
55            .ok_or(Error::index(format!(
56                "SQ Transform: column {} not found in batch",
57                self.input_column
58            )))?;
59        let fsl = input
60            .as_fixed_size_list_opt()
61            .ok_or(Error::index("input column is not vector type".to_string()))?;
62        let sq_code = match fsl.value_type() {
63            DataType::Float16 => self.quantizer.transform::<Float16Type>(input)?,
64            DataType::Float32 => self.quantizer.transform::<Float32Type>(input)?,
65            DataType::Float64 => self.quantizer.transform::<Float64Type>(input)?,
66            _ => {
67                return Err(Error::index(format!(
68                    "unsupported data type: {}",
69                    fsl.value_type()
70                )));
71            }
72        };
73
74        let sq_field = Field::new(&self.output_column, sq_code.data_type().clone(), false);
75        let batch = batch
76            .try_with_column(sq_field, Arc::new(sq_code))?
77            .drop_column(&self.input_column)?;
78        Ok(batch)
79    }
80}