use crate::DistributedTaskContext;
use crate::common::{TreeNodeExt, serialize_uuid};
use crate::coordinator::{DistributedExec, MetricsStore};
use crate::distributed_planner::NetworkBoundaryExt;
use crate::execution_plans::MetricsWrapperExec;
use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL;
use crate::metrics::collect_plan_metrics;
use crate::metrics::proto::metrics_set_proto_to_df;
use crate::stage::{LocalStage, Stage};
use crate::worker::generated::worker::TaskKey;
use datafusion::common::HashMap;
use datafusion::common::plan_err;
use datafusion::common::tree_node::Transformed;
use datafusion::common::tree_node::TreeNode;
use datafusion::common::tree_node::TreeNodeRecursion;
use datafusion::error::Result;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::internal_err;
use datafusion::physical_plan::metrics::{Label, Metric, MetricsSet};
use std::sync::Arc;
use std::vec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DistributedMetricsFormat {
Aggregated,
PerTask,
}
impl DistributedMetricsFormat {
pub(crate) fn to_rewrite_ctx(self, task_id: u64) -> RewriteCtx {
match self {
DistributedMetricsFormat::Aggregated => RewriteCtx::default(),
DistributedMetricsFormat::PerTask => RewriteCtx::from_task_id(task_id),
}
}
}
pub async fn rewrite_distributed_plan_with_metrics(
plan: Arc<dyn ExecutionPlan>,
format: DistributedMetricsFormat,
) -> Result<Arc<dyn ExecutionPlan>> {
let Some(distributed_exec) = plan.downcast_ref::<DistributedExec>() else {
return Ok(plan);
};
let prepared = distributed_exec.prepared_plan()?;
distributed_exec.wait_for_metrics().await;
let Some(metrics_collection) = distributed_exec.metrics_store.clone() else {
return Ok(plan);
};
let task_metrics = collect_plan_metrics(&prepared)?;
let dist_exec_plan_with_metrics = rewrite_local_plan_with_metrics(
format.to_rewrite_ctx(0), plan.children()[0].clone(),
task_metrics,
)?;
let plan = plan.with_new_children(vec![dist_exec_plan_with_metrics])?;
let transformed = plan.transform_down(|plan| {
let inner = plan
.downcast_ref::<MetricsWrapperExec>()
.map(|w| w.inner_arc())
.unwrap_or_else(|| Arc::clone(&plan));
if let Some(network_boundary) = inner.as_network_boundary() {
let Stage::Local(stage) = network_boundary.input_stage() else {
return plan_err!("Stage was not in Local state");
};
let plan_with_metrics =
stage_metrics_rewriter(stage, Arc::clone(&metrics_collection), format)?;
let network_boundary = network_boundary.with_input_stage(Stage::Local(LocalStage {
query_id: stage.query_id,
num: stage.num,
plan: plan_with_metrics,
tasks: stage.tasks,
}))?;
let network_boundary =
MetricsWrapperExec::new(network_boundary, plan.metrics().unwrap_or_default());
return Ok(Transformed::yes(Arc::new(network_boundary)));
}
Ok(Transformed::no(plan))
})?;
Ok(transformed.data)
}
#[derive(Default)]
pub struct RewriteCtx {
pub task_id: Option<u64>,
}
impl RewriteCtx {
pub(crate) fn from_task_id(task_id: u64) -> RewriteCtx {
RewriteCtx {
task_id: Some(task_id),
}
}
pub(crate) fn maybe_rewrite_node_metics(&self, node_metrics: MetricsSet) -> MetricsSet {
if let Some(task_id) = self.task_id {
return annotate_metrics_set_with_task_id(node_metrics, task_id);
}
node_metrics
}
}
pub fn annotate_metrics_set_with_task_id(metrics_set: MetricsSet, task_id: u64) -> MetricsSet {
let mut result = MetricsSet::new();
for metric in metrics_set.iter() {
let mut labels = metric.labels().to_vec();
labels.push(Label::new(
DISTRIBUTED_DATAFUSION_TASK_ID_LABEL,
task_id.to_string(),
));
result.push(Arc::new(Metric::new_with_labels(
metric.value().clone(),
metric.partition(),
labels,
)));
}
result
}
pub fn rewrite_local_plan_with_metrics(
ctx: RewriteCtx,
plan: Arc<dyn ExecutionPlan>,
metrics: Vec<MetricsSet>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut idx = 0;
Ok(plan
.transform_down(|node| {
if idx >= metrics.len() {
return internal_err!("not enough metrics provided to rewrite plan");
}
let mut node_metrics = metrics[idx].clone();
node_metrics = ctx.maybe_rewrite_node_metics(node_metrics);
idx += 1;
Ok(Transformed::new(
Arc::new(MetricsWrapperExec::new(node.clone(), node_metrics)),
true,
if node.is_network_boundary() {
TreeNodeRecursion::Jump
} else {
TreeNodeRecursion::Continue
},
))
})?
.data)
}
pub fn stage_metrics_rewriter(
stage: &LocalStage,
metrics_collection: Arc<MetricsStore>,
format: DistributedMetricsFormat,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut node_metrics_map: HashMap<usize, MetricsSet> = HashMap::new();
for task_id in 0..stage.tasks {
let d_ctx = DistributedTaskContext {
task_index: task_id,
task_count: stage.tasks,
};
let task_key = TaskKey {
query_id: serialize_uuid(&stage.query_id),
stage_id: stage.num as u64,
task_number: task_id as u64,
};
let Some(task_metrics) = metrics_collection.get(&task_key) else {
return internal_err!(
"not enough metrics provided to rewrite task: missing metrics for task {} in stage {}",
task_id,
stage.num
);
};
let mut per_task_counter = 0usize;
stage.plan.apply_with_dt_ctx(d_ctx, |node, _ctx| {
if per_task_counter >= task_metrics.pre_order_plan_metrics.len() {
return internal_err!(
"not enough metrics provided to rewrite task: {} metrics provided",
task_metrics.pre_order_plan_metrics.len()
);
}
let node_metrics_protos = task_metrics.pre_order_plan_metrics[per_task_counter].clone();
let mut node_metrics = metrics_set_proto_to_df(&node_metrics_protos)?;
let rewrite_ctx = format.to_rewrite_ctx(task_id as u64);
node_metrics = rewrite_ctx.maybe_rewrite_node_metics(node_metrics);
let id = Arc::as_ptr(node) as *const () as usize;
let entry = node_metrics_map.entry(id).or_default();
for metric in node_metrics.iter().map(Arc::clone) {
entry.push(metric);
}
per_task_counter += 1;
Ok(TreeNodeRecursion::Continue)
})?;
}
Arc::clone(&stage.plan)
.transform_down(|plan| {
let id = Arc::as_ptr(&plan) as *const () as usize;
let metrics = node_metrics_map.remove(&id).unwrap_or_default();
Ok(Transformed::new(
Arc::new(MetricsWrapperExec::new(plan.clone(), metrics)),
true,
match plan.is_network_boundary() {
true => TreeNodeRecursion::Jump,
false => TreeNodeRecursion::Continue,
},
))
})
.map(|v| v.data)
}
#[cfg(test)]
mod tests {
use crate::coordinator::MetricsStore;
use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL;
use crate::metrics::proto::{df_metrics_set_to_proto, metrics_set_proto_to_df};
use crate::metrics::task_metrics_rewriter::{
annotate_metrics_set_with_task_id, stage_metrics_rewriter,
};
use crate::metrics::{DistributedMetricsFormat, rewrite_distributed_plan_with_metrics};
use crate::test_utils::in_memory_channel_resolver::{
InMemoryChannelResolver, InMemoryWorkerResolver,
};
use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed;
use crate::test_utils::plans::count_plan_nodes_up_to_network_boundary;
use crate::test_utils::session_context::register_temp_parquet_table;
use crate::worker::generated::worker as pb;
use crate::{DistributedExec, SessionStateBuilderExt};
use datafusion::arrow::array::{Int32Array, StringArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::execution::SessionStateBuilder;
use datafusion::physical_plan::metrics::{Count, Label, Metric, MetricValue, MetricsSet};
use test_case::test_case;
use datafusion::physical_plan::{ExecutionPlan, collect};
use itertools::Itertools;
use uuid::Uuid;
use crate::DistributedExt;
use crate::common::serialize_uuid;
use crate::metrics::task_metrics_rewriter::MetricsWrapperExec;
use crate::stage::LocalStage;
use crate::worker::generated::worker::TaskKey;
use datafusion::physical_plan::empty::EmptyExec;
use datafusion::prelude::SessionConfig;
use datafusion::prelude::SessionContext;
use std::sync::Arc;
async fn make_test_ctx() -> SessionContext {
make_test_ctx_inner(false).await
}
async fn make_test_distributed_ctx() -> SessionContext {
make_test_ctx_inner(true).await
}
async fn make_test_ctx_inner(distributed: bool) -> SessionContext {
let config = SessionConfig::new().with_target_partitions(4);
let mut builder = SessionStateBuilder::new()
.with_default_features()
.with_config(config);
if distributed {
builder = builder
.with_distributed_worker_resolver(InMemoryWorkerResolver::new(10))
.with_distributed_channel_resolver(InMemoryChannelResolver::default())
.with_distributed_metrics_collection(true)
.unwrap()
.with_distributed_planner()
.with_distributed_task_estimator(2)
}
let state = builder.build();
let ctx = SessionContext::from(state);
let schema1 = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let batches1 = vec![
RecordBatch::try_new(
schema1.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
],
)
.unwrap(),
];
let schema2 = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
Field::new("phone", DataType::Utf8, false),
Field::new("balance", DataType::Float64, false),
]));
let batches2 = vec![
RecordBatch::try_new(
schema2.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec![
"customer1",
"customer2",
"customer3",
])),
Arc::new(StringArray::from(vec![
"13-123-4567",
"31-456-7890",
"23-789-0123",
])),
Arc::new(datafusion::arrow::array::Float64Array::from(vec![
100.5, 250.0, 50.25,
])),
],
)
.unwrap(),
];
let _ = register_temp_parquet_table("table1", schema1, batches1, &ctx)
.await
.unwrap();
let _ = register_temp_parquet_table("table2", schema2, batches2, &ctx)
.await
.unwrap();
ctx
}
fn make_test_stage(plan: Arc<dyn ExecutionPlan>) -> LocalStage {
LocalStage {
query_id: Uuid::new_v4(),
num: 2,
plan,
tasks: 4,
}
}
fn collect_metrics_from_plan(plan: &Arc<dyn ExecutionPlan>, metrics: &mut Vec<MetricsSet>) {
metrics.extend(plan.metrics());
for child in plan.children() {
collect_metrics_from_plan(child, metrics);
}
}
fn metrics_set_eq(a: &MetricsSet, b: &MetricsSet) -> bool {
println!("a: {a:?}");
println!("b: {b:?}");
df_metrics_set_to_proto(a).unwrap() == df_metrics_set_to_proto(b).unwrap()
}
async fn run_stage_metrics_rewriter_test(sql: &str, format: DistributedMetricsFormat) {
let ctx = make_test_ctx().await;
let plan = ctx
.sql(sql)
.await
.unwrap()
.create_physical_plan()
.await
.unwrap();
let stage = make_test_stage(plan.clone());
let num_metrics_per_task_per_node = 4;
let metrics_collection = MetricsStore::from_entries((0..stage.tasks).map(|task_id| {
let task_key = TaskKey {
query_id: serialize_uuid(&stage.query_id),
stage_id: stage.num as u64,
task_number: task_id as u64,
};
let metrics = (0..count_plan_nodes_up_to_network_boundary(&plan))
.map(|node_id| {
make_test_metrics_set_proto_from_seed(
(node_id * task_id) as u64,
num_metrics_per_task_per_node,
)
})
.collect::<Vec<pb::MetricsSet>>();
let task_metrics = pb::TaskMetrics {
task_metrics: None,
pre_order_plan_metrics: metrics,
};
(task_key, task_metrics)
}));
let metrics_collection = Arc::new(metrics_collection);
let rewritten_plan =
stage_metrics_rewriter(&stage, metrics_collection.clone(), format).unwrap();
let mut actual_metrics = vec![];
collect_metrics_from_plan(&rewritten_plan, &mut actual_metrics);
assert_eq!(
actual_metrics.len(),
count_plan_nodes_up_to_network_boundary(&plan)
);
for (node_id, actual_stage_node_metrics_set) in actual_metrics.iter().enumerate() {
for (task_id, actual_task_node_metrics_set) in actual_stage_node_metrics_set
.iter()
.chunks(num_metrics_per_task_per_node)
.into_iter()
.enumerate()
{
let expected_task_node_metrics = metrics_collection
.get(&TaskKey {
query_id: serialize_uuid(&stage.query_id),
stage_id: stage.num as u64,
task_number: task_id as u64,
})
.unwrap()
.pre_order_plan_metrics[node_id]
.clone();
let mut actual_metrics_set = MetricsSet::new();
actual_task_node_metrics_set
.for_each(|metric| actual_metrics_set.push(metric.clone()));
let mut expected_metrics_set =
metrics_set_proto_to_df(&expected_task_node_metrics).unwrap();
if format == DistributedMetricsFormat::PerTask {
expected_metrics_set =
annotate_metrics_set_with_task_id(expected_metrics_set, task_id as u64);
}
assert!(metrics_set_eq(&actual_metrics_set, &expected_metrics_set));
}
}
}
#[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
#[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
#[tokio::test]
async fn test_stage_metrics_rewriter_1(format: DistributedMetricsFormat) {
run_stage_metrics_rewriter_test(
"SELECT sum(balance) / 7.0 as avg_yearly from table2 group by name",
format,
)
.await;
}
#[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
#[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
#[tokio::test]
async fn test_stage_metrics_rewriter_2(format: DistributedMetricsFormat) {
run_stage_metrics_rewriter_test("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10", format).await;
}
#[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
#[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
#[tokio::test]
async fn test_stage_metrics_rewriter_3(format: DistributedMetricsFormat) {
run_stage_metrics_rewriter_test(
"SELECT sum(balance) / 7.0 as avg_yearly
FROM table2
WHERE name LIKE 'customer%'
AND balance < (
SELECT 0.2 * avg(balance)
FROM table2 t2_inner
WHERE t2_inner.id = table2.id
)",
format,
)
.await;
}
#[tokio::test]
async fn test_rewrite_unexecuted_distributed_plan_with_metrics_err() {
let ctx = make_test_distributed_ctx().await;
let plan = ctx
.sql("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10")
.await
.unwrap()
.create_physical_plan()
.await
.unwrap();
assert!(plan.is::<DistributedExec>());
assert!(
rewrite_distributed_plan_with_metrics(plan, DistributedMetricsFormat::Aggregated)
.await
.is_err()
);
}
fn assert_metrics_present_in_plan(plan: &Arc<dyn ExecutionPlan>) {
if let Some(metrics) = plan.metrics() {
assert!(metrics.iter().count() > 0);
} else {
assert!(plan.is::<DistributedExec>());
}
for child in plan.children() {
assert_metrics_present_in_plan(child);
}
}
#[tokio::test]
async fn test_executed_distributed_plan_has_metrics() {
let ctx = make_test_distributed_ctx().await;
let plan = ctx
.sql("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10")
.await
.unwrap()
.create_physical_plan()
.await
.unwrap();
collect(plan.clone(), ctx.task_ctx()).await.unwrap();
assert!(plan.is::<DistributedExec>());
let rewritten_plan =
rewrite_distributed_plan_with_metrics(plan, DistributedMetricsFormat::Aggregated)
.await
.unwrap();
assert_metrics_present_in_plan(&rewritten_plan);
}
#[test]
fn test_wrapped_node_is_accessible() {
let example_node = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![Field::new(
"id",
DataType::Int32,
false,
)]))));
let wrapped = MetricsWrapperExec::new(example_node, MetricsSet::new());
assert_eq!(wrapped.name(), "EmptyExec");
assert!(wrapped.inner().is::<EmptyExec>());
}
#[test]
fn test_annotate_metrics_set_with_task_id_output_rows() {
let mut metrics_set = MetricsSet::new();
let count = Count::new();
count.add(1234);
let labels = vec![Label::new("operator", "scan")];
metrics_set.push(Arc::new(Metric::new_with_labels(
MetricValue::OutputRows(count),
Some(0),
labels,
)));
let task_id = 42;
let annotated = annotate_metrics_set_with_task_id(metrics_set, task_id);
assert_eq!(annotated.iter().count(), 1);
let metric = annotated.iter().next().unwrap();
match metric.value() {
MetricValue::OutputRows(count) => {
assert_eq!(count.value(), 1234);
}
other => panic!("Expected OutputRows, got {:?}", other.name()),
}
assert_eq!(metric.partition(), Some(0));
let labels: Vec<_> = metric.labels().iter().collect();
assert_eq!(labels.len(), 2);
assert_eq!(labels[0].name(), "operator");
assert_eq!(labels[0].value(), "scan");
assert_eq!(labels[1].name(), DISTRIBUTED_DATAFUSION_TASK_ID_LABEL);
assert_eq!(labels[1].value(), "42");
}
}