lance_index/vector/sq/
transform.rs1use std::{
5 fmt::{Debug, Formatter},
6 sync::Arc,
7};
8
9use arrow::array::AsArray;
10use arrow_array::{
11 types::{Float16Type, Float32Type, Float64Type},
12 RecordBatch,
13};
14use arrow_schema::{DataType, Field};
15use snafu::location;
16use tracing::instrument;
17
18use crate::vector::transform::Transformer;
19
20use lance_arrow::RecordBatchExt;
21use lance_core::{Error, Result};
22
23use super::ScalarQuantizer;
24
25pub struct SQTransformer {
26 quantizer: ScalarQuantizer,
27 input_column: String,
28 output_column: String,
29}
30
31impl SQTransformer {
32 pub fn new(quantizer: ScalarQuantizer, input_column: String, output_column: String) -> Self {
33 Self {
34 quantizer,
35 input_column,
36 output_column,
37 }
38 }
39}
40
41impl Debug for SQTransformer {
42 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
43 write!(
44 f,
45 "SQTransformer(input={}, output={})",
46 self.input_column, self.output_column
47 )
48 }
49}
50
51impl Transformer for SQTransformer {
52 #[instrument(name = "SQTransformer::transform", level = "debug", skip_all)]
53 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
54 let input = batch
55 .column_by_name(&self.input_column)
56 .ok_or(Error::Index {
57 message: format!(
58 "SQ Transform: column {} not found in batch",
59 self.input_column
60 ),
61 location: location!(),
62 })?;
63 let fsl = input.as_fixed_size_list_opt().ok_or(Error::Index {
64 message: "input column is not vector type".to_string(),
65 location: location!(),
66 })?;
67 let sq_code = match fsl.value_type() {
68 DataType::Float16 => self.quantizer.transform::<Float16Type>(input)?,
69 DataType::Float32 => self.quantizer.transform::<Float32Type>(input)?,
70 DataType::Float64 => self.quantizer.transform::<Float64Type>(input)?,
71 _ => {
72 return Err(Error::Index {
73 message: format!("unsupported data type: {}", fsl.value_type()),
74 location: location!(),
75 })
76 }
77 };
78
79 let sq_field = Field::new(&self.output_column, sq_code.data_type().clone(), false);
80 let batch = batch
81 .try_with_column(sq_field, Arc::new(sq_code))?
82 .drop_column(&self.input_column)?;
83 Ok(batch)
84 }
85}