use std::borrow::Borrow;
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Arc;
use arrow::array::RecordBatch;
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion_common::{
Result, ScalarValue, exec_err,
nested_struct::validate_struct_compatibility,
tree_node::{Transformed, TransformedResult, TreeNode},
};
use datafusion_functions::core::getfield::GetFieldFunc;
use datafusion_physical_expr::PhysicalExprSimplifier;
use datafusion_physical_expr::expressions::CastColumnExpr;
use datafusion_physical_expr::projection::{ProjectionExprs, Projector};
use datafusion_physical_expr::{
ScalarFunctionExpr,
expressions::{self, Column},
};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use itertools::Itertools;
pub fn replace_columns_with_literals<K, V>(
expr: Arc<dyn PhysicalExpr>,
replacements: &HashMap<K, V>,
) -> Result<Arc<dyn PhysicalExpr>>
where
K: Borrow<str> + Eq + Hash,
V: Borrow<ScalarValue>,
{
expr.transform_down(|expr| {
if let Some(column) = expr.as_any().downcast_ref::<Column>()
&& let Some(replacement_value) = replacements.get(column.name())
{
return Ok(Transformed::yes(expressions::lit(
replacement_value.borrow().clone(),
)));
}
Ok(Transformed::no(expr))
})
.data()
}
pub trait PhysicalExprAdapter: Send + Sync + std::fmt::Debug {
fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>>;
}
pub trait PhysicalExprAdapterFactory: Send + Sync + std::fmt::Debug {
fn create(
&self,
logical_file_schema: SchemaRef,
physical_file_schema: SchemaRef,
) -> Arc<dyn PhysicalExprAdapter>;
}
#[derive(Debug, Clone)]
pub struct DefaultPhysicalExprAdapterFactory;
impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory {
fn create(
&self,
logical_file_schema: SchemaRef,
physical_file_schema: SchemaRef,
) -> Arc<dyn PhysicalExprAdapter> {
Arc::new(DefaultPhysicalExprAdapter {
logical_file_schema,
physical_file_schema,
})
}
}
#[derive(Debug, Clone)]
pub struct DefaultPhysicalExprAdapter {
logical_file_schema: SchemaRef,
physical_file_schema: SchemaRef,
}
impl DefaultPhysicalExprAdapter {
pub fn new(logical_file_schema: SchemaRef, physical_file_schema: SchemaRef) -> Self {
Self {
logical_file_schema,
physical_file_schema,
}
}
}
impl PhysicalExprAdapter for DefaultPhysicalExprAdapter {
fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
let rewriter = DefaultPhysicalExprAdapterRewriter {
logical_file_schema: &self.logical_file_schema,
physical_file_schema: &self.physical_file_schema,
};
expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr)))
.data()
}
}
struct DefaultPhysicalExprAdapterRewriter<'a> {
logical_file_schema: &'a Schema,
physical_file_schema: &'a Schema,
}
impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
fn rewrite_expr(
&self,
expr: Arc<dyn PhysicalExpr>,
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
if let Some(transformed) = self.try_rewrite_struct_field_access(&expr)? {
return Ok(Transformed::yes(transformed));
}
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
return self.rewrite_column(Arc::clone(&expr), column);
}
Ok(Transformed::no(expr))
}
fn try_rewrite_struct_field_access(
&self,
expr: &Arc<dyn PhysicalExpr>,
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
let get_field_expr =
match ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(expr.as_ref()) {
Some(expr) => expr,
None => return Ok(None),
};
let source_expr = match get_field_expr.args().first() {
Some(expr) => expr,
None => return Ok(None),
};
let field_name_expr = match get_field_expr.args().get(1) {
Some(expr) => expr,
None => return Ok(None),
};
let lit = match field_name_expr
.as_any()
.downcast_ref::<expressions::Literal>()
{
Some(lit) => lit,
None => return Ok(None),
};
let field_name = match lit.value().try_as_str().flatten() {
Some(name) => name,
None => return Ok(None),
};
let column = match source_expr.as_any().downcast_ref::<Column>() {
Some(column) => column,
None => return Ok(None),
};
let physical_field =
match self.physical_file_schema.field_with_name(column.name()) {
Ok(field) => field,
Err(_) => return Ok(None),
};
let physical_struct_fields = match physical_field.data_type() {
DataType::Struct(fields) => fields,
_ => return Ok(None),
};
if physical_struct_fields
.iter()
.any(|f| f.name() == field_name)
{
return Ok(None);
}
let logical_field = match self.logical_file_schema.field_with_name(column.name())
{
Ok(field) => field,
Err(_) => return Ok(None),
};
let logical_struct_fields = match logical_field.data_type() {
DataType::Struct(fields) => fields,
_ => return Ok(None),
};
let logical_struct_field = match logical_struct_fields
.iter()
.find(|f| f.name() == field_name)
{
Some(field) => field,
None => return Ok(None),
};
let null_value = ScalarValue::Null.cast_to(logical_struct_field.data_type())?;
Ok(Some(expressions::lit(null_value)))
}
fn rewrite_column(
&self,
expr: Arc<dyn PhysicalExpr>,
column: &Column,
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
let logical_field = match self.logical_file_schema.field_with_name(column.name())
{
Ok(field) => field,
Err(e) => {
if let Ok(physical_field) =
self.physical_file_schema.field_with_name(column.name())
{
physical_field
} else {
return Err(e.into());
}
}
};
let physical_column_index = match self
.physical_file_schema
.index_of(column.name())
{
Ok(index) => index,
Err(_) => {
if !logical_field.is_nullable() {
return exec_err!(
"Non-nullable column '{}' is missing from the physical schema",
column.name()
);
}
let null_value = ScalarValue::Null.cast_to(logical_field.data_type())?;
return Ok(Transformed::yes(expressions::lit(null_value)));
}
};
let physical_field = self.physical_file_schema.field(physical_column_index);
let column = match (
column.index() == physical_column_index,
logical_field.data_type() == physical_field.data_type(),
) {
(true, true) => return Ok(Transformed::no(expr)),
(true, _) => column.clone(),
(false, _) => {
Column::new_with_schema(logical_field.name(), self.physical_file_schema)?
}
};
if logical_field.data_type() == physical_field.data_type() {
return Ok(Transformed::yes(Arc::new(column)));
}
match (physical_field.data_type(), logical_field.data_type()) {
(DataType::Struct(physical_fields), DataType::Struct(logical_fields)) => {
validate_struct_compatibility(physical_fields, logical_fields)?;
}
_ => {
let is_compatible =
can_cast_types(physical_field.data_type(), logical_field.data_type());
if !is_compatible {
return exec_err!(
"Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)",
column.name(),
physical_field.data_type(),
logical_field.data_type()
);
}
}
}
let cast_expr = Arc::new(CastColumnExpr::new(
Arc::new(column),
Arc::new(physical_field.clone()),
Arc::new(logical_field.clone()),
None,
));
Ok(Transformed::yes(cast_expr))
}
}
#[derive(Debug)]
pub struct BatchAdapterFactory {
target_schema: SchemaRef,
expr_adapter_factory: Arc<dyn PhysicalExprAdapterFactory>,
}
impl BatchAdapterFactory {
pub fn new(target_schema: SchemaRef) -> Self {
let expr_adapter_factory = Arc::new(DefaultPhysicalExprAdapterFactory);
Self {
target_schema,
expr_adapter_factory,
}
}
pub fn with_adapter_factory(
self,
factory: Arc<dyn PhysicalExprAdapterFactory>,
) -> Self {
Self {
expr_adapter_factory: factory,
..self
}
}
pub fn make_adapter(&self, source_schema: SchemaRef) -> Result<BatchAdapter> {
let expr_adapter = self
.expr_adapter_factory
.create(Arc::clone(&self.target_schema), Arc::clone(&source_schema));
let simplifier = PhysicalExprSimplifier::new(&self.target_schema);
let projection = ProjectionExprs::from_indices(
&(0..self.target_schema.fields().len()).collect_vec(),
&self.target_schema,
);
let adapted = projection
.try_map_exprs(|e| simplifier.simplify(expr_adapter.rewrite(e)?))?;
let projector = adapted.make_projector(&source_schema)?;
Ok(BatchAdapter { projector })
}
}
#[derive(Debug)]
pub struct BatchAdapter {
projector: Projector,
}
impl BatchAdapter {
pub fn adapt_batch(&self, batch: &RecordBatch) -> Result<RecordBatch> {
self.projector.project_batch(batch)
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{
BooleanArray, Int32Array, Int64Array, RecordBatch, RecordBatchOptions,
StringArray, StringViewArray, StructArray,
};
use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
use datafusion_common::{Result, ScalarValue, assert_contains, record_batch};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{Column, Literal, col, lit};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use itertools::Itertools;
use std::sync::Arc;
fn create_test_schema() -> (Schema, Schema) {
let physical_schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true),
]);
let logical_schema = Schema::new(vec![
Field::new("a", DataType::Int64, false), Field::new("b", DataType::Utf8, true),
Field::new("c", DataType::Float64, true), ]);
(physical_schema, logical_schema)
}
#[test]
fn test_rewrite_column_with_type_cast() {
let (physical_schema, logical_schema) = create_test_schema();
let factory = DefaultPhysicalExprAdapterFactory;
let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
let column_expr = Arc::new(Column::new("a", 0));
let result = adapter.rewrite(column_expr).unwrap();
assert!(result.as_any().downcast_ref::<CastColumnExpr>().is_some());
}
#[test]
fn test_rewrite_multi_column_expr_with_type_cast() {
let (physical_schema, logical_schema) = create_test_schema();
let factory = DefaultPhysicalExprAdapterFactory;
let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
let column_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
let column_c = Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>;
let expr = expressions::BinaryExpr::new(
Arc::clone(&column_a),
Operator::Plus,
Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
);
let expr = expressions::BinaryExpr::new(
Arc::new(expr),
Operator::Or,
Arc::new(expressions::BinaryExpr::new(
Arc::clone(&column_c),
Operator::Gt,
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
)),
);
let result = adapter.rewrite(Arc::new(expr)).unwrap();
println!("Rewritten expression: {result}");
let expected = expressions::BinaryExpr::new(
Arc::new(CastColumnExpr::new(
Arc::new(Column::new("a", 0)),
Arc::new(Field::new("a", DataType::Int32, false)),
Arc::new(Field::new("a", DataType::Int64, false)),
None,
)),
Operator::Plus,
Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
);
let expected = Arc::new(expressions::BinaryExpr::new(
Arc::new(expected),
Operator::Or,
Arc::new(expressions::BinaryExpr::new(
lit(ScalarValue::Float64(None)), Operator::Gt,
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
)),
)) as Arc<dyn PhysicalExpr>;
assert_eq!(
result.to_string(),
expected.to_string(),
"The rewritten expression did not match the expected output"
);
}
#[test]
fn test_rewrite_struct_column_incompatible() {
let physical_schema = Schema::new(vec![Field::new(
"data",
DataType::Struct(vec![Field::new("field1", DataType::Binary, true)].into()),
true,
)]);
let logical_schema = Schema::new(vec![Field::new(
"data",
DataType::Struct(vec![Field::new("field1", DataType::Int32, true)].into()),
true,
)]);
let factory = DefaultPhysicalExprAdapterFactory;
let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
let column_expr = Arc::new(Column::new("data", 0));
let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
assert_contains!(
error_msg,
"Cannot cast struct field 'field1' from type Binary to type Int32"
);
}
#[test]
fn test_rewrite_struct_compatible_cast() {
let physical_schema = Schema::new(vec![Field::new(
"data",
DataType::Struct(
vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]
.into(),
),
false,
)]);
let logical_schema = Schema::new(vec![Field::new(
"data",
DataType::Struct(
vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8View, true),
]
.into(),
),
false,
)]);
let factory = DefaultPhysicalExprAdapterFactory;
let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
let column_expr = Arc::new(Column::new("data", 0));
let result = adapter.rewrite(column_expr).unwrap();
let expected = Arc::new(CastColumnExpr::new(
Arc::new(Column::new("data", 0)),
Arc::new(Field::new(
"data",
DataType::Struct(
vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]
.into(),
),
false,
)),
Arc::new(Field::new(
"data",
DataType::Struct(
vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8View, true),
]
.into(),
),
false,
)),
None,
)) as Arc<dyn PhysicalExpr>;
assert_eq!(result.to_string(), expected.to_string());
}
#[test]
fn test_rewrite_missing_column() -> Result<()> {
let (physical_schema, logical_schema) = create_test_schema();
let factory = DefaultPhysicalExprAdapterFactory;
let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
let column_expr = Arc::new(Column::new("c", 2));
let result = adapter.rewrite(column_expr)?;
if let Some(literal) = result.as_any().downcast_ref::<expressions::Literal>() {
assert_eq!(*literal.value(), ScalarValue::Float64(None));
} else {
panic!("Expected literal expression");
}
Ok(())
}
#[test]
fn test_rewrite_missing_column_non_nullable_error() {
let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let logical_schema = Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Utf8, false), ]);
let factory = DefaultPhysicalExprAdapterFactory;
let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
let column_expr = Arc::new(Column::new("b", 1));
let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
assert_contains!(error_msg, "Non-nullable column 'b' is missing");
}
#[test]
fn test_rewrite_missing_column_nullable() {
let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let logical_schema = Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Utf8, true), ]);
let factory = DefaultPhysicalExprAdapterFactory;
let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
let column_expr = Arc::new(Column::new("b", 1));
let result = adapter.rewrite(column_expr).unwrap();
let expected =
Arc::new(Literal::new(ScalarValue::Utf8(None))) as Arc<dyn PhysicalExpr>;
assert_eq!(result.to_string(), expected.to_string());
}
#[test]
fn test_replace_columns_with_literals() -> Result<()> {
let partition_value = ScalarValue::Utf8(Some("test_value".to_string()));
let replacements = HashMap::from([("partition_col", &partition_value)]);
let column_expr =
Arc::new(Column::new("partition_col", 0)) as Arc<dyn PhysicalExpr>;
let result = replace_columns_with_literals(column_expr, &replacements)?;
let literal = result
.as_any()
.downcast_ref::<expressions::Literal>()
.expect("Expected literal expression");
assert_eq!(*literal.value(), partition_value);
Ok(())
}
#[test]
fn test_replace_columns_with_literals_no_match() -> Result<()> {
let value = ScalarValue::Utf8(Some("test_value".to_string()));
let replacements = HashMap::from([("other_col", &value)]);
let column_expr =
Arc::new(Column::new("partition_col", 0)) as Arc<dyn PhysicalExpr>;
let result = replace_columns_with_literals(column_expr, &replacements)?;
assert!(result.as_any().downcast_ref::<Column>().is_some());
Ok(())
}
#[test]
fn test_replace_columns_with_literals_nested_expr() -> Result<()> {
let value_a = ScalarValue::Int64(Some(10));
let value_b = ScalarValue::Int64(Some(20));
let replacements = HashMap::from([("a", &value_a), ("b", &value_b)]);
let expr = Arc::new(expressions::BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
)) as Arc<dyn PhysicalExpr>;
let result = replace_columns_with_literals(expr, &replacements)?;
assert_eq!(result.to_string(), "10 + 20");
Ok(())
}
#[test]
fn test_rewrite_no_change_needed() -> Result<()> {
let (physical_schema, logical_schema) = create_test_schema();
let factory = DefaultPhysicalExprAdapterFactory;
let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
let column_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
let result = adapter.rewrite(Arc::clone(&column_expr))?;
assert!(std::ptr::eq(
column_expr.as_ref() as *const dyn PhysicalExpr,
result.as_ref() as *const dyn PhysicalExpr
));
Ok(())
}
#[test]
fn test_non_nullable_missing_column_error() {
let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let logical_schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false), ]);
let factory = DefaultPhysicalExprAdapterFactory;
let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
let column_expr = Arc::new(Column::new("b", 1));
let result = adapter.rewrite(column_expr);
assert!(result.is_err());
assert_contains!(
result.unwrap_err().to_string(),
"Non-nullable column 'b' is missing from the physical schema"
);
}
fn batch_project(
expr: Vec<Arc<dyn PhysicalExpr>>,
batch: &RecordBatch,
schema: SchemaRef,
) -> Result<RecordBatch> {
let arrays = expr
.iter()
.map(|expr| {
expr.evaluate(batch)
.and_then(|v| v.into_array(batch.num_rows()))
})
.collect::<Result<Vec<_>>>()?;
if arrays.is_empty() {
let options =
RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
RecordBatch::try_new_with_options(Arc::clone(&schema), arrays, &options)
.map_err(Into::into)
} else {
RecordBatch::try_new(Arc::clone(&schema), arrays).map_err(Into::into)
}
}
#[test]
fn test_adapt_batches() {
let physical_batch = record_batch!(
("a", Int32, vec![Some(1), None, Some(3)]),
("extra", Utf8, vec![Some("x"), Some("y"), None])
)
.unwrap();
let physical_schema = physical_batch.schema();
let logical_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, true), Field::new("b", DataType::Utf8, true), ]));
let projection = vec![
col("b", &logical_schema).unwrap(),
col("a", &logical_schema).unwrap(),
];
let factory = DefaultPhysicalExprAdapterFactory;
let adapter =
factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema));
let adapted_projection = projection
.into_iter()
.map(|expr| adapter.rewrite(expr).unwrap())
.collect_vec();
let adapted_schema = Arc::new(Schema::new(
adapted_projection
.iter()
.map(|expr| expr.return_field(&physical_schema).unwrap())
.collect_vec(),
));
let res = batch_project(
adapted_projection,
&physical_batch,
Arc::clone(&adapted_schema),
)
.unwrap();
assert_eq!(res.num_columns(), 2);
assert_eq!(res.column(0).data_type(), &DataType::Utf8);
assert_eq!(res.column(1).data_type(), &DataType::Int64);
assert_eq!(
res.column(0)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.unwrap()
.iter()
.collect_vec(),
vec![None, None, None]
);
assert_eq!(
res.column(1)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.unwrap()
.iter()
.collect_vec(),
vec![Some(1), None, Some(3)]
);
}
#[test]
fn test_adapt_struct_batches() {
let physical_struct_fields: Fields = vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]
.into();
let struct_array = StructArray::new(
physical_struct_fields.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])) as _,
Arc::new(StringArray::from(vec![
Some("alice"),
None,
Some("charlie"),
])) as _,
],
None,
);
let physical_schema = Arc::new(Schema::new(vec![Field::new(
"data",
DataType::Struct(physical_struct_fields),
false,
)]));
let physical_batch = RecordBatch::try_new(
Arc::clone(&physical_schema),
vec![Arc::new(struct_array)],
)
.unwrap();
let logical_struct_fields: Fields = vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8View, true),
Field::new("extra", DataType::Boolean, true), ]
.into();
let logical_schema = Arc::new(Schema::new(vec![Field::new(
"data",
DataType::Struct(logical_struct_fields),
false,
)]));
let projection = vec![col("data", &logical_schema).unwrap()];
let factory = DefaultPhysicalExprAdapterFactory;
let adapter =
factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema));
let adapted_projection = projection
.into_iter()
.map(|expr| adapter.rewrite(expr).unwrap())
.collect_vec();
let adapted_schema = Arc::new(Schema::new(
adapted_projection
.iter()
.map(|expr| expr.return_field(&physical_schema).unwrap())
.collect_vec(),
));
let res = batch_project(
adapted_projection,
&physical_batch,
Arc::clone(&adapted_schema),
)
.unwrap();
assert_eq!(res.num_columns(), 1);
let result_struct = res
.column(0)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let id_col = result_struct.column_by_name("id").unwrap();
assert_eq!(id_col.data_type(), &DataType::Int64);
let id_values = id_col.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(
id_values.iter().collect_vec(),
vec![Some(1), Some(2), Some(3)]
);
let name_col = result_struct.column_by_name("name").unwrap();
assert_eq!(name_col.data_type(), &DataType::Utf8View);
let name_values = name_col.as_any().downcast_ref::<StringViewArray>().unwrap();
assert_eq!(
name_values.iter().collect_vec(),
vec![Some("alice"), None, Some("charlie")]
);
let extra_col = result_struct.column_by_name("extra").unwrap();
assert_eq!(extra_col.data_type(), &DataType::Boolean);
let extra_values = extra_col.as_any().downcast_ref::<BooleanArray>().unwrap();
assert_eq!(extra_values.iter().collect_vec(), vec![None, None, None]);
}
#[test]
fn test_try_rewrite_struct_field_access() {
let physical_schema = Schema::new(vec![Field::new(
"struct_col",
DataType::Struct(
vec![Field::new("existing_field", DataType::Int32, true)].into(),
),
true,
)]);
let logical_schema = Schema::new(vec![Field::new(
"struct_col",
DataType::Struct(
vec![
Field::new("existing_field", DataType::Int32, true),
Field::new("missing_field", DataType::Utf8, true),
]
.into(),
),
true,
)]);
let rewriter = DefaultPhysicalExprAdapterRewriter {
logical_file_schema: &logical_schema,
physical_file_schema: &physical_schema,
};
let column = Arc::new(Column::new("struct_col", 0)) as Arc<dyn PhysicalExpr>;
let result = rewriter.try_rewrite_struct_field_access(&column).unwrap();
assert!(result.is_none());
}
#[test]
fn test_batch_adapter_factory_basic() {
let target_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Utf8, true),
]));
let source_schema = Arc::new(Schema::new(vec![
Field::new("b", DataType::Utf8, true),
Field::new("a", DataType::Int32, false), ]));
let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap();
let source_batch = RecordBatch::try_new(
Arc::clone(&source_schema),
vec![
Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
Arc::new(Int32Array::from(vec![1, 2, 3])),
],
)
.unwrap();
let adapted = adapter.adapt_batch(&source_batch).unwrap();
assert_eq!(adapted.num_columns(), 2);
assert_eq!(adapted.schema().field(0).name(), "a");
assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int64);
assert_eq!(adapted.schema().field(1).name(), "b");
assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8);
let col_a = adapted
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(col_a.iter().collect_vec(), vec![Some(1), Some(2), Some(3)]);
let col_b = adapted
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(
col_b.iter().collect_vec(),
vec![Some("hello"), None, Some("world")]
);
}
#[test]
fn test_batch_adapter_factory_missing_column() {
let target_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true), Field::new("c", DataType::Float64, true), ]));
let source_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true),
]));
let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap();
let source_batch = RecordBatch::try_new(
Arc::clone(&source_schema),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StringArray::from(vec!["x", "y"])),
],
)
.unwrap();
let adapted = adapter.adapt_batch(&source_batch).unwrap();
assert_eq!(adapted.num_columns(), 3);
let col_c = adapted.column(2);
assert_eq!(col_c.data_type(), &DataType::Float64);
assert_eq!(col_c.null_count(), 2); }
#[test]
fn test_batch_adapter_factory_with_struct() {
let target_struct_fields: Fields = vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, true),
]
.into();
let target_schema = Arc::new(Schema::new(vec![Field::new(
"data",
DataType::Struct(target_struct_fields),
false,
)]));
let source_struct_fields: Fields = vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]
.into();
let source_schema = Arc::new(Schema::new(vec![Field::new(
"data",
DataType::Struct(source_struct_fields.clone()),
false,
)]));
let struct_array = StructArray::new(
source_struct_fields,
vec![
Arc::new(Int32Array::from(vec![10, 20])) as _,
Arc::new(StringArray::from(vec!["a", "b"])) as _,
],
None,
);
let source_batch = RecordBatch::try_new(
Arc::clone(&source_schema),
vec![Arc::new(struct_array)],
)
.unwrap();
let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
let adapter = factory.make_adapter(source_schema).unwrap();
let adapted = adapter.adapt_batch(&source_batch).unwrap();
let result_struct = adapted
.column(0)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let id_col = result_struct.column_by_name("id").unwrap();
assert_eq!(id_col.data_type(), &DataType::Int64);
let id_values = id_col.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(id_values.iter().collect_vec(), vec![Some(10), Some(20)]);
}
#[test]
fn test_batch_adapter_factory_identity() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true),
]));
let factory = BatchAdapterFactory::new(Arc::clone(&schema));
let adapter = factory.make_adapter(Arc::clone(&schema)).unwrap();
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
],
)
.unwrap();
let adapted = adapter.adapt_batch(&batch).unwrap();
assert_eq!(adapted.num_columns(), 2);
assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int32);
assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8);
}
#[test]
fn test_batch_adapter_factory_reuse() {
let target_schema = Arc::new(Schema::new(vec![
Field::new("x", DataType::Int64, false),
Field::new("y", DataType::Utf8, true),
]));
let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
let source1 = Arc::new(Schema::new(vec![
Field::new("x", DataType::Int32, false),
Field::new("y", DataType::Utf8, true),
]));
let adapter1 = factory.make_adapter(source1).unwrap();
let source2 = Arc::new(Schema::new(vec![
Field::new("y", DataType::Utf8, true),
Field::new("x", DataType::Int64, false),
]));
let adapter2 = factory.make_adapter(source2).unwrap();
assert!(format!("{:?}", adapter1).contains("BatchAdapter"));
assert!(format!("{:?}", adapter2).contains("BatchAdapter"));
}
}