lance_index/vector/sq/
transform.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{
5    fmt::{Debug, Formatter},
6    sync::Arc,
7};
8
9use arrow::array::AsArray;
10use arrow_array::{
11    types::{Float16Type, Float32Type, Float64Type},
12    RecordBatch,
13};
14use arrow_schema::{DataType, Field};
15use snafu::location;
16use tracing::instrument;
17
18use crate::vector::transform::Transformer;
19
20use lance_arrow::RecordBatchExt;
21use lance_core::{Error, Result};
22
23use super::ScalarQuantizer;
24
25pub struct SQTransformer {
26    quantizer: ScalarQuantizer,
27    input_column: String,
28    output_column: String,
29}
30
31impl SQTransformer {
32    pub fn new(quantizer: ScalarQuantizer, input_column: String, output_column: String) -> Self {
33        Self {
34            quantizer,
35            input_column,
36            output_column,
37        }
38    }
39}
40
41impl Debug for SQTransformer {
42    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
43        write!(
44            f,
45            "SQTransformer(input={}, output={})",
46            self.input_column, self.output_column
47        )
48    }
49}
50
51impl Transformer for SQTransformer {
52    #[instrument(name = "SQTransformer::transform", level = "debug", skip_all)]
53    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
54        let input = batch
55            .column_by_name(&self.input_column)
56            .ok_or(Error::Index {
57                message: format!(
58                    "SQ Transform: column {} not found in batch",
59                    self.input_column
60                ),
61                location: location!(),
62            })?;
63        let fsl = input.as_fixed_size_list_opt().ok_or(Error::Index {
64            message: "input column is not vector type".to_string(),
65            location: location!(),
66        })?;
67        let sq_code = match fsl.value_type() {
68            DataType::Float16 => self.quantizer.transform::<Float16Type>(input)?,
69            DataType::Float32 => self.quantizer.transform::<Float32Type>(input)?,
70            DataType::Float64 => self.quantizer.transform::<Float64Type>(input)?,
71            _ => {
72                return Err(Error::Index {
73                    message: format!("unsupported data type: {}", fsl.value_type()),
74                    location: location!(),
75                })
76            }
77        };
78
79        let sq_field = Field::new(&self.output_column, sq_code.data_type().clone(), false);
80        let batch = batch
81            .try_with_column(sq_field, Arc::new(sq_code))?
82            .drop_column(&self.input_column)?;
83        Ok(batch)
84    }
85}