1use crate::DistributedTaskContext;
2use crate::common::{TreeNodeExt, serialize_uuid};
3use crate::coordinator::{DistributedExec, MetricsStore};
4use crate::distributed_planner::NetworkBoundaryExt;
5use crate::execution_plans::MetricsWrapperExec;
6use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL;
7use crate::metrics::collect_plan_metrics;
8use crate::metrics::proto::metrics_set_proto_to_df;
9use crate::stage::{LocalStage, Stage};
10use crate::worker::generated::worker::TaskKey;
11use datafusion::common::HashMap;
12use datafusion::common::plan_err;
13use datafusion::common::tree_node::Transformed;
14use datafusion::common::tree_node::TreeNode;
15use datafusion::common::tree_node::TreeNodeRecursion;
16use datafusion::error::Result;
17use datafusion::physical_plan::ExecutionPlan;
18use datafusion::physical_plan::internal_err;
19use datafusion::physical_plan::metrics::{Label, Metric, MetricsSet};
20use std::sync::Arc;
21use std::vec;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum DistributedMetricsFormat {
26 Aggregated,
28
29 PerTask,
31}
32
33impl DistributedMetricsFormat {
34 pub(crate) fn to_rewrite_ctx(self, task_id: u64) -> RewriteCtx {
35 match self {
36 DistributedMetricsFormat::Aggregated => RewriteCtx::default(),
37 DistributedMetricsFormat::PerTask => RewriteCtx::from_task_id(task_id),
38 }
39 }
40}
41
42pub async fn rewrite_distributed_plan_with_metrics(
47 plan: Arc<dyn ExecutionPlan>,
48 format: DistributedMetricsFormat,
49) -> Result<Arc<dyn ExecutionPlan>> {
50 let Some(distributed_exec) = plan.downcast_ref::<DistributedExec>() else {
51 return Ok(plan);
52 };
53
54 let prepared = distributed_exec.prepared_plan()?;
57
58 distributed_exec.wait_for_metrics().await;
59
60 let Some(metrics_collection) = distributed_exec.metrics_store.clone() else {
61 return Ok(plan);
62 };
63
64 let task_metrics = collect_plan_metrics(&prepared)?;
65
66 let dist_exec_plan_with_metrics = rewrite_local_plan_with_metrics(
68 format.to_rewrite_ctx(0), plan.children()[0].clone(),
70 task_metrics,
71 )?;
72 let plan = plan.with_new_children(vec![dist_exec_plan_with_metrics])?;
73
74 let transformed = plan.transform_down(|plan| {
75 let inner = plan
79 .downcast_ref::<MetricsWrapperExec>()
80 .map(|w| w.inner_arc())
81 .unwrap_or_else(|| Arc::clone(&plan));
82
83 if let Some(network_boundary) = inner.as_network_boundary() {
85 let Stage::Local(stage) = network_boundary.input_stage() else {
86 return plan_err!("Stage was not in Local state");
87 };
88 let plan_with_metrics =
91 stage_metrics_rewriter(stage, Arc::clone(&metrics_collection), format)?;
92 let network_boundary = network_boundary.with_input_stage(Stage::Local(LocalStage {
93 query_id: stage.query_id,
94 num: stage.num,
95 plan: plan_with_metrics,
96 tasks: stage.tasks,
97 }))?;
98 let network_boundary =
99 MetricsWrapperExec::new(network_boundary, plan.metrics().unwrap_or_default());
100 return Ok(Transformed::yes(Arc::new(network_boundary)));
101 }
102
103 Ok(Transformed::no(plan))
104 })?;
105 Ok(transformed.data)
106}
107
108#[derive(Default)]
110pub struct RewriteCtx {
111 pub task_id: Option<u64>,
113}
114
115impl RewriteCtx {
116 pub(crate) fn from_task_id(task_id: u64) -> RewriteCtx {
117 RewriteCtx {
118 task_id: Some(task_id),
119 }
120 }
121
122 pub(crate) fn maybe_rewrite_node_metics(&self, node_metrics: MetricsSet) -> MetricsSet {
124 if let Some(task_id) = self.task_id {
125 return annotate_metrics_set_with_task_id(node_metrics, task_id);
126 }
127 node_metrics
128 }
129}
130
131pub fn annotate_metrics_set_with_task_id(metrics_set: MetricsSet, task_id: u64) -> MetricsSet {
137 let mut result = MetricsSet::new();
138
139 for metric in metrics_set.iter() {
140 let mut labels = metric.labels().to_vec();
141 labels.push(Label::new(
142 DISTRIBUTED_DATAFUSION_TASK_ID_LABEL,
143 task_id.to_string(),
144 ));
145 result.push(Arc::new(Metric::new_with_labels(
146 metric.value().clone(),
147 metric.partition(),
148 labels,
149 )));
150 }
151
152 result
153}
154
155pub fn rewrite_local_plan_with_metrics(
169 ctx: RewriteCtx,
170 plan: Arc<dyn ExecutionPlan>,
171 metrics: Vec<MetricsSet>,
172) -> Result<Arc<dyn ExecutionPlan>> {
173 let mut idx = 0;
174 Ok(plan
175 .transform_down(|node| {
176 if idx >= metrics.len() {
177 return internal_err!("not enough metrics provided to rewrite plan");
178 }
179 let mut node_metrics = metrics[idx].clone();
180
181 node_metrics = ctx.maybe_rewrite_node_metics(node_metrics);
182
183 idx += 1;
184 Ok(Transformed::new(
185 Arc::new(MetricsWrapperExec::new(node.clone(), node_metrics)),
186 true,
187 if node.is_network_boundary() {
188 TreeNodeRecursion::Jump
189 } else {
190 TreeNodeRecursion::Continue
191 },
192 ))
193 })?
194 .data)
195}
196
197pub fn stage_metrics_rewriter(
222 stage: &LocalStage,
223 metrics_collection: Arc<MetricsStore>,
224 format: DistributedMetricsFormat,
225) -> Result<Arc<dyn ExecutionPlan>> {
226 let mut node_metrics_map: HashMap<usize, MetricsSet> = HashMap::new();
235
236 for task_id in 0..stage.tasks {
237 let d_ctx = DistributedTaskContext {
238 task_index: task_id,
239 task_count: stage.tasks,
240 };
241 let task_key = TaskKey {
242 query_id: serialize_uuid(&stage.query_id),
243 stage_id: stage.num as u64,
244 task_number: task_id as u64,
245 };
246 let Some(task_metrics) = metrics_collection.get(&task_key) else {
247 return internal_err!(
248 "not enough metrics provided to rewrite task: missing metrics for task {} in stage {}",
249 task_id,
250 stage.num
251 );
252 };
253
254 let mut per_task_counter = 0usize;
255 stage.plan.apply_with_dt_ctx(d_ctx, |node, _ctx| {
256 if per_task_counter >= task_metrics.pre_order_plan_metrics.len() {
257 return internal_err!(
258 "not enough metrics provided to rewrite task: {} metrics provided",
259 task_metrics.pre_order_plan_metrics.len()
260 );
261 }
262
263 let node_metrics_protos = task_metrics.pre_order_plan_metrics[per_task_counter].clone();
264 let mut node_metrics = metrics_set_proto_to_df(&node_metrics_protos)?;
265 let rewrite_ctx = format.to_rewrite_ctx(task_id as u64);
266 node_metrics = rewrite_ctx.maybe_rewrite_node_metics(node_metrics);
267
268 let id = Arc::as_ptr(node) as *const () as usize;
269 let entry = node_metrics_map.entry(id).or_default();
270 for metric in node_metrics.iter().map(Arc::clone) {
271 entry.push(metric);
272 }
273
274 per_task_counter += 1;
275 Ok(TreeNodeRecursion::Continue)
276 })?;
277 }
278
279 Arc::clone(&stage.plan)
282 .transform_down(|plan| {
283 let id = Arc::as_ptr(&plan) as *const () as usize;
284 let metrics = node_metrics_map.remove(&id).unwrap_or_default();
285 Ok(Transformed::new(
286 Arc::new(MetricsWrapperExec::new(plan.clone(), metrics)),
287 true,
288 match plan.is_network_boundary() {
289 true => TreeNodeRecursion::Jump,
290 false => TreeNodeRecursion::Continue,
291 },
292 ))
293 })
294 .map(|v| v.data)
295}
296
297#[cfg(test)]
298mod tests {
299 use crate::coordinator::MetricsStore;
300 use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL;
301 use crate::metrics::proto::{df_metrics_set_to_proto, metrics_set_proto_to_df};
302 use crate::metrics::task_metrics_rewriter::{
303 annotate_metrics_set_with_task_id, stage_metrics_rewriter,
304 };
305 use crate::metrics::{DistributedMetricsFormat, rewrite_distributed_plan_with_metrics};
306 use crate::test_utils::in_memory_channel_resolver::{
307 InMemoryChannelResolver, InMemoryWorkerResolver,
308 };
309 use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed;
310 use crate::test_utils::plans::count_plan_nodes_up_to_network_boundary;
311 use crate::test_utils::session_context::register_temp_parquet_table;
312 use crate::worker::generated::worker as pb;
313 use crate::{DistributedExec, SessionStateBuilderExt};
314 use datafusion::arrow::array::{Int32Array, StringArray};
315 use datafusion::arrow::datatypes::{DataType, Field, Schema};
316 use datafusion::arrow::record_batch::RecordBatch;
317 use datafusion::execution::SessionStateBuilder;
318 use datafusion::physical_plan::metrics::{Count, Label, Metric, MetricValue, MetricsSet};
319 use test_case::test_case;
320
321 use datafusion::physical_plan::{ExecutionPlan, collect};
322 use itertools::Itertools;
323 use uuid::Uuid;
324
325 use crate::DistributedExt;
326 use crate::common::serialize_uuid;
327 use crate::metrics::task_metrics_rewriter::MetricsWrapperExec;
328 use crate::stage::LocalStage;
329 use crate::worker::generated::worker::TaskKey;
330 use datafusion::physical_plan::empty::EmptyExec;
331 use datafusion::prelude::SessionConfig;
332 use datafusion::prelude::SessionContext;
333 use std::sync::Arc;
334
335 async fn make_test_ctx() -> SessionContext {
336 make_test_ctx_inner(false).await
337 }
338
339 async fn make_test_distributed_ctx() -> SessionContext {
340 make_test_ctx_inner(true).await
341 }
342
343 async fn make_test_ctx_inner(distributed: bool) -> SessionContext {
347 let config = SessionConfig::new().with_target_partitions(4);
348 let mut builder = SessionStateBuilder::new()
349 .with_default_features()
350 .with_config(config);
351
352 if distributed {
353 builder = builder
354 .with_distributed_worker_resolver(InMemoryWorkerResolver::new(10))
355 .with_distributed_channel_resolver(InMemoryChannelResolver::default())
356 .with_distributed_metrics_collection(true)
357 .unwrap()
358 .with_distributed_planner()
359 .with_distributed_task_estimator(2)
360 }
361
362 let state = builder.build();
363 let ctx = SessionContext::from(state);
364
365 let schema1 = Arc::new(Schema::new(vec![
367 Field::new("id", DataType::Int32, false),
368 Field::new("name", DataType::Utf8, false),
369 ]));
370
371 let batches1 = vec![
372 RecordBatch::try_new(
373 schema1.clone(),
374 vec![
375 Arc::new(Int32Array::from(vec![1, 2, 3])),
376 Arc::new(StringArray::from(vec!["a", "b", "c"])),
377 ],
378 )
379 .unwrap(),
380 ];
381
382 let schema2 = Arc::new(Schema::new(vec![
384 Field::new("id", DataType::Int32, false),
385 Field::new("name", DataType::Utf8, false),
386 Field::new("phone", DataType::Utf8, false),
387 Field::new("balance", DataType::Float64, false),
388 ]));
389
390 let batches2 = vec![
391 RecordBatch::try_new(
392 schema2.clone(),
393 vec![
394 Arc::new(Int32Array::from(vec![1, 2, 3])),
395 Arc::new(StringArray::from(vec![
396 "customer1",
397 "customer2",
398 "customer3",
399 ])),
400 Arc::new(StringArray::from(vec![
401 "13-123-4567",
402 "31-456-7890",
403 "23-789-0123",
404 ])),
405 Arc::new(datafusion::arrow::array::Float64Array::from(vec![
406 100.5, 250.0, 50.25,
407 ])),
408 ],
409 )
410 .unwrap(),
411 ];
412
413 let _ = register_temp_parquet_table("table1", schema1, batches1, &ctx)
415 .await
416 .unwrap();
417
418 let _ = register_temp_parquet_table("table2", schema2, batches2, &ctx)
419 .await
420 .unwrap();
421
422 ctx
423 }
424
425 fn make_test_stage(plan: Arc<dyn ExecutionPlan>) -> LocalStage {
426 LocalStage {
427 query_id: Uuid::new_v4(),
428 num: 2,
429 plan,
430 tasks: 4,
431 }
432 }
433
434 fn collect_metrics_from_plan(plan: &Arc<dyn ExecutionPlan>, metrics: &mut Vec<MetricsSet>) {
435 metrics.extend(plan.metrics());
436 for child in plan.children() {
437 collect_metrics_from_plan(child, metrics);
438 }
439 }
440
441 fn metrics_set_eq(a: &MetricsSet, b: &MetricsSet) -> bool {
442 println!("a: {a:?}");
443 println!("b: {b:?}");
444 df_metrics_set_to_proto(a).unwrap() == df_metrics_set_to_proto(b).unwrap()
446 }
447
448 async fn run_stage_metrics_rewriter_test(sql: &str, format: DistributedMetricsFormat) {
454 let ctx = make_test_ctx().await;
456 let plan = ctx
457 .sql(sql)
458 .await
459 .unwrap()
460 .create_physical_plan()
461 .await
462 .unwrap();
463
464 let stage = make_test_stage(plan.clone());
465
466 let num_metrics_per_task_per_node = 4;
467
468 let metrics_collection = MetricsStore::from_entries((0..stage.tasks).map(|task_id| {
470 let task_key = TaskKey {
471 query_id: serialize_uuid(&stage.query_id),
472 stage_id: stage.num as u64,
473 task_number: task_id as u64,
474 };
475 let metrics = (0..count_plan_nodes_up_to_network_boundary(&plan))
476 .map(|node_id| {
477 make_test_metrics_set_proto_from_seed(
478 (node_id * task_id) as u64,
479 num_metrics_per_task_per_node,
480 )
481 })
482 .collect::<Vec<pb::MetricsSet>>();
483 let task_metrics = pb::TaskMetrics {
484 task_metrics: None,
485 pre_order_plan_metrics: metrics,
486 };
487 (task_key, task_metrics)
488 }));
489 let metrics_collection = Arc::new(metrics_collection);
490
491 let rewritten_plan =
493 stage_metrics_rewriter(&stage, metrics_collection.clone(), format).unwrap();
494
495 let mut actual_metrics = vec![];
497 collect_metrics_from_plan(&rewritten_plan, &mut actual_metrics);
498 assert_eq!(
499 actual_metrics.len(),
500 count_plan_nodes_up_to_network_boundary(&plan)
501 );
502
503 for (node_id, actual_stage_node_metrics_set) in actual_metrics.iter().enumerate() {
506 for (task_id, actual_task_node_metrics_set) in actual_stage_node_metrics_set
508 .iter()
509 .chunks(num_metrics_per_task_per_node)
510 .into_iter()
511 .enumerate()
512 {
513 let expected_task_node_metrics = metrics_collection
514 .get(&TaskKey {
515 query_id: serialize_uuid(&stage.query_id),
516 stage_id: stage.num as u64,
517 task_number: task_id as u64,
518 })
519 .unwrap()
520 .pre_order_plan_metrics[node_id]
521 .clone();
522
523 let mut actual_metrics_set = MetricsSet::new();
524 actual_task_node_metrics_set
525 .for_each(|metric| actual_metrics_set.push(metric.clone()));
526
527 let mut expected_metrics_set =
529 metrics_set_proto_to_df(&expected_task_node_metrics).unwrap();
530
531 if format == DistributedMetricsFormat::PerTask {
532 expected_metrics_set =
535 annotate_metrics_set_with_task_id(expected_metrics_set, task_id as u64);
536 }
537 assert!(metrics_set_eq(&actual_metrics_set, &expected_metrics_set));
538 }
539 }
540 }
541
542 #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
543 #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
544 #[tokio::test]
545 async fn test_stage_metrics_rewriter_1(format: DistributedMetricsFormat) {
546 run_stage_metrics_rewriter_test(
547 "SELECT sum(balance) / 7.0 as avg_yearly from table2 group by name",
548 format,
549 )
550 .await;
551 }
552
553 #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
554 #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
555 #[tokio::test]
556 async fn test_stage_metrics_rewriter_2(format: DistributedMetricsFormat) {
557 run_stage_metrics_rewriter_test("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10", format).await;
558 }
559
560 #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
561 #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
562 #[tokio::test]
563 async fn test_stage_metrics_rewriter_3(format: DistributedMetricsFormat) {
564 run_stage_metrics_rewriter_test(
565 "SELECT sum(balance) / 7.0 as avg_yearly
566 FROM table2
567 WHERE name LIKE 'customer%'
568 AND balance < (
569 SELECT 0.2 * avg(balance)
570 FROM table2 t2_inner
571 WHERE t2_inner.id = table2.id
572 )",
573 format,
574 )
575 .await;
576 }
577
578 #[tokio::test]
579 async fn test_rewrite_unexecuted_distributed_plan_with_metrics_err() {
580 let ctx = make_test_distributed_ctx().await;
581 let plan = ctx
582 .sql("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10")
583 .await
584 .unwrap()
585 .create_physical_plan()
586 .await
587 .unwrap();
588 assert!(plan.is::<DistributedExec>());
589 assert!(
590 rewrite_distributed_plan_with_metrics(plan, DistributedMetricsFormat::Aggregated)
591 .await
592 .is_err()
593 );
594 }
595
596 fn assert_metrics_present_in_plan(plan: &Arc<dyn ExecutionPlan>) {
598 if let Some(metrics) = plan.metrics() {
599 assert!(metrics.iter().count() > 0);
600 } else {
601 assert!(plan.is::<DistributedExec>());
602 }
603 for child in plan.children() {
604 assert_metrics_present_in_plan(child);
605 }
606 }
607
608 #[tokio::test]
609 async fn test_executed_distributed_plan_has_metrics() {
610 let ctx = make_test_distributed_ctx().await;
611 let plan = ctx
612 .sql("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10")
613 .await
614 .unwrap()
615 .create_physical_plan()
616 .await
617 .unwrap();
618 collect(plan.clone(), ctx.task_ctx()).await.unwrap();
619 assert!(plan.is::<DistributedExec>());
620 let rewritten_plan =
621 rewrite_distributed_plan_with_metrics(plan, DistributedMetricsFormat::Aggregated)
622 .await
623 .unwrap();
624 assert_metrics_present_in_plan(&rewritten_plan);
625 }
626
627 #[test]
628 fn test_wrapped_node_is_accessible() {
632 let example_node = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![Field::new(
633 "id",
634 DataType::Int32,
635 false,
636 )]))));
637
638 let wrapped = MetricsWrapperExec::new(example_node, MetricsSet::new());
639 assert_eq!(wrapped.name(), "EmptyExec");
640 assert!(wrapped.inner().is::<EmptyExec>());
641 }
642
643 #[test]
644 fn test_annotate_metrics_set_with_task_id_output_rows() {
645 let mut metrics_set = MetricsSet::new();
647 let count = Count::new();
648 count.add(1234);
649 let labels = vec![Label::new("operator", "scan")];
650 metrics_set.push(Arc::new(Metric::new_with_labels(
651 MetricValue::OutputRows(count),
652 Some(0),
653 labels,
654 )));
655
656 let task_id = 42;
657 let annotated = annotate_metrics_set_with_task_id(metrics_set, task_id);
658
659 assert_eq!(annotated.iter().count(), 1);
661
662 let metric = annotated.iter().next().unwrap();
663
664 match metric.value() {
666 MetricValue::OutputRows(count) => {
667 assert_eq!(count.value(), 1234);
668 }
669 other => panic!("Expected OutputRows, got {:?}", other.name()),
670 }
671
672 assert_eq!(metric.partition(), Some(0));
674
675 let labels: Vec<_> = metric.labels().iter().collect();
677 assert_eq!(labels.len(), 2);
678 assert_eq!(labels[0].name(), "operator");
679 assert_eq!(labels[0].value(), "scan");
680 assert_eq!(labels[1].name(), DISTRIBUTED_DATAFUSION_TASK_ID_LABEL);
681 assert_eq!(labels[1].value(), "42");
682 }
683}