datafusion-distributed 2.0.0

Framework for enhancing Apache DataFusion with distributed capabilities
Documentation
use crate::distributed_planner::NetworkBoundaryExt;
use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion::error::Result;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::metrics::MetricsSet;
use std::sync::Arc;

/// Collects per-node metrics from the given plan via pre-order traversal,
/// stopping at network boundary nodes (which have no subtree on the coordinator side).
pub fn collect_plan_metrics(plan: &Arc<dyn ExecutionPlan>) -> Result<Vec<MetricsSet>> {
    let mut metrics = Vec::new();
    plan.apply(|node| {
        metrics.push(node.metrics().unwrap_or_default());
        if node.is_network_boundary() {
            return Ok(TreeNodeRecursion::Jump);
        }
        Ok(TreeNodeRecursion::Continue)
    })?;
    Ok(metrics)
}

#[cfg(test)]
mod tests {

    use super::*;
    use arrow::datatypes::UInt16Type;
    use datafusion::arrow::array::{Int32Array, StringArray};
    use datafusion::arrow::record_batch::RecordBatch;
    use futures::StreamExt;

    use crate::coordinator::DistributedExec;
    use crate::test_utils::in_memory_channel_resolver::{
        InMemoryChannelResolver, InMemoryWorkerResolver,
    };
    use crate::test_utils::parquet::register_parquet_tables;
    use crate::test_utils::plans::{
        count_plan_nodes_up_to_network_boundary, get_stages_and_task_keys,
    };
    use crate::test_utils::session_context::register_temp_parquet_table;
    use crate::{DistributedExt, SessionStateBuilderExt};
    use datafusion::execution::{SessionStateBuilder, context::SessionContext};
    use datafusion::prelude::SessionConfig;
    use datafusion::{
        arrow::datatypes::{DataType, Field, Schema},
        physical_plan::display::DisplayableExecutionPlan,
    };
    use std::sync::Arc;

    /// Creates a session context and registers two tables:
    /// - table1 (id: int, name: string)
    /// - table2 (id: int, name: string, phone: string, balance: float64)
    async fn make_test_ctx() -> SessionContext {
        // Create distributed session state with in-memory channel resolver
        let config = SessionConfig::new().with_target_partitions(2);

        let state = SessionStateBuilder::new()
            .with_default_features()
            .with_config(config)
            .with_distributed_worker_resolver(InMemoryWorkerResolver::new(10))
            .with_distributed_channel_resolver(InMemoryChannelResolver::default())
            .with_distributed_planner()
            .with_distributed_task_estimator(2)
            .with_distributed_metrics_collection(true)
            .unwrap()
            .build();

        let ctx = SessionContext::from(state);

        // Create test data for table1
        let schema1 = Arc::new(Schema::new(vec![
            Field::new("id", DataType::Int32, false),
            Field::new("name", DataType::Utf8, false),
        ]));

        let batches1 = vec![
            RecordBatch::try_new(
                schema1.clone(),
                vec![
                    Arc::new(Int32Array::from(vec![1, 2, 3])),
                    Arc::new(StringArray::from(vec!["a", "b", "c"])),
                ],
            )
            .unwrap(),
        ];

        // Create test data for table2 with extended schema
        let schema2 = Arc::new(Schema::new(vec![
            Field::new("id", DataType::Int32, false),
            Field::new("name", DataType::Utf8, false),
            Field::new("phone", DataType::Utf8, false),
            Field::new("balance", DataType::Float64, false),
            Field::new(
                "company",
                DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
                false,
            ),
        ]));

        let batches2 = vec![
            RecordBatch::try_new(
                schema2.clone(),
                vec![
                    Arc::new(Int32Array::from(vec![1, 2, 3])),
                    Arc::new(StringArray::from(vec![
                        "customer1",
                        "customer2",
                        "customer3",
                    ])),
                    Arc::new(StringArray::from(vec![
                        "13-123-4567",
                        "31-456-7890",
                        "23-789-0123",
                    ])),
                    Arc::new(datafusion::arrow::array::Float64Array::from(vec![
                        100.5, 250.0, 50.25,
                    ])),
                    Arc::new(
                        vec!["company1", "company1", "company1"]
                            .into_iter()
                            .collect::<arrow::array::DictionaryArray<UInt16Type>>(),
                    ),
                ],
            )
            .unwrap(),
        ];

        // Register the test data as parquet tables
        let _ = register_temp_parquet_table("table1", schema1, batches1, &ctx)
            .await
            .unwrap();

        let _ = register_temp_parquet_table("table2", schema2, batches2, &ctx)
            .await
            .unwrap();

        ctx
    }

