1use crate::error::{Result, _plan_err};
19use arrow::{
20 array::{new_null_array, Array, ArrayRef, StructArray},
21 compute::{cast_with_options, CastOptions},
22 datatypes::{DataType::Struct, Field, FieldRef},
23};
24use std::sync::Arc;
25
26fn cast_struct_column(
53 source_col: &ArrayRef,
54 target_fields: &[Arc<Field>],
55 cast_options: &CastOptions,
56) -> Result<ArrayRef> {
57 if let Some(source_struct) = source_col.as_any().downcast_ref::<StructArray>() {
58 validate_struct_compatibility(source_struct.fields(), target_fields)?;
59
60 let mut fields: Vec<Arc<Field>> = Vec::with_capacity(target_fields.len());
61 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(target_fields.len());
62 let num_rows = source_col.len();
63
64 for target_child_field in target_fields {
65 fields.push(Arc::clone(target_child_field));
66 match source_struct.column_by_name(target_child_field.name()) {
67 Some(source_child_col) => {
68 let adapted_child =
69 cast_column(source_child_col, target_child_field, cast_options)
70 .map_err(|e| {
71 e.context(format!(
72 "While casting struct field '{}'",
73 target_child_field.name()
74 ))
75 })?;
76 arrays.push(adapted_child);
77 }
78 None => {
79 arrays.push(new_null_array(target_child_field.data_type(), num_rows));
80 }
81 }
82 }
83
84 let struct_array =
85 StructArray::new(fields.into(), arrays, source_struct.nulls().cloned());
86 Ok(Arc::new(struct_array))
87 } else {
88 _plan_err!(
90 "Cannot cast column of type {} to struct type. Source must be a struct to cast to struct.",
91 source_col.data_type()
92 )
93 }
94}
95
96pub fn cast_column(
152 source_col: &ArrayRef,
153 target_field: &Field,
154 cast_options: &CastOptions,
155) -> Result<ArrayRef> {
156 match target_field.data_type() {
157 Struct(target_fields) => {
158 cast_struct_column(source_col, target_fields, cast_options)
159 }
160 _ => Ok(cast_with_options(
161 source_col,
162 target_field.data_type(),
163 cast_options,
164 )?),
165 }
166}
167
168pub fn validate_struct_compatibility(
204 source_fields: &[FieldRef],
205 target_fields: &[FieldRef],
206) -> Result<()> {
207 for target_field in target_fields {
209 if let Some(source_field) = source_fields
211 .iter()
212 .find(|f| f.name() == target_field.name())
213 {
214 if source_field.is_nullable() && !target_field.is_nullable() {
218 return _plan_err!(
219 "Cannot cast nullable struct field '{}' to non-nullable field",
220 target_field.name()
221 );
222 }
223 match (source_field.data_type(), target_field.data_type()) {
225 (Struct(source_nested), Struct(target_nested)) => {
227 validate_struct_compatibility(source_nested, target_nested)?;
228 }
229 _ => {
231 if !arrow::compute::can_cast_types(
232 source_field.data_type(),
233 target_field.data_type(),
234 ) {
235 return _plan_err!(
236 "Cannot cast struct field '{}' from type {} to type {}",
237 target_field.name(),
238 source_field.data_type(),
239 target_field.data_type()
240 );
241 }
242 }
243 }
244 }
245 }
247
248 Ok(())
250}
251
252#[cfg(test)]
253mod tests {
254
255 use super::*;
256 use crate::format::DEFAULT_CAST_OPTIONS;
257 use arrow::{
258 array::{
259 BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, MapArray,
260 MapBuilder, StringArray, StringBuilder,
261 },
262 buffer::NullBuffer,
263 datatypes::{DataType, Field, FieldRef, Int32Type},
264 };
265 macro_rules! get_column_as {
267 ($struct_array:expr, $column_name:expr, $array_type:ty) => {
268 $struct_array
269 .column_by_name($column_name)
270 .unwrap()
271 .as_any()
272 .downcast_ref::<$array_type>()
273 .unwrap()
274 };
275 }
276
277 fn field(name: &str, data_type: DataType) -> Field {
278 Field::new(name, data_type, true)
279 }
280
281 fn non_null_field(name: &str, data_type: DataType) -> Field {
282 Field::new(name, data_type, false)
283 }
284
285 fn arc_field(name: &str, data_type: DataType) -> FieldRef {
286 Arc::new(field(name, data_type))
287 }
288
289 fn struct_type(fields: Vec<Field>) -> DataType {
290 Struct(fields.into())
291 }
292
293 fn struct_field(name: &str, fields: Vec<Field>) -> Field {
294 field(name, struct_type(fields))
295 }
296
297 fn arc_struct_field(name: &str, fields: Vec<Field>) -> FieldRef {
298 Arc::new(struct_field(name, fields))
299 }
300
301 #[test]
302 fn test_cast_simple_column() {
303 let source = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
304 let target_field = field("ints", DataType::Int64);
305 let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS).unwrap();
306 let result = result.as_any().downcast_ref::<Int64Array>().unwrap();
307 assert_eq!(result.len(), 3);
308 assert_eq!(result.value(0), 1);
309 assert_eq!(result.value(1), 2);
310 assert_eq!(result.value(2), 3);
311 }
312
313 #[test]
314 fn test_cast_column_with_options() {
315 let source = Arc::new(Int64Array::from(vec![1, i64::MAX])) as ArrayRef;
316 let target_field = field("ints", DataType::Int32);
317
318 let safe_opts = CastOptions {
319 safe: false,
321 ..DEFAULT_CAST_OPTIONS
322 };
323 assert!(cast_column(&source, &target_field, &safe_opts).is_err());
324
325 let unsafe_opts = CastOptions {
326 safe: true,
328 ..DEFAULT_CAST_OPTIONS
329 };
330 let result = cast_column(&source, &target_field, &unsafe_opts).unwrap();
331 let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
332 assert_eq!(result.value(0), 1);
333 assert!(result.is_null(1));
334 }
335
336 #[test]
337 fn test_cast_struct_with_missing_field() {
338 let a_array = Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef;
339 let source_struct = StructArray::from(vec![(
340 arc_field("a", DataType::Int32),
341 Arc::clone(&a_array),
342 )]);
343 let source_col = Arc::new(source_struct) as ArrayRef;
344
345 let target_field = struct_field(
346 "s",
347 vec![field("a", DataType::Int32), field("b", DataType::Utf8)],
348 );
349
350 let result =
351 cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap();
352 let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
353 assert_eq!(struct_array.fields().len(), 2);
354 let a_result = get_column_as!(&struct_array, "a", Int32Array);
355 assert_eq!(a_result.value(0), 1);
356 assert_eq!(a_result.value(1), 2);
357
358 let b_result = get_column_as!(&struct_array, "b", StringArray);
359 assert_eq!(b_result.len(), 2);
360 assert!(b_result.is_null(0));
361 assert!(b_result.is_null(1));
362 }
363
364 #[test]
365 fn test_cast_struct_source_not_struct() {
366 let source = Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef;
367 let target_field = struct_field("s", vec![field("a", DataType::Int32)]);
368
369 let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS);
370 assert!(result.is_err());
371 let error_msg = result.unwrap_err().to_string();
372 assert!(error_msg.contains("Cannot cast column of type"));
373 assert!(error_msg.contains("to struct type"));
374 assert!(error_msg.contains("Source must be a struct"));
375 }
376
377 #[test]
378 fn test_cast_struct_incompatible_child_type() {
379 let a_array = Arc::new(BinaryArray::from(vec![
380 Some(b"a".as_ref()),
381 Some(b"b".as_ref()),
382 ])) as ArrayRef;
383 let source_struct =
384 StructArray::from(vec![(arc_field("a", DataType::Binary), a_array)]);
385 let source_col = Arc::new(source_struct) as ArrayRef;
386
387 let target_field = struct_field("s", vec![field("a", DataType::Int32)]);
388
389 let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS);
390 assert!(result.is_err());
391 let error_msg = result.unwrap_err().to_string();
392 assert!(error_msg.contains("Cannot cast struct field 'a'"));
393 }
394
395 #[test]
396 fn test_validate_struct_compatibility_incompatible_types() {
397 let source_fields = vec![
399 arc_field("field1", DataType::Binary),
400 arc_field("field2", DataType::Utf8),
401 ];
402
403 let target_fields = vec![arc_field("field1", DataType::Int32)];
405
406 let result = validate_struct_compatibility(&source_fields, &target_fields);
407 assert!(result.is_err());
408 let error_msg = result.unwrap_err().to_string();
409 assert!(error_msg.contains("Cannot cast struct field 'field1'"));
410 assert!(error_msg.contains("Binary"));
411 assert!(error_msg.contains("Int32"));
412 }
413
414 #[test]
415 fn test_validate_struct_compatibility_compatible_types() {
416 let source_fields = vec![
418 arc_field("field1", DataType::Int32),
419 arc_field("field2", DataType::Utf8),
420 ];
421
422 let target_fields = vec![arc_field("field1", DataType::Int64)];
424
425 let result = validate_struct_compatibility(&source_fields, &target_fields);
426 assert!(result.is_ok());
427 }
428
429 #[test]
430 fn test_validate_struct_compatibility_missing_field_in_source() {
431 let source_fields = vec![arc_field("field2", DataType::Utf8)];
433
434 let target_fields = vec![arc_field("field1", DataType::Int32)];
436
437 let result = validate_struct_compatibility(&source_fields, &target_fields);
439 assert!(result.is_ok());
440 }
441
442 #[test]
443 fn test_validate_struct_compatibility_additional_field_in_source() {
444 let source_fields = vec![
446 arc_field("field1", DataType::Int32),
447 arc_field("field2", DataType::Utf8),
448 ];
449
450 let target_fields = vec![arc_field("field1", DataType::Int32)];
452
453 let result = validate_struct_compatibility(&source_fields, &target_fields);
455 assert!(result.is_ok());
456 }
457
458 #[test]
459 fn test_cast_struct_parent_nulls_retained() {
460 let a_array = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef;
461 let fields = vec![arc_field("a", DataType::Int32)];
462 let nulls = Some(NullBuffer::from(vec![true, false]));
463 let source_struct = StructArray::new(fields.clone().into(), vec![a_array], nulls);
464 let source_col = Arc::new(source_struct) as ArrayRef;
465
466 let target_field = struct_field("s", vec![field("a", DataType::Int64)]);
467
468 let result =
469 cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap();
470 let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
471 assert_eq!(struct_array.null_count(), 1);
472 assert!(struct_array.is_valid(0));
473 assert!(struct_array.is_null(1));
474
475 let a_result = get_column_as!(&struct_array, "a", Int64Array);
476 assert_eq!(a_result.value(0), 1);
477 assert_eq!(a_result.value(1), 2);
478 }
479
480 #[test]
481 fn test_validate_struct_compatibility_nullable_to_non_nullable() {
482 let source_fields = vec![arc_field("field1", DataType::Int32)];
484
485 let target_fields = vec![Arc::new(non_null_field("field1", DataType::Int32))];
487
488 let result = validate_struct_compatibility(&source_fields, &target_fields);
489 assert!(result.is_err());
490 let error_msg = result.unwrap_err().to_string();
491 assert!(error_msg.contains("field1"));
492 assert!(error_msg.contains("non-nullable"));
493 }
494
495 #[test]
496 fn test_validate_struct_compatibility_non_nullable_to_nullable() {
497 let source_fields = vec![Arc::new(non_null_field("field1", DataType::Int32))];
499
500 let target_fields = vec![arc_field("field1", DataType::Int32)];
502
503 let result = validate_struct_compatibility(&source_fields, &target_fields);
504 assert!(result.is_ok());
505 }
506
507 #[test]
508 fn test_validate_struct_compatibility_nested_nullable_to_non_nullable() {
509 let source_fields = vec![Arc::new(non_null_field(
511 "field1",
512 struct_type(vec![field("nested", DataType::Int32)]),
513 ))];
514
515 let target_fields = vec![Arc::new(non_null_field(
517 "field1",
518 struct_type(vec![non_null_field("nested", DataType::Int32)]),
519 ))];
520
521 let result = validate_struct_compatibility(&source_fields, &target_fields);
522 assert!(result.is_err());
523 let error_msg = result.unwrap_err().to_string();
524 assert!(error_msg.contains("nested"));
525 assert!(error_msg.contains("non-nullable"));
526 }
527
528 #[test]
529 fn test_cast_nested_struct_with_extra_and_missing_fields() {
530 let a = Arc::new(Int32Array::from(vec![Some(1), None])) as ArrayRef;
532 let b = Arc::new(Int32Array::from(vec![Some(2), Some(3)])) as ArrayRef;
533 let extra = Arc::new(Int32Array::from(vec![Some(9), Some(10)])) as ArrayRef;
534
535 let inner = StructArray::from(vec![
536 (arc_field("a", DataType::Int32), a),
537 (arc_field("b", DataType::Int32), b),
538 (arc_field("extra", DataType::Int32), extra),
539 ]);
540
541 let source_struct = StructArray::from(vec![(
542 arc_struct_field(
543 "inner",
544 vec![
545 field("a", DataType::Int32),
546 field("b", DataType::Int32),
547 field("extra", DataType::Int32),
548 ],
549 ),
550 Arc::new(inner) as ArrayRef,
551 )]);
552 let source_col = Arc::new(source_struct) as ArrayRef;
553
554 let target_field = struct_field(
556 "outer",
557 vec![struct_field(
558 "inner",
559 vec![
560 field("b", DataType::Int64),
561 field("a", DataType::Int32),
562 field("missing", DataType::Int32),
563 ],
564 )],
565 );
566
567 let result =
568 cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap();
569 let outer = result.as_any().downcast_ref::<StructArray>().unwrap();
570 let inner = get_column_as!(&outer, "inner", StructArray);
571 assert_eq!(inner.fields().len(), 3);
572
573 let b = get_column_as!(inner, "b", Int64Array);
574 assert_eq!(b.value(0), 2);
575 assert_eq!(b.value(1), 3);
576 assert!(!b.is_null(0));
577 assert!(!b.is_null(1));
578
579 let a = get_column_as!(inner, "a", Int32Array);
580 assert_eq!(a.value(0), 1);
581 assert!(a.is_null(1));
582
583 let missing = get_column_as!(inner, "missing", Int32Array);
584 assert!(missing.is_null(0));
585 assert!(missing.is_null(1));
586 }
587
588 #[test]
589 fn test_cast_struct_with_array_and_map_fields() {
590 let arr_array = Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
592 Some(vec![Some(1), Some(2)]),
593 None,
594 ])) as ArrayRef;
595
596 let string_builder = StringBuilder::new();
598 let int_builder = Int32Builder::new();
599 let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
600 map_builder.keys().append_value("a");
601 map_builder.values().append_value(1);
602 map_builder.append(true).unwrap();
603 map_builder.append(false).unwrap();
604 let map_array = Arc::new(map_builder.finish()) as ArrayRef;
605
606 let source_struct = StructArray::from(vec![
607 (
608 arc_field(
609 "arr",
610 DataType::List(Arc::new(field("item", DataType::Int32))),
611 ),
612 arr_array,
613 ),
614 (
615 arc_field(
616 "map",
617 DataType::Map(
618 Arc::new(non_null_field(
619 "entries",
620 struct_type(vec![
621 non_null_field("keys", DataType::Utf8),
622 field("values", DataType::Int32),
623 ]),
624 )),
625 false,
626 ),
627 ),
628 map_array,
629 ),
630 ]);
631 let source_col = Arc::new(source_struct) as ArrayRef;
632
633 let target_field = struct_field(
634 "s",
635 vec![
636 field(
637 "arr",
638 DataType::List(Arc::new(field("item", DataType::Int32))),
639 ),
640 field(
641 "map",
642 DataType::Map(
643 Arc::new(non_null_field(
644 "entries",
645 struct_type(vec![
646 non_null_field("keys", DataType::Utf8),
647 field("values", DataType::Int32),
648 ]),
649 )),
650 false,
651 ),
652 ),
653 ],
654 );
655
656 let result =
657 cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap();
658 let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
659
660 let arr = get_column_as!(&struct_array, "arr", ListArray);
661 assert!(!arr.is_null(0));
662 assert!(arr.is_null(1));
663 let arr0 = arr.value(0);
664 let values = arr0.as_any().downcast_ref::<Int32Array>().unwrap();
665 assert_eq!(values.value(0), 1);
666 assert_eq!(values.value(1), 2);
667
668 let map = get_column_as!(&struct_array, "map", MapArray);
669 assert!(!map.is_null(0));
670 assert!(map.is_null(1));
671 let map0 = map.value(0);
672 let entries = map0.as_any().downcast_ref::<StructArray>().unwrap();
673 let keys = get_column_as!(entries, "keys", StringArray);
674 let vals = get_column_as!(entries, "values", Int32Array);
675 assert_eq!(keys.value(0), "a");
676 assert_eq!(vals.value(0), 1);
677 }
678
679 #[test]
680 fn test_cast_struct_field_order_differs() {
681 let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef;
682 let b = Arc::new(Int32Array::from(vec![Some(3), None])) as ArrayRef;
683
684 let source_struct = StructArray::from(vec![
685 (arc_field("a", DataType::Int32), a),
686 (arc_field("b", DataType::Int32), b),
687 ]);
688 let source_col = Arc::new(source_struct) as ArrayRef;
689
690 let target_field = struct_field(
691 "s",
692 vec![field("b", DataType::Int64), field("a", DataType::Int32)],
693 );
694
695 let result =
696 cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap();
697 let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();
698
699 let b_col = get_column_as!(&struct_array, "b", Int64Array);
700 assert_eq!(b_col.value(0), 3);
701 assert!(b_col.is_null(1));
702
703 let a_col = get_column_as!(&struct_array, "a", Int32Array);
704 assert_eq!(a_col.value(0), 1);
705 assert_eq!(a_col.value(1), 2);
706 }
707}