lance_index/vector/sq/
transform.rs1use std::{
5 fmt::{Debug, Formatter},
6 sync::Arc,
7};
8
9use arrow::array::AsArray;
10use arrow_array::{
11 RecordBatch,
12 types::{Float16Type, Float32Type, Float64Type},
13};
14use arrow_schema::{DataType, Field};
15use tracing::instrument;
16
17use crate::vector::transform::Transformer;
18
19use lance_arrow::RecordBatchExt;
20use lance_core::{Error, Result};
21
22use super::ScalarQuantizer;
23
24pub struct SQTransformer {
25 quantizer: ScalarQuantizer,
26 input_column: String,
27 output_column: String,
28}
29
30impl SQTransformer {
31 pub fn new(quantizer: ScalarQuantizer, input_column: String, output_column: String) -> Self {
32 Self {
33 quantizer,
34 input_column,
35 output_column,
36 }
37 }
38}
39
40impl Debug for SQTransformer {
41 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
42 write!(
43 f,
44 "SQTransformer(input={}, output={})",
45 self.input_column, self.output_column
46 )
47 }
48}
49
50impl Transformer for SQTransformer {
51 #[instrument(name = "SQTransformer::transform", level = "debug", skip_all)]
52 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
53 let input = batch
54 .column_by_name(&self.input_column)
55 .ok_or(Error::index(format!(
56 "SQ Transform: column {} not found in batch",
57 self.input_column
58 )))?;
59 let fsl = input
60 .as_fixed_size_list_opt()
61 .ok_or(Error::index("input column is not vector type".to_string()))?;
62 let sq_code = match fsl.value_type() {
63 DataType::Float16 => self.quantizer.transform::<Float16Type>(input)?,
64 DataType::Float32 => self.quantizer.transform::<Float32Type>(input)?,
65 DataType::Float64 => self.quantizer.transform::<Float64Type>(input)?,
66 _ => {
67 return Err(Error::index(format!(
68 "unsupported data type: {}",
69 fsl.value_type()
70 )));
71 }
72 };
73
74 let sq_field = Field::new(&self.output_column, sq_code.data_type().clone(), false);
75 let batch = batch
76 .try_with_column(sq_field, Arc::new(sq_code))?
77 .drop_column(&self.input_column)?;
78 Ok(batch)
79 }
80}