use std::{iter::repeat_with, sync::Arc};
use arrow_array::{
Array, ArrowNumericType, FixedSizeListArray, PrimitiveArray, RecordBatch, RecordBatchIterator,
RecordBatchReader,
cast::AsArray,
types::{Float16Type, Float32Type, Float64Type, Int32Type, Int64Type},
};
use arrow_cast::{can_cast_types, cast};
use arrow_schema::{ArrowError, DataType, Field, Schema};
use half::f16;
use lance_arrow::{DataTypeExt, FixedSizeListArrayExt};
use log::warn;
use num_traits::cast::AsPrimitive;
use super::inspect::infer_dimension;
use crate::error::Result;
fn cast_array<I: ArrowNumericType, O: ArrowNumericType>(
arr: &PrimitiveArray<I>,
) -> Arc<PrimitiveArray<O>>
where
I::Native: AsPrimitive<O::Native>,
{
Arc::new(PrimitiveArray::<O>::from_iter_values(
arr.values().iter().map(|v| (*v).as_()),
))
}
fn cast_float_array<I: ArrowNumericType>(
arr: &PrimitiveArray<I>,
dt: &DataType,
) -> std::result::Result<Arc<dyn Array>, ArrowError>
where
I::Native: AsPrimitive<f64> + AsPrimitive<f32> + AsPrimitive<f16>,
{
match dt {
DataType::Float16 => Ok(cast_array::<I, Float16Type>(arr)),
DataType::Float32 => Ok(cast_array::<I, Float32Type>(arr)),
DataType::Float64 => Ok(cast_array::<I, Float64Type>(arr)),
_ => Err(ArrowError::SchemaError(format!(
"Incompatible change field: unable to coerce {:?} to {:?}",
arr.data_type(),
dt
))),
}
}
fn coerce_array(
array: &Arc<dyn Array>,
field: &Field,
) -> std::result::Result<Arc<dyn Array>, ArrowError> {
if array.data_type() == field.data_type() {
return Ok(array.clone());
}
match (array.data_type(), field.data_type()) {
(adt, dt) if can_cast_types(adt, dt) => cast(&array, dt),
(adt, dt) if (adt.is_floating() || dt.is_floating()) => {
if adt.byte_width() > dt.byte_width() {
warn!(
"Coercing field {} {:?} to {:?} might lose precision",
field.name(),
adt,
dt
);
}
match adt {
DataType::Float16 => cast_float_array(array.as_primitive::<Float16Type>(), dt),
DataType::Float32 => cast_float_array(array.as_primitive::<Float32Type>(), dt),
DataType::Float64 => cast_float_array(array.as_primitive::<Float64Type>(), dt),
_ => unreachable!(),
}
}
(adt, DataType::FixedSizeList(exp_field, exp_dim)) => match adt {
DataType::FixedSizeList(_, dim) if dim == exp_dim => {
let actual_sub = array.as_fixed_size_list();
let values = coerce_array(actual_sub.values(), exp_field)?;
Ok(Arc::new(FixedSizeListArray::try_new_from_values(
values.clone(),
*dim,
)?) as Arc<dyn Array>)
}
DataType::List(_) | DataType::LargeList(_) => {
let Some(dim) = (match adt {
DataType::List(_) => infer_dimension::<Int32Type>(array.as_list::<i32>())
.map_err(|e| {
ArrowError::SchemaError(format!(
"failed to infer dimension from list: {}",
e
))
})?
.map(|d| d as i64),
DataType::LargeList(_) => infer_dimension::<Int64Type>(array.as_list::<i64>())
.map_err(|e| {
ArrowError::SchemaError(format!(
"failed to infer dimension from large list: {}",
e
))
})?,
_ => unreachable!(),
}) else {
return Err(ArrowError::SchemaError(format!(
"Incompatible coerce fixed size list: unable to coerce {:?} from {:?}",
field,
array.data_type()
)));
};
if dim != *exp_dim as i64 {
return Err(ArrowError::SchemaError(format!(
"Incompatible coerce fixed size list: expected dimension {} but got {}",
exp_dim, dim
)));
}
let values = coerce_array(array, exp_field)?;
Ok(Arc::new(FixedSizeListArray::try_new_from_values(
values.clone(),
*exp_dim,
)?) as Arc<dyn Array>)
}
_ => Err(ArrowError::SchemaError(format!(
"Incompatible coerce fixed size list: unable to coerce {:?} from {:?}",
field,
array.data_type()
)))?,
},
_ => Err(ArrowError::SchemaError(format!(
"Incompatible change field {}: unable to coerce {:?} to {:?}",
field.name(),
array.data_type(),
field.data_type()
)))?,
}
}
fn coerce_schema_batch(
batch: RecordBatch,
schema: Arc<Schema>,
) -> std::result::Result<RecordBatch, ArrowError> {
if batch.schema() == schema {
return Ok(batch);
}
let columns = schema
.fields()
.iter()
.map(|field| {
batch
.column_by_name(field.name())
.ok_or_else(|| {
ArrowError::SchemaError(format!("Column {} not found", field.name()))
})
.and_then(|c| coerce_array(c, field))
})
.collect::<std::result::Result<Vec<_>, ArrowError>>()?;
RecordBatch::try_new(schema, columns)
}
pub fn coerce_schema(
reader: impl RecordBatchReader + Send + 'static,
schema: Arc<Schema>,
) -> Result<Box<dyn RecordBatchReader + Send>> {
if reader.schema() == schema {
return Ok(Box::new(RecordBatchIterator::new(reader, schema)));
}
let s = schema.clone();
let batches = reader
.zip(repeat_with(move || s.clone()))
.map(|(batch, s)| coerce_schema_batch(batch?, s));
Ok(Box::new(RecordBatchIterator::new(batches, schema)))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use arrow_array::{
FixedSizeListArray, Float16Array, Float32Array, Float64Array, Int8Array, Int32Array,
RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::Field;
use half::f16;
use lance_arrow::FixedSizeListArrayExt;
#[test]
fn test_coerce_list_to_fixed_size_list() {
let schema = Arc::new(Schema::new(vec![
Field::new(
"fl",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 64),
true,
),
Field::new("s", DataType::Utf8, true),
Field::new("f", DataType::Float16, true),
Field::new("i", DataType::Int32, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(
FixedSizeListArray::try_new_from_values(
Float32Array::from_iter_values((0..256).map(|v| v as f32)),
64,
)
.unwrap(),
),
Arc::new(StringArray::from(vec![
Some("hello"),
Some("world"),
Some("from"),
Some("lance"),
])),
Arc::new(Float16Array::from_iter_values(
(0..4).map(|v| f16::from_f32(v as f32)),
)),
Arc::new(Int32Array::from_iter_values(0..4)),
],
)
.unwrap();
let reader =
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), schema.clone());
let expected_schema = Arc::new(Schema::new(vec![
Field::new(
"fl",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float16, true)), 64),
true,
),
Field::new("s", DataType::Utf8, true),
Field::new("f", DataType::Float64, true),
Field::new("i", DataType::Int8, true),
]));
let stream = coerce_schema(reader, expected_schema.clone()).unwrap();
let batches = stream.collect::<Vec<_>>();
assert_eq!(batches.len(), 1);
let batch = batches[0].as_ref().unwrap();
assert_eq!(batch.schema(), expected_schema);
let expected = RecordBatch::try_new(
expected_schema,
vec![
Arc::new(
FixedSizeListArray::try_new_from_values(
Float16Array::from_iter_values((0..256).map(|v| f16::from_f32(v as f32))),
64,
)
.unwrap(),
),
Arc::new(StringArray::from(vec![
Some("hello"),
Some("world"),
Some("from"),
Some("lance"),
])),
Arc::new(Float64Array::from_iter_values((0..4).map(|v| v as f64))),
Arc::new(Int8Array::from_iter_values(0..4)),
],
)
.unwrap();
assert_eq!(batch, &expected);
}
}