1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::aggregation::{AggregatedPredictionBlock, PredictionUnitId};
6use crate::error::{DagMlError, Result};
7use crate::ids::{FoldId, NodeId};
8use crate::oof::{PredictionBlock, PredictionPartition};
9use crate::policy::PredictionLevel;
10use crate::selection::{CandidateScore, MetricObjective};
11
12#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum RegressionMetricKind {
15 Mse,
16 Rmse,
17 Mae,
18 R2,
19}
20
21impl RegressionMetricKind {
22 pub fn name(self) -> &'static str {
23 match self {
24 Self::Mse => "mse",
25 Self::Rmse => "rmse",
26 Self::Mae => "mae",
27 Self::R2 => "r2",
28 }
29 }
30
31 pub fn objective(self) -> MetricObjective {
32 match self {
33 Self::Mse | Self::Rmse | Self::Mae => MetricObjective::Minimize,
34 Self::R2 => MetricObjective::Maximize,
35 }
36 }
37}
38
39#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
40pub struct RegressionTargetBlock {
41 pub level: PredictionLevel,
42 pub unit_ids: Vec<PredictionUnitId>,
43 pub values: Vec<Vec<f64>>,
44 #[serde(default)]
45 pub target_names: Vec<String>,
46}
47
48impl RegressionTargetBlock {
49 pub fn validate_shape(&self) -> Result<usize> {
50 if self.unit_ids.len() != self.values.len() {
51 return Err(DagMlError::OofValidation(format!(
52 "target block has {} unit ids but {} target rows",
53 self.unit_ids.len(),
54 self.values.len()
55 )));
56 }
57 if self
58 .unit_ids
59 .iter()
60 .any(|unit_id| unit_id.level() != self.level)
61 {
62 return Err(DagMlError::OofValidation(format!(
63 "target block contains units outside level {:?}",
64 self.level
65 )));
66 }
67 let unique = self.unit_ids.iter().collect::<BTreeSet<_>>();
68 if unique.len() != self.unit_ids.len() {
69 return Err(DagMlError::OofValidation(
70 "target block contains duplicate unit ids".to_string(),
71 ));
72 }
73 let width = self.values.first().map_or(0, Vec::len);
74 if width == 0 {
75 return Err(DagMlError::OofValidation(
76 "target block has empty target rows".to_string(),
77 ));
78 }
79 if self.values.iter().any(|row| row.len() != width) {
80 return Err(DagMlError::OofValidation(
81 "target block has ragged target rows".to_string(),
82 ));
83 }
84 if self.values.iter().flatten().any(|value| !value.is_finite()) {
85 return Err(DagMlError::OofValidation(
86 "target block contains non-finite values".to_string(),
87 ));
88 }
89 if !self.target_names.is_empty() && self.target_names.len() != width {
90 return Err(DagMlError::OofValidation(format!(
91 "target block has {} target names for width {}",
92 self.target_names.len(),
93 width
94 )));
95 }
96 Ok(width)
97 }
98}
99
100#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
101pub struct RegressionMetricReport {
102 #[serde(default)]
103 pub prediction_id: Option<String>,
104 pub producer_node: NodeId,
105 pub partition: PredictionPartition,
106 pub fold_id: Option<FoldId>,
107 pub level: PredictionLevel,
108 pub row_count: usize,
109 pub target_width: usize,
110 #[serde(default)]
111 pub target_names: Vec<String>,
112 pub metrics: BTreeMap<String, f64>,
113}
114
115impl RegressionMetricReport {
116 pub fn validate(&self) -> Result<()> {
117 if self.row_count == 0 {
118 return Err(DagMlError::OofValidation(
119 "regression metric report has zero rows".to_string(),
120 ));
121 }
122 if self.target_width == 0 {
123 return Err(DagMlError::OofValidation(
124 "regression metric report has zero target width".to_string(),
125 ));
126 }
127 if !self.target_names.is_empty() && self.target_names.len() != self.target_width {
128 return Err(DagMlError::OofValidation(format!(
129 "regression metric report has {} target names for width {}",
130 self.target_names.len(),
131 self.target_width
132 )));
133 }
134 if self.metrics.is_empty() {
135 return Err(DagMlError::OofValidation(
136 "regression metric report has no metrics".to_string(),
137 ));
138 }
139 for (name, value) in &self.metrics {
140 if name.trim().is_empty() {
141 return Err(DagMlError::OofValidation(
142 "regression metric report contains an empty metric name".to_string(),
143 ));
144 }
145 if !value.is_finite() {
146 return Err(DagMlError::OofValidation(format!(
147 "regression metric `{name}` is not finite"
148 )));
149 }
150 }
151 Ok(())
152 }
153
154 pub fn into_candidate_score(self, candidate_id: impl Into<String>) -> Result<CandidateScore> {
155 self.validate()?;
156 let mut metadata = BTreeMap::from([
157 (
158 "producer_node".to_string(),
159 serde_json::json!(self.producer_node),
160 ),
161 ("partition".to_string(), serde_json::json!(self.partition)),
162 (
163 "metric_level".to_string(),
164 serde_json::json!(prediction_level_name(self.level)),
165 ),
166 ("row_count".to_string(), serde_json::json!(self.row_count)),
167 (
168 "target_width".to_string(),
169 serde_json::json!(self.target_width),
170 ),
171 ]);
172 if let Some(prediction_id) = self.prediction_id {
173 metadata.insert(
174 "prediction_id".to_string(),
175 serde_json::json!(prediction_id),
176 );
177 }
178 if let Some(fold_id) = self.fold_id {
179 metadata.insert("fold_id".to_string(), serde_json::json!(fold_id));
180 }
181 if !self.target_names.is_empty() {
182 metadata.insert(
183 "target_names".to_string(),
184 serde_json::json!(self.target_names),
185 );
186 }
187 let score = CandidateScore {
188 candidate_id: candidate_id.into(),
189 metrics: self.metrics,
190 metadata,
191 };
192 score.validate()?;
193 Ok(score)
194 }
195}
196
197pub fn regression_report_to_candidate_score(
198 candidate_id: impl Into<String>,
199 report: RegressionMetricReport,
200) -> Result<CandidateScore> {
201 report.into_candidate_score(candidate_id)
202}
203
204pub fn score_regression_prediction_block(
205 predictions: &PredictionBlock,
206 targets: &RegressionTargetBlock,
207 metrics: &[RegressionMetricKind],
208) -> Result<RegressionMetricReport> {
209 let width = validate_sample_prediction_block(predictions)?;
210 let prediction_units = predictions
211 .sample_ids
212 .iter()
213 .cloned()
214 .map(PredictionUnitId::Sample)
215 .collect::<Vec<_>>();
216 score_regression_rows(
217 PredictionRows {
218 level: PredictionLevel::Sample,
219 unit_ids: &prediction_units,
220 values: &predictions.values,
221 target_names: &predictions.target_names,
222 width,
223 origin: PredictionReportOrigin {
224 prediction_id: predictions.prediction_id.clone(),
225 producer_node: predictions.producer_node.clone(),
226 partition: predictions.partition.clone(),
227 fold_id: predictions.fold_id.clone(),
228 },
229 },
230 targets,
231 metrics,
232 )
233}
234
235pub fn score_regression_aggregated_block(
236 predictions: &AggregatedPredictionBlock,
237 targets: &RegressionTargetBlock,
238 metrics: &[RegressionMetricKind],
239) -> Result<RegressionMetricReport> {
240 let width = predictions.validate_shape()?;
241 score_regression_rows(
242 PredictionRows {
243 level: predictions.level,
244 unit_ids: &predictions.unit_ids,
245 values: &predictions.values,
246 target_names: &predictions.target_names,
247 width,
248 origin: PredictionReportOrigin {
249 prediction_id: predictions.prediction_id.clone(),
250 producer_node: predictions.producer_node.clone(),
251 partition: predictions.partition.clone(),
252 fold_id: predictions.fold_id.clone(),
253 },
254 },
255 targets,
256 metrics,
257 )
258}
259
260#[derive(Clone, Debug)]
261struct PredictionReportOrigin {
262 prediction_id: Option<String>,
263 producer_node: NodeId,
264 partition: PredictionPartition,
265 fold_id: Option<FoldId>,
266}
267
268#[derive(Clone, Debug)]
269struct PredictionRows<'a> {
270 level: PredictionLevel,
271 unit_ids: &'a [PredictionUnitId],
272 values: &'a [Vec<f64>],
273 target_names: &'a [String],
274 width: usize,
275 origin: PredictionReportOrigin,
276}
277
278fn score_regression_rows(
279 predictions: PredictionRows<'_>,
280 targets: &RegressionTargetBlock,
281 metrics: &[RegressionMetricKind],
282) -> Result<RegressionMetricReport> {
283 if metrics.is_empty() {
284 return Err(DagMlError::OofValidation(
285 "no regression metrics requested".to_string(),
286 ));
287 }
288 let mut requested_metrics = BTreeSet::new();
289 for metric in metrics {
290 if !requested_metrics.insert(*metric) {
291 return Err(DagMlError::OofValidation(format!(
292 "duplicate regression metric `{}` requested",
293 metric.name()
294 )));
295 }
296 }
297
298 let target_width = targets.validate_shape()?;
299 if predictions.width != target_width {
300 return Err(DagMlError::OofValidation(format!(
301 "prediction width {} does not match target width {target_width}",
302 predictions.width
303 )));
304 }
305 if predictions.level != targets.level {
306 return Err(DagMlError::OofValidation(format!(
307 "prediction level {:?} does not match target level {:?}",
308 predictions.level, targets.level
309 )));
310 }
311 if !predictions.target_names.is_empty()
312 && !targets.target_names.is_empty()
313 && predictions.target_names != targets.target_names
314 {
315 return Err(DagMlError::OofValidation(
316 "prediction target names do not match target block names".to_string(),
317 ));
318 }
319
320 let target_by_unit = targets
321 .unit_ids
322 .iter()
323 .zip(targets.values.iter().map(Vec::as_slice))
324 .collect::<BTreeMap<_, _>>();
325 let mut aligned_predictions = Vec::with_capacity(predictions.unit_ids.len());
326 let mut aligned_targets = Vec::with_capacity(predictions.unit_ids.len());
327 for (unit_id, prediction_row) in predictions.unit_ids.iter().zip(predictions.values.iter()) {
328 let target_row = target_by_unit.get(unit_id).ok_or_else(|| {
329 DagMlError::OofValidation(format!(
330 "prediction unit `{unit_id}` is missing from target block"
331 ))
332 })?;
333 aligned_predictions.push(prediction_row.as_slice());
334 aligned_targets.push(*target_row);
335 }
336 if aligned_predictions.len() != target_by_unit.len() {
337 return Err(DagMlError::OofValidation(
338 "target block contains units not present in predictions".to_string(),
339 ));
340 }
341
342 let target_names = if !predictions.target_names.is_empty() {
343 predictions.target_names.to_vec()
344 } else {
345 targets.target_names.clone()
346 };
347 let metric_suffixes = target_metric_names(predictions.width, &target_names);
348 let mut values = BTreeMap::new();
349 for metric in metrics {
350 let per_target = compute_metric_per_target(
351 *metric,
352 predictions.width,
353 &aligned_predictions,
354 &aligned_targets,
355 );
356 values.insert(metric.name().to_string(), macro_mean(&per_target));
357 for (name, value) in metric_suffixes.iter().zip(per_target) {
358 values.insert(format!("{}:{name}", metric.name()), value);
359 }
360 }
361
362 let report = RegressionMetricReport {
363 prediction_id: predictions.origin.prediction_id,
364 producer_node: predictions.origin.producer_node,
365 partition: predictions.origin.partition,
366 fold_id: predictions.origin.fold_id,
367 level: predictions.level,
368 row_count: predictions.unit_ids.len(),
369 target_width: predictions.width,
370 target_names,
371 metrics: values,
372 };
373 report.validate()?;
374 Ok(report)
375}
376
377fn validate_sample_prediction_block(block: &PredictionBlock) -> Result<usize> {
378 let width = block.validate_shape()?;
379 if block
380 .values
381 .iter()
382 .flatten()
383 .any(|value| !value.is_finite())
384 {
385 return Err(DagMlError::OofValidation(format!(
386 "producer `{}` emitted non-finite sample prediction values",
387 block.producer_node
388 )));
389 }
390 let unique = block.sample_ids.iter().collect::<BTreeSet<_>>();
391 if unique.len() != block.sample_ids.len() {
392 return Err(DagMlError::OofValidation(format!(
393 "producer `{}` emitted duplicate sample predictions",
394 block.producer_node
395 )));
396 }
397 Ok(width)
398}
399
400fn compute_metric_per_target(
401 metric: RegressionMetricKind,
402 width: usize,
403 predictions: &[&[f64]],
404 targets: &[&[f64]],
405) -> Vec<f64> {
406 (0..width)
407 .map(|target_idx| match metric {
408 RegressionMetricKind::Mse => {
409 predictions
410 .iter()
411 .zip(targets.iter())
412 .map(|(prediction, target)| {
413 let error = prediction[target_idx] - target[target_idx];
414 error * error
415 })
416 .sum::<f64>()
417 / predictions.len() as f64
418 }
419 RegressionMetricKind::Rmse => (predictions
420 .iter()
421 .zip(targets.iter())
422 .map(|(prediction, target)| {
423 let error = prediction[target_idx] - target[target_idx];
424 error * error
425 })
426 .sum::<f64>()
427 / predictions.len() as f64)
428 .sqrt(),
429 RegressionMetricKind::Mae => {
430 predictions
431 .iter()
432 .zip(targets.iter())
433 .map(|(prediction, target)| (prediction[target_idx] - target[target_idx]).abs())
434 .sum::<f64>()
435 / predictions.len() as f64
436 }
437 RegressionMetricKind::R2 => r2_for_target(target_idx, predictions, targets),
438 })
439 .collect()
440}
441
442fn r2_for_target(target_idx: usize, predictions: &[&[f64]], targets: &[&[f64]]) -> f64 {
443 let mean = targets.iter().map(|row| row[target_idx]).sum::<f64>() / targets.len() as f64;
444 let ss_res = predictions
445 .iter()
446 .zip(targets.iter())
447 .map(|(prediction, target)| {
448 let error = prediction[target_idx] - target[target_idx];
449 error * error
450 })
451 .sum::<f64>();
452 let ss_tot = targets
453 .iter()
454 .map(|target| {
455 let centered = target[target_idx] - mean;
456 centered * centered
457 })
458 .sum::<f64>();
459 if ss_tot == 0.0 {
460 if ss_res == 0.0 {
461 1.0
462 } else {
463 0.0
464 }
465 } else {
466 1.0 - ss_res / ss_tot
467 }
468}
469
470fn macro_mean(values: &[f64]) -> f64 {
471 values.iter().sum::<f64>() / values.len() as f64
472}
473
474fn target_metric_names(width: usize, target_names: &[String]) -> Vec<String> {
475 if target_names.is_empty() {
476 (0..width).map(|idx| format!("target_{idx}")).collect()
477 } else {
478 target_names.to_vec()
479 }
480}
481
482fn prediction_level_name(level: PredictionLevel) -> &'static str {
483 match level {
484 PredictionLevel::Observation => "observation",
485 PredictionLevel::Sample => "sample",
486 PredictionLevel::Target => "target",
487 PredictionLevel::Group => "group",
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494 use crate::ids::{GroupId, NodeId, SampleId, TargetId};
495 use crate::oof::PredictionPartition;
496
497 fn sid(value: &str) -> SampleId {
498 SampleId::new(value).unwrap()
499 }
500
501 fn sample_unit(value: &str) -> PredictionUnitId {
502 PredictionUnitId::Sample(sid(value))
503 }
504
505 fn target_unit(value: &str) -> PredictionUnitId {
506 PredictionUnitId::Target(TargetId::new(value).unwrap())
507 }
508
509 fn group_unit(value: &str) -> PredictionUnitId {
510 PredictionUnitId::Group(GroupId::new(value).unwrap())
511 }
512
513 fn assert_close(left: f64, right: f64) {
514 assert!((left - right).abs() < 1e-12, "expected {right}, got {left}");
515 }
516
517 #[test]
518 fn metric_objectives_match_selection_direction() {
519 assert_eq!(
520 RegressionMetricKind::Rmse.objective(),
521 MetricObjective::Minimize
522 );
523 assert_eq!(
524 RegressionMetricKind::Mae.objective(),
525 MetricObjective::Minimize
526 );
527 assert_eq!(
528 RegressionMetricKind::Mse.objective(),
529 MetricObjective::Minimize
530 );
531 assert_eq!(
532 RegressionMetricKind::R2.objective(),
533 MetricObjective::Maximize
534 );
535 }
536
537 #[test]
538 fn scores_sample_predictions_and_exports_candidate_metrics() {
539 let predictions = PredictionBlock {
540 prediction_id: Some("pred:sample".to_string()),
541 producer_node: NodeId::new("model:pls").unwrap(),
542 partition: PredictionPartition::Validation,
543 fold_id: None,
544 sample_ids: vec![sid("sample:1"), sid("sample:2")],
545 values: vec![vec![2.0], vec![4.0]],
546 target_names: vec!["y".to_string()],
547 };
548 let targets = RegressionTargetBlock {
549 level: PredictionLevel::Sample,
550 unit_ids: vec![sample_unit("sample:2"), sample_unit("sample:1")],
551 values: vec![vec![5.0], vec![1.0]],
552 target_names: vec!["y".to_string()],
553 };
554
555 let report = score_regression_prediction_block(
556 &predictions,
557 &targets,
558 &[
559 RegressionMetricKind::Rmse,
560 RegressionMetricKind::Mae,
561 RegressionMetricKind::R2,
562 ],
563 )
564 .unwrap();
565
566 assert_eq!(report.level, PredictionLevel::Sample);
567 assert_close(report.metrics["rmse"], 1.0);
568 assert_close(report.metrics["rmse:y"], 1.0);
569 assert_close(report.metrics["mae"], 1.0);
570 assert_close(report.metrics["r2"], 0.75);
571 let candidate = regression_report_to_candidate_score("model:pls", report).unwrap();
572 assert_eq!(candidate.metrics["rmse"], 1.0);
573 assert_eq!(candidate.metadata["metric_level"], "sample");
574 assert_eq!(candidate.metadata["producer_node"], "model:pls");
575 assert_eq!(candidate.metadata["partition"], "validation");
576 assert_eq!(candidate.metadata["prediction_id"], "pred:sample");
577 assert_eq!(candidate.metadata["target_names"], serde_json::json!(["y"]));
578 }
579
580 #[test]
581 fn scores_target_and_group_prediction_blocks() {
582 let predictions = AggregatedPredictionBlock {
583 prediction_id: Some("pred:target".to_string()),
584 producer_node: NodeId::new("model:pls").unwrap(),
585 partition: PredictionPartition::Validation,
586 fold_id: None,
587 level: PredictionLevel::Target,
588 unit_ids: vec![target_unit("target:a"), target_unit("target:b")],
589 values: vec![vec![1.0, 10.0], vec![3.0, 30.0]],
590 target_names: vec!["y1".to_string(), "y2".to_string()],
591 };
592 let targets = RegressionTargetBlock {
593 level: PredictionLevel::Target,
594 unit_ids: vec![target_unit("target:b"), target_unit("target:a")],
595 values: vec![vec![2.0, 28.0], vec![2.0, 12.0]],
596 target_names: vec!["y1".to_string(), "y2".to_string()],
597 };
598 let report = score_regression_aggregated_block(
599 &predictions,
600 &targets,
601 &[RegressionMetricKind::Mse, RegressionMetricKind::Rmse],
602 )
603 .unwrap();
604
605 assert_eq!(report.level, PredictionLevel::Target);
606 assert_close(report.metrics["mse:y1"], 1.0);
607 assert_close(report.metrics["mse:y2"], 4.0);
608 assert_close(report.metrics["mse"], 2.5);
609 assert_close(report.metrics["rmse:y1"], 1.0);
610 assert_close(report.metrics["rmse:y2"], 2.0);
611 assert_close(report.metrics["rmse"], 1.5);
612
613 let group_predictions = AggregatedPredictionBlock {
614 prediction_id: Some("pred:group".to_string()),
615 producer_node: NodeId::new("model:pls").unwrap(),
616 partition: PredictionPartition::Validation,
617 fold_id: None,
618 level: PredictionLevel::Group,
619 unit_ids: vec![group_unit("group:a")],
620 values: vec![vec![3.0]],
621 target_names: vec!["y".to_string()],
622 };
623 let group_targets = RegressionTargetBlock {
624 level: PredictionLevel::Group,
625 unit_ids: vec![group_unit("group:a")],
626 values: vec![vec![1.0]],
627 target_names: vec!["y".to_string()],
628 };
629 let group_report = score_regression_aggregated_block(
630 &group_predictions,
631 &group_targets,
632 &[RegressionMetricKind::Mae],
633 )
634 .unwrap();
635 assert_eq!(group_report.level, PredictionLevel::Group);
636 assert_close(group_report.metrics["mae"], 2.0);
637 }
638
639 #[test]
640 fn refuses_metric_alignment_and_contract_mismatches() {
641 let predictions = AggregatedPredictionBlock {
642 prediction_id: None,
643 producer_node: NodeId::new("model:pls").unwrap(),
644 partition: PredictionPartition::Validation,
645 fold_id: None,
646 level: PredictionLevel::Target,
647 unit_ids: vec![target_unit("target:a")],
648 values: vec![vec![1.0]],
649 target_names: vec!["y".to_string()],
650 };
651 let missing_target = RegressionTargetBlock {
652 level: PredictionLevel::Target,
653 unit_ids: vec![target_unit("target:b")],
654 values: vec![vec![1.0]],
655 target_names: vec!["y".to_string()],
656 };
657 assert!(score_regression_aggregated_block(
658 &predictions,
659 &missing_target,
660 &[RegressionMetricKind::Rmse],
661 )
662 .is_err());
663
664 let wrong_level = RegressionTargetBlock {
665 level: PredictionLevel::Group,
666 unit_ids: vec![group_unit("group:a")],
667 values: vec![vec![1.0]],
668 target_names: vec!["y".to_string()],
669 };
670 assert!(score_regression_aggregated_block(
671 &predictions,
672 &wrong_level,
673 &[RegressionMetricKind::Rmse],
674 )
675 .is_err());
676
677 assert!(score_regression_aggregated_block(&predictions, &missing_target, &[]).is_err());
678 assert!(score_regression_aggregated_block(
679 &predictions,
680 &RegressionTargetBlock {
681 level: PredictionLevel::Target,
682 unit_ids: vec![target_unit("target:a")],
683 values: vec![vec![1.0]],
684 target_names: vec!["other".to_string()],
685 },
686 &[RegressionMetricKind::Rmse],
687 )
688 .is_err());
689 assert!(score_regression_aggregated_block(
690 &predictions,
691 &RegressionTargetBlock {
692 level: PredictionLevel::Target,
693 unit_ids: vec![target_unit("target:a")],
694 values: vec![vec![1.0]],
695 target_names: vec!["y".to_string()],
696 },
697 &[RegressionMetricKind::Rmse, RegressionMetricKind::Rmse],
698 )
699 .is_err());
700 }
701
702 #[test]
703 fn refuses_duplicate_and_non_finite_sample_predictions() {
704 let targets = RegressionTargetBlock {
705 level: PredictionLevel::Sample,
706 unit_ids: vec![sample_unit("sample:1")],
707 values: vec![vec![1.0]],
708 target_names: vec!["y".to_string()],
709 };
710 let mut predictions = PredictionBlock {
711 prediction_id: None,
712 producer_node: NodeId::new("model:pls").unwrap(),
713 partition: PredictionPartition::Validation,
714 fold_id: None,
715 sample_ids: vec![sid("sample:1")],
716 values: vec![vec![f64::INFINITY]],
717 target_names: vec!["y".to_string()],
718 };
719 assert!(score_regression_prediction_block(
720 &predictions,
721 &targets,
722 &[RegressionMetricKind::Rmse],
723 )
724 .is_err());
725
726 predictions.values = vec![vec![1.0], vec![1.0]];
727 predictions.sample_ids = vec![sid("sample:1"), sid("sample:1")];
728 assert!(score_regression_prediction_block(
729 &predictions,
730 &targets,
731 &[RegressionMetricKind::Rmse],
732 )
733 .is_err());
734 }
735
736 #[test]
737 fn constant_target_r2_is_finite_and_deterministic() {
738 let targets = RegressionTargetBlock {
739 level: PredictionLevel::Sample,
740 unit_ids: vec![sample_unit("sample:1"), sample_unit("sample:2")],
741 values: vec![vec![2.0], vec![2.0]],
742 target_names: vec!["y".to_string()],
743 };
744 let exact_predictions = PredictionBlock {
745 prediction_id: None,
746 producer_node: NodeId::new("model:exact").unwrap(),
747 partition: PredictionPartition::Validation,
748 fold_id: None,
749 sample_ids: vec![sid("sample:1"), sid("sample:2")],
750 values: vec![vec![2.0], vec![2.0]],
751 target_names: vec!["y".to_string()],
752 };
753 let exact_report = score_regression_prediction_block(
754 &exact_predictions,
755 &targets,
756 &[RegressionMetricKind::R2],
757 )
758 .unwrap();
759 assert_close(exact_report.metrics["r2"], 1.0);
760
761 let off_predictions = PredictionBlock {
762 values: vec![vec![2.0], vec![3.0]],
763 ..exact_predictions
764 };
765 let off_report = score_regression_prediction_block(
766 &off_predictions,
767 &targets,
768 &[RegressionMetricKind::R2],
769 )
770 .unwrap();
771 assert_close(off_report.metrics["r2"], 0.0);
772 }
773}