1use std::fmt::Debug;
8use std::sync::Arc;
9
10use arrow::datatypes::UInt64Type;
11use arrow_array::UInt64Array;
12use arrow_array::types::{Float16Type, Float32Type, Float64Type};
13use arrow_array::{Array, ArrowPrimitiveType, RecordBatch, UInt32Array, cast::AsArray};
14use arrow_schema::{DataType, Field, Schema};
15use lance_arrow::RecordBatchExt;
16use num_traits::Float;
17
18use lance_core::{Error, ROW_ID, ROW_ID_FIELD, Result};
19use lance_linalg::kernels::normalize_fsl;
20use tracing::instrument;
21
22pub trait Transformer: Debug + Send + Sync {
26 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch>;
29}
30
31#[derive(Debug)]
35pub struct NormalizeTransformer {
36 input_column: String,
37 output_column: Option<String>,
38}
39
40impl NormalizeTransformer {
41 pub fn new(column: impl AsRef<str>) -> Self {
42 Self {
43 input_column: column.as_ref().to_owned(),
44 output_column: None,
45 }
46 }
47
48 pub fn new_with_output(input_column: impl AsRef<str>, output_column: impl AsRef<str>) -> Self {
51 Self {
52 input_column: input_column.as_ref().to_owned(),
53 output_column: Some(output_column.as_ref().to_owned()),
54 }
55 }
56}
57
58impl Transformer for NormalizeTransformer {
59 #[instrument(name = "NormalizeTransformer::transform", level = "debug", skip_all)]
60 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
61 let arr = batch.column_by_name(&self.input_column).ok_or_else(|| {
62 Error::index(format!(
63 "Normalize Transform: column {} not found in RecordBatch {}",
64 self.input_column,
65 batch.schema(),
66 ))
67 })?;
68
69 let data = arr.as_fixed_size_list();
70 let norm = normalize_fsl(data)?;
71 let transformed = Arc::new(norm);
72
73 if let Some(output_column) = &self.output_column {
74 let field = Field::new(output_column, transformed.data_type().clone(), true);
75 Ok(batch.try_with_column(field, transformed)?)
76 } else {
77 Ok(batch.replace_column_by_name(&self.input_column, transformed)?)
78 }
79 }
80}
81
82#[derive(Debug)]
84pub(crate) struct KeepFiniteVectors {
85 column: String,
86}
87
88impl KeepFiniteVectors {
89 pub fn new(column: impl AsRef<str>) -> Self {
90 Self {
91 column: column.as_ref().to_owned(),
92 }
93 }
94}
95
96fn is_all_finite<T: ArrowPrimitiveType>(arr: &dyn Array) -> bool
97where
98 T::Native: Float,
99{
100 arr.null_count() == 0
101 && !arr
102 .as_primitive::<T>()
103 .values()
104 .iter()
105 .any(|&v| !v.is_finite())
106}
107
108impl Transformer for KeepFiniteVectors {
109 #[instrument(name = "KeepFiniteVectors::transform", level = "debug", skip_all)]
110 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
111 let Some(arr) = batch.column_by_name(&self.column) else {
112 return Ok(batch.clone());
113 };
114
115 let data = match arr.data_type() {
116 DataType::FixedSizeList(_, _) => arr.as_fixed_size_list(),
117 DataType::List(_) => arr.as_list::<i32>().values().as_fixed_size_list(),
118 _ => {
119 return Err(Error::index(format!(
120 "KeepFiniteVectors: column {} is not a fixed size list: {}",
121 self.column,
122 arr.data_type()
123 )));
124 }
125 };
126
127 let mut valid = Vec::with_capacity(batch.num_rows());
128 data.iter().enumerate().for_each(|(idx, arr)| {
129 if let Some(data) = arr {
130 let is_valid = match data.data_type() {
131 DataType::Float16 => is_all_finite::<Float16Type>(&data),
133 DataType::Float32 => is_all_finite::<Float32Type>(&data),
135 DataType::Float64 => is_all_finite::<Float64Type>(&data),
137 DataType::UInt8 => data.null_count() == 0,
138 DataType::Int8 => data.null_count() == 0,
139 _ => false,
140 };
141 if is_valid {
142 valid.push(idx as u32);
143 }
144 };
145 });
146 if valid.len() < batch.num_rows() {
147 let indices = UInt32Array::from(valid);
148 Ok(batch.take(&indices)?)
149 } else {
150 Ok(batch.clone())
151 }
152 }
153}
154
155#[derive(Debug)]
156pub struct DropColumn {
157 column: String,
158}
159
160impl DropColumn {
161 pub fn new(column: &str) -> Self {
162 Self {
163 column: column.to_owned(),
164 }
165 }
166}
167
168impl Transformer for DropColumn {
169 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
170 Ok(batch.drop_column(&self.column)?)
171 }
172}
173
174#[derive(Debug)]
175pub struct Flatten {
176 column: String,
177}
178
179impl Flatten {
180 pub fn new(column: &str) -> Self {
181 Self {
182 column: column.to_owned(),
183 }
184 }
185}
186
187impl Transformer for Flatten {
188 fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
189 let Some(arr) = batch.column_by_name(&self.column) else {
190 return Ok(batch.clone());
193 };
194 match arr.data_type() {
195 DataType::FixedSizeList(_, _) => Ok(batch.clone()),
196 DataType::List(_) => {
197 let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>();
198 let vectors = arr.as_list::<i32>();
199
200 let row_ids = row_ids.values().iter().zip(vectors.iter()).flat_map(
201 |(row_id, multivector)| {
202 std::iter::repeat_n(
203 *row_id,
204 multivector.map(|multivec| multivec.len()).unwrap_or(0),
205 )
206 },
207 );
208 let row_ids = UInt64Array::from_iter_values(row_ids);
209 let vectors = vectors.values().as_fixed_size_list().clone();
210 let schema = Arc::new(Schema::new(vec![
211 ROW_ID_FIELD.clone(),
212 Field::new(self.column.as_str(), vectors.data_type().clone(), true),
213 ]));
214 let batch =
215 RecordBatch::try_new(schema, vec![Arc::new(row_ids), Arc::new(vectors)])?;
216 Ok(batch)
217 }
218 _ => Err(Error::index(format!(
219 "Flatten: column {} is not a vector: {}",
220 self.column,
221 arr.data_type()
222 ))),
223 }
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 use approx::assert_relative_eq;
232 use arrow_array::{FixedSizeListArray, Float16Array, Float32Array, Int32Array};
233 use arrow_schema::Schema;
234 use half::f16;
235 use lance_arrow::*;
236 use lance_linalg::distance::L2;
237
238 #[tokio::test]
239 async fn test_normalize_transformer_f32() {
240 let data = Float32Array::from_iter_values([1.0, 1.0, 2.0, 2.0].into_iter());
241 let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
242 let schema = Schema::new(vec![Field::new(
243 "v",
244 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
245 true,
246 )]);
247 let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl)]).unwrap();
248 let transformer = NormalizeTransformer::new("v");
249 let output = transformer.transform(&batch).unwrap();
250 let actual = output.column_by_name("v").unwrap();
251 let act_fsl = actual.as_fixed_size_list();
252 assert_eq!(act_fsl.len(), 2);
253 assert_relative_eq!(
254 act_fsl.value(0).as_primitive::<Float32Type>().values()[..],
255 [1.0 / 2.0_f32.sqrt(); 2]
256 );
257 assert_relative_eq!(
258 act_fsl.value(1).as_primitive::<Float32Type>().values()[..],
259 [2.0 / 8.0_f32.sqrt(); 2]
260 );
261 }
262
263 #[tokio::test]
264 async fn test_normalize_transformer_16() {
265 let data =
266 Float16Array::from_iter_values([1.0_f32, 1.0, 2.0, 2.0].into_iter().map(f16::from_f32));
267 let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
268 let schema = Schema::new(vec![Field::new(
269 "v",
270 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float16, true)), 2),
271 true,
272 )]);
273 let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl)]).unwrap();
274 let transformer = NormalizeTransformer::new("v");
275 let output = transformer.transform(&batch).unwrap();
276 let actual = output.column_by_name("v").unwrap();
277 let act_fsl = actual.as_fixed_size_list();
278 assert_eq!(act_fsl.len(), 2);
279 let expect_1 = [f16::from_f32_const(1.0) / f16::from_f32_const(2.0).sqrt(); 2];
280 act_fsl
281 .value(0)
282 .as_primitive::<Float16Type>()
283 .values()
284 .iter()
285 .zip(expect_1.iter())
286 .for_each(|(a, b)| assert!(a - b <= f16::epsilon()));
287
288 let expect_2 = [f16::from_f32_const(2.0) / f16::from_f32_const(8.0).sqrt(); 2];
289 act_fsl
290 .value(1)
291 .as_primitive::<Float16Type>()
292 .values()
293 .iter()
294 .zip(expect_2.iter())
295 .for_each(|(a, b)| assert!(a - b <= f16::epsilon()));
296 }
297
298 #[tokio::test]
299 async fn test_normalize_transformer_with_output_column() {
300 let data = Float32Array::from_iter_values([1.0, 1.0, 2.0, 2.0].into_iter());
301 let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
302 let schema = Schema::new(vec![Field::new(
303 "v",
304 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
305 true,
306 )]);
307 let batch = RecordBatch::try_new(schema.into(), vec![Arc::new(fsl.clone())]).unwrap();
308 let transformer = NormalizeTransformer::new_with_output("v", "o");
309 let output = transformer.transform(&batch).unwrap();
310 let input = output.column_by_name("v").unwrap();
311 assert_eq!(input.as_ref(), &fsl);
312 let actual = output.column_by_name("o").unwrap();
313 let act_fsl = actual.as_fixed_size_list();
314 assert_eq!(act_fsl.len(), 2);
315 assert_relative_eq!(
316 act_fsl.value(0).as_primitive::<Float32Type>().values()[..],
317 [1.0 / 2.0_f32.sqrt(); 2]
318 );
319 assert_relative_eq!(
320 act_fsl.value(1).as_primitive::<Float32Type>().values()[..],
321 [2.0 / 8.0_f32.sqrt(); 2]
322 );
323 }
324
325 #[tokio::test]
326 async fn test_drop_column() {
327 let i32_array = Int32Array::from_iter_values([1, 2].into_iter());
328 let data = Float32Array::from_iter_values([1.0, 1.0, 2.0, 2.0].into_iter());
329 let fsl = FixedSizeListArray::try_new_from_values(data, 2).unwrap();
330 let schema = Schema::new(vec![
331 Field::new("i32", DataType::Int32, false),
332 Field::new(
333 "v",
334 DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
335 true,
336 ),
337 ]);
338 let batch =
339 RecordBatch::try_new(schema.into(), vec![Arc::new(i32_array), Arc::new(fsl)]).unwrap();
340 let transformer = DropColumn::new("v");
341 let output = transformer.transform(&batch).unwrap();
342 assert!(output.column_by_name("v").is_none());
343
344 let dup_drop_result = transformer.transform(&output);
345 assert!(dup_drop_result.is_ok());
346 }
347
348 #[test]
349 fn test_is_all_finite() {
350 let array = Float32Array::from(vec![1.0, 2.0]);
351 assert!(is_all_finite::<Float32Type>(&array));
352
353 let failure_values = [f32::INFINITY, f32::NEG_INFINITY, f32::NAN];
354 for &v in &failure_values {
355 let array = Float32Array::from(vec![1.0, v]);
356 assert!(
357 !is_all_finite::<Float32Type>(&array),
358 "value {} should fail is_all_finite",
359 v
360 );
361 }
362 }
363
364 #[test]
365 fn test_finite_f16() {
366 let v1 = vec![f16::MAX; 10_000];
367 let v2 = vec![f16::MAX - f16::from_f32_const(1.0); 10_000];
368 let distance = f16::l2(&v1, &v2);
369 assert!(distance.is_finite());
370 }
371
372 #[test]
373 fn test_finite_f32() {
374 let v1 = vec![f32::MAX; 10_000];
375 let v2 = vec![f32::MAX - 1.0; 10_000];
376 let distance = f32::l2(&v1, &v2);
377 assert!(distance.is_finite());
378 }
379
380 #[test]
381 fn test_finite_f64() {
382 let v1 = vec![f64::MAX; 10_000];
383 let v2 = vec![f64::MAX - 1.0; 10_000];
384 let distance = f64::l2(&v1, &v2);
385 assert!(distance.is_finite());
386 }
387}