use std::sync::Arc;
use datafusion::arrow::array::RecordBatch;
use datafusion::arrow::datatypes::{DataType, Schema as ArrowSchema};
use datafusion::common::{DataFusionError, Result as DFResult};
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::{ColumnarValue, ExecutionPlan};
use iceberg::arrow::{
PROJECTED_PARTITION_VALUE_COLUMN, PartitionValueCalculator, schema_to_arrow_schema,
strip_metadata_from_schema,
};
use iceberg::spec::PartitionSpec;
use iceberg::table::Table;
use crate::to_datafusion_error;
pub fn project_with_partition(
input: Arc<dyn ExecutionPlan>,
table: &Table,
) -> DFResult<Arc<dyn ExecutionPlan>> {
let metadata = table.metadata();
let partition_spec = metadata.default_partition_spec();
let table_schema = metadata.current_schema();
if partition_spec.is_unpartitioned() {
return Ok(input);
}
let input_schema = input.schema();
let expected_arrow_schema =
schema_to_arrow_schema(table_schema.as_ref()).map_err(to_datafusion_error)?;
let input_schema_cleaned =
strip_metadata_from_schema(&input_schema).map_err(to_datafusion_error)?;
let expected_schema_cleaned =
strip_metadata_from_schema(&expected_arrow_schema).map_err(to_datafusion_error)?;
if input_schema_cleaned != expected_schema_cleaned {
return Err(DataFusionError::Plan(format!(
"Input schema does not match Iceberg table schema.\n\
Expected schema: {expected_schema_cleaned}\n\
Input schema: {input_schema_cleaned}"
)));
}
let calculator =
PartitionValueCalculator::try_new(partition_spec.as_ref(), table_schema.as_ref())
.map_err(to_datafusion_error)?;
let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
Vec::with_capacity(input_schema.fields().len() + 1);
for (index, field) in input_schema.fields().iter().enumerate() {
let column_expr = Arc::new(Column::new(field.name(), index));
projection_exprs.push((column_expr, field.name().clone()));
}
let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec.clone()));
projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
let projection = ProjectionExec::try_new(projection_exprs, input)?;
Ok(Arc::new(projection))
}
#[derive(Debug, Clone)]
struct PartitionExpr {
calculator: Arc<PartitionValueCalculator>,
partition_spec: Arc<PartitionSpec>,
}
impl PartitionExpr {
fn new(calculator: PartitionValueCalculator, partition_spec: Arc<PartitionSpec>) -> Self {
Self {
calculator: Arc::new(calculator),
partition_spec,
}
}
}
impl PartialEq for PartitionExpr {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.calculator, &other.calculator)
&& Arc::ptr_eq(&self.partition_spec, &other.partition_spec)
}
}
impl Eq for PartitionExpr {}
impl PhysicalExpr for PartitionExpr {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn data_type(&self, _input_schema: &ArrowSchema) -> DFResult<DataType> {
Ok(self.calculator.partition_arrow_type().clone())
}
fn nullable(&self, _input_schema: &ArrowSchema) -> DFResult<bool> {
Ok(false)
}
fn evaluate(&self, batch: &RecordBatch) -> DFResult<ColumnarValue> {
let array = self
.calculator
.calculate(batch)
.map_err(to_datafusion_error)?;
Ok(ColumnarValue::Array(array))
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn PhysicalExpr>>,
) -> DFResult<Arc<dyn PhysicalExpr>> {
Ok(self)
}
fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let field_names: Vec<String> = self
.partition_spec
.fields()
.iter()
.map(|pf| format!("{}({})", pf.transform, pf.name))
.collect();
write!(f, "iceberg_partition_values[{}]", field_names.join(", "))
}
}
impl std::fmt::Display for PartitionExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let field_names: Vec<&str> = self
.partition_spec
.fields()
.iter()
.map(|pf| pf.name.as_str())
.collect();
write!(f, "iceberg_partition_values({})", field_names.join(", "))
}
}
impl std::hash::Hash for PartitionExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
Arc::as_ptr(&self.calculator).hash(state);
Arc::as_ptr(&self.partition_spec).hash(state);
}
}
#[cfg(test)]
mod tests {
use datafusion::arrow::array::{ArrayRef, Int32Array, StructArray};
use datafusion::arrow::datatypes::{DataType, Field, Fields};
use datafusion::physical_plan::empty::EmptyExec;
use iceberg::spec::{NestedField, PrimitiveType, Schema, StructType, Transform, Type};
use super::*;
#[test]
fn test_partition_calculator_basic() {
let table_schema = Schema::builder()
.with_schema_id(0)
.with_fields(vec![
NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
])
.build()
.unwrap();
let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
.add_partition_field("id", "id_partition", Transform::Identity)
.unwrap()
.build()
.unwrap();
let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
assert_eq!(calculator.partition_type().fields().len(), 1);
assert_eq!(calculator.partition_type().fields()[0].name, "id_partition");
}
#[test]
fn test_partition_expr_with_projection() {
let table_schema = Schema::builder()
.with_schema_id(0)
.with_fields(vec![
NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
])
.build()
.unwrap();
let partition_spec = Arc::new(
iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
.add_partition_field("id", "id_partition", Transform::Identity)
.unwrap()
.build()
.unwrap(),
);
let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let input = Arc::new(EmptyExec::new(arrow_schema.clone()));
let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
Vec::with_capacity(arrow_schema.fields().len() + 1);
for (i, field) in arrow_schema.fields().iter().enumerate() {
let column_expr = Arc::new(Column::new(field.name(), i));
projection_exprs.push((column_expr, field.name().clone()));
}
let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec));
projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
let projection = ProjectionExec::try_new(projection_exprs, input).unwrap();
let result = Arc::new(projection);
assert_eq!(result.schema().fields().len(), 3);
assert_eq!(result.schema().field(0).name(), "id");
assert_eq!(result.schema().field(1).name(), "name");
assert_eq!(result.schema().field(2).name(), "_partition");
}
#[test]
fn test_partition_expr_evaluate() {
let table_schema = Schema::builder()
.with_schema_id(0)
.with_fields(vec![
NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(2, "data", Type::Primitive(PrimitiveType::String)).into(),
])
.build()
.unwrap();
let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
.add_partition_field("id", "id_partition", Transform::Identity)
.unwrap()
.build()
.unwrap();
let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("data", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
Arc::new(Int32Array::from(vec![10, 20, 30])),
Arc::new(datafusion::arrow::array::StringArray::from(vec![
"a", "b", "c",
])),
])
.unwrap();
let partition_spec = Arc::new(partition_spec);
let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
let partition_type = calculator.partition_arrow_type().clone();
let expr = PartitionExpr::new(calculator, partition_spec);
assert_eq!(expr.data_type(&arrow_schema).unwrap(), partition_type);
assert!(!expr.nullable(&arrow_schema).unwrap());
let result = expr.evaluate(&batch).unwrap();
match result {
ColumnarValue::Array(array) => {
let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
let id_partition = struct_array
.column_by_name("id_partition")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(id_partition.value(0), 10);
assert_eq!(id_partition.value(1), 20);
assert_eq!(id_partition.value(2), 30);
}
_ => panic!("Expected array result"),
}
}
#[test]
fn test_nested_partition() {
let address_struct = StructType::new(vec![
NestedField::required(3, "street", Type::Primitive(PrimitiveType::String)).into(),
NestedField::required(4, "city", Type::Primitive(PrimitiveType::String)).into(),
]);
let table_schema = Schema::builder()
.with_schema_id(0)
.with_fields(vec![
NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(2, "address", Type::Struct(address_struct)).into(),
])
.build()
.unwrap();
let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
.add_partition_field("address.city", "city_partition", Transform::Identity)
.unwrap()
.build()
.unwrap();
let struct_fields = Fields::from(vec![
Field::new("street", DataType::Utf8, false),
Field::new("city", DataType::Utf8, false),
]);
let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("address", DataType::Struct(struct_fields), false),
]));
let street_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
"123 Main St",
"456 Oak Ave",
]));
let city_array = Arc::new(datafusion::arrow::array::StringArray::from(vec![
"New York",
"Los Angeles",
]));
let struct_array = StructArray::from(vec![
(
Arc::new(Field::new("street", DataType::Utf8, false)),
street_array as ArrayRef,
),
(
Arc::new(Field::new("city", DataType::Utf8, false)),
city_array as ArrayRef,
),
]);
let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(struct_array),
])
.unwrap();
let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();
let array = calculator.calculate(&batch).unwrap();
let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
let city_partition = struct_array
.column_by_name("city_partition")
.unwrap()
.as_any()
.downcast_ref::<datafusion::arrow::array::StringArray>()
.unwrap();
assert_eq!(city_partition.value(0), "New York");
assert_eq!(city_partition.value(1), "Los Angeles");
}
#[test]
fn test_schema_validation_matching_schemas() {
use iceberg::TableIdent;
use iceberg::io::FileIO;
use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
let table_schema = Arc::new(
Schema::builder()
.with_fields(vec![
NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
])
.build()
.unwrap(),
);
let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
.add_partition_field("id", "id_partition", Transform::Identity)
.unwrap()
.build()
.unwrap();
let sort_order = iceberg::spec::SortOrder::builder()
.build(&table_schema)
.unwrap();
let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
(*table_schema).clone(),
partition_spec,
sort_order,
"/test/table".to_string(),
FormatVersion::V2,
std::collections::HashMap::new(),
)
.unwrap();
let table_metadata = table_metadata_builder.build().unwrap();
let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let input = Arc::new(EmptyExec::new(arrow_schema));
let table = iceberg::table::Table::builder()
.metadata(table_metadata.metadata)
.identifier(TableIdent::from_strs(["test", "table"]).unwrap())
.file_io(FileIO::new_with_fs())
.metadata_location("/test/metadata.json".to_string())
.build()
.unwrap();
let result = project_with_partition(input, &table);
assert!(result.is_ok(), "Schema validation should pass");
}
#[test]
fn test_schema_validation_mismatched_schemas() {
use iceberg::TableIdent;
use iceberg::io::FileIO;
use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
let table_schema = Arc::new(
Schema::builder()
.with_fields(vec![
NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
])
.build()
.unwrap(),
);
let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
.add_partition_field("id", "id_partition", Transform::Identity)
.unwrap()
.build()
.unwrap();
let sort_order = iceberg::spec::SortOrder::builder()
.build(&table_schema)
.unwrap();
let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
(*table_schema).clone(),
partition_spec,
sort_order,
"/test/table".to_string(),
FormatVersion::V2,
std::collections::HashMap::new(),
)
.unwrap();
let table_metadata = table_metadata_builder.build().unwrap();
let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("different_name", DataType::Utf8, false), ]));
let input = Arc::new(EmptyExec::new(arrow_schema));
let table = iceberg::table::Table::builder()
.metadata(table_metadata.metadata)
.identifier(TableIdent::from_strs(["test", "table"]).unwrap())
.file_io(FileIO::new_with_fs())
.metadata_location("/test/metadata.json".to_string())
.build()
.unwrap();
let result = project_with_partition(input, &table);
assert!(
result.is_err(),
"Schema validation should fail for mismatched schemas"
);
assert!(
result
.unwrap_err()
.to_string()
.contains("Input schema does not match Iceberg table schema")
);
}
#[test]
fn test_schema_validation_with_metadata_differences() {
use std::collections::HashMap;
use iceberg::TableIdent;
use iceberg::io::FileIO;
use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type};
let table_schema = Arc::new(
Schema::builder()
.with_fields(vec![
NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
])
.build()
.unwrap(),
);
let partition_spec = iceberg::spec::PartitionSpec::builder(table_schema.clone())
.add_partition_field("id", "id_partition", Transform::Identity)
.unwrap()
.build()
.unwrap();
let sort_order = iceberg::spec::SortOrder::builder()
.build(&table_schema)
.unwrap();
let table_metadata_builder = iceberg::spec::TableMetadataBuilder::new(
(*table_schema).clone(),
partition_spec,
sort_order,
"/test/table".to_string(),
FormatVersion::V2,
std::collections::HashMap::new(),
)
.unwrap();
let table_metadata = table_metadata_builder.build().unwrap();
let mut metadata = HashMap::new();
metadata.insert("extra".to_string(), "metadata".to_string());
let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false).with_metadata(metadata.clone()),
Field::new("name", DataType::Utf8, false).with_metadata(metadata),
]));
let input = Arc::new(EmptyExec::new(arrow_schema));
let table = iceberg::table::Table::builder()
.metadata(table_metadata.metadata)
.identifier(TableIdent::from_strs(["test", "table"]).unwrap())
.file_io(FileIO::new_with_fs())
.metadata_location("/test/metadata.json".to_string())
.build()
.unwrap();
let result = project_with_partition(input, &table);
assert!(
result.is_ok(),
"Schema validation should pass even with metadata differences"
);
}
}