use crate::NetworkBroadcastExec;
use crate::execution_plans::NetworkCoalesceExec;
use crate::execution_plans::NetworkShuffleExec;
use crate::worker::generated::worker as pb;
use crate::worker::generated::worker::TaskKey;
use datafusion::common::HashMap;
use datafusion::common::tree_node::Transformed;
use datafusion::common::tree_node::TreeNode;
use datafusion::common::tree_node::TreeNodeRecursion;
use datafusion::common::tree_node::TreeNodeRewriter;
use datafusion::error::DataFusionError;
use datafusion::error::Result;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::internal_err;
use datafusion::physical_plan::metrics::MetricsSet;
use std::sync::Arc;
pub struct TaskMetricsCollector {
task_metrics: Vec<MetricsSet>,
input_task_metrics: HashMap<TaskKey, Vec<pb::MetricsSet>>,
}
pub struct MetricsCollectorResult {
pub task_metrics: Vec<MetricsSet>,
pub input_task_metrics: HashMap<TaskKey, Vec<pb::MetricsSet>>,
}
impl TreeNodeRewriter for TaskMetricsCollector {
type Node = Arc<dyn ExecutionPlan>;
fn f_down(&mut self, plan: Self::Node) -> Result<Transformed<Self::Node>> {
match plan.metrics() {
Some(metrics) => self.task_metrics.push(metrics.clone()),
None => {
self.task_metrics.push(MetricsSet::new())
}
}
let metrics_collection =
if let Some(node) = plan.as_any().downcast_ref::<NetworkShuffleExec>() {
Some(Arc::clone(&node.metrics_collection))
} else if let Some(node) = plan.as_any().downcast_ref::<NetworkCoalesceExec>() {
Some(Arc::clone(&node.metrics_collection))
} else if let Some(node) = plan.as_any().downcast_ref::<NetworkBroadcastExec>() {
Some(Arc::clone(&node.metrics_collection))
} else {
None
};
if let Some(metrics_collection) = metrics_collection {
for mut entry in metrics_collection.iter_mut() {
let task_key = entry.key().clone();
let task_metrics = std::mem::take(entry.value_mut()); match self.input_task_metrics.get(&task_key) {
Some(_) => {
return internal_err!(
"duplicate task metrics for key {:?} during metrics collection",
task_key
);
}
None => {
self.input_task_metrics
.insert(task_key.clone(), task_metrics);
}
}
}
return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump));
}
Ok(Transformed::new(plan, false, TreeNodeRecursion::Continue))
}
}
impl TaskMetricsCollector {
pub fn new() -> Self {
Self {
task_metrics: Vec::new(),
input_task_metrics: HashMap::new(),
}
}
pub fn collect(
mut self,
plan: Arc<dyn ExecutionPlan>,
) -> Result<MetricsCollectorResult, DataFusionError> {
plan.rewrite(&mut self)?;
Ok(MetricsCollectorResult {
task_metrics: self.task_metrics,
input_task_metrics: self.input_task_metrics,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::UInt16Type;
use datafusion::arrow::array::{Int32Array, StringArray};
use datafusion::arrow::record_batch::RecordBatch;
use futures::StreamExt;
use crate::execution_plans::DistributedExec;
use crate::test_utils::in_memory_channel_resolver::{
InMemoryChannelResolver, InMemoryWorkerResolver,
};
use crate::test_utils::parquet::register_parquet_tables;
use crate::test_utils::plans::{
count_plan_nodes_up_to_network_boundary, get_stages_and_task_keys,
};
use crate::test_utils::session_context::register_temp_parquet_table;
use crate::{DistributedExt, DistributedPhysicalOptimizerRule};
use datafusion::execution::{SessionStateBuilder, context::SessionContext};
use datafusion::prelude::SessionConfig;
use datafusion::{
arrow::datatypes::{DataType, Field, Schema},
physical_plan::display::DisplayableExecutionPlan,
};
use std::sync::Arc;
async fn make_test_ctx() -> SessionContext {
let config = SessionConfig::new().with_target_partitions(2);
let state = SessionStateBuilder::new()
.with_default_features()
.with_config(config)
.with_distributed_worker_resolver(InMemoryWorkerResolver::new(10))
.with_distributed_channel_resolver(InMemoryChannelResolver::default())
.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
.with_distributed_task_estimator(2)
.with_distributed_metrics_collection(true)
.unwrap()
.build();
let ctx = SessionContext::from(state);
let schema1 = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let batches1 = vec![
RecordBatch::try_new(
schema1.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
],
)
.unwrap(),
];
let schema2 = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
Field::new("phone", DataType::Utf8, false),
Field::new("balance", DataType::Float64, false),
Field::new(
"company",
DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
false,
),
]));
let batches2 = vec![
RecordBatch::try_new(
schema2.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec![
"customer1",
"customer2",
"customer3",
])),
Arc::new(StringArray::from(vec![
"13-123-4567",
"31-456-7890",
"23-789-0123",
])),
Arc::new(datafusion::arrow::array::Float64Array::from(vec![
100.5, 250.0, 50.25,
])),
Arc::new(
vec!["company1", "company1", "company1"]
.into_iter()
.collect::<arrow::array::DictionaryArray<UInt16Type>>(),
),
],
)
.unwrap(),
];
let _ = register_temp_parquet_table("table1", schema1, batches1, &ctx)
.await
.unwrap();
let _ = register_temp_parquet_table("table2", schema2, batches2, &ctx)
.await
.unwrap();
ctx
}
async fn execute_plan(stage_exec: Arc<dyn ExecutionPlan>, ctx: &SessionContext) {
let task_ctx = ctx.task_ctx();
let stream = stage_exec.execute(0, task_ctx).unwrap();
let schema = stream.schema();
let mut stream = stream;
while let Some(batch) = stream.next().await {
let batch = batch.unwrap();
assert_eq!(schema, batch.schema())
}
}
async fn run_metrics_collection_e2e_test(sql: &str) {
let ctx = make_test_ctx().await;
let df = ctx.sql(sql).await.unwrap();
let plan = df.create_physical_plan().await.unwrap();
execute_plan(plan.clone(), &ctx).await;
let dist_exec = plan
.as_any()
.downcast_ref::<DistributedExec>()
.expect("expected DistributedExec");
let (stages, expected_task_keys) = get_stages_and_task_keys(dist_exec);
assert!(
expected_task_keys.len() > 1,
"expected more than 1 task key in test. the plan was not distributed):\n{}",
DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
);
let collector = TaskMetricsCollector::new();
let result = collector.collect(dist_exec.plan.clone()).unwrap();
for expected_task_key in expected_task_keys {
let actual_metrics = result.input_task_metrics.get(&expected_task_key).unwrap();
let stage = stages.get(&(expected_task_key.stage_id as usize)).unwrap();
let stage_plan = stage.plan.as_ref().unwrap();
assert_eq!(
actual_metrics.len(),
count_plan_nodes_up_to_network_boundary(stage_plan),
"Mismatch between collected metrics and actual nodes for {expected_task_key:?}"
);
}
}
#[tokio::test]
async fn test_metrics_collection_e2e_1() {
run_metrics_collection_e2e_test("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10").await;
}
#[tokio::test]
async fn test_metrics_collection_e2e_2() {
run_metrics_collection_e2e_test(
"SELECT sum(balance) / 7.0 as avg_yearly
FROM table2
WHERE name LIKE 'customer%'
AND balance < (
SELECT 0.2 * avg(balance)
FROM table2 t2_inner
WHERE t2_inner.id = table2.id
)",
)
.await;
}
#[tokio::test]
async fn test_metrics_collection_e2e_3() {
run_metrics_collection_e2e_test(
"SELECT
substring(phone, 1, 2) as country_code,
count(*) as num_customers,
sum(balance) as total_balance
FROM table2
WHERE substring(phone, 1, 2) IN ('13', '31', '23', '29', '30', '18')
AND balance > (
SELECT avg(balance)
FROM table2
WHERE balance > 0.00
)
GROUP BY substring(phone, 1, 2)
ORDER BY country_code",
)
.await;
}
#[tokio::test]
#[ignore]
async fn test_metrics_collection_e2e_4() {
run_metrics_collection_e2e_test("SELECT distinct company from table2").await;
}
#[tokio::test]
#[ignore]
async fn test_metrics_collection_with_limit_causing_early_stream_termination() {
let ctx = make_test_ctx().await;
register_parquet_tables(&ctx).await.unwrap();
let sql =
"SELECT \"FL_DATE\", COUNT(*) as cnt FROM flights_1m GROUP BY \"FL_DATE\" LIMIT 1";
let df = ctx.sql(sql).await.unwrap();
let plan = df.create_physical_plan().await.unwrap();
let dist_exec = plan
.as_any()
.downcast_ref::<DistributedExec>()
.expect("expected DistributedExec");
let (stages, expected_task_keys) = get_stages_and_task_keys(dist_exec);
assert!(
expected_task_keys.len() > 1,
"expected more than 1 task key. Plan was not distributed:\n{}",
DisplayableExecutionPlan::new(plan.as_ref()).indent(true)
);
execute_plan(plan.clone(), &ctx).await;
let collector = TaskMetricsCollector::new();
let result = collector.collect(dist_exec.plan.clone()).unwrap();
for expected_task_key in expected_task_keys {
let actual_metrics = result
.input_task_metrics
.get(&expected_task_key)
.unwrap_or_else(|| {
panic!(
"Missing metrics for task key {expected_task_key:?}. \
The LIMIT caused the stream to be dropped before the worker \
sent the last FlightData message with metrics."
)
});
let stage = stages.get(&(expected_task_key.stage_id as usize)).unwrap();
let stage_plan = stage.plan.as_ref().unwrap();
assert_eq!(
actual_metrics.len(),
count_plan_nodes_up_to_network_boundary(stage_plan),
"Mismatch between collected metrics and actual nodes for {expected_task_key:?}"
);
}
}
#[tokio::test]
async fn test_metrics_collection_with_partition_isolator() {
let ctx = make_test_ctx().await;
ctx.sql("SET distributed.children_isolator_unions=true;")
.await
.unwrap();
register_parquet_tables(&ctx).await.unwrap();
let query = r#"
SELECT "MinTemp" FROM weather WHERE "RainToday" = 'yes'
UNION ALL
SELECT "MaxTemp" FROM weather WHERE "RainToday" = 'no'
"#;
let df = ctx.sql(query).await.unwrap();
let plan = df.create_physical_plan().await.unwrap();
execute_plan(plan.clone(), &ctx).await;
let dist_exec = plan
.as_any()
.downcast_ref::<DistributedExec>()
.expect("expected DistributedExec");
let (stages, expected_task_keys) = get_stages_and_task_keys(dist_exec);
let collector = TaskMetricsCollector::new();
let result = collector.collect(dist_exec.plan.clone()).unwrap();
for expected_task_key in expected_task_keys {
let actual_metrics = result.input_task_metrics.get(&expected_task_key).unwrap();
let stage = stages.get(&(expected_task_key.stage_id as usize)).unwrap();
let stage_plan = stage.plan.as_ref().unwrap();
assert_eq!(
actual_metrics.len(),
count_plan_nodes_up_to_network_boundary(stage_plan),
"Metrics count must match plan nodes for stage {expected_task_key:?}"
);
}
}
}