use arrow::compute::{can_cast_types, cast};
use arrow_array::{new_null_array, RecordBatch, RecordBatchOptions};
use arrow_schema::{Schema, SchemaRef};
use datafusion_common::plan_err;
use std::fmt::Debug;
use std::sync::Arc;
pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static {
fn create(
&self,
projected_table_schema: SchemaRef,
table_schema: SchemaRef,
) -> Box<dyn SchemaAdapter>;
}
pub trait SchemaAdapter: Send + Sync {
fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option<usize>;
fn map_schema(
&self,
file_schema: &Schema,
) -> datafusion_common::Result<(Arc<dyn SchemaMapper>, Vec<usize>)>;
}
pub trait SchemaMapper: Debug + Send + Sync {
fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result<RecordBatch>;
fn map_partial_batch(
&self,
batch: RecordBatch,
) -> datafusion_common::Result<RecordBatch>;
}
#[derive(Clone, Debug, Default)]
pub struct DefaultSchemaAdapterFactory;
impl DefaultSchemaAdapterFactory {
pub fn from_schema(table_schema: SchemaRef) -> Box<dyn SchemaAdapter> {
Self.create(Arc::clone(&table_schema), table_schema)
}
}
impl SchemaAdapterFactory for DefaultSchemaAdapterFactory {
fn create(
&self,
projected_table_schema: SchemaRef,
table_schema: SchemaRef,
) -> Box<dyn SchemaAdapter> {
Box::new(DefaultSchemaAdapter {
projected_table_schema,
table_schema,
})
}
}
#[derive(Clone, Debug)]
pub(crate) struct DefaultSchemaAdapter {
projected_table_schema: SchemaRef,
table_schema: SchemaRef,
}
impl SchemaAdapter for DefaultSchemaAdapter {
fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option<usize> {
let field = self.projected_table_schema.field(index);
Some(file_schema.fields.find(field.name())?.0)
}
fn map_schema(
&self,
file_schema: &Schema,
) -> datafusion_common::Result<(Arc<dyn SchemaMapper>, Vec<usize>)> {
let mut projection = Vec::with_capacity(file_schema.fields().len());
let mut field_mappings = vec![None; self.projected_table_schema.fields().len()];
for (file_idx, file_field) in file_schema.fields.iter().enumerate() {
if let Some((table_idx, table_field)) =
self.projected_table_schema.fields().find(file_field.name())
{
match can_cast_types(file_field.data_type(), table_field.data_type()) {
true => {
field_mappings[table_idx] = Some(projection.len());
projection.push(file_idx);
}
false => {
return plan_err!(
"Cannot cast file schema field {} of type {:?} to table schema field of type {:?}",
file_field.name(),
file_field.data_type(),
table_field.data_type()
)
}
}
}
}
Ok((
Arc::new(SchemaMapping {
projected_table_schema: Arc::clone(&self.projected_table_schema),
field_mappings,
table_schema: Arc::clone(&self.table_schema),
}),
projection,
))
}
}
#[derive(Debug)]
pub struct SchemaMapping {
projected_table_schema: SchemaRef,
field_mappings: Vec<Option<usize>>,
table_schema: SchemaRef,
}
impl SchemaMapper for SchemaMapping {
fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result<RecordBatch> {
let batch_rows = batch.num_rows();
let batch_cols = batch.columns().to_vec();
let cols = self
.projected_table_schema
.fields()
.iter()
.zip(&self.field_mappings)
.map(|(field, file_idx)| {
file_idx.map_or_else(
|| Ok(new_null_array(field.data_type(), batch_rows)),
|batch_idx| cast(&batch_cols[batch_idx], field.data_type()),
)
})
.collect::<datafusion_common::Result<Vec<_>, _>>()?;
let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
let schema = Arc::clone(&self.projected_table_schema);
let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?;
Ok(record_batch)
}
fn map_partial_batch(
&self,
batch: RecordBatch,
) -> datafusion_common::Result<RecordBatch> {
let batch_cols = batch.columns().to_vec();
let schema = batch.schema();
let (cols, fields) = schema
.fields()
.iter()
.zip(batch_cols.iter())
.flat_map(|(field, batch_col)| {
self.table_schema
.field_with_name(field.name())
.ok()
.map(|table_field| {
cast(batch_col, table_field.data_type())
.map(|new_col| (new_col, table_field.clone()))
})
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.unzip::<_, _, Vec<_>, Vec<_>>();
let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
let schema =
Arc::new(Schema::new_with_metadata(fields, schema.metadata().clone()));
let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?;
Ok(record_batch)
}
}
#[cfg(test)]
mod tests {
use std::fs;
use std::sync::Arc;
use crate::assert_batches_sorted_eq;
use arrow::datatypes::{Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow_array::{Int32Array, StringArray};
use arrow_schema::{DataType, SchemaRef};
use object_store::path::Path;
use object_store::ObjectMeta;
use crate::datasource::object_store::ObjectStoreUrl;
use crate::datasource::physical_plan::{FileScanConfig, ParquetExec};
use crate::physical_plan::collect;
use crate::prelude::SessionContext;
use crate::datasource::listing::PartitionedFile;
use crate::datasource::schema_adapter::{
DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, SchemaMapper,
};
use datafusion_common::record_batch;
#[cfg(feature = "parquet")]
use parquet::arrow::ArrowWriter;
use tempfile::TempDir;
#[tokio::test]
async fn can_override_schema_adapter() {
let tmp_dir = TempDir::new().unwrap();
let table_dir = tmp_dir.path().join("parquet_test");
fs::DirBuilder::new().create(table_dir.as_path()).unwrap();
let f1 = Field::new("id", DataType::Int32, true);
let file_schema = Arc::new(Schema::new(vec![f1.clone()]));
let filename = "part.parquet".to_string();
let path = table_dir.as_path().join(filename.clone());
let file = fs::File::create(path.clone()).unwrap();
let mut writer = ArrowWriter::try_new(file, file_schema.clone(), None).unwrap();
let ids = Arc::new(Int32Array::from(vec![1i32]));
let rec_batch = RecordBatch::try_new(file_schema.clone(), vec![ids]).unwrap();
writer.write(&rec_batch).unwrap();
writer.close().unwrap();
let location = Path::parse(path.to_str().unwrap()).unwrap();
let metadata = fs::metadata(path.as_path()).expect("Local file metadata");
let meta = ObjectMeta {
location,
last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(),
size: metadata.len() as usize,
e_tag: None,
version: None,
};
let partitioned_file = PartitionedFile {
object_meta: meta,
partition_values: vec![],
range: None,
statistics: None,
extensions: None,
metadata_size_hint: None,
};
let f1 = Field::new("id", DataType::Int32, true);
let f2 = Field::new("extra_column", DataType::Utf8, true);
let schema = Arc::new(Schema::new(vec![f1.clone(), f2.clone()]));
let parquet_exec = ParquetExec::builder(
FileScanConfig::new(ObjectStoreUrl::local_filesystem(), schema)
.with_file(partitioned_file),
)
.build()
.with_schema_adapter_factory(Arc::new(TestSchemaAdapterFactory {}));
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let read = collect(Arc::new(parquet_exec), task_ctx).await.unwrap();
let expected = [
"+----+--------------+",
"| id | extra_column |",
"+----+--------------+",
"| 1 | foo |",
"+----+--------------+",
];
assert_batches_sorted_eq!(expected, &read);
}
#[test]
fn default_schema_adapter() {
let table_schema = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
]);
let file_schema = Schema::new(vec![
Field::new("c", DataType::Float64, true), Field::new("b", DataType::Float64, true),
]);
let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema));
let (mapper, indices) = adapter.map_schema(&file_schema).unwrap();
assert_eq!(indices, vec![1]);
let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap();
let mapped_batch = mapper.map_batch(file_batch).unwrap();
let expected_batch = record_batch!(
("a", Int32, vec![None, None]), ("b", Utf8, vec!["1.0", "2.0"]) )
.unwrap();
assert_eq!(mapped_batch, expected_batch);
}
#[test]
fn default_schema_adapter_non_nullable_columns() {
let table_schema = Schema::new(vec![
Field::new("a", DataType::Int32, false), Field::new("b", DataType::Utf8, true),
]);
let file_schema = Schema::new(vec![
Field::new("b", DataType::Float64, true),
]);
let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema));
let (mapper, indices) = adapter.map_schema(&file_schema).unwrap();
assert_eq!(indices, vec![0]);
let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap();
let err = mapper.map_batch(file_batch).unwrap_err().to_string();
assert!(err.contains("Invalid argument error: Column 'a' is declared as non-nullable but contains null values"), "{err}");
}
#[derive(Debug)]
struct TestSchemaAdapterFactory;
impl SchemaAdapterFactory for TestSchemaAdapterFactory {
fn create(
&self,
projected_table_schema: SchemaRef,
_table_schema: SchemaRef,
) -> Box<dyn SchemaAdapter> {
Box::new(TestSchemaAdapter {
table_schema: projected_table_schema,
})
}
}
struct TestSchemaAdapter {
table_schema: SchemaRef,
}
impl SchemaAdapter for TestSchemaAdapter {
fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option<usize> {
let field = self.table_schema.field(index);
Some(file_schema.fields.find(field.name())?.0)
}
fn map_schema(
&self,
file_schema: &Schema,
) -> datafusion_common::Result<(Arc<dyn SchemaMapper>, Vec<usize>)> {
let mut projection = Vec::with_capacity(file_schema.fields().len());
for (file_idx, file_field) in file_schema.fields.iter().enumerate() {
if self.table_schema.fields().find(file_field.name()).is_some() {
projection.push(file_idx);
}
}
Ok((Arc::new(TestSchemaMapping {}), projection))
}
}
#[derive(Debug)]
struct TestSchemaMapping {}
impl SchemaMapper for TestSchemaMapping {
fn map_batch(
&self,
batch: RecordBatch,
) -> datafusion_common::Result<RecordBatch> {
let f1 = Field::new("id", DataType::Int32, true);
let f2 = Field::new("extra_column", DataType::Utf8, true);
let schema = Arc::new(Schema::new(vec![f1, f2]));
let extra_column = Arc::new(StringArray::from(vec!["foo"]));
let mut new_columns = batch.columns().to_vec();
new_columns.push(extra_column);
Ok(RecordBatch::try_new(schema, new_columns).unwrap())
}
fn map_partial_batch(
&self,
batch: RecordBatch,
) -> datafusion_common::Result<RecordBatch> {
self.map_batch(batch)
}
}
}