Skip to main content

datafusion_distributed/metrics/
task_metrics_rewriter.rs

1use crate::DistributedTaskContext;
2use crate::common::{TreeNodeExt, serialize_uuid};
3use crate::coordinator::{DistributedExec, MetricsStore};
4use crate::distributed_planner::NetworkBoundaryExt;
5use crate::execution_plans::MetricsWrapperExec;
6use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL;
7use crate::metrics::collect_plan_metrics;
8use crate::metrics::proto::metrics_set_proto_to_df;
9use crate::stage::{LocalStage, Stage};
10use crate::worker::generated::worker::TaskKey;
11use datafusion::common::HashMap;
12use datafusion::common::plan_err;
13use datafusion::common::tree_node::Transformed;
14use datafusion::common::tree_node::TreeNode;
15use datafusion::common::tree_node::TreeNodeRecursion;
16use datafusion::error::Result;
17use datafusion::physical_plan::ExecutionPlan;
18use datafusion::physical_plan::internal_err;
19use datafusion::physical_plan::metrics::{Label, Metric, MetricsSet};
20use std::sync::Arc;
21use std::vec;
22
23/// Format to use when displaying metrics for a distributed plan.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum DistributedMetricsFormat {
26    /// Metrics are aggregated across all tasks. ex. a `output_rows=X` represents the output rows for all tasks.
27    Aggregated,
28
29    /// Metric names are rewritten to include the task id. ex. `output_rows` -> `output_rows_0`, `output_rows_1` etc.
30    PerTask,
31}
32
33impl DistributedMetricsFormat {
34    pub(crate) fn to_rewrite_ctx(self, task_id: u64) -> RewriteCtx {
35        match self {
36            DistributedMetricsFormat::Aggregated => RewriteCtx::default(),
37            DistributedMetricsFormat::PerTask => RewriteCtx::from_task_id(task_id),
38        }
39    }
40}
41
42/// Rewrites a distributed plan with metrics. Does nothing if the root node is not a [DistributedExec].
43/// Returns an error if the distributed plan was not executed.
44///
45/// Waits for all worker task metrics to arrive before rewriting, so the result is always complete.
46pub async fn rewrite_distributed_plan_with_metrics(
47    plan: Arc<dyn ExecutionPlan>,
48    format: DistributedMetricsFormat,
49) -> Result<Arc<dyn ExecutionPlan>> {
50    let Some(distributed_exec) = plan.downcast_ref::<DistributedExec>() else {
51        return Ok(plan);
52    };
53
54    // Check that the plan was executed before waiting — if not, prepared_plan() returns an
55    // error immediately rather than waiting forever for metrics that will never arrive.
56    let prepared = distributed_exec.prepared_plan()?;
57
58    distributed_exec.wait_for_metrics().await;
59
60    let Some(metrics_collection) = distributed_exec.metrics_store.clone() else {
61        return Ok(plan);
62    };
63
64    let task_metrics = collect_plan_metrics(&prepared)?;
65
66    // Rewrite the DistributedExec's child plan with metrics.
67    let dist_exec_plan_with_metrics = rewrite_local_plan_with_metrics(
68        format.to_rewrite_ctx(0), // Task id is 0 for the DistributedExec plan
69        plan.children()[0].clone(),
70        task_metrics,
71    )?;
72    let plan = plan.with_new_children(vec![dist_exec_plan_with_metrics])?;
73
74    let transformed = plan.transform_down(|plan| {
75        // After `rewrite_local_plan_with_metrics` above, every node (including network
76        // boundaries) is wrapped in a `MetricsWrapperExec`. Peek through the wrapper so we
77        // can still recognize a network boundary by downcasting the inner node.
78        let inner = plan
79            .downcast_ref::<MetricsWrapperExec>()
80            .map(|w| w.inner_arc())
81            .unwrap_or_else(|| Arc::clone(&plan));
82
83        // Transform all stages using NetworkShuffleExec and NetworkCoalesceExec as barriers.
84        if let Some(network_boundary) = inner.as_network_boundary() {
85            let Stage::Local(stage) = network_boundary.input_stage() else {
86                return plan_err!("Stage was not in Local state");
87            };
88            // This transform is a bit inefficient because we traverse the plan nodes twice
89            // For now, we are okay with trading off performance for simplicity.
90            let plan_with_metrics =
91                stage_metrics_rewriter(stage, Arc::clone(&metrics_collection), format)?;
92            let network_boundary = network_boundary.with_input_stage(Stage::Local(LocalStage {
93                query_id: stage.query_id,
94                num: stage.num,
95                plan: plan_with_metrics,
96                tasks: stage.tasks,
97            }))?;
98            let network_boundary =
99                MetricsWrapperExec::new(network_boundary, plan.metrics().unwrap_or_default());
100            return Ok(Transformed::yes(Arc::new(network_boundary)));
101        }
102
103        Ok(Transformed::no(plan))
104    })?;
105    Ok(transformed.data)
106}
107
108/// Extra information for rewriting local plans.
109#[derive(Default)]
110pub struct RewriteCtx {
111    /// Used to rename metrics for the current task.
112    pub task_id: Option<u64>,
113}
114
115impl RewriteCtx {
116    pub(crate) fn from_task_id(task_id: u64) -> RewriteCtx {
117        RewriteCtx {
118            task_id: Some(task_id),
119        }
120    }
121
122    /// Rewrites the [MetricsSet] depending on the context.
123    pub(crate) fn maybe_rewrite_node_metics(&self, node_metrics: MetricsSet) -> MetricsSet {
124        if let Some(task_id) = self.task_id {
125            return annotate_metrics_set_with_task_id(node_metrics, task_id);
126        }
127        node_metrics
128    }
129}
130
131/// Adds task id labels to all metrics in the provided [MetricsSet].
132///
133/// TODO: This re-allocates the vec of metrics by creating a new [MetricsSet]. It also
134/// reallocates the labels vec for each metric. Can we avoid this?
135/// See https://github.com/apache/datafusion/issues/19959
136pub fn annotate_metrics_set_with_task_id(metrics_set: MetricsSet, task_id: u64) -> MetricsSet {
137    let mut result = MetricsSet::new();
138
139    for metric in metrics_set.iter() {
140        let mut labels = metric.labels().to_vec();
141        labels.push(Label::new(
142            DISTRIBUTED_DATAFUSION_TASK_ID_LABEL,
143            task_id.to_string(),
144        ));
145        result.push(Arc::new(Metric::new_with_labels(
146            metric.value().clone(),
147            metric.partition(),
148            labels,
149        )));
150    }
151
152    result
153}
154
155/// Rewrites a local plan with metrics, stopping at network boundaries.
156///
157/// Example:
158///
159/// AggregateExec [output_rows = 1, elapsed_compute = 100]
160///  └── ProjectionExec [output_rows = 2, elapsed_compute = 200]
161///      └── NetworkShuffleExec [bytes_transferred = 100, max_mem_used = 100]
162///
163/// The result will be:
164///
165/// MetricsWrapperExec (wrapped: AggregateExec) [output_rows = 1, elapsed_compute = 100]
166///  └── MetricsWrapperExec (wrapped: ProjectionExec) [output_rows = 2, elapsed_compute = 200]
167///      └── MetricsWrapperExec (wrapped: NetworkShuffleExec) [bytes_transferred = 100, max_mem_used = 100]
168pub fn rewrite_local_plan_with_metrics(
169    ctx: RewriteCtx,
170    plan: Arc<dyn ExecutionPlan>,
171    metrics: Vec<MetricsSet>,
172) -> Result<Arc<dyn ExecutionPlan>> {
173    let mut idx = 0;
174    Ok(plan
175        .transform_down(|node| {
176            if idx >= metrics.len() {
177                return internal_err!("not enough metrics provided to rewrite plan");
178            }
179            let mut node_metrics = metrics[idx].clone();
180
181            node_metrics = ctx.maybe_rewrite_node_metics(node_metrics);
182
183            idx += 1;
184            Ok(Transformed::new(
185                Arc::new(MetricsWrapperExec::new(node.clone(), node_metrics)),
186                true,
187                if node.is_network_boundary() {
188                    TreeNodeRecursion::Jump
189                } else {
190                    TreeNodeRecursion::Continue
191                },
192            ))
193        })?
194        .data)
195}
196
197/// Enriches a stage with metrics from each task by re-writing the plan using
198/// [MetricsWrapperExec] nodes.
199///
200/// Example:
201///
202/// For a stage with 2 tasks:
203///
204/// Task 1:
205/// AggregateExec [output_rows = 1, elapsed_compute = 100]
206///  └── ProjectionExec [output_rows = 2, elapsed_compute = 200]
207///      └── NetworkShuffleExec [bytes_transferred = 100, max_mem_used = 100]
208///
209/// Task 2:
210/// AggregateExec [output_rows = 3, elapsed_compute = 300]
211///  └── ProjectionExec [output_rows = 4, elapsed_compute = 400]
212///      └── NetworkShuffleExec [bytes_transferred = 200, max_mem_used = 200]
213///
214/// The result will be:
215///
216/// MetricsWrapperExec (wrapped: AggregateExec) [output_rows = 1, output_rows = 3, elapsed_compute = 100, elapsed_compute = 300]
217///  └── MetricsWrapperExec (wrapped: ProjectionExec) [output_rows = 2, output_rows = 4, elapsed_compute = 200, elapsed_compute = 400]
218///      └── MetricsWrapperExec (wrapped: NetworkShuffleExec) [bytes_transferred = 100, bytes_transferred = 200, max_mem_used = 100, max_mem_used = 200]
219///
220/// Note: Metrics may be aggregated by name (ex. output_rows) automatically by various datafusion utils.
221pub fn stage_metrics_rewriter(
222    stage: &LocalStage,
223    metrics_collection: Arc<MetricsStore>,
224    format: DistributedMetricsFormat,
225) -> Result<Arc<dyn ExecutionPlan>> {
226    // Phase 1 — accumulate per-task metrics into a map keyed by node identity.
227    //
228    // For each task, the plan is traversed with `apply_with_dt_ctx`, which visits nodes in pre-order
229    // traversal, ignoring branches that do not belong to the recursed DistributedTaskContext
230    // (e.g., because of the presence of ChildrenIsolatorUnionExec).
231    //
232    // The raw allocation address of each `Arc<dyn ExecutionPlan>` as the node key.
233    // The planning plan is not modified between traversals, so these addresses are stable.
234    let mut node_metrics_map: HashMap<usize, MetricsSet> = HashMap::new();
235
236    for task_id in 0..stage.tasks {
237        let d_ctx = DistributedTaskContext {
238            task_index: task_id,
239            task_count: stage.tasks,
240        };
241        let task_key = TaskKey {
242            query_id: serialize_uuid(&stage.query_id),
243            stage_id: stage.num as u64,
244            task_number: task_id as u64,
245        };
246        let Some(task_metrics) = metrics_collection.get(&task_key) else {
247            return internal_err!(
248                "not enough metrics provided to rewrite task: missing metrics for task {} in stage {}",
249                task_id,
250                stage.num
251            );
252        };
253
254        let mut per_task_counter = 0usize;
255        stage.plan.apply_with_dt_ctx(d_ctx, |node, _ctx| {
256            if per_task_counter >= task_metrics.pre_order_plan_metrics.len() {
257                return internal_err!(
258                    "not enough metrics provided to rewrite task: {} metrics provided",
259                    task_metrics.pre_order_plan_metrics.len()
260                );
261            }
262
263            let node_metrics_protos = task_metrics.pre_order_plan_metrics[per_task_counter].clone();
264            let mut node_metrics = metrics_set_proto_to_df(&node_metrics_protos)?;
265            let rewrite_ctx = format.to_rewrite_ctx(task_id as u64);
266            node_metrics = rewrite_ctx.maybe_rewrite_node_metics(node_metrics);
267
268            let id = Arc::as_ptr(node) as *const () as usize;
269            let entry = node_metrics_map.entry(id).or_default();
270            for metric in node_metrics.iter().map(Arc::clone) {
271                entry.push(metric);
272            }
273
274            per_task_counter += 1;
275            Ok(TreeNodeRecursion::Continue)
276        })?;
277    }
278
279    // Phase 2 — rewrite: wrap every node with its accumulated metrics.
280    // Nodes that were inactive for all tasks (never visited in phase 1) get empty metrics.
281    Arc::clone(&stage.plan)
282        .transform_down(|plan| {
283            let id = Arc::as_ptr(&plan) as *const () as usize;
284            let metrics = node_metrics_map.remove(&id).unwrap_or_default();
285            Ok(Transformed::new(
286                Arc::new(MetricsWrapperExec::new(plan.clone(), metrics)),
287                true,
288                match plan.is_network_boundary() {
289                    true => TreeNodeRecursion::Jump,
290                    false => TreeNodeRecursion::Continue,
291                },
292            ))
293        })
294        .map(|v| v.data)
295}
296
297#[cfg(test)]
298mod tests {
299    use crate::coordinator::MetricsStore;
300    use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL;
301    use crate::metrics::proto::{df_metrics_set_to_proto, metrics_set_proto_to_df};
302    use crate::metrics::task_metrics_rewriter::{
303        annotate_metrics_set_with_task_id, stage_metrics_rewriter,
304    };
305    use crate::metrics::{DistributedMetricsFormat, rewrite_distributed_plan_with_metrics};
306    use crate::test_utils::in_memory_channel_resolver::{
307        InMemoryChannelResolver, InMemoryWorkerResolver,
308    };
309    use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed;
310    use crate::test_utils::plans::count_plan_nodes_up_to_network_boundary;
311    use crate::test_utils::session_context::register_temp_parquet_table;
312    use crate::worker::generated::worker as pb;
313    use crate::{DistributedExec, SessionStateBuilderExt};
314    use datafusion::arrow::array::{Int32Array, StringArray};
315    use datafusion::arrow::datatypes::{DataType, Field, Schema};
316    use datafusion::arrow::record_batch::RecordBatch;
317    use datafusion::execution::SessionStateBuilder;
318    use datafusion::physical_plan::metrics::{Count, Label, Metric, MetricValue, MetricsSet};
319    use test_case::test_case;
320
321    use datafusion::physical_plan::{ExecutionPlan, collect};
322    use itertools::Itertools;
323    use uuid::Uuid;
324
325    use crate::DistributedExt;
326    use crate::common::serialize_uuid;
327    use crate::metrics::task_metrics_rewriter::MetricsWrapperExec;
328    use crate::stage::LocalStage;
329    use crate::worker::generated::worker::TaskKey;
330    use datafusion::physical_plan::empty::EmptyExec;
331    use datafusion::prelude::SessionConfig;
332    use datafusion::prelude::SessionContext;
333    use std::sync::Arc;
334
335    async fn make_test_ctx() -> SessionContext {
336        make_test_ctx_inner(false).await
337    }
338
339    async fn make_test_distributed_ctx() -> SessionContext {
340        make_test_ctx_inner(true).await
341    }
342
343    /// Creates a non-distributed session context and registers two tables:
344    /// - table1 (id: int, name: string)
345    /// - table2 (id: int, name: string, phone: string, balance: float64)
346    async fn make_test_ctx_inner(distributed: bool) -> SessionContext {
347        let config = SessionConfig::new().with_target_partitions(4);
348        let mut builder = SessionStateBuilder::new()
349            .with_default_features()
350            .with_config(config);
351
352        if distributed {
353            builder = builder
354                .with_distributed_worker_resolver(InMemoryWorkerResolver::new(10))
355                .with_distributed_channel_resolver(InMemoryChannelResolver::default())
356                .with_distributed_metrics_collection(true)
357                .unwrap()
358                .with_distributed_planner()
359                .with_distributed_task_estimator(2)
360        }
361
362        let state = builder.build();
363        let ctx = SessionContext::from(state);
364
365        // Create test data for table1
366        let schema1 = Arc::new(Schema::new(vec![
367            Field::new("id", DataType::Int32, false),
368            Field::new("name", DataType::Utf8, false),
369        ]));
370
371        let batches1 = vec![
372            RecordBatch::try_new(
373                schema1.clone(),
374                vec![
375                    Arc::new(Int32Array::from(vec![1, 2, 3])),
376                    Arc::new(StringArray::from(vec!["a", "b", "c"])),
377                ],
378            )
379            .unwrap(),
380        ];
381
382        // Create test data for table2 with extended schema
383        let schema2 = Arc::new(Schema::new(vec![
384            Field::new("id", DataType::Int32, false),
385            Field::new("name", DataType::Utf8, false),
386            Field::new("phone", DataType::Utf8, false),
387            Field::new("balance", DataType::Float64, false),
388        ]));
389
390        let batches2 = vec![
391            RecordBatch::try_new(
392                schema2.clone(),
393                vec![
394                    Arc::new(Int32Array::from(vec![1, 2, 3])),
395                    Arc::new(StringArray::from(vec![
396                        "customer1",
397                        "customer2",
398                        "customer3",
399                    ])),
400                    Arc::new(StringArray::from(vec![
401                        "13-123-4567",
402                        "31-456-7890",
403                        "23-789-0123",
404                    ])),
405                    Arc::new(datafusion::arrow::array::Float64Array::from(vec![
406                        100.5, 250.0, 50.25,
407                    ])),
408                ],
409            )
410            .unwrap(),
411        ];
412
413        // Register the test data as parquet tables
414        let _ = register_temp_parquet_table("table1", schema1, batches1, &ctx)
415            .await
416            .unwrap();
417
418        let _ = register_temp_parquet_table("table2", schema2, batches2, &ctx)
419            .await
420            .unwrap();
421
422        ctx
423    }
424
425    fn make_test_stage(plan: Arc<dyn ExecutionPlan>) -> LocalStage {
426        LocalStage {
427            query_id: Uuid::new_v4(),
428            num: 2,
429            plan,
430            tasks: 4,
431        }
432    }
433
434    fn collect_metrics_from_plan(plan: &Arc<dyn ExecutionPlan>, metrics: &mut Vec<MetricsSet>) {
435        metrics.extend(plan.metrics());
436        for child in plan.children() {
437            collect_metrics_from_plan(child, metrics);
438        }
439    }
440
441    fn metrics_set_eq(a: &MetricsSet, b: &MetricsSet) -> bool {
442        println!("a: {a:?}");
443        println!("b: {b:?}");
444        // Check equality by converting to proto representation.
445        df_metrics_set_to_proto(a).unwrap() == df_metrics_set_to_proto(b).unwrap()
446    }
447
448    /// Asserts that we successfully re-write the metrics of a plan generated from the provided SQL query.
449    /// Also asserts that the order which metrics are collected from a plan matches the order which
450    /// they are re-written (ie. ensures we don't assign metrics to the wrong nodes)
451    ///
452    /// Only tests single node plans since the [TaskMetricsRewriter] stops on [NetworkBoundary].
453    async fn run_stage_metrics_rewriter_test(sql: &str, format: DistributedMetricsFormat) {
454        // Generate the plan
455        let ctx = make_test_ctx().await;
456        let plan = ctx
457            .sql(sql)
458            .await
459            .unwrap()
460            .create_physical_plan()
461            .await
462            .unwrap();
463
464        let stage = make_test_stage(plan.clone());
465
466        let num_metrics_per_task_per_node = 4;
467
468        // Generate metrics for each task and store them in the map.
469        let metrics_collection = MetricsStore::from_entries((0..stage.tasks).map(|task_id| {
470            let task_key = TaskKey {
471                query_id: serialize_uuid(&stage.query_id),
472                stage_id: stage.num as u64,
473                task_number: task_id as u64,
474            };
475            let metrics = (0..count_plan_nodes_up_to_network_boundary(&plan))
476                .map(|node_id| {
477                    make_test_metrics_set_proto_from_seed(
478                        (node_id * task_id) as u64,
479                        num_metrics_per_task_per_node,
480                    )
481                })
482                .collect::<Vec<pb::MetricsSet>>();
483            let task_metrics = pb::TaskMetrics {
484                task_metrics: None,
485                pre_order_plan_metrics: metrics,
486            };
487            (task_key, task_metrics)
488        }));
489        let metrics_collection = Arc::new(metrics_collection);
490
491        // Rewrite the plan.
492        let rewritten_plan =
493            stage_metrics_rewriter(&stage, metrics_collection.clone(), format).unwrap();
494
495        // Collect metrics from the plan.
496        let mut actual_metrics = vec![];
497        collect_metrics_from_plan(&rewritten_plan, &mut actual_metrics);
498        assert_eq!(
499            actual_metrics.len(),
500            count_plan_nodes_up_to_network_boundary(&plan)
501        );
502
503        // Assert that metrics from all tasks are present.
504        // actual_stage_node_metrics_set contains metrics for all task ex. [output_rows=1, elapsed_compute=1, output_rows=2, elapsed_compute=2...]
505        for (node_id, actual_stage_node_metrics_set) in actual_metrics.iter().enumerate() {
506            // actual_task_node_metrics_set contains metrics for one task ex. [output_rows=1, elapsed_compute=1]
507            for (task_id, actual_task_node_metrics_set) in actual_stage_node_metrics_set
508                .iter()
509                .chunks(num_metrics_per_task_per_node)
510                .into_iter()
511                .enumerate()
512            {
513                let expected_task_node_metrics = metrics_collection
514                    .get(&TaskKey {
515                        query_id: serialize_uuid(&stage.query_id),
516                        stage_id: stage.num as u64,
517                        task_number: task_id as u64,
518                    })
519                    .unwrap()
520                    .pre_order_plan_metrics[node_id]
521                    .clone();
522
523                let mut actual_metrics_set = MetricsSet::new();
524                actual_task_node_metrics_set
525                    .for_each(|metric| actual_metrics_set.push(metric.clone()));
526
527                // Convert from proto to check for equality.
528                let mut expected_metrics_set =
529                    metrics_set_proto_to_df(&expected_task_node_metrics).unwrap();
530
531                if format == DistributedMetricsFormat::PerTask {
532                    // Add task ids labels. We expect the actual metrics to be annotated by the
533                    // rewriter when using DistributedMetricsFormat::PerTask
534                    expected_metrics_set =
535                        annotate_metrics_set_with_task_id(expected_metrics_set, task_id as u64);
536                }
537                assert!(metrics_set_eq(&actual_metrics_set, &expected_metrics_set));
538            }
539        }
540    }
541
542    #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
543    #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
544    #[tokio::test]
545    async fn test_stage_metrics_rewriter_1(format: DistributedMetricsFormat) {
546        run_stage_metrics_rewriter_test(
547            "SELECT sum(balance) / 7.0 as avg_yearly from table2 group by name",
548            format,
549        )
550        .await;
551    }
552
553    #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
554    #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
555    #[tokio::test]
556    async fn test_stage_metrics_rewriter_2(format: DistributedMetricsFormat) {
557        run_stage_metrics_rewriter_test("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10", format).await;
558    }
559
560    #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
561    #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
562    #[tokio::test]
563    async fn test_stage_metrics_rewriter_3(format: DistributedMetricsFormat) {
564        run_stage_metrics_rewriter_test(
565            "SELECT sum(balance) / 7.0 as avg_yearly
566            FROM table2
567            WHERE name LIKE 'customer%'
568              AND balance < (
569                SELECT 0.2 * avg(balance)
570                FROM table2 t2_inner
571                WHERE t2_inner.id = table2.id
572              )",
573            format,
574        )
575        .await;
576    }
577
578    #[tokio::test]
579    async fn test_rewrite_unexecuted_distributed_plan_with_metrics_err() {
580        let ctx = make_test_distributed_ctx().await;
581        let plan = ctx
582            .sql("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10")
583            .await
584            .unwrap()
585            .create_physical_plan()
586            .await
587            .unwrap();
588        assert!(plan.is::<DistributedExec>());
589        assert!(
590            rewrite_distributed_plan_with_metrics(plan, DistributedMetricsFormat::Aggregated)
591                .await
592                .is_err()
593        );
594    }
595
596    // Assert every plan node has at least one metric except partition isolators, network boundary nodes, and the root DistributedExec node.
597    fn assert_metrics_present_in_plan(plan: &Arc<dyn ExecutionPlan>) {
598        if let Some(metrics) = plan.metrics() {
599            assert!(metrics.iter().count() > 0);
600        } else {
601            assert!(plan.is::<DistributedExec>());
602        }
603        for child in plan.children() {
604            assert_metrics_present_in_plan(child);
605        }
606    }
607
608    #[tokio::test]
609    async fn test_executed_distributed_plan_has_metrics() {
610        let ctx = make_test_distributed_ctx().await;
611        let plan = ctx
612            .sql("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10")
613            .await
614            .unwrap()
615            .create_physical_plan()
616            .await
617            .unwrap();
618        collect(plan.clone(), ctx.task_ctx()).await.unwrap();
619        assert!(plan.is::<DistributedExec>());
620        let rewritten_plan =
621            rewrite_distributed_plan_with_metrics(plan, DistributedMetricsFormat::Aggregated)
622                .await
623                .unwrap();
624        assert_metrics_present_in_plan(&rewritten_plan);
625    }
626
627    #[test]
628    // An important feature of DF execution plans which we want to preserve is the ability
629    // to traverse a plan and collect metrics from specific nodes. To do this, the wrapper must
630    // allow access to the inner node. This test asserts that we support this.
631    fn test_wrapped_node_is_accessible() {
632        let example_node = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![Field::new(
633            "id",
634            DataType::Int32,
635            false,
636        )]))));
637
638        let wrapped = MetricsWrapperExec::new(example_node, MetricsSet::new());
639        assert_eq!(wrapped.name(), "EmptyExec");
640        assert!(wrapped.inner().is::<EmptyExec>());
641    }
642
643    #[test]
644    fn test_annotate_metrics_set_with_task_id_output_rows() {
645        // Create a MetricsSet with an OutputRows metric
646        let mut metrics_set = MetricsSet::new();
647        let count = Count::new();
648        count.add(1234);
649        let labels = vec![Label::new("operator", "scan")];
650        metrics_set.push(Arc::new(Metric::new_with_labels(
651            MetricValue::OutputRows(count),
652            Some(0),
653            labels,
654        )));
655
656        let task_id = 42;
657        let annotated = annotate_metrics_set_with_task_id(metrics_set, task_id);
658
659        // Verify we have one metric
660        assert_eq!(annotated.iter().count(), 1);
661
662        let metric = annotated.iter().next().unwrap();
663
664        // Verify metric type is preserved (OutputRows)
665        match metric.value() {
666            MetricValue::OutputRows(count) => {
667                assert_eq!(count.value(), 1234);
668            }
669            other => panic!("Expected OutputRows, got {:?}", other.name()),
670        }
671
672        // Verify partition is preserved
673        assert_eq!(metric.partition(), Some(0));
674
675        // Verify original labels are preserved and task_id label is added
676        let labels: Vec<_> = metric.labels().iter().collect();
677        assert_eq!(labels.len(), 2);
678        assert_eq!(labels[0].name(), "operator");
679        assert_eq!(labels[0].value(), "scan");
680        assert_eq!(labels[1].name(), DISTRIBUTED_DATAFUSION_TASK_ID_LABEL);
681        assert_eq!(labels[1].value(), "42");
682    }
683}