use arrow::datatypes::{DataType, Field, Schema};
use criterion::{Criterion, criterion_group, criterion_main};
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr::simplifier::PhysicalExprSimplifier;
use std::hint::black_box;
use std::sync::Arc;
use datafusion_common::ScalarValue;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{
BinaryExpr, CaseExpr, Column, IsNullExpr, Literal,
};
fn catalog_sales_schema() -> Schema {
Schema::new(vec![
Field::new("cs_sold_date_sk", DataType::Int64, true), Field::new("cs_sold_time_sk", DataType::Int64, true), Field::new("cs_ship_date_sk", DataType::Int64, true), Field::new("cs_bill_customer_sk", DataType::Int64, true), Field::new("cs_bill_cdemo_sk", DataType::Int64, true), Field::new("cs_bill_hdemo_sk", DataType::Int64, true), Field::new("cs_bill_addr_sk", DataType::Int64, true), Field::new("cs_ship_customer_sk", DataType::Int64, true), Field::new("cs_ship_cdemo_sk", DataType::Int64, true), Field::new("cs_ship_hdemo_sk", DataType::Int64, true), Field::new("cs_ship_addr_sk", DataType::Int64, true), Field::new("cs_call_center_sk", DataType::Int64, true), Field::new("cs_catalog_page_sk", DataType::Int64, true), Field::new("cs_ship_mode_sk", DataType::Int64, true), Field::new("cs_warehouse_sk", DataType::Int64, true), Field::new("cs_item_sk", DataType::Int64, true), Field::new("cs_promo_sk", DataType::Int64, true), Field::new("cs_order_number", DataType::Int64, true), Field::new("cs_quantity", DataType::Int64, true), Field::new("cs_wholesale_cost", DataType::Decimal128(7, 2), true),
Field::new("cs_list_price", DataType::Decimal128(7, 2), true),
Field::new("cs_sales_price", DataType::Decimal128(7, 2), true),
Field::new("cs_ext_discount_amt", DataType::Decimal128(7, 2), true),
Field::new("cs_ext_sales_price", DataType::Decimal128(7, 2), true),
Field::new("cs_ext_wholesale_cost", DataType::Decimal128(7, 2), true),
Field::new("cs_ext_list_price", DataType::Decimal128(7, 2), true),
Field::new("cs_ext_tax", DataType::Decimal128(7, 2), true),
Field::new("cs_coupon_amt", DataType::Decimal128(7, 2), true),
Field::new("cs_ext_ship_cost", DataType::Decimal128(7, 2), true),
Field::new("cs_net_paid", DataType::Decimal128(7, 2), true),
Field::new("cs_net_paid_inc_tax", DataType::Decimal128(7, 2), true),
Field::new("cs_net_paid_inc_ship", DataType::Decimal128(7, 2), true),
Field::new("cs_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true),
Field::new("cs_net_profit", DataType::Decimal128(7, 2), true),
])
}
fn web_sales_schema() -> Schema {
Schema::new(vec![
Field::new("ws_sold_date_sk", DataType::Int64, true),
Field::new("ws_sold_time_sk", DataType::Int64, true),
Field::new("ws_ship_date_sk", DataType::Int64, true),
Field::new("ws_item_sk", DataType::Int64, true),
Field::new("ws_bill_customer_sk", DataType::Int64, true),
Field::new("ws_bill_cdemo_sk", DataType::Int64, true),
Field::new("ws_bill_hdemo_sk", DataType::Int64, true),
Field::new("ws_bill_addr_sk", DataType::Int64, true),
Field::new("ws_ship_customer_sk", DataType::Int64, true),
Field::new("ws_ship_cdemo_sk", DataType::Int64, true),
Field::new("ws_ship_hdemo_sk", DataType::Int64, true),
Field::new("ws_ship_addr_sk", DataType::Int64, true),
Field::new("ws_web_page_sk", DataType::Int64, true),
Field::new("ws_web_site_sk", DataType::Int64, true),
Field::new("ws_ship_mode_sk", DataType::Int64, true),
Field::new("ws_warehouse_sk", DataType::Int64, true),
Field::new("ws_promo_sk", DataType::Int64, true),
Field::new("ws_order_number", DataType::Int64, true),
Field::new("ws_quantity", DataType::Int64, true),
Field::new("ws_wholesale_cost", DataType::Decimal128(7, 2), true),
Field::new("ws_list_price", DataType::Decimal128(7, 2), true),
Field::new("ws_sales_price", DataType::Decimal128(7, 2), true),
Field::new("ws_ext_discount_amt", DataType::Decimal128(7, 2), true),
Field::new("ws_ext_sales_price", DataType::Decimal128(7, 2), true),
Field::new("ws_ext_wholesale_cost", DataType::Decimal128(7, 2), true),
Field::new("ws_ext_list_price", DataType::Decimal128(7, 2), true),
Field::new("ws_ext_tax", DataType::Decimal128(7, 2), true),
Field::new("ws_coupon_amt", DataType::Decimal128(7, 2), true),
Field::new("ws_ext_ship_cost", DataType::Decimal128(7, 2), true),
Field::new("ws_net_paid", DataType::Decimal128(7, 2), true),
Field::new("ws_net_paid_inc_tax", DataType::Decimal128(7, 2), true),
Field::new("ws_net_paid_inc_ship", DataType::Decimal128(7, 2), true),
Field::new("ws_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true),
Field::new("ws_net_profit", DataType::Decimal128(7, 2), true),
])
}
fn lit_i64(val: i64) -> Arc<dyn PhysicalExpr> {
Arc::new(Literal::new(ScalarValue::Int64(Some(val))))
}
fn lit_i32(val: i32) -> Arc<dyn PhysicalExpr> {
Arc::new(Literal::new(ScalarValue::Int32(Some(val))))
}
fn lit_bool(val: bool) -> Arc<dyn PhysicalExpr> {
Arc::new(Literal::new(ScalarValue::Boolean(Some(val))))
}
fn and(
left: Arc<dyn PhysicalExpr>,
right: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
Arc::new(BinaryExpr::new(left, Operator::And, right))
}
fn gte(
left: Arc<dyn PhysicalExpr>,
right: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
Arc::new(BinaryExpr::new(left, Operator::GtEq, right))
}
fn lte(
left: Arc<dyn PhysicalExpr>,
right: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
Arc::new(BinaryExpr::new(left, Operator::LtEq, right))
}
fn modulo(
left: Arc<dyn PhysicalExpr>,
right: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
Arc::new(BinaryExpr::new(left, Operator::Modulo, right))
}
fn eq(
left: Arc<dyn PhysicalExpr>,
right: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
Arc::new(BinaryExpr::new(left, Operator::Eq, right))
}
pub fn catalog_sales_predicate(num_partitions: usize) -> Arc<dyn PhysicalExpr> {
let cs_sold_date_sk: Arc<dyn PhysicalExpr> =
Arc::new(Column::new("cs_sold_date_sk", 0));
let cs_ship_addr_sk: Arc<dyn PhysicalExpr> =
Arc::new(Column::new("cs_ship_addr_sk", 10));
let cs_item_sk: Arc<dyn PhysicalExpr> = Arc::new(Column::new("cs_item_sk", 15));
let item_hash_mod = modulo(cs_item_sk.clone(), lit_i64(num_partitions as i64));
let date_hash_mod = modulo(cs_sold_date_sk.clone(), lit_i64(num_partitions as i64));
let is_null_expr: Arc<dyn PhysicalExpr> = Arc::new(IsNullExpr::new(cs_ship_addr_sk));
let item_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0
..num_partitions)
.map(|partition| {
let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32));
let then_expr = and(
gte(cs_item_sk.clone(), lit_i64(partition as i64)),
lte(cs_item_sk.clone(), lit_i64(18000)),
);
(when_expr, then_expr)
})
.collect();
let item_case_expr: Arc<dyn PhysicalExpr> =
Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap());
let date_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0
..num_partitions)
.map(|partition| {
let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32));
let then_expr = and(
gte(cs_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)),
lte(cs_sold_date_sk.clone(), lit_i64(2488070)),
);
(when_expr, then_expr)
})
.collect();
let date_case_expr: Arc<dyn PhysicalExpr> =
Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap());
and(and(is_null_expr, item_case_expr), date_case_expr)
}
fn web_sales_predicate(num_partitions: usize) -> Arc<dyn PhysicalExpr> {
let ws_sold_date_sk: Arc<dyn PhysicalExpr> =
Arc::new(Column::new("ws_sold_date_sk", 0));
let ws_item_sk: Arc<dyn PhysicalExpr> = Arc::new(Column::new("ws_item_sk", 3));
let ws_ship_customer_sk: Arc<dyn PhysicalExpr> =
Arc::new(Column::new("ws_ship_customer_sk", 8));
let item_hash_mod = modulo(ws_item_sk.clone(), lit_i64(num_partitions as i64));
let date_hash_mod = modulo(ws_sold_date_sk.clone(), lit_i64(num_partitions as i64));
let is_null_expr: Arc<dyn PhysicalExpr> =
Arc::new(IsNullExpr::new(ws_ship_customer_sk));
let item_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0
..num_partitions)
.map(|partition| {
let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32));
let then_expr = and(
gte(ws_item_sk.clone(), lit_i64(partition as i64)),
lte(ws_item_sk.clone(), lit_i64(18000)),
);
(when_expr, then_expr)
})
.collect();
let item_case_expr: Arc<dyn PhysicalExpr> =
Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap());
let date_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0
..num_partitions)
.map(|partition| {
let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32));
let then_expr = and(
gte(ws_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)),
lte(ws_sold_date_sk.clone(), lit_i64(2488070)),
);
(when_expr, then_expr)
})
.collect();
let date_case_expr: Arc<dyn PhysicalExpr> =
Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap());
and(and(is_null_expr, item_case_expr), date_case_expr)
}
fn bench_simplify(
c: &mut Criterion,
name: &str,
schema: &Schema,
expr: &Arc<dyn PhysicalExpr>,
) {
let simplifier = PhysicalExprSimplifier::new(schema);
c.bench_function(name, |b| {
b.iter(|| black_box(simplifier.simplify(black_box(Arc::clone(expr))).unwrap()))
});
}
fn criterion_benchmark(c: &mut Criterion) {
let cs_schema = catalog_sales_schema();
let ws_schema = web_sales_schema();
for num_partitions in [16, 128] {
bench_simplify(
c,
&format!("tpc-ds/q76/cs/{num_partitions}"),
&cs_schema,
&catalog_sales_predicate(num_partitions),
);
bench_simplify(
c,
&format!("tpc-ds/q76/ws/{num_partitions}"),
&ws_schema,
&web_sales_predicate(num_partitions),
);
}
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);