datafusion-physical-expr 53.1.0

Physical expression implementation for DataFusion query engine
Documentation
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

//! This is an attempt at reproducing some predicates generated by TPC-DS query #76,
//! and trying to figure out how long it takes to simplify them.

use arrow::datatypes::{DataType, Field, Schema};
use criterion::{Criterion, criterion_group, criterion_main};
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr::simplifier::PhysicalExprSimplifier;
use std::hint::black_box;
use std::sync::Arc;

use datafusion_common::ScalarValue;
use datafusion_expr::Operator;

use datafusion_physical_expr::expressions::{
    BinaryExpr, CaseExpr, Column, IsNullExpr, Literal,
};

fn catalog_sales_schema() -> Schema {
    Schema::new(vec![
        Field::new("cs_sold_date_sk", DataType::Int64, true), // 0
        Field::new("cs_sold_time_sk", DataType::Int64, true), // 1
        Field::new("cs_ship_date_sk", DataType::Int64, true), // 2
        Field::new("cs_bill_customer_sk", DataType::Int64, true), // 3
        Field::new("cs_bill_cdemo_sk", DataType::Int64, true), // 4
        Field::new("cs_bill_hdemo_sk", DataType::Int64, true), // 5
        Field::new("cs_bill_addr_sk", DataType::Int64, true), // 6
        Field::new("cs_ship_customer_sk", DataType::Int64, true), // 7
        Field::new("cs_ship_cdemo_sk", DataType::Int64, true), // 8
        Field::new("cs_ship_hdemo_sk", DataType::Int64, true), // 9
        Field::new("cs_ship_addr_sk", DataType::Int64, true), // 10
        Field::new("cs_call_center_sk", DataType::Int64, true), // 11
        Field::new("cs_catalog_page_sk", DataType::Int64, true), // 12
        Field::new("cs_ship_mode_sk", DataType::Int64, true), // 13
        Field::new("cs_warehouse_sk", DataType::Int64, true), // 14
        Field::new("cs_item_sk", DataType::Int64, true),      // 15
        Field::new("cs_promo_sk", DataType::Int64, true),     // 16
        Field::new("cs_order_number", DataType::Int64, true), // 17
        Field::new("cs_quantity", DataType::Int64, true),     // 18
        Field::new("cs_wholesale_cost", DataType::Decimal128(7, 2), true),
        Field::new("cs_list_price", DataType::Decimal128(7, 2), true),
        Field::new("cs_sales_price", DataType::Decimal128(7, 2), true),
        Field::new("cs_ext_discount_amt", DataType::Decimal128(7, 2), true),
        Field::new("cs_ext_sales_price", DataType::Decimal128(7, 2), true),
        Field::new("cs_ext_wholesale_cost", DataType::Decimal128(7, 2), true),
        Field::new("cs_ext_list_price", DataType::Decimal128(7, 2), true),
        Field::new("cs_ext_tax", DataType::Decimal128(7, 2), true),
        Field::new("cs_coupon_amt", DataType::Decimal128(7, 2), true),
        Field::new("cs_ext_ship_cost", DataType::Decimal128(7, 2), true),
        Field::new("cs_net_paid", DataType::Decimal128(7, 2), true),
        Field::new("cs_net_paid_inc_tax", DataType::Decimal128(7, 2), true),
        Field::new("cs_net_paid_inc_ship", DataType::Decimal128(7, 2), true),
        Field::new("cs_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true),
        Field::new("cs_net_profit", DataType::Decimal128(7, 2), true),
    ])
}

fn web_sales_schema() -> Schema {
    Schema::new(vec![
        Field::new("ws_sold_date_sk", DataType::Int64, true),
        Field::new("ws_sold_time_sk", DataType::Int64, true),
        Field::new("ws_ship_date_sk", DataType::Int64, true),
        Field::new("ws_item_sk", DataType::Int64, true),
        Field::new("ws_bill_customer_sk", DataType::Int64, true),
        Field::new("ws_bill_cdemo_sk", DataType::Int64, true),
        Field::new("ws_bill_hdemo_sk", DataType::Int64, true),
        Field::new("ws_bill_addr_sk", DataType::Int64, true),
        Field::new("ws_ship_customer_sk", DataType::Int64, true),
        Field::new("ws_ship_cdemo_sk", DataType::Int64, true),
        Field::new("ws_ship_hdemo_sk", DataType::Int64, true),
        Field::new("ws_ship_addr_sk", DataType::Int64, true),
        Field::new("ws_web_page_sk", DataType::Int64, true),
        Field::new("ws_web_site_sk", DataType::Int64, true),
        Field::new("ws_ship_mode_sk", DataType::Int64, true),
        Field::new("ws_warehouse_sk", DataType::Int64, true),
        Field::new("ws_promo_sk", DataType::Int64, true),
        Field::new("ws_order_number", DataType::Int64, true),
        Field::new("ws_quantity", DataType::Int64, true),
        Field::new("ws_wholesale_cost", DataType::Decimal128(7, 2), true),
        Field::new("ws_list_price", DataType::Decimal128(7, 2), true),
        Field::new("ws_sales_price", DataType::Decimal128(7, 2), true),
        Field::new("ws_ext_discount_amt", DataType::Decimal128(7, 2), true),
        Field::new("ws_ext_sales_price", DataType::Decimal128(7, 2), true),
        Field::new("ws_ext_wholesale_cost", DataType::Decimal128(7, 2), true),
        Field::new("ws_ext_list_price", DataType::Decimal128(7, 2), true),
        Field::new("ws_ext_tax", DataType::Decimal128(7, 2), true),
        Field::new("ws_coupon_amt", DataType::Decimal128(7, 2), true),
        Field::new("ws_ext_ship_cost", DataType::Decimal128(7, 2), true),
        Field::new("ws_net_paid", DataType::Decimal128(7, 2), true),
        Field::new("ws_net_paid_inc_tax", DataType::Decimal128(7, 2), true),
        Field::new("ws_net_paid_inc_ship", DataType::Decimal128(7, 2), true),
        Field::new("ws_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true),
        Field::new("ws_net_profit", DataType::Decimal128(7, 2), true),
    ])
}

// Helper to create a literal
fn lit_i64(val: i64) -> Arc<dyn PhysicalExpr> {
    Arc::new(Literal::new(ScalarValue::Int64(Some(val))))
}

fn lit_i32(val: i32) -> Arc<dyn PhysicalExpr> {
    Arc::new(Literal::new(ScalarValue::Int32(Some(val))))
}

fn lit_bool(val: bool) -> Arc<dyn PhysicalExpr> {
    Arc::new(Literal::new(ScalarValue::Boolean(Some(val))))
}

// Helper to create binary expressions
fn and(
    left: Arc<dyn PhysicalExpr>,
    right: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
    Arc::new(BinaryExpr::new(left, Operator::And, right))
}

fn gte(
    left: Arc<dyn PhysicalExpr>,
    right: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
    Arc::new(BinaryExpr::new(left, Operator::GtEq, right))
}

fn lte(
    left: Arc<dyn PhysicalExpr>,
    right: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
    Arc::new(BinaryExpr::new(left, Operator::LtEq, right))
}

fn modulo(
    left: Arc<dyn PhysicalExpr>,
    right: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
    Arc::new(BinaryExpr::new(left, Operator::Modulo, right))
}

fn eq(
    left: Arc<dyn PhysicalExpr>,
    right: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
    Arc::new(BinaryExpr::new(left, Operator::Eq, right))
}

/// Build a predicate similar to TPC-DS q76 catalog_sales filter.
/// Uses placeholder columns instead of hash expressions.
pub fn catalog_sales_predicate(num_partitions: usize) -> Arc<dyn PhysicalExpr> {
    let cs_sold_date_sk: Arc<dyn PhysicalExpr> =
        Arc::new(Column::new("cs_sold_date_sk", 0));
    let cs_ship_addr_sk: Arc<dyn PhysicalExpr> =
        Arc::new(Column::new("cs_ship_addr_sk", 10));
    let cs_item_sk: Arc<dyn PhysicalExpr> = Arc::new(Column::new("cs_item_sk", 15));

    // Use a simple modulo expression as placeholder for hash
    let item_hash_mod = modulo(cs_item_sk.clone(), lit_i64(num_partitions as i64));
    let date_hash_mod = modulo(cs_sold_date_sk.clone(), lit_i64(num_partitions as i64));

    // cs_ship_addr_sk IS NULL
    let is_null_expr: Arc<dyn PhysicalExpr> = Arc::new(IsNullExpr::new(cs_ship_addr_sk));

    // Build item_sk CASE expression with num_partitions branches
    let item_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0
        ..num_partitions)
        .map(|partition| {
            let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32));
            let then_expr = and(
                gte(cs_item_sk.clone(), lit_i64(partition as i64)),
                lte(cs_item_sk.clone(), lit_i64(18000)),
            );
            (when_expr, then_expr)
        })
        .collect();

    let item_case_expr: Arc<dyn PhysicalExpr> =
        Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap());

    // Build sold_date_sk CASE expression with num_partitions branches
    let date_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0
        ..num_partitions)
        .map(|partition| {
            let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32));
            let then_expr = and(
                gte(cs_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)),
                lte(cs_sold_date_sk.clone(), lit_i64(2488070)),
            );
            (when_expr, then_expr)
        })
        .collect();

    let date_case_expr: Arc<dyn PhysicalExpr> =
        Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap());

    // Final: is_null AND item_case AND date_case
    and(and(is_null_expr, item_case_expr), date_case_expr)
}
/// Build a predicate similar to TPC-DS q76 web_sales filter.
/// Uses placeholder columns instead of hash expressions.
fn web_sales_predicate(num_partitions: usize) -> Arc<dyn PhysicalExpr> {
    let ws_sold_date_sk: Arc<dyn PhysicalExpr> =
        Arc::new(Column::new("ws_sold_date_sk", 0));
    let ws_item_sk: Arc<dyn PhysicalExpr> = Arc::new(Column::new("ws_item_sk", 3));
    let ws_ship_customer_sk: Arc<dyn PhysicalExpr> =
        Arc::new(Column::new("ws_ship_customer_sk", 8));

    // Use simple modulo expression as placeholder for hash
    let item_hash_mod = modulo(ws_item_sk.clone(), lit_i64(num_partitions as i64));
    let date_hash_mod = modulo(ws_sold_date_sk.clone(), lit_i64(num_partitions as i64));

    // ws_ship_customer_sk IS NULL
    let is_null_expr: Arc<dyn PhysicalExpr> =
        Arc::new(IsNullExpr::new(ws_ship_customer_sk));

    // Build item_sk CASE expression with num_partitions branches
    let item_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0
        ..num_partitions)
        .map(|partition| {
            let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32));
            let then_expr = and(
                gte(ws_item_sk.clone(), lit_i64(partition as i64)),
                lte(ws_item_sk.clone(), lit_i64(18000)),
            );
            (when_expr, then_expr)
        })
        .collect();

    let item_case_expr: Arc<dyn PhysicalExpr> =
        Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap());

    // Build sold_date_sk CASE expression with num_partitions branches
    let date_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0
        ..num_partitions)
        .map(|partition| {
            let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32));
            let then_expr = and(
                gte(ws_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)),
                lte(ws_sold_date_sk.clone(), lit_i64(2488070)),
            );
            (when_expr, then_expr)
        })
        .collect();

    let date_case_expr: Arc<dyn PhysicalExpr> =
        Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap());

    and(and(is_null_expr, item_case_expr), date_case_expr)
}

/// Measures how long `PhysicalExprSimplifier::simplify` takes for a given expression.
fn bench_simplify(
    c: &mut Criterion,
    name: &str,
    schema: &Schema,
    expr: &Arc<dyn PhysicalExpr>,
) {
    let simplifier = PhysicalExprSimplifier::new(schema);
    c.bench_function(name, |b| {
        b.iter(|| black_box(simplifier.simplify(black_box(Arc::clone(expr))).unwrap()))
    });
}

fn criterion_benchmark(c: &mut Criterion) {
    let cs_schema = catalog_sales_schema();
    let ws_schema = web_sales_schema();

    for num_partitions in [16, 128] {
        bench_simplify(
            c,
            &format!("tpc-ds/q76/cs/{num_partitions}"),
            &cs_schema,
            &catalog_sales_predicate(num_partitions),
        );
        bench_simplify(
            c,
            &format!("tpc-ds/q76/ws/{num_partitions}"),
            &ws_schema,
            &web_sales_predicate(num_partitions),
        );
    }
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);