1use crate::distributed_planner::NetworkBoundaryExt;
2use crate::execution_plans::DistributedExec;
3use crate::execution_plans::MetricsWrapperExec;
4use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL;
5use crate::metrics::MetricsCollectorResult;
6use crate::metrics::TaskMetricsCollector;
7use crate::metrics::proto::metrics_set_proto_to_df;
8use crate::stage::Stage;
9use crate::worker::generated::worker as pb;
10use crate::worker::generated::worker::TaskKey;
11use datafusion::common::HashMap;
12use datafusion::common::tree_node::Transformed;
13use datafusion::common::tree_node::TreeNode;
14use datafusion::common::tree_node::TreeNodeRecursion;
15use datafusion::error::Result;
16use datafusion::physical_plan::ExecutionPlan;
17use datafusion::physical_plan::internal_err;
18use datafusion::physical_plan::metrics::{Label, Metric, MetricsSet};
19use std::sync::Arc;
20use std::vec;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum DistributedMetricsFormat {
25 Aggregated,
27
28 PerTask,
30}
31
32impl DistributedMetricsFormat {
33 pub(crate) fn to_rewrite_ctx(self, task_id: u64) -> RewriteCtx {
34 match self {
35 DistributedMetricsFormat::Aggregated => RewriteCtx::default(),
36 DistributedMetricsFormat::PerTask => RewriteCtx::from_task_id(task_id),
37 }
38 }
39}
40
41pub fn rewrite_distributed_plan_with_metrics(
44 plan: Arc<dyn ExecutionPlan>,
45 format: DistributedMetricsFormat,
46) -> Result<Arc<dyn ExecutionPlan>> {
47 let Some(distributed_exec) = plan.as_any().downcast_ref::<DistributedExec>() else {
48 return Ok(plan);
49 };
50
51 let MetricsCollectorResult {
53 task_metrics, input_task_metrics, } = TaskMetricsCollector::new().collect(distributed_exec.prepared_plan()?)?;
56
57 let dist_exec_plan_with_metrics = rewrite_local_plan_with_metrics(
59 format.to_rewrite_ctx(0), plan.children()[0].clone(),
61 task_metrics,
62 )?;
63 let plan = plan.with_new_children(vec![dist_exec_plan_with_metrics])?;
64
65 let metrics_collection = Arc::new(input_task_metrics);
66
67 let transformed = plan.transform_down(|plan| {
68 if let Some(network_boundary) = plan.as_network_boundary() {
70 let stage = network_boundary.input_stage();
71 let plan_with_metrics =
74 stage_metrics_rewriter(stage, metrics_collection.clone(), format)?;
75 let network_boundary = network_boundary.with_input_stage(Stage::new(
76 stage.query_id,
77 stage.num,
78 plan_with_metrics,
79 stage.tasks.len(),
80 ))?;
81 let network_boundary =
82 MetricsWrapperExec::new(network_boundary, plan.metrics().unwrap_or_default());
83 return Ok(Transformed::yes(Arc::new(network_boundary)));
84 }
85
86 Ok(Transformed::no(plan))
87 })?;
88 Ok(transformed.data)
89}
90
91#[derive(Default)]
93pub struct RewriteCtx {
94 pub task_id: Option<u64>,
96}
97
98impl RewriteCtx {
99 pub(crate) fn from_task_id(task_id: u64) -> RewriteCtx {
100 RewriteCtx {
101 task_id: Some(task_id),
102 }
103 }
104
105 pub(crate) fn maybe_rewrite_node_metics(&self, node_metrics: MetricsSet) -> MetricsSet {
107 if let Some(task_id) = self.task_id {
108 return annotate_metrics_set_with_task_id(node_metrics, task_id);
109 }
110 node_metrics
111 }
112}
113
114pub fn annotate_metrics_set_with_task_id(metrics_set: MetricsSet, task_id: u64) -> MetricsSet {
120 let mut result = MetricsSet::new();
121
122 for metric in metrics_set.iter() {
123 let mut labels = metric.labels().to_vec();
124 labels.push(Label::new(
125 DISTRIBUTED_DATAFUSION_TASK_ID_LABEL,
126 task_id.to_string(),
127 ));
128 result.push(Arc::new(Metric::new_with_labels(
129 metric.value().clone(),
130 metric.partition(),
131 labels,
132 )));
133 }
134
135 result
136}
137
138pub fn rewrite_local_plan_with_metrics(
152 ctx: RewriteCtx,
153 plan: Arc<dyn ExecutionPlan>,
154 metrics: Vec<MetricsSet>,
155) -> Result<Arc<dyn ExecutionPlan>> {
156 let mut idx = 0;
157 Ok(plan
158 .transform_down(|node| {
159 if idx >= metrics.len() {
160 return internal_err!("not enough metrics provided to rewrite plan");
161 }
162 let mut node_metrics = metrics[idx].clone();
163
164 node_metrics = ctx.maybe_rewrite_node_metics(node_metrics);
165
166 idx += 1;
167 Ok(Transformed::new(
168 Arc::new(MetricsWrapperExec::new(node.clone(), node_metrics)),
169 true,
170 if node.is_network_boundary() {
171 TreeNodeRecursion::Jump
172 } else {
173 TreeNodeRecursion::Continue
174 },
175 ))
176 })?
177 .data)
178}
179
180pub fn stage_metrics_rewriter(
205 stage: &Stage,
206 metrics_collection: Arc<HashMap<TaskKey, Vec<pb::MetricsSet>>>,
207 format: DistributedMetricsFormat,
208) -> Result<Arc<dyn ExecutionPlan>> {
209 let mut node_idx = 0;
210
211 let Some(plan) = &stage.plan else {
212 return internal_err!("The inner plan of a stage was not present");
213 };
214
215 plan.clone().transform_down(|plan| {
216 let mut stage_metrics = MetricsSet::new();
218
219 for task_id in 0..stage.tasks.len() {
220 let task_key = TaskKey {
221 query_id: stage.query_id.as_bytes().to_vec(),
222 stage_id: stage.num as u64,
223 task_number: task_id as u64,
224 };
225 match metrics_collection.get(&task_key) {
226 Some(task_metrics) => {
227 if node_idx >= task_metrics.len() {
228 return internal_err!(
229 "not enough metrics provided to rewrite task: {} metrics provided",
230 task_metrics.len()
231 );
232 }
233 let node_metrics_protos = task_metrics[node_idx].clone();
234 let mut node_metrics = metrics_set_proto_to_df(&node_metrics_protos)?;
235
236 let rewrite_ctx = format.to_rewrite_ctx(task_id as u64);
237 node_metrics = rewrite_ctx.maybe_rewrite_node_metics(node_metrics);
238
239 for metric in node_metrics.iter().map(Arc::clone) {
240 stage_metrics.push(metric);
241 }
242 }
243 None => {
244 return internal_err!(
245 "not enough metrics provided to rewrite task: missing metrics for task {} in stage {}",
246 task_id,
247 stage.num
248 );
249 }
250 }
251 }
252
253 node_idx += 1;
254
255 let wrapped_plan_node: Arc<dyn ExecutionPlan> = Arc::new(MetricsWrapperExec::new(
256 plan.clone(),
257 stage_metrics,
258 ));
259 Ok(Transformed::new(
260 wrapped_plan_node,
261 true,
262 if plan.is_network_boundary() {
263 TreeNodeRecursion::Jump
264 } else {
265 TreeNodeRecursion::Continue
266 }
267 ))
268 }).map(|v| v.data)
269}
270
271#[cfg(test)]
272mod tests {
273 use crate::Stage;
274 use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL;
275 use crate::metrics::proto::{df_metrics_set_to_proto, metrics_set_proto_to_df};
276 use crate::metrics::task_metrics_rewriter::{
277 annotate_metrics_set_with_task_id, stage_metrics_rewriter,
278 };
279 use crate::metrics::{DistributedMetricsFormat, rewrite_distributed_plan_with_metrics};
280 use crate::test_utils::in_memory_channel_resolver::{
281 InMemoryChannelResolver, InMemoryWorkerResolver,
282 };
283 use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed;
284 use crate::test_utils::plans::count_plan_nodes_up_to_network_boundary;
285 use crate::test_utils::session_context::register_temp_parquet_table;
286 use crate::worker::generated::worker as pb;
287 use crate::{DistributedExec, DistributedPhysicalOptimizerRule};
288 use datafusion::arrow::array::{Int32Array, StringArray};
289 use datafusion::arrow::datatypes::{DataType, Field, Schema};
290 use datafusion::arrow::record_batch::RecordBatch;
291 use datafusion::common::HashMap;
292 use datafusion::execution::SessionStateBuilder;
293 use datafusion::physical_plan::metrics::{Count, Label, Metric, MetricValue, MetricsSet};
294 use test_case::test_case;
295
296 use datafusion::physical_plan::{ExecutionPlan, collect};
297 use itertools::Itertools;
298 use uuid::Uuid;
299
300 use crate::DistributedExt;
301 use crate::metrics::task_metrics_rewriter::MetricsWrapperExec;
302 use crate::worker::generated::worker::TaskKey;
303 use datafusion::physical_plan::empty::EmptyExec;
304 use datafusion::prelude::SessionConfig;
305 use datafusion::prelude::SessionContext;
306 use std::sync::Arc;
307
308 async fn make_test_ctx() -> SessionContext {
309 make_test_ctx_inner(false).await
310 }
311
312 async fn make_test_distributed_ctx() -> SessionContext {
313 make_test_ctx_inner(true).await
314 }
315
316 async fn make_test_ctx_inner(distributed: bool) -> SessionContext {
320 let config = SessionConfig::new().with_target_partitions(4);
321 let mut builder = SessionStateBuilder::new()
322 .with_default_features()
323 .with_config(config);
324
325 if distributed {
326 builder = builder
327 .with_distributed_worker_resolver(InMemoryWorkerResolver::new(10))
328 .with_distributed_channel_resolver(InMemoryChannelResolver::default())
329 .with_distributed_metrics_collection(true)
330 .unwrap()
331 .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
332 .with_distributed_task_estimator(2)
333 }
334
335 let state = builder.build();
336 let ctx = SessionContext::from(state);
337
338 let schema1 = Arc::new(Schema::new(vec![
340 Field::new("id", DataType::Int32, false),
341 Field::new("name", DataType::Utf8, false),
342 ]));
343
344 let batches1 = vec![
345 RecordBatch::try_new(
346 schema1.clone(),
347 vec![
348 Arc::new(Int32Array::from(vec![1, 2, 3])),
349 Arc::new(StringArray::from(vec!["a", "b", "c"])),
350 ],
351 )
352 .unwrap(),
353 ];
354
355 let schema2 = Arc::new(Schema::new(vec![
357 Field::new("id", DataType::Int32, false),
358 Field::new("name", DataType::Utf8, false),
359 Field::new("phone", DataType::Utf8, false),
360 Field::new("balance", DataType::Float64, false),
361 ]));
362
363 let batches2 = vec![
364 RecordBatch::try_new(
365 schema2.clone(),
366 vec![
367 Arc::new(Int32Array::from(vec![1, 2, 3])),
368 Arc::new(StringArray::from(vec![
369 "customer1",
370 "customer2",
371 "customer3",
372 ])),
373 Arc::new(StringArray::from(vec![
374 "13-123-4567",
375 "31-456-7890",
376 "23-789-0123",
377 ])),
378 Arc::new(datafusion::arrow::array::Float64Array::from(vec![
379 100.5, 250.0, 50.25,
380 ])),
381 ],
382 )
383 .unwrap(),
384 ];
385
386 let _ = register_temp_parquet_table("table1", schema1, batches1, &ctx)
388 .await
389 .unwrap();
390
391 let _ = register_temp_parquet_table("table2", schema2, batches2, &ctx)
392 .await
393 .unwrap();
394
395 ctx
396 }
397
398 fn make_test_stage(plan: Arc<dyn ExecutionPlan>) -> Stage {
399 Stage::new(Uuid::new_v4(), 2, plan, 4)
400 }
401
402 fn collect_metrics_from_plan(plan: &Arc<dyn ExecutionPlan>, metrics: &mut Vec<MetricsSet>) {
403 metrics.extend(plan.metrics());
404 for child in plan.children() {
405 collect_metrics_from_plan(child, metrics);
406 }
407 }
408
409 fn metrics_set_eq(a: &MetricsSet, b: &MetricsSet) -> bool {
410 println!("a: {a:?}");
411 println!("b: {b:?}");
412 df_metrics_set_to_proto(a).unwrap() == df_metrics_set_to_proto(b).unwrap()
414 }
415
416 async fn run_stage_metrics_rewriter_test(sql: &str, format: DistributedMetricsFormat) {
422 let ctx = make_test_ctx().await;
424 let plan = ctx
425 .sql(sql)
426 .await
427 .unwrap()
428 .create_physical_plan()
429 .await
430 .unwrap();
431
432 let stage = make_test_stage(plan.clone());
433
434 let num_metrics_per_task_per_node = 4;
435
436 let mut metrics_collection = HashMap::new();
438 for task_id in 0..stage.tasks.len() {
439 let task_key = TaskKey {
440 query_id: stage.query_id.as_bytes().to_vec(),
441 stage_id: stage.num as u64,
442 task_number: task_id as u64,
443 };
444 let metrics = (0..count_plan_nodes_up_to_network_boundary(&plan))
445 .map(|node_id| {
446 make_test_metrics_set_proto_from_seed(
447 (node_id * task_id) as u64,
448 num_metrics_per_task_per_node,
449 )
450 })
451 .collect::<Vec<pb::MetricsSet>>();
452
453 metrics_collection.insert(task_key, metrics);
454 }
455 let metrics_collection = Arc::new(metrics_collection);
456
457 let rewritten_plan =
459 stage_metrics_rewriter(&stage, metrics_collection.clone(), format).unwrap();
460
461 let mut actual_metrics = vec![];
463 collect_metrics_from_plan(&rewritten_plan, &mut actual_metrics);
464 assert_eq!(
465 actual_metrics.len(),
466 count_plan_nodes_up_to_network_boundary(&plan)
467 );
468
469 for (node_id, actual_stage_node_metrics_set) in actual_metrics.iter().enumerate() {
472 for (task_id, actual_task_node_metrics_set) in actual_stage_node_metrics_set
474 .iter()
475 .chunks(num_metrics_per_task_per_node)
476 .into_iter()
477 .enumerate()
478 {
479 let expected_task_node_metrics = metrics_collection
480 .get(&TaskKey {
481 query_id: stage.query_id.as_bytes().to_vec(),
482 stage_id: stage.num as u64,
483 task_number: task_id as u64,
484 })
485 .unwrap()[node_id]
486 .clone();
487
488 let mut actual_metrics_set = MetricsSet::new();
489 actual_task_node_metrics_set
490 .for_each(|metric| actual_metrics_set.push(metric.clone()));
491
492 let mut expected_metrics_set =
494 metrics_set_proto_to_df(&expected_task_node_metrics).unwrap();
495
496 if format == DistributedMetricsFormat::PerTask {
497 expected_metrics_set =
500 annotate_metrics_set_with_task_id(expected_metrics_set, task_id as u64);
501 }
502 assert!(metrics_set_eq(&actual_metrics_set, &expected_metrics_set));
503 }
504 }
505 }
506
507 #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
508 #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
509 #[tokio::test]
510 async fn test_stage_metrics_rewriter_1(format: DistributedMetricsFormat) {
511 run_stage_metrics_rewriter_test(
512 "SELECT sum(balance) / 7.0 as avg_yearly from table2 group by name",
513 format,
514 )
515 .await;
516 }
517
518 #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
519 #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
520 #[tokio::test]
521 async fn test_stage_metrics_rewriter_2(format: DistributedMetricsFormat) {
522 run_stage_metrics_rewriter_test("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10", format).await;
523 }
524
525 #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
526 #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
527 #[tokio::test]
528 async fn test_stage_metrics_rewriter_3(format: DistributedMetricsFormat) {
529 run_stage_metrics_rewriter_test(
530 "SELECT sum(balance) / 7.0 as avg_yearly
531 FROM table2
532 WHERE name LIKE 'customer%'
533 AND balance < (
534 SELECT 0.2 * avg(balance)
535 FROM table2 t2_inner
536 WHERE t2_inner.id = table2.id
537 )",
538 format,
539 )
540 .await;
541 }
542
543 #[tokio::test]
544 async fn test_rewrite_unexecuted_distributed_plan_with_metrics_err() {
545 let ctx = make_test_distributed_ctx().await;
546 let plan = ctx
547 .sql("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10")
548 .await
549 .unwrap()
550 .create_physical_plan()
551 .await
552 .unwrap();
553 assert!(plan.as_any().is::<DistributedExec>());
554 assert!(
555 rewrite_distributed_plan_with_metrics(plan, DistributedMetricsFormat::Aggregated)
556 .is_err()
557 );
558 }
559
560 fn assert_metrics_present_in_plan(plan: &Arc<dyn ExecutionPlan>) {
562 if let Some(metrics) = plan.metrics() {
563 assert!(metrics.iter().count() > 0);
564 } else {
565 assert!(plan.as_any().is::<DistributedExec>());
566 }
567 for child in plan.children() {
568 assert_metrics_present_in_plan(child);
569 }
570 }
571
572 #[tokio::test]
573 async fn test_executed_distributed_plan_has_metrics() {
574 let ctx = make_test_distributed_ctx().await;
575 let plan = ctx
576 .sql("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10")
577 .await
578 .unwrap()
579 .create_physical_plan()
580 .await
581 .unwrap();
582 collect(plan.clone(), ctx.task_ctx()).await.unwrap();
583 assert!(plan.as_any().is::<DistributedExec>());
584 let rewritten_plan =
585 rewrite_distributed_plan_with_metrics(plan, DistributedMetricsFormat::Aggregated)
586 .unwrap();
587 assert_metrics_present_in_plan(&rewritten_plan);
588 }
589
590 #[test]
591 fn test_wrapped_node_is_accessible() {
595 let example_node = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![Field::new(
596 "id",
597 DataType::Int32,
598 false,
599 )]))));
600
601 let wrapped = MetricsWrapperExec::new(example_node, MetricsSet::new());
602 assert_eq!(wrapped.name(), "EmptyExec");
603 assert!(wrapped.as_any().is::<EmptyExec>());
604 }
605
606 #[test]
607 fn test_annotate_metrics_set_with_task_id_output_rows() {
608 let mut metrics_set = MetricsSet::new();
610 let count = Count::new();
611 count.add(1234);
612 let labels = vec![Label::new("operator", "scan")];
613 metrics_set.push(Arc::new(Metric::new_with_labels(
614 MetricValue::OutputRows(count),
615 Some(0),
616 labels,
617 )));
618
619 let task_id = 42;
620 let annotated = annotate_metrics_set_with_task_id(metrics_set, task_id);
621
622 assert_eq!(annotated.iter().count(), 1);
624
625 let metric = annotated.iter().next().unwrap();
626
627 match metric.value() {
629 MetricValue::OutputRows(count) => {
630 assert_eq!(count.value(), 1234);
631 }
632 other => panic!("Expected OutputRows, got {:?}", other.name()),
633 }
634
635 assert_eq!(metric.partition(), Some(0));
637
638 let labels: Vec<_> = metric.labels().iter().collect();
640 assert_eq!(labels.len(), 2);
641 assert_eq!(labels[0].name(), "operator");
642 assert_eq!(labels[0].value(), "scan");
643 assert_eq!(labels[1].name(), DISTRIBUTED_DATAFUSION_TASK_ID_LABEL);
644 assert_eq!(labels[1].value(), "42");
645 }
646}