use crate::error::{_plan_err, Result};
use arrow::{
array::{
Array, ArrayRef, DictionaryArray, GenericListArray, GenericListViewArray,
StructArray, downcast_integer, new_null_array,
},
compute::{CastOptions, can_cast_types, cast_with_options},
datatypes::{DataType, DataType::Struct, Field, FieldRef},
};
use std::{collections::HashSet, sync::Arc};
fn cast_struct_column(
source_col: &ArrayRef,
target_fields: &[Arc<Field>],
cast_options: &CastOptions,
) -> Result<ArrayRef> {
if source_col.data_type() == &DataType::Null
|| (!source_col.is_empty() && source_col.null_count() == source_col.len())
{
return Ok(new_null_array(
&Struct(target_fields.to_vec().into()),
source_col.len(),
));
}
if let Some(source_struct) = source_col.as_any().downcast_ref::<StructArray>() {
let source_fields = source_struct.fields();
validate_struct_compatibility(source_fields, target_fields)?;
let mut fields: Vec<Arc<Field>> = Vec::with_capacity(target_fields.len());
let mut arrays: Vec<ArrayRef> = Vec::with_capacity(target_fields.len());
let num_rows = source_col.len();
for target_child_field in target_fields.iter() {
fields.push(Arc::clone(target_child_field));
let source_child_opt =
source_struct.column_by_name(target_child_field.name());
match source_child_opt {
Some(source_child_col) => {
let adapted_child = cast_column(
source_child_col,
target_child_field.data_type(),
cast_options,
)
.map_err(|e| {
e.context(format!(
"While casting struct field '{}'",
target_child_field.name()
))
})?;
arrays.push(adapted_child);
}
None => {
arrays.push(new_null_array(target_child_field.data_type(), num_rows));
}
}
}
let struct_array =
StructArray::new(fields.into(), arrays, source_struct.nulls().cloned());
Ok(Arc::new(struct_array))
} else {
_plan_err!(
"Cannot cast column of type {} to struct type. Source must be a struct to cast to struct.",
source_col.data_type()
)
}
}
pub fn cast_column(
source_col: &ArrayRef,
target_type: &DataType,
cast_options: &CastOptions,
) -> Result<ArrayRef> {
match (source_col.data_type(), target_type) {
(_, Struct(target_fields)) => {
cast_struct_column(source_col, target_fields, cast_options)
}
(DataType::List(_), DataType::List(target_inner)) => {
cast_list_column::<i32>(source_col, target_inner, cast_options)
}
(DataType::LargeList(_), DataType::LargeList(target_inner)) => {
cast_list_column::<i64>(source_col, target_inner, cast_options)
}
(DataType::ListView(_), DataType::ListView(target_inner)) => {
cast_list_view_column::<i32>(source_col, target_inner, cast_options)
}
(DataType::LargeListView(_), DataType::LargeListView(target_inner)) => {
cast_list_view_column::<i64>(source_col, target_inner, cast_options)
}
(
DataType::Dictionary(source_key_type, _),
DataType::Dictionary(target_key_type, target_value_type),
) => cast_dictionary_column(
source_col,
source_key_type,
target_key_type,
target_value_type,
cast_options,
),
_ => Ok(cast_with_options(source_col, target_type, cast_options)?),
}
}
fn cast_list_column<O: arrow::array::OffsetSizeTrait>(
source_col: &ArrayRef,
target_inner_field: &FieldRef,
cast_options: &CastOptions,
) -> Result<ArrayRef> {
let source_list = source_col
.as_any()
.downcast_ref::<GenericListArray<O>>()
.ok_or_else(|| {
crate::error::DataFusionError::Plan(format!(
"Expected list array but got {}",
source_col.data_type()
))
})?;
let cast_values = cast_column(
source_list.values(),
target_inner_field.data_type(),
cast_options,
)?;
let result = GenericListArray::<O>::new(
Arc::clone(target_inner_field),
source_list.offsets().clone(),
cast_values,
source_list.nulls().cloned(),
);
Ok(Arc::new(result))
}
fn cast_list_view_column<O: arrow::array::OffsetSizeTrait>(
source_col: &ArrayRef,
target_inner_field: &FieldRef,
cast_options: &CastOptions,
) -> Result<ArrayRef> {
let source_list = source_col
.as_any()
.downcast_ref::<GenericListViewArray<O>>()
.ok_or_else(|| {
crate::error::DataFusionError::Plan(format!(
"Expected list view array but got {}",
source_col.data_type()
))
})?;
let cast_values = cast_column(
source_list.values(),
target_inner_field.data_type(),
cast_options,
)?;
let result = GenericListViewArray::<O>::try_new(
Arc::clone(target_inner_field),
source_list.offsets().clone(),
source_list.sizes().clone(),
cast_values,
source_list.nulls().cloned(),
)?;
Ok(Arc::new(result))
}
fn cast_dictionary_column(
source_col: &ArrayRef,
source_key_type: &DataType,
target_key_type: &DataType,
target_value_type: &DataType,
cast_options: &CastOptions,
) -> Result<ArrayRef> {
macro_rules! cast_dict_values {
($t:ty) => {{
let source_dict = source_col
.as_any()
.downcast_ref::<DictionaryArray<$t>>()
.expect("downcast must succeed");
let cast_values =
cast_column(source_dict.values(), target_value_type, cast_options)?;
Ok(Arc::new(DictionaryArray::<$t>::new(
source_dict.keys().clone(),
cast_values,
)) as ArrayRef)
}};
}
let result: Result<ArrayRef> = downcast_integer! {
source_key_type => (cast_dict_values),
k => _plan_err!("Unsupported dictionary key type: {k}")
};
let result = result?;
if source_key_type != target_key_type {
let target_dict_type = DataType::Dictionary(
Box::new(target_key_type.clone()),
Box::new(target_value_type.clone()),
);
Ok(cast_with_options(&result, &target_dict_type, cast_options)?)
} else {
Ok(result)
}
}
pub fn validate_struct_compatibility(
source_fields: &[FieldRef],
target_fields: &[FieldRef],
) -> Result<()> {
let has_overlap = has_one_of_more_common_fields(source_fields, target_fields);
if !has_overlap {
return _plan_err!(
"Cannot cast struct with {} fields to {} fields because there is no field name overlap",
source_fields.len(),
target_fields.len()
);
}
for target_field in target_fields {
if let Some(source_field) = source_fields
.iter()
.find(|f| f.name() == target_field.name())
{
validate_field_compatibility(source_field, target_field)?;
} else {
if !target_field.is_nullable() {
return _plan_err!(
"Cannot cast struct: target field '{}' is non-nullable but missing from source. \
Cannot fill with NULL.",
target_field.name()
);
}
}
}
Ok(())
}
fn validate_field_compatibility(
source_field: &Field,
target_field: &Field,
) -> Result<()> {
if source_field.data_type() == &DataType::Null {
if !target_field.is_nullable() {
return _plan_err!(
"Cannot cast NULL struct field '{}' to non-nullable field '{}'",
source_field.name(),
target_field.name()
);
}
return Ok(());
}
if source_field.is_nullable() && !target_field.is_nullable() {
return _plan_err!(
"Cannot cast nullable struct field '{}' to non-nullable field",
target_field.name()
);
}
validate_data_type_compatibility(
target_field.name(),
source_field.data_type(),
target_field.data_type(),
)
}
pub fn validate_data_type_compatibility(
field_name: &str,
source_type: &DataType,
target_type: &DataType,
) -> Result<()> {
match (source_type, target_type) {
(Struct(source_nested), Struct(target_nested)) => {
validate_struct_compatibility(source_nested, target_nested)?;
}
(DataType::List(s), DataType::List(t))
| (DataType::LargeList(s), DataType::LargeList(t))
| (DataType::ListView(s), DataType::ListView(t))
| (DataType::LargeListView(s), DataType::LargeListView(t)) => {
validate_field_compatibility(s, t)?;
}
(DataType::Dictionary(s_key, s_val), DataType::Dictionary(t_key, t_val)) => {
if !can_cast_types(s_key, t_key) {
return _plan_err!(
"Cannot cast dictionary key type {} to {} for field '{}'",
s_key,
t_key,
field_name
);
}
validate_data_type_compatibility(field_name, s_val, t_val)?;
}
_ => {
if !can_cast_types(source_type, target_type) {
return _plan_err!(
"Cannot cast struct field '{}' from type {} to type {}",
field_name,
source_type,
target_type
);
}
}
}
Ok(())
}
pub fn requires_nested_struct_cast(
source_type: &DataType,
target_type: &DataType,
) -> bool {
match (source_type, target_type) {
(Struct(_), Struct(_)) => true,
(DataType::List(s), DataType::List(t))
| (DataType::LargeList(s), DataType::LargeList(t))
| (DataType::ListView(s), DataType::ListView(t))
| (DataType::LargeListView(s), DataType::LargeListView(t)) => {
requires_nested_struct_cast(s.data_type(), t.data_type())
}
(DataType::Dictionary(_, s_val), DataType::Dictionary(_, t_val)) => {
requires_nested_struct_cast(s_val, t_val)
}
_ => false,
}
}
pub fn has_one_of_more_common_fields(
source_fields: &[FieldRef],
target_fields: &[FieldRef],
) -> bool {
let source_names: HashSet<&str> = source_fields
.iter()
.map(|field| field.name().as_str())
.collect();
target_fields
.iter()
.any(|field| source_names.contains(field.name().as_str()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{assert_contains, format::DEFAULT_CAST_OPTIONS};
use arrow::{
array::{
BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, ListViewArray,
MapArray, MapBuilder, NullArray, StringArray, StringBuilder,
},
buffer::{NullBuffer, ScalarBuffer},
datatypes::{DataType, Field, FieldRef, Int32Type},
};
macro_rules! get_column_as {
($struct_array:expr, $column_name:expr, $array_type:ty) => {
$struct_array
.column_by_name($column_name)
.unwrap()
.as_any()
.downcast_ref::<$array_type>()
.unwrap()
};
}
fn field(name: &str, data_type: DataType) -> Field {
Field::new(name, data_type, true)
}
fn non_null_field(name: &str, data_type: DataType) -> Field {
Field::new(name, data_type, false)
}
fn arc_field(name: &str, data_type: DataType) -> FieldRef {
Arc::new(field(name, data_type))
}
fn struct_type(fields: Vec<Field>) -> DataType {
Struct(fields.into())
}
fn struct_field(name: &str, fields: Vec<Field>) -> Field {
field(name, struct_type(fields))
}
fn arc_struct_field(name: &str, fields: Vec<Field>) -> FieldRef {
Arc::new(struct_field(name, fields))
}
#[test]
fn test_cast_simple_column() {
let source = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
let target_field = field("ints", DataType::Int64);
let result =
cast_column(&source, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
.unwrap();
let result = result.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result.value(0), 1);
assert_eq!(result.value(1), 2);
assert_eq!(result.value(2), 3);
}
#[test]
fn test_cast_column_with_options() {
let source = Arc::new(Int64Array::from(vec![1, i64::MAX])) as ArrayRef;
let target_field = field("ints", DataType::Int32);
let safe_opts = CastOptions {
safe: false,
..DEFAULT_CAST_OPTIONS
};
assert!(cast_column(&source, target_field.data_type(), &safe_opts).is_err());
let unsafe_opts = CastOptions {
safe: true,
..DEFAULT_CAST_OPTIONS
};
let result =
cast_column(&source, target_field.data_type(), &unsafe_opts).unwrap();
let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(result.value(0), 1);
assert!(result.is_null(1));
}
#[test]
fn test_cast_struct_with_missing_field() {
let a_array = Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef;
let source_struct = StructArray::from(vec![(
arc_field("a", DataType::Int32),
Arc::clone(&a_array),
)]);
let source_col = Arc::new(source_struct) as ArrayRef;
let target_field = struct_field(
"s",
vec![field("a", DataType::Int32), field("b", DataType::Utf8)],
);
let result =
cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
.unwrap();
let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_array.fields().len(), 2);
let a_result = get_column_as!(&struct_array, "a", Int32Array);
assert_eq!(a_result.value(0), 1);
assert_eq!(a_result.value(1), 2);
let b_result = get_column_as!(&struct_array, "b", StringArray);
assert_eq!(b_result.len(), 2);
assert!(b_result.is_null(0));
assert!(b_result.is_null(1));
}
#[test]
fn test_cast_struct_source_not_struct() {
let source = Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef;
let target_field = struct_field("s", vec![field("a", DataType::Int32)]);
let result =
cast_column(&source, target_field.data_type(), &DEFAULT_CAST_OPTIONS);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("Cannot cast column of type"));
assert!(error_msg.contains("to struct type"));
assert!(error_msg.contains("Source must be a struct"));
}
#[test]
fn test_cast_struct_incompatible_child_type() {
let a_array = Arc::new(BinaryArray::from(vec![
Some(b"a".as_ref()),
Some(b"b".as_ref()),
])) as ArrayRef;
let source_struct =
StructArray::from(vec![(arc_field("a", DataType::Binary), a_array)]);
let source_col = Arc::new(source_struct) as ArrayRef;
let target_field = struct_field("s", vec![field("a", DataType::Int32)]);
let result =
cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("Cannot cast struct field 'a'"));
}
#[test]
fn test_validate_struct_compatibility_incompatible_types() {
let source_fields = vec![
arc_field("field1", DataType::Binary),
arc_field("field2", DataType::Utf8),
];
let target_fields = vec![arc_field("field1", DataType::Int32)];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("Cannot cast struct field 'field1'"));
assert!(error_msg.contains("Binary"));
assert!(error_msg.contains("Int32"));
}
#[test]
fn test_validate_struct_compatibility_compatible_types() {
let source_fields = vec![
arc_field("field1", DataType::Int32),
arc_field("field2", DataType::Utf8),
];
let target_fields = vec![arc_field("field1", DataType::Int64)];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_ok());
}
#[test]
fn test_validate_struct_compatibility_missing_field_in_source() {
let source_fields = vec![arc_field("field1", DataType::Int32)];
let target_fields = vec![
arc_field("field1", DataType::Int32),
arc_field("field2", DataType::Utf8),
];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_ok());
}
#[test]
fn test_validate_struct_compatibility_additional_field_in_source() {
let source_fields = vec![
arc_field("field1", DataType::Int32),
arc_field("field2", DataType::Utf8),
];
let target_fields = vec![arc_field("field1", DataType::Int32)];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_ok());
}
#[test]
fn test_validate_struct_compatibility_no_overlap_mismatch_len() {
let source_fields = vec![
arc_field("left", DataType::Int32),
arc_field("right", DataType::Int32),
];
let target_fields = vec![arc_field("alpha", DataType::Int32)];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert_contains!(error_msg, "no field name overlap");
}
#[test]
fn test_cast_struct_parent_nulls_retained() {
let a_array = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef;
let fields = vec![arc_field("a", DataType::Int32)];
let nulls = Some(NullBuffer::from(vec![true, false]));
let source_struct = StructArray::new(fields.clone().into(), vec![a_array], nulls);
let source_col = Arc::new(source_struct) as ArrayRef;
let target_field = struct_field("s", vec![field("a", DataType::Int64)]);
let result =
cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
.unwrap();
let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_array.null_count(), 1);
assert!(struct_array.is_valid(0));
assert!(struct_array.is_null(1));
let a_result = get_column_as!(&struct_array, "a", Int64Array);
assert_eq!(a_result.value(0), 1);
assert_eq!(a_result.value(1), 2);
}
#[test]
fn test_validate_struct_compatibility_nullable_to_non_nullable() {
let source_fields = vec![arc_field("field1", DataType::Int32)];
let target_fields = vec![Arc::new(non_null_field("field1", DataType::Int32))];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("field1"));
assert!(error_msg.contains("non-nullable"));
}
#[test]
fn test_validate_struct_compatibility_non_nullable_to_nullable() {
let source_fields = vec![Arc::new(non_null_field("field1", DataType::Int32))];
let target_fields = vec![arc_field("field1", DataType::Int32)];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_ok());
}
#[test]
fn test_validate_struct_compatibility_nested_nullable_to_non_nullable() {
let source_fields = vec![Arc::new(non_null_field(
"field1",
struct_type(vec![field("nested", DataType::Int32)]),
))];
let target_fields = vec![Arc::new(non_null_field(
"field1",
struct_type(vec![non_null_field("nested", DataType::Int32)]),
))];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("nested"));
assert!(error_msg.contains("non-nullable"));
}
#[test]
fn test_validate_struct_compatibility_by_name() {
let source_fields = vec![
arc_field("field1", DataType::Int32),
arc_field("field2", DataType::Utf8),
];
let target_fields = vec![
arc_field("field2", DataType::Utf8),
arc_field("field1", DataType::Int64),
];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_ok());
}
#[test]
fn test_validate_struct_compatibility_by_name_with_type_mismatch() {
let source_fields = vec![arc_field("field1", DataType::Binary)];
let target_fields = vec![arc_field("field1", DataType::Int32)];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert_contains!(
error_msg,
"Cannot cast struct field 'field1' from type Binary to type Int32"
);
}
#[test]
fn test_validate_struct_compatibility_no_overlap_equal_len() {
let source_fields = vec![
arc_field("left", DataType::Int32),
arc_field("right", DataType::Utf8),
];
let target_fields = vec![
arc_field("alpha", DataType::Int32),
arc_field("beta", DataType::Utf8),
];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert_contains!(error_msg, "no field name overlap");
}
#[test]
fn test_validate_struct_compatibility_mixed_name_overlap() {
let source_fields = vec![
arc_field("a", DataType::Int32),
arc_field("b", DataType::Utf8),
arc_field("extra", DataType::Boolean),
];
let target_fields = vec![
arc_field("b", DataType::Utf8),
arc_field("a", DataType::Int64),
arc_field("c", DataType::Float32),
];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_ok());
}
#[test]
fn test_validate_struct_compatibility_by_name_missing_required_field() {
let source_fields = vec![arc_field("field1", DataType::Int32)];
let target_fields = vec![
arc_field("field1", DataType::Int32),
Arc::new(non_null_field("field2", DataType::Int32)),
];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert_contains!(
error_msg,
"Cannot cast struct: target field 'field2' is non-nullable but missing from source. Cannot fill with NULL."
);
}
#[test]
fn test_validate_struct_compatibility_partial_name_overlap_with_count_mismatch() {
let source_fields = vec![arc_field("a", DataType::Int32)];
let target_fields = vec![
arc_field("a", DataType::Int32),
arc_field("b", DataType::Utf8),
];
let result = validate_struct_compatibility(&source_fields, &target_fields);
assert!(result.is_ok());
}
#[test]
fn test_cast_nested_struct_with_extra_and_missing_fields() {
let a = Arc::new(Int32Array::from(vec![Some(1), None])) as ArrayRef;
let b = Arc::new(Int32Array::from(vec![Some(2), Some(3)])) as ArrayRef;
let extra = Arc::new(Int32Array::from(vec![Some(9), Some(10)])) as ArrayRef;
let inner = StructArray::from(vec![
(arc_field("a", DataType::Int32), a),
(arc_field("b", DataType::Int32), b),
(arc_field("extra", DataType::Int32), extra),
]);
let source_struct = StructArray::from(vec![(
arc_struct_field(
"inner",
vec![
field("a", DataType::Int32),
field("b", DataType::Int32),
field("extra", DataType::Int32),
],
),
Arc::new(inner) as ArrayRef,
)]);
let source_col = Arc::new(source_struct) as ArrayRef;
let target_field = struct_field(
"outer",
vec![struct_field(
"inner",
vec![
field("b", DataType::Int64),
field("a", DataType::Int32),
field("missing", DataType::Int32),
],
)],
);
let result =
cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
.unwrap();
let outer = result.as_any().downcast_ref::<StructArray>().unwrap();
let inner = get_column_as!(&outer, "inner", StructArray);
assert_eq!(inner.fields().len(), 3);
let b = get_column_as!(inner, "b", Int64Array);
assert_eq!(b.value(0), 2);
assert_eq!(b.value(1), 3);
assert!(!b.is_null(0));
assert!(!b.is_null(1));
let a = get_column_as!(inner, "a", Int32Array);
assert_eq!(a.value(0), 1);
assert!(a.is_null(1));
let missing = get_column_as!(inner, "missing", Int32Array);
assert!(missing.is_null(0));
assert!(missing.is_null(1));
}
#[test]
fn test_cast_null_struct_field_to_nested_struct() {
let null_inner = Arc::new(NullArray::new(2)) as ArrayRef;
let source_struct = StructArray::from(vec![(
arc_field("inner", DataType::Null),
Arc::clone(&null_inner),
)]);
let source_col = Arc::new(source_struct) as ArrayRef;
let target_field = struct_field(
"outer",
vec![struct_field("inner", vec![field("a", DataType::Int32)])],
);
let result =
cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
.unwrap();
let outer = result.as_any().downcast_ref::<StructArray>().unwrap();
let inner = get_column_as!(&outer, "inner", StructArray);
assert_eq!(inner.len(), 2);
assert!(inner.is_null(0));
assert!(inner.is_null(1));
let inner_a = get_column_as!(inner, "a", Int32Array);
assert!(inner_a.is_null(0));
assert!(inner_a.is_null(1));
}
#[test]
fn test_cast_struct_with_array_and_map_fields() {
let arr_array = Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2)]),
None,
])) as ArrayRef;
let string_builder = StringBuilder::new();
let int_builder = Int32Builder::new();
let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
map_builder.keys().append_value("a");
map_builder.values().append_value(1);
map_builder.append(true).unwrap();
map_builder.append(false).unwrap();
let map_array = Arc::new(map_builder.finish()) as ArrayRef;
let source_struct = StructArray::from(vec![
(
arc_field(
"arr",
DataType::List(Arc::new(field("item", DataType::Int32))),
),
arr_array,
),
(
arc_field(
"map",
DataType::Map(
Arc::new(non_null_field(
"entries",
struct_type(vec![
non_null_field("keys", DataType::Utf8),
field("values", DataType::Int32),
]),
)),
false,
),
),
map_array,
),
]);
let source_col = Arc::new(source_struct) as ArrayRef;
let target_field = struct_field(
"s",
vec![
field(
"arr",
DataType::List(Arc::new(field("item", DataType::Int32))),
),
field(
"map",
DataType::Map(
Arc::new(non_null_field(
"entries",
struct_type(vec![
non_null_field("keys", DataType::Utf8),
field("values", DataType::Int32),
]),
)),
false,
),
),
],
);
let result =
cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
.unwrap();
let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
let arr = get_column_as!(&struct_array, "arr", ListArray);
assert!(!arr.is_null(0));
assert!(arr.is_null(1));
let arr0 = arr.value(0);
let values = arr0.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(values.value(0), 1);
assert_eq!(values.value(1), 2);
let map = get_column_as!(&struct_array, "map", MapArray);
assert!(!map.is_null(0));
assert!(map.is_null(1));
let map0 = map.value(0);
let entries = map0.as_any().downcast_ref::<StructArray>().unwrap();
let keys = get_column_as!(entries, "keys", StringArray);
let vals = get_column_as!(entries, "values", Int32Array);
assert_eq!(keys.value(0), "a");
assert_eq!(vals.value(0), 1);
}
#[test]
fn test_cast_struct_field_order_differs() {
let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef;
let b = Arc::new(Int32Array::from(vec![Some(3), None])) as ArrayRef;
let source_struct = StructArray::from(vec![
(arc_field("a", DataType::Int32), a),
(arc_field("b", DataType::Int32), b),
]);
let source_col = Arc::new(source_struct) as ArrayRef;
let target_field = struct_field(
"s",
vec![field("b", DataType::Int64), field("a", DataType::Int32)],
);
let result =
cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
.unwrap();
let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
let b_col = get_column_as!(&struct_array, "b", Int64Array);
assert_eq!(b_col.value(0), 3);
assert!(b_col.is_null(1));
let a_col = get_column_as!(&struct_array, "a", Int32Array);
assert_eq!(a_col.value(0), 1);
assert_eq!(a_col.value(1), 2);
}
#[test]
fn test_cast_struct_no_overlap_rejected() {
let first = Arc::new(Int32Array::from(vec![Some(10), Some(20)])) as ArrayRef;
let second =
Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])) as ArrayRef;
let source_struct = StructArray::from(vec![
(arc_field("left", DataType::Int32), first),
(arc_field("right", DataType::Utf8), second),
]);
let source_col = Arc::new(source_struct) as ArrayRef;
let target_field = struct_field(
"s",
vec![field("a", DataType::Int64), field("b", DataType::Utf8)],
);
let result =
cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert_contains!(error_msg, "no field name overlap");
}
#[test]
fn test_cast_struct_missing_non_nullable_field_fails() {
let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef;
let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]);
let source_col = Arc::new(source_struct) as ArrayRef;
let target_field = struct_field(
"s",
vec![
field("a", DataType::Int32),
non_null_field("b", DataType::Int32),
],
);
let result =
cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string()
.contains("target field 'b' is non-nullable but missing from source"),
"Unexpected error: {err}"
);
}
#[test]
fn test_cast_struct_missing_nullable_field_succeeds() {
let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef;
let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]);
let source_col = Arc::new(source_struct) as ArrayRef;
let target_field = struct_field(
"s",
vec![field("a", DataType::Int32), field("b", DataType::Int32)],
);
let result =
cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
.unwrap();
let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
let a_col = get_column_as!(&struct_array, "a", Int32Array);
assert_eq!(a_col.value(0), 1);
assert_eq!(a_col.value(1), 2);
let b_col = get_column_as!(&struct_array, "b", Int32Array);
assert!(b_col.is_null(0));
assert!(b_col.is_null(1));
}
#[test]
fn test_validate_dictionary_value_evolution() {
let source_inner = struct_type(vec![field("a", DataType::Int32)]);
let target_inner = struct_type(vec![
field("a", DataType::Int32),
field("b", DataType::Utf8),
]);
let source =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(source_inner));
let target =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(target_inner));
assert!(validate_data_type_compatibility("col", &source, &target).is_ok());
}
#[test]
fn test_cast_dictionary_struct_value() {
let struct_arr = StructArray::from(vec![(
arc_field("a", DataType::Int32),
Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef,
)]);
let keys = Int32Array::from(vec![Some(0), None, Some(1)]);
let source_dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(struct_arr));
let source_col: ArrayRef = Arc::new(source_dict);
let target_type = DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(struct_type(vec![
field("a", DataType::Int64),
field("b", DataType::Utf8),
])),
);
let result =
cast_column(&source_col, &target_type, &DEFAULT_CAST_OPTIONS).unwrap();
let result_dict = result
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.unwrap();
assert!(result_dict.is_valid(0));
assert!(result_dict.is_null(1));
assert!(result_dict.is_valid(2));
let struct_values = result_dict
.values()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let a_col = get_column_as!(&struct_values, "a", Int64Array);
assert_eq!(a_col.values(), &[10, 20]);
let b_col = get_column_as!(&struct_values, "b", StringArray);
assert!(b_col.iter().all(|v| v.is_none()));
}
#[test]
fn test_cast_list_view_struct() {
let struct_arr = StructArray::from(vec![(
arc_field("a", DataType::Int32),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
)]);
let source_field =
arc_field("item", struct_type(vec![field("a", DataType::Int32)]));
let target_field = arc_field(
"item",
struct_type(vec![
field("a", DataType::Int64),
field("b", DataType::Utf8),
]),
);
let list_view = ListViewArray::new(
source_field,
ScalarBuffer::from(vec![0i32, 2]),
ScalarBuffer::from(vec![2i32, 1]),
Arc::new(struct_arr),
None,
);
let source_col: ArrayRef = Arc::new(list_view);
let target_type = DataType::ListView(target_field);
let result =
cast_column(&source_col, &target_type, &DEFAULT_CAST_OPTIONS).unwrap();
let result_lv = result.as_any().downcast_ref::<ListViewArray>().unwrap();
assert_eq!(result_lv.len(), 2);
let struct_values = result_lv
.values()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let a_col = get_column_as!(&struct_values, "a", Int64Array);
assert_eq!(a_col.values(), &[1, 2, 3]);
let b_col = get_column_as!(&struct_values, "b", StringArray);
assert!(b_col.iter().all(|v| v.is_none()));
}
#[test]
fn test_requires_nested_struct_cast() {
let s1 = struct_type(vec![field("a", DataType::Int32)]);
let s2 = struct_type(vec![field("a", DataType::Int64)]);
assert!(requires_nested_struct_cast(&s1, &s2));
assert!(requires_nested_struct_cast(
&DataType::List(arc_field("item", s1.clone())),
&DataType::List(arc_field("item", s2.clone())),
));
assert!(requires_nested_struct_cast(
&DataType::Dictionary(Box::new(DataType::Int32), Box::new(s1.clone())),
&DataType::Dictionary(Box::new(DataType::Int32), Box::new(s2.clone())),
));
assert!(requires_nested_struct_cast(
&DataType::ListView(arc_field("item", s1)),
&DataType::ListView(arc_field("item", s2)),
));
assert!(!requires_nested_struct_cast(
&DataType::Int32,
&DataType::Int64
));
assert!(!requires_nested_struct_cast(
&DataType::List(arc_field("item", DataType::Int32)),
&DataType::List(arc_field("item", DataType::Int64)),
));
}
}