use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;
use itertools::Itertools;
use tracing::warn;
use crate::arrow::array::types::*;
use crate::arrow::array::{
self as arrow_array, make_array, new_null_array, Array, ArrayBuilder, ArrayData, ArrayRef,
AsArray, BooleanArray, Datum, MapArray, MutableArrayData, NullBufferBuilder, RecordBatch,
StringArray, StructArray,
};
use crate::arrow::buffer::{NullBuffer, OffsetBuffer};
use crate::arrow::compute::kernels::cmp::{distinct, eq, gt, gt_eq, lt, lt_eq, neq, not_distinct};
use crate::arrow::compute::kernels::comparison::in_list_utf8;
use crate::arrow::compute::kernels::numeric::{add, div, mul, sub};
use crate::arrow::compute::{and_kleene, cast, is_not_null, is_null, not, or_kleene};
use crate::arrow::datatypes::{
DataType as ArrowDataType, Field as ArrowField, Fields as ArrowFields, IntervalUnit,
Schema as ArrowSchema, TimeUnit,
};
use crate::arrow::error::ArrowError;
use crate::arrow::json::writer::{make_encoder, EncoderOptions};
use crate::arrow::json::StructMode;
use crate::engine::arrow_conversion::{TryFromKernel, TryIntoArrow};
use crate::engine::arrow_expression::opaque::{
ArrowOpaqueExpressionOpAdaptor, ArrowOpaquePredicateOpAdaptor,
};
use crate::engine::arrow_utils::parse_json_impl;
use crate::engine::arrow_utils::prim_array_cmp;
use crate::engine::ensure_data_types::{ensure_data_types, ValidationMode};
use crate::error::{DeltaResult, Error};
use crate::expressions::{
BinaryExpression, BinaryExpressionOp, BinaryPredicate, BinaryPredicateOp, Expression,
ExpressionRef, JunctionPredicate, JunctionPredicateOp, OpaqueExpression, OpaquePredicate,
Predicate, Scalar, Transform, UnaryExpression, UnaryExpressionOp, UnaryPredicate,
UnaryPredicateOp, VariadicExpression, VariadicExpressionOp,
};
use crate::schema::{DataType, PrimitiveType, StructField, StructType};
pub(super) trait ProvidesColumnByName {
fn schema_fields(&self) -> &ArrowFields;
fn column_by_name(&self, name: &str) -> Option<&ArrayRef>;
}
impl ProvidesColumnByName for RecordBatch {
fn schema_fields(&self) -> &ArrowFields {
self.schema_ref().fields()
}
fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
self.column_by_name(name)
}
}
impl ProvidesColumnByName for StructArray {
fn schema_fields(&self) -> &ArrowFields {
self.fields()
}
fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
self.column_by_name(name)
}
}
pub(super) fn extract_column(
mut parent: &dyn ProvidesColumnByName,
col: &[impl AsRef<str>],
) -> DeltaResult<ArrayRef> {
let mut field_names = col.iter();
let Some(field_name) = field_names.next() else {
return Err(ArrowError::SchemaError("Empty column path".to_string()))?;
};
let mut field_name = field_name.as_ref();
loop {
let child = parent
.column_by_name(field_name)
.ok_or_else(|| ArrowError::SchemaError(format!("No such field: {field_name}")))?;
field_name = match field_names.next() {
Some(name) => name.as_ref(),
None => return Ok(child.clone()),
};
parent = child
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| ArrowError::SchemaError(format!("Not a struct: {field_name}")))?;
}
}
fn evaluate_struct_expression(
fields: &[ExpressionRef],
batch: &RecordBatch,
output_schema: &StructType,
nullability_predicate: Option<&ExpressionRef>,
) -> DeltaResult<ArrayRef> {
if fields.len() != output_schema.num_fields() {
return Err(Error::generic(format!(
"Struct expression field count mismatch: {} fields in expression but {} in schema",
fields.len(),
output_schema.num_fields()
)));
}
let output_cols: Vec<ArrayRef> = fields
.iter()
.zip(output_schema.fields())
.map(|(expr, field)| evaluate_expression(expr, batch, Some(field.data_type())))
.try_collect()?;
let output_fields: Vec<ArrowField> = output_cols
.iter()
.zip(output_schema.fields())
.map(|(output_col, output_field)| {
ArrowField::new(
output_field.name(),
output_col.data_type().clone(),
output_field.nullable, )
})
.collect();
let null_buffer = if let Some(predicate_expr) = nullability_predicate {
let predicate_array = evaluate_expression(predicate_expr, batch, Some(&DataType::BOOLEAN))?;
let bool_array = predicate_array
.as_any()
.downcast_ref::<BooleanArray>()
.ok_or_else(|| Error::generic("Nullability predicate must evaluate to boolean"))?;
let values = bool_array.values();
let combined = match bool_array.nulls() {
Some(nulls) => values & nulls.inner(),
None => values.clone(),
};
Some(NullBuffer::new(combined))
} else {
None
};
let data = StructArray::try_new(output_fields.into(), output_cols, null_buffer)?;
Ok(Arc::new(data))
}
fn evaluate_transform_expression(
transform: &Transform,
batch: &RecordBatch,
output_schema: &StructType,
) -> DeltaResult<ArrayRef> {
let mut used_field_transforms = 0;
let mut output_cols = Vec::with_capacity(output_schema.num_fields());
let mut output_schema_iter = output_schema.fields();
let mut next_output_type = || {
output_schema_iter
.next()
.map(|field| field.data_type())
.ok_or_else(|| Error::generic("Too few fields in output schema"))
};
for expr in &transform.prepended_fields {
output_cols.push(evaluate_expression(expr, batch, Some(next_output_type()?))?);
}
let source_array = transform
.input_path()
.map(|path| extract_column(batch, path))
.transpose()?;
let source_null_buffer = source_array.as_ref().and_then(|arr| {
arr.as_any()
.downcast_ref::<StructArray>()
.and_then(|s| s.nulls().cloned())
});
let source_data: &dyn ProvidesColumnByName = match source_array {
Some(ref array) => array
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| Error::generic("Input path must point to a struct"))?,
None => batch,
};
for input_field in source_data.schema_fields() {
let field_name: &str = input_field.name();
let field_transform = transform.field_transforms.get(field_name);
if !field_transform.is_some_and(|t| t.is_replace) {
output_cols.push(extract_column(source_data, &[field_name])?);
let _ = next_output_type()?; }
if let Some(field_transform) = field_transform {
for expr in &field_transform.exprs {
output_cols.push(evaluate_expression(expr, batch, Some(next_output_type()?))?);
}
used_field_transforms += 1;
}
}
let required_count = transform
.field_transforms
.values()
.filter(|ft| !ft.optional)
.count();
if used_field_transforms < required_count {
return Err(Error::generic(
"Some non-optional field transforms reference invalid input field names",
));
}
if output_schema_iter.next().is_some() {
return Err(Error::generic("Too many fields in output schema"));
}
let output_fields: Vec<ArrowField> = output_cols
.iter()
.zip(output_schema.fields())
.map(|(output_col, output_field)| {
ArrowField::new(
output_field.name(),
output_col.data_type().clone(),
output_col.is_nullable(),
)
})
.collect();
let data = StructArray::try_new(output_fields.into(), output_cols, source_null_buffer)?;
Ok(Arc::new(data))
}
pub fn evaluate_expression(
expression: &Expression,
batch: &RecordBatch,
result_type: Option<&DataType>,
) -> DeltaResult<ArrayRef> {
use BinaryExpressionOp::*;
use Expression::*;
use UnaryExpressionOp::*;
use VariadicExpressionOp::*;
match (expression, result_type) {
(Literal(scalar), _) => {
validate_array_type(scalar.to_array(batch.num_rows())?, result_type)
}
(Column(name), _) => {
let arr = extract_column(batch, name)?;
if let Some(expected) = result_type {
ensure_data_types(expected, arr.data_type(), ValidationMode::TypesOnly)?;
}
Ok(arr)
}
(Struct(fields, nullability), Some(DataType::Struct(output_schema))) => {
evaluate_struct_expression(fields, batch, output_schema, nullability.as_ref())
}
(Struct(..), dt) => Err(Error::Generic(format!(
"Struct expression expects a DataType::Struct result, but got {dt:?}"
))),
(Transform(transform), Some(DataType::Struct(output_schema))) => {
evaluate_transform_expression(transform, batch, output_schema)
}
(Transform(_), _) => Err(Error::generic(
"Data type is required to evaluate transform expressions",
)),
(Predicate(pred), None | Some(&DataType::BOOLEAN)) => {
let result = evaluate_predicate(pred, batch, false)?;
Ok(Arc::new(result))
}
(Predicate(_), Some(data_type)) => Err(Error::generic(format!(
"Predicate evaluation produces boolean output, but caller expects {data_type:?}"
))),
(Unary(UnaryExpression { op: ToJson, expr }), result_type) => match result_type {
None | Some(&DataType::STRING) => {
let input = evaluate_expression(expr, batch, None)?;
Ok(to_json(&input)?)
}
Some(data_type) => Err(Error::generic(format!(
"ToJson operator requires STRING output, but got {data_type:?}"
))),
},
(Binary(BinaryExpression { op, left, right }), _) => {
let left_arr = evaluate_expression(left.as_ref(), batch, None)?;
let right_arr = evaluate_expression(right.as_ref(), batch, None)?;
type Operation = fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>;
let eval: Operation = match op {
Plus => add,
Minus => sub,
Multiply => mul,
Divide => div,
};
validate_array_type(eval(&left_arr, &right_arr)?, result_type)
}
(
Variadic(VariadicExpression {
op: Coalesce,
exprs,
}),
result_type,
) => {
let mut arrays: Vec<ArrayRef> = Vec::with_capacity(exprs.len());
for expr in exprs {
let array = evaluate_expression(expr, batch, result_type)?;
let null_count = array.null_count();
arrays.push(array);
if null_count == 0 {
break;
}
}
Ok(coalesce_arrays(&arrays, result_type)?)
}
(Opaque(OpaqueExpression { op, exprs }), _) => {
match op
.any_ref()
.downcast_ref::<ArrowOpaqueExpressionOpAdaptor>()
{
Some(op) => op.eval_expr(exprs, batch, result_type),
None => Err(Error::unsupported(format!(
"Unsupported opaque expression: {op:?}"
))),
}
}
(ParseJson(p), _) => {
let json_arr = evaluate_expression(&p.json_expr, batch, Some(&DataType::STRING))?;
let json_strings =
json_arr
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
Error::generic("ParseJson input must evaluate to a STRING column")
})?;
let arrow_schema = Arc::new(ArrowSchema::try_from_kernel(p.output_schema.as_ref())?);
match parse_json_impl(json_strings, arrow_schema.clone()) {
Ok(batch) => Ok(Arc::new(StructArray::from(batch)) as ArrayRef),
Err(e) => {
warn!(
"Failed to parse JSON stats as {}: {e}. Using null stats.",
p.output_schema,
);
Ok(new_null_array(
&ArrowDataType::Struct(arrow_schema.fields().clone()),
json_strings.len(),
))
}
}
}
(MapToStruct(m), Some(DataType::Struct(output_schema))) => {
let map_arr = evaluate_expression(&m.map_expr, batch, None)?;
let result = evaluate_map_to_struct(&map_arr, output_schema)?;
Ok(Arc::new(result) as ArrayRef)
}
(MapToStruct(_), dt) => Err(Error::Generic(format!(
"MapToStruct expression requires a DataType::Struct result type, but got {dt:?}"
))),
(Unknown(name), _) => Err(Error::unsupported(format!("Unknown expression: {name:?}"))),
}
}
#[derive(Clone, Copy)]
enum ViewCast {
ToView,
ToNonView,
}
fn cast_list_elements(
vals: &Arc<dyn Array>,
field: &Arc<ArrowField>,
dir: ViewCast,
) -> DeltaResult<Arc<dyn Array>> {
let to_type = match dir {
ViewCast::ToView => match field.data_type() {
ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => ArrowDataType::Utf8View,
ArrowDataType::Binary | ArrowDataType::LargeBinary => ArrowDataType::BinaryView,
_ => return Ok(vals.clone()),
},
ViewCast::ToNonView => match field.data_type() {
ArrowDataType::Utf8View => ArrowDataType::Utf8,
ArrowDataType::BinaryView => ArrowDataType::Binary,
other => {
if !matches!(
vals.data_type(),
ArrowDataType::ListView(_) | ArrowDataType::LargeListView(_)
) {
return Ok(vals.clone());
}
other.clone()
}
},
};
let new_field = Arc::new(field.as_ref().clone().with_data_type(to_type));
let container = match (vals.data_type(), dir) {
(ArrowDataType::List(_), _) => ArrowDataType::List(new_field),
(ArrowDataType::LargeList(_), _) => ArrowDataType::LargeList(new_field),
(ArrowDataType::ListView(_), ViewCast::ToView) => ArrowDataType::ListView(new_field),
(ArrowDataType::ListView(_), ViewCast::ToNonView) => ArrowDataType::List(new_field),
(ArrowDataType::LargeListView(_), ViewCast::ToView) => {
ArrowDataType::LargeListView(new_field)
}
(ArrowDataType::LargeListView(_), ViewCast::ToNonView) => {
ArrowDataType::LargeList(new_field)
}
(dt, _) => {
return Err(Error::generic(format!(
"cast_list_elements: expected a list type, got {dt:?}"
)))
}
};
Ok(cast(vals, &container)?)
}
fn arrow_convert_to_non_view_type(vals: Arc<dyn Array>) -> DeltaResult<Arc<dyn Array>> {
match vals.data_type() {
ArrowDataType::List(field) => cast_list_elements(&vals, field, ViewCast::ToNonView),
ArrowDataType::LargeList(field) => cast_list_elements(&vals, field, ViewCast::ToNonView),
ArrowDataType::ListView(field) => cast_list_elements(&vals, field, ViewCast::ToNonView),
ArrowDataType::LargeListView(field) => {
cast_list_elements(&vals, field, ViewCast::ToNonView)
}
ArrowDataType::Utf8View => Ok(cast(&vals, &ArrowDataType::Utf8)?),
ArrowDataType::BinaryView => Ok(cast(&vals, &ArrowDataType::Binary)?),
_ => Ok(vals),
}
}
fn arrow_convert_to_view_type(vals: Arc<dyn Array>) -> DeltaResult<Arc<dyn Array>> {
match vals.data_type() {
ArrowDataType::List(field) => cast_list_elements(&vals, field, ViewCast::ToView),
ArrowDataType::LargeList(field) => cast_list_elements(&vals, field, ViewCast::ToView),
ArrowDataType::ListView(field) => cast_list_elements(&vals, field, ViewCast::ToView),
ArrowDataType::LargeListView(field) => cast_list_elements(&vals, field, ViewCast::ToView),
ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => {
Ok(cast(&vals, &ArrowDataType::Utf8View)?)
}
ArrowDataType::Binary | ArrowDataType::LargeBinary => {
Ok(cast(&vals, &ArrowDataType::BinaryView)?)
}
_ => Ok(vals),
}
}
pub fn evaluate_predicate(
predicate: &Predicate,
batch: &RecordBatch,
inverted: bool,
) -> DeltaResult<BooleanArray> {
use BinaryPredicateOp::*;
use Predicate::*;
let maybe_inverted = |result: Cow<'_, BooleanArray>| match inverted {
true => not(&result),
false => Ok(result.into_owned()),
};
match predicate {
BooleanExpression(expr) => {
let arr = evaluate_expression(expr, batch, Some(&DataType::BOOLEAN))?;
match arr.as_any().downcast_ref::<BooleanArray>() {
Some(arr) => Ok(maybe_inverted(Cow::Borrowed(arr))?),
None => Err(Error::generic("expected boolean array")),
}
}
Not(pred) => evaluate_predicate(pred, batch, !inverted),
Unary(UnaryPredicate { op, expr }) => {
let arr = evaluate_expression(expr.as_ref(), batch, None)?;
let eval_op_fn = match (op, inverted) {
(UnaryPredicateOp::IsNull, false) => is_null,
(UnaryPredicateOp::IsNull, true) => is_not_null,
};
Ok(eval_op_fn(&arr)?)
}
Binary(BinaryPredicate { op, left, right }) => {
let (left, right) = (left.as_ref(), right.as_ref());
let eval_in = || match (left, right) {
(Expression::Literal(_), Expression::Column(_)) => {
let left = evaluate_expression(left, batch, None)?;
let left = arrow_convert_to_non_view_type(left)?;
let right = evaluate_expression(right, batch, None)?;
let right = arrow_convert_to_non_view_type(right)?;
if let Some(string_arr) = left.as_string_opt::<i32>() {
if let Some(list_arr) = right.as_list_opt::<i32>() {
if list_arr.value_type() == ArrowDataType::Utf8 {
let result = in_list_utf8(string_arr, list_arr)?;
return Ok(result);
}
}
}
use ArrowDataType::*;
prim_array_cmp! {
left, right,
(Int8, Int8Type),
(Int16, Int16Type),
(Int32, Int32Type),
(Int64, Int64Type),
(UInt8, UInt8Type),
(UInt16, UInt16Type),
(UInt32, UInt32Type),
(UInt64, UInt64Type),
(Float16, Float16Type),
(Float32, Float32Type),
(Float64, Float64Type),
(Timestamp(TimeUnit::Second, _), TimestampSecondType),
(Timestamp(TimeUnit::Millisecond, _), TimestampMillisecondType),
(Timestamp(TimeUnit::Microsecond, _), TimestampMicrosecondType),
(Timestamp(TimeUnit::Nanosecond, _), TimestampNanosecondType),
(Date32, Date32Type),
(Date64, Date64Type),
(Time32(TimeUnit::Second), Time32SecondType),
(Time32(TimeUnit::Millisecond), Time32MillisecondType),
(Time64(TimeUnit::Microsecond), Time64MicrosecondType),
(Time64(TimeUnit::Nanosecond), Time64NanosecondType),
(Duration(TimeUnit::Second), DurationSecondType),
(Duration(TimeUnit::Millisecond), DurationMillisecondType),
(Duration(TimeUnit::Microsecond), DurationMicrosecondType),
(Duration(TimeUnit::Nanosecond), DurationNanosecondType),
(Interval(IntervalUnit::DayTime), IntervalDayTimeType),
(Interval(IntervalUnit::YearMonth), IntervalYearMonthType),
(Interval(IntervalUnit::MonthDayNano), IntervalMonthDayNanoType),
(Decimal128(_, _), Decimal128Type),
(Decimal256(_, _), Decimal256Type)
}
}
(Expression::Literal(lit), Expression::Literal(Scalar::Array(ad))) => {
let exists = ad.array_elements().contains(lit);
Ok(BooleanArray::from(vec![exists]))
}
(l, r) => Err(Error::invalid_expression(format!(
"Invalid right value for (NOT) IN comparison, left is: {l} right is: {r}"
))),
};
let eval_fn = match (op, inverted) {
(LessThan, false) => lt,
(LessThan, true) => gt_eq,
(GreaterThan, false) => gt,
(GreaterThan, true) => lt_eq,
(Equal, false) => eq,
(Equal, true) => neq,
(Distinct, false) => distinct,
(Distinct, true) => not_distinct,
(In, _) => return Ok(maybe_inverted(Cow::Owned(eval_in()?))?),
};
let left = evaluate_expression(left, batch, None)?;
let right = evaluate_expression(right, batch, None)?;
let (left, right) = if left.data_type() == right.data_type() {
(left, right)
} else {
(
arrow_convert_to_view_type(left)?,
arrow_convert_to_view_type(right)?,
)
};
Ok(eval_fn(&left, &right)?)
}
Junction(JunctionPredicate { op, preds }) => {
use JunctionPredicateOp::*;
type Operation = fn(&BooleanArray, &BooleanArray) -> Result<BooleanArray, ArrowError>;
let (reducer, default): (Operation, _) = match (op, inverted) {
(And, false) | (Or, true) => (and_kleene, true),
(Or, false) | (And, true) => (or_kleene, false),
};
preds
.iter()
.map(|pred| evaluate_predicate(pred, batch, inverted))
.reduce(|l, r| Ok(reducer(&l?, &r?)?))
.unwrap_or_else(|| Ok(BooleanArray::from(vec![default; batch.num_rows()])))
}
Opaque(OpaquePredicate { op, exprs }) => {
match op.any_ref().downcast_ref::<ArrowOpaquePredicateOpAdaptor>() {
Some(op) => op.eval_pred(exprs, batch, inverted),
None => Err(Error::unsupported(format!(
"Unsupported opaque predicate: {op:?}"
))),
}
}
Unknown(name) => Err(Error::unsupported(format!("Unknown predicate: {name:?}"))),
}
}
pub fn to_json(input: &dyn Datum) -> Result<ArrayRef, ArrowError> {
let (array_ref, _is_scalar) = input.get();
match array_ref.data_type() {
ArrowDataType::Struct(_) => {
let struct_array = array_ref.as_struct_opt().ok_or_else(|| {
ArrowError::InvalidArgumentError(format!(
"Failed to convert {} to StructArray",
array_ref.data_type(),
))
})?;
let num_rows = struct_array.len();
if num_rows == 0 {
return Ok(Arc::new(StringArray::from(Vec::<Option<String>>::new())));
}
let field = Arc::new(ArrowField::new_struct(
"root",
struct_array.fields().iter().cloned().collect_vec(),
true,
));
let options = EncoderOptions::default().with_struct_mode(StructMode::ObjectOnly);
let mut encoder = make_encoder(&field, struct_array, &options)?;
const ROW_SIZE_ESTIMATE: usize = 64;
let mut data = Vec::with_capacity(num_rows * ROW_SIZE_ESTIMATE);
let mut offsets = Vec::with_capacity(num_rows + 1);
offsets.push(0);
let mut nulls = NullBufferBuilder::new(num_rows);
for i in 0..num_rows {
if struct_array.is_null(i) {
nulls.append_null();
} else {
encoder.encode(i, &mut data);
nulls.append_non_null();
}
let offset = i32::try_from(data.len()).map_err(|_| {
ArrowError::InvalidArgumentError("Failed to convert offset".to_string())
})?;
offsets.push(offset);
}
let array = StringArray::try_new(
OffsetBuffer::new(offsets.into()),
data.into(),
nulls.finish(),
)?;
Ok(Arc::new(array))
}
_ => Err(ArrowError::InvalidArgumentError(format!(
"TO_JSON can only be applied to struct arrays, got {:?}",
array_ref.data_type()
))),
}
}
pub fn coalesce_arrays(
arrays: &[ArrayRef],
result_type: Option<&DataType>,
) -> Result<ArrayRef, ArrowError> {
let Some((first, rest)) = arrays.split_first() else {
return Err(ArrowError::InvalidArgumentError(
"The default engine currently does not support empty COALESCE statements".into(),
));
};
if let Some(result_type) = result_type {
let result_type = result_type.try_into_arrow()?;
if first.data_type() != &result_type {
return Err(ArrowError::InvalidArgumentError(format!(
"Requested result type {result_type:?} does not match arrays' data type {:?}",
first.data_type()
)));
}
}
if rest.is_empty() {
return Ok(first.clone());
}
for (i, arr) in rest.iter().enumerate() {
if arr.len() != first.len() {
return Err(ArrowError::InvalidArgumentError(format!(
"Array at index {} has length {}, expected {}",
i + 1,
arr.len(),
first.len()
)));
}
if arr.data_type() != first.data_type() {
return Err(ArrowError::InvalidArgumentError(format!(
"Array at index {} has type {:?}, but expected {:?}",
i + 1,
arr.data_type(),
first.data_type()
)));
}
}
let array_data: Vec<ArrayData> = arrays.iter().map(|arr| arr.to_data()).collect();
let mut mutable = MutableArrayData::new(array_data.iter().collect(), false, first.len());
for row in 0..first.len() {
match arrays.iter().enumerate().find(|(_, arr)| arr.is_valid(row)) {
Some((array_idx, _)) => mutable.extend(array_idx, row, row + 1),
None => mutable.extend_nulls(1),
}
}
Ok(make_array(mutable.freeze()))
}
fn evaluate_map_to_struct(
map_arr: &ArrayRef,
output_schema: &StructType,
) -> DeltaResult<StructArray> {
let map_array = map_arr
.as_any()
.downcast_ref::<MapArray>()
.ok_or_else(|| Error::generic("MapToStruct requires a MapArray as input"))?;
let map_keys = map_array
.keys()
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| Error::generic("MapToStruct requires maps with string keys"))?;
let map_values = map_array
.values()
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| Error::generic("MapToStruct requires maps with string values"))?;
let num_rows = map_array.len();
let fields: Vec<&StructField> = output_schema.fields().collect();
let mut builders: Vec<Box<dyn ArrayBuilder>> = Vec::with_capacity(fields.len());
let mut target_types: Vec<&PrimitiveType> = Vec::with_capacity(fields.len());
for field in &fields {
let prim = match field.data_type() {
DataType::Primitive(p) => p,
other => {
return Err(Error::generic(format!(
"MapToStruct only supports primitive target types, got {other:?}"
)));
}
};
target_types.push(prim);
let arrow_type = ArrowDataType::try_from_kernel(field.data_type())?;
builders.push(arrow_array::make_builder(&arrow_type, num_rows));
}
let field_indices: HashMap<&str, usize> = HashMap::from_iter(
fields
.iter()
.enumerate()
.map(|(i, f)| (f.name().as_str(), i)),
);
let mut matched_entry_idx: Vec<i32> = vec![-1; fields.len()];
let offsets = map_array.value_offsets();
let mut entry_end = offsets[0];
for row in 0..num_rows {
let entry_start = entry_end;
entry_end = offsets[row + 1];
if map_array.is_valid(row) {
for entry_idx in entry_start..entry_end {
let key = map_keys.value(entry_idx as usize);
if let Some(&i) = field_indices.get(key) {
matched_entry_idx[i] = entry_idx;
}
}
}
for (i, field) in fields.iter().enumerate() {
let entry_idx = matched_entry_idx[i];
let builder = builders[i].as_mut();
if entry_idx >= entry_start && map_values.is_valid(entry_idx as usize) {
let raw = map_values.value(entry_idx as usize);
let scalar = target_types[i].parse_scalar(raw)?;
scalar.append_to(builder, 1)?;
} else {
Scalar::append_null(builder, field.data_type(), 1)?;
}
}
}
let output_columns: Vec<ArrayRef> = builders.iter_mut().map(|b| b.finish()).collect();
let arrow_fields: Vec<ArrowField> = fields
.iter()
.map(|f| ArrowField::try_from_kernel(*f))
.try_collect()?;
Ok(StructArray::try_new(
arrow_fields.into(),
output_columns,
None,
)?)
}
fn validate_array_type(array: ArrayRef, expected: Option<&DataType>) -> DeltaResult<ArrayRef> {
if let Some(expected) = expected {
ensure_data_types(expected, array.data_type(), ValidationMode::TypesAndNames)?;
}
Ok(array)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arrow::array::{
ArrayRef, BooleanArray, Int32Array, Int64Array, StringArray, StructArray,
};
use crate::arrow::datatypes::{
DataType as ArrowDataType, Field as ArrowField, Schema as ArrowSchema,
};
use crate::expressions::column_expr;
use crate::expressions::{column_expr_ref, BinaryExpressionOp, Expression as Expr, Transform};
use crate::schema::{DataType, StructField, StructType};
use crate::utils::test_utils::assert_result_error_with_message;
use rstest::rstest;
use std::sync::Arc;
fn create_test_batch() -> RecordBatch {
let schema = ArrowSchema::new(vec![
ArrowField::new("a", ArrowDataType::Int32, false),
ArrowField::new("b", ArrowDataType::Int32, false),
ArrowField::new("c", ArrowDataType::Int32, false),
]);
let a_values = Int32Array::from(vec![1, 2, 3]);
let b_values = Int32Array::from(vec![10, 20, 30]);
let c_values = Int32Array::from(vec![100, 200, 300]);
RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(a_values), Arc::new(b_values), Arc::new(c_values)],
)
.unwrap()
}
fn validate_i32_column(result: &StructArray, idx: usize, expected: &[i32]) {
let col = result
.column(idx)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(col.values(), expected);
}
fn create_nested_test_batch() -> RecordBatch {
let inner_schema = ArrowSchema::new(vec![
ArrowField::new("x", ArrowDataType::Int32, false),
ArrowField::new("y", ArrowDataType::Int32, false),
]);
let nested_type = ArrowDataType::Struct(inner_schema.fields().clone());
let schema = ArrowSchema::new(vec![
ArrowField::new("a", ArrowDataType::Int32, false),
ArrowField::new("nested", nested_type, false),
]);
let x_values = Int32Array::from(vec![1, 2, 3]);
let y_values = Int32Array::from(vec![10, 20, 30]);
let nested_struct = StructArray::from(vec![
(
Arc::new(ArrowField::new("x", ArrowDataType::Int32, false)),
Arc::new(x_values) as ArrayRef,
),
(
Arc::new(ArrowField::new("y", ArrowDataType::Int32, false)),
Arc::new(y_values) as ArrayRef,
),
]);
let a_values = Int32Array::from(vec![100, 200, 300]);
RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(a_values), Arc::new(nested_struct)],
)
.unwrap()
}
#[test]
fn test_identity_transforms() {
let batch = create_test_batch();
let transform = Transform::new_top_level();
let output_schema = StructType::new_unchecked(vec![
StructField::new("a", DataType::INTEGER, false),
StructField::new("b", DataType::INTEGER, false),
StructField::new("c", DataType::INTEGER, false),
]);
let expr = Expr::Transform(transform);
let result = evaluate_expression(
&expr,
&batch,
Some(&DataType::Struct(Box::new(output_schema))),
)
.unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
for i in 0..3 {
assert_eq!(struct_result.column(i).as_ref(), batch.column(i).as_ref());
}
let nested_batch = create_nested_test_batch();
let transform_nested = Transform::new_nested(["nested"]);
let nested_output_schema = StructType::new_unchecked(vec![
StructField::new("x", DataType::INTEGER, false),
StructField::new("y", DataType::INTEGER, false),
]);
let expr_nested = Expr::Transform(transform_nested);
let result_nested = evaluate_expression(
&expr_nested,
&nested_batch,
Some(&DataType::Struct(Box::new(nested_output_schema))),
)
.unwrap();
let original_nested = nested_batch
.column_by_name("nested")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let nested_result = result_nested
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
for i in 0..2 {
assert_eq!(
nested_result.column(i).as_ref(),
original_nested.column(i).as_ref()
);
}
}
#[test]
fn test_field_operations_and_multiple_insertions() {
let batch = create_test_batch();
let transform = Transform::new_top_level()
.with_replaced_field("a", column_expr_ref!("b"))
.with_dropped_field("b")
.with_inserted_field(None::<&str>, Expr::literal(1).into())
.with_inserted_field(None::<&str>, Expr::literal(2).into())
.with_inserted_field(None::<&str>, column_expr_ref!("c"))
.with_inserted_field(Some("c"), Expr::literal(42).into())
.with_inserted_field(Some("c"), column_expr_ref!("a"))
.with_inserted_field(Some("c"), Expr::literal(99).into());
let output_schema = StructType::new_unchecked(vec![
StructField::new("pre1", DataType::INTEGER, false), StructField::new("pre2", DataType::INTEGER, false), StructField::new("pre3", DataType::INTEGER, false), StructField::new("a", DataType::INTEGER, false), StructField::new("c", DataType::INTEGER, false), StructField::new("after_c1", DataType::INTEGER, false), StructField::new("after_c2", DataType::INTEGER, false), StructField::new("after_c3", DataType::INTEGER, false), ]);
let expr = Expr::Transform(transform);
let result = evaluate_expression(
&expr,
&batch,
Some(&DataType::Struct(Box::new(output_schema))),
)
.unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_result.num_columns(), 8);
assert_eq!(struct_result.len(), 3);
validate_i32_column(struct_result, 0, &[1, 1, 1]);
validate_i32_column(struct_result, 1, &[2, 2, 2]);
validate_i32_column(struct_result, 2, &[100, 200, 300]);
validate_i32_column(struct_result, 3, &[10, 20, 30]);
validate_i32_column(struct_result, 4, &[100, 200, 300]);
validate_i32_column(struct_result, 5, &[42, 42, 42]);
validate_i32_column(struct_result, 6, &[1, 2, 3]); validate_i32_column(struct_result, 7, &[99, 99, 99]);
}
#[test]
fn test_nested_path_transforms() {
let nested_batch = create_nested_test_batch();
let transform_copy = Transform::new_nested(["nested"]);
let copy_output_schema = StructType::new_unchecked(vec![
StructField::new("x", DataType::INTEGER, false),
StructField::new("y", DataType::INTEGER, false),
]);
let expr_copy = Expr::Transform(transform_copy);
let result_copy = evaluate_expression(
&expr_copy,
&nested_batch,
Some(&DataType::Struct(Box::new(copy_output_schema))),
)
.unwrap();
let copy_result = result_copy.as_any().downcast_ref::<StructArray>().unwrap();
let original_nested = nested_batch
.column_by_name("nested")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
for i in 0..2 {
assert_eq!(
copy_result.column(i).as_ref(),
original_nested.column(i).as_ref()
);
}
let transform_modify = Transform::new_nested(["nested"])
.with_replaced_field("x".to_string(), Expr::literal(777).into())
.with_inserted_field(Some("y"), Expr::literal(555).into());
let modify_output_schema = StructType::new_unchecked(vec![
StructField::new("x", DataType::INTEGER, false), StructField::new("y", DataType::INTEGER, false), StructField::new("new_field", DataType::INTEGER, false), ]);
let expr_modify = Expr::Transform(transform_modify);
let result_modify = evaluate_expression(
&expr_modify,
&nested_batch,
Some(&DataType::Struct(Box::new(modify_output_schema))),
)
.unwrap();
let modify_result = result_modify
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
assert_eq!(modify_result.num_columns(), 3);
assert_eq!(modify_result.len(), 3);
validate_i32_column(modify_result, 0, &[777, 777, 777]);
validate_i32_column(modify_result, 1, &[10, 20, 30]);
validate_i32_column(modify_result, 2, &[555, 555, 555]);
}
#[test]
fn test_transform_validation() {
let batch = create_test_batch();
let transform =
Transform::new_top_level().with_replaced_field("missing", Expr::literal(1).into());
let output_schema = StructType::new_unchecked(vec![
StructField::not_null("a", DataType::INTEGER),
StructField::not_null("b", DataType::INTEGER),
StructField::not_null("c", DataType::INTEGER),
]);
let expr = Expr::Transform(transform);
let result = evaluate_expression(
&expr,
&batch,
Some(&DataType::Struct(Box::new(output_schema.clone()))),
);
assert!(result
.unwrap_err()
.to_string()
.contains("reference invalid input field names"));
let transform2 = Transform::new_top_level()
.with_inserted_field(Some("nonexistent"), Expr::literal(1).into());
let expr2 = Expr::Transform(transform2);
let result2 = evaluate_expression(
&expr2,
&batch,
Some(&DataType::Struct(Box::new(output_schema.clone()))),
);
assert!(result2.is_err());
assert!(result2
.unwrap_err()
.to_string()
.contains("reference invalid input field names"));
let transform3 = Transform::new_top_level().with_dropped_field("a");
let wrong_output_schema = StructType::new_unchecked(vec![
StructField::not_null("a", DataType::INTEGER), StructField::not_null("b", DataType::INTEGER),
StructField::not_null("c", DataType::INTEGER),
]);
let expr3 = Expr::Transform(transform3);
let result3 = evaluate_expression(
&expr3,
&batch,
Some(&DataType::Struct(Box::new(wrong_output_schema))),
);
assert!(result3.is_err());
assert!(result3
.unwrap_err()
.to_string()
.contains("Too many fields in output schema"));
let transform3 = Transform::new_top_level().with_dropped_field("a");
let wrong_output_schema =
StructType::new_unchecked(vec![StructField::not_null("c", DataType::INTEGER)]);
let expr3 = Expr::Transform(transform3);
let result3 = evaluate_expression(
&expr3,
&batch,
Some(&DataType::Struct(Box::new(wrong_output_schema))),
);
assert!(result3.is_err());
assert!(result3
.unwrap_err()
.to_string()
.contains("Too few fields in output schema"));
let transform4 = Transform::new_top_level();
let expr4 = Expr::Transform(transform4);
let result4 = evaluate_expression(&expr4, &batch, None);
assert!(result4.is_err());
assert!(result4
.unwrap_err()
.to_string()
.contains("Data type is required"));
}
#[test]
fn test_drop_field_if_exists_present() {
let batch = create_test_batch();
let transform = Transform::new_top_level().with_dropped_field_if_exists("a");
let output_schema = StructType::new_unchecked(vec![
StructField::not_null("b", DataType::INTEGER),
StructField::not_null("c", DataType::INTEGER),
]);
let expr = Expr::Transform(transform);
let result = evaluate_expression(
&expr,
&batch,
Some(&DataType::Struct(Box::new(output_schema))),
)
.unwrap();
let result = result.as_any().downcast_ref::<StructArray>().unwrap();
validate_i32_column(result, 0, &[10, 20, 30]);
validate_i32_column(result, 1, &[100, 200, 300]);
}
#[test]
fn test_drop_field_if_exists_missing() {
let batch = create_test_batch();
let transform = Transform::new_top_level().with_dropped_field_if_exists("nonexistent");
let output_schema = StructType::new_unchecked(vec![
StructField::not_null("a", DataType::INTEGER),
StructField::not_null("b", DataType::INTEGER),
StructField::not_null("c", DataType::INTEGER),
]);
let expr = Expr::Transform(transform);
let result = evaluate_expression(
&expr,
&batch,
Some(&DataType::Struct(Box::new(output_schema))),
)
.unwrap();
let result = result.as_any().downcast_ref::<StructArray>().unwrap();
validate_i32_column(result, 0, &[1, 2, 3]);
validate_i32_column(result, 1, &[10, 20, 30]);
validate_i32_column(result, 2, &[100, 200, 300]);
}
#[test]
fn test_drop_field_non_optional_missing_still_errors() {
let batch = create_test_batch();
let transform = Transform::new_top_level().with_dropped_field("nonexistent");
let output_schema = StructType::new_unchecked(vec![
StructField::not_null("a", DataType::INTEGER),
StructField::not_null("b", DataType::INTEGER),
StructField::not_null("c", DataType::INTEGER),
]);
let expr = Expr::Transform(transform);
let result = evaluate_expression(
&expr,
&batch,
Some(&DataType::Struct(Box::new(output_schema))),
);
assert!(result
.unwrap_err()
.to_string()
.contains("reference invalid input field names"));
}
#[test]
fn test_struct_expression_schema_validation() {
let batch = create_test_batch();
let test_cases = vec![
(
"too many schema fields",
Expr::struct_from([column_expr_ref!("a"), column_expr_ref!("b")]),
StructType::new_unchecked(vec![
StructField::not_null("a", DataType::INTEGER),
StructField::not_null("b", DataType::INTEGER),
StructField::not_null("c", DataType::INTEGER),
]),
),
(
"too few schema fields",
Expr::struct_from([
column_expr_ref!("a"),
column_expr_ref!("b"),
column_expr_ref!("c"),
]),
StructType::new_unchecked(vec![
StructField::not_null("a", DataType::INTEGER),
StructField::not_null("b", DataType::INTEGER),
]),
),
];
for (name, expr, schema) in test_cases {
let result =
evaluate_expression(&expr, &batch, Some(&DataType::Struct(Box::new(schema))));
assert!(result.is_err(), "Test case '{name}' should fail");
assert!(
result
.unwrap_err()
.to_string()
.contains("field count mismatch"),
"Test case '{name}' should contain 'field count mismatch' error"
);
}
}
#[test]
fn test_coalesce_arrays_same_type() {
let arr1 = Int32Array::from(vec![Some(1), None, Some(3), None, None, Some(8), None]);
let arr2 = Int32Array::from(vec![None, Some(2), Some(4), None, Some(6), None, None]);
let arr3 = Int32Array::from(vec![None, None, None, Some(5), Some(7), Some(9), None]);
let result =
coalesce_arrays(&[Arc::new(arr1), Arc::new(arr2), Arc::new(arr3)], None).unwrap();
let result_array = result.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(result_array.len(), 7);
assert_eq!(result_array.value(0), 1); assert_eq!(result_array.value(1), 2); assert_eq!(result_array.value(2), 3); assert_eq!(result_array.value(3), 5); assert_eq!(result_array.value(4), 6); assert_eq!(result_array.value(5), 8); assert!(result_array.is_null(6));
let str_arr1 = Arc::new(StringArray::from(vec![Some("a"), None, Some("c")]));
let str_arr2 = Arc::new(StringArray::from(vec![None, Some("b"), None]));
let str_result = coalesce_arrays(&[str_arr1, str_arr2], None).unwrap();
let str_result_array = str_result.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(str_result_array.len(), 3);
assert_eq!(str_result_array.value(0), "a"); assert_eq!(str_result_array.value(1), "b"); assert_eq!(str_result_array.value(2), "c"); }
#[test]
fn test_coalesce_arrays_all_nulls() {
let arr1 = Arc::new(Int32Array::from(vec![None, None, None]));
let arr2 = Arc::new(Int32Array::from(vec![None, None, None]));
let result = coalesce_arrays(&[arr1, arr2], None).unwrap();
let result_array = result.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(result_array.len(), 3);
assert!(result_array.is_null(0));
assert!(result_array.is_null(1));
assert!(result_array.is_null(2));
}
#[test]
fn test_coalesce_arrays_single_array() {
let arr: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
let result = coalesce_arrays(std::slice::from_ref(&arr), None).unwrap();
assert_eq!(result.as_ref(), arr.as_ref());
}
#[test]
fn test_coalesce_arrays_type_mismatch_error() {
let int32_arr = Arc::new(Int32Array::from(vec![Some(1), None]));
let int64_arr = Arc::new(Int64Array::from(vec![None, Some(2)]));
let result = coalesce_arrays(&[int32_arr, int64_arr], None);
assert_result_error_with_message(
result,
"Array at index 1 has type Int64, but expected Int32",
);
let int_arr = Arc::new(Int32Array::from(vec![Some(1)]));
let str_arr = Arc::new(StringArray::from(vec![Some("hello")]));
let result2 = coalesce_arrays(&[int_arr, str_arr], None);
assert_result_error_with_message(
result2,
"Array at index 1 has type Utf8, but expected Int32",
);
}
#[test]
fn test_coalesce_arrays_length_mismatch_error() {
let arr1 = Arc::new(Int32Array::from(vec![Some(1), Some(2)]));
let arr2 = Arc::new(Int32Array::from(vec![Some(3), Some(4), Some(5)]));
let result = coalesce_arrays(&[arr1, arr2], None);
assert_result_error_with_message(result, "Array at index 1 has length 3, expected 2");
}
#[test]
fn test_coalesce_arrays_empty_input_error() {
let result = coalesce_arrays(&[], None);
assert_result_error_with_message(result, "empty COALESCE statements");
}
#[test]
fn test_coalesce_arrays_result_type_validation() {
let arr1 = Arc::new(Int32Array::from(vec![Some(1), None]));
let arr2 = Arc::new(Int32Array::from(vec![None, Some(2)]));
let result = coalesce_arrays(&[arr1.clone(), arr2.clone()], Some(&DataType::INTEGER));
assert!(result.is_ok());
let result2 = coalesce_arrays(&[arr1, arr2], Some(&DataType::STRING));
assert_result_error_with_message(
result2,
"Requested result type Utf8 does not match arrays' data type Int32",
);
}
#[test]
fn test_coalesce_arrays_first_no_nulls() {
let arr1 = Arc::new(Int32Array::from(vec![1, 2, 3])); let arr2 = Arc::new(Int32Array::from(vec![10, 20, 30]));
let result = coalesce_arrays(&[arr1.clone(), arr2], None).unwrap();
let result_array = result.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(result_array.len(), 3);
assert_eq!(result_array.value(0), 1);
assert_eq!(result_array.value(1), 2);
assert_eq!(result_array.value(2), 3);
}
#[test]
fn test_coalesce_arrays_second_no_nulls() {
let arr1 = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
let arr2 = Arc::new(Int32Array::from(vec![10, 20, 30]));
let result = coalesce_arrays(&[arr1, arr2], None).unwrap();
let result_array = result.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(result_array.len(), 3);
assert_eq!(result_array.value(0), 1);
assert_eq!(result_array.value(1), 20);
assert_eq!(result_array.value(2), 3);
}
#[test]
fn test_coalesce_expression_short_circuit_first() {
let schema = ArrowSchema::new(vec![ArrowField::new("a", ArrowDataType::Int32, false)]);
let a_values = Int32Array::from(vec![1, 2, 3]); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a_values)]).unwrap();
let expr = Expression::coalesce([
Expression::column(["a"]),
Expression::column(["nonexistent"]), ]);
let result = evaluate_expression(&expr, &batch, Some(&DataType::INTEGER)).unwrap();
let result_array = result.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(result_array.values(), &[1, 2, 3]);
}
#[test]
fn test_coalesce_expression_short_circuit_second() {
let schema = ArrowSchema::new(vec![
ArrowField::new("a", ArrowDataType::Int32, true),
ArrowField::new("b", ArrowDataType::Int32, false),
]);
let a_values = Int32Array::from(vec![Some(1), None, Some(3)]); let b_values = Int32Array::from(vec![10, 20, 30]); let batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(a_values), Arc::new(b_values)],
)
.unwrap();
let expr = Expression::coalesce([
Expression::column(["a"]),
Expression::column(["b"]),
Expression::column(["nonexistent"]), ]);
let result = evaluate_expression(&expr, &batch, Some(&DataType::INTEGER)).unwrap();
let result_array = result.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(result_array.len(), 3);
assert_eq!(result_array.value(0), 1);
assert_eq!(result_array.value(1), 20);
assert_eq!(result_array.value(2), 3);
}
#[test]
fn test_coalesce_expression_short_circuit_type_mismatch() {
let schema = ArrowSchema::new(vec![ArrowField::new("a", ArrowDataType::Int32, false)]);
let a_values = Int32Array::from(vec![1, 2, 3]); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a_values)]).unwrap();
let expr = Expression::coalesce([Expression::column(["a"])]);
let result = evaluate_expression(&expr, &batch, Some(&DataType::STRING));
assert!(result.is_err());
}
#[test]
fn test_nested_transforms() {
let nested_batch = create_nested_test_batch();
let nested_transform =
Transform::new_nested(["nested"]).with_replaced_field("x", Expr::literal(999).into());
let outer_transform = Transform::new_top_level()
.with_inserted_field(Some("a"), Expr::Transform(nested_transform).into());
let nested_output_schema = StructType::new_unchecked(vec![
StructField::not_null("x", DataType::INTEGER),
StructField::not_null("y", DataType::INTEGER),
]);
let output_schema = StructType::new_unchecked(vec![
StructField::not_null("a", DataType::INTEGER),
StructField::not_null("transformed", nested_output_schema.clone()),
StructField::not_null("nested", nested_output_schema),
]);
let expr = Expr::Transform(outer_transform);
let result = evaluate_expression(
&expr,
&nested_batch,
Some(&DataType::Struct(Box::new(output_schema))),
)
.unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_result.num_columns(), 3);
assert_eq!(struct_result.len(), 3);
validate_i32_column(struct_result, 0, &[100, 200, 300]);
let nested_struct_result = struct_result
.column(1)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
validate_i32_column(nested_struct_result, 0, &[999, 999, 999]);
validate_i32_column(nested_struct_result, 1, &[10, 20, 30]);
let nested_struct_result = struct_result
.column(2)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
validate_i32_column(nested_struct_result, 0, &[1, 2, 3]);
validate_i32_column(nested_struct_result, 1, &[10, 20, 30]);
}
#[test]
fn test_literal_type_validation() {
let batch = create_test_batch();
let result = evaluate_expression(&Expr::literal(42), &batch, Some(&DataType::INTEGER));
assert!(result.is_ok());
let result = evaluate_expression(&Expr::literal(42), &batch, Some(&DataType::STRING));
assert_result_error_with_message(result, "Incorrect datatype");
}
#[test]
fn test_column_type_validation() {
let batch = create_test_batch();
let result = evaluate_expression(&column_expr_ref!("a"), &batch, Some(&DataType::INTEGER));
assert!(result.is_ok());
let result = evaluate_expression(&column_expr_ref!("a"), &batch, Some(&DataType::STRING));
assert_result_error_with_message(result, "Incorrect datatype");
}
#[test]
fn test_binary_type_validation() {
let batch = create_test_batch();
let add_expr = Expr::binary(
BinaryExpressionOp::Plus,
Expr::column(["a"]),
Expr::column(["b"]),
);
let result = evaluate_expression(&add_expr, &batch, Some(&DataType::INTEGER));
assert!(result.is_ok());
let result = evaluate_expression(&add_expr, &batch, Some(&DataType::STRING));
assert_result_error_with_message(result, "Incorrect datatype");
}
fn create_json_batch() -> RecordBatch {
let schema = ArrowSchema::new(vec![ArrowField::new("json_col", ArrowDataType::Utf8, true)]);
let json_strings = StringArray::from(vec![
Some(r#"{"a": 1, "b": "hello"}"#),
Some(r#"{"a": 2, "b": "world"}"#),
Some(r#"{"a": 3, "b": "test"}"#),
]);
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(json_strings)]).unwrap()
}
#[test]
fn test_parse_json_basic() {
let batch = create_json_batch();
let output_schema = Arc::new(StructType::new_unchecked(vec![
StructField::new("a", DataType::LONG, true),
StructField::new("b", DataType::STRING, true),
]));
let expr = Expr::parse_json(column_expr!("json_col"), output_schema);
let result = evaluate_expression(&expr, &batch, None).unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_result.num_columns(), 2);
assert_eq!(struct_result.len(), 3);
let a_col = struct_result
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(a_col.values(), &[1, 2, 3]);
let b_col = struct_result
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(b_col.value(0), "hello");
assert_eq!(b_col.value(1), "world");
assert_eq!(b_col.value(2), "test");
}
#[test]
fn test_parse_json_nested_struct() {
let schema = ArrowSchema::new(vec![ArrowField::new("json_col", ArrowDataType::Utf8, true)]);
let json_strings = StringArray::from(vec![
Some(r#"{"outer": 10, "inner": {"x": 1, "y": 2}}"#),
Some(r#"{"outer": 20, "inner": {"x": 3, "y": 4}}"#),
]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(json_strings)]).unwrap();
let inner_schema = StructType::new_unchecked(vec![
StructField::new("x", DataType::LONG, true),
StructField::new("y", DataType::LONG, true),
]);
let output_schema = Arc::new(StructType::new_unchecked(vec![
StructField::new("outer", DataType::LONG, true),
StructField::new("inner", DataType::Struct(Box::new(inner_schema)), true),
]));
let expr = Expr::parse_json(column_expr!("json_col"), output_schema);
let result = evaluate_expression(&expr, &batch, None).unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_result.num_columns(), 2);
assert_eq!(struct_result.len(), 2);
let outer_col = struct_result
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(outer_col.values(), &[10, 20]);
let inner_struct = struct_result
.column(1)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let x_col = inner_struct
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
let y_col = inner_struct
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(x_col.values(), &[1, 3]);
assert_eq!(y_col.values(), &[2, 4]);
}
#[test]
fn test_parse_json_with_nulls() {
let schema = ArrowSchema::new(vec![ArrowField::new("json_col", ArrowDataType::Utf8, true)]);
let json_strings = StringArray::from(vec![Some(r#"{"a": 1}"#), None, Some(r#"{"a": 3}"#)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(json_strings)]).unwrap();
let output_schema = Arc::new(StructType::new_unchecked(vec![StructField::new(
"a",
DataType::LONG,
true,
)]));
let expr = Expr::parse_json(column_expr!("json_col"), output_schema);
let result = evaluate_expression(&expr, &batch, None).unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_result.len(), 3);
let a_col = struct_result
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert!(!a_col.is_null(0));
assert_eq!(a_col.value(0), 1);
assert!(a_col.is_null(1)); assert!(!a_col.is_null(2));
assert_eq!(a_col.value(2), 3);
}
#[test]
fn test_parse_json_empty_batch() {
let schema = ArrowSchema::new(vec![ArrowField::new("json_col", ArrowDataType::Utf8, true)]);
let json_strings: StringArray = StringArray::from(Vec::<Option<&str>>::new());
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(json_strings)]).unwrap();
let output_schema = Arc::new(StructType::new_unchecked(vec![StructField::new(
"a",
DataType::LONG,
true,
)]));
let expr = Expr::parse_json(column_expr!("json_col"), output_schema);
let result = evaluate_expression(&expr, &batch, None).unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_result.len(), 0);
}
#[test]
fn test_parse_json_missing_field() {
let schema = ArrowSchema::new(vec![ArrowField::new("json_col", ArrowDataType::Utf8, true)]);
let json_strings = StringArray::from(vec![
Some(r#"{"a": 1}"#), Some(r#"{"a": 2, "b": "hi"}"#), Some(r#"{"a": 3}"#), ]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(json_strings)]).unwrap();
let output_schema = Arc::new(StructType::new_unchecked(vec![
StructField::new("a", DataType::LONG, true),
StructField::new("b", DataType::STRING, true),
]));
let expr = Expr::parse_json(column_expr!("json_col"), output_schema);
let result = evaluate_expression(&expr, &batch, None).unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_result.len(), 3);
let a_col = struct_result
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(a_col.values(), &[1, 2, 3]);
let b_col = struct_result
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert!(b_col.is_null(0)); assert_eq!(b_col.value(1), "hi");
assert!(b_col.is_null(2)); }
#[test]
fn test_parse_json_extra_field_ignored() {
let schema = ArrowSchema::new(vec![ArrowField::new("json_col", ArrowDataType::Utf8, true)]);
let json_strings = StringArray::from(vec![
Some(r#"{"a": 1, "b": "x", "c": "extra"}"#),
Some(r#"{"a": 2, "b": "y", "ignored": 999}"#),
]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(json_strings)]).unwrap();
let output_schema = Arc::new(StructType::new_unchecked(vec![
StructField::new("a", DataType::LONG, true),
StructField::new("b", DataType::STRING, true),
]));
let expr = Expr::parse_json(column_expr!("json_col"), output_schema);
let result = evaluate_expression(&expr, &batch, None).unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_result.num_columns(), 2); assert_eq!(struct_result.len(), 2);
let a_col = struct_result
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(a_col.values(), &[1, 2]);
let b_col = struct_result
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(b_col.value(0), "x");
assert_eq!(b_col.value(1), "y");
}
#[test]
fn test_parse_json_errors_return_nulls() {
fn assert_parse_json_result_all_nulls(
json_strings: Vec<Option<&str>>,
output_schema: Arc<StructType>,
) {
let schema =
ArrowSchema::new(vec![ArrowField::new("json_col", ArrowDataType::Utf8, true)]);
let len = json_strings.len();
let json_arr = StringArray::from(json_strings);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(json_arr)]).unwrap();
let expr = Expr::parse_json(column_expr!("json_col"), output_schema);
let result = evaluate_expression(&expr, &batch, None).unwrap();
assert_eq!(result.len(), len);
assert_eq!(result.null_count(), len);
}
assert_parse_json_result_all_nulls(
vec![Some(r#"{"a": "not_a_number"}"#)],
Arc::new(StructType::new_unchecked(vec![StructField::new(
"a",
DataType::LONG,
true,
)])),
);
assert_parse_json_result_all_nulls(
vec![Some(r#"{"a": 99999}"#)],
Arc::new(StructType::new_unchecked(vec![StructField::new(
"a",
DataType::decimal(4, 2).unwrap(),
true,
)])),
);
}
fn create_partition_map_batch() -> RecordBatch {
use crate::arrow::array::{MapBuilder, StringBuilder};
let mut builder = MapBuilder::new(None, StringBuilder::new(), StringBuilder::new());
builder.keys().append_value("date");
builder.values().append_value("2024-01-15");
builder.keys().append_value("region");
builder.values().append_value("us");
builder.keys().append_value("id");
builder.values().append_value("42");
builder.append(true).unwrap();
builder.keys().append_value("date");
builder.values().append_value("");
builder.keys().append_value("region");
builder.values().append_value("eu");
builder.keys().append_value("id");
builder.values().append_value("-7");
builder.append(true).unwrap();
builder.append(false).unwrap();
let map_array = builder.finish();
let schema = ArrowSchema::new(vec![ArrowField::new(
"pv",
map_array.data_type().clone(),
true,
)]);
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(map_array)]).unwrap()
}
#[test]
fn test_map_to_struct_basic() {
use crate::arrow::array::Date32Array;
let batch = create_partition_map_batch();
let output_schema = StructType::new_unchecked(vec![
StructField::nullable("region", DataType::STRING),
StructField::nullable("id", DataType::INTEGER),
StructField::nullable("date", DataType::DATE),
]);
let result_type = DataType::Struct(Box::new(output_schema));
let expr = Expr::map_to_struct(column_expr!("pv"));
let result = evaluate_expression(&expr, &batch, Some(&result_type)).unwrap();
let structs = result.as_any().downcast_ref::<StructArray>().unwrap();
let regions = structs
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let ids = structs
.column(1)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
let dates = structs
.column(2)
.as_any()
.downcast_ref::<Date32Array>()
.unwrap();
assert_eq!(regions.value(0), "us");
assert_eq!(ids.value(0), 42);
assert_eq!(dates.value(0), 19737);
assert_eq!(regions.value(1), "eu");
assert_eq!(ids.value(1), -7);
assert!(dates.is_null(1));
assert!(regions.is_null(2));
assert!(ids.is_null(2));
assert!(dates.is_null(2));
}
#[test]
fn test_map_to_struct_missing_key() {
let batch = create_partition_map_batch();
let output_schema =
StructType::new_unchecked(vec![StructField::nullable("nonexistent", DataType::STRING)]);
let result_type = DataType::Struct(Box::new(output_schema));
let expr = Expr::map_to_struct(column_expr!("pv"));
let result = evaluate_expression(&expr, &batch, Some(&result_type)).unwrap();
let structs = result.as_any().downcast_ref::<StructArray>().unwrap();
let col = structs
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert!(col.is_null(0));
assert!(col.is_null(1));
assert!(col.is_null(2));
}
#[test]
fn test_map_to_struct_parse_error() {
use crate::arrow::array::{MapBuilder, StringBuilder};
let mut builder = MapBuilder::new(None, StringBuilder::new(), StringBuilder::new());
builder.keys().append_value("count");
builder.values().append_value("not_a_number");
builder.append(true).unwrap();
let map_array = builder.finish();
let schema = ArrowSchema::new(vec![ArrowField::new(
"pv",
map_array.data_type().clone(),
true,
)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(map_array)]).unwrap();
let output_schema =
StructType::new_unchecked(vec![StructField::nullable("count", DataType::INTEGER)]);
let result_type = DataType::Struct(Box::new(output_schema));
let expr = Expr::map_to_struct(column_expr!("pv"));
let result = evaluate_expression(&expr, &batch, Some(&result_type));
assert!(result.is_err());
}
#[test]
fn test_map_to_struct_duplicate_keys() {
use crate::arrow::array::{MapBuilder, StringBuilder};
let mut builder = MapBuilder::new(None, StringBuilder::new(), StringBuilder::new());
builder.keys().append_value("x");
builder.values().append_value("first");
builder.keys().append_value("x");
builder.values().append_value("last");
builder.append(true).unwrap();
let map_array = builder.finish();
let schema = ArrowSchema::new(vec![ArrowField::new(
"pv",
map_array.data_type().clone(),
true,
)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(map_array)]).unwrap();
let output_schema =
StructType::new_unchecked(vec![StructField::nullable("x", DataType::STRING)]);
let result_type = DataType::Struct(Box::new(output_schema));
let expr = Expr::map_to_struct(column_expr!("pv"));
let result = evaluate_expression(&expr, &batch, Some(&result_type)).unwrap();
let structs = result.as_any().downcast_ref::<StructArray>().unwrap();
let col = structs
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(col.value(0), "last");
}
#[test]
fn test_map_to_struct_non_map_input() {
let schema = ArrowSchema::new(vec![ArrowField::new("s", ArrowDataType::Utf8, true)]);
let strings = StringArray::from(vec![Some("hello")]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(strings)]).unwrap();
let output_schema =
StructType::new_unchecked(vec![StructField::nullable("x", DataType::STRING)]);
let result_type = DataType::Struct(Box::new(output_schema));
let expr = Expr::map_to_struct(column_expr!("s"));
let result = evaluate_expression(&expr, &batch, Some(&result_type));
assert!(result.is_err());
}
fn create_batch_with_bool_col(
a_vals: Vec<Option<i32>>,
is_valid_vals: Vec<Option<bool>>,
) -> RecordBatch {
let schema = ArrowSchema::new(vec![
ArrowField::new("a", ArrowDataType::Int32, true),
ArrowField::new("is_valid", ArrowDataType::Boolean, true),
]);
let a_array: ArrayRef = Arc::new(Int32Array::from(a_vals));
let is_valid_array: ArrayRef = Arc::new(BooleanArray::from(is_valid_vals));
RecordBatch::try_new(Arc::new(schema), vec![a_array, is_valid_array]).unwrap()
}
#[rstest]
#[case::fast_path(
vec![Some(1), Some(2), Some(3)],
vec![Some(true), Some(false), Some(true)],
vec![true, false, true],
)]
#[case::slow_path(
vec![Some(1), Some(2), Some(3), Some(4)],
vec![Some(true), Some(false), None, Some(true)],
vec![true, false, false, true],
)]
fn test_struct_with_nullability_predicate(
#[case] a_vals: Vec<Option<i32>>,
#[case] pred_vals: Vec<Option<bool>>,
#[case] expected_valid: Vec<bool>,
) {
let batch = create_batch_with_bool_col(a_vals, pred_vals);
let schema = DataType::Struct(Box::new(StructType::new_unchecked(vec![StructField::new(
"a",
DataType::INTEGER,
true,
)])));
let expr = Expr::struct_with_nullability_from(
[column_expr_ref!("a")],
column_expr_ref!("is_valid"),
);
let result = evaluate_expression(&expr, &batch, Some(&schema)).unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
for (i, valid) in expected_valid.iter().enumerate() {
assert_eq!(struct_result.is_valid(i), *valid, "row {i}");
}
}
#[test]
fn test_struct_with_nullability_predicate_nested_schema() {
let batch = create_batch_with_bool_col(
vec![Some(1), Some(2), Some(3)],
vec![Some(true), Some(false), Some(true)],
);
let inner_schema =
StructType::new_unchecked(vec![StructField::new("a", DataType::INTEGER, true)]);
let schema = DataType::Struct(Box::new(StructType::new_unchecked(vec![StructField::new(
"nested",
DataType::Struct(Box::new(inner_schema)),
true,
)])));
let inner_expr = Expr::struct_from([column_expr_ref!("a")]);
let expr = Expr::struct_with_nullability_from([inner_expr], column_expr_ref!("is_valid"));
let result = evaluate_expression(&expr, &batch, Some(&schema)).unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert!(struct_result.is_valid(0));
assert!(struct_result.is_null(1));
assert!(struct_result.is_valid(2));
let nested = struct_result
.column(0)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
assert_eq!(nested.len(), 3);
}
#[test]
fn test_struct_with_nullability_predicate_multiple_fields() {
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", ArrowDataType::Int32, true),
ArrowField::new("b", ArrowDataType::Int32, true),
ArrowField::new("is_valid", ArrowDataType::Boolean, true),
]);
let batch = RecordBatch::try_new(
Arc::new(arrow_schema),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])) as ArrayRef,
Arc::new(Int32Array::from(vec![Some(10), Some(20), Some(30)])) as ArrayRef,
Arc::new(BooleanArray::from(vec![
Some(true),
Some(false),
Some(true),
])) as ArrayRef,
],
)
.unwrap();
let schema = DataType::Struct(Box::new(StructType::new_unchecked(vec![
StructField::new("a", DataType::INTEGER, true),
StructField::new("b", DataType::INTEGER, true),
])));
let expr = Expr::struct_with_nullability_from(
[column_expr_ref!("a"), column_expr_ref!("b")],
column_expr_ref!("is_valid"),
);
let result = evaluate_expression(&expr, &batch, Some(&schema)).unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert!(struct_result.is_valid(0), "row 0 should be valid");
assert!(struct_result.is_null(1), "row 1 should be null");
assert!(struct_result.is_valid(2), "row 2 should be valid");
validate_i32_column(struct_result, 0, &[1, 2, 3]);
validate_i32_column(struct_result, 1, &[10, 20, 30]);
}
#[test]
fn test_struct_nullability_non_boolean_predicate_errors() {
let batch = create_batch_with_bool_col(
vec![Some(1), Some(2), Some(3)],
vec![Some(true), Some(false), Some(true)],
);
let schema = DataType::Struct(Box::new(StructType::new_unchecked(vec![StructField::new(
"a",
DataType::INTEGER,
true,
)])));
let expr =
Expr::struct_with_nullability_from([column_expr_ref!("a")], column_expr_ref!("a"));
let result = evaluate_expression(&expr, &batch, Some(&schema));
assert_result_error_with_message(result, "Incorrect datatype");
}
#[test]
fn test_struct_no_result_type_errors() {
let batch = create_test_batch();
let expr = Expr::struct_from([column_expr_ref!("a")]);
let result = evaluate_expression(&expr, &batch, None);
assert!(result.is_err());
}
fn make_struct_batch(arrow_fields: Vec<ArrowField>, arrays: Vec<ArrayRef>) -> RecordBatch {
let stats_type = ArrowDataType::Struct(arrow_fields.clone().into());
let schema = ArrowSchema::new(vec![ArrowField::new("stats", stats_type, true)]);
let stats_array = StructArray::try_new(arrow_fields.into(), arrays, None).unwrap();
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(stats_array)]).unwrap()
}
#[test]
fn column_extract_struct_with_mismatched_field_names() {
let batch = make_struct_batch(
vec![
ArrowField::new("col-abc-001", ArrowDataType::Int64, true),
ArrowField::new("col-abc-002", ArrowDataType::Int64, true),
],
vec![
Arc::new(Int64Array::from(vec![Some(1), Some(2)])),
Arc::new(Int64Array::from(vec![Some(10), Some(20)])),
],
);
let logical_type = DataType::try_struct_type([
StructField::nullable("my_column", DataType::LONG),
StructField::nullable("other_column", DataType::LONG),
])
.unwrap();
let expr = column_expr!("stats");
let result = evaluate_expression(&expr, &batch, Some(&logical_type));
let arr = result.expect("should succeed with mismatched names but matching types");
let struct_arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_arr.num_columns(), 2);
assert_eq!(struct_arr.len(), 2);
}
#[test]
fn column_extract_struct_rejects_mismatched_field_count() {
let batch = make_struct_batch(
vec![ArrowField::new("col-abc-001", ArrowDataType::Int64, true)],
vec![Arc::new(Int64Array::from(vec![Some(1), Some(2)]))],
);
let logical_type = DataType::try_struct_type([
StructField::nullable("a", DataType::LONG),
StructField::nullable("b", DataType::LONG),
])
.unwrap();
let expr = column_expr!("stats");
let result = evaluate_expression(&expr, &batch, Some(&logical_type));
assert_result_error_with_message(result, "Struct field count mismatch");
}
#[test]
fn column_extract_struct_rejects_mismatched_child_types() {
let batch = make_struct_batch(
vec![
ArrowField::new("col-abc-001", ArrowDataType::Int64, true),
ArrowField::new("col-abc-002", ArrowDataType::Utf8, true),
],
vec![
Arc::new(Int64Array::from(vec![Some(1)])),
Arc::new(StringArray::from(vec![Some("x")])),
],
);
let logical_type = DataType::try_struct_type([
StructField::nullable("a", DataType::LONG),
StructField::nullable("b", DataType::LONG),
])
.unwrap();
let expr = column_expr!("stats");
let result = evaluate_expression(&expr, &batch, Some(&logical_type));
assert_result_error_with_message(result, "Incorrect datatype");
}
#[test]
fn column_extract_struct_with_matching_names_still_works() {
let batch = make_struct_batch(
vec![
ArrowField::new("a", ArrowDataType::Int64, true),
ArrowField::new("b", ArrowDataType::Int64, true),
],
vec![
Arc::new(Int64Array::from(vec![Some(1)])),
Arc::new(Int64Array::from(vec![Some(2)])),
],
);
let logical_type = DataType::try_struct_type([
StructField::nullable("a", DataType::LONG),
StructField::nullable("b", DataType::LONG),
])
.unwrap();
let expr = column_expr!("stats");
let result = evaluate_expression(&expr, &batch, Some(&logical_type));
assert!(result.is_ok());
}
#[test]
fn struct_from_with_column_tolerates_nested_name_mismatch() {
let stats_fields: Vec<ArrowField> = vec![
ArrowField::new("col-abc-001", ArrowDataType::Int64, true),
ArrowField::new("col-abc-002", ArrowDataType::Int64, true),
];
let stats_array = StructArray::try_new(
stats_fields.clone().into(),
vec![
Arc::new(Int64Array::from(vec![Some(1)])),
Arc::new(Int64Array::from(vec![Some(10)])),
],
None,
)
.unwrap();
let add_fields: Vec<ArrowField> = vec![
ArrowField::new("path", ArrowDataType::Utf8, true),
ArrowField::new(
"stats_parsed",
ArrowDataType::Struct(stats_fields.into()),
true,
),
];
let add_struct = StructArray::try_new(
add_fields.clone().into(),
vec![
Arc::new(StringArray::from(vec![Some("file.parquet")])),
Arc::new(stats_array),
],
None,
)
.unwrap();
let schema = ArrowSchema::new(vec![ArrowField::new(
"add",
ArrowDataType::Struct(add_fields.into()),
true,
)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(add_struct)]).unwrap();
let expr = Expr::struct_from([
column_expr_ref!("add.path"),
column_expr_ref!("add.stats_parsed"),
]);
let output_type = DataType::try_struct_type([
StructField::nullable("path", DataType::STRING),
StructField::nullable(
"stats_parsed",
DataType::struct_type_unchecked([
StructField::nullable("id", DataType::LONG),
StructField::nullable("value", DataType::LONG),
]),
),
])
.unwrap();
let result = evaluate_expression(&expr, &batch, Some(&output_type));
result.expect("struct_from with Column sub-expression should tolerate field name mismatch");
}
#[test]
fn column_extract_nested_struct_with_mismatched_names() {
let inner_fields = vec![ArrowField::new("phys-inner", ArrowDataType::Int64, true)];
let inner_struct = ArrowDataType::Struct(inner_fields.clone().into());
let batch = make_struct_batch(
vec![ArrowField::new("phys-outer", inner_struct, true)],
vec![Arc::new(
StructArray::try_new(
inner_fields.into(),
vec![Arc::new(Int64Array::from(vec![Some(42)]))],
None,
)
.unwrap(),
)],
);
let logical_type = DataType::try_struct_type([StructField::nullable(
"logical_outer",
DataType::struct_type_unchecked([StructField::nullable(
"logical_inner",
DataType::LONG,
)]),
)])
.unwrap();
let expr = column_expr!("stats");
let result = evaluate_expression(&expr, &batch, Some(&logical_type));
assert!(result.is_ok());
}
}