use std::borrow::Borrow;
use std::collections::HashMap;
use std::sync::Arc;
use itertools::Itertools;
use crate::arrow::array::{
Array, ArrayRef, AsArray, ListArray, MapArray, RecordBatch, StructArray,
};
use crate::arrow::datatypes::Schema as ArrowSchema;
use crate::arrow::datatypes::{DataType as ArrowDataType, Field as ArrowField};
use super::super::arrow_conversion::kernel_metadata_to_arrow_metadata;
use super::super::arrow_utils::make_arrow_error;
use crate::engine::ensure_data_types::{ensure_data_types, ValidationMode};
use crate::error::{DeltaResult, Error};
use crate::parquet::arrow::PARQUET_FIELD_ID_META_KEY;
use crate::schema::{ArrayType, DataType, MapType, Schema, StructField};
pub(crate) fn apply_schema(array: &dyn Array, schema: &DataType) -> DeltaResult<RecordBatch> {
let DataType::Struct(struct_schema) = schema else {
return Err(Error::generic(
"apply_schema at top-level must be passed a struct schema",
));
};
let applied = apply_schema_to_struct(array, struct_schema)?;
let (fields, columns, _nulls) = applied.into_parts();
Ok(RecordBatch::try_new(
Arc::new(ArrowSchema::new(fields)),
columns,
)?)
}
fn new_field_with_metadata(
field_name: &str,
data_type: &ArrowDataType,
nullable: bool,
metadata: Option<HashMap<String, String>>,
) -> ArrowField {
let mut field = ArrowField::new(field_name, data_type.clone(), nullable);
if let Some(metadata) = metadata {
field.set_metadata(metadata);
};
field
}
fn transform_struct(
struct_array: &StructArray,
target_fields: impl Iterator<Item = impl Borrow<StructField>>,
) -> DeltaResult<StructArray> {
let (input_fields, arrow_cols, nulls) = struct_array.clone().into_parts();
let input_col_count = arrow_cols.len();
let result_iter = arrow_cols
.into_iter()
.zip(input_fields.iter())
.zip(target_fields)
.map(|((sa_col, input_field), target_field)| -> DeltaResult<_> {
let target_field = target_field.borrow();
let transformed_col = apply_schema_to(&sa_col, target_field.data_type())?;
let arrow_metadata = kernel_metadata_to_arrow_metadata(target_field)?;
if let (Some(input_id), Some(target_id)) = (
input_field.metadata().get(PARQUET_FIELD_ID_META_KEY),
arrow_metadata.get(PARQUET_FIELD_ID_META_KEY),
) {
if input_id != target_id {
return Err(Error::generic(format!(
"Field '{}': input field ID {} conflicts with target field ID {}",
target_field.name, input_id, target_id
)));
}
}
let transformed_field = new_field_with_metadata(
&target_field.name,
transformed_col.data_type(),
target_field.nullable,
Some(arrow_metadata),
);
Ok((transformed_field, transformed_col))
});
let (transformed_fields, transformed_cols): (Vec<ArrowField>, Vec<ArrayRef>) =
result_iter.process_results(|iter| iter.unzip())?;
if transformed_cols.len() != input_col_count {
return Err(Error::InternalError(format!(
"Passed struct had {input_col_count} columns, but transformed column has {}",
transformed_cols.len()
)));
}
Ok(StructArray::try_new(
transformed_fields.into(),
transformed_cols,
nulls,
)?)
}
fn apply_schema_to_struct(array: &dyn Array, kernel_fields: &Schema) -> DeltaResult<StructArray> {
let Some(sa) = array.as_struct_opt() else {
return Err(make_arrow_error(
"Arrow claimed to be a struct but isn't a StructArray",
));
};
transform_struct(sa, kernel_fields.fields())
}
fn apply_schema_to_list(
array: &dyn Array,
target_inner_type: &ArrayType,
) -> DeltaResult<ListArray> {
let Some(la) = array.as_list_opt() else {
return Err(make_arrow_error(
"Arrow claimed to be a list but isn't a ListArray",
));
};
let (field, offset_buffer, values, nulls) = la.clone().into_parts();
let transformed_values = apply_schema_to(&values, &target_inner_type.element_type)?;
let transformed_field = ArrowField::new(
field.name(),
transformed_values.data_type().clone(),
target_inner_type.contains_null,
);
Ok(ListArray::try_new(
Arc::new(transformed_field),
offset_buffer,
transformed_values,
nulls,
)?)
}
fn apply_schema_to_map(array: &dyn Array, kernel_map_type: &MapType) -> DeltaResult<MapArray> {
let Some(ma) = array.as_map_opt() else {
return Err(make_arrow_error(
"Arrow claimed to be a map but isn't a MapArray",
));
};
let (map_field, offset_buffer, map_struct_array, nulls, ordered) = ma.clone().into_parts();
let target_fields = map_struct_array
.fields()
.iter()
.zip([&kernel_map_type.key_type, &kernel_map_type.value_type])
.zip([false, kernel_map_type.value_contains_null])
.map(|((arrow_field, target_type), nullable)| {
StructField::new(arrow_field.name(), target_type.clone(), nullable)
});
let transformed_map_struct_array = transform_struct(&map_struct_array, target_fields)?;
let transformed_map_field = ArrowField::new(
map_field.name().clone(),
transformed_map_struct_array.data_type().clone(),
map_field.is_nullable(),
);
Ok(MapArray::try_new(
Arc::new(transformed_map_field),
offset_buffer,
transformed_map_struct_array,
nulls,
ordered,
)?)
}
pub(crate) fn apply_schema_to(array: &ArrayRef, schema: &DataType) -> DeltaResult<ArrayRef> {
use DataType::*;
let array: ArrayRef = match schema {
Struct(stype) => Arc::new(apply_schema_to_struct(array, stype)?),
Array(atype) => Arc::new(apply_schema_to_list(array, atype)?),
Map(mtype) => Arc::new(apply_schema_to_map(array, mtype)?),
_ => {
ensure_data_types(schema, array.data_type(), ValidationMode::Full)?;
array.clone()
}
};
Ok(array)
}
#[cfg(test)]
mod apply_schema_validation_tests {
use super::*;
use std::sync::Arc;
use crate::arrow::array::{Int32Array, StructArray};
use crate::arrow::buffer::{BooleanBuffer, NullBuffer};
use crate::arrow::datatypes::{
DataType as ArrowDataType, Field as ArrowField, Schema as ArrowSchema,
};
use crate::parquet::arrow::PARQUET_FIELD_ID_META_KEY;
use crate::schema::{ColumnMetadataKey, DataType, MetadataValue, StructField, StructType};
use crate::utils::test_utils::assert_result_error_with_message;
#[test]
fn test_apply_schema_basic_functionality() {
let input_array = create_test_struct_array_2_fields();
let target_schema = create_target_schema_2_fields();
let result = apply_schema_to_struct(&input_array, &target_schema);
assert!(result.is_ok(), "Basic schema application should succeed");
let result_array = result.unwrap();
assert_eq!(
result_array.len(),
input_array.len(),
"Row count should be preserved"
);
assert_eq!(result_array.num_columns(), 2, "Should have 2 columns");
}
fn create_test_struct_array_2_fields() -> StructArray {
let field1 = ArrowField::new("a", ArrowDataType::Int32, false);
let field2 = ArrowField::new("b", ArrowDataType::Int32, false);
let schema = ArrowSchema::new(vec![field1, field2]);
let a_data = Int32Array::from(vec![1, 2, 3]);
let b_data = Int32Array::from(vec![4, 5, 6]);
StructArray::try_new(
schema.fields.clone(),
vec![Arc::new(a_data), Arc::new(b_data)],
None,
)
.unwrap()
}
fn create_target_schema_2_fields() -> StructType {
StructType::new_unchecked([
StructField::new("a", DataType::INTEGER, false),
StructField::new("b", DataType::INTEGER, false),
])
}
#[test]
fn test_apply_schema_handles_top_level_nulls() {
let field_a = ArrowField::new("a", ArrowDataType::Int32, true);
let field_b = ArrowField::new("b", ArrowDataType::Int32, true);
let schema = ArrowSchema::new(vec![field_a, field_b]);
let a_data = Int32Array::from(vec![Some(1), None, Some(3), None]);
let b_data = Int32Array::from(vec![Some(10), None, Some(30), None]);
let null_buffer = NullBuffer::new(BooleanBuffer::from(vec![true, false, true, false]));
let struct_array = StructArray::try_new(
schema.fields.clone(),
vec![Arc::new(a_data), Arc::new(b_data)],
Some(null_buffer),
)
.unwrap();
let target_schema = DataType::Struct(Box::new(StructType::new_unchecked([
StructField::new("a", DataType::INTEGER, true),
StructField::new("b", DataType::INTEGER, true),
])));
let result = apply_schema(&struct_array, &target_schema).unwrap();
assert_eq!(result.num_rows(), 4);
assert_eq!(result.num_columns(), 2);
let col_a = result.column(0);
assert!(col_a.is_valid(0), "Row 0 should be valid");
assert!(col_a.is_null(1), "Row 1 should be null");
assert!(col_a.is_valid(2), "Row 2 should be valid");
assert!(col_a.is_null(3), "Row 3 should be null");
let col_b = result.column(1);
assert!(col_b.is_valid(0), "Row 0 should be valid");
assert!(col_b.is_null(1), "Row 1 should be null");
assert!(col_b.is_valid(2), "Row 2 should be valid");
assert!(col_b.is_null(3), "Row 3 should be null");
let col_a = col_a
.as_any()
.downcast_ref::<Int32Array>()
.expect("column a should be Int32Array");
let col_b = col_b
.as_any()
.downcast_ref::<Int32Array>()
.expect("column b should be Int32Array");
assert_eq!(col_a.value(0), 1);
assert_eq!(col_a.value(2), 3);
assert_eq!(col_b.value(0), 10);
assert_eq!(col_b.value(2), 30);
}
#[test]
fn test_apply_schema_transforms_parquet_field_id_metadata() {
let field_id_key = ColumnMetadataKey::ParquetFieldId.as_ref();
let target_schema =
StructType::new_unchecked([StructField::new("a", DataType::INTEGER, false)
.with_metadata([(field_id_key.to_string(), MetadataValue::Number(42))])]);
let arrow_field = ArrowField::new("a", ArrowDataType::Int32, false);
let input_array = StructArray::try_new(
vec![arrow_field].into(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
None,
)
.unwrap();
let result = apply_schema_to_struct(&input_array, &target_schema).unwrap();
let (_, output_field) = result.fields().find("a").unwrap();
assert_eq!(
output_field
.metadata()
.get(PARQUET_FIELD_ID_META_KEY)
.map(String::as_str),
Some("42"),
"parquet.field.id should be translated to PARQUET:field_id"
);
assert!(
!output_field.metadata().contains_key(field_id_key),
"original parquet.field.id key should not be present after translation"
);
}
#[test]
fn test_apply_schema_matching_field_ids_succeed() {
let field_id_key = ColumnMetadataKey::ParquetFieldId.as_ref();
let target_schema =
StructType::new_unchecked([StructField::new("a", DataType::INTEGER, false)
.with_metadata([(field_id_key.to_string(), MetadataValue::Number(42))])]);
let arrow_field = ArrowField::new("a", ArrowDataType::Int32, false)
.with_metadata([(PARQUET_FIELD_ID_META_KEY.to_string(), "42".to_string())].into());
let input_array = StructArray::try_new(
vec![arrow_field].into(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
None,
)
.unwrap();
let result = apply_schema_to_struct(&input_array, &target_schema);
assert!(result.is_ok(), "Matching field IDs should succeed");
}
#[test]
fn test_apply_schema_conflicting_field_ids_fail() {
let field_id_key = ColumnMetadataKey::ParquetFieldId.as_ref();
let target_schema =
StructType::new_unchecked([StructField::new("a", DataType::INTEGER, false)
.with_metadata([(field_id_key.to_string(), MetadataValue::Number(42))])]);
let arrow_field = ArrowField::new("a", ArrowDataType::Int32, false)
.with_metadata([(PARQUET_FIELD_ID_META_KEY.to_string(), "99".to_string())].into());
let input_array = StructArray::try_new(
vec![arrow_field].into(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
None,
)
.unwrap();
assert_result_error_with_message(
apply_schema_to_struct(&input_array, &target_schema),
"conflicts with",
);
}
}