1use std::fmt::Debug;
8use std::sync::Arc;
9
10use arrow::datatypes::UInt64Type;
11use arrow_array::types::{Float16Type, Float32Type, Float64Type};
12use arrow_array::UInt64Array;
13use arrow_array::{cast::AsArray, Array, ArrowPrimitiveType, RecordBatch, UInt32Array};
14use arrow_schema::{DataType, Field, Schema};
15use lance_arrow::RecordBatchExt;
16use num_traits::Float;
17use snafu::location;
18
19use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD};
20use lance_linalg::kernels::normalize_fsl;
21use tracing::instrument;
22
23pub trait Transformer: Debug + Send + Sync {
27 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch>;
30}
31
32#[derive(Debug)]
36pub struct NormalizeTransformer {
37 input_column: String,
38 output_column: Option<String>,
39}
40
41impl NormalizeTransformer {
42 pub fn new(column: impl AsRef<str>) -> Self {
43 Self {
44 input_column: column.as_ref().to_owned(),
45 output_column: None,
46 }
47 }
48
49 pub fn new_with_output(input_column: impl AsRef<str>, output_column: impl AsRef<str>) -> Self {
52 Self {
53 input_column: input_column.as_ref().to_owned(),
54 output_column: Some(output_column.as_ref().to_owned()),
55 }
56 }
57}
58
59impl Transformer for NormalizeTransformer {
60 #[instrument(name = "NormalizeTransformer::transform", level = "debug", skip_all)]
61 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
62 let arr = batch
63 .column_by_name(&self.input_column)
64 .ok_or_else(|| Error::Index {
65 message: format!(
66 "Normalize Transform: column {} not found in RecordBatch {}",
67 self.input_column,
68 batch.schema(),
69 ),
70 location: location!(),
71 })?;
72
73 let data = arr.as_fixed_size_list();
74 let norm = normalize_fsl(data)?;
75 let transformed = Arc::new(norm);
76
77 if let Some(output_column) = &self.output_column {
78 let field = Field::new(output_column, transformed.data_type().clone(), true);
79 Ok(batch.try_with_column(field, transformed)?)
80 } else {
81 Ok(batch.replace_column_by_name(&self.input_column, transformed)?)
82 }
83 }
84}
85
86#[derive(Debug)]
88pub(crate) struct KeepFiniteVectors {
89 column: String,
90}
91
92impl KeepFiniteVectors {
93 pub fn new(column: impl AsRef<str>) -> Self {
94 Self {
95 column: column.as_ref().to_owned(),
96 }
97 }
98}
99
100fn is_all_finite<T: ArrowPrimitiveType>(arr: &dyn Array) -> bool
101where
102 T::Native: Float,
103{
104 arr.null_count() == 0
105 && !arr
106 .as_primitive::<T>()
107 .values()
108 .iter()
109 .any(|&v| !v.is_finite())
110}
111
112impl Transformer for KeepFiniteVectors {
113 #[instrument(name = "KeepFiniteVectors::transform", level = "debug", skip_all)]
114 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
115 let Some(arr) = batch.column_by_name(&self.column) else {
116 return Ok(batch.clone());
117 };
118
119 let data = match arr.data_type() {
120 DataType::FixedSizeList(_, _) => arr.as_fixed_size_list(),
121 DataType::List(_) => arr.as_list::<i32>().values().as_fixed_size_list(),
122 _ => {
123 return Err(Error::Index {
124 message: format!(
125 "KeepFiniteVectors: column {} is not a fixed size list: {}",
126 self.column,
127 arr.data_type()
128 ),
129 location: location!(),
130 })
131 }
132 };
133
134 let mut valid = Vec::with_capacity(batch.num_rows());
135 data.iter().enumerate().for_each(|(idx, arr)| {
136 if let Some(data) = arr {
137 let is_valid = match data.data_type() {
138 DataType::Float16 => is_all_finite::<Float16Type>(&data),
140 DataType::Float32 => is_all_finite::<Float32Type>(&data),
142 DataType::Float64 => is_all_finite::<Float64Type>(&data),
144 DataType::UInt8 => data.null_count() == 0,
145 DataType::Int8 => data.null_count() == 0,
146 _ => false,
147 };
148 if is_valid {
149 valid.push(idx as u32);
150 }
151 };
152 });
153 if valid.len() < batch.num_rows() {
154 let indices = UInt32Array::from(valid);
155 Ok(batch.take(&indices)?)
156 } else {
157 Ok(batch.clone())
158 }
159 }
160}
161
162#[derive(Debug)]
163pub struct DropColumn {
164 column: String,
165}
166
167impl DropColumn {
168 pub fn new(column: &str) -> Self {
169 Self {
170 column: column.to_owned(),
171 }
172 }
173}
174
175impl Transformer for DropColumn {
176 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
177 Ok(batch.drop_column(&self.column)?)
178 }
179}
180
181#[derive(Debug)]
182pub struct Flatten {
183 column: String,
184}
185
186impl Flatten {
187 pub fn new(column: &str) -> Self {
188 Self {
189 column: column.to_owned(),
190 }
191 }
192}
193
194impl Transformer for Flatten {
195 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
196 let Some(arr) = batch.column_by_name(&self.column) else {
197 return Ok(batch.clone());
200 };
201 match arr.data_type() {
202 DataType::FixedSizeList(_, _) => Ok(batch.clone()),
203 DataType::List(_) => {
204 let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
205 let vectors = arr.as_list::<i32>();
206
207 let row_ids = row_ids.values().iter().zip(vectors.iter()).flat_map(
208 |(row_id, multivector)| {
209 std::iter::repeat_n(
210 *row_id,
211 multivector.map(|multivec| multivec.len()).unwrap_or(0),
212 )
213 },
214 );
215 let row_ids = UInt64Array::from_iter_values(row_ids);
216 let vectors = vectors.values().as_fixed_size_list().clone();
217 let schema = Arc::new(Schema::new(vec![
218 ROW_ID_FIELD.clone(),
219 Field::new(self.column.as_str(), vectors.data_type().clone(), true),
220 ]));
221 let batch =
222 RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(vectors)])?;
223 Ok(batch)
224 }
225 _ => Err(Error::Index {
226 message: format!(
227 "Flatten: column {} is not a vector: {}",
228 self.column,
229 arr.data_type()
230 ),
231 location: location!(),
232 }),
233 }
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 use approx::assert_relative_eq;
242 use arrow_array::{FixedSizeListArray, Float16Array, Float32Array, Int32Array};
243 use arrow_schema::Schema;
244 use half::f16;
245 use lance_arrow::*;
246 use lance_linalg::distance::L2;
247
248 #[tokio::test]
249 async fn test_normalize_transformer_f32() {
250 let data = Float32Array::from_iter_values([1.0, 1.0, 2.0, 2.0].into_iter());
251 let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
252 let schema = Schema::new(vec![Field::new(
253 "v",
254 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
255 true,
256 )]);
257 let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl)]).unwrap();
258 let transformer = NormalizeTransformer::new("v");
259 let output = transformer.transform(&batch).unwrap();
260 let actual = output.column_by_name("v").unwrap();
261 let act_fsl = actual.as_fixed_size_list();
262 assert_eq!(act_fsl.len(), 2);
263 assert_relative_eq!(
264 act_fsl.value(0).as_primitive::<Float32Type>().values()[..],
265 [1.0 / 2.0_f32.sqrt(); 2]
266 );
267 assert_relative_eq!(
268 act_fsl.value(1).as_primitive::<Float32Type>().values()[..],
269 [2.0 / 8.0_f32.sqrt(); 2]
270 );
271 }
272
273 #[tokio::test]
274 async fn test_normalize_transformer_16() {
275 let data =
276 Float16Array::from_iter_values([1.0_f32, 1.0, 2.0, 2.0].into_iter().map(f16::from_f32));
277 let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
278 let schema = Schema::new(vec![Field::new(
279 "v",
280 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float16, true)), 2),
281 true,
282 )]);
283 let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl)]).unwrap();
284 let transformer = NormalizeTransformer::new("v");
285 let output = transformer.transform(&batch).unwrap();
286 let actual = output.column_by_name("v").unwrap();
287 let act_fsl = actual.as_fixed_size_list();
288 assert_eq!(act_fsl.len(), 2);
289 let expect_1 = [f16::from_f32_const(1.0) / f16::from_f32_const(2.0).sqrt(); 2];
290 act_fsl
291 .value(0)
292 .as_primitive::<Float16Type>()
293 .values()
294 .iter()
295 .zip(expect_1.iter())
296 .for_each(|(a, b)| assert!(a - b <= f16::epsilon()));
297
298 let expect_2 = [f16::from_f32_const(2.0) / f16::from_f32_const(8.0).sqrt(); 2];
299 act_fsl
300 .value(1)
301 .as_primitive::<Float16Type>()
302 .values()
303 .iter()
304 .zip(expect_2.iter())
305 .for_each(|(a, b)| assert!(a - b <= f16::epsilon()));
306 }
307
308 #[tokio::test]
309 async fn test_normalize_transformer_with_output_column() {
310 let data = Float32Array::from_iter_values([1.0, 1.0, 2.0, 2.0].into_iter());
311 let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
312 let schema = Schema::new(vec![Field::new(
313 "v",
314 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
315 true,
316 )]);
317 let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl.clone())]).unwrap();
318 let transformer = NormalizeTransformer::new_with_output("v", "o");
319 let output = transformer.transform(&batch).unwrap();
320 let input = output.column_by_name("v").unwrap();
321 assert_eq!(input.as_ref(), &fsl);
322 let actual = output.column_by_name("o").unwrap();
323 let act_fsl = actual.as_fixed_size_list();
324 assert_eq!(act_fsl.len(), 2);
325 assert_relative_eq!(
326 act_fsl.value(0).as_primitive::<Float32Type>().values()[..],
327 [1.0 / 2.0_f32.sqrt(); 2]
328 );
329 assert_relative_eq!(
330 act_fsl.value(1).as_primitive::<Float32Type>().values()[..],
331 [2.0 / 8.0_f32.sqrt(); 2]
332 );
333 }
334
335 #[tokio::test]
336 async fn test_drop_column() {
337 let i32_array = Int32Array::from_iter_values([1, 2].into_iter());
338 let data = Float32Array::from_iter_values([1.0, 1.0, 2.0, 2.0].into_iter());
339 let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
340 let schema = Schema::new(vec![
341 Field::new("i32", DataType::Int32, false),
342 Field::new(
343 "v",
344 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
345 true,
346 ),
347 ]);
348 let batch =
349 RecordBatch::try_new(schema.into(), vec![Arc::new(i32_array), Arc::new(fsl)]).unwrap();
350 let transformer = DropColumn::new("v");
351 let output = transformer.transform(&batch).unwrap();
352 assert!(output.column_by_name("v").is_none());
353
354 let dup_drop_result = transformer.transform(&output);
355 assert!(dup_drop_result.is_ok());
356 }
357
358 #[test]
359 fn test_is_all_finite() {
360 let array = Float32Array::from(vec![1.0, 2.0]);
361 assert!(is_all_finite::<Float32Type>(&array));
362
363 let failure_values = [f32::INFINITY, f32::NEG_INFINITY, f32::NAN];
364 for &v in &failure_values {
365 let array = Float32Array::from(vec![1.0, v]);
366 assert!(
367 !is_all_finite::<Float32Type>(&array),
368 "value {} should fail is_all_finite",
369 v
370 );
371 }
372 }
373
374 #[test]
375 fn test_finite_f16() {
376 let v1 = vec![f16::MAX; 10_000];
377 let v2 = vec![f16::MAX - f16::from_f32_const(1.0); 10_000];
378 let distance = f16::l2(&v1, &v2);
379 assert!(distance.is_finite());
380 }
381
382 #[test]
383 fn test_finite_f32() {
384 let v1 = vec![f32::MAX; 10_000];
385 let v2 = vec![f32::MAX - 1.0; 10_000];
386 let distance = f32::l2(&v1, &v2);
387 assert!(distance.is_finite());
388 }
389
390 #[test]
391 fn test_finite_f64() {
392 let v1 = vec![f64::MAX; 10_000];
393 let v2 = vec![f64::MAX - 1.0; 10_000];
394 let distance = f64::l2(&v1, &v2);
395 assert!(distance.is_finite());
396 }
397}