use std::sync::Arc;
use datafusion::arrow::compute::SortOptions;
use datafusion::common::Result as DFResult;
use datafusion::error::DataFusionError;
use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::expressions::Column;
use datafusion::physical_plan::sorts::sort::SortExec;
use iceberg::arrow::PROJECTED_PARTITION_VALUE_COLUMN;
pub(crate) fn sort_by_partition(input: Arc<dyn ExecutionPlan>) -> DFResult<Arc<dyn ExecutionPlan>> {
let schema = input.schema();
let (partition_column_index, _partition_field) = schema
.column_with_name(PROJECTED_PARTITION_VALUE_COLUMN)
.ok_or_else(|| {
DataFusionError::Plan(format!(
"Partition column '{PROJECTED_PARTITION_VALUE_COLUMN}' not found in schema. Ensure the plan has been extended with partition values using project_with_partition."
))
})?;
let column_expr = Arc::new(Column::new(
PROJECTED_PARTITION_VALUE_COLUMN,
partition_column_index,
));
let sort_expr = PhysicalSortExpr {
expr: column_expr,
options: SortOptions::default(), };
let lex_ordering = LexOrdering::new(vec![sort_expr]).ok_or_else(|| {
DataFusionError::Plan("Failed to create LexOrdering from sort expression".to_string())
})?;
let sort_exec = SortExec::new(lex_ordering, input).with_preserve_partitioning(true);
Ok(Arc::new(sort_exec))
}
#[cfg(test)]
mod tests {
use datafusion::arrow::array::{Int32Array, RecordBatch, StringArray, StructArray};
use datafusion::arrow::datatypes::{DataType, Field, Fields, Schema as ArrowSchema};
use datafusion::datasource::{MemTable, TableProvider};
use datafusion::prelude::SessionContext;
use super::*;
#[tokio::test]
async fn test_sort_by_partition_basic() {
let partition_fields =
Fields::from(vec![Field::new("id_partition", DataType::Int32, false)]);
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
Field::new(
PROJECTED_PARTITION_VALUE_COLUMN,
DataType::Struct(partition_fields.clone()),
false,
),
]));
let id_array = Arc::new(Int32Array::from(vec![3, 1, 2]));
let name_array = Arc::new(StringArray::from(vec!["c", "a", "b"]));
let partition_array = Arc::new(StructArray::from(vec![(
Arc::new(Field::new("id_partition", DataType::Int32, false)),
Arc::new(Int32Array::from(vec![3, 1, 2])) as _,
)]));
let batch =
RecordBatch::try_new(schema.clone(), vec![id_array, name_array, partition_array])
.unwrap();
let ctx = SessionContext::new();
let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap();
let input = mem_table.scan(&ctx.state(), None, &[], None).await.unwrap();
let sorted_plan = sort_by_partition(input).unwrap();
let result = datafusion::physical_plan::collect(sorted_plan, ctx.task_ctx())
.await
.unwrap();
assert_eq!(result.len(), 1);
let result_batch = &result[0];
let id_col = result_batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(id_col.value(0), 1);
assert_eq!(id_col.value(1), 2);
assert_eq!(id_col.value(2), 3);
}
#[tokio::test]
async fn test_sort_by_partition_missing_column() {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(schema.clone(), vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
])
.unwrap();
let ctx = SessionContext::new();
let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap();
let input = mem_table.scan(&ctx.state(), None, &[], None).await.unwrap();
let result = sort_by_partition(input);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Partition column '_partition' not found")
);
}
#[tokio::test]
async fn test_sort_by_partition_multi_field() {
let partition_fields = Fields::from(vec![
Field::new("year", DataType::Int32, false),
Field::new("month", DataType::Int32, false),
]);
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("data", DataType::Utf8, false),
Field::new(
PROJECTED_PARTITION_VALUE_COLUMN,
DataType::Struct(partition_fields.clone()),
false,
),
]));
let id_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
let data_array = Arc::new(StringArray::from(vec!["a", "b", "c", "d"]));
let year_array = Arc::new(Int32Array::from(vec![2024, 2024, 2023, 2024]));
let month_array = Arc::new(Int32Array::from(vec![2, 1, 12, 1]));
let partition_array = Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("year", DataType::Int32, false)),
year_array as _,
),
(
Arc::new(Field::new("month", DataType::Int32, false)),
month_array as _,
),
]));
let batch =
RecordBatch::try_new(schema.clone(), vec![id_array, data_array, partition_array])
.unwrap();
let ctx = SessionContext::new();
let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap();
let input = mem_table.scan(&ctx.state(), None, &[], None).await.unwrap();
let sorted_plan = sort_by_partition(input).unwrap();
let result = datafusion::physical_plan::collect(sorted_plan, ctx.task_ctx())
.await
.unwrap();
assert_eq!(result.len(), 1);
let result_batch = &result[0];
let id_col = result_batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
assert_eq!(id_col.value(0), 3);
assert_eq!(id_col.value(1), 2);
assert_eq!(id_col.value(2), 4);
assert_eq!(id_col.value(3), 1);
}
}