lance_index/vector/
transform.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Vector Transforms
5//!
6
7use std::fmt::Debug;
8use std::sync::Arc;
9
10use arrow::datatypes::UInt64Type;
11use arrow_array::types::{Float16Type, Float32Type, Float64Type};
12use arrow_array::UInt64Array;
13use arrow_array::{cast::AsArray, Array, ArrowPrimitiveType, RecordBatch, UInt32Array};
14use arrow_schema::{DataType, Field, Schema};
15use lance_arrow::RecordBatchExt;
16use num_traits::Float;
17use snafu::location;
18
19use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD};
20use lance_linalg::kernels::normalize_fsl;
21use tracing::instrument;
22
23/// Transform of a Vector Matrix.
24///
25///
26pub trait Transformer: Debug + Send + Sync {
27    /// Transform a [`RecordBatch`] of vectors
28    ///
29    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch>;
30}
31
32/// Normalize Transformer
33///
34/// L2 Normalize each vector.
35#[derive(Debug)]
36pub struct NormalizeTransformer {
37    input_column: String,
38    output_column: Option<String>,
39}
40
41impl NormalizeTransformer {
42    pub fn new(column: impl AsRef<str>) -> Self {
43        Self {
44            input_column: column.as_ref().to_owned(),
45            output_column: None,
46        }
47    }
48
49    /// Create Normalize output transform that will be stored in a different column.
50    ///
51    pub fn new_with_output(input_column: impl AsRef<str>, output_column: impl AsRef<str>) -> Self {
52        Self {
53            input_column: input_column.as_ref().to_owned(),
54            output_column: Some(output_column.as_ref().to_owned()),
55        }
56    }
57}
58
59impl Transformer for NormalizeTransformer {
60    #[instrument(name = "NormalizeTransformer::transform", level = "debug", skip_all)]
61    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
62        let arr = batch
63            .column_by_name(&self.input_column)
64            .ok_or_else(|| Error::Index {
65                message: format!(
66                    "Normalize Transform: column {} not found in RecordBatch {}",
67                    self.input_column,
68                    batch.schema(),
69                ),
70                location: location!(),
71            })?;
72
73        let data = arr.as_fixed_size_list();
74        let norm = normalize_fsl(data)?;
75        let transformed = Arc::new(norm);
76
77        if let Some(output_column) = &self.output_column {
78            let field = Field::new(output_column, transformed.data_type().clone(), true);
79            Ok(batch.try_with_column(field, transformed)?)
80        } else {
81            Ok(batch.replace_column_by_name(&self.input_column, transformed)?)
82        }
83    }
84}
85
86/// Only keep the vectors that is finite number, filter out NaN and Inf.
87#[derive(Debug)]
88pub(crate) struct KeepFiniteVectors {
89    column: String,
90}
91
92impl KeepFiniteVectors {
93    pub fn new(column: impl AsRef<str>) -> Self {
94        Self {
95            column: column.as_ref().to_owned(),
96        }
97    }
98}
99
100fn is_all_finite<T: ArrowPrimitiveType>(arr: &dyn Array) -> bool
101where
102    T::Native: Float,
103{
104    arr.null_count() == 0
105        && !arr
106            .as_primitive::<T>()
107            .values()
108            .iter()
109            .any(|&v| !v.is_finite())
110}
111
112impl Transformer for KeepFiniteVectors {
113    #[instrument(name = "KeepFiniteVectors::transform", level = "debug", skip_all)]
114    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
115        let Some(arr) = batch.column_by_name(&self.column) else {
116            return Ok(batch.clone());
117        };
118
119        let data = match arr.data_type() {
120            DataType::FixedSizeList(_, _) => arr.as_fixed_size_list(),
121            DataType::List(_) => arr.as_list::<i32>().values().as_fixed_size_list(),
122            _ => {
123                return Err(Error::Index {
124                    message: format!(
125                        "KeepFiniteVectors: column {} is not a fixed size list: {}",
126                        self.column,
127                        arr.data_type()
128                    ),
129                    location: location!(),
130                })
131            }
132        };
133
134        let mut valid = Vec::with_capacity(batch.num_rows());
135        data.iter().enumerate().for_each(|(idx, arr)| {
136            if let Some(data) = arr {
137                let is_valid = match data.data_type() {
138                    // f16 vectors are computed in f32 space, so they will not overflow.
139                    DataType::Float16 => is_all_finite::<Float16Type>(&data),
140                    // f32 vectors must be bounded to avoid overflow in distance computation.
141                    DataType::Float32 => is_all_finite::<Float32Type>(&data),
142                    // f32 vectors are computed in f32 space, so they have the same limit as f64.
143                    DataType::Float64 => is_all_finite::<Float64Type>(&data),
144                    DataType::UInt8 => data.null_count() == 0,
145                    DataType::Int8 => data.null_count() == 0,
146                    _ => false,
147                };
148                if is_valid {
149                    valid.push(idx as u32);
150                }
151            };
152        });
153        if valid.len() < batch.num_rows() {
154            let indices = UInt32Array::from(valid);
155            Ok(batch.take(&indices)?)
156        } else {
157            Ok(batch.clone())
158        }
159    }
160}
161
162#[derive(Debug)]
163pub struct DropColumn {
164    column: String,
165}
166
167impl DropColumn {
168    pub fn new(column: &str) -> Self {
169        Self {
170            column: column.to_owned(),
171        }
172    }
173}
174
175impl Transformer for DropColumn {
176    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
177        Ok(batch.drop_column(&self.column)?)
178    }
179}
180
181#[derive(Debug)]
182pub struct Flatten {
183    column: String,
184}
185
186impl Flatten {
187    pub fn new(column: &str) -> Self {
188        Self {
189            column: column.to_owned(),
190        }
191    }
192}
193
194impl Transformer for Flatten {
195    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
196        let Some(arr) = batch.column_by_name(&self.column) else {
197            // this case is that we have precomputed buffers,
198            // so we don't need to flatten the original vectors.
199            return Ok(batch.clone());
200        };
201        match arr.data_type() {
202            DataType::FixedSizeList(_, _) => Ok(batch.clone()),
203            DataType::List(_) => {
204                let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
205                let vectors = arr.as_list::<i32>();
206
207                let row_ids = row_ids.values().iter().zip(vectors.iter()).flat_map(
208                    |(row_id, multivector)| {
209                        std::iter::repeat_n(
210                            *row_id,
211                            multivector.map(|multivec| multivec.len()).unwrap_or(0),
212                        )
213                    },
214                );
215                let row_ids = UInt64Array::from_iter_values(row_ids);
216                let vectors = vectors.values().as_fixed_size_list().clone();
217                let schema = Arc::new(Schema::new(vec![
218                    ROW_ID_FIELD.clone(),
219                    Field::new(self.column.as_str(), vectors.data_type().clone(), true),
220                ]));
221                let batch =
222                    RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(vectors)])?;
223                Ok(batch)
224            }
225            _ => Err(Error::Index {
226                message: format!(
227                    "Flatten: column {} is not a vector: {}",
228                    self.column,
229                    arr.data_type()
230                ),
231                location: location!(),
232            }),
233        }
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    use approx::assert_relative_eq;
242    use arrow_array::{FixedSizeListArray, Float16Array, Float32Array, Int32Array};
243    use arrow_schema::Schema;
244    use half::f16;
245    use lance_arrow::*;
246    use lance_linalg::distance::L2;
247
248    #[tokio::test]
249    async fn test_normalize_transformer_f32() {
250        let data = Float32Array::from_iter_values([1.0, 1.0, 2.0, 2.0].into_iter());
251        let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
252        let schema = Schema::new(vec![Field::new(
253            "v",
254            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
255            true,
256        )]);
257        let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl)]).unwrap();
258        let transformer = NormalizeTransformer::new("v");
259        let output = transformer.transform(&batch).unwrap();
260        let actual = output.column_by_name("v").unwrap();
261        let act_fsl = actual.as_fixed_size_list();
262        assert_eq!(act_fsl.len(), 2);
263        assert_relative_eq!(
264            act_fsl.value(0).as_primitive::<Float32Type>().values()[..],
265            [1.0 / 2.0_f32.sqrt(); 2]
266        );
267        assert_relative_eq!(
268            act_fsl.value(1).as_primitive::<Float32Type>().values()[..],
269            [2.0 / 8.0_f32.sqrt(); 2]
270        );
271    }
272
273    #[tokio::test]
274    async fn test_normalize_transformer_16() {
275        let data =
276            Float16Array::from_iter_values([1.0_f32, 1.0, 2.0, 2.0].into_iter().map(f16::from_f32));
277        let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
278        let schema = Schema::new(vec![Field::new(
279            "v",
280            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float16, true)), 2),
281            true,
282        )]);
283        let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl)]).unwrap();
284        let transformer = NormalizeTransformer::new("v");
285        let output = transformer.transform(&batch).unwrap();
286        let actual = output.column_by_name("v").unwrap();
287        let act_fsl = actual.as_fixed_size_list();
288        assert_eq!(act_fsl.len(), 2);
289        let expect_1 = [f16::from_f32_const(1.0) / f16::from_f32_const(2.0).sqrt(); 2];
290        act_fsl
291            .value(0)
292            .as_primitive::<Float16Type>()
293            .values()
294            .iter()
295            .zip(expect_1.iter())
296            .for_each(|(a, b)| assert!(a - b <= f16::epsilon()));
297
298        let expect_2 = [f16::from_f32_const(2.0) / f16::from_f32_const(8.0).sqrt(); 2];
299        act_fsl
300            .value(1)
301            .as_primitive::<Float16Type>()
302            .values()
303            .iter()
304            .zip(expect_2.iter())
305            .for_each(|(a, b)| assert!(a - b <= f16::epsilon()));
306    }
307
308    #[tokio::test]
309    async fn test_normalize_transformer_with_output_column() {
310        let data = Float32Array::from_iter_values([1.0, 1.0, 2.0, 2.0].into_iter());
311        let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
312        let schema = Schema::new(vec![Field::new(
313            "v",
314            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
315            true,
316        )]);
317        let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl.clone())]).unwrap();
318        let transformer = NormalizeTransformer::new_with_output("v", "o");
319        let output = transformer.transform(&batch).unwrap();
320        let input = output.column_by_name("v").unwrap();
321        assert_eq!(input.as_ref(), &fsl);
322        let actual = output.column_by_name("o").unwrap();
323        let act_fsl = actual.as_fixed_size_list();
324        assert_eq!(act_fsl.len(), 2);
325        assert_relative_eq!(
326            act_fsl.value(0).as_primitive::<Float32Type>().values()[..],
327            [1.0 / 2.0_f32.sqrt(); 2]
328        );
329        assert_relative_eq!(
330            act_fsl.value(1).as_primitive::<Float32Type>().values()[..],
331            [2.0 / 8.0_f32.sqrt(); 2]
332        );
333    }
334
335    #[tokio::test]
336    async fn test_drop_column() {
337        let i32_array = Int32Array::from_iter_values([1, 2].into_iter());
338        let data = Float32Array::from_iter_values([1.0, 1.0, 2.0, 2.0].into_iter());
339        let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
340        let schema = Schema::new(vec![
341            Field::new("i32", DataType::Int32, false),
342            Field::new(
343                "v",
344                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
345                true,
346            ),
347        ]);
348        let batch =
349            RecordBatch::try_new(schema.into(), vec![Arc::new(i32_array), Arc::new(fsl)]).unwrap();
350        let transformer = DropColumn::new("v");
351        let output = transformer.transform(&batch).unwrap();
352        assert!(output.column_by_name("v").is_none());
353
354        let dup_drop_result = transformer.transform(&output);
355        assert!(dup_drop_result.is_ok());
356    }
357
358    #[test]
359    fn test_is_all_finite() {
360        let array = Float32Array::from(vec![1.0, 2.0]);
361        assert!(is_all_finite::<Float32Type>(&array));
362
363        let failure_values = [f32::INFINITY, f32::NEG_INFINITY, f32::NAN];
364        for &v in &failure_values {
365            let array = Float32Array::from(vec![1.0, v]);
366            assert!(
367                !is_all_finite::<Float32Type>(&array),
368                "value {} should fail is_all_finite",
369                v
370            );
371        }
372    }
373
374    #[test]
375    fn test_finite_f16() {
376        let v1 = vec![f16::MAX; 10_000];
377        let v2 = vec![f16::MAX - f16::from_f32_const(1.0); 10_000];
378        let distance = f16::l2(&v1, &v2);
379        assert!(distance.is_finite());
380    }
381
382    #[test]
383    fn test_finite_f32() {
384        let v1 = vec![f32::MAX; 10_000];
385        let v2 = vec![f32::MAX - 1.0; 10_000];
386        let distance = f32::l2(&v1, &v2);
387        assert!(distance.is_finite());
388    }
389
390    #[test]
391    fn test_finite_f64() {
392        let v1 = vec![f64::MAX; 10_000];
393        let v2 = vec![f64::MAX - 1.0; 10_000];
394        let distance = f64::l2(&v1, &v2);
395        assert!(distance.is_finite());
396    }
397}