1use crate::error::{_plan_err, Result};
19use arrow::{
20 array::{
21 Array, ArrayRef, DictionaryArray, GenericListArray, GenericListViewArray,
22 StructArray, downcast_integer, new_null_array,
23 },
24 compute::{CastOptions, can_cast_types, cast_with_options},
25 datatypes::{DataType, DataType::Struct, Field, FieldRef},
26};
27use std::{collections::HashSet, sync::Arc};
28
29fn cast_struct_column(
57 source_col: &ArrayRef,
58 target_fields: &[Arc<Field>],
59 cast_options: &CastOptions,
60) -> Result<ArrayRef> {
61 if source_col.data_type() == &DataType::Null
62 || (!source_col.is_empty() && source_col.null_count() == source_col.len())
63 {
64 return Ok(new_null_array(
65 &Struct(target_fields.to_vec().into()),
66 source_col.len(),
67 ));
68 }
69
70 if let Some(source_struct) = source_col.as_any().downcast_ref::<StructArray>() {
71 let source_fields = source_struct.fields();
72 validate_struct_compatibility(source_fields, target_fields)?;
73 let mut fields: Vec<Arc<Field>> = Vec::with_capacity(target_fields.len());
74 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(target_fields.len());
75 let num_rows = source_col.len();
76
77 for target_child_field in target_fields.iter() {
79 fields.push(Arc::clone(target_child_field));
80
81 let source_child_opt =
82 source_struct.column_by_name(target_child_field.name());
83
84 match source_child_opt {
85 Some(source_child_col) => {
86 let adapted_child = cast_column(
87 source_child_col,
88 target_child_field.data_type(),
89 cast_options,
90 )
91 .map_err(|e| {
92 e.context(format!(
93 "While casting struct field '{}'",
94 target_child_field.name()
95 ))
96 })?;
97 arrays.push(adapted_child);
98 }
99 None => {
100 arrays.push(new_null_array(target_child_field.data_type(), num_rows));
101 }
102 }
103 }
104
105 let struct_array =
106 StructArray::new(fields.into(), arrays, source_struct.nulls().cloned());
107 Ok(Arc::new(struct_array))
108 } else {
109 _plan_err!(
111 "Cannot cast column of type {} to struct type. Source must be a struct to cast to struct.",
112 source_col.data_type()
113 )
114 }
115}
116
117pub fn cast_column(
172 source_col: &ArrayRef,
173 target_type: &DataType,
174 cast_options: &CastOptions,
175) -> Result<ArrayRef> {
176 match (source_col.data_type(), target_type) {
177 (_, Struct(target_fields)) => {
178 cast_struct_column(source_col, target_fields, cast_options)
179 }
180 (DataType::List(_), DataType::List(target_inner)) => {
181 cast_list_column::<i32>(source_col, target_inner, cast_options)
182 }
183 (DataType::LargeList(_), DataType::LargeList(target_inner)) => {
184 cast_list_column::<i64>(source_col, target_inner, cast_options)
185 }
186 (DataType::ListView(_), DataType::ListView(target_inner)) => {
187 cast_list_view_column::<i32>(source_col, target_inner, cast_options)
188 }
189 (DataType::LargeListView(_), DataType::LargeListView(target_inner)) => {
190 cast_list_view_column::<i64>(source_col, target_inner, cast_options)
191 }
192 (
193 DataType::Dictionary(source_key_type, _),
194 DataType::Dictionary(target_key_type, target_value_type),
195 ) => cast_dictionary_column(
196 source_col,
197 source_key_type,
198 target_key_type,
199 target_value_type,
200 cast_options,
201 ),
202 _ => Ok(cast_with_options(source_col, target_type, cast_options)?),
203 }
204}
205
206fn cast_list_column<O: arrow::array::OffsetSizeTrait>(
207 source_col: &ArrayRef,
208 target_inner_field: &FieldRef,
209 cast_options: &CastOptions,
210) -> Result<ArrayRef> {
211 let source_list = source_col
212 .as_any()
213 .downcast_ref::<GenericListArray<O>>()
214 .ok_or_else(|| {
215 crate::error::DataFusionError::Plan(format!(
216 "Expected list array but got {}",
217 source_col.data_type()
218 ))
219 })?;
220
221 let cast_values = cast_column(
222 source_list.values(),
223 target_inner_field.data_type(),
224 cast_options,
225 )?;
226
227 let result = GenericListArray::<O>::new(
228 Arc::clone(target_inner_field),
229 source_list.offsets().clone(),
230 cast_values,
231 source_list.nulls().cloned(),
232 );
233 Ok(Arc::new(result))
234}
235
236fn cast_list_view_column<O: arrow::array::OffsetSizeTrait>(
237 source_col: &ArrayRef,
238 target_inner_field: &FieldRef,
239 cast_options: &CastOptions,
240) -> Result<ArrayRef> {
241 let source_list = source_col
242 .as_any()
243 .downcast_ref::<GenericListViewArray<O>>()
244 .ok_or_else(|| {
245 crate::error::DataFusionError::Plan(format!(
246 "Expected list view array but got {}",
247 source_col.data_type()
248 ))
249 })?;
250
251 let cast_values = cast_column(
252 source_list.values(),
253 target_inner_field.data_type(),
254 cast_options,
255 )?;
256
257 let result = GenericListViewArray::<O>::try_new(
258 Arc::clone(target_inner_field),
259 source_list.offsets().clone(),
260 source_list.sizes().clone(),
261 cast_values,
262 source_list.nulls().cloned(),
263 )?;
264 Ok(Arc::new(result))
265}
266
267fn cast_dictionary_column(
268 source_col: &ArrayRef,
269 source_key_type: &DataType,
270 target_key_type: &DataType,
271 target_value_type: &DataType,
272 cast_options: &CastOptions,
273) -> Result<ArrayRef> {
274 macro_rules! cast_dict_values {
277 ($t:ty) => {{
278 let source_dict = source_col
279 .as_any()
280 .downcast_ref::<DictionaryArray<$t>>()
281 .expect("downcast must succeed");
282 let cast_values =
283 cast_column(source_dict.values(), target_value_type, cast_options)?;
284 Ok(Arc::new(DictionaryArray::<$t>::new(
285 source_dict.keys().clone(),
286 cast_values,
287 )) as ArrayRef)
288 }};
289 }
290
291 let result: Result<ArrayRef> = downcast_integer! {
292 source_key_type => (cast_dict_values),
293 k => _plan_err!("Unsupported dictionary key type: {k}")
294 };
295 let result = result?;
296
297 if source_key_type != target_key_type {
299 let target_dict_type = DataType::Dictionary(
300 Box::new(target_key_type.clone()),
301 Box::new(target_value_type.clone()),
302 );
303 Ok(cast_with_options(&result, &target_dict_type, cast_options)?)
304 } else {
305 Ok(result)
306 }
307}
308
309pub fn validate_struct_compatibility(
346 source_fields: &[FieldRef],
347 target_fields: &[FieldRef],
348) -> Result<()> {
349 let has_overlap = has_one_of_more_common_fields(source_fields, target_fields);
350 if !has_overlap {
351 return _plan_err!(
352 "Cannot cast struct with {} fields to {} fields because there is no field name overlap",
353 source_fields.len(),
354 target_fields.len()
355 );
356 }
357
358 for target_field in target_fields {
360 if let Some(source_field) = source_fields
362 .iter()
363 .find(|f| f.name() == target_field.name())
364 {
365 validate_field_compatibility(source_field, target_field)?;
366 } else {
367 if !target_field.is_nullable() {
370 return _plan_err!(
371 "Cannot cast struct: target field '{}' is non-nullable but missing from source. \
372 Cannot fill with NULL.",
373 target_field.name()
374 );
375 }
376 }
377 }
378
379 Ok(())
381}
382
383fn validate_field_compatibility(
384 source_field: &Field,
385 target_field: &Field,
386) -> Result<()> {
387 if source_field.data_type() == &DataType::Null {
388 if !target_field.is_nullable() {
391 return _plan_err!(
392 "Cannot cast NULL struct field '{}' to non-nullable field '{}'",
393 source_field.name(),
394 target_field.name()
395 );
396 }
397 return Ok(());
398 }
399
400 if source_field.is_nullable() && !target_field.is_nullable() {
404 return _plan_err!(
405 "Cannot cast nullable struct field '{}' to non-nullable field",
406 target_field.name()
407 );
408 }
409
410 validate_data_type_compatibility(
411 target_field.name(),
412 source_field.data_type(),
413 target_field.data_type(),
414 )
415}
416
417pub fn validate_data_type_compatibility(
420 field_name: &str,
421 source_type: &DataType,
422 target_type: &DataType,
423) -> Result<()> {
424 match (source_type, target_type) {
425 (Struct(source_nested), Struct(target_nested)) => {
426 validate_struct_compatibility(source_nested, target_nested)?;
427 }
428 (DataType::List(s), DataType::List(t))
429 | (DataType::LargeList(s), DataType::LargeList(t))
430 | (DataType::ListView(s), DataType::ListView(t))
431 | (DataType::LargeListView(s), DataType::LargeListView(t)) => {
432 validate_field_compatibility(s, t)?;
433 }
434 (DataType::Dictionary(s_key, s_val), DataType::Dictionary(t_key, t_val)) => {
435 if !can_cast_types(s_key, t_key) {
436 return _plan_err!(
437 "Cannot cast dictionary key type {} to {} for field '{}'",
438 s_key,
439 t_key,
440 field_name
441 );
442 }
443 validate_data_type_compatibility(field_name, s_val, t_val)?;
444 }
445 _ => {
446 if !can_cast_types(source_type, target_type) {
447 return _plan_err!(
448 "Cannot cast struct field '{}' from type {} to type {}",
449 field_name,
450 source_type,
451 target_type
452 );
453 }
454 }
455 }
456 Ok(())
457}
458
459pub fn requires_nested_struct_cast(
470 source_type: &DataType,
471 target_type: &DataType,
472) -> bool {
473 match (source_type, target_type) {
474 (Struct(_), Struct(_)) => true,
475 (DataType::List(s), DataType::List(t))
476 | (DataType::LargeList(s), DataType::LargeList(t))
477 | (DataType::ListView(s), DataType::ListView(t))
478 | (DataType::LargeListView(s), DataType::LargeListView(t)) => {
479 requires_nested_struct_cast(s.data_type(), t.data_type())
480 }
481 (DataType::Dictionary(_, s_val), DataType::Dictionary(_, t_val)) => {
482 requires_nested_struct_cast(s_val, t_val)
483 }
484 _ => false,
485 }
486}
487
488pub fn has_one_of_more_common_fields(
493 source_fields: &[FieldRef],
494 target_fields: &[FieldRef],
495) -> bool {
496 let source_names: HashSet<&str> = source_fields
497 .iter()
498 .map(|field| field.name().as_str())
499 .collect();
500 target_fields
501 .iter()
502 .any(|field| source_names.contains(field.name().as_str()))
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508 use crate::{assert_contains, format::DEFAULT_CAST_OPTIONS};
509 use arrow::{
510 array::{
511 BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, ListViewArray,
512 MapArray, MapBuilder, NullArray, StringArray, StringBuilder,
513 },
514 buffer::{NullBuffer, ScalarBuffer},
515 datatypes::{DataType, Field, FieldRef, Int32Type},
516 };
517 macro_rules! get_column_as {
519 ($struct_array:expr, $column_name:expr, $array_type:ty) => {
520 $struct_array
521 .column_by_name($column_name)
522 .unwrap()
523 .as_any()
524 .downcast_ref::<$array_type>()
525 .unwrap()
526 };
527 }
528
529 fn field(name: &str, data_type: DataType) -> Field {
530 Field::new(name, data_type, true)
531 }
532
533 fn non_null_field(name: &str, data_type: DataType) -> Field {
534 Field::new(name, data_type, false)
535 }
536
537 fn arc_field(name: &str, data_type: DataType) -> FieldRef {
538 Arc::new(field(name, data_type))
539 }
540
541 fn struct_type(fields: Vec<Field>) -> DataType {
542 Struct(fields.into())
543 }
544
545 fn struct_field(name: &str, fields: Vec<Field>) -> Field {
546 field(name, struct_type(fields))
547 }
548
549 fn arc_struct_field(name: &str, fields: Vec<Field>) -> FieldRef {
550 Arc::new(struct_field(name, fields))
551 }
552
553 #[test]
554 fn test_cast_simple_column() {
555 let source = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
556 let target_field = field("ints", DataType::Int64);
557 let result =
558 cast_column(&source, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
559 .unwrap();
560 let result = result.as_any().downcast_ref::<Int64Array>().unwrap();
561 assert_eq!(result.len(), 3);
562 assert_eq!(result.value(0), 1);
563 assert_eq!(result.value(1), 2);
564 assert_eq!(result.value(2), 3);
565 }
566
567 #[test]
568 fn test_cast_column_with_options() {
569 let source = Arc::new(Int64Array::from(vec![1, i64::MAX])) as ArrayRef;
570 let target_field = field("ints", DataType::Int32);
571
572 let safe_opts = CastOptions {
573 safe: false,
575 ..DEFAULT_CAST_OPTIONS
576 };
577 assert!(cast_column(&source, target_field.data_type(), &safe_opts).is_err());
578
579 let unsafe_opts = CastOptions {
580 safe: true,
582 ..DEFAULT_CAST_OPTIONS
583 };
584 let result =
585 cast_column(&source, target_field.data_type(), &unsafe_opts).unwrap();
586 let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
587 assert_eq!(result.value(0), 1);
588 assert!(result.is_null(1));
589 }
590
591 #[test]
592 fn test_cast_struct_with_missing_field() {
593 let a_array = Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef;
594 let source_struct = StructArray::from(vec![(
595 arc_field("a", DataType::Int32),
596 Arc::clone(&a_array),
597 )]);
598 let source_col = Arc::new(source_struct) as ArrayRef;
599
600 let target_field = struct_field(
601 "s",
602 vec![field("a", DataType::Int32), field("b", DataType::Utf8)],
603 );
604
605 let result =
606 cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
607 .unwrap();
608 let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
609 assert_eq!(struct_array.fields().len(), 2);
610 let a_result = get_column_as!(&struct_array, "a", Int32Array);
611 assert_eq!(a_result.value(0), 1);
612 assert_eq!(a_result.value(1), 2);
613
614 let b_result = get_column_as!(&struct_array, "b", StringArray);
615 assert_eq!(b_result.len(), 2);
616 assert!(b_result.is_null(0));
617 assert!(b_result.is_null(1));
618 }
619
620 #[test]
621 fn test_cast_struct_source_not_struct() {
622 let source = Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef;
623 let target_field = struct_field("s", vec![field("a", DataType::Int32)]);
624
625 let result =
626 cast_column(&source, target_field.data_type(), &DEFAULT_CAST_OPTIONS);
627 assert!(result.is_err());
628 let error_msg = result.unwrap_err().to_string();
629 assert!(error_msg.contains("Cannot cast column of type"));
630 assert!(error_msg.contains("to struct type"));
631 assert!(error_msg.contains("Source must be a struct"));
632 }
633
634 #[test]
635 fn test_cast_struct_incompatible_child_type() {
636 let a_array = Arc::new(BinaryArray::from(vec![
637 Some(b"a".as_ref()),
638 Some(b"b".as_ref()),
639 ])) as ArrayRef;
640 let source_struct =
641 StructArray::from(vec![(arc_field("a", DataType::Binary), a_array)]);
642 let source_col = Arc::new(source_struct) as ArrayRef;
643
644 let target_field = struct_field("s", vec![field("a", DataType::Int32)]);
645
646 let result =
647 cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS);
648 assert!(result.is_err());
649 let error_msg = result.unwrap_err().to_string();
650 assert!(error_msg.contains("Cannot cast struct field 'a'"));
651 }
652
653 #[test]
654 fn test_validate_struct_compatibility_incompatible_types() {
655 let source_fields = vec![
657 arc_field("field1", DataType::Binary),
658 arc_field("field2", DataType::Utf8),
659 ];
660
661 let target_fields = vec![arc_field("field1", DataType::Int32)];
663
664 let result = validate_struct_compatibility(&source_fields, &target_fields);
665 assert!(result.is_err());
666 let error_msg = result.unwrap_err().to_string();
667 assert!(error_msg.contains("Cannot cast struct field 'field1'"));
668 assert!(error_msg.contains("Binary"));
669 assert!(error_msg.contains("Int32"));
670 }
671
672 #[test]
673 fn test_validate_struct_compatibility_compatible_types() {
674 let source_fields = vec![
676 arc_field("field1", DataType::Int32),
677 arc_field("field2", DataType::Utf8),
678 ];
679
680 let target_fields = vec![arc_field("field1", DataType::Int64)];
682
683 let result = validate_struct_compatibility(&source_fields, &target_fields);
684 assert!(result.is_ok());
685 }
686
687 #[test]
688 fn test_validate_struct_compatibility_missing_field_in_source() {
689 let source_fields = vec![arc_field("field1", DataType::Int32)];
691
692 let target_fields = vec![
694 arc_field("field1", DataType::Int32),
695 arc_field("field2", DataType::Utf8),
696 ];
697
698 let result = validate_struct_compatibility(&source_fields, &target_fields);
700 assert!(result.is_ok());
701 }
702
703 #[test]
704 fn test_validate_struct_compatibility_additional_field_in_source() {
705 let source_fields = vec![
707 arc_field("field1", DataType::Int32),
708 arc_field("field2", DataType::Utf8),
709 ];
710
711 let target_fields = vec![arc_field("field1", DataType::Int32)];
713
714 let result = validate_struct_compatibility(&source_fields, &target_fields);
716 assert!(result.is_ok());
717 }
718
719 #[test]
720 fn test_validate_struct_compatibility_no_overlap_mismatch_len() {
721 let source_fields = vec![
722 arc_field("left", DataType::Int32),
723 arc_field("right", DataType::Int32),
724 ];
725 let target_fields = vec![arc_field("alpha", DataType::Int32)];
726
727 let result = validate_struct_compatibility(&source_fields, &target_fields);
728 assert!(result.is_err());
729 let error_msg = result.unwrap_err().to_string();
730 assert_contains!(error_msg, "no field name overlap");
731 }
732
733 #[test]
734 fn test_cast_struct_parent_nulls_retained() {
735 let a_array = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef;
736 let fields = vec![arc_field("a", DataType::Int32)];
737 let nulls = Some(NullBuffer::from(vec![true, false]));
738 let source_struct = StructArray::new(fields.clone().into(), vec![a_array], nulls);
739 let source_col = Arc::new(source_struct) as ArrayRef;
740
741 let target_field = struct_field("s", vec![field("a", DataType::Int64)]);
742
743 let result =
744 cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
745 .unwrap();
746 let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
747 assert_eq!(struct_array.null_count(), 1);
748 assert!(struct_array.is_valid(0));
749 assert!(struct_array.is_null(1));
750
751 let a_result = get_column_as!(&struct_array, "a", Int64Array);
752 assert_eq!(a_result.value(0), 1);
753 assert_eq!(a_result.value(1), 2);
754 }
755
756 #[test]
757 fn test_validate_struct_compatibility_nullable_to_non_nullable() {
758 let source_fields = vec![arc_field("field1", DataType::Int32)];
760
761 let target_fields = vec![Arc::new(non_null_field("field1", DataType::Int32))];
763
764 let result = validate_struct_compatibility(&source_fields, &target_fields);
765 assert!(result.is_err());
766 let error_msg = result.unwrap_err().to_string();
767 assert!(error_msg.contains("field1"));
768 assert!(error_msg.contains("non-nullable"));
769 }
770
771 #[test]
772 fn test_validate_struct_compatibility_non_nullable_to_nullable() {
773 let source_fields = vec![Arc::new(non_null_field("field1", DataType::Int32))];
775
776 let target_fields = vec![arc_field("field1", DataType::Int32)];
778
779 let result = validate_struct_compatibility(&source_fields, &target_fields);
780 assert!(result.is_ok());
781 }
782
783 #[test]
784 fn test_validate_struct_compatibility_nested_nullable_to_non_nullable() {
785 let source_fields = vec![Arc::new(non_null_field(
787 "field1",
788 struct_type(vec![field("nested", DataType::Int32)]),
789 ))];
790
791 let target_fields = vec![Arc::new(non_null_field(
793 "field1",
794 struct_type(vec![non_null_field("nested", DataType::Int32)]),
795 ))];
796
797 let result = validate_struct_compatibility(&source_fields, &target_fields);
798 assert!(result.is_err());
799 let error_msg = result.unwrap_err().to_string();
800 assert!(error_msg.contains("nested"));
801 assert!(error_msg.contains("non-nullable"));
802 }
803
804 #[test]
805 fn test_validate_struct_compatibility_by_name() {
806 let source_fields = vec![
808 arc_field("field1", DataType::Int32),
809 arc_field("field2", DataType::Utf8),
810 ];
811
812 let target_fields = vec![
814 arc_field("field2", DataType::Utf8),
815 arc_field("field1", DataType::Int64),
816 ];
817
818 let result = validate_struct_compatibility(&source_fields, &target_fields);
819 assert!(result.is_ok());
820 }
821
822 #[test]
823 fn test_validate_struct_compatibility_by_name_with_type_mismatch() {
824 let source_fields = vec![arc_field("field1", DataType::Binary)];
826
827 let target_fields = vec![arc_field("field1", DataType::Int32)];
829
830 let result = validate_struct_compatibility(&source_fields, &target_fields);
831 assert!(result.is_err());
832 let error_msg = result.unwrap_err().to_string();
833 assert_contains!(
834 error_msg,
835 "Cannot cast struct field 'field1' from type Binary to type Int32"
836 );
837 }
838
839 #[test]
840 fn test_validate_struct_compatibility_no_overlap_equal_len() {
841 let source_fields = vec![
842 arc_field("left", DataType::Int32),
843 arc_field("right", DataType::Utf8),
844 ];
845
846 let target_fields = vec![
847 arc_field("alpha", DataType::Int32),
848 arc_field("beta", DataType::Utf8),
849 ];
850
851 let result = validate_struct_compatibility(&source_fields, &target_fields);
852 assert!(result.is_err());
853 let error_msg = result.unwrap_err().to_string();
854 assert_contains!(error_msg, "no field name overlap");
855 }
856
857 #[test]
858 fn test_validate_struct_compatibility_mixed_name_overlap() {
859 let source_fields = vec![
861 arc_field("a", DataType::Int32),
862 arc_field("b", DataType::Utf8),
863 arc_field("extra", DataType::Boolean),
864 ];
865
866 let target_fields = vec![
869 arc_field("b", DataType::Utf8),
870 arc_field("a", DataType::Int64),
871 arc_field("c", DataType::Float32),
872 ];
873
874 let result = validate_struct_compatibility(&source_fields, &target_fields);
875 assert!(result.is_ok());
876 }
877
878 #[test]
879 fn test_validate_struct_compatibility_by_name_missing_required_field() {
880 let source_fields = vec![arc_field("field1", DataType::Int32)];
882
883 let target_fields = vec![
885 arc_field("field1", DataType::Int32),
886 Arc::new(non_null_field("field2", DataType::Int32)),
887 ];
888
889 let result = validate_struct_compatibility(&source_fields, &target_fields);
890 assert!(result.is_err());
891 let error_msg = result.unwrap_err().to_string();
892 assert_contains!(
893 error_msg,
894 "Cannot cast struct: target field 'field2' is non-nullable but missing from source. Cannot fill with NULL."
895 );
896 }
897
898 #[test]
899 fn test_validate_struct_compatibility_partial_name_overlap_with_count_mismatch() {
900 let source_fields = vec![arc_field("a", DataType::Int32)];
902
903 let target_fields = vec![
905 arc_field("a", DataType::Int32),
906 arc_field("b", DataType::Utf8),
907 ];
908
909 let result = validate_struct_compatibility(&source_fields, &target_fields);
912 assert!(result.is_ok());
913 }
914
915 #[test]
916 fn test_cast_nested_struct_with_extra_and_missing_fields() {
917 let a = Arc::new(Int32Array::from(vec![Some(1), None])) as ArrayRef;
919 let b = Arc::new(Int32Array::from(vec![Some(2), Some(3)])) as ArrayRef;
920 let extra = Arc::new(Int32Array::from(vec![Some(9), Some(10)])) as ArrayRef;
921
922 let inner = StructArray::from(vec![
923 (arc_field("a", DataType::Int32), a),
924 (arc_field("b", DataType::Int32), b),
925 (arc_field("extra", DataType::Int32), extra),
926 ]);
927
928 let source_struct = StructArray::from(vec![(
929 arc_struct_field(
930 "inner",
931 vec![
932 field("a", DataType::Int32),
933 field("b", DataType::Int32),
934 field("extra", DataType::Int32),
935 ],
936 ),
937 Arc::new(inner) as ArrayRef,
938 )]);
939 let source_col = Arc::new(source_struct) as ArrayRef;
940
941 let target_field = struct_field(
943 "outer",
944 vec![struct_field(
945 "inner",
946 vec![
947 field("b", DataType::Int64),
948 field("a", DataType::Int32),
949 field("missing", DataType::Int32),
950 ],
951 )],
952 );
953
954 let result =
955 cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
956 .unwrap();
957 let outer = result.as_any().downcast_ref::<StructArray>().unwrap();
958 let inner = get_column_as!(&outer, "inner", StructArray);
959 assert_eq!(inner.fields().len(), 3);
960
961 let b = get_column_as!(inner, "b", Int64Array);
962 assert_eq!(b.value(0), 2);
963 assert_eq!(b.value(1), 3);
964 assert!(!b.is_null(0));
965 assert!(!b.is_null(1));
966
967 let a = get_column_as!(inner, "a", Int32Array);
968 assert_eq!(a.value(0), 1);
969 assert!(a.is_null(1));
970
971 let missing = get_column_as!(inner, "missing", Int32Array);
972 assert!(missing.is_null(0));
973 assert!(missing.is_null(1));
974 }
975
976 #[test]
977 fn test_cast_null_struct_field_to_nested_struct() {
978 let null_inner = Arc::new(NullArray::new(2)) as ArrayRef;
979 let source_struct = StructArray::from(vec![(
980 arc_field("inner", DataType::Null),
981 Arc::clone(&null_inner),
982 )]);
983 let source_col = Arc::new(source_struct) as ArrayRef;
984
985 let target_field = struct_field(
986 "outer",
987 vec![struct_field("inner", vec![field("a", DataType::Int32)])],
988 );
989
990 let result =
991 cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
992 .unwrap();
993 let outer = result.as_any().downcast_ref::<StructArray>().unwrap();
994 let inner = get_column_as!(&outer, "inner", StructArray);
995 assert_eq!(inner.len(), 2);
996 assert!(inner.is_null(0));
997 assert!(inner.is_null(1));
998
999 let inner_a = get_column_as!(inner, "a", Int32Array);
1000 assert!(inner_a.is_null(0));
1001 assert!(inner_a.is_null(1));
1002 }
1003
1004 #[test]
1005 fn test_cast_struct_with_array_and_map_fields() {
1006 let arr_array = Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
1008 Some(vec![Some(1), Some(2)]),
1009 None,
1010 ])) as ArrayRef;
1011
1012 let string_builder = StringBuilder::new();
1014 let int_builder = Int32Builder::new();
1015 let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
1016 map_builder.keys().append_value("a");
1017 map_builder.values().append_value(1);
1018 map_builder.append(true).unwrap();
1019 map_builder.append(false).unwrap();
1020 let map_array = Arc::new(map_builder.finish()) as ArrayRef;
1021
1022 let source_struct = StructArray::from(vec![
1023 (
1024 arc_field(
1025 "arr",
1026 DataType::List(Arc::new(field("item", DataType::Int32))),
1027 ),
1028 arr_array,
1029 ),
1030 (
1031 arc_field(
1032 "map",
1033 DataType::Map(
1034 Arc::new(non_null_field(
1035 "entries",
1036 struct_type(vec![
1037 non_null_field("keys", DataType::Utf8),
1038 field("values", DataType::Int32),
1039 ]),
1040 )),
1041 false,
1042 ),
1043 ),
1044 map_array,
1045 ),
1046 ]);
1047 let source_col = Arc::new(source_struct) as ArrayRef;
1048
1049 let target_field = struct_field(
1050 "s",
1051 vec![
1052 field(
1053 "arr",
1054 DataType::List(Arc::new(field("item", DataType::Int32))),
1055 ),
1056 field(
1057 "map",
1058 DataType::Map(
1059 Arc::new(non_null_field(
1060 "entries",
1061 struct_type(vec![
1062 non_null_field("keys", DataType::Utf8),
1063 field("values", DataType::Int32),
1064 ]),
1065 )),
1066 false,
1067 ),
1068 ),
1069 ],
1070 );
1071
1072 let result =
1073 cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
1074 .unwrap();
1075 let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
1076
1077 let arr = get_column_as!(&struct_array, "arr", ListArray);
1078 assert!(!arr.is_null(0));
1079 assert!(arr.is_null(1));
1080 let arr0 = arr.value(0);
1081 let values = arr0.as_any().downcast_ref::<Int32Array>().unwrap();
1082 assert_eq!(values.value(0), 1);
1083 assert_eq!(values.value(1), 2);
1084
1085 let map = get_column_as!(&struct_array, "map", MapArray);
1086 assert!(!map.is_null(0));
1087 assert!(map.is_null(1));
1088 let map0 = map.value(0);
1089 let entries = map0.as_any().downcast_ref::<StructArray>().unwrap();
1090 let keys = get_column_as!(entries, "keys", StringArray);
1091 let vals = get_column_as!(entries, "values", Int32Array);
1092 assert_eq!(keys.value(0), "a");
1093 assert_eq!(vals.value(0), 1);
1094 }
1095
1096 #[test]
1097 fn test_cast_struct_field_order_differs() {
1098 let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef;
1099 let b = Arc::new(Int32Array::from(vec![Some(3), None])) as ArrayRef;
1100
1101 let source_struct = StructArray::from(vec![
1102 (arc_field("a", DataType::Int32), a),
1103 (arc_field("b", DataType::Int32), b),
1104 ]);
1105 let source_col = Arc::new(source_struct) as ArrayRef;
1106
1107 let target_field = struct_field(
1108 "s",
1109 vec![field("b", DataType::Int64), field("a", DataType::Int32)],
1110 );
1111
1112 let result =
1113 cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
1114 .unwrap();
1115 let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
1116
1117 let b_col = get_column_as!(&struct_array, "b", Int64Array);
1118 assert_eq!(b_col.value(0), 3);
1119 assert!(b_col.is_null(1));
1120
1121 let a_col = get_column_as!(&struct_array, "a", Int32Array);
1122 assert_eq!(a_col.value(0), 1);
1123 assert_eq!(a_col.value(1), 2);
1124 }
1125
1126 #[test]
1127 fn test_cast_struct_no_overlap_rejected() {
1128 let first = Arc::new(Int32Array::from(vec![Some(10), Some(20)])) as ArrayRef;
1129 let second =
1130 Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])) as ArrayRef;
1131
1132 let source_struct = StructArray::from(vec![
1133 (arc_field("left", DataType::Int32), first),
1134 (arc_field("right", DataType::Utf8), second),
1135 ]);
1136 let source_col = Arc::new(source_struct) as ArrayRef;
1137
1138 let target_field = struct_field(
1139 "s",
1140 vec![field("a", DataType::Int64), field("b", DataType::Utf8)],
1141 );
1142
1143 let result =
1144 cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS);
1145 assert!(result.is_err());
1146 let error_msg = result.unwrap_err().to_string();
1147 assert_contains!(error_msg, "no field name overlap");
1148 }
1149
1150 #[test]
1151 fn test_cast_struct_missing_non_nullable_field_fails() {
1152 let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef;
1154 let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]);
1155 let source_col = Arc::new(source_struct) as ArrayRef;
1156
1157 let target_field = struct_field(
1159 "s",
1160 vec![
1161 field("a", DataType::Int32),
1162 non_null_field("b", DataType::Int32),
1163 ],
1164 );
1165
1166 let result =
1168 cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS);
1169 assert!(result.is_err());
1170 let err = result.unwrap_err();
1171 assert!(
1172 err.to_string()
1173 .contains("target field 'b' is non-nullable but missing from source"),
1174 "Unexpected error: {err}"
1175 );
1176 }
1177
1178 #[test]
1179 fn test_cast_struct_missing_nullable_field_succeeds() {
1180 let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef;
1182 let source_struct = StructArray::from(vec![(arc_field("a", DataType::Int32), a)]);
1183 let source_col = Arc::new(source_struct) as ArrayRef;
1184
1185 let target_field = struct_field(
1187 "s",
1188 vec![field("a", DataType::Int32), field("b", DataType::Int32)],
1189 );
1190
1191 let result =
1193 cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS)
1194 .unwrap();
1195 let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
1196
1197 let a_col = get_column_as!(&struct_array, "a", Int32Array);
1198 assert_eq!(a_col.value(0), 1);
1199 assert_eq!(a_col.value(1), 2);
1200
1201 let b_col = get_column_as!(&struct_array, "b", Int32Array);
1202 assert!(b_col.is_null(0));
1203 assert!(b_col.is_null(1));
1204 }
1205
1206 #[test]
1207 fn test_validate_dictionary_value_evolution() {
1208 let source_inner = struct_type(vec![field("a", DataType::Int32)]);
1209 let target_inner = struct_type(vec![
1210 field("a", DataType::Int32),
1211 field("b", DataType::Utf8),
1212 ]);
1213 let source =
1214 DataType::Dictionary(Box::new(DataType::Int32), Box::new(source_inner));
1215 let target =
1216 DataType::Dictionary(Box::new(DataType::Int32), Box::new(target_inner));
1217 assert!(validate_data_type_compatibility("col", &source, &target).is_ok());
1218 }
1219
1220 #[test]
1221 fn test_cast_dictionary_struct_value() {
1222 let struct_arr = StructArray::from(vec![(
1225 arc_field("a", DataType::Int32),
1226 Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef,
1227 )]);
1228 let keys = Int32Array::from(vec![Some(0), None, Some(1)]);
1230 let source_dict = DictionaryArray::<Int32Type>::new(keys, Arc::new(struct_arr));
1231 let source_col: ArrayRef = Arc::new(source_dict);
1232
1233 let target_type = DataType::Dictionary(
1234 Box::new(DataType::Int32),
1235 Box::new(struct_type(vec![
1236 field("a", DataType::Int64),
1237 field("b", DataType::Utf8),
1238 ])),
1239 );
1240
1241 let result =
1242 cast_column(&source_col, &target_type, &DEFAULT_CAST_OPTIONS).unwrap();
1243 let result_dict = result
1244 .as_any()
1245 .downcast_ref::<DictionaryArray<Int32Type>>()
1246 .unwrap();
1247
1248 assert!(result_dict.is_valid(0));
1249 assert!(result_dict.is_null(1));
1250 assert!(result_dict.is_valid(2));
1251
1252 let struct_values = result_dict
1253 .values()
1254 .as_any()
1255 .downcast_ref::<StructArray>()
1256 .unwrap();
1257 let a_col = get_column_as!(&struct_values, "a", Int64Array);
1258 assert_eq!(a_col.values(), &[10, 20]);
1259 let b_col = get_column_as!(&struct_values, "b", StringArray);
1260 assert!(b_col.iter().all(|v| v.is_none()));
1261 }
1262
1263 #[test]
1264 fn test_cast_list_view_struct() {
1265 let struct_arr = StructArray::from(vec![(
1268 arc_field("a", DataType::Int32),
1269 Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
1270 )]);
1271
1272 let source_field =
1273 arc_field("item", struct_type(vec![field("a", DataType::Int32)]));
1274 let target_field = arc_field(
1275 "item",
1276 struct_type(vec![
1277 field("a", DataType::Int64),
1278 field("b", DataType::Utf8),
1279 ]),
1280 );
1281
1282 let list_view = ListViewArray::new(
1284 source_field,
1285 ScalarBuffer::from(vec![0i32, 2]),
1286 ScalarBuffer::from(vec![2i32, 1]),
1287 Arc::new(struct_arr),
1288 None,
1289 );
1290 let source_col: ArrayRef = Arc::new(list_view);
1291
1292 let target_type = DataType::ListView(target_field);
1293
1294 let result =
1295 cast_column(&source_col, &target_type, &DEFAULT_CAST_OPTIONS).unwrap();
1296 let result_lv = result.as_any().downcast_ref::<ListViewArray>().unwrap();
1297 assert_eq!(result_lv.len(), 2);
1298
1299 let struct_values = result_lv
1300 .values()
1301 .as_any()
1302 .downcast_ref::<StructArray>()
1303 .unwrap();
1304 let a_col = get_column_as!(&struct_values, "a", Int64Array);
1305 assert_eq!(a_col.values(), &[1, 2, 3]);
1306 let b_col = get_column_as!(&struct_values, "b", StringArray);
1307 assert!(b_col.iter().all(|v| v.is_none()));
1308 }
1309
1310 #[test]
1311 fn test_requires_nested_struct_cast() {
1312 let s1 = struct_type(vec![field("a", DataType::Int32)]);
1313 let s2 = struct_type(vec![field("a", DataType::Int64)]);
1314
1315 assert!(requires_nested_struct_cast(&s1, &s2));
1316 assert!(requires_nested_struct_cast(
1317 &DataType::List(arc_field("item", s1.clone())),
1318 &DataType::List(arc_field("item", s2.clone())),
1319 ));
1320 assert!(requires_nested_struct_cast(
1321 &DataType::Dictionary(Box::new(DataType::Int32), Box::new(s1.clone())),
1322 &DataType::Dictionary(Box::new(DataType::Int32), Box::new(s2.clone())),
1323 ));
1324 assert!(requires_nested_struct_cast(
1325 &DataType::ListView(arc_field("item", s1)),
1326 &DataType::ListView(arc_field("item", s2)),
1327 ));
1328
1329 assert!(!requires_nested_struct_cast(
1331 &DataType::Int32,
1332 &DataType::Int64
1333 ));
1334 assert!(!requires_nested_struct_cast(
1335 &DataType::List(arc_field("item", DataType::Int32)),
1336 &DataType::List(arc_field("item", DataType::Int64)),
1337 ));
1338 }
1339}