Skip to main content

datafusion_distributed/metrics/
task_metrics_rewriter.rs

1use crate::distributed_planner::NetworkBoundaryExt;
2use crate::execution_plans::DistributedExec;
3use crate::execution_plans::MetricsWrapperExec;
4use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL;
5use crate::metrics::MetricsCollectorResult;
6use crate::metrics::TaskMetricsCollector;
7use crate::metrics::proto::metrics_set_proto_to_df;
8use crate::stage::Stage;
9use crate::worker::generated::worker as pb;
10use crate::worker::generated::worker::TaskKey;
11use datafusion::common::HashMap;
12use datafusion::common::tree_node::Transformed;
13use datafusion::common::tree_node::TreeNode;
14use datafusion::common::tree_node::TreeNodeRecursion;
15use datafusion::error::Result;
16use datafusion::physical_plan::ExecutionPlan;
17use datafusion::physical_plan::internal_err;
18use datafusion::physical_plan::metrics::{Label, Metric, MetricsSet};
19use std::sync::Arc;
20use std::vec;
21
22/// Format to use when displaying metrics for a distributed plan.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum DistributedMetricsFormat {
25    /// Metrics are aggregated across all tasks. ex. a `output_rows=X` represents the output rows for all tasks.
26    Aggregated,
27
28    /// Metric names are rewritten to include the task id. ex. `output_rows` -> `output_rows_0`, `output_rows_1` etc.
29    PerTask,
30}
31
32impl DistributedMetricsFormat {
33    pub(crate) fn to_rewrite_ctx(self, task_id: u64) -> RewriteCtx {
34        match self {
35            DistributedMetricsFormat::Aggregated => RewriteCtx::default(),
36            DistributedMetricsFormat::PerTask => RewriteCtx::from_task_id(task_id),
37        }
38    }
39}
40
41/// Rewrites a distributed plan with metrics. Does nothing if the root node is not a [DistributedExec].
42/// Returns an error if the distributed plan was not executed.
43pub fn rewrite_distributed_plan_with_metrics(
44    plan: Arc<dyn ExecutionPlan>,
45    format: DistributedMetricsFormat,
46) -> Result<Arc<dyn ExecutionPlan>> {
47    let Some(distributed_exec) = plan.as_any().downcast_ref::<DistributedExec>() else {
48        return Ok(plan);
49    };
50
51    // Collect metrics from the DistributedExec's prepared plan.
52    let MetricsCollectorResult {
53        task_metrics,       // Metrics for the DistributedExec plan
54        input_task_metrics, // Metrics for all child stages / tasks.
55    } = TaskMetricsCollector::new().collect(distributed_exec.prepared_plan()?)?;
56
57    // Rewrite the DistributedExec's child plan with metrics.
58    let dist_exec_plan_with_metrics = rewrite_local_plan_with_metrics(
59        format.to_rewrite_ctx(0), // Task id is 0 for the DistributedExec plan
60        plan.children()[0].clone(),
61        task_metrics,
62    )?;
63    let plan = plan.with_new_children(vec![dist_exec_plan_with_metrics])?;
64
65    let metrics_collection = Arc::new(input_task_metrics);
66
67    let transformed = plan.transform_down(|plan| {
68        // Transform all stages using NetworkShuffleExec and NetworkCoalesceExec as barriers.
69        if let Some(network_boundary) = plan.as_network_boundary() {
70            let stage = network_boundary.input_stage();
71            // This transform is a bit inefficient because we traverse the plan nodes twice
72            // For now, we are okay with trading off performance for simplicity.
73            let plan_with_metrics =
74                stage_metrics_rewriter(stage, metrics_collection.clone(), format)?;
75            let network_boundary = network_boundary.with_input_stage(Stage::new(
76                stage.query_id,
77                stage.num,
78                plan_with_metrics,
79                stage.tasks.len(),
80            ))?;
81            let network_boundary =
82                MetricsWrapperExec::new(network_boundary, plan.metrics().unwrap_or_default());
83            return Ok(Transformed::yes(Arc::new(network_boundary)));
84        }
85
86        Ok(Transformed::no(plan))
87    })?;
88    Ok(transformed.data)
89}
90
91/// Extra information for rewriting local plans.
92#[derive(Default)]
93pub struct RewriteCtx {
94    /// Used to rename metrics for the current task.
95    pub task_id: Option<u64>,
96}
97
98impl RewriteCtx {
99    pub(crate) fn from_task_id(task_id: u64) -> RewriteCtx {
100        RewriteCtx {
101            task_id: Some(task_id),
102        }
103    }
104
105    /// Rewrites the [MetricsSet] depending on the context.
106    pub(crate) fn maybe_rewrite_node_metics(&self, node_metrics: MetricsSet) -> MetricsSet {
107        if let Some(task_id) = self.task_id {
108            return annotate_metrics_set_with_task_id(node_metrics, task_id);
109        }
110        node_metrics
111    }
112}
113
114/// Adds task id labels to all metrics in the provided [MetricsSet].
115///
116/// TODO: This re-allocates the vec of metrics by creating a new [MetricsSet]. It also
117/// reallocates the labels vec for each metric. Can we avoid this?
118/// See https://github.com/apache/datafusion/issues/19959
119pub fn annotate_metrics_set_with_task_id(metrics_set: MetricsSet, task_id: u64) -> MetricsSet {
120    let mut result = MetricsSet::new();
121
122    for metric in metrics_set.iter() {
123        let mut labels = metric.labels().to_vec();
124        labels.push(Label::new(
125            DISTRIBUTED_DATAFUSION_TASK_ID_LABEL,
126            task_id.to_string(),
127        ));
128        result.push(Arc::new(Metric::new_with_labels(
129            metric.value().clone(),
130            metric.partition(),
131            labels,
132        )));
133    }
134
135    result
136}
137
138/// Rewrites a local plan with metrics, stopping at network boundaries.
139///
140/// Example:
141///
142/// AggregateExec [output_rows = 1, elapsed_compute = 100]
143///  └── ProjectionExec [output_rows = 2, elapsed_compute = 200]
144///      └── NetworkShuffleExec [bytes_transferred = 100, max_mem_used = 100]
145///
146/// The result will be:
147///
148/// MetricsWrapperExec (wrapped: AggregateExec) [output_rows = 1, elapsed_compute = 100]
149///  └── MetricsWrapperExec (wrapped: ProjectionExec) [output_rows = 2, elapsed_compute = 200]
150///      └── MetricsWrapperExec (wrapped: NetworkShuffleExec) [bytes_transferred = 100, max_mem_used = 100]
151pub fn rewrite_local_plan_with_metrics(
152    ctx: RewriteCtx,
153    plan: Arc<dyn ExecutionPlan>,
154    metrics: Vec<MetricsSet>,
155) -> Result<Arc<dyn ExecutionPlan>> {
156    let mut idx = 0;
157    Ok(plan
158        .transform_down(|node| {
159            if idx >= metrics.len() {
160                return internal_err!("not enough metrics provided to rewrite plan");
161            }
162            let mut node_metrics = metrics[idx].clone();
163
164            node_metrics = ctx.maybe_rewrite_node_metics(node_metrics);
165
166            idx += 1;
167            Ok(Transformed::new(
168                Arc::new(MetricsWrapperExec::new(node.clone(), node_metrics)),
169                true,
170                if node.is_network_boundary() {
171                    TreeNodeRecursion::Jump
172                } else {
173                    TreeNodeRecursion::Continue
174                },
175            ))
176        })?
177        .data)
178}
179
180/// Enriches a stage with metrics from each task by re-writing the plan using
181/// [MetricsWrapperExec] nodes.
182///
183/// Example:
184///
185/// For a stage with 2 tasks:
186///
187/// Task 1:
188/// AggregateExec [output_rows = 1, elapsed_compute = 100]
189///  └── ProjectionExec [output_rows = 2, elapsed_compute = 200]
190///      └── NetworkShuffleExec [bytes_transferred = 100, max_mem_used = 100]
191///
192/// Task 2:
193/// AggregateExec [output_rows = 3, elapsed_compute = 300]
194///  └── ProjectionExec [output_rows = 4, elapsed_compute = 400]
195///      └── NetworkShuffleExec [bytes_transferred = 200, max_mem_used = 200]
196///
197/// The result will be:
198///
199/// MetricsWrapperExec (wrapped: AggregateExec) [output_rows = 1, output_rows = 3, elapsed_compute = 100, elapsed_compute = 300]
200///  └── MetricsWrapperExec (wrapped: ProjectionExec) [output_rows = 2, output_rows = 4, elapsed_compute = 200, elapsed_compute = 400]
201///      └── MetricsWrapperExec (wrapped: NetworkShuffleExec) [bytes_transferred = 100, bytes_transferred = 200, max_mem_used = 100, max_mem_used = 200]
202///
203/// Note: Metrics may be aggregated by name (ex. output_rows) automatically by various datafusion utils.
204pub fn stage_metrics_rewriter(
205    stage: &Stage,
206    metrics_collection: Arc<HashMap<TaskKey, Vec<pb::MetricsSet>>>,
207    format: DistributedMetricsFormat,
208) -> Result<Arc<dyn ExecutionPlan>> {
209    let mut node_idx = 0;
210
211    let Some(plan) = &stage.plan else {
212        return internal_err!("The inner plan of a stage was not present");
213    };
214
215    plan.clone().transform_down(|plan| {
216        // Collect metrics for this node. It should contain metrics from each task.
217        let mut stage_metrics = MetricsSet::new();
218
219        for task_id in 0..stage.tasks.len() {
220            let task_key = TaskKey {
221                query_id: stage.query_id.as_bytes().to_vec(),
222                stage_id: stage.num as u64,
223                task_number: task_id as u64,
224            };
225            match metrics_collection.get(&task_key) {
226                Some(task_metrics) => {
227                    if node_idx >= task_metrics.len() {
228                        return internal_err!(
229                            "not enough metrics provided to rewrite task: {} metrics provided",
230                            task_metrics.len()
231                        );
232                    }
233                    let node_metrics_protos = task_metrics[node_idx].clone();
234                    let mut node_metrics = metrics_set_proto_to_df(&node_metrics_protos)?;
235
236                    let rewrite_ctx = format.to_rewrite_ctx(task_id as u64);
237                    node_metrics = rewrite_ctx.maybe_rewrite_node_metics(node_metrics);
238
239                    for metric in node_metrics.iter().map(Arc::clone) {
240                        stage_metrics.push(metric);
241                    }
242                }
243                None => {
244                    return internal_err!(
245                        "not enough metrics provided to rewrite task: missing metrics for task {} in stage {}",
246                        task_id,
247                        stage.num
248                    );
249                }
250            }
251        }
252
253        node_idx += 1;
254
255        let wrapped_plan_node: Arc<dyn ExecutionPlan> = Arc::new(MetricsWrapperExec::new(
256            plan.clone(),
257            stage_metrics,
258        ));
259        Ok(Transformed::new(
260            wrapped_plan_node,
261            true,
262            if plan.is_network_boundary() {
263                TreeNodeRecursion::Jump
264            } else {
265                TreeNodeRecursion::Continue
266            }
267        ))
268    }).map(|v| v.data)
269}
270
271#[cfg(test)]
272mod tests {
273    use crate::Stage;
274    use crate::metrics::DISTRIBUTED_DATAFUSION_TASK_ID_LABEL;
275    use crate::metrics::proto::{df_metrics_set_to_proto, metrics_set_proto_to_df};
276    use crate::metrics::task_metrics_rewriter::{
277        annotate_metrics_set_with_task_id, stage_metrics_rewriter,
278    };
279    use crate::metrics::{DistributedMetricsFormat, rewrite_distributed_plan_with_metrics};
280    use crate::test_utils::in_memory_channel_resolver::{
281        InMemoryChannelResolver, InMemoryWorkerResolver,
282    };
283    use crate::test_utils::metrics::make_test_metrics_set_proto_from_seed;
284    use crate::test_utils::plans::count_plan_nodes_up_to_network_boundary;
285    use crate::test_utils::session_context::register_temp_parquet_table;
286    use crate::worker::generated::worker as pb;
287    use crate::{DistributedExec, DistributedPhysicalOptimizerRule};
288    use datafusion::arrow::array::{Int32Array, StringArray};
289    use datafusion::arrow::datatypes::{DataType, Field, Schema};
290    use datafusion::arrow::record_batch::RecordBatch;
291    use datafusion::common::HashMap;
292    use datafusion::execution::SessionStateBuilder;
293    use datafusion::physical_plan::metrics::{Count, Label, Metric, MetricValue, MetricsSet};
294    use test_case::test_case;
295
296    use datafusion::physical_plan::{ExecutionPlan, collect};
297    use itertools::Itertools;
298    use uuid::Uuid;
299
300    use crate::DistributedExt;
301    use crate::metrics::task_metrics_rewriter::MetricsWrapperExec;
302    use crate::worker::generated::worker::TaskKey;
303    use datafusion::physical_plan::empty::EmptyExec;
304    use datafusion::prelude::SessionConfig;
305    use datafusion::prelude::SessionContext;
306    use std::sync::Arc;
307
308    async fn make_test_ctx() -> SessionContext {
309        make_test_ctx_inner(false).await
310    }
311
312    async fn make_test_distributed_ctx() -> SessionContext {
313        make_test_ctx_inner(true).await
314    }
315
316    /// Creates a non-distributed session context and registers two tables:
317    /// - table1 (id: int, name: string)
318    /// - table2 (id: int, name: string, phone: string, balance: float64)
319    async fn make_test_ctx_inner(distributed: bool) -> SessionContext {
320        let config = SessionConfig::new().with_target_partitions(4);
321        let mut builder = SessionStateBuilder::new()
322            .with_default_features()
323            .with_config(config);
324
325        if distributed {
326            builder = builder
327                .with_distributed_worker_resolver(InMemoryWorkerResolver::new(10))
328                .with_distributed_channel_resolver(InMemoryChannelResolver::default())
329                .with_distributed_metrics_collection(true)
330                .unwrap()
331                .with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule))
332                .with_distributed_task_estimator(2)
333        }
334
335        let state = builder.build();
336        let ctx = SessionContext::from(state);
337
338        // Create test data for table1
339        let schema1 = Arc::new(Schema::new(vec![
340            Field::new("id", DataType::Int32, false),
341            Field::new("name", DataType::Utf8, false),
342        ]));
343
344        let batches1 = vec![
345            RecordBatch::try_new(
346                schema1.clone(),
347                vec![
348                    Arc::new(Int32Array::from(vec![1, 2, 3])),
349                    Arc::new(StringArray::from(vec!["a", "b", "c"])),
350                ],
351            )
352            .unwrap(),
353        ];
354
355        // Create test data for table2 with extended schema
356        let schema2 = Arc::new(Schema::new(vec![
357            Field::new("id", DataType::Int32, false),
358            Field::new("name", DataType::Utf8, false),
359            Field::new("phone", DataType::Utf8, false),
360            Field::new("balance", DataType::Float64, false),
361        ]));
362
363        let batches2 = vec![
364            RecordBatch::try_new(
365                schema2.clone(),
366                vec![
367                    Arc::new(Int32Array::from(vec![1, 2, 3])),
368                    Arc::new(StringArray::from(vec![
369                        "customer1",
370                        "customer2",
371                        "customer3",
372                    ])),
373                    Arc::new(StringArray::from(vec![
374                        "13-123-4567",
375                        "31-456-7890",
376                        "23-789-0123",
377                    ])),
378                    Arc::new(datafusion::arrow::array::Float64Array::from(vec![
379                        100.5, 250.0, 50.25,
380                    ])),
381                ],
382            )
383            .unwrap(),
384        ];
385
386        // Register the test data as parquet tables
387        let _ = register_temp_parquet_table("table1", schema1, batches1, &ctx)
388            .await
389            .unwrap();
390
391        let _ = register_temp_parquet_table("table2", schema2, batches2, &ctx)
392            .await
393            .unwrap();
394
395        ctx
396    }
397
398    fn make_test_stage(plan: Arc<dyn ExecutionPlan>) -> Stage {
399        Stage::new(Uuid::new_v4(), 2, plan, 4)
400    }
401
402    fn collect_metrics_from_plan(plan: &Arc<dyn ExecutionPlan>, metrics: &mut Vec<MetricsSet>) {
403        metrics.extend(plan.metrics());
404        for child in plan.children() {
405            collect_metrics_from_plan(child, metrics);
406        }
407    }
408
409    fn metrics_set_eq(a: &MetricsSet, b: &MetricsSet) -> bool {
410        println!("a: {a:?}");
411        println!("b: {b:?}");
412        // Check equality by converting to proto representation.
413        df_metrics_set_to_proto(a).unwrap() == df_metrics_set_to_proto(b).unwrap()
414    }
415
416    /// Asserts that we successfully re-write the metrics of a plan generated from the provided SQL query.
417    /// Also asserts that the order which metrics are collected from a plan matches the order which
418    /// they are re-written (ie. ensures we don't assign metrics to the wrong nodes)
419    ///
420    /// Only tests single node plans since the [TaskMetricsRewriter] stops on [NetworkBoundary].
421    async fn run_stage_metrics_rewriter_test(sql: &str, format: DistributedMetricsFormat) {
422        // Generate the plan
423        let ctx = make_test_ctx().await;
424        let plan = ctx
425            .sql(sql)
426            .await
427            .unwrap()
428            .create_physical_plan()
429            .await
430            .unwrap();
431
432        let stage = make_test_stage(plan.clone());
433
434        let num_metrics_per_task_per_node = 4;
435
436        // Generate metrics for each task and store them in the map.
437        let mut metrics_collection = HashMap::new();
438        for task_id in 0..stage.tasks.len() {
439            let task_key = TaskKey {
440                query_id: stage.query_id.as_bytes().to_vec(),
441                stage_id: stage.num as u64,
442                task_number: task_id as u64,
443            };
444            let metrics = (0..count_plan_nodes_up_to_network_boundary(&plan))
445                .map(|node_id| {
446                    make_test_metrics_set_proto_from_seed(
447                        (node_id * task_id) as u64,
448                        num_metrics_per_task_per_node,
449                    )
450                })
451                .collect::<Vec<pb::MetricsSet>>();
452
453            metrics_collection.insert(task_key, metrics);
454        }
455        let metrics_collection = Arc::new(metrics_collection);
456
457        // Rewrite the plan.
458        let rewritten_plan =
459            stage_metrics_rewriter(&stage, metrics_collection.clone(), format).unwrap();
460
461        // Collect metrics from the plan.
462        let mut actual_metrics = vec![];
463        collect_metrics_from_plan(&rewritten_plan, &mut actual_metrics);
464        assert_eq!(
465            actual_metrics.len(),
466            count_plan_nodes_up_to_network_boundary(&plan)
467        );
468
469        // Assert that metrics from all tasks are present.
470        // actual_stage_node_metrics_set contains metrics for all task ex. [output_rows=1, elapsed_compute=1, output_rows=2, elapsed_compute=2...]
471        for (node_id, actual_stage_node_metrics_set) in actual_metrics.iter().enumerate() {
472            // actual_task_node_metrics_set contains metrics for one task ex. [output_rows=1, elapsed_compute=1]
473            for (task_id, actual_task_node_metrics_set) in actual_stage_node_metrics_set
474                .iter()
475                .chunks(num_metrics_per_task_per_node)
476                .into_iter()
477                .enumerate()
478            {
479                let expected_task_node_metrics = metrics_collection
480                    .get(&TaskKey {
481                        query_id: stage.query_id.as_bytes().to_vec(),
482                        stage_id: stage.num as u64,
483                        task_number: task_id as u64,
484                    })
485                    .unwrap()[node_id]
486                    .clone();
487
488                let mut actual_metrics_set = MetricsSet::new();
489                actual_task_node_metrics_set
490                    .for_each(|metric| actual_metrics_set.push(metric.clone()));
491
492                // Convert from proto to check for equality.
493                let mut expected_metrics_set =
494                    metrics_set_proto_to_df(&expected_task_node_metrics).unwrap();
495
496                if format == DistributedMetricsFormat::PerTask {
497                    // Add task ids labels. We expect the actual metrics to be annotated by the
498                    // rewriter when using DistributedMetricsFormat::PerTask
499                    expected_metrics_set =
500                        annotate_metrics_set_with_task_id(expected_metrics_set, task_id as u64);
501                }
502                assert!(metrics_set_eq(&actual_metrics_set, &expected_metrics_set));
503            }
504        }
505    }
506
507    #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
508    #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
509    #[tokio::test]
510    async fn test_stage_metrics_rewriter_1(format: DistributedMetricsFormat) {
511        run_stage_metrics_rewriter_test(
512            "SELECT sum(balance) / 7.0 as avg_yearly from table2 group by name",
513            format,
514        )
515        .await;
516    }
517
518    #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
519    #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
520    #[tokio::test]
521    async fn test_stage_metrics_rewriter_2(format: DistributedMetricsFormat) {
522        run_stage_metrics_rewriter_test("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10", format).await;
523    }
524
525    #[test_case(DistributedMetricsFormat::Aggregated ; "aggregated_metrics")]
526    #[test_case(DistributedMetricsFormat::PerTask ; "per_task_metrics")]
527    #[tokio::test]
528    async fn test_stage_metrics_rewriter_3(format: DistributedMetricsFormat) {
529        run_stage_metrics_rewriter_test(
530            "SELECT sum(balance) / 7.0 as avg_yearly
531            FROM table2
532            WHERE name LIKE 'customer%'
533              AND balance < (
534                SELECT 0.2 * avg(balance)
535                FROM table2 t2_inner
536                WHERE t2_inner.id = table2.id
537              )",
538            format,
539        )
540        .await;
541    }
542
543    #[tokio::test]
544    async fn test_rewrite_unexecuted_distributed_plan_with_metrics_err() {
545        let ctx = make_test_distributed_ctx().await;
546        let plan = ctx
547            .sql("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10")
548            .await
549            .unwrap()
550            .create_physical_plan()
551            .await
552            .unwrap();
553        assert!(plan.as_any().is::<DistributedExec>());
554        assert!(
555            rewrite_distributed_plan_with_metrics(plan, DistributedMetricsFormat::Aggregated)
556                .is_err()
557        );
558    }
559
560    // Assert every plan node has at least one metric except partition isolators, network boundary nodes, and the root DistributedExec node.
561    fn assert_metrics_present_in_plan(plan: &Arc<dyn ExecutionPlan>) {
562        if let Some(metrics) = plan.metrics() {
563            assert!(metrics.iter().count() > 0);
564        } else {
565            assert!(plan.as_any().is::<DistributedExec>());
566        }
567        for child in plan.children() {
568            assert_metrics_present_in_plan(child);
569        }
570    }
571
572    #[tokio::test]
573    async fn test_executed_distributed_plan_has_metrics() {
574        let ctx = make_test_distributed_ctx().await;
575        let plan = ctx
576            .sql("SELECT id, COUNT(*) as count FROM table1 WHERE id > 1 GROUP BY id ORDER BY id LIMIT 10")
577            .await
578            .unwrap()
579            .create_physical_plan()
580            .await
581            .unwrap();
582        collect(plan.clone(), ctx.task_ctx()).await.unwrap();
583        assert!(plan.as_any().is::<DistributedExec>());
584        let rewritten_plan =
585            rewrite_distributed_plan_with_metrics(plan, DistributedMetricsFormat::Aggregated)
586                .unwrap();
587        assert_metrics_present_in_plan(&rewritten_plan);
588    }
589
590    #[test]
591    // An important feature of DF execution plans which we want to preserve is the ability
592    // to traverse a plan and collect metrics from specific nodes. To do this, the wrapper must
593    // allow access to the inner node. This test asserts that we support this.
594    fn test_wrapped_node_is_accessible() {
595        let example_node = Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![Field::new(
596            "id",
597            DataType::Int32,
598            false,
599        )]))));
600
601        let wrapped = MetricsWrapperExec::new(example_node, MetricsSet::new());
602        assert_eq!(wrapped.name(), "EmptyExec");
603        assert!(wrapped.as_any().is::<EmptyExec>());
604    }
605
606    #[test]
607    fn test_annotate_metrics_set_with_task_id_output_rows() {
608        // Create a MetricsSet with an OutputRows metric
609        let mut metrics_set = MetricsSet::new();
610        let count = Count::new();
611        count.add(1234);
612        let labels = vec![Label::new("operator", "scan")];
613        metrics_set.push(Arc::new(Metric::new_with_labels(
614            MetricValue::OutputRows(count),
615            Some(0),
616            labels,
617        )));
618
619        let task_id = 42;
620        let annotated = annotate_metrics_set_with_task_id(metrics_set, task_id);
621
622        // Verify we have one metric
623        assert_eq!(annotated.iter().count(), 1);
624
625        let metric = annotated.iter().next().unwrap();
626
627        // Verify metric type is preserved (OutputRows)
628        match metric.value() {
629            MetricValue::OutputRows(count) => {
630                assert_eq!(count.value(), 1234);
631            }
632            other => panic!("Expected OutputRows, got {:?}", other.name()),
633        }
634
635        // Verify partition is preserved
636        assert_eq!(metric.partition(), Some(0));
637
638        // Verify original labels are preserved and task_id label is added
639        let labels: Vec<_> = metric.labels().iter().collect();
640        assert_eq!(labels.len(), 2);
641        assert_eq!(labels[0].name(), "operator");
642        assert_eq!(labels[0].value(), "scan");
643        assert_eq!(labels[1].name(), DISTRIBUTED_DATAFUSION_TASK_ID_LABEL);
644        assert_eq!(labels[1].value(), "42");
645    }
646}