use crate::cast::can_cast_types;
use crate::cast_with_options;
use arrow_array::{Array, ArrayRef, UnionArray};
use arrow_schema::{ArrowError, DataType, FieldRef, UnionFields};
use arrow_select::union_extract::union_extract_by_id;
use super::CastOptions;
fn same_type_family(a: &DataType, b: &DataType) -> bool {
use DataType::*;
matches!(
(a, b),
(Utf8 | LargeUtf8 | Utf8View, Utf8 | LargeUtf8 | Utf8View)
| (
Binary | LargeBinary | BinaryView,
Binary | LargeBinary | BinaryView
)
| (Int8 | Int16 | Int32 | Int64, Int8 | Int16 | Int32 | Int64)
| (
UInt8 | UInt16 | UInt32 | UInt64,
UInt8 | UInt16 | UInt32 | UInt64
)
| (Float16 | Float32 | Float64, Float16 | Float32 | Float64)
)
}
pub(crate) fn resolve_child_array<'a>(
fields: &'a UnionFields,
target_type: &DataType,
) -> Option<(i8, &'a FieldRef)> {
fields
.iter()
.find(|(_, f)| f.data_type() == target_type)
.or_else(|| {
fields
.iter()
.find(|(_, f)| same_type_family(f.data_type(), target_type))
})
.or_else(|| {
if target_type.is_nested() {
return None;
}
fields
.iter()
.find(|(_, f)| can_cast_types(f.data_type(), target_type))
})
}
pub fn union_extract_by_type(
union_array: &UnionArray,
target_type: &DataType,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let fields = match union_array.data_type() {
DataType::Union(fields, _) => fields,
_ => unreachable!("union_extract_by_type called on non-union array"),
};
let Some((type_id, _)) = resolve_child_array(fields, target_type) else {
return Err(ArrowError::CastError(format!(
"cannot cast Union with fields {} to {}",
fields
.iter()
.map(|(_, f)| f.data_type().to_string())
.collect::<Vec<_>>()
.join(", "),
target_type
)));
};
let extracted = union_extract_by_id(union_array, type_id)?;
if extracted.data_type() == target_type {
return Ok(extracted);
}
cast_with_options(&extracted, target_type, cast_options)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cast;
use arrow_array::*;
use arrow_schema::{Field, UnionFields, UnionMode};
use std::sync::Arc;
fn int_str_fields() -> UnionFields {
UnionFields::try_new(
[0, 1],
[
Field::new("int", DataType::Int32, true),
Field::new("str", DataType::Utf8, true),
],
)
.unwrap()
}
fn int_str_union_type(mode: UnionMode) -> DataType {
DataType::Union(int_str_fields(), mode)
}
#[test]
fn test_exact_type_match() {
let target = DataType::Utf8;
assert!(can_cast_types(
&int_str_union_type(UnionMode::Sparse),
&target
));
let sparse = UnionArray::try_new(
int_str_fields(),
vec![1_i8, 0, 1].into(),
None,
vec![
Arc::new(Int32Array::from(vec![None, Some(42), None])) as ArrayRef,
Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
],
)
.unwrap();
let result = cast::cast(&sparse, &target).unwrap();
assert_eq!(result.data_type(), &target);
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(arr.value(0), "hello");
assert!(arr.is_null(1));
assert_eq!(arr.value(2), "world");
assert!(can_cast_types(
&int_str_union_type(UnionMode::Dense),
&target
));
let dense = UnionArray::try_new(
int_str_fields(),
vec![1_i8, 0, 1].into(),
Some(vec![0_i32, 0, 1].into()),
vec![
Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
Arc::new(StringArray::from(vec![Some("hello"), Some("world")])),
],
)
.unwrap();
let result = cast::cast(&dense, &target).unwrap();
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(arr.value(0), "hello");
assert!(arr.is_null(1));
assert_eq!(arr.value(2), "world");
}
#[test]
fn test_same_family_utf8_to_utf8view() {
let target = DataType::Utf8View;
assert!(can_cast_types(
&int_str_union_type(UnionMode::Sparse),
&target
));
let sparse = UnionArray::try_new(
int_str_fields(),
vec![1_i8, 0, 1, 1].into(),
None,
vec![
Arc::new(Int32Array::from(vec![None, Some(42), None, None])) as ArrayRef,
Arc::new(StringArray::from(vec![
Some("agent_alpha"),
None,
Some("agent_beta"),
None,
])),
],
)
.unwrap();
let result = cast::cast(&sparse, &target).unwrap();
assert_eq!(result.data_type(), &target);
let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
assert_eq!(arr.value(0), "agent_alpha");
assert!(arr.is_null(1));
assert_eq!(arr.value(2), "agent_beta");
assert!(arr.is_null(3));
assert!(can_cast_types(
&int_str_union_type(UnionMode::Dense),
&target
));
let dense = UnionArray::try_new(
int_str_fields(),
vec![1_i8, 0, 1].into(),
Some(vec![0_i32, 0, 1].into()),
vec![
Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])),
],
)
.unwrap();
let result = cast::cast(&dense, &target).unwrap();
let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
assert_eq!(arr.value(0), "alpha");
assert!(arr.is_null(1));
assert_eq!(arr.value(2), "beta");
}
#[test]
fn test_one_directional_cast() {
let target = DataType::Boolean;
assert!(can_cast_types(
&int_str_union_type(UnionMode::Sparse),
&target
));
let sparse = UnionArray::try_new(
int_str_fields(),
vec![0_i8, 1, 0].into(),
None,
vec![
Arc::new(Int32Array::from(vec![Some(42), None, Some(0)])) as ArrayRef,
Arc::new(StringArray::from(vec![None, Some("hello"), None])),
],
)
.unwrap();
let result = cast::cast(&sparse, &target).unwrap();
assert_eq!(result.data_type(), &target);
let arr = result.as_any().downcast_ref::<BooleanArray>().unwrap();
assert!(arr.value(0));
assert!(arr.is_null(1));
assert!(!arr.value(2));
assert!(can_cast_types(
&int_str_union_type(UnionMode::Dense),
&target
));
let dense = UnionArray::try_new(
int_str_fields(),
vec![0_i8, 1, 0].into(),
Some(vec![0_i32, 0, 1].into()),
vec![
Arc::new(Int32Array::from(vec![Some(42), Some(0)])) as ArrayRef,
Arc::new(StringArray::from(vec![Some("hello")])),
],
)
.unwrap();
let result = cast::cast(&dense, &target).unwrap();
let arr = result.as_any().downcast_ref::<BooleanArray>().unwrap();
assert!(arr.value(0));
assert!(arr.is_null(1));
assert!(!arr.value(2));
}
#[test]
fn test_duplicate_field_names() {
let fields = UnionFields::try_new(
[0, 1],
[
Field::new("val", DataType::Int32, true),
Field::new("val", DataType::Utf8, true),
],
)
.unwrap();
let target = DataType::Utf8;
let sparse = UnionArray::try_new(
fields.clone(),
vec![0_i8, 1, 0, 1].into(),
None,
vec![
Arc::new(Int32Array::from(vec![Some(42), None, Some(99), None])) as ArrayRef,
Arc::new(StringArray::from(vec![
None,
Some("hello"),
None,
Some("world"),
])),
],
)
.unwrap();
let result = cast::cast(&sparse, &target).unwrap();
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
assert!(arr.is_null(0));
assert_eq!(arr.value(1), "hello");
assert!(arr.is_null(2));
assert_eq!(arr.value(3), "world");
let dense = UnionArray::try_new(
fields,
vec![0_i8, 1, 1].into(),
Some(vec![0_i32, 0, 1].into()),
vec![
Arc::new(Int32Array::from(vec![Some(42)])) as ArrayRef,
Arc::new(StringArray::from(vec![Some("hello"), Some("world")])),
],
)
.unwrap();
let result = cast::cast(&dense, &target).unwrap();
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
assert!(arr.is_null(0));
assert_eq!(arr.value(1), "hello");
assert_eq!(arr.value(2), "world");
}
#[test]
fn test_no_match_errors() {
let target = DataType::Struct(vec![Field::new("x", DataType::Int32, true)].into());
assert!(!can_cast_types(
&int_str_union_type(UnionMode::Sparse),
&target
));
let union = UnionArray::try_new(
int_str_fields(),
vec![0_i8, 1].into(),
None,
vec![
Arc::new(Int32Array::from(vec![Some(42), None])) as ArrayRef,
Arc::new(StringArray::from(vec![None, Some("hello")])),
],
)
.unwrap();
assert!(cast::cast(&union, &target).is_err());
}
#[test]
fn test_exact_match_preferred_over_family() {
let fields = UnionFields::try_new(
[0, 1],
[
Field::new("a", DataType::Utf8, true),
Field::new("b", DataType::Utf8View, true),
],
)
.unwrap();
let target = DataType::Utf8View;
assert!(can_cast_types(
&DataType::Union(fields.clone(), UnionMode::Sparse),
&target,
));
let union = UnionArray::try_new(
fields,
vec![0_i8, 1, 0].into(),
None,
vec![
Arc::new(StringArray::from(vec![
Some("from_a"),
None,
Some("also_a"),
])) as ArrayRef,
Arc::new(StringViewArray::from(vec![None, Some("from_b"), None])),
],
)
.unwrap();
let result = cast::cast(&union, &target).unwrap();
assert_eq!(result.data_type(), &target);
let arr = result.as_any().downcast_ref::<StringViewArray>().unwrap();
assert!(arr.is_null(0));
assert_eq!(arr.value(1), "from_b");
assert!(arr.is_null(2));
}
#[test]
fn test_null_in_selected_child_array() {
let target = DataType::Utf8;
assert!(can_cast_types(
&int_str_union_type(UnionMode::Sparse),
&target
));
let union = UnionArray::try_new(
int_str_fields(),
vec![1_i8, 1, 1].into(),
None,
vec![
Arc::new(Int32Array::from(vec![None, None, None])) as ArrayRef,
Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
],
)
.unwrap();
let result = cast::cast(&union, &target).unwrap();
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(arr.value(0), "hello");
assert!(arr.is_null(1));
assert_eq!(arr.value(2), "world");
}
#[test]
fn test_empty_union() {
let target = DataType::Utf8View;
assert!(can_cast_types(
&int_str_union_type(UnionMode::Sparse),
&target
));
let union = UnionArray::try_new(
int_str_fields(),
Vec::<i8>::new().into(),
None,
vec![
Arc::new(Int32Array::from(Vec::<Option<i32>>::new())) as ArrayRef,
Arc::new(StringArray::from(Vec::<Option<&str>>::new())),
],
)
.unwrap();
let result = cast::cast(&union, &target).unwrap();
assert_eq!(result.data_type(), &target);
assert_eq!(result.len(), 0);
}
}