Skip to main content

dag_ml_core/
metrics.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::aggregation::{AggregatedPredictionBlock, PredictionUnitId};
6use crate::error::{DagMlError, Result};
7use crate::ids::{FoldId, NodeId};
8use crate::oof::{PredictionBlock, PredictionPartition};
9use crate::policy::PredictionLevel;
10use crate::selection::{CandidateScore, MetricObjective};
11
12#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum RegressionMetricKind {
15    Mse,
16    Rmse,
17    Mae,
18    R2,
19}
20
21impl RegressionMetricKind {
22    pub fn name(self) -> &'static str {
23        match self {
24            Self::Mse => "mse",
25            Self::Rmse => "rmse",
26            Self::Mae => "mae",
27            Self::R2 => "r2",
28        }
29    }
30
31    pub fn objective(self) -> MetricObjective {
32        match self {
33            Self::Mse | Self::Rmse | Self::Mae => MetricObjective::Minimize,
34            Self::R2 => MetricObjective::Maximize,
35        }
36    }
37}
38
39#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
40pub struct RegressionTargetBlock {
41    pub level: PredictionLevel,
42    pub unit_ids: Vec<PredictionUnitId>,
43    pub values: Vec<Vec<f64>>,
44    #[serde(default)]
45    pub target_names: Vec<String>,
46}
47
48impl RegressionTargetBlock {
49    pub fn validate_shape(&self) -> Result<usize> {
50        if self.unit_ids.len() != self.values.len() {
51            return Err(DagMlError::OofValidation(format!(
52                "target block has {} unit ids but {} target rows",
53                self.unit_ids.len(),
54                self.values.len()
55            )));
56        }
57        if self
58            .unit_ids
59            .iter()
60            .any(|unit_id| unit_id.level() != self.level)
61        {
62            return Err(DagMlError::OofValidation(format!(
63                "target block contains units outside level {:?}",
64                self.level
65            )));
66        }
67        let unique = self.unit_ids.iter().collect::<BTreeSet<_>>();
68        if unique.len() != self.unit_ids.len() {
69            return Err(DagMlError::OofValidation(
70                "target block contains duplicate unit ids".to_string(),
71            ));
72        }
73        let width = self.values.first().map_or(0, Vec::len);
74        if width == 0 {
75            return Err(DagMlError::OofValidation(
76                "target block has empty target rows".to_string(),
77            ));
78        }
79        if self.values.iter().any(|row| row.len() != width) {
80            return Err(DagMlError::OofValidation(
81                "target block has ragged target rows".to_string(),
82            ));
83        }
84        if self.values.iter().flatten().any(|value| !value.is_finite()) {
85            return Err(DagMlError::OofValidation(
86                "target block contains non-finite values".to_string(),
87            ));
88        }
89        if !self.target_names.is_empty() && self.target_names.len() != width {
90            return Err(DagMlError::OofValidation(format!(
91                "target block has {} target names for width {}",
92                self.target_names.len(),
93                width
94            )));
95        }
96        Ok(width)
97    }
98}
99
100#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
101pub struct RegressionMetricReport {
102    #[serde(default)]
103    pub prediction_id: Option<String>,
104    pub producer_node: NodeId,
105    pub partition: PredictionPartition,
106    pub fold_id: Option<FoldId>,
107    pub level: PredictionLevel,
108    pub row_count: usize,
109    pub target_width: usize,
110    #[serde(default)]
111    pub target_names: Vec<String>,
112    pub metrics: BTreeMap<String, f64>,
113}
114
115impl RegressionMetricReport {
116    pub fn validate(&self) -> Result<()> {
117        if self.row_count == 0 {
118            return Err(DagMlError::OofValidation(
119                "regression metric report has zero rows".to_string(),
120            ));
121        }
122        if self.target_width == 0 {
123            return Err(DagMlError::OofValidation(
124                "regression metric report has zero target width".to_string(),
125            ));
126        }
127        if !self.target_names.is_empty() && self.target_names.len() != self.target_width {
128            return Err(DagMlError::OofValidation(format!(
129                "regression metric report has {} target names for width {}",
130                self.target_names.len(),
131                self.target_width
132            )));
133        }
134        if self.metrics.is_empty() {
135            return Err(DagMlError::OofValidation(
136                "regression metric report has no metrics".to_string(),
137            ));
138        }
139        for (name, value) in &self.metrics {
140            if name.trim().is_empty() {
141                return Err(DagMlError::OofValidation(
142                    "regression metric report contains an empty metric name".to_string(),
143                ));
144            }
145            if !value.is_finite() {
146                return Err(DagMlError::OofValidation(format!(
147                    "regression metric `{name}` is not finite"
148                )));
149            }
150        }
151        Ok(())
152    }
153
154    pub fn into_candidate_score(self, candidate_id: impl Into<String>) -> Result<CandidateScore> {
155        self.validate()?;
156        let mut metadata = BTreeMap::from([
157            (
158                "producer_node".to_string(),
159                serde_json::json!(self.producer_node),
160            ),
161            ("partition".to_string(), serde_json::json!(self.partition)),
162            (
163                "metric_level".to_string(),
164                serde_json::json!(prediction_level_name(self.level)),
165            ),
166            ("row_count".to_string(), serde_json::json!(self.row_count)),
167            (
168                "target_width".to_string(),
169                serde_json::json!(self.target_width),
170            ),
171        ]);
172        if let Some(prediction_id) = self.prediction_id {
173            metadata.insert(
174                "prediction_id".to_string(),
175                serde_json::json!(prediction_id),
176            );
177        }
178        if let Some(fold_id) = self.fold_id {
179            metadata.insert("fold_id".to_string(), serde_json::json!(fold_id));
180        }
181        if !self.target_names.is_empty() {
182            metadata.insert(
183                "target_names".to_string(),
184                serde_json::json!(self.target_names),
185            );
186        }
187        let score = CandidateScore {
188            candidate_id: candidate_id.into(),
189            metrics: self.metrics,
190            metadata,
191        };
192        score.validate()?;
193        Ok(score)
194    }
195}
196
197pub fn regression_report_to_candidate_score(
198    candidate_id: impl Into<String>,
199    report: RegressionMetricReport,
200) -> Result<CandidateScore> {
201    report.into_candidate_score(candidate_id)
202}
203
204pub fn score_regression_prediction_block(
205    predictions: &PredictionBlock,
206    targets: &RegressionTargetBlock,
207    metrics: &[RegressionMetricKind],
208) -> Result<RegressionMetricReport> {
209    let width = validate_sample_prediction_block(predictions)?;
210    let prediction_units = predictions
211        .sample_ids
212        .iter()
213        .cloned()
214        .map(PredictionUnitId::Sample)
215        .collect::<Vec<_>>();
216    score_regression_rows(
217        PredictionRows {
218            level: PredictionLevel::Sample,
219            unit_ids: &prediction_units,
220            values: &predictions.values,
221            target_names: &predictions.target_names,
222            width,
223            origin: PredictionReportOrigin {
224                prediction_id: predictions.prediction_id.clone(),
225                producer_node: predictions.producer_node.clone(),
226                partition: predictions.partition.clone(),
227                fold_id: predictions.fold_id.clone(),
228            },
229        },
230        targets,
231        metrics,
232    )
233}
234
235pub fn score_regression_aggregated_block(
236    predictions: &AggregatedPredictionBlock,
237    targets: &RegressionTargetBlock,
238    metrics: &[RegressionMetricKind],
239) -> Result<RegressionMetricReport> {
240    let width = predictions.validate_shape()?;
241    score_regression_rows(
242        PredictionRows {
243            level: predictions.level,
244            unit_ids: &predictions.unit_ids,
245            values: &predictions.values,
246            target_names: &predictions.target_names,
247            width,
248            origin: PredictionReportOrigin {
249                prediction_id: predictions.prediction_id.clone(),
250                producer_node: predictions.producer_node.clone(),
251                partition: predictions.partition.clone(),
252                fold_id: predictions.fold_id.clone(),
253            },
254        },
255        targets,
256        metrics,
257    )
258}
259
260#[derive(Clone, Debug)]
261struct PredictionReportOrigin {
262    prediction_id: Option<String>,
263    producer_node: NodeId,
264    partition: PredictionPartition,
265    fold_id: Option<FoldId>,
266}
267
268#[derive(Clone, Debug)]
269struct PredictionRows<'a> {
270    level: PredictionLevel,
271    unit_ids: &'a [PredictionUnitId],
272    values: &'a [Vec<f64>],
273    target_names: &'a [String],
274    width: usize,
275    origin: PredictionReportOrigin,
276}
277
278fn score_regression_rows(
279    predictions: PredictionRows<'_>,
280    targets: &RegressionTargetBlock,
281    metrics: &[RegressionMetricKind],
282) -> Result<RegressionMetricReport> {
283    if metrics.is_empty() {
284        return Err(DagMlError::OofValidation(
285            "no regression metrics requested".to_string(),
286        ));
287    }
288    let mut requested_metrics = BTreeSet::new();
289    for metric in metrics {
290        if !requested_metrics.insert(*metric) {
291            return Err(DagMlError::OofValidation(format!(
292                "duplicate regression metric `{}` requested",
293                metric.name()
294            )));
295        }
296    }
297
298    let target_width = targets.validate_shape()?;
299    if predictions.width != target_width {
300        return Err(DagMlError::OofValidation(format!(
301            "prediction width {} does not match target width {target_width}",
302            predictions.width
303        )));
304    }
305    if predictions.level != targets.level {
306        return Err(DagMlError::OofValidation(format!(
307            "prediction level {:?} does not match target level {:?}",
308            predictions.level, targets.level
309        )));
310    }
311    if !predictions.target_names.is_empty()
312        && !targets.target_names.is_empty()
313        && predictions.target_names != targets.target_names
314    {
315        return Err(DagMlError::OofValidation(
316            "prediction target names do not match target block names".to_string(),
317        ));
318    }
319
320    let target_by_unit = targets
321        .unit_ids
322        .iter()
323        .zip(targets.values.iter().map(Vec::as_slice))
324        .collect::<BTreeMap<_, _>>();
325    let mut aligned_predictions = Vec::with_capacity(predictions.unit_ids.len());
326    let mut aligned_targets = Vec::with_capacity(predictions.unit_ids.len());
327    for (unit_id, prediction_row) in predictions.unit_ids.iter().zip(predictions.values.iter()) {
328        let target_row = target_by_unit.get(unit_id).ok_or_else(|| {
329            DagMlError::OofValidation(format!(
330                "prediction unit `{unit_id}` is missing from target block"
331            ))
332        })?;
333        aligned_predictions.push(prediction_row.as_slice());
334        aligned_targets.push(*target_row);
335    }
336    if aligned_predictions.len() != target_by_unit.len() {
337        return Err(DagMlError::OofValidation(
338            "target block contains units not present in predictions".to_string(),
339        ));
340    }
341
342    let target_names = if !predictions.target_names.is_empty() {
343        predictions.target_names.to_vec()
344    } else {
345        targets.target_names.clone()
346    };
347    let metric_suffixes = target_metric_names(predictions.width, &target_names);
348    let mut values = BTreeMap::new();
349    for metric in metrics {
350        let per_target = compute_metric_per_target(
351            *metric,
352            predictions.width,
353            &aligned_predictions,
354            &aligned_targets,
355        );
356        values.insert(metric.name().to_string(), macro_mean(&per_target));
357        for (name, value) in metric_suffixes.iter().zip(per_target) {
358            values.insert(format!("{}:{name}", metric.name()), value);
359        }
360    }
361
362    let report = RegressionMetricReport {
363        prediction_id: predictions.origin.prediction_id,
364        producer_node: predictions.origin.producer_node,
365        partition: predictions.origin.partition,
366        fold_id: predictions.origin.fold_id,
367        level: predictions.level,
368        row_count: predictions.unit_ids.len(),
369        target_width: predictions.width,
370        target_names,
371        metrics: values,
372    };
373    report.validate()?;
374    Ok(report)
375}
376
377fn validate_sample_prediction_block(block: &PredictionBlock) -> Result<usize> {
378    let width = block.validate_shape()?;
379    if block
380        .values
381        .iter()
382        .flatten()
383        .any(|value| !value.is_finite())
384    {
385        return Err(DagMlError::OofValidation(format!(
386            "producer `{}` emitted non-finite sample prediction values",
387            block.producer_node
388        )));
389    }
390    let unique = block.sample_ids.iter().collect::<BTreeSet<_>>();
391    if unique.len() != block.sample_ids.len() {
392        return Err(DagMlError::OofValidation(format!(
393            "producer `{}` emitted duplicate sample predictions",
394            block.producer_node
395        )));
396    }
397    Ok(width)
398}
399
400fn compute_metric_per_target(
401    metric: RegressionMetricKind,
402    width: usize,
403    predictions: &[&[f64]],
404    targets: &[&[f64]],
405) -> Vec<f64> {
406    (0..width)
407        .map(|target_idx| match metric {
408            RegressionMetricKind::Mse => {
409                predictions
410                    .iter()
411                    .zip(targets.iter())
412                    .map(|(prediction, target)| {
413                        let error = prediction[target_idx] - target[target_idx];
414                        error * error
415                    })
416                    .sum::<f64>()
417                    / predictions.len() as f64
418            }
419            RegressionMetricKind::Rmse => (predictions
420                .iter()
421                .zip(targets.iter())
422                .map(|(prediction, target)| {
423                    let error = prediction[target_idx] - target[target_idx];
424                    error * error
425                })
426                .sum::<f64>()
427                / predictions.len() as f64)
428                .sqrt(),
429            RegressionMetricKind::Mae => {
430                predictions
431                    .iter()
432                    .zip(targets.iter())
433                    .map(|(prediction, target)| (prediction[target_idx] - target[target_idx]).abs())
434                    .sum::<f64>()
435                    / predictions.len() as f64
436            }
437            RegressionMetricKind::R2 => r2_for_target(target_idx, predictions, targets),
438        })
439        .collect()
440}
441
442fn r2_for_target(target_idx: usize, predictions: &[&[f64]], targets: &[&[f64]]) -> f64 {
443    let mean = targets.iter().map(|row| row[target_idx]).sum::<f64>() / targets.len() as f64;
444    let ss_res = predictions
445        .iter()
446        .zip(targets.iter())
447        .map(|(prediction, target)| {
448            let error = prediction[target_idx] - target[target_idx];
449            error * error
450        })
451        .sum::<f64>();
452    let ss_tot = targets
453        .iter()
454        .map(|target| {
455            let centered = target[target_idx] - mean;
456            centered * centered
457        })
458        .sum::<f64>();
459    if ss_tot == 0.0 {
460        if ss_res == 0.0 {
461            1.0
462        } else {
463            0.0
464        }
465    } else {
466        1.0 - ss_res / ss_tot
467    }
468}
469
470fn macro_mean(values: &[f64]) -> f64 {
471    values.iter().sum::<f64>() / values.len() as f64
472}
473
474fn target_metric_names(width: usize, target_names: &[String]) -> Vec<String> {
475    if target_names.is_empty() {
476        (0..width).map(|idx| format!("target_{idx}")).collect()
477    } else {
478        target_names.to_vec()
479    }
480}
481
482fn prediction_level_name(level: PredictionLevel) -> &'static str {
483    match level {
484        PredictionLevel::Observation => "observation",
485        PredictionLevel::Sample => "sample",
486        PredictionLevel::Target => "target",
487        PredictionLevel::Group => "group",
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494    use crate::ids::{GroupId, NodeId, SampleId, TargetId};
495    use crate::oof::PredictionPartition;
496
497    fn sid(value: &str) -> SampleId {
498        SampleId::new(value).unwrap()
499    }
500
501    fn sample_unit(value: &str) -> PredictionUnitId {
502        PredictionUnitId::Sample(sid(value))
503    }
504
505    fn target_unit(value: &str) -> PredictionUnitId {
506        PredictionUnitId::Target(TargetId::new(value).unwrap())
507    }
508
509    fn group_unit(value: &str) -> PredictionUnitId {
510        PredictionUnitId::Group(GroupId::new(value).unwrap())
511    }
512
513    fn assert_close(left: f64, right: f64) {
514        assert!((left - right).abs() < 1e-12, "expected {right}, got {left}");
515    }
516
517    #[test]
518    fn metric_objectives_match_selection_direction() {
519        assert_eq!(
520            RegressionMetricKind::Rmse.objective(),
521            MetricObjective::Minimize
522        );
523        assert_eq!(
524            RegressionMetricKind::Mae.objective(),
525            MetricObjective::Minimize
526        );
527        assert_eq!(
528            RegressionMetricKind::Mse.objective(),
529            MetricObjective::Minimize
530        );
531        assert_eq!(
532            RegressionMetricKind::R2.objective(),
533            MetricObjective::Maximize
534        );
535    }
536
537    #[test]
538    fn scores_sample_predictions_and_exports_candidate_metrics() {
539        let predictions = PredictionBlock {
540            prediction_id: Some("pred:sample".to_string()),
541            producer_node: NodeId::new("model:pls").unwrap(),
542            partition: PredictionPartition::Validation,
543            fold_id: None,
544            sample_ids: vec![sid("sample:1"), sid("sample:2")],
545            values: vec![vec![2.0], vec![4.0]],
546            target_names: vec!["y".to_string()],
547        };
548        let targets = RegressionTargetBlock {
549            level: PredictionLevel::Sample,
550            unit_ids: vec![sample_unit("sample:2"), sample_unit("sample:1")],
551            values: vec![vec![5.0], vec![1.0]],
552            target_names: vec!["y".to_string()],
553        };
554
555        let report = score_regression_prediction_block(
556            &predictions,
557            &targets,
558            &[
559                RegressionMetricKind::Rmse,
560                RegressionMetricKind::Mae,
561                RegressionMetricKind::R2,
562            ],
563        )
564        .unwrap();
565
566        assert_eq!(report.level, PredictionLevel::Sample);
567        assert_close(report.metrics["rmse"], 1.0);
568        assert_close(report.metrics["rmse:y"], 1.0);
569        assert_close(report.metrics["mae"], 1.0);
570        assert_close(report.metrics["r2"], 0.75);
571        let candidate = regression_report_to_candidate_score("model:pls", report).unwrap();
572        assert_eq!(candidate.metrics["rmse"], 1.0);
573        assert_eq!(candidate.metadata["metric_level"], "sample");
574        assert_eq!(candidate.metadata["producer_node"], "model:pls");
575        assert_eq!(candidate.metadata["partition"], "validation");
576        assert_eq!(candidate.metadata["prediction_id"], "pred:sample");
577        assert_eq!(candidate.metadata["target_names"], serde_json::json!(["y"]));
578    }
579
580    #[test]
581    fn scores_target_and_group_prediction_blocks() {
582        let predictions = AggregatedPredictionBlock {
583            prediction_id: Some("pred:target".to_string()),
584            producer_node: NodeId::new("model:pls").unwrap(),
585            partition: PredictionPartition::Validation,
586            fold_id: None,
587            level: PredictionLevel::Target,
588            unit_ids: vec![target_unit("target:a"), target_unit("target:b")],
589            values: vec![vec![1.0, 10.0], vec![3.0, 30.0]],
590            target_names: vec!["y1".to_string(), "y2".to_string()],
591        };
592        let targets = RegressionTargetBlock {
593            level: PredictionLevel::Target,
594            unit_ids: vec![target_unit("target:b"), target_unit("target:a")],
595            values: vec![vec![2.0, 28.0], vec![2.0, 12.0]],
596            target_names: vec!["y1".to_string(), "y2".to_string()],
597        };
598        let report = score_regression_aggregated_block(
599            &predictions,
600            &targets,
601            &[RegressionMetricKind::Mse, RegressionMetricKind::Rmse],
602        )
603        .unwrap();
604
605        assert_eq!(report.level, PredictionLevel::Target);
606        assert_close(report.metrics["mse:y1"], 1.0);
607        assert_close(report.metrics["mse:y2"], 4.0);
608        assert_close(report.metrics["mse"], 2.5);
609        assert_close(report.metrics["rmse:y1"], 1.0);
610        assert_close(report.metrics["rmse:y2"], 2.0);
611        assert_close(report.metrics["rmse"], 1.5);
612
613        let group_predictions = AggregatedPredictionBlock {
614            prediction_id: Some("pred:group".to_string()),
615            producer_node: NodeId::new("model:pls").unwrap(),
616            partition: PredictionPartition::Validation,
617            fold_id: None,
618            level: PredictionLevel::Group,
619            unit_ids: vec![group_unit("group:a")],
620            values: vec![vec![3.0]],
621            target_names: vec!["y".to_string()],
622        };
623        let group_targets = RegressionTargetBlock {
624            level: PredictionLevel::Group,
625            unit_ids: vec![group_unit("group:a")],
626            values: vec![vec![1.0]],
627            target_names: vec!["y".to_string()],
628        };
629        let group_report = score_regression_aggregated_block(
630            &group_predictions,
631            &group_targets,
632            &[RegressionMetricKind::Mae],
633        )
634        .unwrap();
635        assert_eq!(group_report.level, PredictionLevel::Group);
636        assert_close(group_report.metrics["mae"], 2.0);
637    }
638
639    #[test]
640    fn refuses_metric_alignment_and_contract_mismatches() {
641        let predictions = AggregatedPredictionBlock {
642            prediction_id: None,
643            producer_node: NodeId::new("model:pls").unwrap(),
644            partition: PredictionPartition::Validation,
645            fold_id: None,
646            level: PredictionLevel::Target,
647            unit_ids: vec![target_unit("target:a")],
648            values: vec![vec![1.0]],
649            target_names: vec!["y".to_string()],
650        };
651        let missing_target = RegressionTargetBlock {
652            level: PredictionLevel::Target,
653            unit_ids: vec![target_unit("target:b")],
654            values: vec![vec![1.0]],
655            target_names: vec!["y".to_string()],
656        };
657        assert!(score_regression_aggregated_block(
658            &predictions,
659            &missing_target,
660            &[RegressionMetricKind::Rmse],
661        )
662        .is_err());
663
664        let wrong_level = RegressionTargetBlock {
665            level: PredictionLevel::Group,
666            unit_ids: vec![group_unit("group:a")],
667            values: vec![vec![1.0]],
668            target_names: vec!["y".to_string()],
669        };
670        assert!(score_regression_aggregated_block(
671            &predictions,
672            &wrong_level,
673            &[RegressionMetricKind::Rmse],
674        )
675        .is_err());
676
677        assert!(score_regression_aggregated_block(&predictions, &missing_target, &[]).is_err());
678        assert!(score_regression_aggregated_block(
679            &predictions,
680            &RegressionTargetBlock {
681                level: PredictionLevel::Target,
682                unit_ids: vec![target_unit("target:a")],
683                values: vec![vec![1.0]],
684                target_names: vec!["other".to_string()],
685            },
686            &[RegressionMetricKind::Rmse],
687        )
688        .is_err());
689        assert!(score_regression_aggregated_block(
690            &predictions,
691            &RegressionTargetBlock {
692                level: PredictionLevel::Target,
693                unit_ids: vec![target_unit("target:a")],
694                values: vec![vec![1.0]],
695                target_names: vec!["y".to_string()],
696            },
697            &[RegressionMetricKind::Rmse, RegressionMetricKind::Rmse],
698        )
699        .is_err());
700    }
701
702    #[test]
703    fn refuses_duplicate_and_non_finite_sample_predictions() {
704        let targets = RegressionTargetBlock {
705            level: PredictionLevel::Sample,
706            unit_ids: vec![sample_unit("sample:1")],
707            values: vec![vec![1.0]],
708            target_names: vec!["y".to_string()],
709        };
710        let mut predictions = PredictionBlock {
711            prediction_id: None,
712            producer_node: NodeId::new("model:pls").unwrap(),
713            partition: PredictionPartition::Validation,
714            fold_id: None,
715            sample_ids: vec![sid("sample:1")],
716            values: vec![vec![f64::INFINITY]],
717            target_names: vec!["y".to_string()],
718        };
719        assert!(score_regression_prediction_block(
720            &predictions,
721            &targets,
722            &[RegressionMetricKind::Rmse],
723        )
724        .is_err());
725
726        predictions.values = vec![vec![1.0], vec![1.0]];
727        predictions.sample_ids = vec![sid("sample:1"), sid("sample:1")];
728        assert!(score_regression_prediction_block(
729            &predictions,
730            &targets,
731            &[RegressionMetricKind::Rmse],
732        )
733        .is_err());
734    }
735
736    #[test]
737    fn constant_target_r2_is_finite_and_deterministic() {
738        let targets = RegressionTargetBlock {
739            level: PredictionLevel::Sample,
740            unit_ids: vec![sample_unit("sample:1"), sample_unit("sample:2")],
741            values: vec![vec![2.0], vec![2.0]],
742            target_names: vec!["y".to_string()],
743        };
744        let exact_predictions = PredictionBlock {
745            prediction_id: None,
746            producer_node: NodeId::new("model:exact").unwrap(),
747            partition: PredictionPartition::Validation,
748            fold_id: None,
749            sample_ids: vec![sid("sample:1"), sid("sample:2")],
750            values: vec![vec![2.0], vec![2.0]],
751            target_names: vec!["y".to_string()],
752        };
753        let exact_report = score_regression_prediction_block(
754            &exact_predictions,
755            &targets,
756            &[RegressionMetricKind::R2],
757        )
758        .unwrap();
759        assert_close(exact_report.metrics["r2"], 1.0);
760
761        let off_predictions = PredictionBlock {
762            values: vec![vec![2.0], vec![3.0]],
763            ..exact_predictions
764        };
765        let off_report = score_regression_prediction_block(
766            &off_predictions,
767            &targets,
768            &[RegressionMetricKind::R2],
769        )
770        .unwrap();
771        assert_close(off_report.metrics["r2"], 0.0);
772    }
773}