use std::sync::Arc;
use arrow::datatypes::{Schema, SchemaRef};
use datafusion_common::{
Result, ScalarValue,
tree_node::{Transformed, TransformedResult, TreeNode},
};
use datafusion_physical_expr::{
expressions::{Column, Literal},
projection::{ProjectionExpr, ProjectionExprs},
};
use futures::{FutureExt, StreamExt};
use itertools::Itertools;
use crate::{
PartitionedFile, TableSchema,
file_stream::{FileOpenFuture, FileOpener},
};
pub struct ProjectionOpener {
inner: Arc<dyn FileOpener>,
projection: ProjectionExprs,
input_schema: SchemaRef,
partition_columns: Vec<PartitionColumnIndex>,
}
impl ProjectionOpener {
pub fn try_new(
projection: SplitProjection,
inner: Arc<dyn FileOpener>,
file_schema: &Schema,
) -> Result<Arc<dyn FileOpener>> {
Ok(Arc::new(ProjectionOpener {
inner,
projection: projection.remapped_projection,
input_schema: Arc::new(file_schema.project(&projection.file_indices)?),
partition_columns: projection.partition_columns,
}))
}
}
impl FileOpener for ProjectionOpener {
fn open(&self, partitioned_file: PartitionedFile) -> Result<FileOpenFuture> {
let partition_values = partitioned_file.partition_values.clone();
let projection = if self.partition_columns.is_empty() {
self.projection.clone()
} else {
inject_partition_columns_into_projection(
&self.projection,
&self.partition_columns,
partition_values,
)
};
let projector = projection.make_projector(&self.input_schema)?;
let inner = self.inner.open(partitioned_file)?;
Ok(async move {
let stream = inner.await?;
let stream = stream.map(move |batch| {
let batch = batch?;
let batch = projector.project_batch(&batch)?;
Ok(batch)
});
Ok(stream.boxed())
}
.boxed())
}
}
#[derive(Debug, Clone, Copy)]
pub struct PartitionColumnIndex {
pub in_remainder_projection: usize,
pub in_partition_values: usize,
}
fn inject_partition_columns_into_projection(
projection: &ProjectionExprs,
partition_columns: &[PartitionColumnIndex],
partition_values: Vec<ScalarValue>,
) -> ProjectionExprs {
let partition_literals: Vec<Arc<Literal>> = partition_values
.into_iter()
.map(|value| Arc::new(Literal::new(value)))
.collect();
let projections = projection
.iter()
.map(|projection| {
let expr = Arc::clone(&projection.expr)
.transform(|expr| {
let original_expr = Arc::clone(&expr);
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
if let Some(pci) = partition_columns
.iter()
.find(|pci| pci.in_remainder_projection == column.index())
{
let literal =
Arc::clone(&partition_literals[pci.in_partition_values]);
return Ok(Transformed::yes(literal));
}
}
Ok(Transformed::no(original_expr))
})
.data()
.expect("infallible transform");
ProjectionExpr::new(expr, projection.alias.clone())
})
.collect_vec();
ProjectionExprs::new(projections)
}
#[derive(Debug, Clone)]
pub struct SplitProjection {
pub source: ProjectionExprs,
pub file_indices: Vec<usize>,
pub(crate) partition_columns: Vec<PartitionColumnIndex>,
pub(crate) remapped_projection: ProjectionExprs,
}
impl SplitProjection {
pub fn unprojected(table_schema: &TableSchema) -> Self {
let projection = ProjectionExprs::from_indices(
&(0..table_schema.table_schema().fields().len()).collect_vec(),
table_schema.table_schema(),
);
Self::new(table_schema.file_schema(), &projection)
}
pub fn new(logical_file_schema: &Schema, projection: &ProjectionExprs) -> Self {
let num_file_schema_columns = logical_file_schema.fields().len();
let mut file_columns = Vec::new();
let mut partition_columns = Vec::new();
let mut all_columns = std::collections::HashMap::new();
for proj_expr in projection {
proj_expr
.expr
.apply(|expr| {
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
all_columns
.entry(column.index())
.or_insert_with(|| column.name().to_string());
}
Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue)
})
.expect("infallible apply");
}
let mut sorted_columns: Vec<_> = all_columns
.into_iter()
.map(|(idx, name)| (name, idx))
.collect();
sorted_columns.sort_by_key(|(_, idx)| *idx);
let mut column_mapping = std::collections::HashMap::new();
let mut file_idx = 0;
let mut partition_idx = 0;
for (name, original_index) in sorted_columns {
let new_index = if original_index < num_file_schema_columns {
file_columns.push(original_index);
let idx = file_idx;
file_idx += 1;
idx
} else {
partition_columns.push(original_index);
let idx = file_idx + partition_idx;
partition_idx += 1;
idx
};
let new_column: Arc<dyn datafusion_physical_plan::PhysicalExpr> =
Arc::new(Column::new(&name, new_index));
column_mapping.insert(original_index, new_column);
}
let remapped_projection = projection
.iter()
.map(|proj_expr| {
let expr = Arc::clone(&proj_expr.expr)
.transform(|expr| {
let original_expr = Arc::clone(&expr);
if let Some(column) = expr.as_any().downcast_ref::<Column>()
&& let Some(new_column) = column_mapping.get(&column.index())
{
return Ok(Transformed::yes(Arc::clone(new_column)));
}
Ok(Transformed::no(original_expr))
})
.data()
.expect("infallible transform");
ProjectionExpr::new(expr, proj_expr.alias.clone())
})
.collect_vec();
let num_file_columns = file_columns.len();
let partition_column_mappings = partition_columns
.iter()
.enumerate()
.map(|(partition_idx, &table_index)| PartitionColumnIndex {
in_remainder_projection: num_file_columns + partition_idx,
in_partition_values: table_index - num_file_schema_columns,
})
.collect_vec();
Self {
source: projection.clone(),
file_indices: file_columns,
partition_columns: partition_column_mappings,
remapped_projection: ProjectionExprs::from(remapped_projection),
}
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use arrow::array::AsArray;
use arrow::datatypes::{DataType, SchemaRef};
use datafusion_common::{DFSchema, ScalarValue, record_batch};
use datafusion_expr::{Expr, col, execution_props::ExecutionProps};
use datafusion_physical_expr::{create_physical_exprs, projection::ProjectionExpr};
use itertools::Itertools;
use super::*;
fn create_projection_exprs<'a>(
exprs: impl IntoIterator<Item = &'a Expr>,
schema: &SchemaRef,
) -> ProjectionExprs {
let df_schema = DFSchema::try_from(Arc::clone(schema)).unwrap();
let physical_exprs =
create_physical_exprs(exprs, &df_schema, &ExecutionProps::default()).unwrap();
let projection_exprs = physical_exprs
.into_iter()
.enumerate()
.map(|(i, e)| ProjectionExpr::new(Arc::clone(&e), format!("col{i}")))
.collect_vec();
ProjectionExprs::from(projection_exprs)
}
#[test]
fn test_split_projection_with_partition_columns() {
use arrow::array::AsArray;
use arrow::datatypes::Field;
let file_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("bool_col", DataType::Boolean, false),
Field::new("tinyint_col", DataType::Int8, false),
]));
let table_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("bool_col", DataType::Boolean, false),
Field::new("tinyint_col", DataType::Int8, false),
Field::new("date", DataType::Utf8, false), ]));
let projection_indices = vec![0, 1, 3, 2];
let projection =
ProjectionExprs::from_indices(&projection_indices, &table_schema);
let split = SplitProjection::new(&file_schema, &projection);
assert_eq!(split.file_indices, vec![0, 1, 2]);
assert_eq!(split.partition_columns.len(), 1);
assert_eq!(split.partition_columns[0].in_partition_values, 0);
let file_batch = record_batch!(
("id", Int32, vec![4]),
("bool_col", Boolean, vec![true]),
("tinyint_col", Int8, vec![0])
)
.unwrap();
let partition_values = vec![ScalarValue::from("2021-10-26")];
let partition_columns = vec![PartitionColumnIndex {
in_remainder_projection: 3, in_partition_values: 0, }];
let injected_projection = inject_partition_columns_into_projection(
&split.remapped_projection,
&partition_columns,
partition_values,
);
let projector = injected_projection
.make_projector(&file_batch.schema())
.unwrap();
let result = projector.project_batch(&file_batch).unwrap();
assert_eq!(result.num_columns(), 4);
assert_eq!(
result
.column(0)
.as_primitive::<arrow::datatypes::Int32Type>()
.value(0),
4
);
assert!(result.column(1).as_boolean().value(0));
assert_eq!(result.column(2).as_string::<i32>().value(0), "2021-10-26");
assert_eq!(
result
.column(3)
.as_primitive::<arrow::datatypes::Int8Type>()
.value(0),
0
);
}
fn create_test_schemas(
file_cols: usize,
partition_cols: usize,
) -> (SchemaRef, SchemaRef) {
use arrow::datatypes::Field;
let file_fields: Vec<_> = (0..file_cols)
.map(|i| Field::new(format!("col_{i}"), DataType::Int32, false))
.collect();
let mut table_fields = file_fields.clone();
table_fields.extend(
(0..partition_cols)
.map(|i| Field::new(format!("part_{i}"), DataType::Utf8, false)),
);
(
Arc::new(Schema::new(file_fields)),
Arc::new(Schema::new(table_fields)),
)
}
#[test]
fn test_split_projection_only_file_columns() {
let (file_schema, table_schema) = create_test_schemas(3, 2);
let projection = ProjectionExprs::from_indices(&[0, 1, 2], &table_schema);
let split = SplitProjection::new(&file_schema, &projection);
assert_eq!(split.file_indices, vec![0, 1, 2]);
assert_eq!(split.partition_columns.len(), 0);
}
#[test]
fn test_split_projection_only_partition_columns() {
let (file_schema, table_schema) = create_test_schemas(3, 2);
let projection = ProjectionExprs::from_indices(&[3, 4], &table_schema);
let split = SplitProjection::new(&file_schema, &projection);
assert_eq!(split.file_indices, Vec::<usize>::new());
assert_eq!(split.partition_columns.len(), 2);
assert_eq!(split.partition_columns[0].in_partition_values, 0);
assert_eq!(split.partition_columns[1].in_partition_values, 1);
}
#[test]
fn test_split_projection_multiple_partition_columns() {
let (file_schema, table_schema) = create_test_schemas(2, 3);
let projection = ProjectionExprs::from_indices(&[0, 2, 4, 1, 3], &table_schema);
let split = SplitProjection::new(&file_schema, &projection);
assert_eq!(split.file_indices, vec![0, 1]);
assert_eq!(split.partition_columns.len(), 3);
assert_eq!(split.partition_columns[0].in_partition_values, 0);
assert_eq!(split.partition_columns[1].in_partition_values, 1);
assert_eq!(split.partition_columns[2].in_partition_values, 2);
assert_eq!(split.remapped_projection.iter().count(), 5);
}
#[test]
fn test_split_projection_partition_columns_reverse_order() {
let (file_schema, table_schema) = create_test_schemas(2, 2);
let projection = ProjectionExprs::from_indices(&[3, 2], &table_schema);
let split = SplitProjection::new(&file_schema, &projection);
assert_eq!(split.file_indices, Vec::<usize>::new());
assert_eq!(split.partition_columns.len(), 2);
assert_eq!(split.partition_columns[0].in_partition_values, 0);
assert_eq!(split.partition_columns[1].in_partition_values, 1);
}
#[test]
fn test_split_projection_interleaved_file_and_partition() {
let (file_schema, table_schema) = create_test_schemas(3, 3);
let projection =
ProjectionExprs::from_indices(&[0, 3, 1, 4, 2, 5], &table_schema);
let split = SplitProjection::new(&file_schema, &projection);
assert_eq!(split.file_indices, vec![0, 1, 2]);
assert_eq!(split.partition_columns.len(), 3);
assert_eq!(split.partition_columns[0].in_partition_values, 0);
assert_eq!(split.partition_columns[1].in_partition_values, 1);
assert_eq!(split.partition_columns[2].in_partition_values, 2);
}
#[test]
fn test_split_projection_expression_with_file_and_partition_columns() {
use arrow::datatypes::Field;
let file_schema = Arc::new(Schema::new(vec![
Field::new("file_a", DataType::Int32, false),
Field::new("file_b", DataType::Int32, false),
]));
let table_schema = Arc::new(Schema::new(vec![
Field::new("file_a", DataType::Int32, false),
Field::new("file_b", DataType::Int32, false),
Field::new("part_c", DataType::Int32, false),
]));
let exprs = [col("file_a") + col("part_c")];
let projection = create_projection_exprs(exprs.iter(), &table_schema);
let split = SplitProjection::new(&file_schema, &projection);
assert_eq!(split.file_indices, vec![0]);
assert_eq!(split.partition_columns.len(), 1);
assert_eq!(split.partition_columns[0].in_partition_values, 0);
}
#[test]
fn test_split_projection_boundary_last_file_column() {
let (file_schema, table_schema) = create_test_schemas(3, 2);
let projection = ProjectionExprs::from_indices(&[2], &table_schema);
let split = SplitProjection::new(&file_schema, &projection);
assert_eq!(split.file_indices, vec![2]);
assert_eq!(split.partition_columns.len(), 0);
}
#[test]
fn test_split_projection_boundary_first_partition_column() {
let (file_schema, table_schema) = create_test_schemas(3, 2);
let projection = ProjectionExprs::from_indices(&[3], &table_schema);
let split = SplitProjection::new(&file_schema, &projection);
assert_eq!(split.file_indices, Vec::<usize>::new());
assert_eq!(split.partition_columns.len(), 1);
assert_eq!(split.partition_columns[0].in_partition_values, 0);
}
#[test]
fn test_inject_partition_columns_multiple_partitions() {
let data =
record_batch!(("col_0", Int32, vec![1]), ("col_1", Int32, vec![2])).unwrap();
let (file_schema, table_schema) = create_test_schemas(2, 2);
let projection = ProjectionExprs::from_indices(&[0, 2, 1, 3], &table_schema);
let split = SplitProjection::new(&file_schema, &projection);
let partition_columns = vec![
PartitionColumnIndex {
in_remainder_projection: 2, in_partition_values: 0,
},
PartitionColumnIndex {
in_remainder_projection: 3, in_partition_values: 1,
},
];
let partition_values =
vec![ScalarValue::from("part_a"), ScalarValue::from("part_b")];
let injected = inject_partition_columns_into_projection(
&split.remapped_projection,
&partition_columns,
partition_values,
);
let projector = injected.make_projector(&data.schema()).unwrap();
let result = projector.project_batch(&data).unwrap();
assert_eq!(result.num_columns(), 4);
assert_eq!(
result
.column(0)
.as_primitive::<arrow::datatypes::Int32Type>()
.value(0),
1
);
assert_eq!(result.column(1).as_string::<i32>().value(0), "part_a");
assert_eq!(
result
.column(2)
.as_primitive::<arrow::datatypes::Int32Type>()
.value(0),
2
);
assert_eq!(result.column(3).as_string::<i32>().value(0), "part_b");
}
}