Skip to main content

lance_index/vector/
transform.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Vector Transforms
5//!
6
7use std::fmt::Debug;
8use std::sync::Arc;
9
10use arrow::datatypes::UInt64Type;
11use arrow_array::UInt64Array;
12use arrow_array::types::{Float16Type, Float32Type, Float64Type};
13use arrow_array::{Array, ArrowPrimitiveType, RecordBatch, UInt32Array, cast::AsArray};
14use arrow_schema::{DataType, Field, Schema};
15use lance_arrow::RecordBatchExt;
16use num_traits::Float;
17
18use lance_core::{Error, ROW_ID, ROW_ID_FIELD, Result};
19use lance_linalg::kernels::normalize_fsl;
20use tracing::instrument;
21
22/// Transform of a Vector Matrix.
23///
24///
25pub trait Transformer: Debug + Send + Sync {
26    /// Transform a [`RecordBatch`] of vectors
27    ///
28    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch>;
29}
30
31/// Normalize Transformer
32///
33/// L2 Normalize each vector.
34#[derive(Debug)]
35pub struct NormalizeTransformer {
36    input_column: String,
37    output_column: Option<String>,
38}
39
40impl NormalizeTransformer {
41    pub fn new(column: impl AsRef<str>) -> Self {
42        Self {
43            input_column: column.as_ref().to_owned(),
44            output_column: None,
45        }
46    }
47
48    /// Create Normalize output transform that will be stored in a different column.
49    ///
50    pub fn new_with_output(input_column: impl AsRef<str>, output_column: impl AsRef<str>) -> Self {
51        Self {
52            input_column: input_column.as_ref().to_owned(),
53            output_column: Some(output_column.as_ref().to_owned()),
54        }
55    }
56}
57
58impl Transformer for NormalizeTransformer {
59    #[instrument(name = "NormalizeTransformer::transform", level = "debug", skip_all)]
60    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
61        let arr = batch.column_by_name(&self.input_column).ok_or_else(|| {
62            Error::index(format!(
63                "Normalize Transform: column {} not found in RecordBatch {}",
64                self.input_column,
65                batch.schema(),
66            ))
67        })?;
68
69        let data = arr.as_fixed_size_list();
70        let norm = normalize_fsl(data)?;
71        let transformed = Arc::new(norm);
72
73        if let Some(output_column) = &self.output_column {
74            let field = Field::new(output_column, transformed.data_type().clone(), true);
75            Ok(batch.try_with_column(field, transformed)?)
76        } else {
77            Ok(batch.replace_column_by_name(&self.input_column, transformed)?)
78        }
79    }
80}
81
82/// Only keep the vectors that is finite number, filter out NaN and Inf.
83#[derive(Debug)]
84pub(crate) struct KeepFiniteVectors {
85    column: String,
86}
87
88impl KeepFiniteVectors {
89    pub fn new(column: impl AsRef<str>) -> Self {
90        Self {
91            column: column.as_ref().to_owned(),
92        }
93    }
94}
95
96fn is_all_finite<T: ArrowPrimitiveType>(arr: &dyn Array) -> bool
97where
98    T::Native: Float,
99{
100    arr.null_count() == 0
101        && !arr
102            .as_primitive::<T>()
103            .values()
104            .iter()
105            .any(|&v| !v.is_finite())
106}
107
108impl Transformer for KeepFiniteVectors {
109    #[instrument(name = "KeepFiniteVectors::transform", level = "debug", skip_all)]
110    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
111        let Some(arr) = batch.column_by_name(&self.column) else {
112            return Ok(batch.clone());
113        };
114
115        let data = match arr.data_type() {
116            DataType::FixedSizeList(_, _) => arr.as_fixed_size_list(),
117            DataType::List(_) => arr.as_list::<i32>().values().as_fixed_size_list(),
118            _ => {
119                return Err(Error::index(format!(
120                    "KeepFiniteVectors: column {} is not a fixed size list: {}",
121                    self.column,
122                    arr.data_type()
123                )));
124            }
125        };
126
127        let mut valid = Vec::with_capacity(batch.num_rows());
128        data.iter().enumerate().for_each(|(idx, arr)| {
129            if let Some(data) = arr {
130                let is_valid = match data.data_type() {
131                    // f16 vectors are computed in f32 space, so they will not overflow.
132                    DataType::Float16 => is_all_finite::<Float16Type>(&data),
133                    // f32 vectors must be bounded to avoid overflow in distance computation.
134                    DataType::Float32 => is_all_finite::<Float32Type>(&data),
135                    // f32 vectors are computed in f32 space, so they have the same limit as f64.
136                    DataType::Float64 => is_all_finite::<Float64Type>(&data),
137                    DataType::UInt8 => data.null_count() == 0,
138                    DataType::Int8 => data.null_count() == 0,
139                    _ => false,
140                };
141                if is_valid {
142                    valid.push(idx as u32);
143                }
144            };
145        });
146        if valid.len() < batch.num_rows() {
147            let indices = UInt32Array::from(valid);
148            Ok(batch.take(&indices)?)
149        } else {
150            Ok(batch.clone())
151        }
152    }
153}
154
155#[derive(Debug)]
156pub struct DropColumn {
157    column: String,
158}
159
160impl DropColumn {
161    pub fn new(column: &str) -> Self {
162        Self {
163            column: column.to_owned(),
164        }
165    }
166}
167
168impl Transformer for DropColumn {
169    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
170        Ok(batch.drop_column(&self.column)?)
171    }
172}
173
174#[derive(Debug)]
175pub struct Flatten {
176    column: String,
177}
178
179impl Flatten {
180    pub fn new(column: &str) -> Self {
181        Self {
182            column: column.to_owned(),
183        }
184    }
185}
186
187impl Transformer for Flatten {
188    fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
189        let Some(arr) = batch.column_by_name(&self.column) else {
190            // this case is that we have precomputed buffers,
191            // so we don't need to flatten the original vectors.
192            return Ok(batch.clone());
193        };
194        match arr.data_type() {
195            DataType::FixedSizeList(_, _) => Ok(batch.clone()),
196            DataType::List(_) => {
197                let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
198                let vectors = arr.as_list::<i32>();
199
200                let row_ids = row_ids.values().iter().zip(vectors.iter()).flat_map(
201                    |(row_id, multivector)| {
202                        std::iter::repeat_n(
203                            *row_id,
204                            multivector.map(|multivec| multivec.len()).unwrap_or(0),
205                        )
206                    },
207                );
208                let row_ids = UInt64Array::from_iter_values(row_ids);
209                let vectors = vectors.values().as_fixed_size_list().clone();
210                let schema = Arc::new(Schema::new(vec![
211                    ROW_ID_FIELD.clone(),
212                    Field::new(self.column.as_str(), vectors.data_type().clone(), true),
213                ]));
214                let batch =
215                    RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(vectors)])?;
216                Ok(batch)
217            }
218            _ => Err(Error::index(format!(
219                "Flatten: column {} is not a vector: {}",
220                self.column,
221                arr.data_type()
222            ))),
223        }
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    use approx::assert_relative_eq;
232    use arrow_array::{FixedSizeListArray, Float16Array, Float32Array, Int32Array};
233    use arrow_schema::Schema;
234    use half::f16;
235    use lance_arrow::*;
236    use lance_linalg::distance::L2;
237
238    #[tokio::test]
239    async fn test_normalize_transformer_f32() {
240        let data = Float32Array::from_iter_values([1.0, 1.0, 2.0, 2.0].into_iter());
241        let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
242        let schema = Schema::new(vec![Field::new(
243            "v",
244            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
245            true,
246        )]);
247        let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl)]).unwrap();
248        let transformer = NormalizeTransformer::new("v");
249        let output = transformer.transform(&batch).unwrap();
250        let actual = output.column_by_name("v").unwrap();
251        let act_fsl = actual.as_fixed_size_list();
252        assert_eq!(act_fsl.len(), 2);
253        assert_relative_eq!(
254            act_fsl.value(0).as_primitive::<Float32Type>().values()[..],
255            [1.0 / 2.0_f32.sqrt(); 2]
256        );
257        assert_relative_eq!(
258            act_fsl.value(1).as_primitive::<Float32Type>().values()[..],
259            [2.0 / 8.0_f32.sqrt(); 2]
260        );
261    }
262
263    #[tokio::test]
264    async fn test_normalize_transformer_16() {
265        let data =
266            Float16Array::from_iter_values([1.0_f32, 1.0, 2.0, 2.0].into_iter().map(f16::from_f32));
267        let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
268        let schema = Schema::new(vec![Field::new(
269            "v",
270            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float16, true)), 2),
271            true,
272        )]);
273        let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl)]).unwrap();
274        let transformer = NormalizeTransformer::new("v");
275        let output = transformer.transform(&batch).unwrap();
276        let actual = output.column_by_name("v").unwrap();
277        let act_fsl = actual.as_fixed_size_list();
278        assert_eq!(act_fsl.len(), 2);
279        let expect_1 = [f16::from_f32_const(1.0) / f16::from_f32_const(2.0).sqrt(); 2];
280        act_fsl
281            .value(0)
282            .as_primitive::<Float16Type>()
283            .values()
284            .iter()
285            .zip(expect_1.iter())
286            .for_each(|(a, b)| assert!(a - b <= f16::epsilon()));
287
288        let expect_2 = [f16::from_f32_const(2.0) / f16::from_f32_const(8.0).sqrt(); 2];
289        act_fsl
290            .value(1)
291            .as_primitive::<Float16Type>()
292            .values()
293            .iter()
294            .zip(expect_2.iter())
295            .for_each(|(a, b)| assert!(a - b <= f16::epsilon()));
296    }
297
298    #[tokio::test]
299    async fn test_normalize_transformer_with_output_column() {
300        let data = Float32Array::from_iter_values([1.0, 1.0, 2.0, 2.0].into_iter());
301        let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
302        let schema = Schema::new(vec![Field::new(
303            "v",
304            DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
305            true,
306        )]);
307        let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl.clone())]).unwrap();
308        let transformer = NormalizeTransformer::new_with_output("v", "o");
309        let output = transformer.transform(&batch).unwrap();
310        let input = output.column_by_name("v").unwrap();
311        assert_eq!(input.as_ref(), &fsl);
312        let actual = output.column_by_name("o").unwrap();
313        let act_fsl = actual.as_fixed_size_list();
314        assert_eq!(act_fsl.len(), 2);
315        assert_relative_eq!(
316            act_fsl.value(0).as_primitive::<Float32Type>().values()[..],
317            [1.0 / 2.0_f32.sqrt(); 2]
318        );
319        assert_relative_eq!(
320            act_fsl.value(1).as_primitive::<Float32Type>().values()[..],
321            [2.0 / 8.0_f32.sqrt(); 2]
322        );
323    }
324
325    #[tokio::test]
326    async fn test_drop_column() {
327        let i32_array = Int32Array::from_iter_values([1, 2].into_iter());
328        let data = Float32Array::from_iter_values([1.0, 1.0, 2.0, 2.0].into_iter());
329        let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
330        let schema = Schema::new(vec![
331            Field::new("i32", DataType::Int32, false),
332            Field::new(
333                "v",
334                DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
335                true,
336            ),
337        ]);
338        let batch =
339            RecordBatch::try_new(schema.into(), vec![Arc::new(i32_array), Arc::new(fsl)]).unwrap();
340        let transformer = DropColumn::new("v");
341        let output = transformer.transform(&batch).unwrap();
342        assert!(output.column_by_name("v").is_none());
343
344        let dup_drop_result = transformer.transform(&output);
345        assert!(dup_drop_result.is_ok());
346    }
347
348    #[test]
349    fn test_is_all_finite() {
350        let array = Float32Array::from(vec![1.0, 2.0]);
351        assert!(is_all_finite::<Float32Type>(&array));
352
353        let failure_values = [f32::INFINITY, f32::NEG_INFINITY, f32::NAN];
354        for &v in &failure_values {
355            let array = Float32Array::from(vec![1.0, v]);
356            assert!(
357                !is_all_finite::<Float32Type>(&array),
358                "value {} should fail is_all_finite",
359                v
360            );
361        }
362    }
363
364    #[test]
365    fn test_finite_f16() {
366        let v1 = vec![f16::MAX; 10_000];
367        let v2 = vec![f16::MAX - f16::from_f32_const(1.0); 10_000];
368        let distance = f16::l2(&v1, &v2);
369        assert!(distance.is_finite());
370    }
371
372    #[test]
373    fn test_finite_f32() {
374        let v1 = vec![f32::MAX; 10_000];
375        let v2 = vec![f32::MAX - 1.0; 10_000];
376        let distance = f32::l2(&v1, &v2);
377        assert!(distance.is_finite());
378    }
379
380    #[test]
381    fn test_finite_f64() {
382        let v1 = vec![f64::MAX; 10_000];
383        let v2 = vec![f64::MAX - 1.0; 10_000];
384        let distance = f64::l2(&v1, &v2);
385        assert!(distance.is_finite());
386    }
387}