use arrow::array::{new_null_array, RecordBatch, RecordBatchOptions};
use arrow::compute::{can_cast_types, cast};
use arrow::datatypes::{Field, Schema, SchemaRef};
use datafusion_common::{plan_err, ColumnStatistics};
use std::fmt::Debug;
use std::sync::Arc;
pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static {
fn create(
&self,
projected_table_schema: SchemaRef,
table_schema: SchemaRef,
) -> Box<dyn SchemaAdapter>;
}
pub trait SchemaAdapter: Send + Sync {
fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option<usize>;
fn map_schema(
&self,
file_schema: &Schema,
) -> datafusion_common::Result<(Arc<dyn SchemaMapper>, Vec<usize>)>;
}
pub trait SchemaMapper: Debug + Send + Sync {
fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result<RecordBatch>;
fn map_column_statistics(
&self,
file_col_statistics: &[ColumnStatistics],
) -> datafusion_common::Result<Vec<ColumnStatistics>>;
}
#[derive(Clone, Debug, Default)]
pub struct DefaultSchemaAdapterFactory;
impl DefaultSchemaAdapterFactory {
pub fn from_schema(table_schema: SchemaRef) -> Box<dyn SchemaAdapter> {
Self.create(Arc::clone(&table_schema), table_schema)
}
}
impl SchemaAdapterFactory for DefaultSchemaAdapterFactory {
fn create(
&self,
projected_table_schema: SchemaRef,
_table_schema: SchemaRef,
) -> Box<dyn SchemaAdapter> {
Box::new(DefaultSchemaAdapter {
projected_table_schema,
})
}
}
#[derive(Clone, Debug)]
pub(crate) struct DefaultSchemaAdapter {
projected_table_schema: SchemaRef,
}
pub(crate) fn can_cast_field(
file_field: &Field,
table_field: &Field,
) -> datafusion_common::Result<bool> {
if can_cast_types(file_field.data_type(), table_field.data_type()) {
Ok(true)
} else {
plan_err!(
"Cannot cast file schema field {} of type {:?} to table schema field of type {:?}",
file_field.name(),
file_field.data_type(),
table_field.data_type()
)
}
}
impl SchemaAdapter for DefaultSchemaAdapter {
fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option<usize> {
let field = self.projected_table_schema.field(index);
Some(file_schema.fields.find(field.name())?.0)
}
fn map_schema(
&self,
file_schema: &Schema,
) -> datafusion_common::Result<(Arc<dyn SchemaMapper>, Vec<usize>)> {
let (field_mappings, projection) = create_field_mapping(
file_schema,
&self.projected_table_schema,
can_cast_field,
)?;
Ok((
Arc::new(SchemaMapping::new(
Arc::clone(&self.projected_table_schema),
field_mappings,
)),
projection,
))
}
}
pub(crate) fn create_field_mapping<F>(
file_schema: &Schema,
projected_table_schema: &SchemaRef,
can_map_field: F,
) -> datafusion_common::Result<(Vec<Option<usize>>, Vec<usize>)>
where
F: Fn(&Field, &Field) -> datafusion_common::Result<bool>,
{
let mut projection = Vec::with_capacity(file_schema.fields().len());
let mut field_mappings = vec![None; projected_table_schema.fields().len()];
for (file_idx, file_field) in file_schema.fields.iter().enumerate() {
if let Some((table_idx, table_field)) =
projected_table_schema.fields().find(file_field.name())
{
if can_map_field(file_field, table_field)? {
field_mappings[table_idx] = Some(projection.len());
projection.push(file_idx);
}
}
}
Ok((field_mappings, projection))
}
#[derive(Debug)]
pub struct SchemaMapping {
projected_table_schema: SchemaRef,
field_mappings: Vec<Option<usize>>,
}
impl SchemaMapping {
pub fn new(
projected_table_schema: SchemaRef,
field_mappings: Vec<Option<usize>>,
) -> Self {
Self {
projected_table_schema,
field_mappings,
}
}
}
impl SchemaMapper for SchemaMapping {
fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result<RecordBatch> {
let batch_rows = batch.num_rows();
let batch_cols = batch.columns().to_vec();
let cols = self
.projected_table_schema
.fields()
.iter()
.zip(&self.field_mappings)
.map(|(field, file_idx)| {
file_idx.map_or_else(
|| Ok(new_null_array(field.data_type(), batch_rows)),
|batch_idx| cast(&batch_cols[batch_idx], field.data_type()),
)
})
.collect::<datafusion_common::Result<Vec<_>, _>>()?;
let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
let schema = Arc::clone(&self.projected_table_schema);
let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?;
Ok(record_batch)
}
fn map_column_statistics(
&self,
file_col_statistics: &[ColumnStatistics],
) -> datafusion_common::Result<Vec<ColumnStatistics>> {
let mut table_col_statistics = vec![];
for (_, file_col_idx) in self
.projected_table_schema
.fields()
.iter()
.zip(&self.field_mappings)
{
if let Some(file_col_idx) = file_col_idx {
table_col_statistics.push(
file_col_statistics
.get(*file_col_idx)
.cloned()
.unwrap_or_default(),
);
} else {
table_col_statistics.push(ColumnStatistics::new_unknown());
}
}
Ok(table_col_statistics)
}
}
#[cfg(test)]
mod tests {
use arrow::datatypes::{DataType, Field};
use datafusion_common::{stats::Precision, Statistics};
use super::*;
#[test]
fn test_schema_mapping_map_statistics_basic() {
let table_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
Field::new("c", DataType::Float64, true),
]));
let file_schema = Schema::new(vec![
Field::new("b", DataType::Utf8, true),
Field::new("a", DataType::Int32, true),
]);
let adapter = DefaultSchemaAdapter {
projected_table_schema: Arc::clone(&table_schema),
};
let (mapper, projection) = adapter.map_schema(&file_schema).unwrap();
assert_eq!(projection, vec![0, 1]);
let mut file_stats = Statistics::default();
let b_stats = ColumnStatistics {
null_count: Precision::Exact(5),
..Default::default()
};
let a_stats = ColumnStatistics {
null_count: Precision::Exact(10),
..Default::default()
};
file_stats.column_statistics = vec![b_stats, a_stats];
let table_col_stats = mapper
.map_column_statistics(&file_stats.column_statistics)
.unwrap();
assert_eq!(table_col_stats.len(), 3);
assert_eq!(table_col_stats[0].null_count, Precision::Exact(10)); assert_eq!(table_col_stats[1].null_count, Precision::Exact(5)); assert_eq!(table_col_stats[2].null_count, Precision::Absent); }
#[test]
fn test_schema_mapping_map_statistics_empty() {
let table_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
]));
let file_schema = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
]);
let adapter = DefaultSchemaAdapter {
projected_table_schema: Arc::clone(&table_schema),
};
let (mapper, _) = adapter.map_schema(&file_schema).unwrap();
let file_stats = Statistics::default();
let table_col_stats = mapper
.map_column_statistics(&file_stats.column_statistics)
.unwrap();
assert_eq!(table_col_stats.len(), 2);
assert_eq!(table_col_stats[0], ColumnStatistics::new_unknown(),);
assert_eq!(table_col_stats[1], ColumnStatistics::new_unknown(),);
}
#[test]
fn test_can_cast_field() {
let from_field = Field::new("col", DataType::Int32, true);
let to_field = Field::new("col", DataType::Int32, true);
assert!(can_cast_field(&from_field, &to_field).unwrap());
let from_field = Field::new("col", DataType::Int32, true);
let to_field = Field::new("col", DataType::Float64, true);
assert!(can_cast_field(&from_field, &to_field).unwrap());
let from_field = Field::new("col", DataType::Float64, true);
let to_field = Field::new("col", DataType::Utf8, true);
assert!(can_cast_field(&from_field, &to_field).unwrap());
let from_field = Field::new("col", DataType::Binary, true);
let to_field = Field::new("col", DataType::Decimal128(10, 2), true);
let result = can_cast_field(&from_field, &to_field);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("Cannot cast file schema field col"));
}
#[test]
fn test_create_field_mapping() {
let table_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
Field::new("c", DataType::Float64, true),
]));
let file_schema = Schema::new(vec![
Field::new("b", DataType::Float64, true), Field::new("a", DataType::Int32, true), Field::new("d", DataType::Boolean, true), ]);
let allow_all = |_: &Field, _: &Field| Ok(true);
let (field_mappings, projection) =
create_field_mapping(&file_schema, &table_schema, allow_all).unwrap();
assert_eq!(field_mappings, vec![Some(1), Some(0), None]);
assert_eq!(projection, vec![0, 1]);
let fails_all = |_: &Field, _: &Field| Ok(false);
let (field_mappings, projection) =
create_field_mapping(&file_schema, &table_schema, fails_all).unwrap();
assert_eq!(field_mappings, vec![None, None, None]);
assert_eq!(projection, Vec::<usize>::new());
let error_mapper = |_: &Field, _: &Field| plan_err!("Test error");
let result = create_field_mapping(&file_schema, &table_schema, error_mapper);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Test error"));
}
#[test]
fn test_schema_mapping_new() {
let projected_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
]));
let field_mappings = vec![Some(1), Some(0)];
let mapping =
SchemaMapping::new(Arc::clone(&projected_schema), field_mappings.clone());
assert_eq!(*mapping.projected_table_schema, *projected_schema);
assert_eq!(mapping.field_mappings, field_mappings);
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("b_file", DataType::Utf8, true),
Field::new("a_file", DataType::Int32, true),
])),
vec![
Arc::new(arrow::array::StringArray::from(vec!["hello", "world"])),
Arc::new(arrow::array::Int32Array::from(vec![1, 2])),
],
)
.unwrap();
let mapped_batch = mapping.map_batch(batch).unwrap();
assert_eq!(*mapped_batch.schema(), *projected_schema);
assert_eq!(mapped_batch.num_columns(), 2);
assert_eq!(mapped_batch.column(0).len(), 2); assert_eq!(mapped_batch.column(1).len(), 2); }
#[test]
fn test_map_schema_error_path() {
let table_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
Field::new("c", DataType::Decimal128(10, 2), true), ]));
let file_schema = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Float64, true), Field::new("c", DataType::Binary, true), ]);
let adapter = DefaultSchemaAdapter {
projected_table_schema: Arc::clone(&table_schema),
};
let result = adapter.map_schema(&file_schema);
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(error_msg.contains("Cannot cast file schema field c"));
}
#[test]
fn test_map_schema_happy_path() {
let table_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
Field::new("c", DataType::Decimal128(10, 2), true),
]));
let adapter = DefaultSchemaAdapter {
projected_table_schema: Arc::clone(&table_schema),
};
let compatible_file_schema = Schema::new(vec![
Field::new("a", DataType::Int64, true), Field::new("b", DataType::Float64, true), ]);
let (mapper, projection) = adapter.map_schema(&compatible_file_schema).unwrap();
assert_eq!(projection, vec![0, 1]);
let file_batch = RecordBatch::try_new(
Arc::new(compatible_file_schema.clone()),
vec![
Arc::new(arrow::array::Int64Array::from(vec![100, 200])),
Arc::new(arrow::array::Float64Array::from(vec![1.5, 2.5])),
],
)
.unwrap();
let mapped_batch = mapper.map_batch(file_batch).unwrap();
assert_eq!(*mapped_batch.schema(), *table_schema);
assert_eq!(mapped_batch.num_columns(), 3);
let c_array = mapped_batch.column(2);
assert_eq!(c_array.len(), 2);
assert_eq!(c_array.null_count(), 2);
}
}