    async fn execute_plan(stage_exec: Arc<dyn ExecutionPlan>, ctx: &SessionContext) {
        let task_ctx = ctx.task_ctx();
        let stream = stage_exec.execute(0, task_ctx).unwrap();

        let schema = stream.schema();

        let mut stream = stream;
        while let Some(batch) = stream.next().await {
            let batch = batch.unwrap();

            assert_eq!(schema, batch.schema())
        }
    }

    /// Asserts that we can collect metrics from a distributed plan generated from the
    /// SQL query. It ensures that metrics are collected for all stages and are propagated
    /// through network boundaries.
    async fn run_metrics_collection_e2e_test(sql: &str) {
        // Plan and execute the query
        let ctx = make_test_ctx().await;
        let df = ctx.sql(sql).await.unwrap();
        let plan = df.create_physical_plan().await.unwrap();
        execute_plan(plan.clone(), &ctx).await;

        let dist_exec = plan
            .downcast_ref::<DistributedExec>()
            .expect("expected DistributedExec");

        // Assert to ensure the distributed test case is sufficiently complex.
        let (stages, expected_task_keys) = get_stages_and_task_keys(dist_exec);
        assert!(
            expected_task_keys.len() > 1,
            "expected more than 1 task key in test. the plan was not distributed):\n{}",
            DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
        );

        // Per-task metrics are delivered asynchronously over the `WorkerToCoordinator` side
        // channel after execution completes; await that delivery instead of racing it (see #487).
        dist_exec.wait_for_metrics().await;

        let metrics_store = dist_exec.metrics_store.as_ref().unwrap();

        // Ensure that there's metrics for each node for each task for each stage.
        for expected_task_key in expected_task_keys {
            let actual_metrics = metrics_store.get(&expected_task_key).unwrap();

            // Verify that metrics were collected for all nodes. Some nodes may legitimately have
            // empty metrics (e.g., custom execution plans without metrics), which is fine - we
            // just verify that a metrics set exists for each node. The count assertion above
            // ensures all nodes are included in the metrics collection.
            let stage = stages.get(&(expected_task_key.stage_id as usize)).unwrap();
            let stage_plan = stage.local_plan().unwrap();
            assert_eq!(
                actual_metrics.pre_order_plan_metrics.len(),
                count_plan_nodes_up_to_network_boundary(stage_plan),
                "Mismatch between collected metrics and actual nodes for {expected_task_key:?}"
            );
        }
    }

    #[tokio::test]
    async fn test_metrics_collection_e2e_1() {
        run_metrics_collection_e2e_test("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10").await;
    }

    // Skip this test, it's failing after upgrading to datafusion 50
    // See https://github.com/datafusion-contrib/datafusion-distributed/pull/146#issuecomment-3356621629
    #[tokio::test]
    async fn test_metrics_collection_e2e_2() {
        run_metrics_collection_e2e_test(
            "SELECT sum(balance) / 7.0 as avg_yearly
            FROM table2
            WHERE name LIKE 'customer%'
              AND balance < (
                SELECT 0.2 * avg(balance)
                FROM table2 t2_inner
                WHERE t2_inner.id = table2.id
              )",
        )
        .await;
    }

    // DF 54 split ScalarSubqueryExpr from its surrounding ScalarSubqueryExec: the expression
    // holds an index into a ScalarSubqueryResults container owned by the surrounding node, and
    // deserialization needs that node present. Our stage boundaries can land between them,
    // shipping a FilterExec subtree without its ScalarSubqueryExec, which fails to decode.
    // Tracked as follow-up to the DF 54 upgrade.
    #[tokio::test]
    #[ignore]
    async fn test_metrics_collection_e2e_3() {
        run_metrics_collection_e2e_test(
            "SELECT
                substring(phone, 1, 2) as country_code,
                count(*) as num_customers,
                sum(balance) as total_balance
            FROM table2
            WHERE substring(phone, 1, 2) IN ('13', '31', '23', '29', '30', '18')
              AND balance > (
                SELECT avg(balance)
                FROM table2
                WHERE balance > 0.00
              )
            GROUP BY substring(phone, 1, 2)
            ORDER BY country_code",
        )
        .await;
    }

    /// Skipped due to https://github.com/apache/datafusion/issues/14218
    ///
    /// When aggregating on a dictionary column (ex. `company` in this case which is Dict<UInt16, Utf8>),
    /// the aggregation seems to be outputting Utf8. Some assertion fails due to this, even in
    /// single node execution:
    /// "column types must match schema types, expected Dictionary(UInt16, Utf8) but found Utf8 at column index 0"
    #[tokio::test]
    async fn test_metrics_collection_e2e_4() {
        run_metrics_collection_e2e_test("SELECT distinct company from table2").await;
    }

    /// Tests whether metrics are preserved when a LIMIT causes early stream termination.
    ///
    /// Issue: https://github.com/datafusion-contrib/datafusion-distributed/issues/187
    ///
    /// Previously, metrics were piggybacked on the last FlightData message of the last partition
    /// stream. If a LIMIT caused the client-side stream to be dropped before the worker finished,
    /// the last message (carrying metrics) was never received.
    ///
    /// Now metrics are sent via the WorkerToCoordinator side channel, so they are always
    /// delivered regardless of early stream termination.
    ///
    /// This uses the `flights_1m` dataset (1M rows) so the worker is still producing data
    /// when the LIMIT causes the client to drop the stream.
    #[tokio::test]
    async fn test_metrics_collection_with_limit_causing_early_stream_termination() {
        let ctx = make_test_ctx().await;
        register_parquet_tables(&ctx).await.unwrap();

        // GROUP BY forces a network shuffle; LIMIT 1 causes early stream termination.
        let sql =
            "SELECT \"FL_DATE\", COUNT(*) as cnt FROM flights_1m GROUP BY \"FL_DATE\" LIMIT 1";

        let df = ctx.sql(sql).await.unwrap();
        let plan = df.create_physical_plan().await.unwrap();

        let dist_exec = plan
            .downcast_ref::<DistributedExec>()
            .expect("expected DistributedExec");

        let (stages, expected_task_keys) = get_stages_and_task_keys(dist_exec);
        assert!(
            expected_task_keys.len() > 1,
            "expected more than 1 task key. Plan was not distributed:\n{}",
            DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
        );

        execute_plan(plan.clone(), &ctx).await;

        // Metrics are delivered via the WorkerToCoordinator side channel in a background task.
        // Wait for that delivery to complete before asserting, rather than racing it.
        dist_exec.wait_for_metrics().await;
        let metrics_store = dist_exec.metrics_store.as_ref().unwrap();

        for expected_task_key in &expected_task_keys {
            let actual_metrics = metrics_store.get(expected_task_key).unwrap_or_else(|| {
                panic!(
                    "Missing metrics for task key {expected_task_key:?}. \
                         The LIMIT caused the stream to be dropped before the worker \
                         sent metrics via the coordinator channel."
                )
            });
            let stage = stages.get(&(expected_task_key.stage_id as usize)).unwrap();
            let stage_plan = stage.local_plan().unwrap();
            assert_eq!(
                actual_metrics.pre_order_plan_metrics.len(),
                count_plan_nodes_up_to_network_boundary(stage_plan),
                "Mismatch between collected metrics and actual nodes for {expected_task_key:?}"
            );
        }
    }
}